diff --git a/discriminator.py b/discriminator.py index 8e9e2ec..0572d17 100644 --- a/discriminator.py +++ b/discriminator.py @@ -1,98 +1,179 @@ import torch import torch.nn as nn -import torch.nn.utils as utils -import numpy as np +import torch.nn.functional as F +from torch.nn.utils.parametrizations import weight_norm, spectral_norm -class PatchEmbedding(nn.Module): - """ - Converts raw audio into a sequence of embeddings (tokens). - Small patch_size = Higher Precision (more tokens, finer detail). - Large patch_size = Lower Precision (fewer tokens, more global). - """ - def __init__(self, in_channels, embed_dim, patch_size, spectral_norm=True): - super().__init__() - # We use a Conv1d with stride=patch_size to create non-overlapping patches - self.proj = nn.Conv1d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) +# ------------------------------------------------------------------- +# 1. Multi-Period Discriminator (MPD) +# Captures periodic structures (pitch/timbre) by folding audio. +# ------------------------------------------------------------------- - if spectral_norm: - self.proj = utils.spectral_norm(self.proj) +class DiscriminatorP(nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + + # Use spectral_norm for stability, or weight_norm for performance + norm_f = spectral_norm if use_spectral_norm else weight_norm + + # We use 2D convs because we "fold" the 1D audio into 2D (Period x Time) + self.convs = nn.ModuleList([ + norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(2, 0))), + norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(2, 0))), + norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(2, 0))), + norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(2, 0))), + norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + + self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) def forward(self, x): - # x shape: (batch, 1, 8000) - x = self.proj(x) # shape: (batch, embed_dim, num_patches) - x = x.transpose(1, 2) # shape: (batch, num_patches, embed_dim) - return x + fmap = [] -class TransformerDiscriminator(nn.Module): - def __init__( - self, - audio_length=8000, - patch_size=16, # Lower this for higher precision (e.g., 8 or 16) - embed_dim=128, # Dimension of the transformer tokens - depth=4, # Number of Transformer blocks - heads=4, # Number of attention heads - mlp_dim=256, # Hidden dimension of the feed-forward layer - spectral_norm=True - ): - super().__init__() - - # 1. Calculate sequence length - self.num_patches = audio_length // patch_size - - # 2. Patch Embedding (Tokenizer) - self.patch_embed = PatchEmbedding(1, embed_dim, patch_size, spectral_norm) - - # 3. Class Token (like in BERT/ViT) to aggregate global info - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - - # 4. Positional Embedding (Learnable) - self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim)) - - # 5. Transformer Encoder - encoder_layer = nn.TransformerEncoderLayer( - d_model=embed_dim, - nhead=heads, - dim_feedforward=mlp_dim, - dropout=0.1, - activation='gelu', - batch_first=True - ) - self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth) - - # 6. Final Classification Head - self.norm = nn.LayerNorm(embed_dim) - self.head = nn.Linear(embed_dim, 1) - - if spectral_norm: - self.head = utils.spectral_norm(self.head) - - # Initialize weights - self._init_weights() - - def _init_weights(self): - nn.init.normal_(self.cls_token, std=0.02) - nn.init.normal_(self.pos_embed, std=0.02) - - def forward(self, x): + # 1d to 2d conversion: [B, C, T] -> [B, C, T/P, P] b, c, t = x.shape + if t % self.period != 0: # Pad if not divisible by period + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad - # --- 1. Tokenize Audio --- - x = self.patch_embed(x) # (Batch, Num_Patches, Embed_Dim) + x = x.view(b, c, t // self.period, self.period) - # --- 2. Add CLS Token --- - cls_tokens = self.cls_token.expand(b, -1, -1) - x = torch.cat((cls_tokens, x), dim=1) # (Batch, Num_Patches + 1, Embed_Dim) + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, 0.1) + fmap.append(x) # Store feature map for Feature Matching Loss - # --- 3. Add Positional Embeddings --- - x = x + self.pos_embed + x = self.conv_post(x) + fmap.append(x) - # --- 4. Transformer Layers --- - x = self.transformer(x) + # Flatten back to 1D for score + x = torch.flatten(x, 1, -1) - # --- 5. Classification (Use only CLS token) --- - cls_output = x[:, 0] # Take the first token - cls_output = self.norm(cls_output) + return x, fmap - score = self.head(cls_output) # (Batch, 1) - return score +class MultiPeriodDiscriminator(nn.Module): + def __init__(self, periods=[2, 3, 5, 7, 11]): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorP(p) for p in periods + ]) + + def forward(self, y, y_hat): + y_d_rs = [] # Real scores + y_d_gs = [] # Generated (Fake) scores + fmap_rs = [] # Real feature maps + fmap_gs = [] # Generated (Fake) feature maps + + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +# ------------------------------------------------------------------- +# 2. Multi-Scale Discriminator (MSD) +# Captures structure at different audio resolutions (raw, x0.5, x0.25). +# ------------------------------------------------------------------- + +class DiscriminatorS(nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = spectral_norm if use_spectral_norm else weight_norm + + # Standard 1D Convolutions with large receptive field + self.convs = nn.ModuleList([ + norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)), + norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, 0.1) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + return x, fmap + + +class MultiScaleDiscriminator(nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + # 3 Scales: Original, Downsampled x2, Downsampled x4 + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList([ + nn.AvgPool1d(4, 2, padding=2), + nn.AvgPool1d(4, 2, padding=2) + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for i, d in enumerate(self.discriminators): + if i != 0: + # Downsample input for subsequent discriminators + y = self.meanpools[i-1](y) + y_hat = self.meanpools[i-1](y_hat) + + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +# ------------------------------------------------------------------- +# 3. Master Wrapper +# Combines MPD and MSD into one class to fit your training script. +# ------------------------------------------------------------------- + +class SISUDiscriminator(nn.Module): + def __init__(self): + super(SISUDiscriminator, self).__init__() + self.mpd = MultiPeriodDiscriminator() + self.msd = MultiScaleDiscriminator() + + def forward(self, y, y_hat): + # Return format: + # scores_real, scores_fake, features_real, features_fake + + # Run Multi-Period + mpd_y_d_rs, mpd_y_d_gs, mpd_fmap_rs, mpd_fmap_gs = self.mpd(y, y_hat) + + # Run Multi-Scale + msd_y_d_rs, msd_y_d_gs, msd_fmap_rs, msd_fmap_gs = self.msd(y, y_hat) + + # Combine all results + return ( + mpd_y_d_rs + msd_y_d_rs, # All real scores + mpd_y_d_gs + msd_y_d_gs, # All fake scores + mpd_fmap_rs + msd_fmap_rs, # All real feature maps + mpd_fmap_gs + msd_fmap_gs # All fake feature maps + ) diff --git a/training.py b/training.py index 35733be..8107b03 100644 --- a/training.py +++ b/training.py @@ -3,6 +3,7 @@ import datetime import os import torch +import torch.nn as nn import torch.optim as optim import tqdm from accelerate import Accelerator @@ -39,10 +40,13 @@ accelerator = Accelerator(mixed_precision="bf16") # Models # --------------------------- generator = SISUGenerator() +# Note: SISUDiscriminator is now an Ensemble (MPD + MSD) discriminator = SISUDiscriminator() accelerator.print("๐Ÿ”จ | Compiling models...") +# Torch compile is great, but if you hit errors with the new List/Tuple outputs +# of the discriminator, you might need to disable it for D. generator = torch.compile(generator) discriminator = torch.compile(discriminator) @@ -108,21 +112,24 @@ models_dir = "./models" os.makedirs(models_dir, exist_ok=True) -def save_ckpt(path, epoch): +def save_ckpt(path, epoch, loss=None, is_best=False): accelerator.wait_for_everyone() if accelerator.is_main_process: - accelerator.save( - { - "epoch": epoch, - "G": accelerator.unwrap_model(generator).state_dict(), - "D": accelerator.unwrap_model(discriminator).state_dict(), - "optG": optimizer_g.state_dict(), - "optD": optimizer_d.state_dict(), - "schedG": scheduler_g.state_dict(), - "schedD": scheduler_d.state_dict(), - }, - path, - ) + state = { + "epoch": epoch, + "G": accelerator.unwrap_model(generator).state_dict(), + "D": accelerator.unwrap_model(discriminator).state_dict(), + "optG": optimizer_g.state_dict(), + "optD": optimizer_d.state_dict(), + "schedG": scheduler_g.state_dict(), + "schedD": scheduler_d.state_dict() + } + + accelerator.save(state, os.path.join(models_dir, "last.pt")) + + if is_best: + accelerator.save(state, os.path.join(models_dir, "best.pt")) + accelerator.print(f"๐ŸŒŸ | New best model saved with G Loss: {loss:.4f}") start_epoch = 0 @@ -143,9 +150,8 @@ if args.resume: else: accelerator.print("โš ๏ธ | Resume requested but no checkpoint found. Starting fresh.") -accelerator.print("๐Ÿ‹๏ธ | Started training...") -smallest_loss = float('inf') +accelerator.print("๐Ÿ‹๏ธ | Started training...") try: for epoch in range(start_epoch, args.epochs): @@ -172,7 +178,6 @@ try: with accelerator.autocast(): d_loss = discriminator_train( high_quality, - low_quality.detach(), discriminator, generator_output.detach() ) @@ -218,7 +223,6 @@ try: steps += 1 progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}ฮผs | G {generator_time}ฮผs") - # epoch averages & schedulers if steps == 0: accelerator.print("๐Ÿชน | No steps in epoch (empty dataloader?). Exiting.") break @@ -230,9 +234,6 @@ try: scheduler_g.step(mean_g) save_ckpt(os.path.join(models_dir, "last.pt"), epoch) - if smallest_loss > mean_g: - smallest_loss = mean_g - save_ckpt(os.path.join(models_dir, "latest-smallest_loss.pt"), epoch) accelerator.print(f"๐Ÿค | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}") except Exception: diff --git a/utils/TrainingTools.py b/utils/TrainingTools.py index 7ed2d0f..cd0350a 100644 --- a/utils/TrainingTools.py +++ b/utils/TrainingTools.py @@ -2,13 +2,14 @@ import torch import torch.nn.functional as F from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss - +# Keep STFT settings as is stft_loss_fn = MultiResolutionSTFTLoss( fft_sizes=[512, 1024, 2048], hop_sizes=[64, 128, 256], win_lengths=[256, 512, 1024] ) + def feature_matching_loss(fmap_r, fmap_g): """ Computes L1 distance between real and fake feature maps. @@ -16,11 +17,9 @@ def feature_matching_loss(fmap_r, fmap_g): loss = 0 for dr, dg in zip(fmap_r, fmap_g): for rl, gl in zip(dr, dg): - # Stop gradient on real features to save memory/computation rl = rl.detach() loss += torch.mean(torch.abs(rl - gl)) - # Scale by number of feature maps to keep loss magnitude reasonable return loss * 2 @@ -33,11 +32,8 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs): r_losses = [] g_losses = [] - # Iterate over both MPD and MSD outputs for dr, dg in zip(disc_real_outputs, disc_generated_outputs): - # Real should be 1.0 r_loss = torch.mean((dr - 1) ** 2) - # Fake should be 0.0 g_loss = torch.mean(dg ** 2) loss += (r_loss + g_loss) @@ -61,16 +57,11 @@ def generator_adv_loss(disc_generated_outputs): def discriminator_train( high_quality, - low_quality, discriminator, generator_output ): - # 1. Forward pass through the Ensemble Discriminator - # Note: We pass inputs separately now: (Real_Target, Fake_Candidate) - # We detach generator_output because we are only optimizing D here y_d_rs, y_d_gs, _, _ = discriminator(high_quality, generator_output.detach()) - # 2. Calculate Loss (LSGAN) d_loss, _, _ = discriminator_loss(y_d_rs, y_d_gs) return d_loss @@ -83,25 +74,14 @@ def generator_train( discriminator, generator_output ): - # 1. Forward pass through Discriminator - # We do NOT detach generator_output here, we need gradients for G y_d_rs, y_d_gs, fmap_rs, fmap_gs = discriminator(high_quality, generator_output) - # 2. Adversarial Loss (Try to fool D into thinking G is Real) loss_gen_adv = generator_adv_loss(y_d_gs) - # 3. Feature Matching Loss (Force G to match internal features of D) loss_fm = feature_matching_loss(fmap_rs, fmap_gs) - # 4. Mel-Spectrogram / STFT Loss (Audio Quality) stft_loss = stft_loss_fn(high_quality, generator_output)["total"] - # ----------------------------------------- - # 5. Combine Losses - # ----------------------------------------- - # Recommended weights for HiFi-GAN/EnCodec style architectures: - # STFT is dominant (45), FM provides stability (2), Adv provides texture (1) - lambda_stft = 45.0 lambda_fm = 2.0 lambda_adv = 1.0