⚗️ | Added MultiPeriodDiscriminator implementation from HiFi-GAN
This commit is contained in:
142
discriminator.py
142
discriminator.py
@@ -1,70 +1,98 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.utils as utils
|
||||
import numpy as np
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
"""
|
||||
Converts raw audio into a sequence of embeddings (tokens).
|
||||
Small patch_size = Higher Precision (more tokens, finer detail).
|
||||
Large patch_size = Lower Precision (fewer tokens, more global).
|
||||
"""
|
||||
def __init__(self, in_channels, embed_dim, patch_size, spectral_norm=True):
|
||||
super().__init__()
|
||||
# We use a Conv1d with stride=patch_size to create non-overlapping patches
|
||||
self.proj = nn.Conv1d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def discriminator_block(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=15,
|
||||
stride=1,
|
||||
dilation=1
|
||||
):
|
||||
padding = dilation * (kernel_size - 1) // 2
|
||||
|
||||
conv_layer = nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding
|
||||
)
|
||||
|
||||
conv_layer = utils.spectral_norm(conv_layer)
|
||||
leaky_relu = nn.LeakyReLU(0.2)
|
||||
|
||||
return nn.Sequential(conv_layer, leaky_relu)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super(AttentionBlock, self).__init__()
|
||||
self.attention = nn.Sequential(
|
||||
nn.Conv1d(channels, channels // 4, kernel_size=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(channels // 4, channels, kernel_size=1),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
if spectral_norm:
|
||||
self.proj = utils.spectral_norm(self.proj)
|
||||
|
||||
def forward(self, x):
|
||||
attention_weights = self.attention(x)
|
||||
return x + (x * attention_weights)
|
||||
# x shape: (batch, 1, 8000)
|
||||
x = self.proj(x) # shape: (batch, embed_dim, num_patches)
|
||||
x = x.transpose(1, 2) # shape: (batch, num_patches, embed_dim)
|
||||
return x
|
||||
|
||||
class TransformerDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
audio_length=8000,
|
||||
patch_size=16, # Lower this for higher precision (e.g., 8 or 16)
|
||||
embed_dim=128, # Dimension of the transformer tokens
|
||||
depth=4, # Number of Transformer blocks
|
||||
heads=4, # Number of attention heads
|
||||
mlp_dim=256, # Hidden dimension of the feed-forward layer
|
||||
spectral_norm=True
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
class SISUDiscriminator(nn.Module):
|
||||
def __init__(self, layers=8):
|
||||
super(SISUDiscriminator, self).__init__()
|
||||
self.discriminator_blocks = nn.Sequential(
|
||||
# 1 -> 32
|
||||
discriminator_block(2, layers),
|
||||
AttentionBlock(layers),
|
||||
# 32 -> 64
|
||||
discriminator_block(layers, layers * 2, dilation=2),
|
||||
# 64 -> 128
|
||||
discriminator_block(layers * 2, layers * 4, dilation=4),
|
||||
AttentionBlock(layers * 4),
|
||||
# 128 -> 256
|
||||
discriminator_block(layers * 4, layers * 8, stride=4),
|
||||
# 256 -> 512
|
||||
# discriminator_block(layers * 8, layers * 16, stride=4)
|
||||
# 1. Calculate sequence length
|
||||
self.num_patches = audio_length // patch_size
|
||||
|
||||
# 2. Patch Embedding (Tokenizer)
|
||||
self.patch_embed = PatchEmbedding(1, embed_dim, patch_size, spectral_norm)
|
||||
|
||||
# 3. Class Token (like in BERT/ViT) to aggregate global info
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
|
||||
# 4. Positional Embedding (Learnable)
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
|
||||
|
||||
# 5. Transformer Encoder
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=embed_dim,
|
||||
nhead=heads,
|
||||
dim_feedforward=mlp_dim,
|
||||
dropout=0.1,
|
||||
activation='gelu',
|
||||
batch_first=True
|
||||
)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
|
||||
|
||||
self.final_conv = nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1)
|
||||
# 6. Final Classification Head
|
||||
self.norm = nn.LayerNorm(embed_dim)
|
||||
self.head = nn.Linear(embed_dim, 1)
|
||||
|
||||
self.avg_pool = nn.AdaptiveAvgPool1d(1)
|
||||
if spectral_norm:
|
||||
self.head = utils.spectral_norm(self.head)
|
||||
|
||||
# Initialize weights
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
nn.init.normal_(self.cls_token, std=0.02)
|
||||
nn.init.normal_(self.pos_embed, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.discriminator_blocks(x)
|
||||
x = self.final_conv(x)
|
||||
x = self.avg_pool(x)
|
||||
return x.squeeze(2)
|
||||
b, c, t = x.shape
|
||||
|
||||
# --- 1. Tokenize Audio ---
|
||||
x = self.patch_embed(x) # (Batch, Num_Patches, Embed_Dim)
|
||||
|
||||
# --- 2. Add CLS Token ---
|
||||
cls_tokens = self.cls_token.expand(b, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1) # (Batch, Num_Patches + 1, Embed_Dim)
|
||||
|
||||
# --- 3. Add Positional Embeddings ---
|
||||
x = x + self.pos_embed
|
||||
|
||||
# --- 4. Transformer Layers ---
|
||||
x = self.transformer(x)
|
||||
|
||||
# --- 5. Classification (Use only CLS token) ---
|
||||
cls_output = x[:, 0] # Take the first token
|
||||
cls_output = self.norm(cls_output)
|
||||
|
||||
score = self.head(cls_output) # (Batch, 1)
|
||||
|
||||
return score
|
||||
|
||||
40
generator.py
40
generator.py
@@ -1,19 +1,20 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
def GeneratorBlock(in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
|
||||
padding = (kernel_size - 1) // 2 * dilation
|
||||
|
||||
return nn.Sequential(
|
||||
nn.Conv1d(
|
||||
|
||||
weight_norm(nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding
|
||||
),
|
||||
nn.InstanceNorm1d(out_channels),
|
||||
)),
|
||||
nn.PReLU(num_parameters=1, init=0.1),
|
||||
)
|
||||
|
||||
@@ -22,9 +23,9 @@ class AttentionBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super(AttentionBlock, self).__init__()
|
||||
self.attention = nn.Sequential(
|
||||
nn.Conv1d(channels, channels // 4, kernel_size=1),
|
||||
weight_norm(nn.Conv1d(channels, channels // 4, kernel_size=1)),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv1d(channels // 4, channels, kernel_size=1),
|
||||
weight_norm(nn.Conv1d(channels // 4, channels, kernel_size=1)),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
@@ -49,21 +50,21 @@ class ResidualInResidualBlock(nn.Module):
|
||||
x = self.attention(x)
|
||||
return x + residual
|
||||
|
||||
def UpsampleBlock(in_channels, out_channels):
|
||||
def UpsampleBlock(in_channels, out_channels, scale_factor=2):
|
||||
return nn.Sequential(
|
||||
nn.ConvTranspose1d(
|
||||
nn.Upsample(scale_factor=scale_factor, mode='nearest'),
|
||||
weight_norm(nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1
|
||||
),
|
||||
nn.InstanceNorm1d(out_channels),
|
||||
)),
|
||||
nn.PReLU(num_parameters=1, init=0.1)
|
||||
)
|
||||
|
||||
class SISUGenerator(nn.Module):
|
||||
def __init__(self, channels=32, num_rirb=1):
|
||||
def __init__(self, channels=32, num_rirb=4):
|
||||
super(SISUGenerator, self).__init__()
|
||||
|
||||
self.first_conv = GeneratorBlock(1, channels)
|
||||
@@ -73,10 +74,9 @@ class SISUGenerator(nn.Module):
|
||||
self.downsample_2 = GeneratorBlock(channels * 2, channels * 4, stride=2)
|
||||
self.downsample_2_attn = AttentionBlock(channels * 4)
|
||||
|
||||
self.rirb = ResidualInResidualBlock(channels * 4)
|
||||
# self.rirb = nn.Sequential(
|
||||
# *[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)]
|
||||
# )
|
||||
self.rirb = nn.Sequential(
|
||||
*[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)]
|
||||
)
|
||||
|
||||
self.upsample = UpsampleBlock(channels * 4, channels * 2)
|
||||
self.upsample_attn = AttentionBlock(channels * 2)
|
||||
@@ -87,13 +87,15 @@ class SISUGenerator(nn.Module):
|
||||
self.compress_2 = GeneratorBlock(channels * 2, channels)
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
nn.Conv1d(channels, 1, kernel_size=7, padding=3),
|
||||
weight_norm(nn.Conv1d(channels, 1, kernel_size=7, padding=3)),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
residual_input = x
|
||||
|
||||
# Encoding
|
||||
x1 = self.first_conv(x)
|
||||
|
||||
x2 = self.downsample(x1)
|
||||
@@ -102,8 +104,10 @@ class SISUGenerator(nn.Module):
|
||||
x3 = self.downsample_2(x2)
|
||||
x3 = self.downsample_2_attn(x3)
|
||||
|
||||
# Bottleneck (Deep Residual processing)
|
||||
x_rirb = self.rirb(x3)
|
||||
|
||||
# Decoding with Skip Connections
|
||||
up1 = self.upsample(x_rirb)
|
||||
up1 = self.upsample_attn(up1)
|
||||
|
||||
|
||||
26
training.py
26
training.py
@@ -3,7 +3,6 @@ import datetime
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import tqdm
|
||||
from accelerate import Accelerator
|
||||
@@ -23,7 +22,7 @@ parser.add_argument(
|
||||
"--epochs", type=int, default=5000, help="Number of training epochs"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
|
||||
parser.add_argument("--num_workers", type=int, default=2, help="DataLoader num_workers")
|
||||
parser.add_argument("--num_workers", type=int, default=4, help="DataLoader num_workers") # Increased workers slightly
|
||||
parser.add_argument("--debug", action="store_true", help="Print debug logs")
|
||||
parser.add_argument(
|
||||
"--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA"
|
||||
@@ -94,8 +93,6 @@ scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer_d, mode="min", factor=0.5, patience=5
|
||||
)
|
||||
|
||||
criterion_d = nn.MSELoss()
|
||||
|
||||
# ---------------------------
|
||||
# Prepare accelerator
|
||||
# ---------------------------
|
||||
@@ -131,6 +128,7 @@ def save_ckpt(path, epoch):
|
||||
start_epoch = 0
|
||||
if args.resume:
|
||||
ckpt_path = os.path.join(models_dir, "last.pt")
|
||||
if os.path.exists(ckpt_path):
|
||||
ckpt = torch.load(ckpt_path)
|
||||
|
||||
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
|
||||
@@ -142,12 +140,13 @@ if args.resume:
|
||||
|
||||
start_epoch = ckpt.get("epoch", 1)
|
||||
accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!")
|
||||
|
||||
real_buf = torch.full((loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32)
|
||||
fake_buf = torch.zeros((loader_batch_size, 1), device=accelerator.device, dtype=torch.float32)
|
||||
else:
|
||||
accelerator.print("⚠️ | Resume requested but no checkpoint found. Starting fresh.")
|
||||
|
||||
accelerator.print("🏋️ | Started training...")
|
||||
|
||||
smallest_loss = float('inf')
|
||||
|
||||
try:
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
generator.train()
|
||||
@@ -164,11 +163,6 @@ try:
|
||||
(high_quality, low_quality),
|
||||
(high_sample_rate, low_sample_rate),
|
||||
) in enumerate(progress_bar):
|
||||
batch_size = high_quality.size(0)
|
||||
|
||||
real_labels = real_buf[:batch_size].to(accelerator.device)
|
||||
fake_labels = fake_buf[:batch_size].to(accelerator.device)
|
||||
|
||||
with accelerator.autocast():
|
||||
generator_output = generator(low_quality)
|
||||
|
||||
@@ -179,10 +173,7 @@ try:
|
||||
d_loss = discriminator_train(
|
||||
high_quality,
|
||||
low_quality.detach(),
|
||||
real_labels,
|
||||
fake_labels,
|
||||
discriminator,
|
||||
criterion_d,
|
||||
generator_output.detach()
|
||||
)
|
||||
|
||||
@@ -197,10 +188,8 @@ try:
|
||||
g_total, g_adv = generator_train(
|
||||
low_quality,
|
||||
high_quality,
|
||||
real_labels,
|
||||
generator,
|
||||
discriminator,
|
||||
criterion_d,
|
||||
generator_output
|
||||
)
|
||||
|
||||
@@ -241,6 +230,9 @@ try:
|
||||
scheduler_g.step(mean_g)
|
||||
|
||||
save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
|
||||
if smallest_loss > mean_g:
|
||||
smallest_loss = mean_g
|
||||
save_ckpt(os.path.join(models_dir, "latest-smallest_loss.pt"), epoch)
|
||||
accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
|
||||
|
||||
except Exception:
|
||||
|
||||
@@ -1,58 +1,113 @@
|
||||
import torch
|
||||
|
||||
import torch.nn.functional as F
|
||||
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
|
||||
|
||||
# stft_loss_fn = MultiResolutionSTFTLoss(
|
||||
# fft_sizes=[512, 1024, 2048, 4096],
|
||||
# hop_sizes=[128, 256, 512, 1024],
|
||||
# win_lengths=[512, 1024, 2048, 4096]
|
||||
# )
|
||||
|
||||
stft_loss_fn = MultiResolutionSTFTLoss(
|
||||
fft_sizes=[512, 1024, 2048],
|
||||
hop_sizes=[64, 128, 256],
|
||||
win_lengths=[256, 512, 1024]
|
||||
)
|
||||
|
||||
def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor:
|
||||
absolute_difference = torch.abs(input_one - input_two)
|
||||
return torch.mean(absolute_difference)
|
||||
def feature_matching_loss(fmap_r, fmap_g):
|
||||
"""
|
||||
Computes L1 distance between real and fake feature maps.
|
||||
"""
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
# Stop gradient on real features to save memory/computation
|
||||
rl = rl.detach()
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
# Scale by number of feature maps to keep loss magnitude reasonable
|
||||
return loss * 2
|
||||
|
||||
|
||||
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||
"""
|
||||
Least Squares GAN Loss (LSGAN) for the Discriminator.
|
||||
Objective: Real -> 1, Fake -> 0
|
||||
"""
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
|
||||
# Iterate over both MPD and MSD outputs
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
# Real should be 1.0
|
||||
r_loss = torch.mean((dr - 1) ** 2)
|
||||
# Fake should be 0.0
|
||||
g_loss = torch.mean(dg ** 2)
|
||||
|
||||
loss += (r_loss + g_loss)
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
|
||||
return loss, r_losses, g_losses
|
||||
|
||||
|
||||
def generator_adv_loss(disc_generated_outputs):
|
||||
"""
|
||||
Least Squares GAN Loss for the Generator.
|
||||
Objective: Fake -> 1 (Fool the discriminator)
|
||||
"""
|
||||
loss = 0
|
||||
for dg in zip(disc_generated_outputs):
|
||||
dg = dg[0] # Unpack tuple
|
||||
loss += torch.mean((dg - 1) ** 2)
|
||||
return loss
|
||||
|
||||
|
||||
def discriminator_train(
|
||||
high_quality,
|
||||
low_quality,
|
||||
high_labels,
|
||||
low_labels,
|
||||
discriminator,
|
||||
criterion,
|
||||
generator_output
|
||||
):
|
||||
# 1. Forward pass through the Ensemble Discriminator
|
||||
# Note: We pass inputs separately now: (Real_Target, Fake_Candidate)
|
||||
# We detach generator_output because we are only optimizing D here
|
||||
y_d_rs, y_d_gs, _, _ = discriminator(high_quality, generator_output.detach())
|
||||
|
||||
real_pair = torch.cat((low_quality, high_quality), dim=1)
|
||||
decision_real = discriminator(real_pair)
|
||||
d_loss_real = criterion(decision_real, high_labels)
|
||||
# 2. Calculate Loss (LSGAN)
|
||||
d_loss, _, _ = discriminator_loss(y_d_rs, y_d_gs)
|
||||
|
||||
fake_pair = torch.cat((low_quality, generator_output), dim=1)
|
||||
decision_fake = discriminator(fake_pair)
|
||||
d_loss_fake = criterion(decision_fake, low_labels)
|
||||
|
||||
d_loss = (d_loss_real + d_loss_fake) / 2.0
|
||||
return d_loss
|
||||
|
||||
|
||||
def generator_train(
|
||||
low_quality, high_quality, real_labels, generator, discriminator, adv_criterion, generator_output):
|
||||
low_quality,
|
||||
high_quality,
|
||||
generator,
|
||||
discriminator,
|
||||
generator_output
|
||||
):
|
||||
# 1. Forward pass through Discriminator
|
||||
# We do NOT detach generator_output here, we need gradients for G
|
||||
y_d_rs, y_d_gs, fmap_rs, fmap_gs = discriminator(high_quality, generator_output)
|
||||
|
||||
fake_pair = torch.cat((low_quality, generator_output), dim=1)
|
||||
# 2. Adversarial Loss (Try to fool D into thinking G is Real)
|
||||
loss_gen_adv = generator_adv_loss(y_d_gs)
|
||||
|
||||
discriminator_decision = discriminator(fake_pair)
|
||||
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
|
||||
# 3. Feature Matching Loss (Force G to match internal features of D)
|
||||
loss_fm = feature_matching_loss(fmap_rs, fmap_gs)
|
||||
|
||||
mae_loss = signal_mae(generator_output, high_quality)
|
||||
# 4. Mel-Spectrogram / STFT Loss (Audio Quality)
|
||||
stft_loss = stft_loss_fn(high_quality, generator_output)["total"]
|
||||
|
||||
lambda_mae = 10.0
|
||||
lambda_stft = 2.5
|
||||
lambda_adv = 2.5
|
||||
combined_loss = (lambda_mae * mae_loss) + (lambda_stft * stft_loss) + (lambda_adv * adversarial_loss)
|
||||
return combined_loss, adversarial_loss
|
||||
# -----------------------------------------
|
||||
# 5. Combine Losses
|
||||
# -----------------------------------------
|
||||
# Recommended weights for HiFi-GAN/EnCodec style architectures:
|
||||
# STFT is dominant (45), FM provides stability (2), Adv provides texture (1)
|
||||
|
||||
lambda_stft = 45.0
|
||||
lambda_fm = 2.0
|
||||
lambda_adv = 1.0
|
||||
|
||||
combined_loss = (lambda_stft * stft_loss) + \
|
||||
(lambda_fm * loss_fm) + \
|
||||
(lambda_adv * loss_gen_adv)
|
||||
|
||||
return combined_loss, loss_gen_adv
|
||||
|
||||
Reference in New Issue
Block a user