| Made training bit... spicier.

This commit is contained in:
2025-09-10 19:52:53 +03:00
parent ff38cefdd3
commit 0bc8fc2792
8 changed files with 581 additions and 303 deletions

View File

@@ -1,8 +1,16 @@
import torch
import torch.nn as nn
import torch.nn.utils as utils
def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, spectral_norm=True, use_instance_norm=True):
def discriminator_block(
in_channels,
out_channels,
kernel_size=3,
stride=1,
dilation=1,
spectral_norm=True,
use_instance_norm=True,
):
padding = (kernel_size // 2) * dilation
conv_layer = nn.Conv1d(
in_channels,
@@ -10,7 +18,7 @@ def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dila
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding
padding=padding,
)
if spectral_norm:
@@ -24,6 +32,7 @@ def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dila
return nn.Sequential(*layers)
class AttentionBlock(nn.Module):
def __init__(self, channels):
super(AttentionBlock, self).__init__()
@@ -31,27 +40,86 @@ class AttentionBlock(nn.Module):
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid()
nn.Sigmoid(),
)
def forward(self, x):
attention_weights = self.attention(x)
return x * attention_weights
class SISUDiscriminator(nn.Module):
def __init__(self, base_channels=16):
super(SISUDiscriminator, self).__init__()
layers = base_channels
self.model = nn.Sequential(
discriminator_block(1, layers, kernel_size=7, stride=1, spectral_norm=True, use_instance_norm=False),
discriminator_block(layers, layers * 2, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers * 2, layers * 4, kernel_size=5, stride=1, dilation=2, spectral_norm=True, use_instance_norm=True),
discriminator_block(
1,
layers,
kernel_size=7,
stride=1,
spectral_norm=True,
use_instance_norm=False,
),
discriminator_block(
layers,
layers * 2,
kernel_size=5,
stride=2,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block(
layers * 2,
layers * 4,
kernel_size=5,
stride=1,
dilation=2,
spectral_norm=True,
use_instance_norm=True,
),
AttentionBlock(layers * 4),
discriminator_block(layers * 4, layers * 8, kernel_size=5, stride=1, dilation=4, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers * 8, layers * 4, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers * 4, layers * 2, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers * 2, layers, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers, 1, kernel_size=3, stride=1, spectral_norm=False, use_instance_norm=False)
discriminator_block(
layers * 4,
layers * 8,
kernel_size=5,
stride=1,
dilation=4,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block(
layers * 8,
layers * 4,
kernel_size=5,
stride=2,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block(
layers * 4,
layers * 2,
kernel_size=3,
stride=1,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block(
layers * 2,
layers,
kernel_size=3,
stride=1,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block(
layers,
1,
kernel_size=3,
stride=1,
spectral_norm=False,
use_instance_norm=False,
),
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)