From d70c86c2578bdd94772069917bed040aa8046275 Mon Sep 17 00:00:00 2001
From: NikkeDoy <niklas.siltala@gmail.com>
Date: Sat, 26 Apr 2025 17:03:28 +0300
Subject: [PATCH] :sparkles: | Implemented MFCC and STFT.

---
 discriminator.py  |  2 +-
 generator.py      |  2 +-
 training.py       | 22 +++++++++------
 training_utils.py | 68 +++++++++++++++++------------------------------
 4 files changed, 40 insertions(+), 54 deletions(-)

diff --git a/discriminator.py b/discriminator.py
index 777abf2..dfd0126 100644
--- a/discriminator.py
+++ b/discriminator.py
@@ -39,7 +39,7 @@ class AttentionBlock(nn.Module):
         return x * attention_weights
 
 class SISUDiscriminator(nn.Module):
-    def __init__(self, base_channels=64):
+    def __init__(self, base_channels=16):
         super(SISUDiscriminator, self).__init__()
         layers = base_channels
         self.model = nn.Sequential(
diff --git a/generator.py b/generator.py
index cd4d48c..a53feb7 100644
--- a/generator.py
+++ b/generator.py
@@ -48,7 +48,7 @@ class ResidualInResidualBlock(nn.Module):
         return x + residual
 
 class SISUGenerator(nn.Module):
-    def __init__(self, channels=64, num_rirb=8, alpha=1.0):
+    def __init__(self, channels=16, num_rirb=4, alpha=1.0):
         super(SISUGenerator, self).__init__()
         self.alpha = alpha
 
diff --git a/training.py b/training.py
index db7cb86..01ea749 100644
--- a/training.py
+++ b/training.py
@@ -34,7 +34,7 @@ parser.add_argument("--discriminator", type=str, default=None,
 parser.add_argument("--device", type=str, default="cpu",  help="Select device")
 parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
 parser.add_argument("--debug", action="store_true",  help="Print debug logs")
-parser.add_argument("--continue_training", type=bool, default=False, help="Continue training using temp_generator and temp_discriminator models")
+parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models")
 
 args = parser.parse_args()
 
@@ -60,6 +60,10 @@ mel_transform = T.MelSpectrogram(
     win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
 ).to(device)
 
+stft_transform = T.Spectrogram(
+    n_fft=n_fft, win_length=win_length, hop_length=hop_length
+).to(device)
+
 debug = args.debug
 
 # Initialize dataset and dataloader
@@ -72,7 +76,7 @@ os.makedirs(audio_output_dir, exist_ok=True)
 
 # ========= SINGLE =========
 
-train_data_loader = DataLoader(dataset, batch_size=12, shuffle=True)
+train_data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
 
 
 # ========= MODELS =========
@@ -143,7 +147,7 @@ def start_training():
 
             # ========= GENERATOR =========
             generator.train()
-            generator_output, combined_loss, adversarial_loss, mel_l1_tensor = generator_train(
+            generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train(
                 low_quality_sample,
                 high_quality_sample,
                 real_labels,
@@ -152,11 +156,13 @@ def start_training():
                 criterion_d,
                 optimizer_g,
                 device,
-                mel_transform
+                mel_transform,
+                stft_transform,
+                mfcc_transform
             )
 
             if debug:
-                print(combined_loss, adversarial_loss, mel_l1_tensor)
+                print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}")
             scheduler_d.step(d_loss.detach())
             scheduler_g.step(adversarial_loss.detach())
 
@@ -173,9 +179,9 @@ def start_training():
             torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1])
             torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1])
 
-        if debug:
-            print(generator.state_dict().keys())
-            print(discriminator.state_dict().keys())
+        #if debug:
+        #    print(generator.state_dict().keys())
+        #    print(discriminator.state_dict().keys())
         torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt")
         torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt")
         Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch})
diff --git a/training_utils.py b/training_utils.py
index be402d9..6f26f58 100644
--- a/training_utils.py
+++ b/training_utils.py
@@ -37,7 +37,6 @@ def mel_spectrogram_l2_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tenso
     mel_spec_true = mel_spec_true[..., :min_len]
     mel_spec_pred = mel_spec_pred[..., :min_len]
 
-    # L2 Loss (Mean Squared Error)
     loss = torch.mean((mel_spec_true - mel_spec_pred)**2)
     return loss
 
@@ -49,7 +48,6 @@ def log_stft_magnitude_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor,
     stft_mag_true = stft_mag_true[..., :min_len]
     stft_mag_pred = stft_mag_pred[..., :min_len]
 
-    # Log Magnitude L1 Loss
     loss = torch.mean(torch.abs(torch.log(stft_mag_true + eps) - torch.log(stft_mag_pred + eps)))
     return loss
 
@@ -61,12 +59,9 @@ def spectral_convergence_loss(stft_transform: T.Spectrogram, y_true: torch.Tenso
     stft_mag_true = stft_mag_true[..., :min_len]
     stft_mag_pred = stft_mag_pred[..., :min_len]
 
-    # Calculate Frobenius norms and the loss
-    # Ensure norms are calculated over frequency and time dims ([..., freq, time])
     norm_true = torch.linalg.norm(stft_mag_true, ord='fro', dim=(-2, -1))
     norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, ord='fro', dim=(-2, -1))
 
