diff --git a/discriminator.py b/discriminator.py index 69de5ce..8e9e2ec 100644 --- a/discriminator.py +++ b/discriminator.py @@ -1,70 +1,98 @@ +import torch import torch.nn as nn import torch.nn.utils as utils +import numpy as np +class PatchEmbedding(nn.Module): + """ + Converts raw audio into a sequence of embeddings (tokens). + Small patch_size = Higher Precision (more tokens, finer detail). + Large patch_size = Lower Precision (fewer tokens, more global). + """ + def __init__(self, in_channels, embed_dim, patch_size, spectral_norm=True): + super().__init__() + # We use a Conv1d with stride=patch_size to create non-overlapping patches + self.proj = nn.Conv1d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) -def discriminator_block( - in_channels, - out_channels, - kernel_size=15, - stride=1, - dilation=1 -): - padding = dilation * (kernel_size - 1) // 2 - - conv_layer = nn.Conv1d( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - padding=padding - ) - - conv_layer = utils.spectral_norm(conv_layer) - leaky_relu = nn.LeakyReLU(0.2) - - return nn.Sequential(conv_layer, leaky_relu) - - -class AttentionBlock(nn.Module): - def __init__(self, channels): - super(AttentionBlock, self).__init__() - self.attention = nn.Sequential( - nn.Conv1d(channels, channels // 4, kernel_size=1), - nn.ReLU(), - nn.Conv1d(channels // 4, channels, kernel_size=1), - nn.Sigmoid(), - ) + if spectral_norm: + self.proj = utils.spectral_norm(self.proj) def forward(self, x): - attention_weights = self.attention(x) - return x + (x * attention_weights) + # x shape: (batch, 1, 8000) + x = self.proj(x) # shape: (batch, embed_dim, num_patches) + x = x.transpose(1, 2) # shape: (batch, num_patches, embed_dim) + return x +class TransformerDiscriminator(nn.Module): + def __init__( + self, + audio_length=8000, + patch_size=16, # Lower this for higher precision (e.g., 8 or 16) + embed_dim=128, # Dimension of the transformer tokens + depth=4, # Number of Transformer blocks + heads=4, # Number of attention heads + mlp_dim=256, # Hidden dimension of the feed-forward layer + spectral_norm=True + ): + super().__init__() -class SISUDiscriminator(nn.Module): - def __init__(self, layers=8): - super(SISUDiscriminator, self).__init__() - self.discriminator_blocks = nn.Sequential( - # 1 -> 32 - discriminator_block(2, layers), - AttentionBlock(layers), - # 32 -> 64 - discriminator_block(layers, layers * 2, dilation=2), - # 64 -> 128 - discriminator_block(layers * 2, layers * 4, dilation=4), - AttentionBlock(layers * 4), - # 128 -> 256 - discriminator_block(layers * 4, layers * 8, stride=4), - # 256 -> 512 - # discriminator_block(layers * 8, layers * 16, stride=4) + # 1. Calculate sequence length + self.num_patches = audio_length // patch_size + + # 2. Patch Embedding (Tokenizer) + self.patch_embed = PatchEmbedding(1, embed_dim, patch_size, spectral_norm) + + # 3. Class Token (like in BERT/ViT) to aggregate global info + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + # 4. Positional Embedding (Learnable) + self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim)) + + # 5. Transformer Encoder + encoder_layer = nn.TransformerEncoderLayer( + d_model=embed_dim, + nhead=heads, + dim_feedforward=mlp_dim, + dropout=0.1, + activation='gelu', + batch_first=True ) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth) - self.final_conv = nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1) + # 6. Final Classification Head + self.norm = nn.LayerNorm(embed_dim) + self.head = nn.Linear(embed_dim, 1) - self.avg_pool = nn.AdaptiveAvgPool1d(1) + if spectral_norm: + self.head = utils.spectral_norm(self.head) + + # Initialize weights + self._init_weights() + + def _init_weights(self): + nn.init.normal_(self.cls_token, std=0.02) + nn.init.normal_(self.pos_embed, std=0.02) def forward(self, x): - x = self.discriminator_blocks(x) - x = self.final_conv(x) - x = self.avg_pool(x) - return x.squeeze(2) + b, c, t = x.shape + + # --- 1. Tokenize Audio --- + x = self.patch_embed(x) # (Batch, Num_Patches, Embed_Dim) + + # --- 2. Add CLS Token --- + cls_tokens = self.cls_token.expand(b, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) # (Batch, Num_Patches + 1, Embed_Dim) + + # --- 3. Add Positional Embeddings --- + x = x + self.pos_embed + + # --- 4. Transformer Layers --- + x = self.transformer(x) + + # --- 5. Classification (Use only CLS token) --- + cls_output = x[:, 0] # Take the first token + cls_output = self.norm(cls_output) + + score = self.head(cls_output) # (Batch, 1) + + return score diff --git a/generator.py b/generator.py index bc994ac..15279b1 100644 --- a/generator.py +++ b/generator.py @@ -1,19 +1,20 @@ import torch import torch.nn as nn - +from torch.nn.utils.parametrizations import weight_norm def GeneratorBlock(in_channels, out_channels, kernel_size=3, stride=1, dilation=1): padding = (kernel_size - 1) // 2 * dilation + return nn.Sequential( - nn.Conv1d( + + weight_norm(nn.Conv1d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding - ), - nn.InstanceNorm1d(out_channels), + )), nn.PReLU(num_parameters=1, init=0.1), ) @@ -22,9 +23,9 @@ class AttentionBlock(nn.Module): def __init__(self, channels): super(AttentionBlock, self).__init__() self.attention = nn.Sequential( - nn.Conv1d(channels, channels // 4, kernel_size=1), + weight_norm(nn.Conv1d(channels, channels // 4, kernel_size=1)), nn.ReLU(inplace=True), - nn.Conv1d(channels // 4, channels, kernel_size=1), + weight_norm(nn.Conv1d(channels // 4, channels, kernel_size=1)), nn.Sigmoid(), ) @@ -49,21 +50,21 @@ class ResidualInResidualBlock(nn.Module): x = self.attention(x) return x + residual -def UpsampleBlock(in_channels, out_channels): +def UpsampleBlock(in_channels, out_channels, scale_factor=2): return nn.Sequential( - nn.ConvTranspose1d( + nn.Upsample(scale_factor=scale_factor, mode='nearest'), + weight_norm(nn.Conv1d( in_channels=in_channels, out_channels=out_channels, - kernel_size=4, - stride=2, + kernel_size=3, + stride=1, padding=1 - ), - nn.InstanceNorm1d(out_channels), + )), nn.PReLU(num_parameters=1, init=0.1) ) class SISUGenerator(nn.Module): - def __init__(self, channels=32, num_rirb=1): + def __init__(self, channels=32, num_rirb=4): super(SISUGenerator, self).__init__() self.first_conv = GeneratorBlock(1, channels) @@ -73,10 +74,9 @@ class SISUGenerator(nn.Module): self.downsample_2 = GeneratorBlock(channels * 2, channels * 4, stride=2) self.downsample_2_attn = AttentionBlock(channels * 4) - self.rirb = ResidualInResidualBlock(channels * 4) - # self.rirb = nn.Sequential( - # *[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)] - # ) + self.rirb = nn.Sequential( + *[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)] + ) self.upsample = UpsampleBlock(channels * 4, channels * 2) self.upsample_attn = AttentionBlock(channels * 2) @@ -87,13 +87,15 @@ class SISUGenerator(nn.Module): self.compress_2 = GeneratorBlock(channels * 2, channels) self.final_conv = nn.Sequential( - nn.Conv1d(channels, 1, kernel_size=7, padding=3), + weight_norm(nn.Conv1d(channels, 1, kernel_size=7, padding=3)), nn.Tanh() ) def forward(self, x): residual_input = x + + # Encoding x1 = self.first_conv(x) x2 = self.downsample(x1) @@ -102,8 +104,10 @@ class SISUGenerator(nn.Module): x3 = self.downsample_2(x2) x3 = self.downsample_2_attn(x3) + # Bottleneck (Deep Residual processing) x_rirb = self.rirb(x3) + # Decoding with Skip Connections up1 = self.upsample(x_rirb) up1 = self.upsample_attn(up1) diff --git a/training.py b/training.py index 6f962f3..35733be 100644 --- a/training.py +++ b/training.py @@ -3,7 +3,6 @@ import datetime import os import torch -import torch.nn as nn import torch.optim as optim import tqdm from accelerate import Accelerator @@ -23,7 +22,7 @@ parser.add_argument( "--epochs", type=int, default=5000, help="Number of training epochs" ) parser.add_argument("--batch_size", type=int, default=8, help="Batch size") -parser.add_argument("--num_workers", type=int, default=2, help="DataLoader num_workers") +parser.add_argument("--num_workers", type=int, default=4, help="DataLoader num_workers") # Increased workers slightly parser.add_argument("--debug", action="store_true", help="Print debug logs") parser.add_argument( "--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA" @@ -94,8 +93,6 @@ scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_d, mode="min", factor=0.5, patience=5 ) -criterion_d = nn.MSELoss() - # --------------------------- # Prepare accelerator # --------------------------- @@ -131,23 +128,25 @@ def save_ckpt(path, epoch): start_epoch = 0 if args.resume: ckpt_path = os.path.join(models_dir, "last.pt") - ckpt = torch.load(ckpt_path) + if os.path.exists(ckpt_path): + ckpt = torch.load(ckpt_path) - accelerator.unwrap_model(generator).load_state_dict(ckpt["G"]) - accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"]) - optimizer_g.load_state_dict(ckpt["optG"]) - optimizer_d.load_state_dict(ckpt["optD"]) - scheduler_g.load_state_dict(ckpt["schedG"]) - scheduler_d.load_state_dict(ckpt["schedD"]) + accelerator.unwrap_model(generator).load_state_dict(ckpt["G"]) + accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"]) + optimizer_g.load_state_dict(ckpt["optG"]) + optimizer_d.load_state_dict(ckpt["optD"]) + scheduler_g.load_state_dict(ckpt["schedG"]) + scheduler_d.load_state_dict(ckpt["schedD"]) - start_epoch = ckpt.get("epoch", 1) - accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!") - -real_buf = torch.full((loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32) -fake_buf = torch.zeros((loader_batch_size, 1), device=accelerator.device, dtype=torch.float32) + start_epoch = ckpt.get("epoch", 1) + accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!") + else: + accelerator.print("⚠️ | Resume requested but no checkpoint found. Starting fresh.") accelerator.print("🏋️ | Started training...") +smallest_loss = float('inf') + try: for epoch in range(start_epoch, args.epochs): generator.train() @@ -164,11 +163,6 @@ try: (high_quality, low_quality), (high_sample_rate, low_sample_rate), ) in enumerate(progress_bar): - batch_size = high_quality.size(0) - - real_labels = real_buf[:batch_size].to(accelerator.device) - fake_labels = fake_buf[:batch_size].to(accelerator.device) - with accelerator.autocast(): generator_output = generator(low_quality) @@ -179,10 +173,7 @@ try: d_loss = discriminator_train( high_quality, low_quality.detach(), - real_labels, - fake_labels, discriminator, - criterion_d, generator_output.detach() ) @@ -197,10 +188,8 @@ try: g_total, g_adv = generator_train( low_quality, high_quality, - real_labels, generator, discriminator, - criterion_d, generator_output ) @@ -241,6 +230,9 @@ try: scheduler_g.step(mean_g) save_ckpt(os.path.join(models_dir, "last.pt"), epoch) + if smallest_loss > mean_g: + smallest_loss = mean_g + save_ckpt(os.path.join(models_dir, "latest-smallest_loss.pt"), epoch) accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}") except Exception: diff --git a/utils/TrainingTools.py b/utils/TrainingTools.py index 581a890..7ed2d0f 100644 --- a/utils/TrainingTools.py +++ b/utils/TrainingTools.py @@ -1,58 +1,113 @@ import torch - +import torch.nn.functional as F from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss -# stft_loss_fn = MultiResolutionSTFTLoss( -# fft_sizes=[512, 1024, 2048, 4096], -# hop_sizes=[128, 256, 512, 1024], -# win_lengths=[512, 1024, 2048, 4096] -# ) + stft_loss_fn = MultiResolutionSTFTLoss( fft_sizes=[512, 1024, 2048], hop_sizes=[64, 128, 256], win_lengths=[256, 512, 1024] ) -def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor: - absolute_difference = torch.abs(input_one - input_two) - return torch.mean(absolute_difference) +def feature_matching_loss(fmap_r, fmap_g): + """ + Computes L1 distance between real and fake feature maps. + """ + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + # Stop gradient on real features to save memory/computation + rl = rl.detach() + loss += torch.mean(torch.abs(rl - gl)) + + # Scale by number of feature maps to keep loss magnitude reasonable + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + """ + Least Squares GAN Loss (LSGAN) for the Discriminator. + Objective: Real -> 1, Fake -> 0 + """ + loss = 0 + r_losses = [] + g_losses = [] + + # Iterate over both MPD and MSD outputs + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + # Real should be 1.0 + r_loss = torch.mean((dr - 1) ** 2) + # Fake should be 0.0 + g_loss = torch.mean(dg ** 2) + + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_adv_loss(disc_generated_outputs): + """ + Least Squares GAN Loss for the Generator. + Objective: Fake -> 1 (Fool the discriminator) + """ + loss = 0 + for dg in zip(disc_generated_outputs): + dg = dg[0] # Unpack tuple + loss += torch.mean((dg - 1) ** 2) + return loss def discriminator_train( high_quality, low_quality, - high_labels, - low_labels, discriminator, - criterion, generator_output ): + # 1. Forward pass through the Ensemble Discriminator + # Note: We pass inputs separately now: (Real_Target, Fake_Candidate) + # We detach generator_output because we are only optimizing D here + y_d_rs, y_d_gs, _, _ = discriminator(high_quality, generator_output.detach()) - real_pair = torch.cat((low_quality, high_quality), dim=1) - decision_real = discriminator(real_pair) - d_loss_real = criterion(decision_real, high_labels) + # 2. Calculate Loss (LSGAN) + d_loss, _, _ = discriminator_loss(y_d_rs, y_d_gs) - fake_pair = torch.cat((low_quality, generator_output), dim=1) - decision_fake = discriminator(fake_pair) - d_loss_fake = criterion(decision_fake, low_labels) - - d_loss = (d_loss_real + d_loss_fake) / 2.0 return d_loss def generator_train( - low_quality, high_quality, real_labels, generator, discriminator, adv_criterion, generator_output): + low_quality, + high_quality, + generator, + discriminator, + generator_output +): + # 1. Forward pass through Discriminator + # We do NOT detach generator_output here, we need gradients for G + y_d_rs, y_d_gs, fmap_rs, fmap_gs = discriminator(high_quality, generator_output) - fake_pair = torch.cat((low_quality, generator_output), dim=1) + # 2. Adversarial Loss (Try to fool D into thinking G is Real) + loss_gen_adv = generator_adv_loss(y_d_gs) - discriminator_decision = discriminator(fake_pair) - adversarial_loss = adv_criterion(discriminator_decision, real_labels) + # 3. Feature Matching Loss (Force G to match internal features of D) + loss_fm = feature_matching_loss(fmap_rs, fmap_gs) - mae_loss = signal_mae(generator_output, high_quality) + # 4. Mel-Spectrogram / STFT Loss (Audio Quality) stft_loss = stft_loss_fn(high_quality, generator_output)["total"] - lambda_mae = 10.0 - lambda_stft = 2.5 - lambda_adv = 2.5 - combined_loss = (lambda_mae * mae_loss) + (lambda_stft * stft_loss) + (lambda_adv * adversarial_loss) - return combined_loss, adversarial_loss + # ----------------------------------------- + # 5. Combine Losses + # ----------------------------------------- + # Recommended weights for HiFi-GAN/EnCodec style architectures: + # STFT is dominant (45), FM provides stability (2), Adv provides texture (1) + + lambda_stft = 45.0 + lambda_fm = 2.0 + lambda_adv = 1.0 + + combined_loss = (lambda_stft * stft_loss) + \ + (lambda_fm * loss_fm) + \ + (lambda_adv * loss_gen_adv) + + return combined_loss, loss_gen_adv