import torch import torch.nn as nn import torch.nn.utils as utils import numpy as np class PatchEmbedding(nn.Module): """ Converts raw audio into a sequence of embeddings (tokens). Small patch_size = Higher Precision (more tokens, finer detail). Large patch_size = Lower Precision (fewer tokens, more global). """ def __init__(self, in_channels, embed_dim, patch_size, spectral_norm=True): super().__init__() # We use a Conv1d with stride=patch_size to create non-overlapping patches self.proj = nn.Conv1d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) if spectral_norm: self.proj = utils.spectral_norm(self.proj) def forward(self, x): # x shape: (batch, 1, 8000) x = self.proj(x) # shape: (batch, embed_dim, num_patches) x = x.transpose(1, 2) # shape: (batch, num_patches, embed_dim) return x class TransformerDiscriminator(nn.Module): def __init__( self, audio_length=8000, patch_size=16, # Lower this for higher precision (e.g., 8 or 16) embed_dim=128, # Dimension of the transformer tokens depth=4, # Number of Transformer blocks heads=4, # Number of attention heads mlp_dim=256, # Hidden dimension of the feed-forward layer spectral_norm=True ): super().__init__() # 1. Calculate sequence length self.num_patches = audio_length // patch_size # 2. Patch Embedding (Tokenizer) self.patch_embed = PatchEmbedding(1, embed_dim, patch_size, spectral_norm) # 3. Class Token (like in BERT/ViT) to aggregate global info self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # 4. Positional Embedding (Learnable) self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim)) # 5. Transformer Encoder encoder_layer = nn.TransformerEncoderLayer( d_model=embed_dim, nhead=heads, dim_feedforward=mlp_dim, dropout=0.1, activation='gelu', batch_first=True ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth) # 6. Final Classification Head self.norm = nn.LayerNorm(embed_dim) self.head = nn.Linear(embed_dim, 1) if spectral_norm: self.head = utils.spectral_norm(self.head) # Initialize weights self._init_weights() def _init_weights(self): nn.init.normal_(self.cls_token, std=0.02) nn.init.normal_(self.pos_embed, std=0.02) def forward(self, x): b, c, t = x.shape # --- 1. Tokenize Audio --- x = self.patch_embed(x) # (Batch, Num_Patches, Embed_Dim) # --- 2. Add CLS Token --- cls_tokens = self.cls_token.expand(b, -1, -1) x = torch.cat((cls_tokens, x), dim=1) # (Batch, Num_Patches + 1, Embed_Dim) # --- 3. Add Positional Embeddings --- x = x + self.pos_embed # --- 4. Transformer Layers --- x = self.transformer(x) # --- 5. Classification (Use only CLS token) --- cls_output = x[:, 0] # Take the first token cls_output = self.norm(cls_output) score = self.head(cls_output) # (Batch, 1) return score