⚗️ | Experimenting, again.

This commit is contained in:
2024-12-26 04:00:24 +02:00
parent 2ff45de22d
commit 89f8c68986
4 changed files with 49 additions and 55 deletions

View File

@ -28,14 +28,9 @@ def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion_d(discriminator_decision_from_real, real_labels)
integer_scale = math.ceil(high_quality[1]/low_quality[1])
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0], integer_scale)
resample_transform = torchaudio.transforms.Resample(low_quality[1] * integer_scale, high_quality[1]).to(device)
resampled = resample_transform(generator_output.detach())
discriminator_decision_from_fake = discriminator(resampled)
generator_output = generator(low_quality[0])
discriminator_decision_from_fake = discriminator(generator_output.detach())
d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels)
# Combine real and fake losses
@ -48,22 +43,17 @@ def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
return d_loss
def generator_train(low_quality, real_labels, target_sample_rate=44100):
def generator_train(low_quality, real_labels):
optimizer_g.zero_grad()
scale = math.ceil(target_sample_rate/low_quality[1])
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0], scale)
resample_transform = torchaudio.transforms.Resample(low_quality[1] * scale, target_sample_rate).to(device)
resampled = resample_transform(generator_output)
discriminator_decision = discriminator(resampled)
generator_output = generator(low_quality[0])
discriminator_decision = discriminator(generator_output)
g_loss = criterion_g(discriminator_decision, real_labels)
g_loss.backward()
optimizer_g.step()
return resampled
return generator_output
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
@ -110,7 +100,7 @@ generator = generator.to(device)
discriminator = discriminator.to(device)
# Loss
criterion_g = nn.L1Loss()
criterion_g = nn.MSELoss()
criterion_d = nn.BCELoss()
# Optimizers
@ -172,7 +162,7 @@ def start_training():
# ========= GENERATOR =========
generator.train()
generator_output = generator_train(low_quality_sample, real_labels, high_quality_sample[1])
generator_output = generator_train(low_quality_sample, real_labels)
# ========= SAVE LATEST AUDIO =========
high_quality_audio = high_quality_clip
@ -185,7 +175,7 @@ def start_training():
if generator_epoch % 10 == 0:
print(f"Saved epoch {generator_epoch}!")
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), low_quality_audio[1])
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from what ever that low_quality had to high_quality
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1])
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1])