⚗️ | More architectural changes
This commit is contained in:
@@ -1,12 +1,17 @@
|
||||
import torch
|
||||
|
||||
# In case if needed again...
|
||||
# from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
|
||||
#
|
||||
# stft_loss_fn = MultiResolutionSTFTLoss(
|
||||
# fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240]
|
||||
# )
|
||||
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
|
||||
|
||||
# stft_loss_fn = MultiResolutionSTFTLoss(
|
||||
# fft_sizes=[512, 1024, 2048, 4096],
|
||||
# hop_sizes=[128, 256, 512, 1024],
|
||||
# win_lengths=[512, 1024, 2048, 4096]
|
||||
# )
|
||||
stft_loss_fn = MultiResolutionSTFTLoss(
|
||||
fft_sizes=[512, 1024, 2048],
|
||||
hop_sizes=[64, 128, 256],
|
||||
win_lengths=[256, 512, 1024]
|
||||
)
|
||||
|
||||
def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor:
|
||||
absolute_difference = torch.abs(input_one - input_two)
|
||||
@@ -19,42 +24,35 @@ def discriminator_train(
|
||||
high_labels,
|
||||
low_labels,
|
||||
discriminator,
|
||||
generator,
|
||||
criterion,
|
||||
generator_output
|
||||
):
|
||||
decision_high = discriminator(high_quality)
|
||||
d_loss_high = criterion(decision_high, high_labels)
|
||||
# print(f"Is this real?: {discriminator_decision_from_real} | {d_loss_real}")
|
||||
|
||||
decision_low = discriminator(low_quality)
|
||||
d_loss_low = criterion(decision_low, low_labels)
|
||||
# print(f"Is this real?: {discriminator_decision_from_fake} | {d_loss_fake}")
|
||||
real_pair = torch.cat((low_quality, high_quality), dim=1)
|
||||
decision_real = discriminator(real_pair)
|
||||
d_loss_real = criterion(decision_real, high_labels)
|
||||
|
||||
with torch.no_grad():
|
||||
generator_quality = generator(low_quality)
|
||||
decision_gen = discriminator(generator_quality)
|
||||
d_loss_gen = criterion(decision_gen, low_labels)
|
||||
|
||||
noise = torch.rand_like(high_quality) * 0.08
|
||||
decision_noise = discriminator(high_quality + noise)
|
||||
d_loss_noise = criterion(decision_noise, low_labels)
|
||||
|
||||
d_loss = (d_loss_high + d_loss_low + d_loss_gen + d_loss_noise) / 4.0
|
||||
fake_pair = torch.cat((low_quality, generator_output), dim=1)
|
||||
decision_fake = discriminator(fake_pair)
|
||||
d_loss_fake = criterion(decision_fake, low_labels)
|
||||
|
||||
d_loss = (d_loss_real + d_loss_fake) / 2.0
|
||||
return d_loss
|
||||
|
||||
|
||||
def generator_train(
|
||||
low_quality, high_quality, real_labels, generator, discriminator, adv_criterion
|
||||
):
|
||||
generator_output = generator(low_quality)
|
||||
low_quality, high_quality, real_labels, generator, discriminator, adv_criterion, generator_output):
|
||||
|
||||
discriminator_decision = discriminator(generator_output)
|
||||
fake_pair = torch.cat((low_quality, generator_output), dim=1)
|
||||
|
||||
discriminator_decision = discriminator(fake_pair)
|
||||
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
|
||||
|
||||
# Signal similarity
|
||||
similarity_loss = signal_mae(generator_output, high_quality)
|
||||
|
||||
combined_loss = adversarial_loss + (similarity_loss * 100)
|
||||
mae_loss = signal_mae(generator_output, high_quality)
|
||||
stft_loss = stft_loss_fn(high_quality, generator_output)["total"]
|
||||
|
||||
lambda_mae = 10.0
|
||||
lambda_stft = 2.5
|
||||
lambda_adv = 2.5
|
||||
combined_loss = (lambda_mae * mae_loss) + (lambda_stft * stft_loss) + (lambda_adv * adversarial_loss)
|
||||
return combined_loss, adversarial_loss
|
||||
|
||||
Reference in New Issue
Block a user