import argparse import os import torch import torch.nn as nn import torch.optim as optim import torchaudio.transforms as T import tqdm from torch.amp import GradScaler, autocast from torch.utils.data import DataLoader import training_utils from data import AudioDataset from discriminator import SISUDiscriminator from generator import SISUGenerator from training_utils import discriminator_train, generator_train # --------------------------- # Argument parsing # --------------------------- parser = argparse.ArgumentParser(description="Training script (safer defaults)") parser.add_argument("--resume", action="store_true", help="Resume training") parser.add_argument( "--device", type=str, default="cuda", help="Device (cuda, cpu, mps)" ) parser.add_argument( "--epochs", type=int, default=5000, help="Number of training epochs" ) parser.add_argument("--batch_size", type=int, default=8, help="Batch size") parser.add_argument("--num_workers", type=int, default=2, help="DataLoader num_workers") parser.add_argument("--debug", action="store_true", help="Print debug logs") parser.add_argument( "--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA" ) args = parser.parse_args() # --------------------------- # Device setup # --------------------------- # Use requested device only if available device = torch.device( args.device if (args.device != "cuda" or torch.cuda.is_available()) else "cpu" ) print(f"Using device: {device}") # sensible performance flags if device.type == "cuda": torch.backends.cudnn.benchmark = True # optional: torch.set_float32_matmul_precision("high") debug = args.debug # --------------------------- # Audio transforms # --------------------------- sample_rate = 44100 n_fft = 1024 win_length = n_fft hop_length = n_fft // 4 n_mels = 96 # n_mfcc = 13 # mfcc_transform = T.MFCC( # sample_rate=sample_rate, # n_mfcc=n_mfcc, # melkwargs=dict( # n_fft=n_fft, # hop_length=hop_length, # win_length=win_length, # n_mels=n_mels, # power=1.0, # ), # ).to(device) mel_transform = T.MelSpectrogram( sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, win_length=win_length, n_mels=n_mels, power=1.0, ).to(device) stft_transform = T.Spectrogram( n_fft=n_fft, win_length=win_length, hop_length=hop_length ).to(device) # training_utils.init(mel_transform, stft_transform, mfcc_transform) training_utils.init(mel_transform, stft_transform) # --------------------------- # Dataset / DataLoader # --------------------------- dataset_dir = "./dataset/good" dataset = AudioDataset(dataset_dir) train_loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, persistent_workers=True, ) # --------------------------- # Models # --------------------------- generator = SISUGenerator().to(device) discriminator = SISUDiscriminator().to(device) generator = torch.compile(generator) discriminator = torch.compile(discriminator) # --------------------------- # Losses / Optimizers / Scalers # --------------------------- criterion_g = nn.BCEWithLogitsLoss() criterion_d = nn.BCEWithLogitsLoss() optimizer_g = optim.AdamW( generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 ) optimizer_d = optim.AdamW( discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 ) # Use modern GradScaler signature; choose device_type based on runtime device. scaler = GradScaler(device=device) scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_g, mode="min", factor=0.5, patience=5 ) scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_d, mode="min", factor=0.5, patience=5 ) # --------------------------- # Checkpoint helpers # --------------------------- models_dir = "./models" os.makedirs(models_dir, exist_ok=True) def save_ckpt(path, epoch): torch.save( { "epoch": epoch, "G": generator.state_dict(), "D": discriminator.state_dict(), "optG": optimizer_g.state_dict(), "optD": optimizer_d.state_dict(), "scaler": scaler.state_dict(), "schedG": scheduler_g.state_dict(), "schedD": scheduler_d.state_dict(), }, path, ) start_epoch = 0 if args.resume: ckpt = torch.load(os.path.join(models_dir, "last.pt"), map_location=device) generator.load_state_dict(ckpt["G"]) discriminator.load_state_dict(ckpt["D"]) optimizer_g.load_state_dict(ckpt["optG"]) optimizer_d.load_state_dict(ckpt["optD"]) scaler.load_state_dict(ckpt["scaler"]) scheduler_g.load_state_dict(ckpt["schedG"]) scheduler_d.load_state_dict(ckpt["schedD"]) start_epoch = ckpt.get("epoch", 1) # --------------------------- # Training loop (safer) # --------------------------- if not train_loader or not train_loader.batch_size: print("There is no data to train with! Exiting...") exit() max_batch = max(1, train_loader.batch_size) real_buf = torch.full((max_batch, 1), 0.9, device=device) # label smoothing fake_buf = torch.zeros(max_batch, 1, device=device) try: for epoch in range(start_epoch, args.epochs): generator.train() discriminator.train() running_d, running_g, steps = 0.0, 0.0, 0 for i, ( (high_quality, low_quality), (high_sample_rate, low_sample_rate), ) in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")): batch_size = high_quality.size(0) high_quality = high_quality.to(device, non_blocking=True) low_quality = low_quality.to(device, non_blocking=True) real_labels = real_buf[:batch_size] fake_labels = fake_buf[:batch_size] # --- Discriminator --- optimizer_d.zero_grad(set_to_none=True) with autocast(device_type=device.type): d_loss = discriminator_train( high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion_d, ) scaler.scale(d_loss).backward() scaler.unscale_(optimizer_d) torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0) scaler.step(optimizer_d) # --- Generator --- optimizer_g.zero_grad(set_to_none=True) with autocast(device_type=device.type): g_out, g_total, g_adv = generator_train( low_quality, high_quality, real_labels, generator, discriminator, criterion_d, ) scaler.scale(g_total).backward() scaler.unscale_(optimizer_g) torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0) scaler.step(optimizer_g) scaler.update() running_d += float(d_loss.detach().cpu().item()) running_g += float(g_total.detach().cpu().item()) steps += 1 # epoch averages & schedulers if steps == 0: print("No steps in epoch (empty dataloader?). Exiting.") break mean_d = running_d / steps mean_g = running_g / steps scheduler_d.step(mean_d) scheduler_g.step(mean_g) save_ckpt(os.path.join(models_dir, "last.pt"), epoch) print(f"Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}") except Exception: try: save_ckpt(os.path.join(models_dir, "crash_last.pt"), epoch) print(f"Saved crash checkpoint for epoch {epoch}") except Exception as e: print("Failed saving crash checkpoint:", e) raise try: torch.save(generator.state_dict(), os.path.join(models_dir, "final_generator.pt")) torch.save( discriminator.state_dict(), os.path.join(models_dir, "final_discriminator.pt") ) except Exception as e: print("Failed to save final states:", e) print("Training finished.")