import torch.nn as nn import torch from discriminator import SISUDiscriminator discriminator = SISUDiscriminator() test_input = torch.randn(1, 2, 1000) # Example input (batch_size, channels, frames) output = discriminator(test_input) print(output) print("Output shape:", output.shape)