⚗️ | Experimenting, again.
This commit is contained in:
28
training.py
28
training.py
@ -28,14 +28,9 @@ def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
|
||||
discriminator_decision_from_real = discriminator(high_quality[0])
|
||||
d_loss_real = criterion_d(discriminator_decision_from_real, real_labels)
|
||||
|
||||
integer_scale = math.ceil(high_quality[1]/low_quality[1])
|
||||
|
||||
# Forward pass for fake samples (from generator output)
|
||||
generator_output = generator(low_quality[0], integer_scale)
|
||||
resample_transform = torchaudio.transforms.Resample(low_quality[1] * integer_scale, high_quality[1]).to(device)
|
||||
resampled = resample_transform(generator_output.detach())
|
||||
|
||||
discriminator_decision_from_fake = discriminator(resampled)
|
||||
generator_output = generator(low_quality[0])
|
||||
discriminator_decision_from_fake = discriminator(generator_output.detach())
|
||||
d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels)
|
||||
|
||||
# Combine real and fake losses
|
||||
@ -48,22 +43,17 @@ def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
|
||||
|
||||
return d_loss
|
||||
|
||||
def generator_train(low_quality, real_labels, target_sample_rate=44100):
|
||||
def generator_train(low_quality, real_labels):
|
||||
optimizer_g.zero_grad()
|
||||
|
||||
scale = math.ceil(target_sample_rate/low_quality[1])
|
||||
|
||||
# Forward pass for fake samples (from generator output)
|
||||
generator_output = generator(low_quality[0], scale)
|
||||
resample_transform = torchaudio.transforms.Resample(low_quality[1] * scale, target_sample_rate).to(device)
|
||||
resampled = resample_transform(generator_output)
|
||||
|
||||
discriminator_decision = discriminator(resampled)
|
||||
generator_output = generator(low_quality[0])
|
||||
discriminator_decision = discriminator(generator_output)
|
||||
g_loss = criterion_g(discriminator_decision, real_labels)
|
||||
|
||||
g_loss.backward()
|
||||
optimizer_g.step()
|
||||
return resampled
|
||||
return generator_output
|
||||
|
||||
# Init script argument parser
|
||||
parser = argparse.ArgumentParser(description="Training script")
|
||||
@ -110,7 +100,7 @@ generator = generator.to(device)
|
||||
discriminator = discriminator.to(device)
|
||||
|
||||
# Loss
|
||||
criterion_g = nn.L1Loss()
|
||||
criterion_g = nn.MSELoss()
|
||||
criterion_d = nn.BCELoss()
|
||||
|
||||
# Optimizers
|
||||
@ -172,7 +162,7 @@ def start_training():
|
||||
|
||||
# ========= GENERATOR =========
|
||||
generator.train()
|
||||
generator_output = generator_train(low_quality_sample, real_labels, high_quality_sample[1])
|
||||
generator_output = generator_train(low_quality_sample, real_labels)
|
||||
|
||||
# ========= SAVE LATEST AUDIO =========
|
||||
high_quality_audio = high_quality_clip
|
||||
@ -185,7 +175,7 @@ def start_training():
|
||||
|
||||
if generator_epoch % 10 == 0:
|
||||
print(f"Saved epoch {generator_epoch}!")
|
||||
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), low_quality_audio[1])
|
||||
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from what ever that low_quality had to high_quality
|
||||
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1])
|
||||
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1])
|
||||
|
||||
|
Reference in New Issue
Block a user