diff --git a/data.py b/data.py index ac69730..4cc07cf 100644 --- a/data.py +++ b/data.py @@ -4,32 +4,49 @@ import torch import torchaudio import os import random - -import torchaudio.transforms as T -import AudioUtils +from AudioUtils import stereo_tensor_to_mono, stretch_tensor class AudioDataset(Dataset): - #audio_sample_rates = [8000, 11025, 16000, 22050] audio_sample_rates = [11025] def __init__(self, input_dir): - self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')] - + self.input_files = [ + os.path.join(root, f) + for root, _, files in os.walk(input_dir) + for f in files if f.endswith('.wav') + ] def __len__(self): return len(self.input_files) - def __getitem__(self, idx): # Load high-quality audio - high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True) + high_quality_path = self.input_files[idx] + high_quality_audio, original_sample_rate = torchaudio.load(high_quality_path) + high_quality_audio = stereo_tensor_to_mono(high_quality_audio) # Generate low-quality audio with random downsampling mangled_sample_rate = random.choice(self.audio_sample_rates) - resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) - low_quality_audio = resample_transform_low(high_quality_audio) + resample_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) + low_quality_audio = resample_low(high_quality_audio) - resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate) - low_quality_audio = resample_transform_high(low_quality_audio) + resample_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate) + low_quality_audio = resample_high(low_quality_audio) - return (AudioUtils.stereo_tensor_to_mono(high_quality_audio), original_sample_rate), (AudioUtils.stereo_tensor_to_mono(low_quality_audio), mangled_sample_rate) + # Pad or truncate to match a fixed length + target_length = 44100 # Adjust this based on your data + high_quality_audio = self.pad_or_truncate(high_quality_audio, target_length) + low_quality_audio = self.pad_or_truncate(low_quality_audio, target_length) + + return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate) + + def pad_or_truncate(self, tensor, target_length): + current_length = tensor.size(1) + if current_length < target_length: + # Pad with zeros + padding = target_length - current_length + tensor = F.pad(tensor, (0, padding)) + else: + # Truncate to target length + tensor = tensor[:, :target_length] + return tensor diff --git a/discriminator.py b/discriminator.py index b1d82e1..dcc430d 100644 --- a/discriminator.py +++ b/discriminator.py @@ -5,33 +5,34 @@ import torch.nn.utils as utils def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1): padding = (kernel_size // 2) * dilation return nn.Sequential( - utils.spectral_norm(nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding)), + utils.spectral_norm( + nn.Conv1d(in_channels, out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding + ) + ), nn.BatchNorm1d(out_channels), - nn.LeakyReLU(0.2, inplace=True) # Changed activation to LeakyReLU + nn.LeakyReLU(0.2, inplace=True) ) class SISUDiscriminator(nn.Module): def __init__(self): super(SISUDiscriminator, self).__init__() - layers = 4 # Increased base layer count + layers = 4 self.model = nn.Sequential( - # Initial Convolution - discriminator_block(1, layers, kernel_size=7, stride=2, dilation=1), # Downsample - - # Core Discriminator Blocks with varied kernels and dilations - discriminator_block(layers, layers * 2, kernel_size=5, stride=2, dilation=1), # Downsample - discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=4), - discriminator_block(layers * 4, layers * 4, kernel_size=5, dilation=16), - discriminator_block(layers * 4, layers * 2, kernel_size=3, dilation=8), - discriminator_block(layers * 2, layers, kernel_size=3, dilation=1), - # Final Convolution - discriminator_block(layers, 1, kernel_size=3, stride=1), + discriminator_block(1, layers, kernel_size=7, stride=2, dilation=1), + discriminator_block(layers, layers * 2, kernel_size=5, stride=2, dilation=1), + discriminator_block(layers * 2, layers * 4, kernel_size=3, dilation=4), + discriminator_block(layers * 4, layers * 4, kernel_size=5, dilation=8), + discriminator_block(layers * 4, layers * 2, kernel_size=3, dilation=16), + discriminator_block(layers * 2, layers, kernel_size=5, dilation=2), + discriminator_block(layers, 1, kernel_size=3, stride=1) ) self.global_avg_pool = nn.AdaptiveAvgPool1d(1) def forward(self, x): - # Gaussian noise is not necessary here for discriminator as it is already implicit in the training process x = self.model(x) x = self.global_avg_pool(x) - x = x.view(-1, 1) - return x + return x.view(-1, 1) diff --git a/generator.py b/generator.py index 03fa279..c3b8085 100644 --- a/generator.py +++ b/generator.py @@ -1,36 +1,41 @@ import torch.nn as nn -def conv_block(in_channels, out_channels, kernel_size=3, dilation=1): +def conv_residual_block(in_channels, out_channels, kernel_size=3, dilation=1): + padding = (kernel_size // 2) * dilation return nn.Sequential( - nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, dilation=dilation, padding=(kernel_size // 2) * dilation), + nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=padding), nn.BatchNorm1d(out_channels), - nn.PReLU() + nn.PReLU(), + nn.Conv1d(out_channels, out_channels, kernel_size, dilation=dilation, padding=padding), + nn.BatchNorm1d(out_channels) ) class SISUGenerator(nn.Module): def __init__(self): super(SISUGenerator, self).__init__() - layer = 4 # Increased base layer count + layers = 4 self.conv1 = nn.Sequential( - nn.Conv1d(1, layer, kernel_size=7, padding=3), - nn.BatchNorm1d(layer), - nn.PReLU(), + nn.Conv1d(1, layers, kernel_size=7, padding=3), + nn.BatchNorm1d(layers), + nn.PReLU() ) + self.conv_blocks = nn.Sequential( - conv_block(layer, layer, kernel_size=3, dilation=1), # Local details - conv_block(layer, layer*2, kernel_size=5, dilation=2), # Local Context - conv_block(layer*2, layer*2, kernel_size=3, dilation=16), # Longer range dependencies - conv_block(layer*2, layer*2, kernel_size=5, dilation=8), # Wider context - conv_block(layer*2, layer, kernel_size=5, dilation=2), # Local Context - conv_block(layer, layer, kernel_size=3, dilation=1), # Local details + conv_residual_block(layers, layers, kernel_size=3, dilation=1), + conv_residual_block(layers, layers * 2, kernel_size=5, dilation=2), + conv_residual_block(layers * 2, layers * 4, kernel_size=3, dilation=16), + conv_residual_block(layers * 4, layers * 2, kernel_size=5, dilation=8), + conv_residual_block(layers * 2, layers, kernel_size=5, dilation=2), + conv_residual_block(layers, layers, kernel_size=3, dilation=1) ) + self.final_layer = nn.Sequential( - nn.Conv1d(layer, 1, kernel_size=3, padding=1), + nn.Conv1d(layers, 1, kernel_size=3, padding=1) ) def forward(self, x): residual = x x = self.conv1(x) - x = self.conv_blocks(x) + x = self.conv_blocks(x) + x # Adding residual connection after blocks x = self.final_layer(x) return x + residual diff --git a/training.py b/training.py index 6ee7116..29b24d8 100644 --- a/training.py +++ b/training.py @@ -55,6 +55,11 @@ def generator_train(low_quality, real_labels): optimizer_g.step() return generator_output +def first(objects): + if len(objects) >= 1: + return objects[0] + return objects + # Init script argument parser parser = argparse.ArgumentParser(description="Training script") parser.add_argument("--generator", type=str, default=None, @@ -72,17 +77,6 @@ print(f"Using device: {device}") dataset_dir = './dataset/good' dataset = AudioDataset(dataset_dir) -# ========= MULTIPLE ========= - -# dataset_size = len(dataset) -# train_size = int(dataset_size * .9) -# val_size = int(dataset_size-train_size) - -#train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) - -# train_data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) -# val_data_loader = DataLoader(val_dataset, batch_size=1, shuffle=True) - # ========= SINGLE ========= train_data_loader = DataLoader(dataset, batch_size=16, shuffle=True) @@ -112,31 +106,6 @@ scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min' scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5) def start_training(): - - # Training loop - - # ========= DISCRIMINATOR PRE-TRAINING ========= - # discriminator_epochs = 1 - # for discriminator_epoch in range(discriminator_epochs): - - # # ========= TRAINING ========= - # for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {discriminator_epoch+1}/{discriminator_epochs}"): - # high_quality_sample = high_quality_clip[0].to(device) - # low_quality_sample = low_quality_clip[0].to(device) - - # scale = high_quality_clip[0].shape[2]/low_quality_clip[0].shape[2] - - # # ========= LABELS ========= - # batch_size = high_quality_sample.size(0) - # real_labels = torch.ones(batch_size, 1).to(device) - # fake_labels = torch.zeros(batch_size, 1).to(device) - - # # ========= DISCRIMINATOR ========= - # discriminator.train() - # discriminator_train(high_quality_sample, low_quality_sample, scale, real_labels, fake_labels) - - # torch.save(discriminator.state_dict(), "models/discriminator-single-shot-pre-train.pt") - generator_epochs = 5000 for generator_epoch in range(generator_epochs): low_quality_audio = (torch.empty((1)), 1) @@ -165,9 +134,15 @@ def start_training(): generator_output = generator_train(low_quality_sample, real_labels) # ========= SAVE LATEST AUDIO ========= - high_quality_audio = high_quality_clip - low_quality_audio = low_quality_clip - ai_enhanced_audio = (generator_output, high_quality_clip[1]) + high_quality_audio = (first(high_quality_clip[0]), high_quality_clip[1][0]) + low_quality_audio = (first(low_quality_clip[0]), low_quality_clip[1][0]) + ai_enhanced_audio = (first(generator_output[0]), high_quality_clip[1][0]) + print(high_quality_audio) + + print(f"Saved epoch {generator_epoch}!") + torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. + torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1]) + torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1]) #metric = snr(high_quality_audio[0].to(device), ai_enhanced_audio[0]) #print(f"Generator metric {metric}!")