import torch # In case if needed again... # from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss # # stft_loss_fn = MultiResolutionSTFTLoss( # fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240] # ) def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor: absolute_difference = torch.abs(input_one - input_two) return torch.mean(absolute_difference) def discriminator_train( high_quality, low_quality, high_labels, low_labels, discriminator, generator, criterion, ): decision_high = discriminator(high_quality) d_loss_high = criterion(decision_high, high_labels) # print(f"Is this real?: {discriminator_decision_from_real} | {d_loss_real}") decision_low = discriminator(low_quality) d_loss_low = criterion(decision_low, low_labels) # print(f"Is this real?: {discriminator_decision_from_fake} | {d_loss_fake}") with torch.no_grad(): generator_quality = generator(low_quality) decision_gen = discriminator(generator_quality) d_loss_gen = criterion(decision_gen, low_labels) noise = torch.rand_like(high_quality) * 0.08 decision_noise = discriminator(high_quality + noise) d_loss_noise = criterion(decision_noise, low_labels) d_loss = (d_loss_high + d_loss_low + d_loss_gen + d_loss_noise) / 4.0 return d_loss def generator_train( low_quality, high_quality, real_labels, generator, discriminator, adv_criterion ): generator_output = generator(low_quality) discriminator_decision = discriminator(generator_output) adversarial_loss = adv_criterion(discriminator_decision, real_labels) # Signal similarity similarity_loss = signal_mae(generator_output, high_quality) combined_loss = adversarial_loss + (similarity_loss * 100) return combined_loss, adversarial_loss