From f615b39ded89b05ce23c6ebdd72849d5ff82ff2d Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Wed, 8 Jan 2025 15:33:18 +0200 Subject: [PATCH] :alembic: | Experimenting with larger model architecture. --- discriminator.py | 40 +++++++++++++++++++++++++--------------- generator.py | 38 +++++++++++++++++++++++--------------- training.py | 11 +++++------ 3 files changed, 53 insertions(+), 36 deletions(-) diff --git a/discriminator.py b/discriminator.py index 4e78a42..af29f5d 100644 --- a/discriminator.py +++ b/discriminator.py @@ -2,29 +2,39 @@ import torch import torch.nn as nn import torch.nn.utils as utils +def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1): + padding = (kernel_size // 2) * dilation + return nn.Sequential( + utils.spectral_norm(nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding)), + nn.BatchNorm1d(out_channels), + nn.LeakyReLU(0.2, inplace=True) # Changed activation to LeakyReLU + ) + class SISUDiscriminator(nn.Module): def __init__(self): super(SISUDiscriminator, self).__init__() - layers = 8 + layers = 32 # Increased base layer count self.model = nn.Sequential( - utils.spectral_norm(nn.Conv1d(1, layers, kernel_size=7, stride=2, padding=3)), - nn.BatchNorm1d(layers), - nn.PReLU(), - nn.Conv1d(layers, layers * 2, kernel_size=7, padding=3), - nn.BatchNorm1d(layers * 2), - nn.PReLU(), - nn.Conv1d(layers * 2, layers * 4, kernel_size=5, padding=2), - nn.BatchNorm1d(layers * 4), - nn.PReLU(), - nn.Conv1d(layers * 4, layers * 8, kernel_size=3, padding=1), - nn.BatchNorm1d(layers * 8), - nn.PReLU(), - nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1), + # Initial Convolution + discriminator_block(1, layers, kernel_size=7, stride=2, dilation=1), # Downsample + + # Core Discriminator Blocks with varied kernels and dilations + discriminator_block(layers, layers * 2, kernel_size=5, stride=2, dilation=1), # Downsample + discriminator_block(layers * 2, layers * 2, kernel_size=3, dilation=2), + discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=4), + discriminator_block(layers * 4, layers * 4, kernel_size=3, dilation=8), + discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=16), + discriminator_block(layers * 8, layers * 8, kernel_size=3, dilation=8), + discriminator_block(layers * 8, layers * 4, kernel_size=5, dilation=4), + discriminator_block(layers * 4, layers * 2, kernel_size=3, dilation=2), + discriminator_block(layers * 2, layers, kernel_size=5, dilation=1), + # Final Convolution + discriminator_block(layers, 1, kernel_size=3, stride=1), ) self.global_avg_pool = nn.AdaptiveAvgPool1d(1) def forward(self, x): - x = x + 0.01 * torch.randn_like(x) + # Gaussian noise is not necessary here for discriminator as it is already implicit in the training process x = self.model(x) x = self.global_avg_pool(x) x = x.view(-1, 1) diff --git a/generator.py b/generator.py index 68c50dd..6ea267d 100644 --- a/generator.py +++ b/generator.py @@ -1,31 +1,39 @@ import torch.nn as nn +def conv_block(in_channels, out_channels, kernel_size=3, dilation=1): + return nn.Sequential( + nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, dilation=dilation, padding=(kernel_size // 2) * dilation), + nn.BatchNorm1d(out_channels), + nn.PReLU() + ) + class SISUGenerator(nn.Module): def __init__(self): super(SISUGenerator, self).__init__() - layer = 16 - # Convolution layers with BatchNorm and Residuals + layer = 32 # Increased base layer count self.conv1 = nn.Sequential( - nn.Conv1d(1, layer * 2, kernel_size=7, padding=3), - nn.BatchNorm1d(layer * 2), - nn.PReLU(), - nn.Conv1d(layer * 2, layer * 5, kernel_size=7, padding=3), - nn.BatchNorm1d(layer * 5), - nn.PReLU(), - nn.Conv1d(layer * 5, layer * 5, kernel_size=7, padding=3), - nn.BatchNorm1d(layer * 5), + nn.Conv1d(1, layer, kernel_size=7, padding=3), + nn.BatchNorm1d(layer), nn.PReLU(), ) + self.conv_blocks = nn.Sequential( + conv_block(layer, layer, kernel_size=3, dilation=1), # Local details + conv_block(layer, layer*2, kernel_size=5, dilation=2), # Local Context + conv_block(layer*2, layer*2, kernel_size=3, dilation=4), # Wider context + conv_block(layer*2, layer*4, kernel_size=7, dilation=8), # Longer range dependencies + conv_block(layer*4, layer*4, kernel_size=3, dilation=16), # Longer range dependencies + conv_block(layer*4, layer*2, kernel_size=5, dilation=8), # Wider context + conv_block(layer*2, layer*2, kernel_size=3, dilation=4), # Wider context + conv_block(layer*2, layer, kernel_size=5, dilation=2), # Local Context + conv_block(layer, layer, kernel_size=3, dilation=1), # Local details + ) self.final_layer = nn.Sequential( - nn.Conv1d(layer * 5, layer * 2, kernel_size=5, padding=2), - nn.BatchNorm1d(layer * 2), - nn.PReLU(), - nn.Conv1d(layer * 2, 1, kernel_size=3, padding=1), - # nn.Tanh() # Normalize audio... if needed... + nn.Conv1d(layer, 1, kernel_size=3, padding=1), ) def forward(self, x): residual = x x = self.conv1(x) + x = self.conv_blocks(x) x = self.final_layer(x) return x + residual diff --git a/training.py b/training.py index 8ce6874..e114817 100644 --- a/training.py +++ b/training.py @@ -175,16 +175,15 @@ def start_training(): if generator_epoch % 10 == 0: print(f"Saved epoch {generator_epoch}!") - torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from what ever that low_quality had to high_quality + torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1]) torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1]) - if generator_epoch % 50 == 0: - torch.save(discriminator.state_dict(), f"models/epoch-{generator_epoch}-discriminator.pt") - torch.save(generator.state_dict(), f"models/epoch-{generator_epoch}-generator.pt") + torch.save(discriminator.state_dict(), f"models/current-epoch-discriminator.pt") + torch.save(generator.state_dict(), f"models/current-epoch-generator.pt") - torch.save(discriminator.state_dict(), "models/epoch-500-discriminator.pt") - torch.save(generator.state_dict(), "models/epoch-500-generator.pt") + torch.save(discriminator.state_dict(), "models/epoch-5000-discriminator.pt") + torch.save(generator.state_dict(), "models/epoch-5000-generator.pt") print("Training complete!") start_training()