⚗️ | Added MultiPeriodDiscriminator implementation from HiFi-GAN

This commit is contained in:
2025-12-04 14:22:48 +02:00
parent 782a3bab28
commit bf0a6e58e9
4 changed files with 210 additions and 131 deletions

View File

@@ -1,70 +1,98 @@
import torch
import torch.nn as nn
import torch.nn.utils as utils
import numpy as np
class PatchEmbedding(nn.Module):
"""
Converts raw audio into a sequence of embeddings (tokens).
Small patch_size = Higher Precision (more tokens, finer detail).
Large patch_size = Lower Precision (fewer tokens, more global).
"""
def __init__(self, in_channels, embed_dim, patch_size, spectral_norm=True):
super().__init__()
# We use a Conv1d with stride=patch_size to create non-overlapping patches
self.proj = nn.Conv1d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def discriminator_block(
in_channels,
out_channels,
kernel_size=15,
stride=1,
dilation=1
):
padding = dilation * (kernel_size - 1) // 2
conv_layer = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding
)
conv_layer = utils.spectral_norm(conv_layer)
leaky_relu = nn.LeakyReLU(0.2)
return nn.Sequential(conv_layer, leaky_relu)
class AttentionBlock(nn.Module):
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid(),
)
if spectral_norm:
self.proj = utils.spectral_norm(self.proj)
def forward(self, x):
attention_weights = self.attention(x)
return x + (x * attention_weights)
# x shape: (batch, 1, 8000)
x = self.proj(x) # shape: (batch, embed_dim, num_patches)
x = x.transpose(1, 2) # shape: (batch, num_patches, embed_dim)
return x
class TransformerDiscriminator(nn.Module):
def __init__(
self,
audio_length=8000,
patch_size=16, # Lower this for higher precision (e.g., 8 or 16)
embed_dim=128, # Dimension of the transformer tokens
depth=4, # Number of Transformer blocks
heads=4, # Number of attention heads
mlp_dim=256, # Hidden dimension of the feed-forward layer
spectral_norm=True
):
super().__init__()
class SISUDiscriminator(nn.Module):
def __init__(self, layers=8):
super(SISUDiscriminator, self).__init__()
self.discriminator_blocks = nn.Sequential(
# 1 -> 32
discriminator_block(2, layers),
AttentionBlock(layers),
# 32 -> 64
discriminator_block(layers, layers * 2, dilation=2),
# 64 -> 128
discriminator_block(layers * 2, layers * 4, dilation=4),
AttentionBlock(layers * 4),
# 128 -> 256
discriminator_block(layers * 4, layers * 8, stride=4),
# 256 -> 512
# discriminator_block(layers * 8, layers * 16, stride=4)
# 1. Calculate sequence length
self.num_patches = audio_length // patch_size
# 2. Patch Embedding (Tokenizer)
self.patch_embed = PatchEmbedding(1, embed_dim, patch_size, spectral_norm)
# 3. Class Token (like in BERT/ViT) to aggregate global info
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# 4. Positional Embedding (Learnable)
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
# 5. Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=heads,
dim_feedforward=mlp_dim,
dropout=0.1,
activation='gelu',
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
self.final_conv = nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1)
# 6. Final Classification Head
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, 1)
self.avg_pool = nn.AdaptiveAvgPool1d(1)
if spectral_norm:
self.head = utils.spectral_norm(self.head)
# Initialize weights
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.cls_token, std=0.02)
nn.init.normal_(self.pos_embed, std=0.02)
def forward(self, x):
x = self.discriminator_blocks(x)
x = self.final_conv(x)
x = self.avg_pool(x)
return x.squeeze(2)
b, c, t = x.shape
# --- 1. Tokenize Audio ---
x = self.patch_embed(x) # (Batch, Num_Patches, Embed_Dim)
# --- 2. Add CLS Token ---
cls_tokens = self.cls_token.expand(b, -1, -1)
x = torch.cat((cls_tokens, x), dim=1) # (Batch, Num_Patches + 1, Embed_Dim)
# --- 3. Add Positional Embeddings ---
x = x + self.pos_embed
# --- 4. Transformer Layers ---
x = self.transformer(x)
# --- 5. Classification (Use only CLS token) ---
cls_output = x[:, 0] # Take the first token
cls_output = self.norm(cls_output)
score = self.head(cls_output) # (Batch, 1)
return score

View File

