⚗️ | Added MultiPeriodDiscriminator implementation from HiFi-GAN
This commit is contained in:
@@ -1,58 +1,113 @@
|
||||
import torch
|
||||
|
||||
import torch.nn.functional as F
|
||||
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
|
||||
|
||||
# stft_loss_fn = MultiResolutionSTFTLoss(
|
||||
# fft_sizes=[512, 1024, 2048, 4096],
|
||||
# hop_sizes=[128, 256, 512, 1024],
|
||||
# win_lengths=[512, 1024, 2048, 4096]
|
||||
# )
|
||||
|
||||
stft_loss_fn = MultiResolutionSTFTLoss(
|
||||
fft_sizes=[512, 1024, 2048],
|
||||
hop_sizes=[64, 128, 256],
|
||||
win_lengths=[256, 512, 1024]
|
||||
)
|
||||
|
||||
def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor:
|
||||
absolute_difference = torch.abs(input_one - input_two)
|
||||
return torch.mean(absolute_difference)
|
||||
def feature_matching_loss(fmap_r, fmap_g):
|
||||
"""
|
||||
Computes L1 distance between real and fake feature maps.
|
||||
"""
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
# Stop gradient on real features to save memory/computation
|
||||
rl = rl.detach()
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
# Scale by number of feature maps to keep loss magnitude reasonable
|
||||
return loss * 2
|
||||
|
||||
|
||||
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||
"""
|
||||
Least Squares GAN Loss (LSGAN) for the Discriminator.
|
||||
Objective: Real -> 1, Fake -> 0
|
||||
"""
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
|
||||
# Iterate over both MPD and MSD outputs
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
# Real should be 1.0
|
||||
r_loss = torch.mean((dr - 1) ** 2)
|
||||
# Fake should be 0.0
|
||||
g_loss = torch.mean(dg ** 2)
|
||||
|
||||
loss += (r_loss + g_loss)
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
|
||||
return loss, r_losses, g_losses
|
||||
|
||||
|
||||
def generator_adv_loss(disc_generated_outputs):
|
||||
"""
|
||||
Least Squares GAN Loss for the Generator.
|
||||
Objective: Fake -> 1 (Fool the discriminator)
|
||||
"""
|
||||
loss = 0
|
||||
for dg in zip(disc_generated_outputs):
|
||||
dg = dg[0] # Unpack tuple
|
||||
loss += torch.mean((dg - 1) ** 2)
|
||||
return loss
|
||||
|
||||
|
||||
def discriminator_train(
|
||||
high_quality,
|
||||
low_quality,
|
||||
high_labels,
|
||||
low_labels,
|
||||
discriminator,
|
||||
criterion,
|
||||
generator_output
|
||||
):
|
||||
# 1. Forward pass through the Ensemble Discriminator
|
||||
# Note: We pass inputs separately now: (Real_Target, Fake_Candidate)
|
||||
# We detach generator_output because we are only optimizing D here
|
||||
y_d_rs, y_d_gs, _, _ = discriminator(high_quality, generator_output.detach())
|
||||
|
||||
real_pair = torch.cat((low_quality, high_quality), dim=1)
|
||||
decision_real = discriminator(real_pair)
|
||||
d_loss_real = criterion(decision_real, high_labels)
|
||||
# 2. Calculate Loss (LSGAN)
|
||||
d_loss, _, _ = discriminator_loss(y_d_rs, y_d_gs)
|
||||
|
||||
fake_pair = torch.cat((low_quality, generator_output), dim=1)
|
||||
decision_fake = discriminator(fake_pair)
|
||||
d_loss_fake = criterion(decision_fake, low_labels)
|
||||
|
||||
d_loss = (d_loss_real + d_loss_fake) / 2.0
|
||||
return d_loss
|
||||
|
||||
|
||||
def generator_train(
|
||||
low_quality, high_quality, real_labels, generator, discriminator, adv_criterion, generator_output):
|
||||
low_quality,
|
||||
high_quality,
|
||||
generator,
|
||||
discriminator,
|
||||
generator_output
|
||||
):
|
||||
# 1. Forward pass through Discriminator
|
||||
# We do NOT detach generator_output here, we need gradients for G
|
||||
y_d_rs, y_d_gs, fmap_rs, fmap_gs = discriminator(high_quality, generator_output)
|
||||
|
||||
fake_pair = torch.cat((low_quality, generator_output), dim=1)
|
||||
# 2. Adversarial Loss (Try to fool D into thinking G is Real)
|
||||
loss_gen_adv = generator_adv_loss(y_d_gs)
|
||||
|
||||
discriminator_decision = discriminator(fake_pair)
|
||||
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
|
||||
# 3. Feature Matching Loss (Force G to match internal features of D)
|
||||
loss_fm = feature_matching_loss(fmap_rs, fmap_gs)
|
||||
|
||||
mae_loss = signal_mae(generator_output, high_quality)
|
||||
# 4. Mel-Spectrogram / STFT Loss (Audio Quality)
|
||||
stft_loss = stft_loss_fn(high_quality, generator_output)["total"]
|
||||
|
||||
lambda_mae = 10.0
|
||||
lambda_stft = 2.5
|
||||
lambda_adv = 2.5
|
||||
combined_loss = (lambda_mae * mae_loss) + (lambda_stft * stft_loss) + (lambda_adv * adversarial_loss)
|
||||
return combined_loss, adversarial_loss
|
||||
# -----------------------------------------
|
||||
# 5. Combine Losses
|
||||
# -----------------------------------------
|
||||
# Recommended weights for HiFi-GAN/EnCodec style architectures:
|
||||
# STFT is dominant (45), FM provides stability (2), Adv provides texture (1)
|
||||
|
||||
lambda_stft = 45.0
|
||||
lambda_fm = 2.0
|
||||
lambda_adv = 1.0
|
||||
|
||||
combined_loss = (lambda_stft * stft_loss) + \
|
||||
(lambda_fm * loss_fm) + \
|
||||
(lambda_adv * loss_gen_adv)
|
||||
|
||||
return combined_loss, loss_gen_adv
|
||||
|
||||
Reference in New Issue
Block a user