71 lines
1.9 KiB
Python
71 lines
1.9 KiB
Python
import torch.nn as nn
|
|
import torch.nn.utils as utils
|
|
|
|
|
|
def discriminator_block(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=15,
|
|
stride=1,
|
|
dilation=1
|
|
):
|
|
padding = dilation * (kernel_size - 1) // 2
|
|
|
|
conv_layer = nn.Conv1d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
dilation=dilation,
|
|
padding=padding
|
|
)
|
|
|
|
conv_layer = utils.spectral_norm(conv_layer)
|
|
leaky_relu = nn.LeakyReLU(0.2)
|
|
|
|
return nn.Sequential(conv_layer, leaky_relu)
|
|
|
|
|
|
class AttentionBlock(nn.Module):
|
|
def __init__(self, channels):
|
|
super(AttentionBlock, self).__init__()
|
|
self.attention = nn.Sequential(
|
|
nn.Conv1d(channels, channels // 4, kernel_size=1),
|
|
nn.ReLU(),
|
|
nn.Conv1d(channels // 4, channels, kernel_size=1),
|
|
nn.Sigmoid(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
attention_weights = self.attention(x)
|
|
return x + (x * attention_weights)
|
|
|
|
|
|
class SISUDiscriminator(nn.Module):
|
|
def __init__(self, layers=8):
|
|
super(SISUDiscriminator, self).__init__()
|
|
self.discriminator_blocks = nn.Sequential(
|
|
# 1 -> 32
|
|
discriminator_block(2, layers),
|
|
AttentionBlock(layers),
|
|
# 32 -> 64
|
|
discriminator_block(layers, layers * 2, dilation=2),
|
|
# 64 -> 128
|
|
discriminator_block(layers * 2, layers * 4, dilation=4),
|
|
AttentionBlock(layers * 4),
|
|
# 128 -> 256
|
|
discriminator_block(layers * 4, layers * 8, stride=4),
|
|
# 256 -> 512
|
|
# discriminator_block(layers * 8, layers * 16, stride=4)
|
|
)
|
|
|
|
self.final_conv = nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1)
|
|
|
|
self.avg_pool = nn.AdaptiveAvgPool1d(1)
|
|
|
|
def forward(self, x):
|
|
x = self.discriminator_blocks(x)
|
|
x = self.final_conv(x)
|
|
x = self.avg_pool(x)
|
|
return x.squeeze(2)
|