⚗️ | Added MultiPeriodDiscriminator implementation from HiFi-GAN

This commit is contained in:
2025-12-06 18:04:18 +02:00
parent bf0a6e58e9
commit e3e555794e
3 changed files with 187 additions and 125 deletions

View File

@@ -1,98 +1,179 @@
import torch
import torch.nn as nn
import torch.nn.utils as utils
import numpy as np
import torch.nn.functional as F
from torch.nn.utils.parametrizations import weight_norm, spectral_norm
class PatchEmbedding(nn.Module):
"""
Converts raw audio into a sequence of embeddings (tokens).
Small patch_size = Higher Precision (more tokens, finer detail).
Large patch_size = Lower Precision (fewer tokens, more global).
"""
def __init__(self, in_channels, embed_dim, patch_size, spectral_norm=True):
super().__init__()
# We use a Conv1d with stride=patch_size to create non-overlapping patches
self.proj = nn.Conv1d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
# -------------------------------------------------------------------
# 1. Multi-Period Discriminator (MPD)
# Captures periodic structures (pitch/timbre) by folding audio.
# -------------------------------------------------------------------
if spectral_norm:
self.proj = utils.spectral_norm(self.proj)
class DiscriminatorP(nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
# Use spectral_norm for stability, or weight_norm for performance
norm_f = spectral_norm if use_spectral_norm else weight_norm
# We use 2D convs because we "fold" the 1D audio into 2D (Period x Time)
self.convs = nn.ModuleList([
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(2, 0))),
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(2, 0))),
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(2, 0))),
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(2, 0))),
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
])
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
# x shape: (batch, 1, 8000)
x = self.proj(x) # shape: (batch, embed_dim, num_patches)
x = x.transpose(1, 2) # shape: (batch, num_patches, embed_dim)
return x
fmap = []
class TransformerDiscriminator(nn.Module):
def __init__(
self,
audio_length=8000,
patch_size=16, # Lower this for higher precision (e.g., 8 or 16)
embed_dim=128, # Dimension of the transformer tokens
depth=4, # Number of Transformer blocks
heads=4, # Number of attention heads
mlp_dim=256, # Hidden dimension of the feed-forward layer
spectral_norm=True
):
super().__init__()
# 1. Calculate sequence length
self.num_patches = audio_length // patch_size
# 2. Patch Embedding (Tokenizer)
self.patch_embed = PatchEmbedding(1, embed_dim, patch_size, spectral_norm)
# 3. Class Token (like in BERT/ViT) to aggregate global info
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# 4. Positional Embedding (Learnable)
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
# 5. Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=heads,
dim_feedforward=mlp_dim,
dropout=0.1,
activation='gelu',
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
# 6. Final Classification Head
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, 1)
if spectral_norm:
self.head = utils.spectral_norm(self.head)
# Initialize weights
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.cls_token, std=0.02)
nn.init.normal_(self.pos_embed, std=0.02)
def forward(self, x):
# 1d to 2d conversion: [B, C, T] -> [B, C, T/P, P]
b, c, t = x.shape
if t % self.period != 0: # Pad if not divisible by period
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
# --- 1. Tokenize Audio ---
x = self.patch_embed(x) # (Batch, Num_Patches, Embed_Dim)
x = x.view(b, c, t // self.period, self.period)
# --- 2. Add CLS Token ---
cls_tokens = self.cls_token.expand(b, -1, -1)
x = torch.cat((cls_tokens, x), dim=1) # (Batch, Num_Patches + 1, Embed_Dim)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.1)
fmap.append(x) # Store feature map for Feature Matching Loss
# --- 3. Add Positional Embeddings ---
x = x + self.pos_embed
x = self.conv_post(x)
fmap.append(x)
# --- 4. Transformer Layers ---
x = self.transformer(x)
# Flatten back to 1D for score
x = torch.flatten(x, 1, -1)
# --- 5. Classification (Use only CLS token) ---
cls_output = x[:, 0] # Take the first token
cls_output = self.norm(cls_output)
return x, fmap
score = self.head(cls_output) # (Batch, 1)
return score
class MultiPeriodDiscriminator(nn.Module):
def __init__(self, periods=[2, 3, 5, 7, 11]):
super(MultiPeriodDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([
DiscriminatorP(p) for p in periods
])
def forward(self, y, y_hat):
y_d_rs = [] # Real scores
y_d_gs = [] # Generated (Fake) scores
fmap_rs = [] # Real feature maps
fmap_gs = [] # Generated (Fake) feature maps
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
# -------------------------------------------------------------------
# 2. Multi-Scale Discriminator (MSD)
# Captures structure at different audio resolutions (raw, x0.5, x0.25).
# -------------------------------------------------------------------
class DiscriminatorS(nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = spectral_norm if use_spectral_norm else weight_norm
# Standard 1D Convolutions with large receptive field
self.convs = nn.ModuleList([
norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)),
norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
])
self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.1)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiScaleDiscriminator(nn.Module):
def __init__(self):
super(MultiScaleDiscriminator, self).__init__()
# 3 Scales: Original, Downsampled x2, Downsampled x4
self.discriminators = nn.ModuleList([
DiscriminatorS(use_spectral_norm=True),
DiscriminatorS(),
DiscriminatorS(),
])
self.meanpools = nn.ModuleList([
nn.AvgPool1d(4, 2, padding=2),
nn.AvgPool1d(4, 2, padding=2)
])
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
if i != 0:
# Downsample input for subsequent discriminators
y = self.meanpools[i-1](y)
y_hat = self.meanpools[i-1](y_hat)
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
# -------------------------------------------------------------------
# 3. Master Wrapper
# Combines MPD and MSD into one class to fit your training script.
# -------------------------------------------------------------------
class SISUDiscriminator(nn.Module):
def __init__(self):
super(SISUDiscriminator, self).__init__()
self.mpd = MultiPeriodDiscriminator()
self.msd = MultiScaleDiscriminator()
def forward(self, y, y_hat):
# Return format:
# scores_real, scores_fake, features_real, features_fake
# Run Multi-Period
mpd_y_d_rs, mpd_y_d_gs, mpd_fmap_rs, mpd_fmap_gs = self.mpd(y, y_hat)
# Run Multi-Scale
msd_y_d_rs, msd_y_d_gs, msd_fmap_rs, msd_fmap_gs = self.msd(y, y_hat)
# Combine all results
return (
mpd_y_d_rs + msd_y_d_rs, # All real scores
mpd_y_d_gs + msd_y_d_gs, # All fake scores
mpd_fmap_rs + msd_fmap_rs, # All real feature maps
mpd_fmap_gs + msd_fmap_gs # All fake feature maps
)

