🐛 | Fixed NVIDIA training... again.
This commit is contained in:
parent
fbcd5803b8
commit
3936b6c160
14
training.py
14
training.py
@ -101,7 +101,7 @@ dataset = AudioDataset(dataset_dir, device)
|
|||||||
|
|
||||||
# ========= SINGLE =========
|
# ========= SINGLE =========
|
||||||
|
|
||||||
train_data_loader = DataLoader(dataset, batch_size=8, shuffle=True)
|
train_data_loader = DataLoader(dataset, batch_size=12, shuffle=True)
|
||||||
|
|
||||||
# Initialize models and move them to device
|
# Initialize models and move them to device
|
||||||
generator = SISUGenerator()
|
generator = SISUGenerator()
|
||||||
@ -118,7 +118,7 @@ if args.discriminator is not None:
|
|||||||
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
|
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
|
||||||
|
|
||||||
# Loss
|
# Loss
|
||||||
criterion_g = nn.MSELoss()
|
criterion_g = nn.BCEWithLogitsLoss()
|
||||||
criterion_d = nn.BCEWithLogitsLoss()
|
criterion_d = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
# Optimizers
|
# Optimizers
|
||||||
@ -163,8 +163,8 @@ def start_training():
|
|||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
print(d_loss, adversarial_loss)
|
print(d_loss, adversarial_loss)
|
||||||
scheduler_d.step(d_loss)
|
scheduler_d.step(d_loss.detach())
|
||||||
scheduler_g.step(adversarial_loss)
|
scheduler_g.step(adversarial_loss.detach())
|
||||||
|
|
||||||
# ========= SAVE LATEST AUDIO =========
|
# ========= SAVE LATEST AUDIO =========
|
||||||
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])
|
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])
|
||||||
@ -175,9 +175,9 @@ def start_training():
|
|||||||
|
|
||||||
if generator_epoch % 10 == 0:
|
if generator_epoch % 10 == 0:
|
||||||
print(f"Saved epoch {new_epoch}!")
|
print(f"Saved epoch {new_epoch}!")
|
||||||
torchaudio.save(f"./output/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
|
torchaudio.save(f"./output/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu().detach(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
|
||||||
torchaudio.save(f"./output/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu(), ai_enhanced_audio[1])
|
torchaudio.save(f"./output/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1])
|
||||||
torchaudio.save(f"./output/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu(), high_quality_audio[1])
|
torchaudio.save(f"./output/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1])
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
print(generator.state_dict().keys())
|
print(generator.state_dict().keys())
|
||||||
|
Loading…
Reference in New Issue
Block a user