From 741dcce7b45af6738c4b8a5df33cb6d3770fccc7 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Sun, 23 Feb 2025 13:52:01 +0200 Subject: [PATCH] :alembic: | Increase discriminator size and implement mfcc_loss for generator. --- discriminator.py | 25 ++++++------- training.py | 96 ++++++++++++++++++++++++++---------------------- 2 files changed, 65 insertions(+), 56 deletions(-) diff --git a/discriminator.py b/discriminator.py index b1d82e1..d090372 100644 --- a/discriminator.py +++ b/discriminator.py @@ -6,8 +6,8 @@ def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dila padding = (kernel_size // 2) * dilation return nn.Sequential( utils.spectral_norm(nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding)), - nn.BatchNorm1d(out_channels), - nn.LeakyReLU(0.2, inplace=True) # Changed activation to LeakyReLU + nn.LeakyReLU(0.2, inplace=True), + nn.BatchNorm1d(out_channels) ) class SISUDiscriminator(nn.Module): @@ -15,17 +15,16 @@ class SISUDiscriminator(nn.Module): super(SISUDiscriminator, self).__init__() layers = 4 # Increased base layer count self.model = nn.Sequential( - # Initial Convolution - discriminator_block(1, layers, kernel_size=7, stride=2, dilation=1), # Downsample - - # Core Discriminator Blocks with varied kernels and dilations - discriminator_block(layers, layers * 2, kernel_size=5, stride=2, dilation=1), # Downsample - discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=4), - discriminator_block(layers * 4, layers * 4, kernel_size=5, dilation=16), - discriminator_block(layers * 4, layers * 2, kernel_size=3, dilation=8), - discriminator_block(layers * 2, layers, kernel_size=3, dilation=1), - # Final Convolution - discriminator_block(layers, 1, kernel_size=3, stride=1), + discriminator_block(1, layers, kernel_size=7, stride=2), # Initial downsampling + discriminator_block(layers, layers * 2, kernel_size=5, stride=2), # Downsampling + discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=2), # Increased dilation + discriminator_block(layers * 4, layers * 4, kernel_size=5, dilation=4), # Increased dilation + discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=8), # Deeper layer! + discriminator_block(layers * 8, layers * 8, kernel_size=5, dilation=1), # Deeper layer! + discriminator_block(layers * 8, layers * 4, kernel_size=3, dilation=2), # Reduced dilation + discriminator_block(layers * 4, layers * 2, kernel_size=3, dilation=1), + discriminator_block(layers * 2, layers, kernel_size=3, stride=1), # Final convolution + discriminator_block(layers, 1, kernel_size=3, stride=1) ) self.global_avg_pool = nn.AdaptiveAvgPool1d(1) diff --git a/training.py b/training.py index 6ee7116..3992829 100644 --- a/training.py +++ b/training.py @@ -10,6 +10,8 @@ import argparse import math +import os + from torch.utils.data import random_split from torch.utils.data import DataLoader @@ -18,8 +20,26 @@ from data import AudioDataset from generator import SISUGenerator from discriminator import SISUDiscriminator -def perceptual_loss(y_true, y_pred): - return torch.mean((y_true - y_pred) ** 2) +import librosa + +def mfcc_loss(y_true, y_pred, sr): + # 1. Ensure sr is a NumPy scalar (not a Tensor) + if isinstance(sr, torch.Tensor): # Check if it's a Tensor + sr = sr.item() # Extract the value as a Python number + + # 2. Convert y_true and y_pred to NumPy arrays + y_true_np = y_true.cpu().detach().numpy()[0] # .cpu() is crucial! + y_pred_np = y_pred.cpu().detach().numpy()[0] + + + mfccs_true = librosa.feature.mfcc(y=y_true_np, sr=sr, n_mfcc=20) + mfccs_pred = librosa.feature.mfcc(y=y_pred_np, sr=sr, n_mfcc=20) + + # 3. Convert MFCCs back to PyTorch tensors and ensure correct device + mfccs_true = torch.tensor(mfccs_true, device=y_true.device, dtype=torch.float32) + mfccs_pred = torch.tensor(mfccs_pred, device=y_pred.device, dtype=torch.float32) + + return torch.mean((mfccs_true - mfccs_pred)**2) def discriminator_train(high_quality, low_quality, real_labels, fake_labels): optimizer_d.zero_grad() @@ -43,17 +63,23 @@ def discriminator_train(high_quality, low_quality, real_labels, fake_labels): return d_loss -def generator_train(low_quality, real_labels): +def generator_train(low_quality, high_quality, real_labels): optimizer_g.zero_grad() # Forward pass for fake samples (from generator output) generator_output = generator(low_quality[0]) - discriminator_decision = discriminator(generator_output) - g_loss = criterion_g(discriminator_decision, real_labels) - g_loss.backward() + mfcc_l = mfcc_loss(high_quality[0], generator_output, high_quality[1]) + + discriminator_decision = discriminator(generator_output) + adversarial_loss = criterion_g(discriminator_decision, real_labels) + + combined_loss = adversarial_loss + 0.5 * mfcc_l + + combined_loss.backward() optimizer_g.step() - return generator_output + + return (generator_output, combined_loss, adversarial_loss, mfcc_l) # Init script argument parser parser = argparse.ArgumentParser(description="Training script") @@ -61,6 +87,7 @@ parser.add_argument("--generator", type=str, default=None, help="Path to the generator model file") parser.add_argument("--discriminator", type=str, default=None, help="Path to the discriminator model file") +parser.add_argument("--verbose", action="store_true", help="Increase output verbosity") args = parser.parse_args() @@ -68,6 +95,8 @@ args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") +debug = args.verbose + # Initialize dataset and dataloader dataset_dir = './dataset/good' dataset = AudioDataset(dataset_dir) @@ -85,7 +114,7 @@ dataset = AudioDataset(dataset_dir) # ========= SINGLE ========= -train_data_loader = DataLoader(dataset, batch_size=16, shuffle=True) +train_data_loader = DataLoader(dataset, batch_size=1, shuffle=True) # Initialize models and move them to device generator = SISUGenerator() @@ -111,32 +140,10 @@ optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.99 scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5) scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5) +models_dir = "models" +os.makedirs(models_dir, exist_ok=True) + def start_training(): - - # Training loop - - # ========= DISCRIMINATOR PRE-TRAINING ========= - # discriminator_epochs = 1 - # for discriminator_epoch in range(discriminator_epochs): - - # # ========= TRAINING ========= - # for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {discriminator_epoch+1}/{discriminator_epochs}"): - # high_quality_sample = high_quality_clip[0].to(device) - # low_quality_sample = low_quality_clip[0].to(device) - - # scale = high_quality_clip[0].shape[2]/low_quality_clip[0].shape[2] - - # # ========= LABELS ========= - # batch_size = high_quality_sample.size(0) - # real_labels = torch.ones(batch_size, 1).to(device) - # fake_labels = torch.zeros(batch_size, 1).to(device) - - # # ========= DISCRIMINATOR ========= - # discriminator.train() - # discriminator_train(high_quality_sample, low_quality_sample, scale, real_labels, fake_labels) - - # torch.save(discriminator.state_dict(), "models/discriminator-single-shot-pre-train.pt") - generator_epochs = 5000 for generator_epoch in range(generator_epochs): low_quality_audio = (torch.empty((1)), 1) @@ -158,32 +165,35 @@ def start_training(): # ========= DISCRIMINATOR ========= discriminator.train() - discriminator_train(high_quality_sample, low_quality_sample, real_labels, fake_labels) + d_loss = discriminator_train(high_quality_sample, low_quality_sample, real_labels, fake_labels) # ========= GENERATOR ========= generator.train() - generator_output = generator_train(low_quality_sample, real_labels) + generator_output, combined_loss, adversarial_loss, mfcc_l = generator_train(low_quality_sample, high_quality_sample, real_labels) + + if debug: + print(d_loss, combined_loss, adversarial_loss, mfcc_l) + scheduler_d.step(d_loss) + scheduler_g.step(combined_loss) # ========= SAVE LATEST AUDIO ========= high_quality_audio = high_quality_clip low_quality_audio = low_quality_clip ai_enhanced_audio = (generator_output, high_quality_clip[1]) - #metric = snr(high_quality_audio[0].to(device), ai_enhanced_audio[0]) - #print(f"Generator metric {metric}!") - #scheduler_g.step(metric) - if generator_epoch % 10 == 0: print(f"Saved epoch {generator_epoch}!") torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1]) torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1]) - torch.save(discriminator.state_dict(), f"models/current-epoch-discriminator.pt") - torch.save(generator.state_dict(), f"models/current-epoch-generator.pt") + torch.save(discriminator.state_dict(), f"{models_dir}/discriminator_epoch_{generator_epoch}.pt") + torch.save(generator.state_dict(), f"{models_dir}/generator_epoch_{generator_epoch}.pt") + torch.save(discriminator, f"{models_dir}/discriminator_epoch_{generator_epoch}_full.pt") + torch.save(generator, f"{models_dir}/generator_epoch_{generator_epoch}_full.pt") - torch.save(discriminator.state_dict(), "models/epoch-5000-discriminator.pt") - torch.save(generator.state_dict(), "models/epoch-5000-generator.pt") + torch.save(discriminator, "models/epoch-5000-discriminator.pt") + torch.save(generator, "models/epoch-5000-generator.pt") print("Training complete!") start_training()