59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
import torch
|
|
|
|
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
|
|
|
|
# stft_loss_fn = MultiResolutionSTFTLoss(
|
|
# fft_sizes=[512, 1024, 2048, 4096],
|
|
# hop_sizes=[128, 256, 512, 1024],
|
|
# win_lengths=[512, 1024, 2048, 4096]
|
|
# )
|
|
stft_loss_fn = MultiResolutionSTFTLoss(
|
|
fft_sizes=[512, 1024, 2048],
|
|
hop_sizes=[64, 128, 256],
|
|
win_lengths=[256, 512, 1024]
|
|
)
|
|
|
|
def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor:
|
|
absolute_difference = torch.abs(input_one - input_two)
|
|
return torch.mean(absolute_difference)
|
|
|
|
|
|
def discriminator_train(
|
|
high_quality,
|
|
low_quality,
|
|
high_labels,
|
|
low_labels,
|
|
discriminator,
|
|
criterion,
|
|
generator_output
|
|
):
|
|
|
|
real_pair = torch.cat((low_quality, high_quality), dim=1)
|
|
decision_real = discriminator(real_pair)
|
|
d_loss_real = criterion(decision_real, high_labels)
|
|
|
|
fake_pair = torch.cat((low_quality, generator_output), dim=1)
|
|
decision_fake = discriminator(fake_pair)
|
|
d_loss_fake = criterion(decision_fake, low_labels)
|
|
|
|
d_loss = (d_loss_real + d_loss_fake) / 2.0
|
|
return d_loss
|
|
|
|
|
|
def generator_train(
|
|
low_quality, high_quality, real_labels, generator, discriminator, adv_criterion, generator_output):
|
|
|
|
fake_pair = torch.cat((low_quality, generator_output), dim=1)
|
|
|
|
discriminator_decision = discriminator(fake_pair)
|
|
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
|
|
|
|
mae_loss = signal_mae(generator_output, high_quality)
|
|
stft_loss = stft_loss_fn(high_quality, generator_output)["total"]
|
|
|
|
lambda_mae = 10.0
|
|
lambda_stft = 2.5
|
|
lambda_adv = 2.5
|
|
combined_loss = (lambda_mae * mae_loss) + (lambda_stft * stft_loss) + (lambda_adv * adversarial_loss)
|
|
return combined_loss, adversarial_loss
|