import jax import jax.numpy as jnp from flax import linen as nn from typing import Sequence, Tuple # Assume InstanceNorm1d and AttentionBlock are defined as in the generator conversion # --- Custom InstanceNorm1d Implementation (from Generator) --- class InstanceNorm1d(nn.Module): features: int epsilon: float = 1e-5 use_scale: bool = True use_bias: bool = True @nn.compact def __call__(self, x): if x.shape[-1] != self.features: raise ValueError(f"Input features {x.shape[-1]} does not match InstanceNorm1d features {self.features}") mean = jnp.mean(x, axis=1, keepdims=True) var = jnp.var(x, axis=1, keepdims=True) normalized = (x - mean) / jnp.sqrt(var + self.epsilon) if self.use_scale: scale = self.param('scale', nn.initializers.ones, (self.features,)) normalized *= scale if self.use_bias: bias = self.param('bias', nn.initializers.zeros, (self.features,)) normalized += bias return normalized # --- AttentionBlock Implementation (from Generator) --- class AttentionBlock(nn.Module): channels: int @nn.compact def __call__(self, x): ks1 = (1,) attention_weights = nn.Conv(features=self.channels // 4, kernel_size=ks1, padding='SAME')(x) attention_weights = nn.relu(attention_weights) attention_weights = nn.Conv(features=self.channels, kernel_size=ks1, padding='SAME')(attention_weights) attention_weights = nn.sigmoid(attention_weights) return x * attention_weights # --- Converted Discriminator Modules --- class DiscriminatorBlock(nn.Module): """Equivalent of the PyTorch discriminator_block function.""" in_channels: int # Needed for clarity, though not strictly used by layers if input shape is known out_channels: int kernel_size: int = 3 stride: int = 1 dilation: int = 1 # spectral_norm: bool = True # Flag for where SN would be applied use_instance_norm: bool = True negative_slope: float = 0.2 @nn.compact def __call__(self, x): """ Args: x: Input tensor (N, L, C_in) Returns: Output tensor (N, L', C_out) - L' depends on stride/padding """ # Flax Conv expects kernel_size, stride, dilation as sequences (tuples) ks = (self.kernel_size,) st = (self.stride,) di = (self.dilation,) # Padding='SAME' works reasonably well for stride=1 and stride=2 downsampling # NOTE: Spectral Norm is omitted here. # If implementing, you'd wrap or replace nn.Conv with a spectral-normalized version. # conv_layer = SpectralNormConv1D(...) or wrap(nn.Conv(...)) y = nn.Conv( features=self.out_channels, kernel_size=ks, strides=st, kernel_dilation=di, padding='SAME' # Often used in GANs )(x) # Apply LeakyReLU first (as in the original code if IN is used) y = nn.leaky_relu(y, negative_slope=self.negative_slope) # Conditionally apply InstanceNorm if self.use_instance_norm: y = InstanceNorm1d(features=self.out_channels)(y) return y class SISUDiscriminator(nn.Module): """SISUDiscriminator model translated to Flax.""" base_channels: int = 16 @nn.compact def __call__(self, x): """ Args: x: Input tensor (N, L, 1) - assumes single channel input Returns: Output tensor (N, 1) - logits """ if x.shape[-1] != 1: raise ValueError(f"Input should have 1 channel (NLC format), got shape {x.shape}") ch = self.base_channels # Block 1: 1 -> ch, k=7, s=1, d=1, SN=T, IN=F # NOTE: Spectral Norm omitted y = DiscriminatorBlock(in_channels=1, out_channels=ch, kernel_size=7, stride=1, use_instance_norm=False)(x) # Block 2: ch -> ch*2, k=5, s=2, d=1, SN=T, IN=T # NOTE: Spectral Norm omitted y = DiscriminatorBlock(in_channels=ch, out_channels=ch*2, kernel_size=5, stride=2, use_instance_norm=True)(y) # Block 3: ch*2 -> ch*4, k=5, s=1, d=2, SN=T, IN=T # NOTE: Spectral Norm omitted y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch*4, kernel_size=5, stride=1, dilation=2, use_instance_norm=True)(y) # Attention Block y = AttentionBlock(channels=ch*4)(y) # Block 4: ch*4 -> ch*8, k=5, s=1, d=4, SN=T, IN=T # NOTE: Spectral Norm omitted y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*8, kernel_size=5, stride=1, dilation=4, use_instance_norm=True)(y) # Block 5: ch*8 -> ch*4, k=5, s=2, d=1, SN=T, IN=T # NOTE: Spectral Norm omitted y = DiscriminatorBlock(in_channels=ch*8, out_channels=ch*4, kernel_size=5, stride=2, use_instance_norm=True)(y) # Block 6: ch*4 -> ch*2, k=3, s=1, d=1, SN=T, IN=T # NOTE: Spectral Norm omitted y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*2, kernel_size=3, stride=1, use_instance_norm=True)(y) # Block 7: ch*2 -> ch, k=3, s=1, d=1, SN=T, IN=T # NOTE: Spectral Norm omitted y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch, kernel_size=3, stride=1, use_instance_norm=True)(y) # Block 8: ch -> 1, k=3, s=1, d=1, SN=F, IN=F # NOTE: Spectral Norm omitted (as per original config) y = DiscriminatorBlock(in_channels=ch, out_channels=1, kernel_size=3, stride=1, use_instance_norm=False)(y) # Global Average Pooling (across Length dimension) pooled = jnp.mean(y, axis=1) # Shape becomes (N, C=1) # Flatten (optional, as shape is likely already (N, 1)) output = jnp.reshape(pooled, (pooled.shape[0], -1)) # Shape (N, 1) return output