From 9394bc6c5a1f25733a0bc5007148d12dbf11a218 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Sun, 6 Apr 2025 00:05:43 +0300 Subject: [PATCH] :albemic: | Fat architecture. Hopefully better results. --- README.md | 1 + discriminator.py | 59 ++++++++++++++++++++++++++---------------------- generator.py | 48 ++++++++++++++++++++++++++++----------- requirements.txt | 4 +--- training.py | 2 +- 5 files changed, 70 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index cd3b819..f747a42 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ SISU (Super Ingenious Sound Upscaler) is a project that uses GANs (Generative Ad 1. **Set Up**: - Make sure you have Python installed (version 3.8 or higher). - Install needed packages: `pip install -r requirements.txt` + - Install current version of PyTorch (CUDA/ROCm/What ever your device supports) 2. **Prepare Audio Data**: - Put your audio files in the `dataset/good` folder. diff --git a/discriminator.py b/discriminator.py index 58b95f0..777abf2 100644 --- a/discriminator.py +++ b/discriminator.py @@ -2,23 +2,34 @@ import torch import torch.nn as nn import torch.nn.utils as utils -def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, spectral_norm=True): +def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, spectral_norm=True, use_instance_norm=True): padding = (kernel_size // 2) * dilation - conv_layer = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) + conv_layer = nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding + ) + if spectral_norm: conv_layer = utils.spectral_norm(conv_layer) - return nn.Sequential( - conv_layer, - nn.LeakyReLU(0.2, inplace=True), - nn.BatchNorm1d(out_channels) - ) + + layers = [conv_layer] + layers.append(nn.LeakyReLU(0.2, inplace=True)) + + if use_instance_norm: + layers.append(nn.InstanceNorm1d(out_channels)) + + return nn.Sequential(*layers) class AttentionBlock(nn.Module): def __init__(self, channels): super(AttentionBlock, self).__init__() self.attention = nn.Sequential( nn.Conv1d(channels, channels // 4, kernel_size=1), - nn.ReLU(), + nn.ReLU(inplace=True), nn.Conv1d(channels // 4, channels, kernel_size=1), nn.Sigmoid() ) @@ -28,31 +39,25 @@ class AttentionBlock(nn.Module): return x * attention_weights class SISUDiscriminator(nn.Module): - def __init__(self, layers=4): #Increased base layer count + def __init__(self, base_channels=64): super(SISUDiscriminator, self).__init__() + layers = base_channels self.model = nn.Sequential( - discriminator_block(1, layers, kernel_size=3, stride=1), #Aggressive downsampling - discriminator_block(layers, layers * 2, kernel_size=5, stride=2), - discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=4), - - #AttentionBlock(layers * 4), #Added attention - - #discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=4), - #AttentionBlock(layers * 8), #Added attention - #discriminator_block(layers * 8, layers * 16, kernel_size=5, dilation=8), - #discriminator_block(layers * 16, layers * 16, kernel_size=3, dilation=1), - #discriminator_block(layers * 16, layers * 8, kernel_size=3, dilation=2), - #discriminator_block(layers * 8, layers * 4, kernel_size=3, dilation=1), - discriminator_block(layers * 4, layers * 2, kernel_size=5, stride=2), - discriminator_block(layers * 2, layers, kernel_size=3, stride=1), - discriminator_block(layers, 1, kernel_size=3, stride=1, spectral_norm=False) #last layer no spectral norm. + discriminator_block(1, layers, kernel_size=7, stride=1, spectral_norm=True, use_instance_norm=False), + discriminator_block(layers, layers * 2, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True), + discriminator_block(layers * 2, layers * 4, kernel_size=5, stride=1, dilation=2, spectral_norm=True, use_instance_norm=True), + AttentionBlock(layers * 4), + discriminator_block(layers * 4, layers * 8, kernel_size=5, stride=1, dilation=4, spectral_norm=True, use_instance_norm=True), + discriminator_block(layers * 8, layers * 4, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True), + discriminator_block(layers * 4, layers * 2, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True), + discriminator_block(layers * 2, layers, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True), + discriminator_block(layers, 1, kernel_size=3, stride=1, spectral_norm=False, use_instance_norm=False) ) + self.global_avg_pool = nn.AdaptiveAvgPool1d(1) - self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.model(x) x = self.global_avg_pool(x) - x = x.view(-1, 1) - x = self.sigmoid(x) + x = x.view(x.size(0), -1) return x diff --git a/generator.py b/generator.py index 950530a..cd4d48c 100644 --- a/generator.py +++ b/generator.py @@ -1,18 +1,28 @@ +import torch import torch.nn as nn def conv_block(in_channels, out_channels, kernel_size=3, dilation=1): return nn.Sequential( - nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, dilation=dilation, padding=(kernel_size // 2) * dilation), - nn.BatchNorm1d(out_channels), + nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + dilation=dilation, + padding=(kernel_size // 2) * dilation + ), + nn.InstanceNorm1d(out_channels), nn.PReLU() ) class AttentionBlock(nn.Module): + """ + Simple Channel Attention Block. Learns to weight channels based on their importance. + """ def __init__(self, channels): super(AttentionBlock, self).__init__() self.attention = nn.Sequential( nn.Conv1d(channels, channels // 4, kernel_size=1), - nn.ReLU(), + nn.ReLU(inplace=True), nn.Conv1d(channels // 4, channels, kernel_size=1), nn.Sigmoid() ) @@ -24,7 +34,11 @@ class AttentionBlock(nn.Module): class ResidualInResidualBlock(nn.Module): def __init__(self, channels, num_convs=3): super(ResidualInResidualBlock, self).__init__() - self.conv_layers = nn.Sequential(*[conv_block(channels, channels) for _ in range(num_convs)]) + + self.conv_layers = nn.Sequential( + *[conv_block(channels, channels) for _ in range(num_convs)] + ) + self.attention = AttentionBlock(channels) def forward(self, x): @@ -34,19 +48,27 @@ class ResidualInResidualBlock(nn.Module): return x + residual class SISUGenerator(nn.Module): - def __init__(self, layer=4, num_rirb=4): #increased base layer and rirb amounts + def __init__(self, channels=64, num_rirb=8, alpha=1.0): super(SISUGenerator, self).__init__() + self.alpha = alpha + self.conv1 = nn.Sequential( - nn.Conv1d(1, layer, kernel_size=7, padding=3), - nn.BatchNorm1d(layer), + nn.Conv1d(1, channels, kernel_size=7, padding=3), + nn.InstanceNorm1d(channels), nn.PReLU(), ) - self.rir_blocks = nn.Sequential(*[ResidualInResidualBlock(layer) for _ in range(num_rirb)]) - self.final_layer = nn.Conv1d(layer, 1, kernel_size=3, padding=1) + + self.rir_blocks = nn.Sequential( + *[ResidualInResidualBlock(channels) for _ in range(num_rirb)] + ) + + self.final_layer = nn.Conv1d(channels, 1, kernel_size=3, padding=1) def forward(self, x): - residual = x + residual_input = x x = self.conv1(x) - x = self.rir_blocks(x) - x = self.final_layer(x) - return x + residual + x_rirb_out = self.rir_blocks(x) + learned_residual = self.final_layer(x_rirb_out) + output = residual_input + self.alpha * learned_residual + + return output diff --git a/requirements.txt b/requirements.txt index eacfc3b..21f6bef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,10 +5,8 @@ MarkupSafe==2.1.5 mpmath==1.3.0 networkx==3.4.2 numpy==2.2.3 -pytorch-triton-rocm==3.2.0+git4b3bb1f8 +pillow==11.0.0 setuptools==70.2.0 sympy==1.13.3 -torch==2.7.0.dev20250226+rocm6.3 -torchaudio==2.6.0.dev20250226+rocm6.3 tqdm==4.67.1 typing_extensions==4.12.2 diff --git a/training.py b/training.py index 814fcda..380f738 100644 --- a/training.py +++ b/training.py @@ -101,7 +101,7 @@ dataset = AudioDataset(dataset_dir, device) # ========= SINGLE ========= -train_data_loader = DataLoader(dataset, batch_size=256, shuffle=True) +train_data_loader = DataLoader(dataset, batch_size=8, shuffle=True) # Initialize models and move them to device generator = SISUGenerator()