From eca71ff5ea3a0f2ef4e25247f897b055591e317b Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Wed, 25 Dec 2024 00:09:57 +0200 Subject: [PATCH] :alembic: | Experimenting still... --- AudioUtils.py | 18 ++++++ data.py | 36 +++-------- discriminator.py | 28 +++++---- generator.py | 53 +++++++--------- requirements.txt | 26 ++++---- training.py | 155 ++++++++++++++++++++++++++--------------------- 6 files changed, 167 insertions(+), 149 deletions(-) create mode 100644 AudioUtils.py diff --git a/AudioUtils.py b/AudioUtils.py new file mode 100644 index 0000000..04f75db --- /dev/null +++ b/AudioUtils.py @@ -0,0 +1,18 @@ +import torch +import torch.nn.functional as F + +def stereo_tensor_to_mono(waveform): + if waveform.shape[0] > 1: + # Average across channels + mono_waveform = torch.mean(waveform, dim=0, keepdim=True) + else: + # Already mono + mono_waveform = waveform + return mono_waveform + +def stretch_tensor(tensor, target_length): + scale_factor = target_length / tensor.size(1) + + tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False) + + return tensor diff --git a/data.py b/data.py index dfd5c2b..02f77d3 100644 --- a/data.py +++ b/data.py @@ -1,49 +1,31 @@ from torch.utils.data import Dataset import torch.nn.functional as F +import torch import torchaudio import os import random +import torchaudio.transforms as T +import AudioUtils class AudioDataset(Dataset): - audio_sample_rates = [8000, 11025, 16000, 22050] + #audio_sample_rates = [8000, 11025, 16000, 22050] + audio_sample_rates = [11025] - def __init__(self, input_dir, target_duration=None, padding_mode='constant', padding_value=0.0): + def __init__(self, input_dir): self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')] - self.target_duration = target_duration # Duration in seconds or None if not set - self.padding_mode = padding_mode - self.padding_value = padding_value def __len__(self): return len(self.input_files) def __getitem__(self, idx): + # Load high-quality audio high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True) + # Generate low-quality audio with random downsampling mangled_sample_rate = random.choice(self.audio_sample_rates) resample_transform = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) low_quality_audio = resample_transform(high_quality_audio) - # Calculate target length based on desired duration and 16000 Hz - # if self.target_duration is not None: - # target_length = int(self.target_duration * 44100) - # else: - # # Calculate duration of original high quality audio - # target_length = high_quality_wav.size(1) - - # Pad both to the calculated target length - # high_quality_wav = self.stretch_tensor(high_quality_wav, target_length) - # low_quality_wav = self.stretch_tensor(low_quality_wav, target_length) - - - return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate) - - def stretch_tensor(self, tensor, target_length): - current_length = tensor.size(1) - scale_factor = target_length / current_length - - # Resample the tensor using linear interpolation - tensor = F.interpolate(tensor.unsqueeze(0), scale_factor=scale_factor, mode='linear', align_corners=False).squeeze(0) - - return tensor + return (AudioUtils.stereo_tensor_to_mono(high_quality_audio), original_sample_rate), (AudioUtils.stereo_tensor_to_mono(low_quality_audio), mangled_sample_rate) diff --git a/discriminator.py b/discriminator.py index a63d1e9..b72b05b 100644 --- a/discriminator.py +++ b/discriminator.py @@ -3,22 +3,28 @@ import torch.nn as nn class SISUDiscriminator(nn.Module): def __init__(self): super(SISUDiscriminator, self).__init__() + layers = 32 self.model = nn.Sequential( - nn.Conv1d(2, 128, kernel_size=3, padding=1), - #nn.LeakyReLU(0.2, inplace=True), - nn.Conv1d(128, 256, kernel_size=3, padding=1), + nn.Conv1d(1, layers, kernel_size=5, stride=2, padding=2), + nn.BatchNorm1d(layers), nn.LeakyReLU(0.2, inplace=True), - nn.Conv1d(256, 128, kernel_size=3, padding=1), - #nn.LeakyReLU(0.2, inplace=True), - nn.Conv1d(128, 64, kernel_size=3, padding=1), - #nn.LeakyReLU(0.2, inplace=True), - nn.Conv1d(64, 1, kernel_size=3, padding=1), - #nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(layers, layers * 2, kernel_size=5, stride=2, padding=2), + nn.BatchNorm1d(layers * 2), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(layers * 2, layers * 4, kernel_size=5, stride=2, padding=2), + nn.BatchNorm1d(layers * 4), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(layers * 4, layers * 8, kernel_size=5, stride=2, padding=2), + nn.BatchNorm1d(layers * 8), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1), ) - self.global_avg_pool = nn.AdaptiveAvgPool1d(1) # Output size (1,) + self.global_avg_pool = nn.AdaptiveAvgPool1d(1) + self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.model(x) x = self.global_avg_pool(x) - x = x.view(-1, 1) # Flatten to (batch_size, 1) + x = x.view(-1, 1) + x = self.sigmoid(x) return x diff --git a/generator.py b/generator.py index ecd4e22..fd13467 100644 --- a/generator.py +++ b/generator.py @@ -1,39 +1,32 @@ import torch.nn as nn class SISUGenerator(nn.Module): - def __init__(self, upscale_scale=1): + def __init__(self, upscale_scale=4): # No noise_dim parameter super(SISUGenerator, self).__init__() - self.layers1 = nn.Sequential( - nn.Conv1d(2, 128, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, inplace=True), # Activation - nn.BatchNorm1d(128), # Batch Norm - nn.Conv1d(128, 256, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, inplace=True), # Activation - nn.BatchNorm1d(256), # Batch Norm + layer = 32 + # Convolution layers + self.conv1 = nn.Sequential( + nn.Conv1d(1, layer * 2, kernel_size=7, padding=1), + nn.PReLU(), + nn.Conv1d(layer * 2, layer * 5, kernel_size=5, padding=1), + nn.PReLU(), + nn.Conv1d(layer * 5, layer * 5, kernel_size=3, padding=1), + nn.PReLU() ) - self.layers2 = nn.Sequential( - nn.Conv1d(256, 128, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, inplace=True), # Activation - nn.BatchNorm1d(128), # Batch Norm - nn.Conv1d(128, 64, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, inplace=True), # Activation - nn.BatchNorm1d(64), # Batch Norm - nn.Conv1d(64, upscale_scale * 2, kernel_size=3, padding=1), # Output channels scaled + # Transposed convolution for upsampling + self.upsample = nn.ConvTranspose1d(layer * 5, layer * 5, kernel_size=upscale_scale, stride=upscale_scale) + + self.conv2 = nn.Sequential( + nn.Conv1d(layer * 5, layer * 5, kernel_size=3, padding=1), + nn.PReLU(), + nn.Conv1d(layer * 5, layer * 2, kernel_size=5, padding=1), + nn.PReLU(), + nn.Conv1d(layer * 2, 1, kernel_size=7, padding=1) ) - self.upscale_factor = upscale_scale - def pixel_shuffle_1d(self, input, upscale_factor): - batch_size, channels, in_width = input.size() - out_width = in_width * upscale_factor - input_view = input.contiguous().view(batch_size, channels // upscale_factor, upscale_factor, in_width) - shuffle_out = input_view.permute(0, 1, 3, 2).contiguous() - return shuffle_out.view(batch_size, channels // upscale_factor, out_width) - - def forward(self, x, scale): - x = self.layers1(x) - upsample = nn.Upsample(scale_factor=scale, mode='nearest') - x = upsample(x) - x = self.layers2(x) - x = self.pixel_shuffle_1d(x, self.upscale_factor) + def forward(self, x, upscale_scale=4): + x = self.conv1(x) + x = self.upsample(x) + x = self.conv2(x) return x diff --git a/requirements.txt b/requirements.txt index dc0c01f..5cb5df1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,14 @@ -filelock>=3.16.1 -fsspec>=2024.10.0 -Jinja2>=3.1.4 -MarkupSafe>=2.1.5 -mpmath>=1.3.0 -networkx>=3.4.2 -numpy>=2.1.2 -pillow>=11.0.0 -setuptools>=70.2.0 -sympy>=1.13.1 -tqdm>=4.67.1 -typing_extensions>=4.12.2 +filelock==3.16.1 +fsspec==2024.10.0 +Jinja2==3.1.4 +MarkupSafe==2.1.5 +mpmath==1.3.0 +networkx==3.4.2 +numpy==2.2.1 +pytorch-triton-rocm==3.2.0+git0d4682f0 +setuptools==70.2.0 +sympy==1.13.1 +torch==2.6.0.dev20241222+rocm6.2.4 +torchaudio==2.6.0.dev20241222+rocm6.2.4 +tqdm==4.67.1 +typing_extensions==4.12.2 diff --git a/training.py b/training.py index 9d16004..07af268 100644 --- a/training.py +++ b/training.py @@ -6,66 +6,73 @@ import torch.nn.functional as F import torchaudio import tqdm +import argparse + +import math + from torch.utils.data import random_split from torch.utils.data import DataLoader +import AudioUtils from data import AudioDataset from generator import SISUGenerator from discriminator import SISUDiscriminator -# Mel Spectrogram Loss -class MelSpectrogramLoss(nn.Module): - def __init__(self, sample_rate=44100, n_fft=2048, hop_length=512, n_mels=128): - super(MelSpectrogramLoss, self).__init__() - self.mel_transform = torchaudio.transforms.MelSpectrogram( - sample_rate=sample_rate, - n_fft=n_fft, - hop_length=hop_length, - n_mels=n_mels - ).to(device) # Move to device +def perceptual_loss(y_true, y_pred): + return torch.mean((y_true - y_pred) ** 2) - def forward(self, y_pred, y_true): - mel_pred = self.mel_transform(y_pred) - mel_true = self.mel_transform(y_true) - return F.l1_loss(mel_pred, mel_true) - -def snr(y_true, y_pred): - noise = y_true - y_pred - signal_power = torch.mean(y_true ** 2) - noise_power = torch.mean(noise ** 2) - snr_db = 10 * torch.log10(signal_power / noise_power) - return snr_db - -def discriminator_train(high_quality, low_quality, scale, real_labels, fake_labels): +def discriminator_train(high_quality, low_quality, real_labels, fake_labels): optimizer_d.zero_grad() - discriminator_decision_from_real = discriminator(high_quality) - # TODO: Experiment with criterions HERE! + # Forward pass for real samples + discriminator_decision_from_real = discriminator(high_quality[0]) d_loss_real = criterion_d(discriminator_decision_from_real, real_labels) - generator_output = generator(low_quality, scale) - discriminator_decision_from_fake = discriminator(generator_output.detach()) - # TODO: Experiment with criterions HERE! + integer_scale = math.ceil(high_quality[1]/low_quality[1]) + + # Forward pass for fake samples (from generator output) + generator_output = generator(low_quality[0], integer_scale) + resample_transform = torchaudio.transforms.Resample(low_quality[1] * integer_scale, high_quality[1]).to(device) + resampled = resample_transform(generator_output.detach()) + + discriminator_decision_from_fake = discriminator(resampled) d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels) + # Combine real and fake losses d_loss = (d_loss_real + d_loss_fake) / 2.0 + # Backward pass and optimization d_loss.backward() - nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) #Gradient Clipping + nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping optimizer_d.step() + return d_loss -def generator_train(low_quality, scale, real_labels): +def generator_train(low_quality, real_labels, target_sample_rate=44100): optimizer_g.zero_grad() - generator_output = generator(low_quality, scale) - discriminator_decision = discriminator(generator_output) - # TODO: Fix this shit + scale = math.ceil(target_sample_rate/low_quality[1]) + + # Forward pass for fake samples (from generator output) + generator_output = generator(low_quality[0], scale) + resample_transform = torchaudio.transforms.Resample(low_quality[1] * scale, target_sample_rate).to(device) + resampled = resample_transform(generator_output) + + discriminator_decision = discriminator(resampled) g_loss = criterion_g(discriminator_decision, real_labels) g_loss.backward() optimizer_g.step() - return generator_output + return resampled + +# Init script argument parser +parser = argparse.ArgumentParser(description="Training script") +parser.add_argument("--generator", type=str, default=None, + help="Path to the generator model file") +parser.add_argument("--discriminator", type=str, default=None, + help="Path to the discriminator model file") + +args = parser.parse_args() # Check for CUDA availability device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -73,28 +80,38 @@ print(f"Using device: {device}") # Initialize dataset and dataloader dataset_dir = './dataset/good' -dataset = AudioDataset(dataset_dir, target_duration=2.0) +dataset = AudioDataset(dataset_dir) -dataset_size = len(dataset) -train_size = int(dataset_size * .9) -val_size = int(dataset_size-train_size) +# ========= MULTIPLE ========= -train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) +# dataset_size = len(dataset) +# train_size = int(dataset_size * .9) +# val_size = int(dataset_size-train_size) -train_data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) -val_data_loader = DataLoader(val_dataset, batch_size=1, shuffle=True) +#train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) + +# train_data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) +# val_data_loader = DataLoader(val_dataset, batch_size=1, shuffle=True) + +# ========= SINGLE ========= + +train_data_loader = DataLoader(dataset, batch_size=1, shuffle=True) # Initialize models and move them to device generator = SISUGenerator() discriminator = SISUDiscriminator() +if args.generator is not None: + generator.load_state_dict(torch.load(args.generator, weights_only=True)) +if args.discriminator is not None: + discriminator.load_state_dict(torch.load(args.discriminator, weights_only=True)) + generator = generator.to(device) discriminator = discriminator.to(device) # Loss criterion_g = nn.L1Loss() -criterion_g_mel = MelSpectrogramLoss().to(device) -criterion_d = nn.BCEWithLogitsLoss() +criterion_d = nn.BCELoss() # Optimizers optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999)) @@ -109,39 +126,40 @@ def start_training(): # Training loop # ========= DISCRIMINATOR PRE-TRAINING ========= - discriminator_epochs = 1 - for discriminator_epoch in range(discriminator_epochs): + # discriminator_epochs = 1 + # for discriminator_epoch in range(discriminator_epochs): - # ========= TRAINING ========= - for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {discriminator_epoch+1}/{discriminator_epochs}"): - high_quality_sample = high_quality_clip[0].to(device) - low_quality_sample = low_quality_clip[0].to(device) + # # ========= TRAINING ========= + # for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {discriminator_epoch+1}/{discriminator_epochs}"): + # high_quality_sample = high_quality_clip[0].to(device) + # low_quality_sample = low_quality_clip[0].to(device) - scale = high_quality_clip[0].shape[2]/low_quality_clip[0].shape[2] + # scale = high_quality_clip[0].shape[2]/low_quality_clip[0].shape[2] - # ========= LABELS ========= - batch_size = high_quality_sample.size(0) - real_labels = torch.ones(batch_size, 1).to(device) - fake_labels = torch.zeros(batch_size, 1).to(device) + # # ========= LABELS ========= + # batch_size = high_quality_sample.size(0) + # real_labels = torch.ones(batch_size, 1).to(device) + # fake_labels = torch.zeros(batch_size, 1).to(device) - # ========= DISCRIMINATOR ========= - discriminator.train() - discriminator_train(high_quality_sample, low_quality_sample, scale, real_labels, fake_labels) + # # ========= DISCRIMINATOR ========= + # discriminator.train() + # discriminator_train(high_quality_sample, low_quality_sample, scale, real_labels, fake_labels) - torch.save(discriminator.state_dict(), "models/discriminator-single-shot-pre-train.pt") + # torch.save(discriminator.state_dict(), "models/discriminator-single-shot-pre-train.pt") - generator_epochs = 500 + generator_epochs = 5000 for generator_epoch in range(generator_epochs): low_quality_audio = (torch.empty((1)), 1) high_quality_audio = (torch.empty((1)), 1) ai_enhanced_audio = (torch.empty((1)), 1) + times_correct = 0 + # ========= TRAINING ========= for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {generator_epoch+1}/{generator_epochs}"): - high_quality_sample = high_quality_clip[0].to(device) - low_quality_sample = low_quality_clip[0].to(device) - - scale = high_quality_clip[0].shape[2]/low_quality_clip[0].shape[2] + # for high_quality_clip, low_quality_clip in train_data_loader: + high_quality_sample = (high_quality_clip[0].to(device), high_quality_clip[1]) + low_quality_sample = (low_quality_clip[0].to(device), low_quality_clip[1]) # ========= LABELS ========= batch_size = high_quality_clip[0].size(0) @@ -150,21 +168,20 @@ def start_training(): # ========= DISCRIMINATOR ========= discriminator.train() - for _ in range(3): - discriminator_train(high_quality_sample, low_quality_sample, scale, real_labels, fake_labels) + discriminator_train(high_quality_sample, low_quality_sample, real_labels, fake_labels) # ========= GENERATOR ========= generator.train() - generator_output = generator_train(low_quality_sample, scale, real_labels) + generator_output = generator_train(low_quality_sample, real_labels, high_quality_sample[1]) # ========= SAVE LATEST AUDIO ========= high_quality_audio = high_quality_clip low_quality_audio = low_quality_clip ai_enhanced_audio = (generator_output, high_quality_clip[1]) - metric = snr(high_quality_audio[0].to(device), ai_enhanced_audio[0]) - print(f"Generator metric {metric}!") - scheduler_g.step(metric) + #metric = snr(high_quality_audio[0].to(device), ai_enhanced_audio[0]) + #print(f"Generator metric {metric}!") + #scheduler_g.step(metric) if generator_epoch % 10 == 0: print(f"Saved epoch {generator_epoch}!")