Files
SISU/training_utils.py
2025-09-10 19:52:53 +03:00

155 lines
4.8 KiB
Python

import torch
import torchaudio.transforms as T
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
mel_transform: T.MelSpectrogram
stft_transform: T.Spectrogram
# mfcc_transform: T.MFCC
# def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram, mfcc_trans: T.MFCC):
# """Initializes the global transform variables for the module."""
# global mel_transform, stft_transform, mfcc_transform
# mel_transform = mel_trans
# stft_transform = stft_trans
# mfcc_transform = mfcc_trans
def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram):
"""Initializes the global transform variables for the module."""
global mel_transform, stft_transform
mel_transform = mel_trans
stft_transform = stft_trans
# def mfcc_loss(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
# """Computes the Mean Squared Error (MSE) loss on MFCCs."""
# mfccs_true = mfcc_transform(y_true)
# mfccs_pred = mfcc_transform(y_pred)
# return F.mse_loss(mfccs_pred, mfccs_true)
# def mel_spectrogram_loss(
# y_true: torch.Tensor, y_pred: torch.Tensor, loss_type: str = "l1"
# ) -> torch.Tensor:
# """Calculates L1 or L2 loss on the Mel Spectrogram."""
# mel_spec_true = mel_transform(y_true)
# mel_spec_pred = mel_transform(y_pred)
# if loss_type == "l1":
# return F.l1_loss(mel_spec_pred, mel_spec_true)
# elif loss_type == "l2":
# return F.mse_loss(mel_spec_pred, mel_spec_true)
# else:
# raise ValueError("loss_type must be 'l1' or 'l2'")
# def log_stft_magnitude_loss(
# y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7
# ) -> torch.Tensor:
# """Calculates L1 loss on the log STFT magnitude."""
# stft_mag_true = stft_transform(y_true)
# stft_mag_pred = stft_transform(y_pred)
# return F.l1_loss(torch.log(stft_mag_pred + eps), torch.log(stft_mag_true + eps))
stft_loss_fn = MultiResolutionSTFTLoss(
fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240]
)
def discriminator_train(
high_quality,
low_quality,
real_labels,
fake_labels,
discriminator,
generator,
criterion,
):
discriminator_decision_from_real = discriminator(high_quality)
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
with torch.no_grad():
generator_output = generator(low_quality)
discriminator_decision_from_fake = discriminator(generator_output)
d_loss_fake = criterion(
discriminator_decision_from_fake,
fake_labels.expand_as(discriminator_decision_from_fake),
)
d_loss = (d_loss_real + d_loss_fake) / 2.0
return d_loss
def generator_train(
low_quality,
high_quality,
real_labels,
generator,
discriminator,
adv_criterion,
lambda_adv: float = 1.0,
lambda_feat: float = 10.0,
lambda_stft: float = 2.5,
):
generator_output = generator(low_quality)
discriminator_decision = discriminator(generator_output)
# adversarial_loss = adv_criterion(
# discriminator_decision, real_labels.expand_as(discriminator_decision)
# )
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
combined_loss = lambda_adv * adversarial_loss
stft_losses = stft_loss_fn(high_quality, generator_output)
stft_loss = stft_losses["total"]
combined_loss = (lambda_adv * adversarial_loss) + (lambda_stft * stft_loss)
return generator_output, combined_loss, adversarial_loss
# def generator_train(
# low_quality,
# high_quality,
# real_labels,
# generator,
# discriminator,
# adv_criterion,
# lambda_adv: float = 1.0,
# lambda_mel_l1: float = 10.0,
# lambda_log_stft: float = 1.0,
# ):
# generator_output = generator(low_quality)
# discriminator_decision = discriminator(generator_output)
# adversarial_loss = adv_criterion(
# discriminator_decision, real_labels.expand_as(discriminator_decision)
# )
# combined_loss = lambda_adv * adversarial_loss
# if lambda_mel_l1 > 0:
# mel_l1_loss = mel_spectrogram_loss(high_quality, generator_output, "l1")
# combined_loss += lambda_mel_l1 * mel_l1_loss
# else:
# mel_l1_loss = torch.tensor(0.0, device=low_quality.device) # For logging
# if lambda_log_stft > 0:
# log_stft_loss = log_stft_magnitude_loss(high_quality, generator_output)
# combined_loss += lambda_log_stft * log_stft_loss
# else:
# log_stft_loss = torch.tensor(0.0, device=low_quality.device)
# if lambda_mfcc > 0:
# mfcc_loss_val = mfcc_loss(high_quality, generator_output)
# combined_loss += lambda_mfcc * mfcc_loss_val
# else:
# mfcc_loss_val = torch.tensor(0.0, device=low_quality.device)
# return generator_output, combined_loss, adversarial_loss