61 lines
1.8 KiB
Python
61 lines
1.8 KiB
Python
import torch
|
|
|
|
# In case if needed again...
|
|
# from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
|
|
#
|
|
# stft_loss_fn = MultiResolutionSTFTLoss(
|
|
# fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240]
|
|
# )
|
|
|
|
|
|
def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor:
|
|
absolute_difference = torch.abs(input_one - input_two)
|
|
return torch.mean(absolute_difference)
|
|
|
|
|
|
def discriminator_train(
|
|
high_quality,
|
|
low_quality,
|
|
high_labels,
|
|
low_labels,
|
|
discriminator,
|
|
generator,
|
|
criterion,
|
|
):
|
|
decision_high = discriminator(high_quality)
|
|
d_loss_high = criterion(decision_high, high_labels)
|
|
# print(f"Is this real?: {discriminator_decision_from_real} | {d_loss_real}")
|
|
|
|
decision_low = discriminator(low_quality)
|
|
d_loss_low = criterion(decision_low, low_labels)
|
|
# print(f"Is this real?: {discriminator_decision_from_fake} | {d_loss_fake}")
|
|
|
|
with torch.no_grad():
|
|
generator_quality = generator(low_quality)
|
|
decision_gen = discriminator(generator_quality)
|
|
d_loss_gen = criterion(decision_gen, low_labels)
|
|
|
|
noise = torch.rand_like(high_quality) * 0.08
|
|
decision_noise = discriminator(high_quality + noise)
|
|
d_loss_noise = criterion(decision_noise, low_labels)
|
|
|
|
d_loss = (d_loss_high + d_loss_low + d_loss_gen + d_loss_noise) / 4.0
|
|
|
|
return d_loss
|
|
|
|
|
|
def generator_train(
|
|
low_quality, high_quality, real_labels, generator, discriminator, adv_criterion
|
|
):
|
|
generator_output = generator(low_quality)
|
|
|
|
discriminator_decision = discriminator(generator_output)
|
|
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
|
|
|
|
# Signal similarity
|
|
similarity_loss = signal_mae(generator_output, high_quality)
|
|
|
|
combined_loss = adversarial_loss + (similarity_loss * 100)
|
|
|
|
return combined_loss, adversarial_loss
|