Files
SISU/discriminator.py
2025-11-18 21:34:59 +02:00

71 lines
1.9 KiB
Python

import torch.nn as nn
import torch.nn.utils as utils
def discriminator_block(
in_channels,
out_channels,
kernel_size=15,
stride=1,
dilation=1
):
padding = dilation * (kernel_size - 1) // 2
conv_layer = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding
)
conv_layer = utils.spectral_norm(conv_layer)
leaky_relu = nn.LeakyReLU(0.2)
return nn.Sequential(conv_layer, leaky_relu)
class AttentionBlock(nn.Module):
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid(),
)
def forward(self, x):
attention_weights = self.attention(x)
return x + (x * attention_weights)
class SISUDiscriminator(nn.Module):
def __init__(self, layers=8):
super(SISUDiscriminator, self).__init__()
self.discriminator_blocks = nn.Sequential(
# 1 -> 32
discriminator_block(2, layers),
AttentionBlock(layers),
# 32 -> 64
discriminator_block(layers, layers * 2, dilation=2),
# 64 -> 128
discriminator_block(layers * 2, layers * 4, dilation=4),
AttentionBlock(layers * 4),
# 128 -> 256
discriminator_block(layers * 4, layers * 8, stride=4),
# 256 -> 512
# discriminator_block(layers * 8, layers * 16, stride=4)
)
self.final_conv = nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1)
self.avg_pool = nn.AdaptiveAvgPool1d(1)
def forward(self, x):
x = self.discriminator_blocks(x)
x = self.final_conv(x)
x = self.avg_pool(x)
return x.squeeze(2)