⚗️ | Added some stupid ways for training + some makeup
This commit is contained in:
60
utils/TrainingTools.py
Normal file
60
utils/TrainingTools.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
|
||||
# In case if needed again...
|
||||
# from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
|
||||
#
|
||||
# stft_loss_fn = MultiResolutionSTFTLoss(
|
||||
# fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240]
|
||||
# )
|
||||
|
||||
|
||||
def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor:
|
||||
absolute_difference = torch.abs(input_one - input_two)
|
||||
return torch.mean(absolute_difference)
|
||||
|
||||
|
||||
def discriminator_train(
|
||||
high_quality,
|
||||
low_quality,
|
||||
high_labels,
|
||||
low_labels,
|
||||
discriminator,
|
||||
generator,
|
||||
criterion,
|
||||
):
|
||||
decision_high = discriminator(high_quality)
|
||||
d_loss_high = criterion(decision_high, high_labels)
|
||||
# print(f"Is this real?: {discriminator_decision_from_real} | {d_loss_real}")
|
||||
|
||||
decision_low = discriminator(low_quality)
|
||||
d_loss_low = criterion(decision_low, low_labels)
|
||||
# print(f"Is this real?: {discriminator_decision_from_fake} | {d_loss_fake}")
|
||||
|
||||
with torch.no_grad():
|
||||
generator_quality = generator(low_quality)
|
||||
decision_gen = discriminator(generator_quality)
|
||||
d_loss_gen = criterion(decision_gen, low_labels)
|
||||
|
||||
noise = torch.rand_like(high_quality) * 0.08
|
||||
decision_noise = discriminator(high_quality + noise)
|
||||
d_loss_noise = criterion(decision_noise, low_labels)
|
||||
|
||||
d_loss = (d_loss_high + d_loss_low + d_loss_gen + d_loss_noise) / 4.0
|
||||
|
||||
return d_loss
|
||||
|
||||
|
||||
def generator_train(
|
||||
low_quality, high_quality, real_labels, generator, discriminator, adv_criterion
|
||||
):
|
||||
generator_output = generator(low_quality)
|
||||
|
||||
discriminator_decision = discriminator(generator_output)
|
||||
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
|
||||
|
||||
# Signal similarity
|
||||
similarity_loss = signal_mae(generator_output, high_quality)
|
||||
|
||||
combined_loss = adversarial_loss + (similarity_loss * 100)
|
||||
|
||||
return combined_loss, adversarial_loss
|
Reference in New Issue
Block a user