⚗️ | Added MultiPeriodDiscriminator implementation from HiFi-GAN

This commit is contained in:
2025-12-04 14:22:48 +02:00
parent 782a3bab28
commit bf0a6e58e9
4 changed files with 210 additions and 131 deletions

View File

@@ -1,70 +1,98 @@
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.utils as utils import torch.nn.utils as utils
import numpy as np
class PatchEmbedding(nn.Module):
"""
Converts raw audio into a sequence of embeddings (tokens).
Small patch_size = Higher Precision (more tokens, finer detail).
Large patch_size = Lower Precision (fewer tokens, more global).
"""
def __init__(self, in_channels, embed_dim, patch_size, spectral_norm=True):
super().__init__()
# We use a Conv1d with stride=patch_size to create non-overlapping patches
self.proj = nn.Conv1d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def discriminator_block( if spectral_norm:
in_channels, self.proj = utils.spectral_norm(self.proj)
out_channels,
kernel_size=15, def forward(self, x):
stride=1, # x shape: (batch, 1, 8000)
dilation=1 x = self.proj(x) # shape: (batch, embed_dim, num_patches)
x = x.transpose(1, 2) # shape: (batch, num_patches, embed_dim)
return x
class TransformerDiscriminator(nn.Module):
def __init__(
self,
audio_length=8000,
patch_size=16, # Lower this for higher precision (e.g., 8 or 16)
embed_dim=128, # Dimension of the transformer tokens
depth=4, # Number of Transformer blocks
heads=4, # Number of attention heads
mlp_dim=256, # Hidden dimension of the feed-forward layer
spectral_norm=True
): ):
padding = dilation * (kernel_size - 1) // 2 super().__init__()
conv_layer = nn.Conv1d( # 1. Calculate sequence length
in_channels, self.num_patches = audio_length // patch_size
out_channels,
kernel_size=kernel_size, # 2. Patch Embedding (Tokenizer)
stride=stride, self.patch_embed = PatchEmbedding(1, embed_dim, patch_size, spectral_norm)
dilation=dilation,
padding=padding # 3. Class Token (like in BERT/ViT) to aggregate global info
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# 4. Positional Embedding (Learnable)
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
# 5. Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=heads,
dim_feedforward=mlp_dim,
dropout=0.1,
activation='gelu',
batch_first=True
) )
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
conv_layer = utils.spectral_norm(conv_layer) # 6. Final Classification Head
leaky_relu = nn.LeakyReLU(0.2) self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, 1)
return nn.Sequential(conv_layer, leaky_relu) if spectral_norm:
self.head = utils.spectral_norm(self.head)
# Initialize weights
self._init_weights()
class AttentionBlock(nn.Module): def _init_weights(self):
def __init__(self, channels): nn.init.normal_(self.cls_token, std=0.02)
super(AttentionBlock, self).__init__() nn.init.normal_(self.pos_embed, std=0.02)
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid(),
)
def forward(self, x): def forward(self, x):
attention_weights = self.attention(x) b, c, t = x.shape
return x + (x * attention_weights)
# --- 1. Tokenize Audio ---
x = self.patch_embed(x) # (Batch, Num_Patches, Embed_Dim)
class SISUDiscriminator(nn.Module): # --- 2. Add CLS Token ---
def __init__(self, layers=8): cls_tokens = self.cls_token.expand(b, -1, -1)
super(SISUDiscriminator, self).__init__() x = torch.cat((cls_tokens, x), dim=1) # (Batch, Num_Patches + 1, Embed_Dim)
self.discriminator_blocks = nn.Sequential(
# 1 -> 32
discriminator_block(2, layers),
AttentionBlock(layers),
# 32 -> 64
discriminator_block(layers, layers * 2, dilation=2),
# 64 -> 128
discriminator_block(layers * 2, layers * 4, dilation=4),
AttentionBlock(layers * 4),
# 128 -> 256
discriminator_block(layers * 4, layers * 8, stride=4),
# 256 -> 512
# discriminator_block(layers * 8, layers * 16, stride=4)
)
self.final_conv = nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1) # --- 3. Add Positional Embeddings ---
x = x + self.pos_embed
self.avg_pool = nn.AdaptiveAvgPool1d(1) # --- 4. Transformer Layers ---
x = self.transformer(x)
def forward(self, x): # --- 5. Classification (Use only CLS token) ---
x = self.discriminator_blocks(x) cls_output = x[:, 0] # Take the first token
x = self.final_conv(x) cls_output = self.norm(cls_output)
x = self.avg_pool(x)
return x.squeeze(2) score = self.head(cls_output) # (Batch, 1)
return score

