diff --git a/.gitignore b/.gitignore index 493e1db..2a8b716 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,4 @@ cython_debug/ backup/ dataset/ old-output/ +*.wav diff --git a/output.wav b/output.wav deleted file mode 100644 index 90b8499..0000000 Binary files a/output.wav and /dev/null differ diff --git a/training.py b/training.py index 23a918a..4b0988d 100644 --- a/training.py +++ b/training.py @@ -46,64 +46,39 @@ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', # Training loop num_epochs = 500 for epoch in range(num_epochs): - latest_crap_audio = torch.empty((2,3), dtype=torch.int64) - for high_quality, low_quality in tqdm.tqdm(train_data_loader): - # Check for NaN values in input tensors - if torch.isnan(low_quality).any() or torch.isnan(high_quality).any(): - continue - - high_quality = high_quality.to(device) + original, crap_audio = torch.empty((1,2,3)), torch.empty((1,2,3)) + for low_quality, high_quality in tqdm.tqdm(train_data_loader): low_quality = low_quality.to(device) - + high_quality = high_quality.to(device) batch_size = low_quality.size(0) - - # Labels real_labels = torch.ones(batch_size, 1).to(device) fake_labels = torch.zeros(batch_size, 1).to(device) # Train Discriminator optimizer_d.zero_grad() - outputs = discriminator(high_quality) - d_loss_real = criterion(outputs, real_labels) - d_loss_real.backward() - - resampled_audio = generator(low_quality) - - outputs = discriminator(resampled_audio.detach()) - d_loss_fake = criterion(outputs, fake_labels) - d_loss_fake.backward() - - - # Gradient clipping for discriminator - clip_value = 2.0 - for param in discriminator.parameters(): - if param.grad is not None: - param.grad.clamp_(-clip_value, clip_value) - + real_outputs = discriminator(high_quality) + fake_audio = generator(low_quality) + fake_outputs = discriminator(fake_audio.detach()) + d_loss_real = criterion(real_outputs, real_labels) + d_loss_fake = criterion(fake_outputs, fake_labels) + d_loss = (d_loss_real + d_loss_fake) * 0.5 + d_loss.backward() optimizer_d.step() - d_loss = d_loss_real + d_loss_fake - # Train Generator optimizer_g.zero_grad() - outputs = discriminator(resampled_audio) - g_loss = criterion(outputs, real_labels) + fake_outputs = discriminator(fake_audio) + g_loss = criterion(fake_outputs, real_labels) g_loss.backward() - - # Gradient clipping for generator - clip_value = 1.0 - for param in generator.parameters(): - if param.grad is not None: - param.grad.clamp_(-clip_value, clip_value) - optimizer_g.step() - - scheduler.step(d_loss + g_loss) - latest_crap_audio = resampled_audio + original = high_quality + crap_audio = fake_audio if epoch % 10 == 0: - print(latest_crap_audio.size()) - torchaudio.save(f"./epoch-{epoch}-audio.wav", latest_crap_audio[0].cpu(), 44100) + print(crap_audio.size()) + torchaudio.save(f"./epoch-{epoch}-audio.wav", crap_audio[0].cpu(), 44100) + torchaudio.save(f"./epoch-{epoch}-audio-orig.wav", original[0].cpu(), 44100) + print(f'Epoch [{epoch+1}/{num_epochs}]') torch.save(generator.state_dict(), "generator.pt")