Files
SISU/discriminator.py

99 lines
3.3 KiB
Python

import torch
import torch.nn as nn
import torch.nn.utils as utils
import numpy as np
class PatchEmbedding(nn.Module):
"""
Converts raw audio into a sequence of embeddings (tokens).
Small patch_size = Higher Precision (more tokens, finer detail).
Large patch_size = Lower Precision (fewer tokens, more global).
"""
def __init__(self, in_channels, embed_dim, patch_size, spectral_norm=True):
super().__init__()
# We use a Conv1d with stride=patch_size to create non-overlapping patches
self.proj = nn.Conv1d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
if spectral_norm:
self.proj = utils.spectral_norm(self.proj)
def forward(self, x):
# x shape: (batch, 1, 8000)
x = self.proj(x) # shape: (batch, embed_dim, num_patches)
x = x.transpose(1, 2) # shape: (batch, num_patches, embed_dim)
return x
class TransformerDiscriminator(nn.Module):
def __init__(
self,
audio_length=8000,
patch_size=16, # Lower this for higher precision (e.g., 8 or 16)
embed_dim=128, # Dimension of the transformer tokens
depth=4, # Number of Transformer blocks
heads=4, # Number of attention heads
mlp_dim=256, # Hidden dimension of the feed-forward layer
spectral_norm=True
):
super().__init__()
# 1. Calculate sequence length
self.num_patches = audio_length // patch_size
# 2. Patch Embedding (Tokenizer)
self.patch_embed = PatchEmbedding(1, embed_dim, patch_size, spectral_norm)
# 3. Class Token (like in BERT/ViT) to aggregate global info
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# 4. Positional Embedding (Learnable)
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
# 5. Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=heads,
dim_feedforward=mlp_dim,
dropout=0.1,
activation='gelu',
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
# 6. Final Classification Head
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, 1)
if spectral_norm:
self.head = utils.spectral_norm(self.head)
# Initialize weights
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.cls_token, std=0.02)
nn.init.normal_(self.pos_embed, std=0.02)
def forward(self, x):
b, c, t = x.shape
# --- 1. Tokenize Audio ---
x = self.patch_embed(x) # (Batch, Num_Patches, Embed_Dim)
# --- 2. Add CLS Token ---
cls_tokens = self.cls_token.expand(b, -1, -1)
x = torch.cat((cls_tokens, x), dim=1) # (Batch, Num_Patches + 1, Embed_Dim)
# --- 3. Add Positional Embeddings ---
x = x + self.pos_embed
# --- 4. Transformer Layers ---
x = self.transformer(x)
# --- 5. Classification (Use only CLS token) ---
cls_output = x[:, 0] # Take the first token
cls_output = self.norm(cls_output)
score = self.head(cls_output) # (Batch, 1)
return score