56 lines
1.9 KiB
Python
56 lines
1.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
|
|
import torchaudio
|
|
|
|
def gpu_mfcc_loss(mfcc_transform, y_true, y_pred):
|
|
mfccs_true = mfcc_transform(y_true)
|
|
mfccs_pred = mfcc_transform(y_pred)
|
|
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
|
|
mfccs_true = mfccs_true[:, :, :min_len]
|
|
mfccs_pred = mfccs_pred[:, :, :min_len]
|
|
loss = torch.mean((mfccs_true - mfccs_pred)**2)
|
|
return loss
|
|
|
|
def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer):
|
|
optimizer.zero_grad()
|
|
|
|
# Forward pass for real samples
|
|
discriminator_decision_from_real = discriminator(high_quality[0])
|
|
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
|
|
|
|
# Forward pass for fake samples (from generator output)
|
|
generator_output = generator(low_quality[0])
|
|
discriminator_decision_from_fake = discriminator(generator_output.detach())
|
|
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels)
|
|
|
|
# Combine real and fake losses
|
|
d_loss = (d_loss_real + d_loss_fake) / 2.0
|
|
|
|
# Backward pass and optimization
|
|
d_loss.backward()
|
|
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
|
|
optimizer.step()
|
|
|
|
return d_loss
|
|
|
|
def generator_train(low_quality, high_quality, real_labels, generator, discriminator, criterion, optimizer):
|
|
optimizer.zero_grad()
|
|
|
|
# Forward pass for fake samples (from generator output)
|
|
generator_output = generator(low_quality[0])
|
|
|
|
#mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output)
|
|
|
|
discriminator_decision = discriminator(generator_output)
|
|
adversarial_loss = criterion(discriminator_decision, real_labels)
|
|
|
|
#combined_loss = adversarial_loss + 0.5 * mfcc_l
|
|
|
|
adversarial_loss.backward()
|
|
optimizer.step()
|
|
|
|
#return (generator_output, combined_loss, adversarial_loss, mfcc_l)
|
|
return (generator_output, adversarial_loss)
|