Merge new-arch, because it has proven to give the best results #1
33
training.py
33
training.py
@ -41,11 +41,24 @@ args = parser.parse_args()
|
|||||||
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||||
print(f"Using device: {device}")
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
# mfcc_transform = T.MFCC(
|
# Parameters
|
||||||
# sample_rate=44100,
|
sample_rate = 44100
|
||||||
# n_mfcc=20,
|
n_fft = 2048
|
||||||
# melkwargs={'n_fft': 2048, 'hop_length': 256}
|
hop_length = 256
|
||||||
# ).to(device)
|
win_length = n_fft
|
||||||
|
n_mels = 128
|
||||||
|
n_mfcc = 20 # If using MFCC
|
||||||
|
|
||||||
|
mfcc_transform = T.MFCC(
|
||||||
|
sample_rate,
|
||||||
|
n_mfcc,
|
||||||
|
melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
mel_transform = T.MelSpectrogram(
|
||||||
|
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
|
||||||
|
win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
|
||||||
|
).to(device)
|
||||||
|
|
||||||
debug = args.debug
|
debug = args.debug
|
||||||
|
|
||||||
@ -130,18 +143,20 @@ def start_training():
|
|||||||
|
|
||||||
# ========= GENERATOR =========
|
# ========= GENERATOR =========
|
||||||
generator.train()
|
generator.train()
|
||||||
generator_output, adversarial_loss = generator_train(
|
generator_output, combined_loss, adversarial_loss, mel_l1_tensor = generator_train(
|
||||||
low_quality_sample,
|
low_quality_sample,
|
||||||
high_quality_sample,
|
high_quality_sample,
|
||||||
real_labels,
|
real_labels,
|
||||||
generator,
|
generator,
|
||||||
discriminator,
|
discriminator,
|
||||||
criterion_g,
|
criterion_d,
|
||||||
optimizer_g
|
optimizer_g,
|
||||||
|
device,
|
||||||
|
mel_transform
|
||||||
)
|
)
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
print(d_loss, adversarial_loss)
|
print(combined_loss, adversarial_loss, mel_l1_tensor)
|
||||||
scheduler_d.step(d_loss.detach())
|
scheduler_d.step(d_loss.detach())
|
||||||
scheduler_g.step(adversarial_loss.detach())
|
scheduler_g.step(adversarial_loss.detach())
|
||||||
|
|
||||||
|
@ -3,16 +3,73 @@ import torch.nn as nn
|
|||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
import torchaudio.transforms as T
|
||||||
|
|
||||||
def gpu_mfcc_loss(mfcc_transform, y_true, y_pred):
|
def gpu_mfcc_loss(mfcc_transform, y_true, y_pred):
|
||||||
mfccs_true = mfcc_transform(y_true)
|
mfccs_true = mfcc_transform(y_true)
|
||||||
mfccs_pred = mfcc_transform(y_pred)
|
mfccs_pred = mfcc_transform(y_pred)
|
||||||
|
|
||||||
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
|
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
|
||||||
mfccs_true = mfccs_true[:, :, :min_len]
|
mfccs_true = mfccs_true[:, :, :min_len]
|
||||||
mfccs_pred = mfccs_pred[:, :, :min_len]
|
mfccs_pred = mfccs_pred[:, :, :min_len]
|
||||||
|
|
||||||
loss = torch.mean((mfccs_true - mfccs_pred)**2)
|
loss = torch.mean((mfccs_true - mfccs_pred)**2)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
|
||||||
|
mel_spec_true = mel_transform(y_true)
|
||||||
|
mel_spec_pred = mel_transform(y_pred)
|
||||||
|
|
||||||
|
# Ensure same time dimension length (due to potential framing differences)
|
||||||
|
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
|
||||||
|
mel_spec_true = mel_spec_true[..., :min_len]
|
||||||
|
mel_spec_pred = mel_spec_pred[..., :min_len]
|
||||||
|
|
||||||
|
# L1 Loss (Mean Absolute Error)
|
||||||
|
loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def mel_spectrogram_l2_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
|
||||||
|
mel_spec_true = mel_transform(y_true)
|
||||||
|
mel_spec_pred = mel_transform(y_pred)
|
||||||
|
|
||||||
|
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
|
||||||
|
mel_spec_true = mel_spec_true[..., :min_len]
|
||||||
|
mel_spec_pred = mel_spec_pred[..., :min_len]
|
||||||
|
|
||||||
|
# L2 Loss (Mean Squared Error)
|
||||||
|
loss = torch.mean((mel_spec_true - mel_spec_pred)**2)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def log_stft_magnitude_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
|
||||||
|
stft_mag_true = stft_transform(y_true)
|
||||||
|
stft_mag_pred = stft_transform(y_pred)
|
||||||
|
|
||||||
|
min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1])
|
||||||
|
stft_mag_true = stft_mag_true[..., :min_len]
|
||||||
|
stft_mag_pred = stft_mag_pred[..., :min_len]
|
||||||
|
|
||||||
|
# Log Magnitude L1 Loss
|
||||||
|
loss = torch.mean(torch.abs(torch.log(stft_mag_true + eps) - torch.log(stft_mag_pred + eps)))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def spectral_convergence_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
|
||||||
|
stft_mag_true = stft_transform(y_true)
|
||||||
|
stft_mag_pred = stft_transform(y_pred)
|
||||||
|
|
||||||
|
min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1])
|
||||||
|
stft_mag_true = stft_mag_true[..., :min_len]
|
||||||
|
stft_mag_pred = stft_mag_pred[..., :min_len]
|
||||||
|
|
||||||
|
# Calculate Frobenius norms and the loss
|
||||||
|
# Ensure norms are calculated over frequency and time dims ([..., freq, time])
|
||||||
|
norm_true = torch.linalg.norm(stft_mag_true, ord='fro', dim=(-2, -1))
|
||||||
|
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, ord='fro', dim=(-2, -1))
|
||||||
|
|
||||||
|
# Average loss over the batch
|
||||||
|
loss = torch.mean(norm_diff / (norm_true + eps))
|
||||||
|
return loss
|
||||||
|
|
||||||
def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer):
|
def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
@ -21,35 +78,87 @@ def discriminator_train(high_quality, low_quality, real_labels, fake_labels, dis
|
|||||||
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
|
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
|
||||||
|
|
||||||
# Forward pass for fake samples (from generator output)
|
# Forward pass for fake samples (from generator output)
|
||||||
generator_output = generator(low_quality[0])
|
with torch.no_grad(): # Detach generator output within no_grad context
|
||||||
discriminator_decision_from_fake = discriminator(generator_output.detach())
|
generator_output = generator(low_quality[0])
|
||||||
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels)
|
discriminator_decision_from_fake = discriminator(generator_output) # No need to detach again if inside no_grad
|
||||||
|
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake))
|
||||||
|
|
||||||
# Combine real and fake losses
|
# Combine real and fake losses
|
||||||
d_loss = (d_loss_real + d_loss_fake) / 2.0
|
d_loss = (d_loss_real + d_loss_fake) / 2.0
|
||||||
|
|
||||||
# Backward pass and optimization
|
# Backward pass and optimization
|
||||||
d_loss.backward()
|
d_loss.backward()
|
||||||
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
|
# Optional: Gradient Clipping (can be helpful)
|
||||||
|
# nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
return d_loss
|
return d_loss
|
||||||
|
|
||||||
def generator_train(low_quality, high_quality, real_labels, generator, discriminator, criterion, optimizer):
|
def generator_train(
|
||||||
optimizer.zero_grad()
|
low_quality,
|
||||||
|
high_quality,
|
||||||
|
real_labels,
|
||||||
|
generator,
|
||||||
|
discriminator,
|
||||||
|
adv_criterion, # Criterion for adversarial loss (e.g., BCEWithLogitsLoss)
|
||||||
|
g_optimizer,
|
||||||
|
device,
|
||||||
|
# --- Pass necessary transforms and loss weights ---
|
||||||
|
mel_transform: T.MelSpectrogram, # Example: Pass Mel transform
|
||||||
|
# stft_transform: T.Spectrogram, # Pass STFT transform if using STFT losses
|
||||||
|
# mfcc_transform: T.MFCC, # Pass MFCC transform if using MFCC loss
|
||||||
|
lambda_adv: float = 1.0, # Weight for adversarial loss
|
||||||
|
lambda_mel_l1: float = 10.0, # Example: Weight for Mel L1 loss
|
||||||
|
# lambda_log_stft: float = 0.0, # Set weights > 0 for losses you want to use
|
||||||
|
# lambda_mfcc: float = 0.0
|
||||||
|
):
|
||||||
|
g_optimizer.zero_grad()
|
||||||
|
|
||||||
# Forward pass for fake samples (from generator output)
|
# 1. Generate high-quality audio from low-quality input
|
||||||
generator_output = generator(low_quality[0])
|
generator_output = generator(low_quality[0])
|
||||||
|
|
||||||
#mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output)
|
# 2. Calculate Adversarial Loss (Generator tries to fool discriminator)
|
||||||
|
|
||||||
discriminator_decision = discriminator(generator_output)
|
discriminator_decision = discriminator(generator_output)
|
||||||
adversarial_loss = criterion(discriminator_decision, real_labels)
|
# Generator wants discriminator to output "real" labels for its fakes
|
||||||
|
adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision))
|
||||||
|
|
||||||
#combined_loss = adversarial_loss + 0.5 * mfcc_l
|
# 3. Calculate Reconstruction/Spectrogram Loss(es)
|
||||||
|
# --- Choose and calculate the losses you want to include ---
|
||||||
|
mel_l1 = 0.0
|
||||||
|
# log_stft_l1 = 0.0
|
||||||
|
# mfcc_l = 0.0
|
||||||
|
|
||||||
adversarial_loss.backward()
|
# Calculate Mel L1 Loss if weight is positive
|
||||||
optimizer.step()
|
if lambda_mel_l1 > 0:
|
||||||
|
mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality[0], generator_output)
|
||||||
|
|
||||||
#return (generator_output, combined_loss, adversarial_loss, mfcc_l)
|
# # Calculate Log STFT L1 Loss if weight is positive
|
||||||
return (generator_output, adversarial_loss)
|
# if lambda_log_stft > 0:
|
||||||
|
# log_stft_l1 = log_stft_magnitude_loss(stft_transform, hq_audio, generator_output)
|
||||||
|
|
||||||
|
# # Calculate MFCC Loss if weight is positive
|
||||||
|
# if lambda_mfcc > 0:
|
||||||
|
# mfcc_l = gpu_mfcc_loss(mfcc_transform, hq_audio, generator_output)
|
||||||
|
# --- End of Loss Calculation Choices ---
|
||||||
|
|
||||||
|
|
||||||
|
# 4. Combine Losses
|
||||||
|
# Make sure calculated losses are tensors even if weights are 0 initially
|
||||||
|
# (or handle appropriately in the sum)
|
||||||
|
mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1
|
||||||
|
# log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1
|
||||||
|
# mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l
|
||||||
|
|
||||||
|
combined_loss = (lambda_adv * adversarial_loss) + \
|
||||||
|
(lambda_mel_l1 * mel_l1_tensor)
|
||||||
|
# + (lambda_log_stft * log_stft_l1_tensor) \
|
||||||
|
# + (lambda_mfcc * mfcc_l_tensor)
|
||||||
|
|
||||||
|
# 5. Backward Pass and Optimization
|
||||||
|
combined_loss.backward()
|
||||||
|
# Optional: Gradient Clipping
|
||||||
|
# nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
|
||||||
|
g_optimizer.step()
|
||||||
|
|
||||||
|
# 6. Return values for logging
|
||||||
|
return generator_output, combined_loss, adversarial_loss, mel_l1_tensor
|
||||||
|
Loading…
Reference in New Issue
Block a user