Files
SISU/discriminator.py
2025-09-10 19:52:53 +03:00

132 lines
3.5 KiB
Python

import torch.nn as nn
import torch.nn.utils as utils
def discriminator_block(
in_channels,
out_channels,
kernel_size=3,
stride=1,
dilation=1,
spectral_norm=True,
use_instance_norm=True,
):
padding = (kernel_size // 2) * dilation
conv_layer = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
)
if spectral_norm:
conv_layer = utils.spectral_norm(conv_layer)
layers = [conv_layer]
layers.append(nn.LeakyReLU(0.2, inplace=True))
if use_instance_norm:
layers.append(nn.InstanceNorm1d(out_channels))
return nn.Sequential(*layers)
class AttentionBlock(nn.Module):
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid(),
)
def forward(self, x):
attention_weights = self.attention(x)
return x * attention_weights
class SISUDiscriminator(nn.Module):
def __init__(self, base_channels=16):
super(SISUDiscriminator, self).__init__()
layers = base_channels
self.model = nn.Sequential(
discriminator_block(
1,
layers,
kernel_size=7,
stride=1,
spectral_norm=True,
use_instance_norm=False,
),
discriminator_block(
layers,
layers * 2,
kernel_size=5,
stride=2,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block(
layers * 2,
layers * 4,
kernel_size=5,
stride=1,
dilation=2,
spectral_norm=True,
use_instance_norm=True,
),
AttentionBlock(layers * 4),
discriminator_block(
layers * 4,
layers * 8,
kernel_size=5,
stride=1,
dilation=4,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block(
layers * 8,
layers * 4,
kernel_size=5,
stride=2,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block(
layers * 4,
layers * 2,
kernel_size=3,
stride=1,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block(
layers * 2,
layers,
kernel_size=3,
stride=1,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block(
layers,
1,
kernel_size=3,
stride=1,
spectral_norm=False,
use_instance_norm=False,
),
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
def forward(self, x):
x = self.model(x)
x = self.global_avg_pool(x)
x = x.view(x.size(0), -1)
return x