155 lines
4.8 KiB
Python
155 lines
4.8 KiB
Python
import torch
|
|
import torchaudio.transforms as T
|
|
|
|
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
|
|
|
|
mel_transform: T.MelSpectrogram
|
|
stft_transform: T.Spectrogram
|
|
# mfcc_transform: T.MFCC
|
|
|
|
|
|
# def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram, mfcc_trans: T.MFCC):
|
|
# """Initializes the global transform variables for the module."""
|
|
# global mel_transform, stft_transform, mfcc_transform
|
|
# mel_transform = mel_trans
|
|
# stft_transform = stft_trans
|
|
# mfcc_transform = mfcc_trans
|
|
|
|
|
|
def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram):
|
|
"""Initializes the global transform variables for the module."""
|
|
global mel_transform, stft_transform
|
|
mel_transform = mel_trans
|
|
stft_transform = stft_trans
|
|
|
|
|
|
# def mfcc_loss(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
|
|
# """Computes the Mean Squared Error (MSE) loss on MFCCs."""
|
|
# mfccs_true = mfcc_transform(y_true)
|
|
# mfccs_pred = mfcc_transform(y_pred)
|
|
# return F.mse_loss(mfccs_pred, mfccs_true)
|
|
|
|
|
|
# def mel_spectrogram_loss(
|
|
# y_true: torch.Tensor, y_pred: torch.Tensor, loss_type: str = "l1"
|
|
# ) -> torch.Tensor:
|
|
# """Calculates L1 or L2 loss on the Mel Spectrogram."""
|
|
# mel_spec_true = mel_transform(y_true)
|
|
# mel_spec_pred = mel_transform(y_pred)
|
|
# if loss_type == "l1":
|
|
# return F.l1_loss(mel_spec_pred, mel_spec_true)
|
|
# elif loss_type == "l2":
|
|
# return F.mse_loss(mel_spec_pred, mel_spec_true)
|
|
# else:
|
|
# raise ValueError("loss_type must be 'l1' or 'l2'")
|
|
|
|
|
|
# def log_stft_magnitude_loss(
|
|
# y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7
|
|
# ) -> torch.Tensor:
|
|
# """Calculates L1 loss on the log STFT magnitude."""
|
|
# stft_mag_true = stft_transform(y_true)
|
|
# stft_mag_pred = stft_transform(y_pred)
|
|
# return F.l1_loss(torch.log(stft_mag_pred + eps), torch.log(stft_mag_true + eps))
|
|
|
|
|
|
stft_loss_fn = MultiResolutionSTFTLoss(
|
|
fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240]
|
|
)
|
|
|
|
|
|
def discriminator_train(
|
|
high_quality,
|
|
low_quality,
|
|
real_labels,
|
|
fake_labels,
|
|
discriminator,
|
|
generator,
|
|
criterion,
|
|
):
|
|
discriminator_decision_from_real = discriminator(high_quality)
|
|
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
|
|
|
|
with torch.no_grad():
|
|
generator_output = generator(low_quality)
|
|
discriminator_decision_from_fake = discriminator(generator_output)
|
|
d_loss_fake = criterion(
|
|
discriminator_decision_from_fake,
|
|
fake_labels.expand_as(discriminator_decision_from_fake),
|
|
)
|
|
|
|
d_loss = (d_loss_real + d_loss_fake) / 2.0
|
|
|
|
return d_loss
|
|
|
|
|
|
def generator_train(
|
|
low_quality,
|
|
high_quality,
|
|
real_labels,
|
|
generator,
|
|
discriminator,
|
|
adv_criterion,
|
|
lambda_adv: float = 1.0,
|
|
lambda_feat: float = 10.0,
|
|
lambda_stft: float = 2.5,
|
|
):
|
|
generator_output = generator(low_quality)
|
|
|
|
discriminator_decision = discriminator(generator_output)
|
|
# adversarial_loss = adv_criterion(
|
|
# discriminator_decision, real_labels.expand_as(discriminator_decision)
|
|
# )
|
|
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
|
|
|
|
combined_loss = lambda_adv * adversarial_loss
|
|
|
|
stft_losses = stft_loss_fn(high_quality, generator_output)
|
|
stft_loss = stft_losses["total"]
|
|
|
|
combined_loss = (lambda_adv * adversarial_loss) + (lambda_stft * stft_loss)
|
|
|
|
return generator_output, combined_loss, adversarial_loss
|
|
|
|
|
|
# def generator_train(
|
|
# low_quality,
|
|
# high_quality,
|
|
# real_labels,
|
|
# generator,
|
|
# discriminator,
|
|
# adv_criterion,
|
|
# lambda_adv: float = 1.0,
|
|
# lambda_mel_l1: float = 10.0,
|
|
# lambda_log_stft: float = 1.0,
|
|
|
|
# ):
|
|
# generator_output = generator(low_quality)
|
|
|
|
# discriminator_decision = discriminator(generator_output)
|
|
# adversarial_loss = adv_criterion(
|
|
# discriminator_decision, real_labels.expand_as(discriminator_decision)
|
|
# )
|
|
|
|
# combined_loss = lambda_adv * adversarial_loss
|
|
|
|
# if lambda_mel_l1 > 0:
|
|
# mel_l1_loss = mel_spectrogram_loss(high_quality, generator_output, "l1")
|
|
# combined_loss += lambda_mel_l1 * mel_l1_loss
|
|
# else:
|
|
# mel_l1_loss = torch.tensor(0.0, device=low_quality.device) # For logging
|
|
|
|
# if lambda_log_stft > 0:
|
|
# log_stft_loss = log_stft_magnitude_loss(high_quality, generator_output)
|
|
# combined_loss += lambda_log_stft * log_stft_loss
|
|
# else:
|
|
# log_stft_loss = torch.tensor(0.0, device=low_quality.device)
|
|
|
|
# if lambda_mfcc > 0:
|
|
# mfcc_loss_val = mfcc_loss(high_quality, generator_output)
|
|
# combined_loss += lambda_mfcc * mfcc_loss_val
|
|
# else:
|
|
# mfcc_loss_val = torch.tensor(0.0, device=low_quality.device)
|
|
|
|
# return generator_output, combined_loss, adversarial_loss
|