import torch import torch.nn as nn def GeneratorBlock(in_channels, out_channels, kernel_size=3, stride=1, dilation=1): padding = (kernel_size - 1) // 2 * dilation return nn.Sequential( nn.Conv1d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding ), nn.InstanceNorm1d(out_channels), nn.PReLU(num_parameters=1, init=0.1), ) class AttentionBlock(nn.Module): def __init__(self, channels): super(AttentionBlock, self).__init__() self.attention = nn.Sequential( nn.Conv1d(channels, channels // 4, kernel_size=1), nn.ReLU(inplace=True), nn.Conv1d(channels // 4, channels, kernel_size=1), nn.Sigmoid(), ) def forward(self, x): attention_weights = self.attention(x) return x + (x * attention_weights) class ResidualInResidualBlock(nn.Module): def __init__(self, channels, num_convs=3): super(ResidualInResidualBlock, self).__init__() self.conv_layers = nn.Sequential( *[GeneratorBlock(channels, channels) for _ in range(num_convs)] ) self.attention = AttentionBlock(channels) def forward(self, x): residual = x x = self.conv_layers(x) x = self.attention(x) return x + residual def UpsampleBlock(in_channels, out_channels): return nn.Sequential( nn.ConvTranspose1d( in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1 ), nn.InstanceNorm1d(out_channels), nn.PReLU(num_parameters=1, init=0.1) ) class SISUGenerator(nn.Module): def __init__(self, channels=32, num_rirb=1): super(SISUGenerator, self).__init__() self.first_conv = GeneratorBlock(1, channels) self.downsample = GeneratorBlock(channels, channels * 2, stride=2) self.downsample_attn = AttentionBlock(channels * 2) self.downsample_2 = GeneratorBlock(channels * 2, channels * 4, stride=2) self.downsample_2_attn = AttentionBlock(channels * 4) self.rirb = ResidualInResidualBlock(channels * 4) # self.rirb = nn.Sequential( # *[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)] # ) self.upsample = UpsampleBlock(channels * 4, channels * 2) self.upsample_attn = AttentionBlock(channels * 2) self.compress_1 = GeneratorBlock(channels * 4, channels * 2) self.upsample_2 = UpsampleBlock(channels * 2, channels) self.upsample_2_attn = AttentionBlock(channels) self.compress_2 = GeneratorBlock(channels * 2, channels) self.final_conv = nn.Sequential( nn.Conv1d(channels, 1, kernel_size=7, padding=3), nn.Tanh() ) def forward(self, x): residual_input = x x1 = self.first_conv(x) x2 = self.downsample(x1) x2 = self.downsample_attn(x2) x3 = self.downsample_2(x2) x3 = self.downsample_2_attn(x3) x_rirb = self.rirb(x3) up1 = self.upsample(x_rirb) up1 = self.upsample_attn(up1) cat1 = torch.cat((up1, x2), dim=1) comp1 = self.compress_1(cat1) up2 = self.upsample_2(comp1) up2 = self.upsample_2_attn(up2) cat2 = torch.cat((up2, x1), dim=1) comp2 = self.compress_2(cat2) learned_residual = self.final_conv(comp2) output = residual_input + learned_residual return output