:albemic: | Real-time testing...
This commit is contained in:
@ -20,12 +20,10 @@ def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tenso
|
||||
mel_spec_true = mel_transform(y_true)
|
||||
mel_spec_pred = mel_transform(y_pred)
|
||||
|
||||
# Ensure same time dimension length (due to potential framing differences)
|
||||
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
|
||||
mel_spec_true = mel_spec_true[..., :min_len]
|
||||
mel_spec_pred = mel_spec_pred[..., :min_len]
|
||||
|
||||
# L1 Loss (Mean Absolute Error)
|
||||
loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred))
|
||||
return loss
|
||||
|
||||
@ -69,11 +67,11 @@ def discriminator_train(high_quality, low_quality, real_labels, fake_labels, dis
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass for real samples
|
||||
discriminator_decision_from_real = discriminator(high_quality[0])
|
||||
discriminator_decision_from_real = discriminator(high_quality)
|
||||
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
|
||||
|
||||
with torch.no_grad():
|
||||
generator_output = generator(low_quality[0])
|
||||
generator_output = generator(low_quality)
|
||||
discriminator_decision_from_fake = discriminator(generator_output)
|
||||
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake))
|
||||
|
||||
@ -105,7 +103,7 @@ def generator_train(
|
||||
):
|
||||
g_optimizer.zero_grad()
|
||||
|
||||
generator_output = generator(low_quality[0])
|
||||
generator_output = generator(low_quality)
|
||||
|
||||
discriminator_decision = discriminator(generator_output)
|
||||
adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision))
|
||||
@ -116,15 +114,15 @@ def generator_train(
|
||||
|
||||
# Calculate Mel L1 Loss if weight is positive
|
||||
if lambda_mel_l1 > 0:
|
||||
mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality[0], generator_output)
|
||||
mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality, generator_output)
|
||||
|
||||
# Calculate Log STFT L1 Loss if weight is positive
|
||||
if lambda_log_stft > 0:
|
||||
log_stft_l1 = log_stft_magnitude_loss(stft_transform, high_quality[0], generator_output)
|
||||
log_stft_l1 = log_stft_magnitude_loss(stft_transform, high_quality, generator_output)
|
||||
|
||||
# Calculate MFCC Loss if weight is positive
|
||||
if lambda_mfcc > 0:
|
||||
mfcc_l = gpu_mfcc_loss(mfcc_transform, high_quality[0], generator_output)
|
||||
mfcc_l = gpu_mfcc_loss(mfcc_transform, high_quality, generator_output)
|
||||
|
||||
mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1
|
||||
log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1
|
||||
|
Reference in New Issue
Block a user