⚗️ | Added MultiPeriodDiscriminator implementation from HiFi-GAN

This commit is contained in:
2025-12-04 14:22:48 +02:00
parent 782a3bab28
commit bf0a6e58e9
4 changed files with 210 additions and 131 deletions

View File

@@ -1,70 +1,98 @@
import torch
import torch.nn as nn
import torch.nn.utils as utils
import numpy as np
class PatchEmbedding(nn.Module):
"""
Converts raw audio into a sequence of embeddings (tokens).
Small patch_size = Higher Precision (more tokens, finer detail).
Large patch_size = Lower Precision (fewer tokens, more global).
"""
def __init__(self, in_channels, embed_dim, patch_size, spectral_norm=True):
super().__init__()
# We use a Conv1d with stride=patch_size to create non-overlapping patches
self.proj = nn.Conv1d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def discriminator_block(
in_channels,
out_channels,
kernel_size=15,
stride=1,
dilation=1
):
padding = dilation * (kernel_size - 1) // 2
conv_layer = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding
)
conv_layer = utils.spectral_norm(conv_layer)
leaky_relu = nn.LeakyReLU(0.2)
return nn.Sequential(conv_layer, leaky_relu)
class AttentionBlock(nn.Module):
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid(),
)
if spectral_norm:
self.proj = utils.spectral_norm(self.proj)
def forward(self, x):
attention_weights = self.attention(x)
return x + (x * attention_weights)
# x shape: (batch, 1, 8000)
x = self.proj(x) # shape: (batch, embed_dim, num_patches)
x = x.transpose(1, 2) # shape: (batch, num_patches, embed_dim)
return x
class TransformerDiscriminator(nn.Module):
def __init__(
self,
audio_length=8000,
patch_size=16, # Lower this for higher precision (e.g., 8 or 16)
embed_dim=128, # Dimension of the transformer tokens
depth=4, # Number of Transformer blocks
heads=4, # Number of attention heads
mlp_dim=256, # Hidden dimension of the feed-forward layer
spectral_norm=True
):
super().__init__()
class SISUDiscriminator(nn.Module):
def __init__(self, layers=8):
super(SISUDiscriminator, self).__init__()
self.discriminator_blocks = nn.Sequential(
# 1 -> 32
discriminator_block(2, layers),
AttentionBlock(layers),
# 32 -> 64
discriminator_block(layers, layers * 2, dilation=2),
# 64 -> 128
discriminator_block(layers * 2, layers * 4, dilation=4),
AttentionBlock(layers * 4),
# 128 -> 256
discriminator_block(layers * 4, layers * 8, stride=4),
# 256 -> 512
# discriminator_block(layers * 8, layers * 16, stride=4)
# 1. Calculate sequence length
self.num_patches = audio_length // patch_size
# 2. Patch Embedding (Tokenizer)
self.patch_embed = PatchEmbedding(1, embed_dim, patch_size, spectral_norm)
# 3. Class Token (like in BERT/ViT) to aggregate global info
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# 4. Positional Embedding (Learnable)
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
# 5. Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=heads,
dim_feedforward=mlp_dim,
dropout=0.1,
activation='gelu',
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
self.final_conv = nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1)
# 6. Final Classification Head
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, 1)
self.avg_pool = nn.AdaptiveAvgPool1d(1)
if spectral_norm:
self.head = utils.spectral_norm(self.head)
# Initialize weights
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.cls_token, std=0.02)
nn.init.normal_(self.pos_embed, std=0.02)
def forward(self, x):
x = self.discriminator_blocks(x)
x = self.final_conv(x)
x = self.avg_pool(x)
return x.squeeze(2)
b, c, t = x.shape
# --- 1. Tokenize Audio ---
x = self.patch_embed(x) # (Batch, Num_Patches, Embed_Dim)
# --- 2. Add CLS Token ---
cls_tokens = self.cls_token.expand(b, -1, -1)
x = torch.cat((cls_tokens, x), dim=1) # (Batch, Num_Patches + 1, Embed_Dim)
# --- 3. Add Positional Embeddings ---
x = x + self.pos_embed
# --- 4. Transformer Layers ---
x = self.transformer(x)
# --- 5. Classification (Use only CLS token) ---
cls_output = x[:, 0] # Take the first token
cls_output = self.norm(cls_output)
score = self.head(cls_output) # (Batch, 1)
return score