import torch import torch.nn as nn import torch.optim as optim import torchaudio import torchaudio.transforms as T def gpu_mfcc_loss(mfcc_transform, y_true, y_pred): mfccs_true = mfcc_transform(y_true) mfccs_pred = mfcc_transform(y_pred) min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2]) mfccs_true = mfccs_true[:, :, :min_len] mfccs_pred = mfccs_pred[:, :, :min_len] loss = torch.mean((mfccs_true - mfccs_pred)**2) return loss def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: mel_spec_true = mel_transform(y_true) mel_spec_pred = mel_transform(y_pred) min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1]) mel_spec_true = mel_spec_true[..., :min_len] mel_spec_pred = mel_spec_pred[..., :min_len] loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred)) return loss def mel_spectrogram_l2_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: mel_spec_true = mel_transform(y_true) mel_spec_pred = mel_transform(y_pred) min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1]) mel_spec_true = mel_spec_true[..., :min_len] mel_spec_pred = mel_spec_pred[..., :min_len] loss = torch.mean((mel_spec_true - mel_spec_pred)**2) return loss def log_stft_magnitude_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: stft_mag_true = stft_transform(y_true) stft_mag_pred = stft_transform(y_pred) min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1]) stft_mag_true = stft_mag_true[..., :min_len] stft_mag_pred = stft_mag_pred[..., :min_len] loss = torch.mean(torch.abs(torch.log(stft_mag_true + eps) - torch.log(stft_mag_pred + eps))) return loss def spectral_convergence_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: stft_mag_true = stft_transform(y_true) stft_mag_pred = stft_transform(y_pred) min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1]) stft_mag_true = stft_mag_true[..., :min_len] stft_mag_pred = stft_mag_pred[..., :min_len] norm_true = torch.linalg.norm(stft_mag_true, ord='fro', dim=(-2, -1)) norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, ord='fro', dim=(-2, -1)) loss = torch.mean(norm_diff / (norm_true + eps)) return loss def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer): optimizer.zero_grad() # Forward pass for real samples discriminator_decision_from_real = discriminator(high_quality) d_loss_real = criterion(discriminator_decision_from_real, real_labels) with torch.no_grad(): generator_output = generator(low_quality) discriminator_decision_from_fake = discriminator(generator_output) d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake)) d_loss = (d_loss_real + d_loss_fake) / 2.0 d_loss.backward() # Optional: Gradient Clipping (can be helpful) # nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping optimizer.step() return d_loss def generator_train( low_quality, high_quality, real_labels, generator, discriminator, adv_criterion, g_optimizer, device, mel_transform: T.MelSpectrogram, stft_transform: T.Spectrogram, mfcc_transform: T.MFCC, lambda_adv: float = 1.0, lambda_mel_l1: float = 10.0, lambda_log_stft: float = 1.0, lambda_mfcc: float = 1.0 ): g_optimizer.zero_grad() generator_output = generator(low_quality) discriminator_decision = discriminator(generator_output) adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision)) mel_l1 = 0.0 log_stft_l1 = 0.0 mfcc_l = 0.0 # Calculate Mel L1 Loss if weight is positive if lambda_mel_l1 > 0: mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality, generator_output) # Calculate Log STFT L1 Loss if weight is positive if lambda_log_stft > 0: log_stft_l1 = log_stft_magnitude_loss(stft_transform, high_quality, generator_output) # Calculate MFCC Loss if weight is positive if lambda_mfcc > 0: mfcc_l = gpu_mfcc_loss(mfcc_transform, high_quality, generator_output) mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1 log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1 mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l combined_loss = (lambda_adv * adversarial_loss) + \ (lambda_mel_l1 * mel_l1_tensor) + \ (lambda_log_stft * log_stft_l1_tensor) + \ (lambda_mfcc * mfcc_l_tensor) combined_loss.backward() # Optional: Gradient Clipping # nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0) g_optimizer.step() # 6. Return values for logging return generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor