import argparse import datetime import os import torch import torch.nn as nn import torch.optim as optim import tqdm from accelerate import Accelerator from torch.utils.data import DataLoader, DistributedSampler from data import AudioDataset from discriminator import SISUDiscriminator from generator import SISUGenerator from utils.TrainingTools import discriminator_train, generator_train # --------------------------- # Argument parsing # --------------------------- parser = argparse.ArgumentParser(description="Training script (safer defaults)") parser.add_argument("--resume", action="store_true", help="Resume training") parser.add_argument( "--epochs", type=int, default=5000, help="Number of training epochs" ) parser.add_argument("--batch_size", type=int, default=8, help="Batch size") parser.add_argument("--num_workers", type=int, default=4, help="DataLoader num_workers") # Increased workers slightly parser.add_argument("--debug", action="store_true", help="Print debug logs") parser.add_argument( "--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA" ) args = parser.parse_args() # --------------------------- # Init accelerator # --------------------------- accelerator = Accelerator(mixed_precision="bf16") # --------------------------- # Models # --------------------------- generator = SISUGenerator() # Note: SISUDiscriminator is now an Ensemble (MPD + MSD) discriminator = SISUDiscriminator() accelerator.print("๐Ÿ”จ | Compiling models...") # Torch compile is great, but if you hit errors with the new List/Tuple outputs # of the discriminator, you might need to disable it for D. generator = torch.compile(generator) discriminator = torch.compile(discriminator) accelerator.print("โœ… | Compiling done!") # --------------------------- # Dataset / DataLoader # --------------------------- accelerator.print("๐Ÿ“Š | Fetching dataset...") dataset = AudioDataset("./dataset", 8192) sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None pin_memory = torch.cuda.is_available() and not args.no_pin_memory train_loader = DataLoader( dataset, sampler=sampler, batch_size=args.batch_size, shuffle=(sampler is None), num_workers=args.num_workers, pin_memory=pin_memory, persistent_workers=pin_memory, ) if not train_loader or not train_loader.batch_size or train_loader.batch_size == 0: accelerator.print("๐Ÿชน | There is no data to train with! Exiting...") exit() loader_batch_size = train_loader.batch_size accelerator.print("โœ… | Dataset fetched!") # --------------------------- # Losses / Optimizers / Scalers # --------------------------- optimizer_g = optim.AdamW( generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 ) optimizer_d = optim.AdamW( discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 ) scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_g, mode="min", factor=0.5, patience=5 ) scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_d, mode="min", factor=0.5, patience=5 ) # --------------------------- # Prepare accelerator # --------------------------- generator, discriminator, optimizer_g, optimizer_d, train_loader = accelerator.prepare( generator, discriminator, optimizer_g, optimizer_d, train_loader ) # --------------------------- # Checkpoint helpers # --------------------------- models_dir = "./models" os.makedirs(models_dir, exist_ok=True) def save_ckpt(path, epoch, loss=None, is_best=False): accelerator.wait_for_everyone() if accelerator.is_main_process: state = { "epoch": epoch, "G": accelerator.unwrap_model(generator).state_dict(), "D": accelerator.unwrap_model(discriminator).state_dict(), "optG": optimizer_g.state_dict(), "optD": optimizer_d.state_dict(), "schedG": scheduler_g.state_dict(), "schedD": scheduler_d.state_dict() } accelerator.save(state, os.path.join(models_dir, "last.pt")) if is_best: accelerator.save(state, os.path.join(models_dir, "best.pt")) accelerator.print(f"๐ŸŒŸ | New best model saved with G Loss: {loss:.4f}") start_epoch = 0 if args.resume: ckpt_path = os.path.join(models_dir, "last.pt") if os.path.exists(ckpt_path): ckpt = torch.load(ckpt_path) accelerator.unwrap_model(generator).load_state_dict(ckpt["G"]) accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"]) optimizer_g.load_state_dict(ckpt["optG"]) optimizer_d.load_state_dict(ckpt["optD"]) scheduler_g.load_state_dict(ckpt["schedG"]) scheduler_d.load_state_dict(ckpt["schedD"]) start_epoch = ckpt.get("epoch", 1) accelerator.print(f"๐Ÿ” | Resumed from epoch {start_epoch}!") else: accelerator.print("โš ๏ธ | Resume requested but no checkpoint found. Starting fresh.") accelerator.print("๐Ÿ‹๏ธ | Started training...") try: for epoch in range(start_epoch, args.epochs): generator.train() discriminator.train() discriminator_time = 0 generator_time = 0 running_d, running_g, steps = 0.0, 0.0, 0 progress_bar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch} | D {discriminator_time}ฮผs | G {generator_time}ฮผs") for i, ( (high_quality, low_quality), (high_sample_rate, low_sample_rate), ) in enumerate(progress_bar): with accelerator.autocast(): generator_output = generator(low_quality) # --- Discriminator --- d_time = datetime.datetime.now() optimizer_d.zero_grad(set_to_none=True) with accelerator.autocast(): d_loss = discriminator_train( high_quality, discriminator, generator_output.detach() ) accelerator.backward(d_loss) optimizer_d.step() discriminator_time = (datetime.datetime.now() - d_time).microseconds # --- Generator --- g_time = datetime.datetime.now() optimizer_g.zero_grad(set_to_none=True) with accelerator.autocast(): g_total, g_adv = generator_train( low_quality, high_quality, generator, discriminator, generator_output ) accelerator.backward(g_total) torch.nn.utils.clip_grad_norm_(generator.parameters(), 1) optimizer_g.step() generator_time = (datetime.datetime.now() - g_time).microseconds d_val = accelerator.gather(d_loss.detach()).mean() g_val = accelerator.gather(g_total.detach()).mean() if torch.isfinite(d_val): running_d += d_val.item() else: accelerator.print( f"๐Ÿซฅ | NaN in discriminator loss at step {i}, skipping update." ) if torch.isfinite(g_val): running_g += g_val.item() else: accelerator.print( f"๐Ÿซฅ | NaN in generator loss at step {i}, skipping update." ) steps += 1 progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}ฮผs | G {generator_time}ฮผs") if steps == 0: accelerator.print("๐Ÿชน | No steps in epoch (empty dataloader?). Exiting.") break mean_d = running_d / steps mean_g = running_g / steps scheduler_d.step(mean_d) scheduler_g.step(mean_g) save_ckpt(os.path.join(models_dir, "last.pt"), epoch) accelerator.print(f"๐Ÿค | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}") except Exception: try: save_ckpt(os.path.join(models_dir, "crash_last.pt"), epoch) accelerator.print(f"๐Ÿ’พ | Saved crash checkpoint for epoch {epoch}") except Exception as e: accelerator.print("๐Ÿ˜ฌ | Failed saving crash checkpoint:", e) raise accelerator.print("๐Ÿ | Training finished.")