View File

@@ -3,6 +3,7 @@ import datetime
import os
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm
from accelerate import Accelerator
@@ -39,10 +40,13 @@ accelerator = Accelerator(mixed_precision="bf16")
# Models
# ---------------------------
generator = SISUGenerator()
# Note: SISUDiscriminator is now an Ensemble (MPD + MSD)
discriminator = SISUDiscriminator()
accelerator.print("🔨 | Compiling models...")
# Torch compile is great, but if you hit errors with the new List/Tuple outputs
# of the discriminator, you might need to disable it for D.
generator = torch.compile(generator)
discriminator = torch.compile(discriminator)
@@ -108,21 +112,24 @@ models_dir = "./models"
os.makedirs(models_dir, exist_ok=True)
def save_ckpt(path, epoch):
def save_ckpt(path, epoch, loss=None, is_best=False):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
accelerator.save(
{
state = {
"epoch": epoch,
"G": accelerator.unwrap_model(generator).state_dict(),
"D": accelerator.unwrap_model(discriminator).state_dict(),
"optG": optimizer_g.state_dict(),
"optD": optimizer_d.state_dict(),
"schedG": scheduler_g.state_dict(),
"schedD": scheduler_d.state_dict(),
},
path,
)
"schedD": scheduler_d.state_dict()
}
accelerator.save(state, os.path.join(models_dir, "last.pt"))
if is_best:
accelerator.save(state, os.path.join(models_dir, "best.pt"))
accelerator.print(f"🌟 | New best model saved with G Loss: {loss:.4f}")
start_epoch = 0
@@ -143,9 +150,8 @@ if args.resume:
else:
accelerator.print("⚠️ | Resume requested but no checkpoint found. Starting fresh.")
accelerator.print("🏋️ | Started training...")
smallest_loss = float('inf')
accelerator.print("🏋️ | Started training...")
try:
for epoch in range(start_epoch, args.epochs):
@@ -172,7 +178,6 @@ try:
with accelerator.autocast():
d_loss = discriminator_train(
high_quality,
low_quality.detach(),
discriminator,
generator_output.detach()
)
@@ -218,7 +223,6 @@ try:
steps += 1
progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs")
# epoch averages & schedulers
if steps == 0:
accelerator.print("🪹 | No steps in epoch (empty dataloader?). Exiting.")
break
@@ -230,9 +234,6 @@ try:
scheduler_g.step(mean_g)
save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
if smallest_loss > mean_g:
smallest_loss = mean_g
save_ckpt(os.path.join(models_dir, "latest-smallest_loss.pt"), epoch)
accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
except Exception:

