⚗️ | Added MultiPeriodDiscriminator implementation from HiFi-GAN
This commit is contained in:
142
discriminator.py
142
discriminator.py
@@ -1,70 +1,98 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.utils as utils
|
||||
import numpy as np
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
"""
|
||||
Converts raw audio into a sequence of embeddings (tokens).
|
||||
Small patch_size = Higher Precision (more tokens, finer detail).
|
||||
Large patch_size = Lower Precision (fewer tokens, more global).
|
||||
"""
|
||||
def __init__(self, in_channels, embed_dim, patch_size, spectral_norm=True):
|
||||
super().__init__()
|
||||
# We use a Conv1d with stride=patch_size to create non-overlapping patches
|
||||
self.proj = nn.Conv1d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def discriminator_block(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=15,
|
||||
stride=1,
|
||||
dilation=1
|
||||
):
|
||||
padding = dilation * (kernel_size - 1) // 2
|
||||
|
||||
conv_layer = nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding
|
||||
)
|
||||
|
||||
conv_layer = utils.spectral_norm(conv_layer)
|
||||
leaky_relu = nn.LeakyReLU(0.2)
|
||||
|
||||
return nn.Sequential(conv_layer, leaky_relu)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super(AttentionBlock, self).__init__()
|
||||
self.attention = nn.Sequential(
|
||||
nn.Conv1d(channels, channels // 4, kernel_size=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(channels // 4, channels, kernel_size=1),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
if spectral_norm:
|
||||
self.proj = utils.spectral_norm(self.proj)
|
||||
|
||||
def forward(self, x):
|
||||
attention_weights = self.attention(x)
|
||||
return x + (x * attention_weights)
|
||||
# x shape: (batch, 1, 8000)
|
||||
x = self.proj(x) # shape: (batch, embed_dim, num_patches)
|
||||
x = x.transpose(1, 2) # shape: (batch, num_patches, embed_dim)
|
||||
return x
|
||||
|
||||
class TransformerDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
audio_length=8000,
|
||||
patch_size=16, # Lower this for higher precision (e.g., 8 or 16)
|
||||
embed_dim=128, # Dimension of the transformer tokens
|
||||
depth=4, # Number of Transformer blocks
|
||||
heads=4, # Number of attention heads
|
||||
mlp_dim=256, # Hidden dimension of the feed-forward layer
|
||||
spectral_norm=True
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
class SISUDiscriminator(nn.Module):
|
||||
def __init__(self, layers=8):
|
||||
super(SISUDiscriminator, self).__init__()
|
||||
self.discriminator_blocks = nn.Sequential(
|
||||
# 1 -> 32
|
||||
discriminator_block(2, layers),
|
||||
AttentionBlock(layers),
|
||||
# 32 -> 64
|
||||
discriminator_block(layers, layers * 2, dilation=2),
|
||||
# 64 -> 128
|
||||
discriminator_block(layers * 2, layers * 4, dilation=4),
|
||||
AttentionBlock(layers * 4),
|
||||
# 128 -> 256
|
||||
discriminator_block(layers * 4, layers * 8, stride=4),
|
||||
# 256 -> 512
|
||||
# discriminator_block(layers * 8, layers * 16, stride=4)
|
||||
# 1. Calculate sequence length
|
||||
self.num_patches = audio_length // patch_size
|
||||
|
||||
# 2. Patch Embedding (Tokenizer)
|
||||
self.patch_embed = PatchEmbedding(1, embed_dim, patch_size, spectral_norm)
|
||||
|
||||
# 3. Class Token (like in BERT/ViT) to aggregate global info
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
|
||||
# 4. Positional Embedding (Learnable)
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
|
||||
|
||||
# 5. Transformer Encoder
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=embed_dim,
|
||||
nhead=heads,
|
||||
dim_feedforward=mlp_dim,
|
||||
dropout=0.1,
|
||||
activation='gelu',
|
||||
batch_first=True
|
||||
)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
|
||||
|
||||
self.final_conv = nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1)
|
||||
# 6. Final Classification Head
|
||||
self.norm = nn.LayerNorm(embed_dim)
|
||||
self.head = nn.Linear(embed_dim, 1)
|
||||
|
||||
self.avg_pool = nn.AdaptiveAvgPool1d(1)
|
||||
if spectral_norm:
|
||||
self.head = utils.spectral_norm(self.head)
|
||||
|
||||
# Initialize weights
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
nn.init.normal_(self.cls_token, std=0.02)
|
||||
nn.init.normal_(self.pos_embed, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.discriminator_blocks(x)
|
||||
x = self.final_conv(x)
|
||||
x = self.avg_pool(x)
|
||||
return x.squeeze(2)
|
||||
b, c, t = x.shape
|
||||
|
||||
# --- 1. Tokenize Audio ---
|
||||
x = self.patch_embed(x) # (Batch, Num_Patches, Embed_Dim)
|
||||
|
||||
# --- 2. Add CLS Token ---
|
||||
cls_tokens = self.cls_token.expand(b, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1) # (Batch, Num_Patches + 1, Embed_Dim)
|
||||
|
||||
# --- 3. Add Positional Embeddings ---
|
||||
x = x + self.pos_embed
|
||||
|
||||
# --- 4. Transformer Layers ---
|
||||
x = self.transformer(x)
|
||||
|
||||
# --- 5. Classification (Use only CLS token) ---
|
||||
cls_output = x[:, 0] # Take the first token
|
||||
cls_output = self.norm(cls_output)
|
||||
|
||||
score = self.head(cls_output) # (Batch, 1)
|
||||
|
||||
return score
|
||||
|
||||
Reference in New Issue
Block a user