SISU/discriminator.py

64 lines
2.5 KiB
Python

import torch
import torch.nn as nn
import torch.nn.utils as utils
def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, spectral_norm=True, use_instance_norm=True):
padding = (kernel_size // 2) * dilation
conv_layer = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding
)
if spectral_norm:
conv_layer = utils.spectral_norm(conv_layer)
layers = [conv_layer]
layers.append(nn.LeakyReLU(0.2, inplace=True))
if use_instance_norm:
layers.append(nn.InstanceNorm1d(out_channels))
return nn.Sequential(*layers)
class AttentionBlock(nn.Module):
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
attention_weights = self.attention(x)
return x * attention_weights
class SISUDiscriminator(nn.Module):
def __init__(self, base_channels=16):
super(SISUDiscriminator, self).__init__()
layers = base_channels
self.model = nn.Sequential(
discriminator_block(1, layers, kernel_size=7, stride=1, spectral_norm=True, use_instance_norm=False),
discriminator_block(layers, layers * 2, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers * 2, layers * 4, kernel_size=5, stride=1, dilation=2, spectral_norm=True, use_instance_norm=True),
AttentionBlock(layers * 4),
discriminator_block(layers * 4, layers * 8, kernel_size=5, stride=1, dilation=4, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers * 8, layers * 4, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers * 4, layers * 2, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers * 2, layers, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers, 1, kernel_size=3, stride=1, spectral_norm=False, use_instance_norm=False)
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
def forward(self, x):
x = self.model(x)
x = self.global_avg_pool(x)
x = x.view(x.size(0), -1)
return x