import torch.nn as nn class SISUDiscriminator(nn.Module): def __init__(self): super(SISUDiscriminator, self).__init__() self.model = nn.Sequential( nn.Conv1d(2, 128, kernel_size=3, padding=1), #nn.LeakyReLU(0.2, inplace=True), nn.Conv1d(128, 256, kernel_size=3, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.Conv1d(256, 128, kernel_size=3, padding=1), #nn.LeakyReLU(0.2, inplace=True), nn.Conv1d(128, 64, kernel_size=3, padding=1), #nn.LeakyReLU(0.2, inplace=True), nn.Conv1d(64, 1, kernel_size=3, padding=1), #nn.LeakyReLU(0.2, inplace=True), ) self.global_avg_pool = nn.AdaptiveAvgPool1d(1) # Output size (1,) def forward(self, x): x = self.model(x) x = self.global_avg_pool(x) x = x.view(-1, 1) # Flatten to (batch_size, 1) return x