import argparse import datetime import os import torch import torch.optim as optim import tqdm from accelerate import Accelerator from torch.utils.data import DataLoader, DistributedSampler from data import AudioDataset from discriminator import SISUDiscriminator from generator import SISUGenerator from utils.TrainingTools import discriminator_train, generator_train # --------------------------- # Argument parsing # --------------------------- parser = argparse.ArgumentParser(description="Training script (safer defaults)") parser.add_argument("--resume", action="store_true", help="Resume training") parser.add_argument( "--epochs", type=int, default=5000, help="Number of training epochs" ) parser.add_argument("--batch_size", type=int, default=8, help="Batch size") parser.add_argument("--num_workers", type=int, default=4, help="DataLoader num_workers") # Increased workers slightly parser.add_argument("--debug", action="store_true", help="Print debug logs") parser.add_argument( "--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA" ) args = parser.parse_args() # --------------------------- # Init accelerator # --------------------------- accelerator = Accelerator(mixed_precision="bf16") # --------------------------- # Models # --------------------------- generator = SISUGenerator() discriminator = SISUDiscriminator() accelerator.print("🔨 | Compiling models...") generator = torch.compile(generator) discriminator = torch.compile(discriminator) accelerator.print("✅ | Compiling done!") # --------------------------- # Dataset / DataLoader # --------------------------- accelerator.print("📊 | Fetching dataset...") dataset = AudioDataset("./dataset", 8192) sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None pin_memory = torch.cuda.is_available() and not args.no_pin_memory train_loader = DataLoader( dataset, sampler=sampler, batch_size=args.batch_size, shuffle=(sampler is None), num_workers=args.num_workers, pin_memory=pin_memory, persistent_workers=pin_memory, ) if not train_loader or not train_loader.batch_size or train_loader.batch_size == 0: accelerator.print("🪹 | There is no data to train with! Exiting...") exit() loader_batch_size = train_loader.batch_size accelerator.print("✅ | Dataset fetched!") # --------------------------- # Losses / Optimizers / Scalers # --------------------------- optimizer_g = optim.AdamW( generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 ) optimizer_d = optim.AdamW( discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 ) scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_g, mode="min", factor=0.5, patience=5 ) scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_d, mode="min", factor=0.5, patience=5 ) # --------------------------- # Prepare accelerator # --------------------------- generator, discriminator, optimizer_g, optimizer_d, train_loader = accelerator.prepare( generator, discriminator, optimizer_g, optimizer_d, train_loader ) # --------------------------- # Checkpoint helpers # --------------------------- models_dir = "./models" os.makedirs(models_dir, exist_ok=True) def save_ckpt(path, epoch): accelerator.wait_for_everyone() if accelerator.is_main_process: accelerator.save( { "epoch": epoch, "G": accelerator.unwrap_model(generator).state_dict(), "D": accelerator.unwrap_model(discriminator).state_dict(), "optG": optimizer_g.state_dict(), "optD": optimizer_d.state_dict(), "schedG": scheduler_g.state_dict(), "schedD": scheduler_d.state_dict(), }, path, ) start_epoch = 0 if args.resume: ckpt_path = os.path.join(models_dir, "last.pt") if os.path.exists(ckpt_path): ckpt = torch.load(ckpt_path) accelerator.unwrap_model(generator).load_state_dict(ckpt["G"]) accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"]) optimizer_g.load_state_dict(ckpt["optG"]) optimizer_d.load_state_dict(ckpt["optD"]) scheduler_g.load_state_dict(ckpt["schedG"]) scheduler_d.load_state_dict(ckpt["schedD"]) start_epoch = ckpt.get("epoch", 1) accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!") else: accelerator.print("⚠️ | Resume requested but no checkpoint found. Starting fresh.") accelerator.print("🏋️ | Started training...") smallest_loss = float('inf') try: for epoch in range(start_epoch, args.epochs): generator.train() discriminator.train() discriminator_time = 0 generator_time = 0 running_d, running_g, steps = 0.0, 0.0, 0 progress_bar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs") for i, ( (high_quality, low_quality), (high_sample_rate, low_sample_rate), ) in enumerate(progress_bar): with accelerator.autocast(): generator_output = generator(low_quality) # --- Discriminator --- d_time = datetime.datetime.now() optimizer_d.zero_grad(set_to_none=True) with accelerator.autocast(): d_loss = discriminator_train( high_quality, low_quality.detach(), discriminator, generator_output.detach() ) accelerator.backward(d_loss) optimizer_d.step() discriminator_time = (datetime.datetime.now() - d_time).microseconds # --- Generator --- g_time = datetime.datetime.now() optimizer_g.zero_grad(set_to_none=True) with accelerator.autocast(): g_total, g_adv = generator_train( low_quality, high_quality, generator, discriminator, generator_output ) accelerator.backward(g_total) torch.nn.utils.clip_grad_norm_(generator.parameters(), 1) optimizer_g.step() generator_time = (datetime.datetime.now() - g_time).microseconds d_val = accelerator.gather(d_loss.detach()).mean() g_val = accelerator.gather(g_total.detach()).mean() if torch.isfinite(d_val): running_d += d_val.item() else: accelerator.print( f"🫥 | NaN in discriminator loss at step {i}, skipping update." ) if torch.isfinite(g_val): running_g += g_val.item() else: accelerator.print( f"🫥 | NaN in generator loss at step {i}, skipping update." ) steps += 1 progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs") # epoch averages & schedulers if steps == 0: accelerator.print("🪹 | No steps in epoch (empty dataloader?). Exiting.") break mean_d = running_d / steps mean_g = running_g / steps scheduler_d.step(mean_d) scheduler_g.step(mean_g) save_ckpt(os.path.join(models_dir, "last.pt"), epoch) if smallest_loss > mean_g: smallest_loss = mean_g save_ckpt(os.path.join(models_dir, "latest-smallest_loss.pt"), epoch) accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}") except Exception: try: save_ckpt(os.path.join(models_dir, "crash_last.pt"), epoch) accelerator.print(f"💾 | Saved crash checkpoint for epoch {epoch}") except Exception as e: accelerator.print("😬 | Failed saving crash checkpoint:", e) raise accelerator.print("🏁 | Training finished.")