import torch from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss # stft_loss_fn = MultiResolutionSTFTLoss( # fft_sizes=[512, 1024, 2048, 4096], # hop_sizes=[128, 256, 512, 1024], # win_lengths=[512, 1024, 2048, 4096] # ) stft_loss_fn = MultiResolutionSTFTLoss( fft_sizes=[512, 1024, 2048], hop_sizes=[64, 128, 256], win_lengths=[256, 512, 1024] ) def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor: absolute_difference = torch.abs(input_one - input_two) return torch.mean(absolute_difference) def discriminator_train( high_quality, low_quality, high_labels, low_labels, discriminator, criterion, generator_output ): real_pair = torch.cat((low_quality, high_quality), dim=1) decision_real = discriminator(real_pair) d_loss_real = criterion(decision_real, high_labels) fake_pair = torch.cat((low_quality, generator_output), dim=1) decision_fake = discriminator(fake_pair) d_loss_fake = criterion(decision_fake, low_labels) d_loss = (d_loss_real + d_loss_fake) / 2.0 return d_loss def generator_train( low_quality, high_quality, real_labels, generator, discriminator, adv_criterion, generator_output): fake_pair = torch.cat((low_quality, generator_output), dim=1) discriminator_decision = discriminator(fake_pair) adversarial_loss = adv_criterion(discriminator_decision, real_labels) mae_loss = signal_mae(generator_output, high_quality) stft_loss = stft_loss_fn(high_quality, generator_output)["total"] lambda_mae = 10.0 lambda_stft = 2.5 lambda_adv = 2.5 combined_loss = (lambda_mae * mae_loss) + (lambda_stft * stft_loss) + (lambda_adv * adversarial_loss) return combined_loss, adversarial_loss