11 lines
287 B
Python
11 lines
287 B
Python
import torch.nn as nn
|
|
import torch
|
|
from discriminator import SISUDiscriminator
|
|
|
|
|
|
discriminator = SISUDiscriminator()
|
|
test_input = torch.randn(1, 2, 1000) # Example input (batch_size, channels, frames)
|
|
output = discriminator(test_input)
|
|
print(output)
|
|
print("Output shape:", output.shape)
|