| Implemented MFCC and STFT.

This commit is contained in:
NikkeDoy 2025-04-26 17:03:28 +03:00
parent c04b072de6
commit d70c86c257
4 changed files with 40 additions and 54 deletions

View File

@ -39,7 +39,7 @@ class AttentionBlock(nn.Module):
return x * attention_weights return x * attention_weights
class SISUDiscriminator(nn.Module): class SISUDiscriminator(nn.Module):
def __init__(self, base_channels=64): def __init__(self, base_channels=16):
super(SISUDiscriminator, self).__init__() super(SISUDiscriminator, self).__init__()
layers = base_channels layers = base_channels
self.model = nn.Sequential( self.model = nn.Sequential(

View File

@ -48,7 +48,7 @@ class ResidualInResidualBlock(nn.Module):
return x + residual return x + residual
class SISUGenerator(nn.Module): class SISUGenerator(nn.Module):
def __init__(self, channels=64, num_rirb=8, alpha=1.0): def __init__(self, channels=16, num_rirb=4, alpha=1.0):
super(SISUGenerator, self).__init__() super(SISUGenerator, self).__init__()
self.alpha = alpha self.alpha = alpha

View File

@ -34,7 +34,7 @@ parser.add_argument("--discriminator", type=str, default=None,
parser.add_argument("--device", type=str, default="cpu", help="Select device") parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning") parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
parser.add_argument("--debug", action="store_true", help="Print debug logs") parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument("--continue_training", type=bool, default=False, help="Continue training using temp_generator and temp_discriminator models") parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models")
args = parser.parse_args() args = parser.parse_args()
@ -60,6 +60,10 @@ mel_transform = T.MelSpectrogram(
win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
).to(device) ).to(device)
stft_transform = T.Spectrogram(
n_fft=n_fft, win_length=win_length, hop_length=hop_length
).to(device)
debug = args.debug debug = args.debug
# Initialize dataset and dataloader # Initialize dataset and dataloader
@ -72,7 +76,7 @@ os.makedirs(audio_output_dir, exist_ok=True)
# ========= SINGLE ========= # ========= SINGLE =========
train_data_loader = DataLoader(dataset, batch_size=12, shuffle=True) train_data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
# ========= MODELS ========= # ========= MODELS =========
@ -143,7 +147,7 @@ def start_training():
# ========= GENERATOR ========= # ========= GENERATOR =========
generator.train() generator.train()
generator_output, combined_loss, adversarial_loss, mel_l1_tensor = generator_train( generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train(
low_quality_sample, low_quality_sample,
high_quality_sample, high_quality_sample,
real_labels, real_labels,
@ -152,11 +156,13 @@ def start_training():
criterion_d, criterion_d,
optimizer_g, optimizer_g,
device, device,
mel_transform mel_transform,
stft_transform,
mfcc_transform
) )
if debug: if debug:
print(combined_loss, adversarial_loss, mel_l1_tensor) print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}")
scheduler_d.step(d_loss.detach()) scheduler_d.step(d_loss.detach())
scheduler_g.step(adversarial_loss.detach()) scheduler_g.step(adversarial_loss.detach())
@ -173,9 +179,9 @@ def start_training():
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1]) torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1])
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1]) torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1])
if debug: #if debug:
print(generator.state_dict().keys()) # print(generator.state_dict().keys())
print(discriminator.state_dict().keys()) # print(discriminator.state_dict().keys())
torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt") torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt")
torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt") torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt")
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch}) Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch})

View File

