import torch.nn as nn import torch.nn.utils as utils def discriminator_block( in_channels, out_channels, kernel_size=15, stride=1, dilation=1 ): padding = dilation * (kernel_size - 1) // 2 conv_layer = nn.Conv1d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding ) conv_layer = utils.spectral_norm(conv_layer) leaky_relu = nn.LeakyReLU(0.2) return nn.Sequential(conv_layer, leaky_relu) class AttentionBlock(nn.Module): def __init__(self, channels): super(AttentionBlock, self).__init__() self.attention = nn.Sequential( nn.Conv1d(channels, channels // 4, kernel_size=1), nn.ReLU(), nn.Conv1d(channels // 4, channels, kernel_size=1), nn.Sigmoid(), ) def forward(self, x): attention_weights = self.attention(x) return x + (x * attention_weights) class SISUDiscriminator(nn.Module): def __init__(self, layers=8): super(SISUDiscriminator, self).__init__() self.discriminator_blocks = nn.Sequential( # 1 -> 32 discriminator_block(2, layers), AttentionBlock(layers), # 32 -> 64 discriminator_block(layers, layers * 2, dilation=2), # 64 -> 128 discriminator_block(layers * 2, layers * 4, dilation=4), AttentionBlock(layers * 4), # 128 -> 256 discriminator_block(layers * 4, layers * 8, stride=4), # 256 -> 512 # discriminator_block(layers * 8, layers * 16, stride=4) ) self.final_conv = nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1) self.avg_pool = nn.AdaptiveAvgPool1d(1) def forward(self, x): x = self.discriminator_blocks(x) x = self.final_conv(x) x = self.avg_pool(x) return x.squeeze(2)