✨ | Made training bit... spicier.
This commit is contained in:
@@ -1,89 +1,88 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
import torchaudio
|
||||
import torchaudio.transforms as T
|
||||
|
||||
def gpu_mfcc_loss(mfcc_transform, y_true, y_pred):
|
||||
mfccs_true = mfcc_transform(y_true)
|
||||
mfccs_pred = mfcc_transform(y_pred)
|
||||
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
|
||||
|
||||
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
|
||||
mfccs_true = mfccs_true[:, :, :min_len]
|
||||
mfccs_pred = mfccs_pred[:, :, :min_len]
|
||||
mel_transform: T.MelSpectrogram
|
||||
stft_transform: T.Spectrogram
|
||||
# mfcc_transform: T.MFCC
|
||||
|
||||
loss = torch.mean((mfccs_true - mfccs_pred)**2)
|
||||
return loss
|
||||
|
||||
def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
|
||||
mel_spec_true = mel_transform(y_true)
|
||||
mel_spec_pred = mel_transform(y_pred)
|
||||
# def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram, mfcc_trans: T.MFCC):
|
||||
# """Initializes the global transform variables for the module."""
|
||||
# global mel_transform, stft_transform, mfcc_transform
|
||||
# mel_transform = mel_trans
|
||||
# stft_transform = stft_trans
|
||||
# mfcc_transform = mfcc_trans
|
||||
|
||||
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
|
||||
mel_spec_true = mel_spec_true[..., :min_len]
|
||||
mel_spec_pred = mel_spec_pred[..., :min_len]
|
||||
|
||||
loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred))
|
||||
return loss
|
||||
def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram):
|
||||
"""Initializes the global transform variables for the module."""
|
||||
global mel_transform, stft_transform
|
||||
mel_transform = mel_trans
|
||||
stft_transform = stft_trans
|
||||
|
||||
def mel_spectrogram_l2_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
|
||||
mel_spec_true = mel_transform(y_true)
|
||||
mel_spec_pred = mel_transform(y_pred)
|
||||
|
||||
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
|
||||
mel_spec_true = mel_spec_true[..., :min_len]
|
||||
mel_spec_pred = mel_spec_pred[..., :min_len]
|
||||
# def mfcc_loss(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
|
||||
# """Computes the Mean Squared Error (MSE) loss on MFCCs."""
|
||||
# mfccs_true = mfcc_transform(y_true)
|
||||
# mfccs_pred = mfcc_transform(y_pred)
|
||||
# return F.mse_loss(mfccs_pred, mfccs_true)
|
||||
|
||||
loss = torch.mean((mel_spec_true - mel_spec_pred)**2)
|
||||
return loss
|
||||
|
||||
def log_stft_magnitude_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
|
||||
stft_mag_true = stft_transform(y_true)
|
||||
stft_mag_pred = stft_transform(y_pred)
|
||||
# def mel_spectrogram_loss(
|
||||
# y_true: torch.Tensor, y_pred: torch.Tensor, loss_type: str = "l1"
|
||||
# ) -> torch.Tensor:
|
||||
# """Calculates L1 or L2 loss on the Mel Spectrogram."""
|
||||
# mel_spec_true = mel_transform(y_true)
|
||||
# mel_spec_pred = mel_transform(y_pred)
|
||||
# if loss_type == "l1":
|
||||
# return F.l1_loss(mel_spec_pred, mel_spec_true)
|
||||
# elif loss_type == "l2":
|
||||
# return F.mse_loss(mel_spec_pred, mel_spec_true)
|
||||
# else:
|
||||
# raise ValueError("loss_type must be 'l1' or 'l2'")
|
||||
|
||||
min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1])
|
||||
stft_mag_true = stft_mag_true[..., :min_len]
|
||||
stft_mag_pred = stft_mag_pred[..., :min_len]
|
||||
|
||||
loss = torch.mean(torch.abs(torch.log(stft_mag_true + eps) - torch.log(stft_mag_pred + eps)))
|
||||
return loss
|
||||
# def log_stft_magnitude_loss(
|
||||
# y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7
|
||||
# ) -> torch.Tensor:
|
||||
# """Calculates L1 loss on the log STFT magnitude."""
|
||||
# stft_mag_true = stft_transform(y_true)
|
||||
# stft_mag_pred = stft_transform(y_pred)
|
||||
# return F.l1_loss(torch.log(stft_mag_pred + eps), torch.log(stft_mag_true + eps))
|
||||
|
||||
def spectral_convergence_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
|
||||
stft_mag_true = stft_transform(y_true)
|
||||
stft_mag_pred = stft_transform(y_pred)
|
||||
|
||||
min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1])
|
||||
stft_mag_true = stft_mag_true[..., :min_len]
|
||||
stft_mag_pred = stft_mag_pred[..., :min_len]
|
||||
stft_loss_fn = MultiResolutionSTFTLoss(
|
||||
fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240]
|
||||
)
|
||||
|
||||
norm_true = torch.linalg.norm(stft_mag_true, ord='fro', dim=(-2, -1))
|
||||
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, ord='fro', dim=(-2, -1))
|
||||
|
||||
loss = torch.mean(norm_diff / (norm_true + eps))
|
||||
return loss
|
||||
|
||||
def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer):
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass for real samples
|
||||
def discriminator_train(
|
||||
high_quality,
|
||||
low_quality,
|
||||
real_labels,
|
||||
fake_labels,
|
||||
discriminator,
|
||||
generator,
|
||||
criterion,
|
||||
):
|
||||
discriminator_decision_from_real = discriminator(high_quality)
|
||||
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
|
||||
|
||||
with torch.no_grad():
|
||||
generator_output = generator(low_quality)
|
||||
discriminator_decision_from_fake = discriminator(generator_output)
|
||||
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake))
|
||||
d_loss_fake = criterion(
|
||||
discriminator_decision_from_fake,
|
||||
fake_labels.expand_as(discriminator_decision_from_fake),
|
||||
)
|
||||
|
||||
d_loss = (d_loss_real + d_loss_fake) / 2.0
|
||||
|
||||
d_loss.backward()
|
||||
# Optional: Gradient Clipping (can be helpful)
|
||||
# nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
|
||||
optimizer.step()
|
||||
|
||||
return d_loss
|
||||
|
||||
|
||||
def generator_train(
|
||||
low_quality,
|
||||
high_quality,
|
||||
@@ -91,52 +90,65 @@ def generator_train(
|
||||
generator,
|
||||
discriminator,
|
||||
adv_criterion,
|
||||
g_optimizer,
|
||||
device,
|
||||
mel_transform: T.MelSpectrogram,
|
||||
stft_transform: T.Spectrogram,
|
||||
mfcc_transform: T.MFCC,
|
||||
lambda_adv: float = 1.0,
|
||||
lambda_mel_l1: float = 10.0,
|
||||
lambda_log_stft: float = 1.0,
|
||||
lambda_mfcc: float = 1.0
|
||||
lambda_feat: float = 10.0,
|
||||
lambda_stft: float = 2.5,
|
||||
):
|
||||
g_optimizer.zero_grad()
|
||||
|
||||
generator_output = generator(low_quality)
|
||||
|
||||
discriminator_decision = discriminator(generator_output)
|
||||
adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision))
|
||||
# adversarial_loss = adv_criterion(
|
||||
# discriminator_decision, real_labels.expand_as(discriminator_decision)
|
||||
# )
|
||||
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
|
||||
|
||||
mel_l1 = 0.0
|
||||
log_stft_l1 = 0.0
|
||||
mfcc_l = 0.0
|
||||
combined_loss = lambda_adv * adversarial_loss
|
||||
|
||||
# Calculate Mel L1 Loss if weight is positive
|
||||
if lambda_mel_l1 > 0:
|
||||
mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality, generator_output)
|
||||
stft_losses = stft_loss_fn(high_quality, generator_output)
|
||||
stft_loss = stft_losses["total"]
|
||||
|
||||
# Calculate Log STFT L1 Loss if weight is positive
|
||||
if lambda_log_stft > 0:
|
||||
log_stft_l1 = log_stft_magnitude_loss(stft_transform, high_quality, generator_output)
|
||||
combined_loss = (lambda_adv * adversarial_loss) + (lambda_stft * stft_loss)
|
||||
|
||||
# Calculate MFCC Loss if weight is positive
|
||||
if lambda_mfcc > 0:
|
||||
mfcc_l = gpu_mfcc_loss(mfcc_transform, high_quality, generator_output)
|
||||
return generator_output, combined_loss, adversarial_loss
|
||||
|
||||
mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1
|
||||
log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1
|
||||
mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l
|
||||
|
||||
combined_loss = (lambda_adv * adversarial_loss) + \
|
||||
(lambda_mel_l1 * mel_l1_tensor) + \
|
||||
(lambda_log_stft * log_stft_l1_tensor) + \
|
||||
(lambda_mfcc * mfcc_l_tensor)
|
||||
# def generator_train(
|
||||
# low_quality,
|
||||
# high_quality,
|
||||
# real_labels,
|
||||
# generator,
|
||||
# discriminator,
|
||||
# adv_criterion,
|
||||
# lambda_adv: float = 1.0,
|
||||
# lambda_mel_l1: float = 10.0,
|
||||
# lambda_log_stft: float = 1.0,
|
||||
|
||||
combined_loss.backward()
|
||||
# Optional: Gradient Clipping
|
||||
# nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
|
||||
g_optimizer.step()
|
||||
# ):
|
||||
# generator_output = generator(low_quality)
|
||||
|
||||
# 6. Return values for logging
|
||||
return generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor
|
||||
# discriminator_decision = discriminator(generator_output)
|
||||
# adversarial_loss = adv_criterion(
|
||||
# discriminator_decision, real_labels.expand_as(discriminator_decision)
|
||||
# )
|
||||
|
||||
# combined_loss = lambda_adv * adversarial_loss
|
||||
|
||||
# if lambda_mel_l1 > 0:
|
||||
# mel_l1_loss = mel_spectrogram_loss(high_quality, generator_output, "l1")
|
||||
# combined_loss += lambda_mel_l1 * mel_l1_loss
|
||||
# else:
|
||||
# mel_l1_loss = torch.tensor(0.0, device=low_quality.device) # For logging
|
||||
|
||||
# if lambda_log_stft > 0:
|
||||
# log_stft_loss = log_stft_magnitude_loss(high_quality, generator_output)
|
||||
# combined_loss += lambda_log_stft * log_stft_loss
|
||||
# else:
|
||||
# log_stft_loss = torch.tensor(0.0, device=low_quality.device)
|
||||
|
||||
# if lambda_mfcc > 0:
|
||||
# mfcc_loss_val = mfcc_loss(high_quality, generator_output)
|
||||
# combined_loss += lambda_mfcc * mfcc_loss_val
|
||||
# else:
|
||||
# mfcc_loss_val = torch.tensor(0.0, device=low_quality.device)
|
||||
|
||||
# return generator_output, combined_loss, adversarial_loss
|
||||
|
Reference in New Issue
Block a user