SISU/training_utils.py
2025-04-14 17:51:34 +03:00

56 lines
1.9 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
def gpu_mfcc_loss(mfcc_transform, y_true, y_pred):
mfccs_true = mfcc_transform(y_true)
mfccs_pred = mfcc_transform(y_pred)
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
mfccs_true = mfccs_true[:, :, :min_len]
mfccs_pred = mfccs_pred[:, :, :min_len]
loss = torch.mean((mfccs_true - mfccs_pred)**2)
return loss
def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer):
optimizer.zero_grad()
# Forward pass for real samples
discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
discriminator_decision_from_fake = discriminator(generator_output.detach())
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels)
# Combine real and fake losses
d_loss = (d_loss_real + d_loss_fake) / 2.0
# Backward pass and optimization
d_loss.backward()
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
optimizer.step()
return d_loss
def generator_train(low_quality, high_quality, real_labels, generator, discriminator, criterion, optimizer):
optimizer.zero_grad()
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
#mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output)
discriminator_decision = discriminator(generator_output)
adversarial_loss = criterion(discriminator_decision, real_labels)
#combined_loss = adversarial_loss + 0.5 * mfcc_l
adversarial_loss.backward()
optimizer.step()
#return (generator_output, combined_loss, adversarial_loss, mfcc_l)
return (generator_output, adversarial_loss)