import torch import torchaudio.transforms as T from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss mel_transform: T.MelSpectrogram stft_transform: T.Spectrogram # mfcc_transform: T.MFCC # def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram, mfcc_trans: T.MFCC): # """Initializes the global transform variables for the module.""" # global mel_transform, stft_transform, mfcc_transform # mel_transform = mel_trans # stft_transform = stft_trans # mfcc_transform = mfcc_trans def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram): """Initializes the global transform variables for the module.""" global mel_transform, stft_transform mel_transform = mel_trans stft_transform = stft_trans # def mfcc_loss(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: # """Computes the Mean Squared Error (MSE) loss on MFCCs.""" # mfccs_true = mfcc_transform(y_true) # mfccs_pred = mfcc_transform(y_pred) # return F.mse_loss(mfccs_pred, mfccs_true) # def mel_spectrogram_loss( # y_true: torch.Tensor, y_pred: torch.Tensor, loss_type: str = "l1" # ) -> torch.Tensor: # """Calculates L1 or L2 loss on the Mel Spectrogram.""" # mel_spec_true = mel_transform(y_true) # mel_spec_pred = mel_transform(y_pred) # if loss_type == "l1": # return F.l1_loss(mel_spec_pred, mel_spec_true) # elif loss_type == "l2": # return F.mse_loss(mel_spec_pred, mel_spec_true) # else: # raise ValueError("loss_type must be 'l1' or 'l2'") # def log_stft_magnitude_loss( # y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7 # ) -> torch.Tensor: # """Calculates L1 loss on the log STFT magnitude.""" # stft_mag_true = stft_transform(y_true) # stft_mag_pred = stft_transform(y_pred) # return F.l1_loss(torch.log(stft_mag_pred + eps), torch.log(stft_mag_true + eps)) stft_loss_fn = MultiResolutionSTFTLoss( fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240] ) def discriminator_train( high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, ): discriminator_decision_from_real = discriminator(high_quality) d_loss_real = criterion(discriminator_decision_from_real, real_labels) with torch.no_grad(): generator_output = generator(low_quality) discriminator_decision_from_fake = discriminator(generator_output) d_loss_fake = criterion( discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake), ) d_loss = (d_loss_real + d_loss_fake) / 2.0 return d_loss def generator_train( low_quality, high_quality, real_labels, generator, discriminator, adv_criterion, lambda_adv: float = 1.0, lambda_feat: float = 10.0, lambda_stft: float = 2.5, ): generator_output = generator(low_quality) discriminator_decision = discriminator(generator_output) # adversarial_loss = adv_criterion( # discriminator_decision, real_labels.expand_as(discriminator_decision) # ) adversarial_loss = adv_criterion(discriminator_decision, real_labels) combined_loss = lambda_adv * adversarial_loss stft_losses = stft_loss_fn(high_quality, generator_output) stft_loss = stft_losses["total"] combined_loss = (lambda_adv * adversarial_loss) + (lambda_stft * stft_loss) return generator_output, combined_loss, adversarial_loss # def generator_train( # low_quality, # high_quality, # real_labels, # generator, # discriminator, # adv_criterion, # lambda_adv: float = 1.0, # lambda_mel_l1: float = 10.0, # lambda_log_stft: float = 1.0, # ): # generator_output = generator(low_quality) # discriminator_decision = discriminator(generator_output) # adversarial_loss = adv_criterion( # discriminator_decision, real_labels.expand_as(discriminator_decision) # ) # combined_loss = lambda_adv * adversarial_loss # if lambda_mel_l1 > 0: # mel_l1_loss = mel_spectrogram_loss(high_quality, generator_output, "l1") # combined_loss += lambda_mel_l1 * mel_l1_loss # else: # mel_l1_loss = torch.tensor(0.0, device=low_quality.device) # For logging # if lambda_log_stft > 0: # log_stft_loss = log_stft_magnitude_loss(high_quality, generator_output) # combined_loss += lambda_log_stft * log_stft_loss # else: # log_stft_loss = torch.tensor(0.0, device=low_quality.device) # if lambda_mfcc > 0: # mfcc_loss_val = mfcc_loss(high_quality, generator_output) # combined_loss += lambda_mfcc * mfcc_loss_val # else: # mfcc_loss_val = torch.tensor(0.0, device=low_quality.device) # return generator_output, combined_loss, adversarial_loss