import torch.nn as nn def conv_block(in_channels, out_channels, kernel_size=3, dilation=1): return nn.Sequential( nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, dilation=dilation, padding=(kernel_size // 2) * dilation), nn.BatchNorm1d(out_channels), nn.PReLU() ) class AttentionBlock(nn.Module): def __init__(self, channels): super(AttentionBlock, self).__init__() self.attention = nn.Sequential( nn.Conv1d(channels, channels // 4, kernel_size=1), nn.ReLU(), nn.Conv1d(channels // 4, channels, kernel_size=1), nn.Sigmoid() ) def forward(self, x): attention_weights = self.attention(x) return x * attention_weights class ResidualInResidualBlock(nn.Module): def __init__(self, channels, num_convs=3): super(ResidualInResidualBlock, self).__init__() self.conv_layers = nn.Sequential(*[conv_block(channels, channels) for _ in range(num_convs)]) self.attention = AttentionBlock(channels) def forward(self, x): residual = x x = self.conv_layers(x) x = self.attention(x) return x + residual class SISUGenerator(nn.Module): def __init__(self, layer=64, num_rirb=4): #increased base layer and rirb amounts super(SISUGenerator, self).__init__() self.conv1 = nn.Sequential( nn.Conv1d(1, layer, kernel_size=7, padding=3), nn.BatchNorm1d(layer), nn.PReLU(), ) self.rir_blocks = nn.Sequential(*[ResidualInResidualBlock(layer) for _ in range(num_rirb)]) self.final_layer = nn.Conv1d(layer, 1, kernel_size=3, padding=1) def forward(self, x): residual = x x = self.conv1(x) x = self.rir_blocks(x) x = self.final_layer(x) return x + residual