@ -37,7 +37,6 @@ def mel_spectrogram_l2_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tenso
mel_spec_true = mel_spec_true[..., :min_len] mel_spec_true = mel_spec_true[..., :min_len]
mel_spec_pred = mel_spec_pred[..., :min_len] mel_spec_pred = mel_spec_pred[..., :min_len]
# L2 Loss (Mean Squared Error)
loss = torch.mean((mel_spec_true - mel_spec_pred)**2) loss = torch.mean((mel_spec_true - mel_spec_pred)**2)
return loss return loss
@ -49,7 +48,6 @@ def log_stft_magnitude_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor,
stft_mag_true = stft_mag_true[..., :min_len] stft_mag_true = stft_mag_true[..., :min_len]
stft_mag_pred = stft_mag_pred[..., :min_len] stft_mag_pred = stft_mag_pred[..., :min_len]
# Log Magnitude L1 Loss
loss = torch.mean(torch.abs(torch.log(stft_mag_true + eps) - torch.log(stft_mag_pred + eps))) loss = torch.mean(torch.abs(torch.log(stft_mag_true + eps) - torch.log(stft_mag_pred + eps)))
return loss return loss
@ -61,12 +59,9 @@ def spectral_convergence_loss(stft_transform: T.Spectrogram, y_true: torch.Tenso
stft_mag_true = stft_mag_true[..., :min_len] stft_mag_true = stft_mag_true[..., :min_len]
stft_mag_pred = stft_mag_pred[..., :min_len] stft_mag_pred = stft_mag_pred[..., :min_len]
# Calculate Frobenius norms and the loss
# Ensure norms are calculated over frequency and time dims ([..., freq, time])
norm_true = torch.linalg.norm(stft_mag_true, ord='fro', dim=(-2, -1)) norm_true = torch.linalg.norm(stft_mag_true, ord='fro', dim=(-2, -1))
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, ord='fro', dim=(-2, -1)) norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, ord='fro', dim=(-2, -1))
# Average loss over the batch
loss = torch.mean(norm_diff / (norm_true + eps)) loss = torch.mean(norm_diff / (norm_true + eps))
return loss return loss
@ -77,16 +72,13 @@ def discriminator_train(high_quality, low_quality, real_labels, fake_labels, dis
discriminator_decision_from_real = discriminator(high_quality[0]) discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion(discriminator_decision_from_real, real_labels) d_loss_real = criterion(discriminator_decision_from_real, real_labels)
# Forward pass for fake samples (from generator output) with torch.no_grad():
with torch.no_grad(): # Detach generator output within no_grad context
generator_output = generator(low_quality[0]) generator_output = generator(low_quality[0])
discriminator_decision_from_fake = discriminator(generator_output) # No need to detach again if inside no_grad discriminator_decision_from_fake = discriminator(generator_output)
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake)) d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake))
# Combine real and fake losses
d_loss = (d_loss_real + d_loss_fake) / 2.0 d_loss = (d_loss_real + d_loss_fake) / 2.0
# Backward pass and optimization
d_loss.backward() d_loss.backward()
# Optional: Gradient Clipping (can be helpful) # Optional: Gradient Clipping (can be helpful)
# nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping # nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
@ -100,65 +92,53 @@ def generator_train(
real_labels, real_labels,
generator, generator,
discriminator, discriminator,
adv_criterion, # Criterion for adversarial loss (e.g., BCEWithLogitsLoss) adv_criterion,
g_optimizer, g_optimizer,
device, device,
# --- Pass necessary transforms and loss weights --- mel_transform: T.MelSpectrogram,
mel_transform: T.MelSpectrogram, # Example: Pass Mel transform stft_transform: T.Spectrogram,
# stft_transform: T.Spectrogram, # Pass STFT transform if using STFT losses mfcc_transform: T.MFCC,
# mfcc_transform: T.MFCC, # Pass MFCC transform if using MFCC loss lambda_adv: float = 1.0,
lambda_adv: float = 1.0, # Weight for adversarial loss lambda_mel_l1: float = 10.0,
lambda_mel_l1: float = 10.0, # Example: Weight for Mel L1 loss lambda_log_stft: float = 1.0,
# lambda_log_stft: float = 0.0, # Set weights > 0 for losses you want to use lambda_mfcc: float = 1.0
# lambda_mfcc: float = 0.0
): ):
g_optimizer.zero_grad() g_optimizer.zero_grad()
# 1. Generate high-quality audio from low-quality input
generator_output = generator(low_quality[0]) generator_output = generator(low_quality[0])
# 2. Calculate Adversarial Loss (Generator tries to fool discriminator)
discriminator_decision = discriminator(generator_output) discriminator_decision = discriminator(generator_output)
# Generator wants discriminator to output "real" labels for its fakes
adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision)) adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision))
# 3. Calculate Reconstruction/Spectrogram Loss(es)
# --- Choose and calculate the losses you want to include ---
mel_l1 = 0.0 mel_l1 = 0.0
# log_stft_l1 = 0.0 log_stft_l1 = 0.0
# mfcc_l = 0.0 mfcc_l = 0.0
# Calculate Mel L1 Loss if weight is positive # Calculate Mel L1 Loss if weight is positive
if lambda_mel_l1 > 0: if lambda_mel_l1 > 0:
mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality[0], generator_output) mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality[0], generator_output)
# # Calculate Log STFT L1 Loss if weight is positive # Calculate Log STFT L1 Loss if weight is positive
# if lambda_log_stft > 0: if lambda_log_stft > 0:
# log_stft_l1 = log_stft_magnitude_loss(stft_transform, hq_audio, generator_output) log_stft_l1 = log_stft_magnitude_loss(stft_transform, high_quality[0], generator_output)
# # Calculate MFCC Loss if weight is positive # Calculate MFCC Loss if weight is positive
# if lambda_mfcc > 0: if lambda_mfcc > 0:
# mfcc_l = gpu_mfcc_loss(mfcc_transform, hq_audio, generator_output) mfcc_l = gpu_mfcc_loss(mfcc_transform, high_quality[0], generator_output)
# --- End of Loss Calculation Choices ---
# 4. Combine Losses
# Make sure calculated losses are tensors even if weights are 0 initially
# (or handle appropriately in the sum)
mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1 mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1
# log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1 log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1
# mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l
combined_loss = (lambda_adv * adversarial_loss) + \ combined_loss = (lambda_adv * adversarial_loss) + \
(lambda_mel_l1 * mel_l1_tensor) (lambda_mel_l1 * mel_l1_tensor) + \
# + (lambda_log_stft * log_stft_l1_tensor) \ (lambda_log_stft * log_stft_l1_tensor) + \
# + (lambda_mfcc * mfcc_l_tensor) (lambda_mfcc * mfcc_l_tensor)
# 5. Backward Pass and Optimization
combined_loss.backward() combined_loss.backward()
# Optional: Gradient Clipping # Optional: Gradient Clipping
# nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0) # nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
g_optimizer.step() g_optimizer.step()
# 6. Return values for logging # 6. Return values for logging
return generator_output, combined_loss, adversarial_loss, mel_l1_tensor return generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor