⚗️ | Added MultiPeriodDiscriminator implementation from HiFi-GAN

This commit is contained in:
2025-12-06 18:04:18 +02:00
parent bf0a6e58e9
commit e3e555794e
3 changed files with 187 additions and 125 deletions

View File

@@ -1,98 +1,179 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.utils as utils import torch.nn.functional as F
import numpy as np from torch.nn.utils.parametrizations import weight_norm, spectral_norm
class PatchEmbedding(nn.Module): # -------------------------------------------------------------------
""" # 1. Multi-Period Discriminator (MPD)
Converts raw audio into a sequence of embeddings (tokens). # Captures periodic structures (pitch/timbre) by folding audio.
Small patch_size = Higher Precision (more tokens, finer detail). # -------------------------------------------------------------------
Large patch_size = Lower Precision (fewer tokens, more global).
"""
def __init__(self, in_channels, embed_dim, patch_size, spectral_norm=True):
super().__init__()
# We use a Conv1d with stride=patch_size to create non-overlapping patches
self.proj = nn.Conv1d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
if spectral_norm: class DiscriminatorP(nn.Module):
self.proj = utils.spectral_norm(self.proj) def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
# Use spectral_norm for stability, or weight_norm for performance
norm_f = spectral_norm if use_spectral_norm else weight_norm
# We use 2D convs because we "fold" the 1D audio into 2D (Period x Time)
self.convs = nn.ModuleList([
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(2, 0))),
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(2, 0))),
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(2, 0))),
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(2, 0))),
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
])
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x): def forward(self, x):
# x shape: (batch, 1, 8000) fmap = []
x = self.proj(x) # shape: (batch, embed_dim, num_patches)
x = x.transpose(1, 2) # shape: (batch, num_patches, embed_dim)
return x
class TransformerDiscriminator(nn.Module): # 1d to 2d conversion: [B, C, T] -> [B, C, T/P, P]
def __init__(
self,
audio_length=8000,
patch_size=16, # Lower this for higher precision (e.g., 8 or 16)
embed_dim=128, # Dimension of the transformer tokens
depth=4, # Number of Transformer blocks
heads=4, # Number of attention heads
mlp_dim=256, # Hidden dimension of the feed-forward layer
spectral_norm=True
):
super().__init__()
# 1. Calculate sequence length
self.num_patches = audio_length // patch_size
# 2. Patch Embedding (Tokenizer)
self.patch_embed = PatchEmbedding(1, embed_dim, patch_size, spectral_norm)
# 3. Class Token (like in BERT/ViT) to aggregate global info
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# 4. Positional Embedding (Learnable)
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
# 5. Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=heads,
dim_feedforward=mlp_dim,
dropout=0.1,
activation='gelu',
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
# 6. Final Classification Head
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, 1)
if spectral_norm:
self.head = utils.spectral_norm(self.head)
# Initialize weights
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.cls_token, std=0.02)
nn.init.normal_(self.pos_embed, std=0.02)
def forward(self, x):
b, c, t = x.shape b, c, t = x.shape
if t % self.period != 0: # Pad if not divisible by period
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
# --- 1. Tokenize Audio --- x = x.view(b, c, t // self.period, self.period)
x = self.patch_embed(x) # (Batch, Num_Patches, Embed_Dim)
# --- 2. Add CLS Token --- for l in self.convs:
cls_tokens = self.cls_token.expand(b, -1, -1) x = l(x)
x = torch.cat((cls_tokens, x), dim=1) # (Batch, Num_Patches + 1, Embed_Dim) x = F.leaky_relu(x, 0.1)
fmap.append(x) # Store feature map for Feature Matching Loss
# --- 3. Add Positional Embeddings --- x = self.conv_post(x)
x = x + self.pos_embed fmap.append(x)
# --- 4. Transformer Layers --- # Flatten back to 1D for score
x = self.transformer(x) x = torch.flatten(x, 1, -1)
# --- 5. Classification (Use only CLS token) --- return x, fmap
cls_output = x[:, 0] # Take the first token
cls_output = self.norm(cls_output)
score = self.head(cls_output) # (Batch, 1)
return score class MultiPeriodDiscriminator(nn.Module):
def __init__(self, periods=[2, 3, 5, 7, 11]):
super(MultiPeriodDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([
DiscriminatorP(p) for p in periods
])
def forward(self, y, y_hat):
y_d_rs = [] # Real scores
y_d_gs = [] # Generated (Fake) scores
fmap_rs = [] # Real feature maps
fmap_gs = [] # Generated (Fake) feature maps
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
# -------------------------------------------------------------------
# 2. Multi-Scale Discriminator (MSD)
# Captures structure at different audio resolutions (raw, x0.5, x0.25).
# -------------------------------------------------------------------
class DiscriminatorS(nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = spectral_norm if use_spectral_norm else weight_norm
# Standard 1D Convolutions with large receptive field
self.convs = nn.ModuleList([
norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)),
norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
])
self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.1)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiScaleDiscriminator(nn.Module):
def __init__(self):
super(MultiScaleDiscriminator, self).__init__()
# 3 Scales: Original, Downsampled x2, Downsampled x4
self.discriminators = nn.ModuleList([
DiscriminatorS(use_spectral_norm=True),
DiscriminatorS(),
DiscriminatorS(),
])
self.meanpools = nn.ModuleList([
nn.AvgPool1d(4, 2, padding=2),
nn.AvgPool1d(4, 2, padding=2)
])
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
if i != 0:
# Downsample input for subsequent discriminators
y = self.meanpools[i-1](y)
y_hat = self.meanpools[i-1](y_hat)
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
# -------------------------------------------------------------------
# 3. Master Wrapper
# Combines MPD and MSD into one class to fit your training script.
# -------------------------------------------------------------------
class SISUDiscriminator(nn.Module):
def __init__(self):
super(SISUDiscriminator, self).__init__()
self.mpd = MultiPeriodDiscriminator()
self.msd = MultiScaleDiscriminator()
def forward(self, y, y_hat):
# Return format:
# scores_real, scores_fake, features_real, features_fake
# Run Multi-Period
mpd_y_d_rs, mpd_y_d_gs, mpd_fmap_rs, mpd_fmap_gs = self.mpd(y, y_hat)
# Run Multi-Scale
msd_y_d_rs, msd_y_d_gs, msd_fmap_rs, msd_fmap_gs = self.msd(y, y_hat)
# Combine all results
return (
mpd_y_d_rs + msd_y_d_rs, # All real scores
mpd_y_d_gs + msd_y_d_gs, # All fake scores
mpd_fmap_rs + msd_fmap_rs, # All real feature maps
mpd_fmap_gs + msd_fmap_gs # All fake feature maps
)

View File

@@ -3,6 +3,7 @@ import datetime
import os import os
import torch import torch
import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import tqdm import tqdm
from accelerate import Accelerator from accelerate import Accelerator
@@ -39,10 +40,13 @@ accelerator = Accelerator(mixed_precision="bf16")
# Models # Models
# --------------------------- # ---------------------------
generator = SISUGenerator() generator = SISUGenerator()
# Note: SISUDiscriminator is now an Ensemble (MPD + MSD)
discriminator = SISUDiscriminator() discriminator = SISUDiscriminator()
accelerator.print("🔨 | Compiling models...") accelerator.print("🔨 | Compiling models...")
# Torch compile is great, but if you hit errors with the new List/Tuple outputs
# of the discriminator, you might need to disable it for D.
generator = torch.compile(generator) generator = torch.compile(generator)
discriminator = torch.compile(discriminator) discriminator = torch.compile(discriminator)
@@ -108,21 +112,24 @@ models_dir = "./models"
os.makedirs(models_dir, exist_ok=True) os.makedirs(models_dir, exist_ok=True)
def save_ckpt(path, epoch): def save_ckpt(path, epoch, loss=None, is_best=False):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.save( state = {
{
"epoch": epoch, "epoch": epoch,
"G": accelerator.unwrap_model(generator).state_dict(), "G": accelerator.unwrap_model(generator).state_dict(),
"D": accelerator.unwrap_model(discriminator).state_dict(), "D": accelerator.unwrap_model(discriminator).state_dict(),
"optG": optimizer_g.state_dict(), "optG": optimizer_g.state_dict(),
"optD": optimizer_d.state_dict(), "optD": optimizer_d.state_dict(),
"schedG": scheduler_g.state_dict(), "schedG": scheduler_g.state_dict(),
"schedD": scheduler_d.state_dict(), "schedD": scheduler_d.state_dict()
}, }
path,
) accelerator.save(state, os.path.join(models_dir, "last.pt"))
if is_best:
accelerator.save(state, os.path.join(models_dir, "best.pt"))
accelerator.print(f"🌟 | New best model saved with G Loss: {loss:.4f}")
start_epoch = 0 start_epoch = 0
@@ -143,9 +150,8 @@ if args.resume:
else: else:
accelerator.print("⚠️ | Resume requested but no checkpoint found. Starting fresh.") accelerator.print("⚠️ | Resume requested but no checkpoint found. Starting fresh.")
accelerator.print("🏋️ | Started training...")
smallest_loss = float('inf') accelerator.print("🏋️ | Started training...")
try: try:
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
@@ -172,7 +178,6 @@ try:
with accelerator.autocast(): with accelerator.autocast():
d_loss = discriminator_train( d_loss = discriminator_train(
high_quality, high_quality,
low_quality.detach(),
discriminator, discriminator,
generator_output.detach() generator_output.detach()
) )
@@ -218,7 +223,6 @@ try:
steps += 1 steps += 1
progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs") progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs")
# epoch averages & schedulers
if steps == 0: if steps == 0:
accelerator.print("🪹 | No steps in epoch (empty dataloader?). Exiting.") accelerator.print("🪹 | No steps in epoch (empty dataloader?). Exiting.")
break break
@@ -230,9 +234,6 @@ try:
scheduler_g.step(mean_g) scheduler_g.step(mean_g)
save_ckpt(os.path.join(models_dir, "last.pt"), epoch) save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
if smallest_loss > mean_g:
smallest_loss = mean_g
save_ckpt(os.path.join(models_dir, "latest-smallest_loss.pt"), epoch)
accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}") accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
except Exception: except Exception:

