From b6d16e4f11582078b40af2a05d60d37487ee0090 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Mon, 14 Apr 2025 17:51:34 +0300 Subject: [PATCH] :recycle: | Restructured procject code. --- file_utils.py | 28 ++++++++++ training.py | 127 ++++++++++++++++++++-------------------------- training_utils.py | 55 ++++++++++++++++++++ 3 files changed, 137 insertions(+), 73 deletions(-) create mode 100644 file_utils.py create mode 100644 training_utils.py diff --git a/file_utils.py b/file_utils.py new file mode 100644 index 0000000..a723688 --- /dev/null +++ b/file_utils.py @@ -0,0 +1,28 @@ +import json + +filepath = "my_data.json" + +def write_data(filepath, data): + try: + with open(filepath, 'w') as f: + json.dump(data, f, indent=4) # Use indent for pretty formatting + print(f"Data written to '{filepath}'") + except Exception as e: + print(f"Error writing to file: {e}") + + +def read_data(filepath): + try: + with open(filepath, 'r') as f: + data = json.load(f) + print(f"Data read from '{filepath}'") + return data + except FileNotFoundError: + print(f"File not found: {filepath}") + return None + except json.JSONDecodeError: + print(f"Error decoding JSON from file: {filepath}") + return None + except Exception as e: + print(f"Error reading from file: {e}") + return None diff --git a/training.py b/training.py index 47982bf..17843e0 100644 --- a/training.py +++ b/training.py @@ -20,6 +20,9 @@ from data import AudioDataset from generator import SISUGenerator from discriminator import SISUDiscriminator +from training_utils import discriminator_train, generator_train +import file_utils as Data + import torchaudio.transforms as T # Init script argument parser @@ -31,92 +34,55 @@ parser.add_argument("--discriminator", type=str, default=None, parser.add_argument("--device", type=str, default="cpu", help="Select device") parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning") parser.add_argument("--debug", action="store_true", help="Print debug logs") +parser.add_argument("--continue_training", type=bool, default=False, help="Continue training using temp_generator and temp_discriminator models") args = parser.parse_args() device = torch.device(args.device if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") -mfcc_transform = T.MFCC( - sample_rate=44100, - n_mfcc=20, - melkwargs={'n_fft': 2048, 'hop_length': 256} -).to(device) - -def gpu_mfcc_loss(y_true, y_pred): - mfccs_true = mfcc_transform(y_true) - mfccs_pred = mfcc_transform(y_pred) - min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2]) - mfccs_true = mfccs_true[:, :, :min_len] - mfccs_pred = mfccs_pred[:, :, :min_len] - loss = torch.mean((mfccs_true - mfccs_pred)**2) - return loss - -def discriminator_train(high_quality, low_quality, real_labels, fake_labels): - optimizer_d.zero_grad() - - # Forward pass for real samples - discriminator_decision_from_real = discriminator(high_quality[0]) - d_loss_real = criterion_d(discriminator_decision_from_real, real_labels) - - # Forward pass for fake samples (from generator output) - generator_output = generator(low_quality[0]) - discriminator_decision_from_fake = discriminator(generator_output.detach()) - d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels) - - # Combine real and fake losses - d_loss = (d_loss_real + d_loss_fake) / 2.0 - - # Backward pass and optimization - d_loss.backward() - nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping - optimizer_d.step() - - return d_loss - -def generator_train(low_quality, high_quality, real_labels): - optimizer_g.zero_grad() - - # Forward pass for fake samples (from generator output) - generator_output = generator(low_quality[0]) - - #mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output) - - discriminator_decision = discriminator(generator_output) - adversarial_loss = criterion_g(discriminator_decision, real_labels) - - #combined_loss = adversarial_loss + 0.5 * mfcc_l - - adversarial_loss.backward() - optimizer_g.step() - - #return (generator_output, combined_loss, adversarial_loss, mfcc_l) - return (generator_output, adversarial_loss) +# mfcc_transform = T.MFCC( +# sample_rate=44100, +# n_mfcc=20, +# melkwargs={'n_fft': 2048, 'hop_length': 256} +# ).to(device) debug = args.debug # Initialize dataset and dataloader dataset_dir = './dataset/good' dataset = AudioDataset(dataset_dir, device) +models_dir = "models" +os.makedirs(models_dir, exist_ok=True) +audio_output_dir = "output" +os.makedirs(audio_output_dir, exist_ok=True) # ========= SINGLE ========= train_data_loader = DataLoader(dataset, batch_size=12, shuffle=True) -# Initialize models and move them to device + +# ========= MODELS ========= + generator = SISUGenerator() discriminator = SISUDiscriminator() epoch: int = args.epoch +epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json") + +if args.continue_training: + generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) + discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) + epoch = epoch_from_file["epoch"] + 1 +else: + if args.generator is not None: + generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True)) + if args.discriminator is not None: + discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True)) generator = generator.to(device) discriminator = discriminator.to(device) -if args.generator is not None: - generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True)) -if args.discriminator is not None: - discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True)) - # Loss criterion_g = nn.BCEWithLogitsLoss() criterion_d = nn.BCEWithLogitsLoss() @@ -129,9 +95,6 @@ optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.99 scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5) scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5) -models_dir = "models" -os.makedirs(models_dir, exist_ok=True) - def start_training(): generator_epochs = 5000 for generator_epoch in range(generator_epochs): @@ -154,12 +117,28 @@ def start_training(): # ========= DISCRIMINATOR ========= discriminator.train() - d_loss = discriminator_train(high_quality_sample, low_quality_sample, real_labels, fake_labels) + d_loss = discriminator_train( + high_quality_sample, + low_quality_sample, + real_labels, + fake_labels, + discriminator, + generator, + criterion_d, + optimizer_d + ) # ========= GENERATOR ========= generator.train() - #generator_output, combined_loss, adversarial_loss, mfcc_l = generator_train(low_quality_sample, high_quality_sample, real_labels) - generator_output, adversarial_loss = generator_train(low_quality_sample, high_quality_sample, real_labels) + generator_output, adversarial_loss = generator_train( + low_quality_sample, + high_quality_sample, + real_labels, + generator, + discriminator, + criterion_g, + optimizer_g + ) if debug: print(d_loss, adversarial_loss) @@ -173,17 +152,19 @@ def start_training(): new_epoch = generator_epoch+epoch - if generator_epoch % 10 == 0: + if generator_epoch % 25 == 0: print(f"Saved epoch {new_epoch}!") - torchaudio.save(f"./output/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu().detach(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. - torchaudio.save(f"./output/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1]) - torchaudio.save(f"./output/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1]) + torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu().detach(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. + torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1]) + torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1]) if debug: print(generator.state_dict().keys()) print(discriminator.state_dict().keys()) - torch.save(discriminator.state_dict(), f"{models_dir}/discriminator_epoch_{new_epoch}.pt") - torch.save(generator.state_dict(), f"{models_dir}/generator_epoch_{new_epoch}.pt") + torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt") + torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt") + Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch}) + torch.save(discriminator, "models/epoch-5000-discriminator.pt") torch.save(generator, "models/epoch-5000-generator.pt") diff --git a/training_utils.py b/training_utils.py new file mode 100644 index 0000000..a1d2c19 --- /dev/null +++ b/training_utils.py @@ -0,0 +1,55 @@ +import torch +import torch.nn as nn +import torch.optim as optim + +import torchaudio + +def gpu_mfcc_loss(mfcc_transform, y_true, y_pred): + mfccs_true = mfcc_transform(y_true) + mfccs_pred = mfcc_transform(y_pred) + min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2]) + mfccs_true = mfccs_true[:, :, :min_len] + mfccs_pred = mfccs_pred[:, :, :min_len] + loss = torch.mean((mfccs_true - mfccs_pred)**2) + return loss + +def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer): + optimizer.zero_grad() + + # Forward pass for real samples + discriminator_decision_from_real = discriminator(high_quality[0]) + d_loss_real = criterion(discriminator_decision_from_real, real_labels) + + # Forward pass for fake samples (from generator output) + generator_output = generator(low_quality[0]) + discriminator_decision_from_fake = discriminator(generator_output.detach()) + d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels) + + # Combine real and fake losses + d_loss = (d_loss_real + d_loss_fake) / 2.0 + + # Backward pass and optimization + d_loss.backward() + nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping + optimizer.step() + + return d_loss + +def generator_train(low_quality, high_quality, real_labels, generator, discriminator, criterion, optimizer): + optimizer.zero_grad() + + # Forward pass for fake samples (from generator output) + generator_output = generator(low_quality[0]) + + #mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output) + + discriminator_decision = discriminator(generator_output) + adversarial_loss = criterion(discriminator_decision, real_labels) + + #combined_loss = adversarial_loss + 0.5 * mfcc_l + + adversarial_loss.backward() + optimizer.step() + + #return (generator_output, combined_loss, adversarial_loss, mfcc_l) + return (generator_output, adversarial_loss)