146 lines
5.7 KiB
Python
146 lines
5.7 KiB
Python
import jax
|
|
import jax.numpy as jnp
|
|
from flax import linen as nn
|
|
from typing import Sequence, Tuple
|
|
|
|
# Assume InstanceNorm1d and AttentionBlock are defined as in the generator conversion
|
|
# --- Custom InstanceNorm1d Implementation (from Generator) ---
|
|
class InstanceNorm1d(nn.Module):
|
|
features: int
|
|
epsilon: float = 1e-5
|
|
use_scale: bool = True
|
|
use_bias: bool = True
|
|
@nn.compact
|
|
def __call__(self, x):
|
|
if x.shape[-1] != self.features:
|
|
raise ValueError(f"Input features {x.shape[-1]} does not match InstanceNorm1d features {self.features}")
|
|
mean = jnp.mean(x, axis=1, keepdims=True)
|
|
var = jnp.var(x, axis=1, keepdims=True)
|
|
normalized = (x - mean) / jnp.sqrt(var + self.epsilon)
|
|
if self.use_scale:
|
|
scale = self.param('scale', nn.initializers.ones, (self.features,))
|
|
normalized *= scale
|
|
if self.use_bias:
|
|
bias = self.param('bias', nn.initializers.zeros, (self.features,))
|
|
normalized += bias
|
|
return normalized
|
|
|
|
# --- AttentionBlock Implementation (from Generator) ---
|
|
class AttentionBlock(nn.Module):
|
|
channels: int
|
|
@nn.compact
|
|
def __call__(self, x):
|
|
ks1 = (1,)
|
|
attention_weights = nn.Conv(features=self.channels // 4, kernel_size=ks1, padding='SAME')(x)
|
|
attention_weights = nn.relu(attention_weights)
|
|
attention_weights = nn.Conv(features=self.channels, kernel_size=ks1, padding='SAME')(attention_weights)
|
|
attention_weights = nn.sigmoid(attention_weights)
|
|
return x * attention_weights
|
|
|
|
# --- Converted Discriminator Modules ---
|
|
|
|
class DiscriminatorBlock(nn.Module):
|
|
"""Equivalent of the PyTorch discriminator_block function."""
|
|
in_channels: int # Needed for clarity, though not strictly used by layers if input shape is known
|
|
out_channels: int
|
|
kernel_size: int = 3
|
|
stride: int = 1
|
|
dilation: int = 1
|
|
# spectral_norm: bool = True # Flag for where SN would be applied
|
|
use_instance_norm: bool = True
|
|
negative_slope: float = 0.2
|
|
|
|
@nn.compact
|
|
def __call__(self, x):
|
|
"""
|
|
Args:
|
|
x: Input tensor (N, L, C_in)
|
|
Returns:
|
|
Output tensor (N, L', C_out) - L' depends on stride/padding
|
|
"""
|
|
# Flax Conv expects kernel_size, stride, dilation as sequences (tuples)
|
|
ks = (self.kernel_size,)
|
|
st = (self.stride,)
|
|
di = (self.dilation,)
|
|
|
|
# Padding='SAME' works reasonably well for stride=1 and stride=2 downsampling
|
|
# NOTE: Spectral Norm is omitted here.
|
|
# If implementing, you'd wrap or replace nn.Conv with a spectral-normalized version.
|
|
# conv_layer = SpectralNormConv1D(...) or wrap(nn.Conv(...))
|
|
y = nn.Conv(
|
|
features=self.out_channels,
|
|
kernel_size=ks,
|
|
strides=st,
|
|
kernel_dilation=di,
|
|
padding='SAME' # Often used in GANs
|
|
)(x)
|
|
|
|
# Apply LeakyReLU first (as in the original code if IN is used)
|
|
y = nn.leaky_relu(y, negative_slope=self.negative_slope)
|
|
|
|
# Conditionally apply InstanceNorm
|
|
if self.use_instance_norm:
|
|
y = InstanceNorm1d(features=self.out_channels)(y)
|
|
|
|
return y
|
|
|
|
class SISUDiscriminator(nn.Module):
|
|
"""SISUDiscriminator model translated to Flax."""
|
|
base_channels: int = 16
|
|
|
|
@nn.compact
|
|
def __call__(self, x):
|
|
"""
|
|
Args:
|
|
x: Input tensor (N, L, 1) - assumes single channel input
|
|
Returns:
|
|
Output tensor (N, 1) - logits
|
|
"""
|
|
if x.shape[-1] != 1:
|
|
raise ValueError(f"Input should have 1 channel (NLC format), got shape {x.shape}")
|
|
|
|
ch = self.base_channels
|
|
|
|
# Block 1: 1 -> ch, k=7, s=1, d=1, SN=T, IN=F
|
|
# NOTE: Spectral Norm omitted
|
|
y = DiscriminatorBlock(in_channels=1, out_channels=ch, kernel_size=7, stride=1, use_instance_norm=False)(x)
|
|
|
|
# Block 2: ch -> ch*2, k=5, s=2, d=1, SN=T, IN=T
|
|
# NOTE: Spectral Norm omitted
|
|
y = DiscriminatorBlock(in_channels=ch, out_channels=ch*2, kernel_size=5, stride=2, use_instance_norm=True)(y)
|
|
|
|
# Block 3: ch*2 -> ch*4, k=5, s=1, d=2, SN=T, IN=T
|
|
# NOTE: Spectral Norm omitted
|
|
y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch*4, kernel_size=5, stride=1, dilation=2, use_instance_norm=True)(y)
|
|
|
|
# Attention Block
|
|
y = AttentionBlock(channels=ch*4)(y)
|
|
|
|
# Block 4: ch*4 -> ch*8, k=5, s=1, d=4, SN=T, IN=T
|
|
# NOTE: Spectral Norm omitted
|
|
y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*8, kernel_size=5, stride=1, dilation=4, use_instance_norm=True)(y)
|
|
|
|
# Block 5: ch*8 -> ch*4, k=5, s=2, d=1, SN=T, IN=T
|
|
# NOTE: Spectral Norm omitted
|
|
y = DiscriminatorBlock(in_channels=ch*8, out_channels=ch*4, kernel_size=5, stride=2, use_instance_norm=True)(y)
|
|
|
|
# Block 6: ch*4 -> ch*2, k=3, s=1, d=1, SN=T, IN=T
|
|
# NOTE: Spectral Norm omitted
|
|
y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*2, kernel_size=3, stride=1, use_instance_norm=True)(y)
|
|
|
|
# Block 7: ch*2 -> ch, k=3, s=1, d=1, SN=T, IN=T
|
|
# NOTE: Spectral Norm omitted
|
|
y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch, kernel_size=3, stride=1, use_instance_norm=True)(y)
|
|
|
|
# Block 8: ch -> 1, k=3, s=1, d=1, SN=F, IN=F
|
|
# NOTE: Spectral Norm omitted (as per original config)
|
|
y = DiscriminatorBlock(in_channels=ch, out_channels=1, kernel_size=3, stride=1, use_instance_norm=False)(y)
|
|
|
|
# Global Average Pooling (across Length dimension)
|
|
pooled = jnp.mean(y, axis=1) # Shape becomes (N, C=1)
|
|
|
|
# Flatten (optional, as shape is likely already (N, 1))
|
|
output = jnp.reshape(pooled, (pooled.shape[0], -1)) # Shape (N, 1)
|
|
|
|
return output
|