From dcde20387a64f1c6d63c5147b8bf7cf9e3b6e5ca Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Wed, 18 Dec 2024 02:55:57 +0200 Subject: [PATCH] :poop: | Hopefully this does something else. --- training.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/training.py b/training.py index 4b0988d..7a30eb8 100644 --- a/training.py +++ b/training.py @@ -45,20 +45,21 @@ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', # Training loop num_epochs = 500 +lambda_gp = 10 for epoch in range(num_epochs): - original, crap_audio = torch.empty((1,2,3)), torch.empty((1,2,3)) + original = torch.empty((2)) + crap_audio = torch.empty((2)) for low_quality, high_quality in tqdm.tqdm(train_data_loader): - low_quality = low_quality.to(device) high_quality = high_quality.to(device) - batch_size = low_quality.size(0) + low_quality = low_quality.to(device) + + batch_size = high_quality.size(0) real_labels = torch.ones(batch_size, 1).to(device) fake_labels = torch.zeros(batch_size, 1).to(device) - # Train Discriminator - optimizer_d.zero_grad() real_outputs = discriminator(high_quality) - fake_audio = generator(low_quality) - fake_outputs = discriminator(fake_audio.detach()) + fake_outputs = discriminator(generator(low_quality)) + d_loss_real = criterion(real_outputs, real_labels) d_loss_fake = criterion(fake_outputs, fake_labels) d_loss = (d_loss_real + d_loss_fake) * 0.5 @@ -67,10 +68,12 @@ for epoch in range(num_epochs): # Train Generator optimizer_g.zero_grad() + fake_audio = generator(low_quality) fake_outputs = discriminator(fake_audio) g_loss = criterion(fake_outputs, real_labels) g_loss.backward() optimizer_g.step() + original = high_quality crap_audio = fake_audio