Files
SISU/utils/TrainingTools.py

114 lines
3.3 KiB
Python

import torch
import torch.nn.functional as F
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
stft_loss_fn = MultiResolutionSTFTLoss(
fft_sizes=[512, 1024, 2048],
hop_sizes=[64, 128, 256],
win_lengths=[256, 512, 1024]
)
def feature_matching_loss(fmap_r, fmap_g):
"""
Computes L1 distance between real and fake feature maps.
"""
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
# Stop gradient on real features to save memory/computation
rl = rl.detach()
loss += torch.mean(torch.abs(rl - gl))
# Scale by number of feature maps to keep loss magnitude reasonable
return loss * 2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
"""
Least Squares GAN Loss (LSGAN) for the Discriminator.
Objective: Real -> 1, Fake -> 0
"""
loss = 0
r_losses = []
g_losses = []
# Iterate over both MPD and MSD outputs
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
# Real should be 1.0
r_loss = torch.mean((dr - 1) ** 2)
# Fake should be 0.0
g_loss = torch.mean(dg ** 2)
loss += (r_loss + g_loss)
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_adv_loss(disc_generated_outputs):
"""
Least Squares GAN Loss for the Generator.
Objective: Fake -> 1 (Fool the discriminator)
"""
loss = 0
for dg in zip(disc_generated_outputs):
dg = dg[0] # Unpack tuple
loss += torch.mean((dg - 1) ** 2)
return loss
def discriminator_train(
high_quality,
low_quality,
discriminator,
generator_output
):
# 1. Forward pass through the Ensemble Discriminator
# Note: We pass inputs separately now: (Real_Target, Fake_Candidate)
# We detach generator_output because we are only optimizing D here
y_d_rs, y_d_gs, _, _ = discriminator(high_quality, generator_output.detach())
# 2. Calculate Loss (LSGAN)
d_loss, _, _ = discriminator_loss(y_d_rs, y_d_gs)
return d_loss
def generator_train(
low_quality,
high_quality,
generator,
discriminator,
generator_output
):
# 1. Forward pass through Discriminator
# We do NOT detach generator_output here, we need gradients for G
y_d_rs, y_d_gs, fmap_rs, fmap_gs = discriminator(high_quality, generator_output)
# 2. Adversarial Loss (Try to fool D into thinking G is Real)
loss_gen_adv = generator_adv_loss(y_d_gs)
# 3. Feature Matching Loss (Force G to match internal features of D)
loss_fm = feature_matching_loss(fmap_rs, fmap_gs)
# 4. Mel-Spectrogram / STFT Loss (Audio Quality)
stft_loss = stft_loss_fn(high_quality, generator_output)["total"]
# -----------------------------------------
# 5. Combine Losses
# -----------------------------------------
# Recommended weights for HiFi-GAN/EnCodec style architectures:
# STFT is dominant (45), FM provides stability (2), Adv provides texture (1)
lambda_stft = 45.0
lambda_fm = 2.0
lambda_adv = 1.0
combined_loss = (lambda_stft * stft_loss) + \
(lambda_fm * loss_fm) + \
(lambda_adv * loss_gen_adv)
return combined_loss, loss_gen_adv