View File

@@ -2,13 +2,14 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
# Keep STFT settings as is
stft_loss_fn = MultiResolutionSTFTLoss( stft_loss_fn = MultiResolutionSTFTLoss(
fft_sizes=[512, 1024, 2048], fft_sizes=[512, 1024, 2048],
hop_sizes=[64, 128, 256], hop_sizes=[64, 128, 256],
win_lengths=[256, 512, 1024] win_lengths=[256, 512, 1024]
) )
def feature_matching_loss(fmap_r, fmap_g): def feature_matching_loss(fmap_r, fmap_g):
""" """
Computes L1 distance between real and fake feature maps. Computes L1 distance between real and fake feature maps.
@@ -16,11 +17,9 @@ def feature_matching_loss(fmap_r, fmap_g):
loss = 0 loss = 0
for dr, dg in zip(fmap_r, fmap_g): for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg): for rl, gl in zip(dr, dg):
# Stop gradient on real features to save memory/computation
rl = rl.detach() rl = rl.detach()
loss += torch.mean(torch.abs(rl - gl)) loss += torch.mean(torch.abs(rl - gl))
# Scale by number of feature maps to keep loss magnitude reasonable
return loss * 2 return loss * 2
@@ -33,11 +32,8 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
r_losses = [] r_losses = []
g_losses = [] g_losses = []
# Iterate over both MPD and MSD outputs
for dr, dg in zip(disc_real_outputs, disc_generated_outputs): for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
# Real should be 1.0
r_loss = torch.mean((dr - 1) ** 2) r_loss = torch.mean((dr - 1) ** 2)
# Fake should be 0.0
g_loss = torch.mean(dg ** 2) g_loss = torch.mean(dg ** 2)
loss += (r_loss + g_loss) loss += (r_loss + g_loss)
@@ -61,16 +57,11 @@ def generator_adv_loss(disc_generated_outputs):
def discriminator_train( def discriminator_train(
high_quality, high_quality,
low_quality,
discriminator, discriminator,
generator_output generator_output
): ):
# 1. Forward pass through the Ensemble Discriminator
# Note: We pass inputs separately now: (Real_Target, Fake_Candidate)
# We detach generator_output because we are only optimizing D here
y_d_rs, y_d_gs, _, _ = discriminator(high_quality, generator_output.detach()) y_d_rs, y_d_gs, _, _ = discriminator(high_quality, generator_output.detach())
# 2. Calculate Loss (LSGAN)
d_loss, _, _ = discriminator_loss(y_d_rs, y_d_gs) d_loss, _, _ = discriminator_loss(y_d_rs, y_d_gs)
return d_loss return d_loss
@@ -83,25 +74,14 @@ def generator_train(
discriminator, discriminator,
generator_output generator_output
): ):
# 1. Forward pass through Discriminator
# We do NOT detach generator_output here, we need gradients for G
y_d_rs, y_d_gs, fmap_rs, fmap_gs = discriminator(high_quality, generator_output) y_d_rs, y_d_gs, fmap_rs, fmap_gs = discriminator(high_quality, generator_output)
# 2. Adversarial Loss (Try to fool D into thinking G is Real)
loss_gen_adv = generator_adv_loss(y_d_gs) loss_gen_adv = generator_adv_loss(y_d_gs)
# 3. Feature Matching Loss (Force G to match internal features of D)
loss_fm = feature_matching_loss(fmap_rs, fmap_gs) loss_fm = feature_matching_loss(fmap_rs, fmap_gs)
# 4. Mel-Spectrogram / STFT Loss (Audio Quality)
stft_loss = stft_loss_fn(high_quality, generator_output)["total"] stft_loss = stft_loss_fn(high_quality, generator_output)["total"]
# -----------------------------------------
# 5. Combine Losses
# -----------------------------------------
# Recommended weights for HiFi-GAN/EnCodec style architectures:
# STFT is dominant (45), FM provides stability (2), Adv provides texture (1)
lambda_stft = 45.0 lambda_stft = 45.0
lambda_fm = 2.0 lambda_fm = 2.0
lambda_adv = 1.0 lambda_adv = 1.0