-    # Average loss over the batch
     loss = torch.mean(norm_diff / (norm_true + eps))
     return loss
 
@@ -77,16 +72,13 @@ def discriminator_train(high_quality, low_quality, real_labels, fake_labels, dis
     discriminator_decision_from_real = discriminator(high_quality[0])
     d_loss_real = criterion(discriminator_decision_from_real, real_labels)
 
-    # Forward pass for fake samples (from generator output)
-    with torch.no_grad(): # Detach generator output within no_grad context
+    with torch.no_grad():
         generator_output = generator(low_quality[0])
-    discriminator_decision_from_fake = discriminator(generator_output) # No need to detach again if inside no_grad
+    discriminator_decision_from_fake = discriminator(generator_output)
     d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake))
 
-    # Combine real and fake losses
     d_loss = (d_loss_real + d_loss_fake) / 2.0
 
-    # Backward pass and optimization
     d_loss.backward()
     # Optional: Gradient Clipping (can be helpful)
     # nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)  # Gradient Clipping
@@ -100,65 +92,53 @@ def generator_train(
     real_labels,
     generator,
     discriminator,
-    adv_criterion, # Criterion for adversarial loss (e.g., BCEWithLogitsLoss)
+    adv_criterion,
     g_optimizer,
     device,
-    # --- Pass necessary transforms and loss weights ---
-    mel_transform: T.MelSpectrogram, # Example: Pass Mel transform
-    # stft_transform: T.Spectrogram, # Pass STFT transform if using STFT losses
-    # mfcc_transform: T.MFCC,      # Pass MFCC transform if using MFCC loss
-    lambda_adv: float = 1.0,       # Weight for adversarial loss
-    lambda_mel_l1: float = 10.0,   # Example: Weight for Mel L1 loss
-    # lambda_log_stft: float = 0.0, # Set weights > 0 for losses you want to use
-    # lambda_mfcc: float = 0.0
+    mel_transform: T.MelSpectrogram,
+    stft_transform: T.Spectrogram,
+    mfcc_transform: T.MFCC,
+    lambda_adv: float = 1.0,
+    lambda_mel_l1: float = 10.0,
+    lambda_log_stft: float = 1.0,
+    lambda_mfcc: float = 1.0
 ):
     g_optimizer.zero_grad()
 
-    # 1. Generate high-quality audio from low-quality input
     generator_output = generator(low_quality[0])
 
-    # 2. Calculate Adversarial Loss (Generator tries to fool discriminator)
     discriminator_decision = discriminator(generator_output)
-    # Generator wants discriminator to output "real" labels for its fakes
     adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision))
 
-    # 3. Calculate Reconstruction/Spectrogram Loss(es)
-    # --- Choose and calculate the losses you want to include ---
     mel_l1 = 0.0
-    # log_stft_l1 = 0.0
-    # mfcc_l = 0.0
+    log_stft_l1 = 0.0
+    mfcc_l = 0.0
 
     # Calculate Mel L1 Loss if weight is positive
     if lambda_mel_l1 > 0:
         mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality[0], generator_output)
 
-    # # Calculate Log STFT L1 Loss if weight is positive
-    # if lambda_log_stft > 0:
-    #     log_stft_l1 = log_stft_magnitude_loss(stft_transform, hq_audio, generator_output)
+    # Calculate Log STFT L1 Loss if weight is positive
+    if lambda_log_stft > 0:
+        log_stft_l1 = log_stft_magnitude_loss(stft_transform, high_quality[0], generator_output)
 
-    # # Calculate MFCC Loss if weight is positive
-    # if lambda_mfcc > 0:
-    #     mfcc_l = gpu_mfcc_loss(mfcc_transform, hq_audio, generator_output)
-    # --- End of Loss Calculation Choices ---
+    # Calculate MFCC Loss if weight is positive
+    if lambda_mfcc > 0:
+        mfcc_l = gpu_mfcc_loss(mfcc_transform, high_quality[0], generator_output)
 
-
-    # 4. Combine Losses
-    # Make sure calculated losses are tensors even if weights are 0 initially
-    # (or handle appropriately in the sum)
     mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1
-    # log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1
-    # mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l
+    log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1
+    mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l
 
     combined_loss = (lambda_adv * adversarial_loss) + \
-                    (lambda_mel_l1 * mel_l1_tensor)
-                    # + (lambda_log_stft * log_stft_l1_tensor) \
-                    # + (lambda_mfcc * mfcc_l_tensor)
+                    (lambda_mel_l1 * mel_l1_tensor) + \
+                    (lambda_log_stft * log_stft_l1_tensor) + \
+                    (lambda_mfcc * mfcc_l_tensor)
 
-    # 5. Backward Pass and Optimization
     combined_loss.backward()
     # Optional: Gradient Clipping
     # nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
     g_optimizer.step()
 
     # 6. Return values for logging
-    return generator_output, combined_loss, adversarial_loss, mel_l1_tensor
+    return generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor