| Added smarter ways that would've been needed from the begining.

This commit is contained in:
2025-04-16 17:08:13 +03:00
parent b6d16e4f11
commit c04b072de6
2 changed files with 148 additions and 24 deletions

View File

@ -41,11 +41,24 @@ args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# mfcc_transform = T.MFCC(
# sample_rate=44100,
# n_mfcc=20,
# melkwargs={'n_fft': 2048, 'hop_length': 256}
# ).to(device)
# Parameters
sample_rate = 44100
n_fft = 2048
hop_length = 256
win_length = n_fft
n_mels = 128
n_mfcc = 20 # If using MFCC
mfcc_transform = T.MFCC(
sample_rate,
n_mfcc,
melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
).to(device)
mel_transform = T.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
).to(device)
debug = args.debug
@ -130,18 +143,20 @@ def start_training():
# ========= GENERATOR =========
generator.train()
generator_output, adversarial_loss = generator_train(
generator_output, combined_loss, adversarial_loss, mel_l1_tensor = generator_train(
low_quality_sample,
high_quality_sample,
real_labels,
generator,
discriminator,
criterion_g,
optimizer_g
criterion_d,
optimizer_g,
device,
mel_transform
)
if debug:
print(d_loss, adversarial_loss)
print(combined_loss, adversarial_loss, mel_l1_tensor)
scheduler_d.step(d_loss.detach())
scheduler_g.step(adversarial_loss.detach())