View File

@@ -1,19 +1,20 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.utils.parametrizations import weight_norm
def GeneratorBlock(in_channels, out_channels, kernel_size=3, stride=1, dilation=1): def GeneratorBlock(in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
padding = (kernel_size - 1) // 2 * dilation padding = (kernel_size - 1) // 2 * dilation
return nn.Sequential( return nn.Sequential(
nn.Conv1d(
weight_norm(nn.Conv1d(
in_channels, in_channels,
out_channels, out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride, stride=stride,
dilation=dilation, dilation=dilation,
padding=padding padding=padding
), )),
nn.InstanceNorm1d(out_channels),
nn.PReLU(num_parameters=1, init=0.1), nn.PReLU(num_parameters=1, init=0.1),
) )
@@ -22,9 +23,9 @@ class AttentionBlock(nn.Module):
def __init__(self, channels): def __init__(self, channels):
super(AttentionBlock, self).__init__() super(AttentionBlock, self).__init__()
self.attention = nn.Sequential( self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1), weight_norm(nn.Conv1d(channels, channels // 4, kernel_size=1)),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv1d(channels // 4, channels, kernel_size=1), weight_norm(nn.Conv1d(channels // 4, channels, kernel_size=1)),
nn.Sigmoid(), nn.Sigmoid(),
) )
@@ -49,21 +50,21 @@ class ResidualInResidualBlock(nn.Module):
x = self.attention(x) x = self.attention(x)
return x + residual return x + residual
def UpsampleBlock(in_channels, out_channels): def UpsampleBlock(in_channels, out_channels, scale_factor=2):
return nn.Sequential( return nn.Sequential(
nn.ConvTranspose1d( nn.Upsample(scale_factor=scale_factor, mode='nearest'),
weight_norm(nn.Conv1d(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=4, kernel_size=3,
stride=2, stride=1,
padding=1 padding=1
), )),
nn.InstanceNorm1d(out_channels),
nn.PReLU(num_parameters=1, init=0.1) nn.PReLU(num_parameters=1, init=0.1)
) )
class SISUGenerator(nn.Module): class SISUGenerator(nn.Module):
def __init__(self, channels=32, num_rirb=1): def __init__(self, channels=32, num_rirb=4):
super(SISUGenerator, self).__init__() super(SISUGenerator, self).__init__()
self.first_conv = GeneratorBlock(1, channels) self.first_conv = GeneratorBlock(1, channels)
@@ -73,10 +74,9 @@ class SISUGenerator(nn.Module):
self.downsample_2 = GeneratorBlock(channels * 2, channels * 4, stride=2) self.downsample_2 = GeneratorBlock(channels * 2, channels * 4, stride=2)
self.downsample_2_attn = AttentionBlock(channels * 4) self.downsample_2_attn = AttentionBlock(channels * 4)
self.rirb = ResidualInResidualBlock(channels * 4) self.rirb = nn.Sequential(
# self.rirb = nn.Sequential( *[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)]
# *[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)] )
# )
self.upsample = UpsampleBlock(channels * 4, channels * 2) self.upsample = UpsampleBlock(channels * 4, channels * 2)
self.upsample_attn = AttentionBlock(channels * 2) self.upsample_attn = AttentionBlock(channels * 2)
@@ -87,13 +87,15 @@ class SISUGenerator(nn.Module):
self.compress_2 = GeneratorBlock(channels * 2, channels) self.compress_2 = GeneratorBlock(channels * 2, channels)
self.final_conv = nn.Sequential( self.final_conv = nn.Sequential(
nn.Conv1d(channels, 1, kernel_size=7, padding=3), weight_norm(nn.Conv1d(channels, 1, kernel_size=7, padding=3)),
nn.Tanh() nn.Tanh()
) )
def forward(self, x): def forward(self, x):
residual_input = x residual_input = x
# Encoding
x1 = self.first_conv(x) x1 = self.first_conv(x)
x2 = self.downsample(x1) x2 = self.downsample(x1)
@@ -102,8 +104,10 @@ class SISUGenerator(nn.Module):
x3 = self.downsample_2(x2) x3 = self.downsample_2(x2)
x3 = self.downsample_2_attn(x3) x3 = self.downsample_2_attn(x3)
# Bottleneck (Deep Residual processing)
x_rirb = self.rirb(x3) x_rirb = self.rirb(x3)
# Decoding with Skip Connections
up1 = self.upsample(x_rirb) up1 = self.upsample(x_rirb)
up1 = self.upsample_attn(up1) up1 = self.upsample_attn(up1)

View File

@@ -3,7 +3,6 @@ import datetime
import os import os
import torch import torch
import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import tqdm import tqdm
from accelerate import Accelerator from accelerate import Accelerator
@@ -23,7 +22,7 @@ parser.add_argument(
"--epochs", type=int, default=5000, help="Number of training epochs" "--epochs", type=int, default=5000, help="Number of training epochs"
) )
parser.add_argument("--batch_size", type=int, default=8, help="Batch size") parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
parser.add_argument("--num_workers", type=int, default=2, help="DataLoader num_workers") parser.add_argument("--num_workers", type=int, default=4, help="DataLoader num_workers") # Increased workers slightly
parser.add_argument("--debug", action="store_true", help="Print debug logs") parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument( parser.add_argument(
"--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA" "--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA"
@@ -94,8 +93,6 @@ scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_d, mode="min", factor=0.5, patience=5 optimizer_d, mode="min", factor=0.5, patience=5
) )
criterion_d = nn.MSELoss()
# --------------------------- # ---------------------------
# Prepare accelerator # Prepare accelerator
# --------------------------- # ---------------------------
@@ -131,6 +128,7 @@ def save_ckpt(path, epoch):
start_epoch = 0 start_epoch = 0
if args.resume: if args.resume:
ckpt_path = os.path.join(models_dir, "last.pt") ckpt_path = os.path.join(models_dir, "last.pt")
if os.path.exists(ckpt_path):
ckpt = torch.load(ckpt_path) ckpt = torch.load(ckpt_path)
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"]) accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
@@ -142,12 +140,13 @@ if args.resume:
start_epoch = ckpt.get("epoch", 1) start_epoch = ckpt.get("epoch", 1)
accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!") accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!")
else:
real_buf = torch.full((loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32) accelerator.print("⚠️ | Resume requested but no checkpoint found. Starting fresh.")
fake_buf = torch.zeros((loader_batch_size, 1), device=accelerator.device, dtype=torch.float32)
accelerator.print("🏋️ | Started training...") accelerator.print("🏋️ | Started training...")
smallest_loss = float('inf')
try: try:
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
generator.train() generator.train()
@@ -164,11 +163,6 @@ try:
(high_quality, low_quality), (high_quality, low_quality),
(high_sample_rate, low_sample_rate), (high_sample_rate, low_sample_rate),
) in enumerate(progress_bar): ) in enumerate(progress_bar):
batch_size = high_quality.size(0)
real_labels = real_buf[:batch_size].to(accelerator.device)
fake_labels = fake_buf[:batch_size].to(accelerator.device)
with accelerator.autocast(): with accelerator.autocast():
generator_output = generator(low_quality) generator_output = generator(low_quality)
@@ -179,10 +173,7 @@ try:
d_loss = discriminator_train( d_loss = discriminator_train(
high_quality, high_quality,
low_quality.detach(), low_quality.detach(),
real_labels,
fake_labels,
discriminator, discriminator,
criterion_d,
generator_output.detach() generator_output.detach()
) )
@@ -197,10 +188,8 @@ try:
g_total, g_adv = generator_train( g_total, g_adv = generator_train(
low_quality, low_quality,
high_quality, high_quality,
real_labels,
generator, generator,
discriminator, discriminator,
criterion_d,
generator_output generator_output
) )
@@ -241,6 +230,9 @@ try:
scheduler_g.step(mean_g) scheduler_g.step(mean_g)
save_ckpt(os.path.join(models_dir, "last.pt"), epoch) save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
if smallest_loss > mean_g:
smallest_loss = mean_g
save_ckpt(os.path.join(models_dir, "latest-smallest_loss.pt"), epoch)
accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}") accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
except Exception: except Exception:

