⚗️ | Added MultiPeriodDiscriminator implementation from HiFi-GAN

This commit is contained in:
2025-12-04 14:22:48 +02:00
parent 782a3bab28
commit bf0a6e58e9
4 changed files with 210 additions and 131 deletions

View File

@@ -1,19 +1,20 @@
import torch
import torch.nn as nn
from torch.nn.utils.parametrizations import weight_norm
def GeneratorBlock(in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
padding = (kernel_size - 1) // 2 * dilation
return nn.Sequential(
nn.Conv1d(
weight_norm(nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding
),
nn.InstanceNorm1d(out_channels),
)),
nn.PReLU(num_parameters=1, init=0.1),
)
@@ -22,9 +23,9 @@ class AttentionBlock(nn.Module):
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
weight_norm(nn.Conv1d(channels, channels // 4, kernel_size=1)),
nn.ReLU(inplace=True),
nn.Conv1d(channels // 4, channels, kernel_size=1),
weight_norm(nn.Conv1d(channels // 4, channels, kernel_size=1)),
nn.Sigmoid(),
)
@@ -49,21 +50,21 @@ class ResidualInResidualBlock(nn.Module):
x = self.attention(x)
return x + residual
def UpsampleBlock(in_channels, out_channels):
def UpsampleBlock(in_channels, out_channels, scale_factor=2):
return nn.Sequential(
nn.ConvTranspose1d(
nn.Upsample(scale_factor=scale_factor, mode='nearest'),
weight_norm(nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=4,
stride=2,
kernel_size=3,
stride=1,
padding=1
),
nn.InstanceNorm1d(out_channels),
)),
nn.PReLU(num_parameters=1, init=0.1)
)
class SISUGenerator(nn.Module):
def __init__(self, channels=32, num_rirb=1):
def __init__(self, channels=32, num_rirb=4):
super(SISUGenerator, self).__init__()
self.first_conv = GeneratorBlock(1, channels)
@@ -73,10 +74,9 @@ class SISUGenerator(nn.Module):
self.downsample_2 = GeneratorBlock(channels * 2, channels * 4, stride=2)
self.downsample_2_attn = AttentionBlock(channels * 4)
self.rirb = ResidualInResidualBlock(channels * 4)
# self.rirb = nn.Sequential(
# *[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)]
# )
self.rirb = nn.Sequential(
*[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)]
)
self.upsample = UpsampleBlock(channels * 4, channels * 2)
self.upsample_attn = AttentionBlock(channels * 2)
@@ -87,13 +87,15 @@ class SISUGenerator(nn.Module):
self.compress_2 = GeneratorBlock(channels * 2, channels)
self.final_conv = nn.Sequential(
nn.Conv1d(channels, 1, kernel_size=7, padding=3),
weight_norm(nn.Conv1d(channels, 1, kernel_size=7, padding=3)),
nn.Tanh()
)
def forward(self, x):
residual_input = x
# Encoding
x1 = self.first_conv(x)
x2 = self.downsample(x1)
@@ -102,8 +104,10 @@ class SISUGenerator(nn.Module):
x3 = self.downsample_2(x2)
x3 = self.downsample_2_attn(x3)
# Bottleneck (Deep Residual processing)
x_rirb = self.rirb(x3)
# Decoding with Skip Connections
up1 = self.upsample(x_rirb)
up1 = self.upsample_attn(up1)