diff --git a/data.py b/data.py index fe41126..71f25dd 100644 --- a/data.py +++ b/data.py @@ -24,7 +24,7 @@ class AudioDataset(Dataset): sample_rate = random.choice(self.audio_sample_rates) resample_transform = torchaudio.transforms.Resample(sr_original, sample_rate) low_quality_wav = resample_transform(high_quality_wav) - low_quality_wav = -low_quality_wav + low_quality_wav = low_quality_wav # Calculate target length based on desired duration and 16000 Hz if self.target_duration is not None: diff --git a/discriminator.py b/discriminator.py index e0083d4..9fd9e30 100644 --- a/discriminator.py +++ b/discriminator.py @@ -1,23 +1,25 @@ import torch.nn as nn +import torch class SISUDiscriminator(nn.Module): def __init__(self): super(SISUDiscriminator, self).__init__() self.model = nn.Sequential( - nn.Conv1d(2, 64, kernel_size=4, stride=2, padding=1), # Now accepts 2 input channels - nn.LeakyReLU(0.2), - nn.Conv1d(64, 128, kernel_size=4, stride=2, padding=1), - nn.BatchNorm1d(128), - nn.LeakyReLU(0.2), - nn.Conv1d(128, 256, kernel_size=4, stride=2, padding=1), - nn.BatchNorm1d(256), - nn.LeakyReLU(0.2), - nn.Conv1d(256, 512, kernel_size=4, stride=2, padding=1), - nn.BatchNorm1d(512), - nn.LeakyReLU(0.2), - nn.Conv1d(512, 1, kernel_size=4, stride=1, padding=0), - nn.Sigmoid() + nn.Conv1d(2, 128, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(128, 256, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(256, 128, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(128, 64, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(64, 1, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), ) + self.global_avg_pool = nn.AdaptiveAvgPool1d(1) # Output size (1,) def forward(self, x): - return self.model(x) + x = self.model(x) + x = self.global_avg_pool(x) + x = x.view(-1, 1) # Flatten to (batch_size, 1) + return x diff --git a/test.py b/test.py new file mode 100644 index 0000000..fbf81e6 --- /dev/null +++ b/test.py @@ -0,0 +1,10 @@ +import torch.nn as nn +import torch +from discriminator import SISUDiscriminator + + +discriminator = SISUDiscriminator() +test_input = torch.randn(1, 2, 1000) # Example input (batch_size, channels, frames) +output = discriminator(test_input) +print(output) +print("Output shape:", output.shape) diff --git a/training.py b/training.py index 828b21b..7d54123 100644 --- a/training.py +++ b/training.py @@ -25,8 +25,8 @@ val_size = int(dataset_size-train_size) train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) -train_data_loader = DataLoader(train_dataset, batch_size=4, shuffle=True) -val_data_loader = DataLoader(val_dataset, batch_size=4, shuffle=True) +train_data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) +val_data_loader = DataLoader(val_dataset, batch_size=1, shuffle=True) # Initialize models and move them to device generator = SISUGenerator() @@ -36,16 +36,13 @@ generator = generator.to(device) discriminator = discriminator.to(device) # Loss -criterion_g = nn.L1Loss() # Perceptual Loss (L1 instead of MSE) -criterion_d = nn.MSELoss() # Can keep MSE for discriminator (optional) +criterion_g = nn.L1Loss() +criterion_d = nn.BCEWithLogitsLoss() # Optimizers -optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999)) # Reduced learning rate +optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999)) optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) -# Learning rate scheduler -scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.1, patience=5) - # Training loop num_epochs = 500 @@ -61,7 +58,7 @@ for epoch in range(num_epochs): high_quality = high_quality.to(device) low_quality = low_quality.to(device) - batch_size = high_quality.size(0) + batch_size = 1 real_labels = torch.ones(batch_size, 1).to(device) fake_labels = torch.zeros(batch_size, 1).to(device) @@ -75,7 +72,7 @@ for epoch in range(num_epochs): # 2. Fake data fake_audio = generator(low_quality) - fake_outputs = discriminator(fake_audio.detach()) # Detach to stop gradient flow to the generator + fake_outputs = discriminator(fake_audio.detach()) d_loss_fake = criterion_d(fake_outputs, fake_labels) d_loss = (d_loss_real + d_loss_fake) / 2.0 # Without gradient penalty