99 lines
3.3 KiB
Python
99 lines
3.3 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.utils as utils
|
|
import numpy as np
|
|
|
|
class PatchEmbedding(nn.Module):
|
|
"""
|
|
Converts raw audio into a sequence of embeddings (tokens).
|
|
Small patch_size = Higher Precision (more tokens, finer detail).
|
|
Large patch_size = Lower Precision (fewer tokens, more global).
|
|
"""
|
|
def __init__(self, in_channels, embed_dim, patch_size, spectral_norm=True):
|
|
super().__init__()
|
|
# We use a Conv1d with stride=patch_size to create non-overlapping patches
|
|
self.proj = nn.Conv1d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
|
|
|
if spectral_norm:
|
|
self.proj = utils.spectral_norm(self.proj)
|
|
|
|
def forward(self, x):
|
|
# x shape: (batch, 1, 8000)
|
|
x = self.proj(x) # shape: (batch, embed_dim, num_patches)
|
|
x = x.transpose(1, 2) # shape: (batch, num_patches, embed_dim)
|
|
return x
|
|
|
|
class TransformerDiscriminator(nn.Module):
|
|
def __init__(
|
|
self,
|
|
audio_length=8000,
|
|
patch_size=16, # Lower this for higher precision (e.g., 8 or 16)
|
|
embed_dim=128, # Dimension of the transformer tokens
|
|
depth=4, # Number of Transformer blocks
|
|
heads=4, # Number of attention heads
|
|
mlp_dim=256, # Hidden dimension of the feed-forward layer
|
|
spectral_norm=True
|
|
):
|
|
super().__init__()
|
|
|
|
# 1. Calculate sequence length
|
|
self.num_patches = audio_length // patch_size
|
|
|
|
# 2. Patch Embedding (Tokenizer)
|
|
self.patch_embed = PatchEmbedding(1, embed_dim, patch_size, spectral_norm)
|
|
|
|
# 3. Class Token (like in BERT/ViT) to aggregate global info
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
|
|
# 4. Positional Embedding (Learnable)
|
|
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
|
|
|
|
# 5. Transformer Encoder
|
|
encoder_layer = nn.TransformerEncoderLayer(
|
|
d_model=embed_dim,
|
|
nhead=heads,
|
|
dim_feedforward=mlp_dim,
|
|
dropout=0.1,
|
|
activation='gelu',
|
|
batch_first=True
|
|
)
|
|
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
|
|
|
|
# 6. Final Classification Head
|
|
self.norm = nn.LayerNorm(embed_dim)
|
|
self.head = nn.Linear(embed_dim, 1)
|
|
|
|
if spectral_norm:
|
|
self.head = utils.spectral_norm(self.head)
|
|
|
|
# Initialize weights
|
|
self._init_weights()
|
|
|
|
def _init_weights(self):
|
|
nn.init.normal_(self.cls_token, std=0.02)
|
|
nn.init.normal_(self.pos_embed, std=0.02)
|
|
|
|
def forward(self, x):
|
|
b, c, t = x.shape
|
|
|
|
# --- 1. Tokenize Audio ---
|
|
x = self.patch_embed(x) # (Batch, Num_Patches, Embed_Dim)
|
|
|
|
# --- 2. Add CLS Token ---
|
|
cls_tokens = self.cls_token.expand(b, -1, -1)
|
|
x = torch.cat((cls_tokens, x), dim=1) # (Batch, Num_Patches + 1, Embed_Dim)
|
|
|
|
# --- 3. Add Positional Embeddings ---
|
|
x = x + self.pos_embed
|
|
|
|
# --- 4. Transformer Layers ---
|
|
x = self.transformer(x)
|
|
|
|
# --- 5. Classification (Use only CLS token) ---
|
|
cls_output = x[:, 0] # Take the first token
|
|
cls_output = self.norm(cls_output)
|
|
|
|
score = self.head(cls_output) # (Batch, 1)
|
|
|
|
return score
|