♻️ | Restructured procject code.
This commit is contained in:
55
training_utils.py
Normal file
55
training_utils.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
import torchaudio
|
||||
|
||||
def gpu_mfcc_loss(mfcc_transform, y_true, y_pred):
|
||||
mfccs_true = mfcc_transform(y_true)
|
||||
mfccs_pred = mfcc_transform(y_pred)
|
||||
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
|
||||
mfccs_true = mfccs_true[:, :, :min_len]
|
||||
mfccs_pred = mfccs_pred[:, :, :min_len]
|
||||
loss = torch.mean((mfccs_true - mfccs_pred)**2)
|
||||
return loss
|
||||
|
||||
def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer):
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass for real samples
|
||||
discriminator_decision_from_real = discriminator(high_quality[0])
|
||||
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
|
||||
|
||||
# Forward pass for fake samples (from generator output)
|
||||
generator_output = generator(low_quality[0])
|
||||
discriminator_decision_from_fake = discriminator(generator_output.detach())
|
||||
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels)
|
||||
|
||||
# Combine real and fake losses
|
||||
d_loss = (d_loss_real + d_loss_fake) / 2.0
|
||||
|
||||
# Backward pass and optimization
|
||||
d_loss.backward()
|
||||
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
|
||||
optimizer.step()
|
||||
|
||||
return d_loss
|
||||
|
||||
def generator_train(low_quality, high_quality, real_labels, generator, discriminator, criterion, optimizer):
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass for fake samples (from generator output)
|
||||
generator_output = generator(low_quality[0])
|
||||
|
||||
#mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output)
|
||||
|
||||
discriminator_decision = discriminator(generator_output)
|
||||
adversarial_loss = criterion(discriminator_decision, real_labels)
|
||||
|
||||
#combined_loss = adversarial_loss + 0.5 * mfcc_l
|
||||
|
||||
adversarial_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
#return (generator_output, combined_loss, adversarial_loss, mfcc_l)
|
||||
return (generator_output, adversarial_loss)
|
Reference in New Issue
Block a user