⚗️ | More architectural changes

This commit is contained in:
2025-11-18 21:34:59 +02:00
parent 3f23242d6f
commit 782a3bab28
8 changed files with 245 additions and 254 deletions

View File

@@ -5,32 +5,25 @@ import torch.nn.utils as utils
def discriminator_block(
in_channels,
out_channels,
kernel_size=3,
kernel_size=15,
stride=1,
dilation=1,
spectral_norm=True,
use_instance_norm=True,
dilation=1
):
padding = (kernel_size // 2) * dilation
padding = dilation * (kernel_size - 1) // 2
conv_layer = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
padding=padding
)
if spectral_norm:
conv_layer = utils.spectral_norm(conv_layer)
conv_layer = utils.spectral_norm(conv_layer)
leaky_relu = nn.LeakyReLU(0.2)
layers = [conv_layer]
layers.append(nn.LeakyReLU(0.2, inplace=True))
if use_instance_norm:
layers.append(nn.InstanceNorm1d(out_channels))
return nn.Sequential(*layers)
return nn.Sequential(conv_layer, leaky_relu)
class AttentionBlock(nn.Module):
@@ -38,38 +31,40 @@ class AttentionBlock(nn.Module):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True),
nn.ReLU(),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid(),
)
def forward(self, x):
attention_weights = self.attention(x)
return x * attention_weights
return x + (x * attention_weights)
class SISUDiscriminator(nn.Module):
def __init__(self, layers=32):
def __init__(self, layers=8):
super(SISUDiscriminator, self).__init__()
self.model = nn.Sequential(
discriminator_block(1, layers, kernel_size=7, stride=1),
discriminator_block(layers, layers * 2, kernel_size=5, stride=2),
discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=2),
self.discriminator_blocks = nn.Sequential(
# 1 -> 32
discriminator_block(2, layers),
AttentionBlock(layers),
# 32 -> 64
discriminator_block(layers, layers * 2, dilation=2),
# 64 -> 128
discriminator_block(layers * 2, layers * 4, dilation=4),
AttentionBlock(layers * 4),
discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=4),
discriminator_block(layers * 8, layers * 2, kernel_size=5, stride=2),
discriminator_block(
layers * 2,
1,
spectral_norm=False,
use_instance_norm=False,
),
# 128 -> 256
discriminator_block(layers * 4, layers * 8, stride=4),
# 256 -> 512
# discriminator_block(layers * 8, layers * 16, stride=4)
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
self.final_conv = nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1)
self.avg_pool = nn.AdaptiveAvgPool1d(1)
def forward(self, x):
x = self.model(x)
x = self.global_avg_pool(x)
x = x.view(x.size(0), -1)
return x
x = self.discriminator_blocks(x)
x = self.final_conv(x)
x = self.avg_pool(x)
return x.squeeze(2)