import torch import torch.nn.functional as F from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss stft_loss_fn = MultiResolutionSTFTLoss( fft_sizes=[512, 1024, 2048], hop_sizes=[64, 128, 256], win_lengths=[256, 512, 1024] ) def feature_matching_loss(fmap_r, fmap_g): """ Computes L1 distance between real and fake feature maps. """ loss = 0 for dr, dg in zip(fmap_r, fmap_g): for rl, gl in zip(dr, dg): # Stop gradient on real features to save memory/computation rl = rl.detach() loss += torch.mean(torch.abs(rl - gl)) # Scale by number of feature maps to keep loss magnitude reasonable return loss * 2 def discriminator_loss(disc_real_outputs, disc_generated_outputs): """ Least Squares GAN Loss (LSGAN) for the Discriminator. Objective: Real -> 1, Fake -> 0 """ loss = 0 r_losses = [] g_losses = [] # Iterate over both MPD and MSD outputs for dr, dg in zip(disc_real_outputs, disc_generated_outputs): # Real should be 1.0 r_loss = torch.mean((dr - 1) ** 2) # Fake should be 0.0 g_loss = torch.mean(dg ** 2) loss += (r_loss + g_loss) r_losses.append(r_loss.item()) g_losses.append(g_loss.item()) return loss, r_losses, g_losses def generator_adv_loss(disc_generated_outputs): """ Least Squares GAN Loss for the Generator. Objective: Fake -> 1 (Fool the discriminator) """ loss = 0 for dg in zip(disc_generated_outputs): dg = dg[0] # Unpack tuple loss += torch.mean((dg - 1) ** 2) return loss def discriminator_train( high_quality, low_quality, discriminator, generator_output ): # 1. Forward pass through the Ensemble Discriminator # Note: We pass inputs separately now: (Real_Target, Fake_Candidate) # We detach generator_output because we are only optimizing D here y_d_rs, y_d_gs, _, _ = discriminator(high_quality, generator_output.detach()) # 2. Calculate Loss (LSGAN) d_loss, _, _ = discriminator_loss(y_d_rs, y_d_gs) return d_loss def generator_train( low_quality, high_quality, generator, discriminator, generator_output ): # 1. Forward pass through Discriminator # We do NOT detach generator_output here, we need gradients for G y_d_rs, y_d_gs, fmap_rs, fmap_gs = discriminator(high_quality, generator_output) # 2. Adversarial Loss (Try to fool D into thinking G is Real) loss_gen_adv = generator_adv_loss(y_d_gs) # 3. Feature Matching Loss (Force G to match internal features of D) loss_fm = feature_matching_loss(fmap_rs, fmap_gs) # 4. Mel-Spectrogram / STFT Loss (Audio Quality) stft_loss = stft_loss_fn(high_quality, generator_output)["total"] # ----------------------------------------- # 5. Combine Losses # ----------------------------------------- # Recommended weights for HiFi-GAN/EnCodec style architectures: # STFT is dominant (45), FM provides stability (2), Adv provides texture (1) lambda_stft = 45.0 lambda_fm = 2.0 lambda_adv = 1.0 combined_loss = (lambda_stft * stft_loss) + \ (lambda_fm * loss_fm) + \ (lambda_adv * loss_gen_adv) return combined_loss, loss_gen_adv