⚗️ | Added MultiPeriodDiscriminator implementation from HiFi-GAN
This commit is contained in:
247
discriminator.py
247
discriminator.py
@@ -1,98 +1,179 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.utils as utils
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.parametrizations import weight_norm, spectral_norm
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
"""
|
||||
Converts raw audio into a sequence of embeddings (tokens).
|
||||
Small patch_size = Higher Precision (more tokens, finer detail).
|
||||
Large patch_size = Lower Precision (fewer tokens, more global).
|
||||
"""
|
||||
def __init__(self, in_channels, embed_dim, patch_size, spectral_norm=True):
|
||||
super().__init__()
|
||||
# We use a Conv1d with stride=patch_size to create non-overlapping patches
|
||||
self.proj = nn.Conv1d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
# -------------------------------------------------------------------
|
||||
# 1. Multi-Period Discriminator (MPD)
|
||||
# Captures periodic structures (pitch/timbre) by folding audio.
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
if spectral_norm:
|
||||
self.proj = utils.spectral_norm(self.proj)
|
||||
class DiscriminatorP(nn.Module):
|
||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super(DiscriminatorP, self).__init__()
|
||||
self.period = period
|
||||
self.use_spectral_norm = use_spectral_norm
|
||||
|
||||
# Use spectral_norm for stability, or weight_norm for performance
|
||||
norm_f = spectral_norm if use_spectral_norm else weight_norm
|
||||
|
||||
# We use 2D convs because we "fold" the 1D audio into 2D (Period x Time)
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(2, 0))),
|
||||
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(2, 0))),
|
||||
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(2, 0))),
|
||||
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(2, 0))),
|
||||
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||
])
|
||||
|
||||
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
# x shape: (batch, 1, 8000)
|
||||
x = self.proj(x) # shape: (batch, embed_dim, num_patches)
|
||||
x = x.transpose(1, 2) # shape: (batch, num_patches, embed_dim)
|
||||
return x
|
||||
fmap = []
|
||||
|
||||
class TransformerDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
audio_length=8000,
|
||||
patch_size=16, # Lower this for higher precision (e.g., 8 or 16)
|
||||
embed_dim=128, # Dimension of the transformer tokens
|
||||
depth=4, # Number of Transformer blocks
|
||||
heads=4, # Number of attention heads
|
||||
mlp_dim=256, # Hidden dimension of the feed-forward layer
|
||||
spectral_norm=True
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 1. Calculate sequence length
|
||||
self.num_patches = audio_length // patch_size
|
||||
|
||||
# 2. Patch Embedding (Tokenizer)
|
||||
self.patch_embed = PatchEmbedding(1, embed_dim, patch_size, spectral_norm)
|
||||
|
||||
# 3. Class Token (like in BERT/ViT) to aggregate global info
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
|
||||
# 4. Positional Embedding (Learnable)
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
|
||||
|
||||
# 5. Transformer Encoder
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=embed_dim,
|
||||
nhead=heads,
|
||||
dim_feedforward=mlp_dim,
|
||||
dropout=0.1,
|
||||
activation='gelu',
|
||||
batch_first=True
|
||||
)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
|
||||
|
||||
# 6. Final Classification Head
|
||||
self.norm = nn.LayerNorm(embed_dim)
|
||||
self.head = nn.Linear(embed_dim, 1)
|
||||
|
||||
if spectral_norm:
|
||||
self.head = utils.spectral_norm(self.head)
|
||||
|
||||
# Initialize weights
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
nn.init.normal_(self.cls_token, std=0.02)
|
||||
nn.init.normal_(self.pos_embed, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
# 1d to 2d conversion: [B, C, T] -> [B, C, T/P, P]
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # Pad if not divisible by period
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
|
||||
# --- 1. Tokenize Audio ---
|
||||
x = self.patch_embed(x) # (Batch, Num_Patches, Embed_Dim)
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
# --- 2. Add CLS Token ---
|
||||
cls_tokens = self.cls_token.expand(b, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1) # (Batch, Num_Patches + 1, Embed_Dim)
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, 0.1)
|
||||
fmap.append(x) # Store feature map for Feature Matching Loss
|
||||
|
||||
# --- 3. Add Positional Embeddings ---
|
||||
x = x + self.pos_embed
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
|
||||
# --- 4. Transformer Layers ---
|
||||
x = self.transformer(x)
|
||||
# Flatten back to 1D for score
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
# --- 5. Classification (Use only CLS token) ---
|
||||
cls_output = x[:, 0] # Take the first token
|
||||
cls_output = self.norm(cls_output)
|
||||
return x, fmap
|
||||
|
||||
score = self.head(cls_output) # (Batch, 1)
|
||||
|
||||
return score
|
||||
class MultiPeriodDiscriminator(nn.Module):
|
||||
def __init__(self, periods=[2, 3, 5, 7, 11]):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList([
|
||||
DiscriminatorP(p) for p in periods
|
||||
])
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = [] # Real scores
|
||||
y_d_gs = [] # Generated (Fake) scores
|
||||
fmap_rs = [] # Real feature maps
|
||||
fmap_gs = [] # Generated (Fake) feature maps
|
||||
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# 2. Multi-Scale Discriminator (MSD)
|
||||
# Captures structure at different audio resolutions (raw, x0.5, x0.25).
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
class DiscriminatorS(nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(DiscriminatorS, self).__init__()
|
||||
norm_f = spectral_norm if use_spectral_norm else weight_norm
|
||||
|
||||
# Standard 1D Convolutions with large receptive field
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)),
|
||||
norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
||||
norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
||||
norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
||||
norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
||||
norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
])
|
||||
self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, 0.1)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiScaleDiscriminator(nn.Module):
|
||||
def __init__(self):
|
||||
super(MultiScaleDiscriminator, self).__init__()
|
||||
# 3 Scales: Original, Downsampled x2, Downsampled x4
|
||||
self.discriminators = nn.ModuleList([
|
||||
DiscriminatorS(use_spectral_norm=True),
|
||||
DiscriminatorS(),
|
||||
DiscriminatorS(),
|
||||
])
|
||||
self.meanpools = nn.ModuleList([
|
||||
nn.AvgPool1d(4, 2, padding=2),
|
||||
nn.AvgPool1d(4, 2, padding=2)
|
||||
])
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
|
||||
for i, d in enumerate(self.discriminators):
|
||||
if i != 0:
|
||||
# Downsample input for subsequent discriminators
|
||||
y = self.meanpools[i-1](y)
|
||||
y_hat = self.meanpools[i-1](y_hat)
|
||||
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# 3. Master Wrapper
|
||||
# Combines MPD and MSD into one class to fit your training script.
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
class SISUDiscriminator(nn.Module):
|
||||
def __init__(self):
|
||||
super(SISUDiscriminator, self).__init__()
|
||||
self.mpd = MultiPeriodDiscriminator()
|
||||
self.msd = MultiScaleDiscriminator()
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
# Return format:
|
||||
# scores_real, scores_fake, features_real, features_fake
|
||||
|
||||
# Run Multi-Period
|
||||
mpd_y_d_rs, mpd_y_d_gs, mpd_fmap_rs, mpd_fmap_gs = self.mpd(y, y_hat)
|
||||
|
||||
# Run Multi-Scale
|
||||
msd_y_d_rs, msd_y_d_gs, msd_fmap_rs, msd_fmap_gs = self.msd(y, y_hat)
|
||||
|
||||
# Combine all results
|
||||
return (
|
||||
mpd_y_d_rs + msd_y_d_rs, # All real scores
|
||||
mpd_y_d_gs + msd_y_d_gs, # All fake scores
|
||||
mpd_fmap_rs + msd_fmap_rs, # All real feature maps
|
||||
mpd_fmap_gs + msd_fmap_gs # All fake feature maps
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user