From 7e1c7e935a0e43a9696f424f9757c6ad344dc2c9 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Sat, 15 Mar 2025 18:01:19 +0200 Subject: [PATCH] :albemic: | Experimenting with other model layouts. --- data.py | 26 +++++++++++++++++----- discriminator.py | 57 ++++++++++++++++++++++++++++++++---------------- generator.py | 44 +++++++++++++++++++++++++------------ training.py | 27 +++++++---------------- 4 files changed, 96 insertions(+), 58 deletions(-) diff --git a/data.py b/data.py index 2f05581..9ca5ee5 100644 --- a/data.py +++ b/data.py @@ -4,23 +4,20 @@ import torch import torchaudio import os import random - import torchaudio.transforms as T import AudioUtils class AudioDataset(Dataset): - #audio_sample_rates = [8000, 11025, 16000, 22050] audio_sample_rates = [11025] + MAX_LENGTH = 88200 # Define your desired maximum length here def __init__(self, input_dir, device): self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')] self.device = device - def __len__(self): return len(self.input_files) - def __getitem__(self, idx): # Load high-quality audio high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True) @@ -33,7 +30,24 @@ class AudioDataset(Dataset): resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate) low_quality_audio = resample_transform_high(low_quality_audio) - high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio).to(self.device) - low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio).to(self.device) + high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio) + low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio) + + # Pad or truncate high-quality audio + if high_quality_audio.shape[1] < self.MAX_LENGTH: + padding = self.MAX_LENGTH - high_quality_audio.shape[1] + high_quality_audio = F.pad(high_quality_audio, (0, padding)) + elif high_quality_audio.shape[1] > self.MAX_LENGTH: + high_quality_audio = high_quality_audio[:, :self.MAX_LENGTH] + + # Pad or truncate low-quality audio + if low_quality_audio.shape[1] < self.MAX_LENGTH: + padding = self.MAX_LENGTH - low_quality_audio.shape[1] + low_quality_audio = F.pad(low_quality_audio, (0, padding)) + elif low_quality_audio.shape[1] > self.MAX_LENGTH: + low_quality_audio = low_quality_audio[:, :self.MAX_LENGTH] + + high_quality_audio = high_quality_audio.to(self.device) + low_quality_audio = low_quality_audio.to(self.device) return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate) diff --git a/discriminator.py b/discriminator.py index d090372..b1ec6eb 100644 --- a/discriminator.py +++ b/discriminator.py @@ -2,35 +2,54 @@ import torch import torch.nn as nn import torch.nn.utils as utils -def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1): +def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, spectral_norm=True): padding = (kernel_size // 2) * dilation + conv_layer = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) + if spectral_norm: + conv_layer = utils.spectral_norm(conv_layer) return nn.Sequential( - utils.spectral_norm(nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding)), + conv_layer, nn.LeakyReLU(0.2, inplace=True), nn.BatchNorm1d(out_channels) ) -class SISUDiscriminator(nn.Module): - def __init__(self): - super(SISUDiscriminator, self).__init__() - layers = 4 # Increased base layer count - self.model = nn.Sequential( - discriminator_block(1, layers, kernel_size=7, stride=2), # Initial downsampling - discriminator_block(layers, layers * 2, kernel_size=5, stride=2), # Downsampling - discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=2), # Increased dilation - discriminator_block(layers * 4, layers * 4, kernel_size=5, dilation=4), # Increased dilation - discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=8), # Deeper layer! - discriminator_block(layers * 8, layers * 8, kernel_size=5, dilation=1), # Deeper layer! - discriminator_block(layers * 8, layers * 4, kernel_size=3, dilation=2), # Reduced dilation - discriminator_block(layers * 4, layers * 2, kernel_size=3, dilation=1), - discriminator_block(layers * 2, layers, kernel_size=3, stride=1), # Final convolution - discriminator_block(layers, 1, kernel_size=3, stride=1) +class AttentionBlock(nn.Module): + def __init__(self, channels): + super(AttentionBlock, self).__init__() + self.attention = nn.Sequential( + nn.Conv1d(channels, channels // 4, kernel_size=1), + nn.ReLU(), + nn.Conv1d(channels // 4, channels, kernel_size=1), + nn.Sigmoid() ) - self.global_avg_pool = nn.AdaptiveAvgPool1d(1) def forward(self, x): - # Gaussian noise is not necessary here for discriminator as it is already implicit in the training process + attention_weights = self.attention(x) + return x * attention_weights + +class SISUDiscriminator(nn.Module): + def __init__(self, layers=4): #Increased base layer count + super(SISUDiscriminator, self).__init__() + self.model = nn.Sequential( + discriminator_block(1, layers, kernel_size=7, stride=4), #Aggressive downsampling + discriminator_block(layers, layers * 2, kernel_size=5, stride=2), + discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=2), + discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=4), + AttentionBlock(layers * 8), #Added attention + discriminator_block(layers * 8, layers * 16, kernel_size=5, dilation=8), + discriminator_block(layers * 16, layers * 16, kernel_size=3, dilation=1), + discriminator_block(layers * 16, layers * 8, kernel_size=3, dilation=2), + discriminator_block(layers * 8, layers * 4, kernel_size=3, dilation=1), + discriminator_block(layers * 4, layers * 2, kernel_size=3, stride=1), + discriminator_block(layers * 2, layers, kernel_size=3, stride=1), + discriminator_block(layers, 1, kernel_size=3, stride=1, spectral_norm=False) #last layer no spectral norm. + ) + self.global_avg_pool = nn.AdaptiveAvgPool1d(1) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): x = self.model(x) x = self.global_avg_pool(x) x = x.view(-1, 1) + x = self.sigmoid(x) return x diff --git a/generator.py b/generator.py index 03fa279..950530a 100644 --- a/generator.py +++ b/generator.py @@ -7,30 +7,46 @@ def conv_block(in_channels, out_channels, kernel_size=3, dilation=1): nn.PReLU() ) +class AttentionBlock(nn.Module): + def __init__(self, channels): + super(AttentionBlock, self).__init__() + self.attention = nn.Sequential( + nn.Conv1d(channels, channels // 4, kernel_size=1), + nn.ReLU(), + nn.Conv1d(channels // 4, channels, kernel_size=1), + nn.Sigmoid() + ) + + def forward(self, x): + attention_weights = self.attention(x) + return x * attention_weights + +class ResidualInResidualBlock(nn.Module): + def __init__(self, channels, num_convs=3): + super(ResidualInResidualBlock, self).__init__() + self.conv_layers = nn.Sequential(*[conv_block(channels, channels) for _ in range(num_convs)]) + self.attention = AttentionBlock(channels) + + def forward(self, x): + residual = x + x = self.conv_layers(x) + x = self.attention(x) + return x + residual + class SISUGenerator(nn.Module): - def __init__(self): + def __init__(self, layer=4, num_rirb=4): #increased base layer and rirb amounts super(SISUGenerator, self).__init__() - layer = 4 # Increased base layer count self.conv1 = nn.Sequential( nn.Conv1d(1, layer, kernel_size=7, padding=3), nn.BatchNorm1d(layer), nn.PReLU(), ) - self.conv_blocks = nn.Sequential( - conv_block(layer, layer, kernel_size=3, dilation=1), # Local details - conv_block(layer, layer*2, kernel_size=5, dilation=2), # Local Context - conv_block(layer*2, layer*2, kernel_size=3, dilation=16), # Longer range dependencies - conv_block(layer*2, layer*2, kernel_size=5, dilation=8), # Wider context - conv_block(layer*2, layer, kernel_size=5, dilation=2), # Local Context - conv_block(layer, layer, kernel_size=3, dilation=1), # Local details - ) - self.final_layer = nn.Sequential( - nn.Conv1d(layer, 1, kernel_size=3, padding=1), - ) + self.rir_blocks = nn.Sequential(*[ResidualInResidualBlock(layer) for _ in range(num_rirb)]) + self.final_layer = nn.Conv1d(layer, 1, kernel_size=3, padding=1) def forward(self, x): residual = x x = self.conv1(x) - x = self.conv_blocks(x) + x = self.rir_blocks(x) x = self.final_layer(x) return x + residual diff --git a/training.py b/training.py index bf60c5c..50743be 100644 --- a/training.py +++ b/training.py @@ -38,7 +38,7 @@ device = torch.device(args.device if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") mfcc_transform = T.MFCC( - sample_rate=16000, # Adjust to your sample rate + sample_rate=44100, # Adjust to your sample rate n_mfcc=20, melkwargs={'n_fft': 2048, 'hop_length': 512} # adjust n_fft and hop_length to your needs. ).to(device) @@ -97,20 +97,9 @@ debug = args.verbose dataset_dir = './dataset/good' dataset = AudioDataset(dataset_dir, device) -# ========= MULTIPLE ========= - -# dataset_size = len(dataset) -# train_size = int(dataset_size * .9) -# val_size = int(dataset_size-train_size) - -#train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) - -# train_data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) -# val_data_loader = DataLoader(val_dataset, batch_size=1, shuffle=True) - # ========= SINGLE ========= -train_data_loader = DataLoader(dataset, batch_size=1, shuffle=True) +train_data_loader = DataLoader(dataset, batch_size=128, shuffle=True) # Initialize models and move them to device generator = SISUGenerator() @@ -175,17 +164,17 @@ def start_training(): scheduler_g.step(combined_loss) # ========= SAVE LATEST AUDIO ========= - high_quality_audio = high_quality_clip - low_quality_audio = low_quality_clip - ai_enhanced_audio = (generator_output, high_quality_clip[1]) + high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0]) + low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0]) + ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0]) new_epoch = generator_epoch+epoch if generator_epoch % 10 == 0: print(f"Saved epoch {new_epoch}!") - torchaudio.save(f"./output/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. - torchaudio.save(f"./output/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1]) - torchaudio.save(f"./output/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1]) + torchaudio.save(f"./output/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. + torchaudio.save(f"./output/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu(), ai_enhanced_audio[1]) + torchaudio.save(f"./output/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu(), high_quality_audio[1]) if debug: print(generator.state_dict().keys())