@@ -1,19 +1,20 @@
import torch
import torch.nn as nn
from torch.nn.utils.parametrizations import weight_norm
def GeneratorBlock(in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
padding = (kernel_size - 1) // 2 * dilation
return nn.Sequential(
nn.Conv1d(
weight_norm(nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding
),
nn.InstanceNorm1d(out_channels),
)),
nn.PReLU(num_parameters=1, init=0.1),
)
@@ -22,9 +23,9 @@ class AttentionBlock(nn.Module):
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
weight_norm(nn.Conv1d(channels, channels // 4, kernel_size=1)),
nn.ReLU(inplace=True),
nn.Conv1d(channels // 4, channels, kernel_size=1),
weight_norm(nn.Conv1d(channels // 4, channels, kernel_size=1)),
nn.Sigmoid(),
)
@@ -49,21 +50,21 @@ class ResidualInResidualBlock(nn.Module):
x = self.attention(x)
return x + residual
def UpsampleBlock(in_channels, out_channels):
def UpsampleBlock(in_channels, out_channels, scale_factor=2):
return nn.Sequential(
nn.ConvTranspose1d(
nn.Upsample(scale_factor=scale_factor, mode='nearest'),
weight_norm(nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=4,
stride=2,
kernel_size=3,
stride=1,
padding=1
),
nn.InstanceNorm1d(out_channels),
)),
nn.PReLU(num_parameters=1, init=0.1)
)
class SISUGenerator(nn.Module):
def __init__(self, channels=32, num_rirb=1):
def __init__(self, channels=32, num_rirb=4):
super(SISUGenerator, self).__init__()
self.first_conv = GeneratorBlock(1, channels)
@@ -73,10 +74,9 @@ class SISUGenerator(nn.Module):
self.downsample_2 = GeneratorBlock(channels * 2, channels * 4, stride=2)
self.downsample_2_attn = AttentionBlock(channels * 4)
self.rirb = ResidualInResidualBlock(channels * 4)
# self.rirb = nn.Sequential(
# *[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)]
# )
self.rirb = nn.Sequential(
*[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)]
)
self.upsample = UpsampleBlock(channels * 4, channels * 2)
self.upsample_attn = AttentionBlock(channels * 2)
@@ -87,13 +87,15 @@ class SISUGenerator(nn.Module):
self.compress_2 = GeneratorBlock(channels * 2, channels)
self.final_conv = nn.Sequential(
nn.Conv1d(channels, 1, kernel_size=7, padding=3),
weight_norm(nn.Conv1d(channels, 1, kernel_size=7, padding=3)),
nn.Tanh()
)
def forward(self, x):
residual_input = x
# Encoding
x1 = self.first_conv(x)
x2 = self.downsample(x1)
@@ -102,8 +104,10 @@ class SISUGenerator(nn.Module):
x3 = self.downsample_2(x2)
x3 = self.downsample_2_attn(x3)
# Bottleneck (Deep Residual processing)
x_rirb = self.rirb(x3)
# Decoding with Skip Connections
up1 = self.upsample(x_rirb)
up1 = self.upsample_attn(up1)

View File

@@ -3,7 +3,6 @@ import datetime
import os
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm
from accelerate import Accelerator
@@ -23,7 +22,7 @@ parser.add_argument(
"--epochs", type=int, default=5000, help="Number of training epochs"
)
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
parser.add_argument("--num_workers", type=int, default=2, help="DataLoader num_workers")
parser.add_argument("--num_workers", type=int, default=4, help="DataLoader num_workers") # Increased workers slightly
parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument(
"--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA"
@@ -94,8 +93,6 @@ scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_d, mode="min", factor=0.5, patience=5
)
criterion_d = nn.MSELoss()
# ---------------------------
# Prepare accelerator
# ---------------------------
@@ -131,23 +128,25 @@ def save_ckpt(path, epoch):
start_epoch = 0
if args.resume:
ckpt_path = os.path.join(models_dir, "last.pt")
ckpt = torch.load(ckpt_path)
if os.path.exists(ckpt_path):
ckpt = torch.load(ckpt_path)
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"])
optimizer_g.load_state_dict(ckpt["optG"])
optimizer_d.load_state_dict(ckpt["optD"])
scheduler_g.load_state_dict(ckpt["schedG"])
scheduler_d.load_state_dict(ckpt["schedD"])
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"])
optimizer_g.load_state_dict(ckpt["optG"])
optimizer_d.load_state_dict(ckpt["optD"])
scheduler_g.load_state_dict(ckpt["schedG"])
scheduler_d.load_state_dict(ckpt["schedD"])
start_epoch = ckpt.get("epoch", 1)
accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!")
real_buf = torch.full((loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32)
fake_buf = torch.zeros((loader_batch_size, 1), device=accelerator.device, dtype=torch.float32)
start_epoch = ckpt.get("epoch", 1)
accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!")
else:
accelerator.print("⚠️ | Resume requested but no checkpoint found. Starting fresh.")
accelerator.print("🏋️ | Started training...")
smallest_loss = float('inf')
try:
for epoch in range(start_epoch, args.epochs):
generator.train()
@@ -164,11 +163,6 @@ try:
(high_quality, low_quality),
(high_sample_rate, low_sample_rate),
) in enumerate(progress_bar):
batch_size = high_quality.size(0)
real_labels = real_buf[:batch_size].to(accelerator.device)
fake_labels = fake_buf[:batch_size].to(accelerator.device)
with accelerator.autocast():
generator_output = generator(low_quality)
@@ -179,10 +173,7 @@ try:
d_loss = discriminator_train(
high_quality,
low_quality.detach(),
real_labels,
fake_labels,
discriminator,
criterion_d,
generator_output.detach()
)
@@ -197,10 +188,8 @@ try:
g_total, g_adv = generator_train(
low_quality,
high_quality,
real_labels,
generator,
discriminator,
criterion_d,
generator_output
)
@@ -241,6 +230,9 @@ try:
scheduler_g.step(mean_g)
save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
if smallest_loss > mean_g:
smallest_loss = mean_g
save_ckpt(os.path.join(models_dir, "latest-smallest_loss.pt"), epoch)
accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
except Exception:

View File

@@ -1,58 +1,113 @@
import torch
import torch.nn.functional as F
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
# stft_loss_fn = MultiResolutionSTFTLoss(
# fft_sizes=[512, 1024, 2048, 4096],
# hop_sizes=[128, 256, 512, 1024],
# win_lengths=[512, 1024, 2048, 4096]
# )
stft_loss_fn = MultiResolutionSTFTLoss(
fft_sizes=[512, 1024, 2048],
hop_sizes=[64, 128, 256],
win_lengths=[256, 512, 1024]
)
def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor:
absolute_difference = torch.abs(input_one - input_two)
return torch.mean(absolute_difference)
def feature_matching_loss(fmap_r, fmap_g):
"""
Computes L1 distance between real and fake feature maps.
"""
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
# Stop gradient on real features to save memory/computation
rl = rl.detach()
loss += torch.mean(torch.abs(rl - gl))
# Scale by number of feature maps to keep loss magnitude reasonable
return loss * 2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
"""
Least Squares GAN Loss (LSGAN) for the Discriminator.
Objective: Real -> 1, Fake -> 0
"""
loss = 0
r_losses = []
g_losses = []
# Iterate over both MPD and MSD outputs
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
# Real should be 1.0
r_loss = torch.mean((dr - 1) ** 2)
# Fake should be 0.0
g_loss = torch.mean(dg ** 2)
loss += (r_loss + g_loss)
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_adv_loss(disc_generated_outputs):
"""
Least Squares GAN Loss for the Generator.
Objective: Fake -> 1 (Fool the discriminator)
"""
loss = 0
for dg in zip(disc_generated_outputs):
dg = dg[0] # Unpack tuple
loss += torch.mean((dg - 1) ** 2)
return loss
def discriminator_train(
high_quality,
low_quality,
high_labels,
low_labels,
discriminator,
criterion,
generator_output
):
# 1. Forward pass through the Ensemble Discriminator
# Note: We pass inputs separately now: (Real_Target, Fake_Candidate)
# We detach generator_output because we are only optimizing D here
y_d_rs, y_d_gs, _, _ = discriminator(high_quality, generator_output.detach())
real_pair = torch.cat((low_quality, high_quality), dim=1)
decision_real = discriminator(real_pair)
d_loss_real = criterion(decision_real, high_labels)
# 2. Calculate Loss (LSGAN)
d_loss, _, _ = discriminator_loss(y_d_rs, y_d_gs)
fake_pair = torch.cat((low_quality, generator_output), dim=1)
decision_fake = discriminator(fake_pair)
d_loss_fake = criterion(decision_fake, low_labels)
d_loss = (d_loss_real + d_loss_fake) / 2.0
return d_loss
def generator_train(
low_quality, high_quality, real_labels, generator, discriminator, adv_criterion, generator_output):
low_quality,
high_quality,
generator,
discriminator,
generator_output
):
# 1. Forward pass through Discriminator
# We do NOT detach generator_output here, we need gradients for G
y_d_rs, y_d_gs, fmap_rs, fmap_gs = discriminator(high_quality, generator_output)
fake_pair = torch.cat((low_quality, generator_output), dim=1)
# 2. Adversarial Loss (Try to fool D into thinking G is Real)
loss_gen_adv = generator_adv_loss(y_d_gs)
discriminator_decision = discriminator(fake_pair)
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
# 3. Feature Matching Loss (Force G to match internal features of D)
loss_fm = feature_matching_loss(fmap_rs, fmap_gs)
mae_loss = signal_mae(generator_output, high_quality)
# 4. Mel-Spectrogram / STFT Loss (Audio Quality)
stft_loss = stft_loss_fn(high_quality, generator_output)["total"]
lambda_mae = 10.0
lambda_stft = 2.5
lambda_adv = 2.5
combined_loss = (lambda_mae * mae_loss) + (lambda_stft * stft_loss) + (lambda_adv * adversarial_loss)
return combined_loss, adversarial_loss
# -----------------------------------------
# 5. Combine Losses
# -----------------------------------------
# Recommended weights for HiFi-GAN/EnCodec style architectures:
# STFT is dominant (45), FM provides stability (2), Adv provides texture (1)
lambda_stft = 45.0
lambda_fm = 2.0
lambda_adv = 1.0
combined_loss = (lambda_stft * stft_loss) + \
(lambda_fm * loss_fm) + \
(lambda_adv * loss_gen_adv)
return combined_loss, loss_gen_adv