import argparse import os import torch import torch.nn as nn import torch.optim as optim import tqdm from accelerate import Accelerator from torch.utils.data import DataLoader, DistributedSampler from data import AudioDataset from discriminator import SISUDiscriminator from generator import SISUGenerator from utils.TrainingTools import discriminator_train, generator_train # --------------------------- # Argument parsing # --------------------------- parser = argparse.ArgumentParser(description="Training script (safer defaults)") parser.add_argument("--resume", action="store_true", help="Resume training") parser.add_argument( "--epochs", type=int, default=5000, help="Number of training epochs" ) parser.add_argument("--batch_size", type=int, default=8, help="Batch size") parser.add_argument("--num_workers", type=int, default=2, help="DataLoader num_workers") parser.add_argument("--debug", action="store_true", help="Print debug logs") parser.add_argument( "--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA" ) args = parser.parse_args() # --------------------------- # Init accelerator # --------------------------- accelerator = Accelerator(mixed_precision="bf16") # --------------------------- # Models # --------------------------- generator = SISUGenerator() discriminator = SISUDiscriminator() accelerator.print("๐Ÿ”จ | Compiling models...") generator = torch.compile(generator) discriminator = torch.compile(discriminator) accelerator.print("โœ… | Compiling done!") # --------------------------- # Dataset / DataLoader # --------------------------- accelerator.print("๐Ÿ“Š | Fetching dataset...") dataset = AudioDataset("./dataset") sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None pin_memory = torch.cuda.is_available() and not args.no_pin_memory train_loader = DataLoader( dataset, sampler=sampler, batch_size=args.batch_size, shuffle=(sampler is None), num_workers=args.num_workers, pin_memory=pin_memory, persistent_workers=pin_memory, ) if not train_loader or not train_loader.batch_size or train_loader.batch_size == 0: accelerator.print("๐Ÿชน | There is no data to train with! Exiting...") exit() loader_batch_size = train_loader.batch_size accelerator.print("โœ… | Dataset fetched!") # --------------------------- # Losses / Optimizers / Scalers # --------------------------- optimizer_g = optim.AdamW( generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 ) optimizer_d = optim.AdamW( discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 ) scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_g, mode="min", factor=0.5, patience=5 ) scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_d, mode="min", factor=0.5, patience=5 ) criterion_g = nn.BCEWithLogitsLoss() criterion_d = nn.MSELoss() # --------------------------- # Prepare accelerator # --------------------------- generator, discriminator, optimizer_g, optimizer_d, train_loader = accelerator.prepare( generator, discriminator, optimizer_g, optimizer_d, train_loader ) # --------------------------- # Checkpoint helpers # --------------------------- models_dir = "./models" os.makedirs(models_dir, exist_ok=True) def save_ckpt(path, epoch): accelerator.wait_for_everyone() if accelerator.is_main_process: accelerator.save( { "epoch": epoch, "G": accelerator.unwrap_model(generator).state_dict(), "D": accelerator.unwrap_model(discriminator).state_dict(), "optG": optimizer_g.state_dict(), "optD": optimizer_d.state_dict(), "schedG": scheduler_g.state_dict(), "schedD": scheduler_d.state_dict(), }, path, ) start_epoch = 0 if args.resume: ckpt_path = os.path.join(models_dir, "last.pt") ckpt = torch.load(ckpt_path) accelerator.unwrap_model(generator).load_state_dict(ckpt["G"]) accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"]) optimizer_g.load_state_dict(ckpt["optG"]) optimizer_d.load_state_dict(ckpt["optD"]) scheduler_g.load_state_dict(ckpt["schedG"]) scheduler_d.load_state_dict(ckpt["schedD"]) start_epoch = ckpt.get("epoch", 1) accelerator.print(f"๐Ÿ” | Resumed from epoch {start_epoch}!") real_buf = torch.full( (loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32 ) fake_buf = torch.zeros( (loader_batch_size, 1), device=accelerator.device, dtype=torch.float32 ) accelerator.print("๐Ÿ‹๏ธ | Started training...") try: for epoch in range(start_epoch, args.epochs): generator.train() discriminator.train() running_d, running_g, steps = 0.0, 0.0, 0 for i, ( (high_quality, low_quality), (high_sample_rate, low_sample_rate), ) in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")): batch_size = high_quality.size(0) real_labels = real_buf[:batch_size].to(accelerator.device) fake_labels = fake_buf[:batch_size].to(accelerator.device) # --- Discriminator --- optimizer_d.zero_grad(set_to_none=True) with accelerator.autocast(): d_loss = discriminator_train( high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion_d, ) accelerator.backward(d_loss) torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1) optimizer_d.step() # --- Generator --- optimizer_g.zero_grad(set_to_none=True) with accelerator.autocast(): g_total, g_adv = generator_train( low_quality, high_quality, real_labels, generator, discriminator, criterion_d, ) accelerator.backward(g_total) torch.nn.utils.clip_grad_norm_(generator.parameters(), 1) optimizer_g.step() d_val = accelerator.gather(d_loss.detach()).mean() g_val = accelerator.gather(g_total.detach()).mean() if torch.isfinite(d_val): running_d += d_val.item() else: accelerator.print( f"๐Ÿซฅ | NaN in discriminator loss at step {i}, skipping update." ) if torch.isfinite(g_val): running_g += g_val.item() else: accelerator.print( f"๐Ÿซฅ | NaN in generator loss at step {i}, skipping update." ) steps += 1 # epoch averages & schedulers if steps == 0: accelerator.print("๐Ÿชน | No steps in epoch (empty dataloader?). Exiting.") break mean_d = running_d / steps mean_g = running_g / steps scheduler_d.step(mean_d) scheduler_g.step(mean_g) save_ckpt(os.path.join(models_dir, "last.pt"), epoch) accelerator.print(f"๐Ÿค | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}") except Exception: try: save_ckpt(os.path.join(models_dir, "crash_last.pt"), epoch) accelerator.print(f"๐Ÿ’พ | Saved crash checkpoint for epoch {epoch}") except Exception as e: accelerator.print("๐Ÿ˜ฌ | Failed saving crash checkpoint:", e) raise accelerator.print("๐Ÿ | Training finished.")