♻️ | Restructured procject code.

This commit is contained in:
2025-04-14 17:51:34 +03:00
parent 3936b6c160
commit b6d16e4f11
3 changed files with 137 additions and 73 deletions

55
training_utils.py Normal file
View File

@@ -0,0 +1,55 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
def gpu_mfcc_loss(mfcc_transform, y_true, y_pred):
mfccs_true = mfcc_transform(y_true)
mfccs_pred = mfcc_transform(y_pred)
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
mfccs_true = mfccs_true[:, :, :min_len]
mfccs_pred = mfccs_pred[:, :, :min_len]
loss = torch.mean((mfccs_true - mfccs_pred)**2)
return loss
def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer):
optimizer.zero_grad()
# Forward pass for real samples
discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
discriminator_decision_from_fake = discriminator(generator_output.detach())
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels)
# Combine real and fake losses
d_loss = (d_loss_real + d_loss_fake) / 2.0
# Backward pass and optimization
d_loss.backward()
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
optimizer.step()
return d_loss
def generator_train(low_quality, high_quality, real_labels, generator, discriminator, criterion, optimizer):
optimizer.zero_grad()
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
#mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output)
discriminator_decision = discriminator(generator_output)
adversarial_loss = criterion(discriminator_decision, real_labels)
#combined_loss = adversarial_loss + 0.5 * mfcc_l
adversarial_loss.backward()
optimizer.step()
#return (generator_output, combined_loss, adversarial_loss, mfcc_l)
return (generator_output, adversarial_loss)