From 8332b0df2daa97561a78fb95012dc0bc349ffb41 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Wed, 26 Feb 2025 19:36:43 +0200 Subject: [PATCH] :sparkles: | Added ability to set epoch. --- training.py | 66 ++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 48 insertions(+), 18 deletions(-) diff --git a/training.py b/training.py index 3992829..710dd65 100644 --- a/training.py +++ b/training.py @@ -21,21 +21,45 @@ from generator import SISUGenerator from discriminator import SISUDiscriminator import librosa +import numpy as np def mfcc_loss(y_true, y_pred, sr): - # 1. Ensure sr is a NumPy scalar (not a Tensor) - if isinstance(sr, torch.Tensor): # Check if it's a Tensor - sr = sr.item() # Extract the value as a Python number + """Calculates MFCC loss between two audio signals. - # 2. Convert y_true and y_pred to NumPy arrays - y_true_np = y_true.cpu().detach().numpy()[0] # .cpu() is crucial! + Args: + y_true: Target audio signal (PyTorch tensor). + y_pred: Predicted audio signal (PyTorch tensor). + sr: Sample rate (NumPy scalar). + + Returns: + MFCC loss (PyTorch tensor). + """ + + # 1. Ensure sr is a NumPy scalar (not a Tensor) + if isinstance(sr, torch.Tensor): + sr = sr.item() + + # 2. Convert y_true and y_pred to NumPy arrays (and detach from graph) + y_true_np = y_true.cpu().detach().numpy()[0] # .cpu() and .detach() are crucial! y_pred_np = y_pred.cpu().detach().numpy()[0] + # 3. Dynamically calculate n_fft based on signal length + signal_length = min(y_true_np.shape[0], y_pred_np.shape[0]) # Use shortest signal length + n_fft = min(2048, 2**int(np.log2(signal_length))) # Power of 2, up to 2048 - mfccs_true = librosa.feature.mfcc(y=y_true_np, sr=sr, n_mfcc=20) - mfccs_pred = librosa.feature.mfcc(y=y_pred_np, sr=sr, n_mfcc=20) + # 4. Calculate MFCCs using adjusted n_fft + mfccs_true = librosa.feature.mfcc(y=y_true_np, sr=sr, n_fft=n_fft, n_mfcc=20) + mfccs_pred = librosa.feature.mfcc(y=y_pred_np, sr=sr, n_fft=n_fft, n_mfcc=20) - # 3. Convert MFCCs back to PyTorch tensors and ensure correct device + # 5. Truncate MFCCs to the same length (important!) + len_true = mfccs_true.shape[1] + len_pred = mfccs_pred.shape[1] + min_len = min(len_true, len_pred) + + mfccs_true = mfccs_true[:, :min_len] + mfccs_pred = mfccs_pred[:, :min_len] + + # 6. Convert MFCCs back to PyTorch tensors and ensure correct device mfccs_true = torch.tensor(mfccs_true, device=y_true.device, dtype=torch.float32) mfccs_pred = torch.tensor(mfccs_pred, device=y_pred.device, dtype=torch.float32) @@ -87,6 +111,7 @@ parser.add_argument("--generator", type=str, default=None, help="Path to the generator model file") parser.add_argument("--discriminator", type=str, default=None, help="Path to the discriminator model file") +parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning") parser.add_argument("--verbose", action="store_true", help="Increase output verbosity") args = parser.parse_args() @@ -120,6 +145,8 @@ train_data_loader = DataLoader(dataset, batch_size=1, shuffle=True) generator = SISUGenerator() discriminator = SISUDiscriminator() +epoch: int = args.epoch + if args.generator is not None: generator.load_state_dict(torch.load(args.generator, weights_only=True)) if args.discriminator is not None: @@ -153,7 +180,7 @@ def start_training(): times_correct = 0 # ========= TRAINING ========= - for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {generator_epoch+1}/{generator_epochs}"): + for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"): # for high_quality_clip, low_quality_clip in train_data_loader: high_quality_sample = (high_quality_clip[0].to(device), high_quality_clip[1]) low_quality_sample = (low_quality_clip[0].to(device), low_quality_clip[1]) @@ -181,16 +208,19 @@ def start_training(): low_quality_audio = low_quality_clip ai_enhanced_audio = (generator_output, high_quality_clip[1]) - if generator_epoch % 10 == 0: - print(f"Saved epoch {generator_epoch}!") - torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. - torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1]) - torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1]) + new_epoch = generator_epoch+epoch - torch.save(discriminator.state_dict(), f"{models_dir}/discriminator_epoch_{generator_epoch}.pt") - torch.save(generator.state_dict(), f"{models_dir}/generator_epoch_{generator_epoch}.pt") - torch.save(discriminator, f"{models_dir}/discriminator_epoch_{generator_epoch}_full.pt") - torch.save(generator, f"{models_dir}/generator_epoch_{generator_epoch}_full.pt") + if generator_epoch % 10 == 0: + print(f"Saved epoch {new_epoch}!") + torchaudio.save(f"./output/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. + torchaudio.save(f"./output/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1]) + torchaudio.save(f"./output/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1]) + + if debug: + print(generator.state_dict().keys()) + print(discriminator.state_dict().keys()) + torch.save(discriminator.state_dict(), f"{models_dir}/discriminator_epoch_{new_epoch}.pt") + torch.save(generator.state_dict(), f"{models_dir}/generator_epoch_{new_epoch}.pt") torch.save(discriminator, "models/epoch-5000-discriminator.pt") torch.save(generator, "models/epoch-5000-generator.pt")