From 3936b6c1600a62b18a7e38371707c47f4dfcfaf4 Mon Sep 17 00:00:00 2001 From: nsiltala <144348410+nsiltala@users.noreply.github.com> Date: Mon, 7 Apr 2025 14:49:07 +0300 Subject: [PATCH] :bug: | Fixed NVIDIA training... again. --- training.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/training.py b/training.py index c050b9c..47982bf 100644 --- a/training.py +++ b/training.py @@ -101,7 +101,7 @@ dataset = AudioDataset(dataset_dir, device) # ========= SINGLE ========= -train_data_loader = DataLoader(dataset, batch_size=8, shuffle=True) +train_data_loader = DataLoader(dataset, batch_size=12, shuffle=True) # Initialize models and move them to device generator = SISUGenerator() @@ -118,7 +118,7 @@ if args.discriminator is not None: discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True)) # Loss -criterion_g = nn.MSELoss() +criterion_g = nn.BCEWithLogitsLoss() criterion_d = nn.BCEWithLogitsLoss() # Optimizers @@ -163,8 +163,8 @@ def start_training(): if debug: print(d_loss, adversarial_loss) - scheduler_d.step(d_loss) - scheduler_g.step(adversarial_loss) + scheduler_d.step(d_loss.detach()) + scheduler_g.step(adversarial_loss.detach()) # ========= SAVE LATEST AUDIO ========= high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0]) @@ -175,9 +175,9 @@ def start_training(): if generator_epoch % 10 == 0: print(f"Saved epoch {new_epoch}!") - torchaudio.save(f"./output/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. - torchaudio.save(f"./output/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu(), ai_enhanced_audio[1]) - torchaudio.save(f"./output/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu(), high_quality_audio[1]) + torchaudio.save(f"./output/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu().detach(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. + torchaudio.save(f"./output/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1]) + torchaudio.save(f"./output/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1]) if debug: print(generator.state_dict().keys())