From c04b072de6d156cb88f2c2652a33c04e80c68a31 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Wed, 16 Apr 2025 17:08:13 +0300 Subject: [PATCH] :sparkles: | Added smarter ways that would've been needed from the begining. --- training.py | 33 ++++++++--- training_utils.py | 139 +++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 148 insertions(+), 24 deletions(-) diff --git a/training.py b/training.py index 17843e0..db7cb86 100644 --- a/training.py +++ b/training.py @@ -41,11 +41,24 @@ args = parser.parse_args() device = torch.device(args.device if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") -# mfcc_transform = T.MFCC( -# sample_rate=44100, -# n_mfcc=20, -# melkwargs={'n_fft': 2048, 'hop_length': 256} -# ).to(device) +# Parameters +sample_rate = 44100 +n_fft = 2048 +hop_length = 256 +win_length = n_fft +n_mels = 128 +n_mfcc = 20 # If using MFCC + +mfcc_transform = T.MFCC( + sample_rate, + n_mfcc, + melkwargs = {'n_fft': n_fft, 'hop_length': hop_length} +).to(device) + +mel_transform = T.MelSpectrogram( + sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, + win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel +).to(device) debug = args.debug @@ -130,18 +143,20 @@ def start_training(): # ========= GENERATOR ========= generator.train() - generator_output, adversarial_loss = generator_train( + generator_output, combined_loss, adversarial_loss, mel_l1_tensor = generator_train( low_quality_sample, high_quality_sample, real_labels, generator, discriminator, - criterion_g, - optimizer_g + criterion_d, + optimizer_g, + device, + mel_transform ) if debug: - print(d_loss, adversarial_loss) + print(combined_loss, adversarial_loss, mel_l1_tensor) scheduler_d.step(d_loss.detach()) scheduler_g.step(adversarial_loss.detach()) diff --git a/training_utils.py b/training_utils.py index a1d2c19..be402d9 100644 --- a/training_utils.py +++ b/training_utils.py @@ -3,16 +3,73 @@ import torch.nn as nn import torch.optim as optim import torchaudio +import torchaudio.transforms as T def gpu_mfcc_loss(mfcc_transform, y_true, y_pred): mfccs_true = mfcc_transform(y_true) mfccs_pred = mfcc_transform(y_pred) + min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2]) mfccs_true = mfccs_true[:, :, :min_len] mfccs_pred = mfccs_pred[:, :, :min_len] + loss = torch.mean((mfccs_true - mfccs_pred)**2) return loss +def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + mel_spec_true = mel_transform(y_true) + mel_spec_pred = mel_transform(y_pred) + + # Ensure same time dimension length (due to potential framing differences) + min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1]) + mel_spec_true = mel_spec_true[..., :min_len] + mel_spec_pred = mel_spec_pred[..., :min_len] + + # L1 Loss (Mean Absolute Error) + loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred)) + return loss + +def mel_spectrogram_l2_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + mel_spec_true = mel_transform(y_true) + mel_spec_pred = mel_transform(y_pred) + + min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1]) + mel_spec_true = mel_spec_true[..., :min_len] + mel_spec_pred = mel_spec_pred[..., :min_len] + + # L2 Loss (Mean Squared Error) + loss = torch.mean((mel_spec_true - mel_spec_pred)**2) + return loss + +def log_stft_magnitude_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: + stft_mag_true = stft_transform(y_true) + stft_mag_pred = stft_transform(y_pred) + + min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1]) + stft_mag_true = stft_mag_true[..., :min_len] + stft_mag_pred = stft_mag_pred[..., :min_len] + + # Log Magnitude L1 Loss + loss = torch.mean(torch.abs(torch.log(stft_mag_true + eps) - torch.log(stft_mag_pred + eps))) + return loss + +def spectral_convergence_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: + stft_mag_true = stft_transform(y_true) + stft_mag_pred = stft_transform(y_pred) + + min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1]) + stft_mag_true = stft_mag_true[..., :min_len] + stft_mag_pred = stft_mag_pred[..., :min_len] + + # Calculate Frobenius norms and the loss + # Ensure norms are calculated over frequency and time dims ([..., freq, time]) + norm_true = torch.linalg.norm(stft_mag_true, ord='fro', dim=(-2, -1)) + norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, ord='fro', dim=(-2, -1)) + + # Average loss over the batch + loss = torch.mean(norm_diff / (norm_true + eps)) + return loss + def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer): optimizer.zero_grad() @@ -21,35 +78,87 @@ def discriminator_train(high_quality, low_quality, real_labels, fake_labels, dis d_loss_real = criterion(discriminator_decision_from_real, real_labels) # Forward pass for fake samples (from generator output) - generator_output = generator(low_quality[0]) - discriminator_decision_from_fake = discriminator(generator_output.detach()) - d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels) + with torch.no_grad(): # Detach generator output within no_grad context + generator_output = generator(low_quality[0]) + discriminator_decision_from_fake = discriminator(generator_output) # No need to detach again if inside no_grad + d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake)) # Combine real and fake losses d_loss = (d_loss_real + d_loss_fake) / 2.0 # Backward pass and optimization d_loss.backward() - nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping + # Optional: Gradient Clipping (can be helpful) + # nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping optimizer.step() return d_loss -def generator_train(low_quality, high_quality, real_labels, generator, discriminator, criterion, optimizer): - optimizer.zero_grad() +def generator_train( + low_quality, + high_quality, + real_labels, + generator, + discriminator, + adv_criterion, # Criterion for adversarial loss (e.g., BCEWithLogitsLoss) + g_optimizer, + device, + # --- Pass necessary transforms and loss weights --- + mel_transform: T.MelSpectrogram, # Example: Pass Mel transform + # stft_transform: T.Spectrogram, # Pass STFT transform if using STFT losses + # mfcc_transform: T.MFCC, # Pass MFCC transform if using MFCC loss + lambda_adv: float = 1.0, # Weight for adversarial loss + lambda_mel_l1: float = 10.0, # Example: Weight for Mel L1 loss + # lambda_log_stft: float = 0.0, # Set weights > 0 for losses you want to use + # lambda_mfcc: float = 0.0 +): + g_optimizer.zero_grad() - # Forward pass for fake samples (from generator output) + # 1. Generate high-quality audio from low-quality input generator_output = generator(low_quality[0]) - #mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output) - + # 2. Calculate Adversarial Loss (Generator tries to fool discriminator) discriminator_decision = discriminator(generator_output) - adversarial_loss = criterion(discriminator_decision, real_labels) + # Generator wants discriminator to output "real" labels for its fakes + adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision)) - #combined_loss = adversarial_loss + 0.5 * mfcc_l + # 3. Calculate Reconstruction/Spectrogram Loss(es) + # --- Choose and calculate the losses you want to include --- + mel_l1 = 0.0 + # log_stft_l1 = 0.0 + # mfcc_l = 0.0 - adversarial_loss.backward() - optimizer.step() + # Calculate Mel L1 Loss if weight is positive + if lambda_mel_l1 > 0: + mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality[0], generator_output) - #return (generator_output, combined_loss, adversarial_loss, mfcc_l) - return (generator_output, adversarial_loss) + # # Calculate Log STFT L1 Loss if weight is positive + # if lambda_log_stft > 0: + # log_stft_l1 = log_stft_magnitude_loss(stft_transform, hq_audio, generator_output) + + # # Calculate MFCC Loss if weight is positive + # if lambda_mfcc > 0: + # mfcc_l = gpu_mfcc_loss(mfcc_transform, hq_audio, generator_output) + # --- End of Loss Calculation Choices --- + + + # 4. Combine Losses + # Make sure calculated losses are tensors even if weights are 0 initially + # (or handle appropriately in the sum) + mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1 + # log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1 + # mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l + + combined_loss = (lambda_adv * adversarial_loss) + \ + (lambda_mel_l1 * mel_l1_tensor) + # + (lambda_log_stft * log_stft_l1_tensor) \ + # + (lambda_mfcc * mfcc_l_tensor) + + # 5. Backward Pass and Optimization + combined_loss.backward() + # Optional: Gradient Clipping + # nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0) + g_optimizer.step() + + # 6. Return values for logging + return generator_output, combined_loss, adversarial_loss, mel_l1_tensor