View File

@@ -2,13 +2,14 @@ import torch
import torch.nn.functional as F
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
# Keep STFT settings as is
stft_loss_fn = MultiResolutionSTFTLoss(
fft_sizes=[512, 1024, 2048],
hop_sizes=[64, 128, 256],
win_lengths=[256, 512, 1024]
)
def feature_matching_loss(fmap_r, fmap_g):
"""
Computes L1 distance between real and fake feature maps.
@@ -16,11 +17,9 @@ def feature_matching_loss(fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
# Stop gradient on real features to save memory/computation
rl = rl.detach()
loss += torch.mean(torch.abs(rl - gl))
# Scale by number of feature maps to keep loss magnitude reasonable
return loss * 2
@@ -33,11 +32,8 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
r_losses = []
g_losses = []
# Iterate over both MPD and MSD outputs
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
# Real should be 1.0
r_loss = torch.mean((dr - 1) ** 2)
# Fake should be 0.0
g_loss = torch.mean(dg ** 2)
loss += (r_loss + g_loss)
@@ -61,16 +57,11 @@ def generator_adv_loss(disc_generated_outputs):
def discriminator_train(
high_quality,
low_quality,
discriminator,
generator_output
):
# 1. Forward pass through the Ensemble Discriminator
# Note: We pass inputs separately now: (Real_Target, Fake_Candidate)
# We detach generator_output because we are only optimizing D here
y_d_rs, y_d_gs, _, _ = discriminator(high_quality, generator_output.detach())
# 2. Calculate Loss (LSGAN)
d_loss, _, _ = discriminator_loss(y_d_rs, y_d_gs)
return d_loss
@@ -83,25 +74,14 @@ def generator_train(
discriminator,
generator_output
):
# 1. Forward pass through Discriminator
# We do NOT detach generator_output here, we need gradients for G
y_d_rs, y_d_gs, fmap_rs, fmap_gs = discriminator(high_quality, generator_output)
# 2. Adversarial Loss (Try to fool D into thinking G is Real)
loss_gen_adv = generator_adv_loss(y_d_gs)
# 3. Feature Matching Loss (Force G to match internal features of D)
loss_fm = feature_matching_loss(fmap_rs, fmap_gs)
# 4. Mel-Spectrogram / STFT Loss (Audio Quality)
stft_loss = stft_loss_fn(high_quality, generator_output)["total"]
# -----------------------------------------
# 5. Combine Losses
# -----------------------------------------
# Recommended weights for HiFi-GAN/EnCodec style architectures:
# STFT is dominant (45), FM provides stability (2), Adv provides texture (1)
lambda_stft = 45.0
lambda_fm = 2.0
lambda_adv = 1.0