View File

@@ -1,58 +1,113 @@
import torch import torch
import torch.nn.functional as F
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
# stft_loss_fn = MultiResolutionSTFTLoss(
# fft_sizes=[512, 1024, 2048, 4096],
# hop_sizes=[128, 256, 512, 1024],
# win_lengths=[512, 1024, 2048, 4096]
# )
stft_loss_fn = MultiResolutionSTFTLoss( stft_loss_fn = MultiResolutionSTFTLoss(
fft_sizes=[512, 1024, 2048], fft_sizes=[512, 1024, 2048],
hop_sizes=[64, 128, 256], hop_sizes=[64, 128, 256],
win_lengths=[256, 512, 1024] win_lengths=[256, 512, 1024]
) )
def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor: def feature_matching_loss(fmap_r, fmap_g):
absolute_difference = torch.abs(input_one - input_two) """
return torch.mean(absolute_difference) Computes L1 distance between real and fake feature maps.
"""
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
# Stop gradient on real features to save memory/computation
rl = rl.detach()
loss += torch.mean(torch.abs(rl - gl))
# Scale by number of feature maps to keep loss magnitude reasonable
return loss * 2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
"""
Least Squares GAN Loss (LSGAN) for the Discriminator.
Objective: Real -> 1, Fake -> 0
"""
loss = 0
r_losses = []
g_losses = []
# Iterate over both MPD and MSD outputs
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
# Real should be 1.0
r_loss = torch.mean((dr - 1) ** 2)
# Fake should be 0.0
g_loss = torch.mean(dg ** 2)
loss += (r_loss + g_loss)
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_adv_loss(disc_generated_outputs):
"""
Least Squares GAN Loss for the Generator.
Objective: Fake -> 1 (Fool the discriminator)
"""
loss = 0
for dg in zip(disc_generated_outputs):
dg = dg[0] # Unpack tuple
loss += torch.mean((dg - 1) ** 2)
return loss
def discriminator_train( def discriminator_train(
high_quality, high_quality,
low_quality, low_quality,
high_labels,
low_labels,
discriminator, discriminator,
criterion,
generator_output generator_output
): ):
# 1. Forward pass through the Ensemble Discriminator
# Note: We pass inputs separately now: (Real_Target, Fake_Candidate)
# We detach generator_output because we are only optimizing D here
y_d_rs, y_d_gs, _, _ = discriminator(high_quality, generator_output.detach())
real_pair = torch.cat((low_quality, high_quality), dim=1) # 2. Calculate Loss (LSGAN)
decision_real = discriminator(real_pair) d_loss, _, _ = discriminator_loss(y_d_rs, y_d_gs)
d_loss_real = criterion(decision_real, high_labels)
fake_pair = torch.cat((low_quality, generator_output), dim=1)
decision_fake = discriminator(fake_pair)
d_loss_fake = criterion(decision_fake, low_labels)
d_loss = (d_loss_real + d_loss_fake) / 2.0
return d_loss return d_loss
def generator_train( def generator_train(
low_quality, high_quality, real_labels, generator, discriminator, adv_criterion, generator_output): low_quality,
high_quality,
generator,
discriminator,
generator_output
):
# 1. Forward pass through Discriminator
# We do NOT detach generator_output here, we need gradients for G
y_d_rs, y_d_gs, fmap_rs, fmap_gs = discriminator(high_quality, generator_output)
fake_pair = torch.cat((low_quality, generator_output), dim=1) # 2. Adversarial Loss (Try to fool D into thinking G is Real)
loss_gen_adv = generator_adv_loss(y_d_gs)
discriminator_decision = discriminator(fake_pair) # 3. Feature Matching Loss (Force G to match internal features of D)
adversarial_loss = adv_criterion(discriminator_decision, real_labels) loss_fm = feature_matching_loss(fmap_rs, fmap_gs)
mae_loss = signal_mae(generator_output, high_quality) # 4. Mel-Spectrogram / STFT Loss (Audio Quality)
stft_loss = stft_loss_fn(high_quality, generator_output)["total"] stft_loss = stft_loss_fn(high_quality, generator_output)["total"]
lambda_mae = 10.0 # -----------------------------------------
lambda_stft = 2.5 # 5. Combine Losses
lambda_adv = 2.5 # -----------------------------------------
combined_loss = (lambda_mae * mae_loss) + (lambda_stft * stft_loss) + (lambda_adv * adversarial_loss) # Recommended weights for HiFi-GAN/EnCodec style architectures:
return combined_loss, adversarial_loss # STFT is dominant (45), FM provides stability (2), Adv provides texture (1)
lambda_stft = 45.0
lambda_fm = 2.0
lambda_adv = 1.0
combined_loss = (lambda_stft * stft_loss) + \
(lambda_fm * loss_fm) + \
(lambda_adv * loss_gen_adv)
return combined_loss, loss_gen_adv