diff --git a/app.py b/app.py index ed51803..8c669a9 100644 --- a/app.py +++ b/app.py @@ -1,33 +1,49 @@ -import torch -import torch.nn as nn -import torch.optim as optim - -import torch.nn.functional as F -import torchaudio -import tqdm - import argparse -import math -import os + +import torch +import torchaudio +import torchcodec +import tqdm import AudioUtils from generator import SISUGenerator - # Init script argument parser parser = argparse.ArgumentParser(description="Training script") -parser.add_argument("--device", type=str, default="cpu", help="Select device") +parser.add_argument("--device", type=str, default="cpu", help="Select device") parser.add_argument("--model", type=str, help="Model to use for upscaling") -parser.add_argument("--clip_length", type=int, default=1024, help="Internal clip length, leave unspecified if unsure") +parser.add_argument( + "--clip_length", + type=int, + default=16384, + help="Internal clip length, leave unspecified if unsure", +) +parser.add_argument( + "--sample_rate", type=int, default=44100, help="Output clip sample rate" +) +parser.add_argument( + "--bitrate", + type=int, + default=192000, + help="Output clip bitrate", +) parser.add_argument("-i", "--input", type=str, help="Input audio file") parser.add_argument("-o", "--output", type=str, help="Output audio file") args = parser.parse_args() +if args.sample_rate < 8000: + print( + "Sample rate cannot be lower than 8000! (44100 is recommended for base models)" + ) + exit() + device = torch.device(args.device if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") -generator = SISUGenerator() +generator = SISUGenerator().to(device) + +generator = torch.compile(generator) models_dir = args.model clip_length = args.clip_length @@ -35,17 +51,30 @@ input_audio = args.input output_audio = args.output if models_dir: - generator.load_state_dict(torch.load(f"{models_dir}", map_location=device, weights_only=True)) + ckpt = torch.load(models_dir, map_location=device) + generator.load_state_dict(ckpt["G"]) else: - print(f"Generator model (--model) isn't specified. Do you have the trained model? If not you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)") + print( + "Generator model (--model) isn't specified. Do you have the trained model? If not, you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)" + ) -generator = generator.to(device) def start(): # To Mono! - audio, original_sample_rate = torchaudio.load(input_audio, normalize=True) + decoder = torchcodec.decoders.AudioDecoder(input_audio) + + decoded_samples = decoder.get_all_samples() + audio = decoded_samples.data + original_sample_rate = decoded_samples.sample_rate + audio = AudioUtils.stereo_tensor_to_mono(audio) + resample_transform = torchaudio.transforms.Resample( + original_sample_rate, args.sample_rate + ) + + audio = resample_transform(audio) + splitted_audio = AudioUtils.split_audio(audio, clip_length) splitted_audio_on_device = [t.to(device) for t in splitted_audio] processed_audio = [] @@ -55,6 +84,13 @@ def start(): reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio) print(f"Saving {output_audio}!") - torchaudio.save(output_audio, reconstructed_audio.cpu().detach(), original_sample_rate) + torchaudio.save_with_torchcodec( + uri=output_audio, + src=reconstructed_audio, + sample_rate=args.sample_rate, + channels_first=True, + compression=args.bitrate, + ) + start() diff --git a/data.py b/data.py index c3e1047..c9f4e76 100644 --- a/data.py +++ b/data.py @@ -1,41 +1,68 @@ -from torch.utils.data import Dataset -import torch.nn.functional as F -import torch -import torchaudio import os import random -import torchaudio.transforms as T + +import torchaudio +import torchcodec.decoders as decoders import tqdm +from torch.utils.data import Dataset + import AudioUtils + class AudioDataset(Dataset): audio_sample_rates = [11025] - def __init__(self, input_dir, device, clip_length = 1024): - self.device = device - input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav') or f.endswith('.mp3') or f.endswith('.flac')] + def __init__(self, input_dir, clip_length=16384): + input_files = [ + os.path.join(root, f) + for root, _, files in os.walk(input_dir) + for f in files + if f.endswith(".wav") or f.endswith(".mp3") or f.endswith(".flac") + ] data = [] - for audio_clip in tqdm.tqdm(input_files, desc=f"Processing {len(input_files)} audio file(s)"): - audio, original_sample_rate = torchaudio.load(audio_clip, normalize=True) + for audio_clip in tqdm.tqdm( + input_files, desc=f"Processing {len(input_files)} audio file(s)" + ): + decoder = decoders.AudioDecoder(audio_clip) + + decoded_samples = decoder.get_all_samples() + audio = decoded_samples.data + original_sample_rate = decoded_samples.sample_rate + audio = AudioUtils.stereo_tensor_to_mono(audio) # Generate low-quality audio with random downsampling mangled_sample_rate = random.choice(self.audio_sample_rates) - resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) - resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate) + resample_transform_low = torchaudio.transforms.Resample( + original_sample_rate, mangled_sample_rate + ) + resample_transform_high = torchaudio.transforms.Resample( + mangled_sample_rate, original_sample_rate + ) low_audio = resample_transform_low(audio) low_audio = resample_transform_high(low_audio) splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length) - splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(splitted_high_quality_audio[-1], clip_length) + splitted_high_quality_audio[-1] = AudioUtils.pad_tensor( + splitted_high_quality_audio[-1], clip_length + ) splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length) - splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(splitted_low_quality_audio[-1], clip_length) + splitted_low_quality_audio[-1] = AudioUtils.pad_tensor( + splitted_low_quality_audio[-1], clip_length + ) - for high_quality_sample, low_quality_sample in zip(splitted_high_quality_audio, splitted_low_quality_audio): - data.append(((high_quality_sample, low_quality_sample), (original_sample_rate, mangled_sample_rate))) + for high_quality_sample, low_quality_sample in zip( + splitted_high_quality_audio, splitted_low_quality_audio + ): + data.append( + ( + (high_quality_sample, low_quality_sample), + (original_sample_rate, mangled_sample_rate), + ) + ) self.audio_data = data diff --git a/discriminator.py b/discriminator.py index dfd0126..ce2e84c 100644 --- a/discriminator.py +++ b/discriminator.py @@ -1,8 +1,16 @@ -import torch import torch.nn as nn import torch.nn.utils as utils -def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, spectral_norm=True, use_instance_norm=True): + +def discriminator_block( + in_channels, + out_channels, + kernel_size=3, + stride=1, + dilation=1, + spectral_norm=True, + use_instance_norm=True, +): padding = (kernel_size // 2) * dilation conv_layer = nn.Conv1d( in_channels, @@ -10,7 +18,7 @@ def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dila kernel_size=kernel_size, stride=stride, dilation=dilation, - padding=padding + padding=padding, ) if spectral_norm: @@ -24,6 +32,7 @@ def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dila return nn.Sequential(*layers) + class AttentionBlock(nn.Module): def __init__(self, channels): super(AttentionBlock, self).__init__() @@ -31,27 +40,86 @@ class AttentionBlock(nn.Module): nn.Conv1d(channels, channels // 4, kernel_size=1), nn.ReLU(inplace=True), nn.Conv1d(channels // 4, channels, kernel_size=1), - nn.Sigmoid() + nn.Sigmoid(), ) def forward(self, x): attention_weights = self.attention(x) return x * attention_weights + class SISUDiscriminator(nn.Module): def __init__(self, base_channels=16): super(SISUDiscriminator, self).__init__() layers = base_channels self.model = nn.Sequential( - discriminator_block(1, layers, kernel_size=7, stride=1, spectral_norm=True, use_instance_norm=False), - discriminator_block(layers, layers * 2, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers * 2, layers * 4, kernel_size=5, stride=1, dilation=2, spectral_norm=True, use_instance_norm=True), + discriminator_block( + 1, + layers, + kernel_size=7, + stride=1, + spectral_norm=True, + use_instance_norm=False, + ), + discriminator_block( + layers, + layers * 2, + kernel_size=5, + stride=2, + spectral_norm=True, + use_instance_norm=True, + ), + discriminator_block( + layers * 2, + layers * 4, + kernel_size=5, + stride=1, + dilation=2, + spectral_norm=True, + use_instance_norm=True, + ), AttentionBlock(layers * 4), - discriminator_block(layers * 4, layers * 8, kernel_size=5, stride=1, dilation=4, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers * 8, layers * 4, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers * 4, layers * 2, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers * 2, layers, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers, 1, kernel_size=3, stride=1, spectral_norm=False, use_instance_norm=False) + discriminator_block( + layers * 4, + layers * 8, + kernel_size=5, + stride=1, + dilation=4, + spectral_norm=True, + use_instance_norm=True, + ), + discriminator_block( + layers * 8, + layers * 4, + kernel_size=5, + stride=2, + spectral_norm=True, + use_instance_norm=True, + ), + discriminator_block( + layers * 4, + layers * 2, + kernel_size=3, + stride=1, + spectral_norm=True, + use_instance_norm=True, + ), + discriminator_block( + layers * 2, + layers, + kernel_size=3, + stride=1, + spectral_norm=True, + use_instance_norm=True, + ), + discriminator_block( + layers, + 1, + kernel_size=3, + stride=1, + spectral_norm=False, + use_instance_norm=False, + ), ) self.global_avg_pool = nn.AdaptiveAvgPool1d(1) diff --git a/generator.py b/generator.py index a53feb7..0240860 100644 --- a/generator.py +++ b/generator.py @@ -1,6 +1,6 @@ -import torch import torch.nn as nn + def conv_block(in_channels, out_channels, kernel_size=3, dilation=1): return nn.Sequential( nn.Conv1d( @@ -8,29 +8,32 @@ def conv_block(in_channels, out_channels, kernel_size=3, dilation=1): out_channels, kernel_size=kernel_size, dilation=dilation, - padding=(kernel_size // 2) * dilation + padding=(kernel_size // 2) * dilation, ), nn.InstanceNorm1d(out_channels), - nn.PReLU() + nn.PReLU(), ) + class AttentionBlock(nn.Module): """ Simple Channel Attention Block. Learns to weight channels based on their importance. """ + def __init__(self, channels): super(AttentionBlock, self).__init__() self.attention = nn.Sequential( nn.Conv1d(channels, channels // 4, kernel_size=1), nn.ReLU(inplace=True), nn.Conv1d(channels // 4, channels, kernel_size=1), - nn.Sigmoid() + nn.Sigmoid(), ) def forward(self, x): attention_weights = self.attention(x) return x * attention_weights + class ResidualInResidualBlock(nn.Module): def __init__(self, channels, num_convs=3): super(ResidualInResidualBlock, self).__init__() @@ -47,6 +50,7 @@ class ResidualInResidualBlock(nn.Module): x = self.attention(x) return x + residual + class SISUGenerator(nn.Module): def __init__(self, channels=16, num_rirb=4, alpha=1.0): super(SISUGenerator, self).__init__() diff --git a/training.py b/training.py index ab2b1e5..8876740 100644 --- a/training.py +++ b/training.py @@ -1,65 +1,74 @@ +import argparse +import os + import torch import torch.nn as nn import torch.optim as optim - -import torch.nn.functional as F -import torchaudio +import torchaudio.transforms as T import tqdm - -import argparse - -import math - -import os - -from torch.utils.data import random_split +from torch.amp import GradScaler, autocast from torch.utils.data import DataLoader -import AudioUtils +import training_utils from data import AudioDataset -from generator import SISUGenerator from discriminator import SISUDiscriminator - +from generator import SISUGenerator from training_utils import discriminator_train, generator_train -import file_utils as Data - -import torchaudio.transforms as T - -# Init script argument parser -parser = argparse.ArgumentParser(description="Training script") -parser.add_argument("--generator", type=str, default=None, - help="Path to the generator model file") -parser.add_argument("--discriminator", type=str, default=None, - help="Path to the discriminator model file") -parser.add_argument("--device", type=str, default="cpu", help="Select device") -parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning") -parser.add_argument("--debug", action="store_true", help="Print debug logs") -parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models") +# --------------------------- +# Argument parsing +# --------------------------- +parser = argparse.ArgumentParser(description="Training script (safer defaults)") +parser.add_argument("--resume", action="store_true", help="Resume training") +parser.add_argument( + "--device", type=str, default="cuda", help="Device (cuda, cpu, mps)" +) +parser.add_argument( + "--epochs", type=int, default=5000, help="Number of training epochs" +) +parser.add_argument("--batch_size", type=int, default=8, help="Batch size") +parser.add_argument("--num_workers", type=int, default=2, help="DataLoader num_workers") +parser.add_argument("--debug", action="store_true", help="Print debug logs") +parser.add_argument( + "--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA" +) args = parser.parse_args() -device = torch.device(args.device if torch.cuda.is_available() else "cpu") +# --------------------------- +# Device setup +# --------------------------- +# Use requested device only if available +device = torch.device( + args.device if (args.device != "cuda" or torch.cuda.is_available()) else "cpu" +) print(f"Using device: {device}") +# sensible performance flags +if device.type == "cuda": + torch.backends.cudnn.benchmark = True + # optional: torch.set_float32_matmul_precision("high") +debug = args.debug -# Parameters +# --------------------------- +# Audio transforms +# --------------------------- sample_rate = 44100 n_fft = 1024 win_length = n_fft hop_length = n_fft // 4 -n_mels = 40 -n_mfcc = 13 +n_mels = 96 +# n_mfcc = 13 -mfcc_transform = T.MFCC( - sample_rate=sample_rate, - n_mfcc=n_mfcc, - melkwargs={ - 'n_fft': n_fft, - 'hop_length': hop_length, - 'win_length': win_length, - 'n_mels': n_mels, - 'power': 1.0, - } -).to(device) +# mfcc_transform = T.MFCC( +# sample_rate=sample_rate, +# n_mfcc=n_mfcc, +# melkwargs=dict( +# n_fft=n_fft, +# hop_length=hop_length, +# win_length=win_length, +# n_mels=n_mels, +# power=1.0, +# ), +# ).to(device) mel_transform = T.MelSpectrogram( sample_rate=sample_rate, @@ -67,138 +76,198 @@ mel_transform = T.MelSpectrogram( hop_length=hop_length, win_length=win_length, n_mels=n_mels, - power=1.0 # Magnitude Mel + power=1.0, ).to(device) stft_transform = T.Spectrogram( - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length + n_fft=n_fft, win_length=win_length, hop_length=hop_length ).to(device) -debug = args.debug -# Initialize dataset and dataloader -dataset_dir = './dataset/good' -dataset = AudioDataset(dataset_dir, device) -models_dir = "./models" -os.makedirs(models_dir, exist_ok=True) -audio_output_dir = "./output" -os.makedirs(audio_output_dir, exist_ok=True) +# training_utils.init(mel_transform, stft_transform, mfcc_transform) +training_utils.init(mel_transform, stft_transform) -# ========= SINGLE ========= +# --------------------------- +# Dataset / DataLoader +# --------------------------- +dataset_dir = "./dataset/good" +dataset = AudioDataset(dataset_dir) -train_data_loader = DataLoader(dataset, batch_size=2048, shuffle=True, num_workers=24) +train_loader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + pin_memory=True, + persistent_workers=True, +) +# --------------------------- +# Models +# --------------------------- +generator = SISUGenerator().to(device) +discriminator = SISUDiscriminator().to(device) -# ========= MODELS ========= +generator = torch.compile(generator) +discriminator = torch.compile(discriminator) -generator = SISUGenerator() -discriminator = SISUDiscriminator() - -epoch: int = args.epoch - -if args.continue_training: - if args.generator is not None: - generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True)) - elif args.discriminator is not None: - discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True)) - else: - generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) - discriminator.load_state_dict(torch.load(f"{models_dir}/temp_discriminator.pt", map_location=device, weights_only=True)) - - epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json") - epoch = epoch_from_file["epoch"] + 1 - -generator = generator.to(device) -discriminator = discriminator.to(device) - -# Loss +# --------------------------- +# Losses / Optimizers / Scalers +# --------------------------- criterion_g = nn.BCEWithLogitsLoss() criterion_d = nn.BCEWithLogitsLoss() -# Optimizers -optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999)) -optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) +optimizer_g = optim.AdamW( + generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 +) +optimizer_d = optim.AdamW( + discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 +) -# Scheduler -scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5) -scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5) +# Use modern GradScaler signature; choose device_type based on runtime device. +scaler = GradScaler(device=device) -def start_training(): - generator_epochs = 5000 - for generator_epoch in range(generator_epochs): - high_quality_audio = ([torch.empty((1))], 1) - low_quality_audio = ([torch.empty((1))], 1) - ai_enhanced_audio = ([torch.empty((1))], 1) +scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer_g, mode="min", factor=0.5, patience=5 +) +scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer_d, mode="min", factor=0.5, patience=5 +) - times_correct = 0 - - # ========= TRAINING ========= - for training_data in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"): - ## Data structure: - # [[[float..., float..., float...], [float..., float..., float...]], [original_sample_rate, mangled_sample_rate]] - - # ========= LABELS ========= - good_quality_data = training_data[0][0].to(device) - bad_quality_data = training_data[0][1].to(device) - original_sample_rate = training_data[1][0] - mangled_sample_rate = training_data[1][1] - - batch_size = good_quality_data.size(0) - real_labels = torch.ones(batch_size, 1).to(device) - fake_labels = torch.zeros(batch_size, 1).to(device) - - high_quality_audio = (good_quality_data, original_sample_rate) - low_quality_audio = (bad_quality_data, mangled_sample_rate) - - # ========= DISCRIMINATOR ========= - discriminator.train() - d_loss = discriminator_train( - good_quality_data, - bad_quality_data, - real_labels, - fake_labels, - discriminator, - generator, - criterion_d, - optimizer_d - ) - - # ========= GENERATOR ========= - generator.train() - generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train( - bad_quality_data, - good_quality_data, - real_labels, - generator, - discriminator, - criterion_d, - optimizer_g, - device, - mel_transform, - stft_transform, - mfcc_transform - ) - - if debug: - print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}") - scheduler_d.step(d_loss.detach()) - scheduler_g.step(adversarial_loss.detach()) - - # ========= SAVE LATEST AUDIO ========= - high_quality_audio = (good_quality_data, original_sample_rate) - low_quality_audio = (bad_quality_data, original_sample_rate) - ai_enhanced_audio = (generator_output, original_sample_rate) - - torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt") - torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt") - - new_epoch = generator_epoch+epoch - Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch}) +# --------------------------- +# Checkpoint helpers +# --------------------------- +models_dir = "./models" +os.makedirs(models_dir, exist_ok=True) - torch.save(discriminator, "models/epoch-5000-discriminator.pt") - torch.save(generator, "models/epoch-5000-generator.pt") - print("Training complete!") +def save_ckpt(path, epoch): + torch.save( + { + "epoch": epoch, + "G": generator.state_dict(), + "D": discriminator.state_dict(), + "optG": optimizer_g.state_dict(), + "optD": optimizer_d.state_dict(), + "scaler": scaler.state_dict(), + "schedG": scheduler_g.state_dict(), + "schedD": scheduler_d.state_dict(), + }, + path, + ) -start_training() + +start_epoch = 0 +if args.resume: + ckpt = torch.load(os.path.join(models_dir, "last.pt"), map_location=device) + generator.load_state_dict(ckpt["G"]) + discriminator.load_state_dict(ckpt["D"]) + optimizer_g.load_state_dict(ckpt["optG"]) + optimizer_d.load_state_dict(ckpt["optD"]) + scaler.load_state_dict(ckpt["scaler"]) + scheduler_g.load_state_dict(ckpt["schedG"]) + scheduler_d.load_state_dict(ckpt["schedD"]) + start_epoch = ckpt.get("epoch", 1) + +# --------------------------- +# Training loop (safer) +# --------------------------- + +if not train_loader or not train_loader.batch_size: + print("There is no data to train with! Exiting...") + exit() + +max_batch = max(1, train_loader.batch_size) +real_buf = torch.full((max_batch, 1), 0.9, device=device) # label smoothing +fake_buf = torch.zeros(max_batch, 1, device=device) + +try: + for epoch in range(start_epoch, args.epochs): + generator.train() + discriminator.train() + + running_d, running_g, steps = 0.0, 0.0, 0 + + for i, ( + (high_quality, low_quality), + (high_sample_rate, low_sample_rate), + ) in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")): + batch_size = high_quality.size(0) + + high_quality = high_quality.to(device, non_blocking=True) + low_quality = low_quality.to(device, non_blocking=True) + + real_labels = real_buf[:batch_size] + fake_labels = fake_buf[:batch_size] + + # --- Discriminator --- + optimizer_d.zero_grad(set_to_none=True) + with autocast(device_type=device.type): + d_loss = discriminator_train( + high_quality, + low_quality, + real_labels, + fake_labels, + discriminator, + generator, + criterion_d, + ) + + scaler.scale(d_loss).backward() + scaler.unscale_(optimizer_d) + torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0) + scaler.step(optimizer_d) + + # --- Generator --- + optimizer_g.zero_grad(set_to_none=True) + with autocast(device_type=device.type): + g_out, g_total, g_adv = generator_train( + low_quality, + high_quality, + real_labels, + generator, + discriminator, + criterion_d, + ) + + scaler.scale(g_total).backward() + scaler.unscale_(optimizer_g) + torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0) + scaler.step(optimizer_g) + + scaler.update() + + running_d += float(d_loss.detach().cpu().item()) + running_g += float(g_total.detach().cpu().item()) + steps += 1 + + # epoch averages & schedulers + if steps == 0: + print("No steps in epoch (empty dataloader?). Exiting.") + break + + mean_d = running_d / steps + mean_g = running_g / steps + + scheduler_d.step(mean_d) + scheduler_g.step(mean_g) + + save_ckpt(os.path.join(models_dir, "last.pt"), epoch) + print(f"Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}") + +except Exception: + try: + save_ckpt(os.path.join(models_dir, "crash_last.pt"), epoch) + print(f"Saved crash checkpoint for epoch {epoch}") + except Exception as e: + print("Failed saving crash checkpoint:", e) + raise + +try: + torch.save(generator.state_dict(), os.path.join(models_dir, "final_generator.pt")) + torch.save( + discriminator.state_dict(), os.path.join(models_dir, "final_discriminator.pt") + ) +except Exception as e: + print("Failed to save final states:", e) + +print("Training finished.") diff --git a/training_utils.py b/training_utils.py index c7d43e5..02403af 100644 --- a/training_utils.py +++ b/training_utils.py @@ -1,89 +1,88 @@ import torch -import torch.nn as nn -import torch.optim as optim - -import torchaudio import torchaudio.transforms as T -def gpu_mfcc_loss(mfcc_transform, y_true, y_pred): - mfccs_true = mfcc_transform(y_true) - mfccs_pred = mfcc_transform(y_pred) +from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss - min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2]) - mfccs_true = mfccs_true[:, :, :min_len] - mfccs_pred = mfccs_pred[:, :, :min_len] +mel_transform: T.MelSpectrogram +stft_transform: T.Spectrogram +# mfcc_transform: T.MFCC - loss = torch.mean((mfccs_true - mfccs_pred)**2) - return loss -def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: - mel_spec_true = mel_transform(y_true) - mel_spec_pred = mel_transform(y_pred) +# def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram, mfcc_trans: T.MFCC): +# """Initializes the global transform variables for the module.""" +# global mel_transform, stft_transform, mfcc_transform +# mel_transform = mel_trans +# stft_transform = stft_trans +# mfcc_transform = mfcc_trans - min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1]) - mel_spec_true = mel_spec_true[..., :min_len] - mel_spec_pred = mel_spec_pred[..., :min_len] - loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred)) - return loss +def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram): + """Initializes the global transform variables for the module.""" + global mel_transform, stft_transform + mel_transform = mel_trans + stft_transform = stft_trans -def mel_spectrogram_l2_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: - mel_spec_true = mel_transform(y_true) - mel_spec_pred = mel_transform(y_pred) - min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1]) - mel_spec_true = mel_spec_true[..., :min_len] - mel_spec_pred = mel_spec_pred[..., :min_len] +# def mfcc_loss(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: +# """Computes the Mean Squared Error (MSE) loss on MFCCs.""" +# mfccs_true = mfcc_transform(y_true) +# mfccs_pred = mfcc_transform(y_pred) +# return F.mse_loss(mfccs_pred, mfccs_true) - loss = torch.mean((mel_spec_true - mel_spec_pred)**2) - return loss -def log_stft_magnitude_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: - stft_mag_true = stft_transform(y_true) - stft_mag_pred = stft_transform(y_pred) +# def mel_spectrogram_loss( +# y_true: torch.Tensor, y_pred: torch.Tensor, loss_type: str = "l1" +# ) -> torch.Tensor: +# """Calculates L1 or L2 loss on the Mel Spectrogram.""" +# mel_spec_true = mel_transform(y_true) +# mel_spec_pred = mel_transform(y_pred) +# if loss_type == "l1": +# return F.l1_loss(mel_spec_pred, mel_spec_true) +# elif loss_type == "l2": +# return F.mse_loss(mel_spec_pred, mel_spec_true) +# else: +# raise ValueError("loss_type must be 'l1' or 'l2'") - min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1]) - stft_mag_true = stft_mag_true[..., :min_len] - stft_mag_pred = stft_mag_pred[..., :min_len] - loss = torch.mean(torch.abs(torch.log(stft_mag_true + eps) - torch.log(stft_mag_pred + eps))) - return loss +# def log_stft_magnitude_loss( +# y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7 +# ) -> torch.Tensor: +# """Calculates L1 loss on the log STFT magnitude.""" +# stft_mag_true = stft_transform(y_true) +# stft_mag_pred = stft_transform(y_pred) +# return F.l1_loss(torch.log(stft_mag_pred + eps), torch.log(stft_mag_true + eps)) -def spectral_convergence_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: - stft_mag_true = stft_transform(y_true) - stft_mag_pred = stft_transform(y_pred) - min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1]) - stft_mag_true = stft_mag_true[..., :min_len] - stft_mag_pred = stft_mag_pred[..., :min_len] +stft_loss_fn = MultiResolutionSTFTLoss( + fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240] +) - norm_true = torch.linalg.norm(stft_mag_true, ord='fro', dim=(-2, -1)) - norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, ord='fro', dim=(-2, -1)) - loss = torch.mean(norm_diff / (norm_true + eps)) - return loss - -def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer): - optimizer.zero_grad() - - # Forward pass for real samples +def discriminator_train( + high_quality, + low_quality, + real_labels, + fake_labels, + discriminator, + generator, + criterion, +): discriminator_decision_from_real = discriminator(high_quality) d_loss_real = criterion(discriminator_decision_from_real, real_labels) with torch.no_grad(): generator_output = generator(low_quality) discriminator_decision_from_fake = discriminator(generator_output) - d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake)) + d_loss_fake = criterion( + discriminator_decision_from_fake, + fake_labels.expand_as(discriminator_decision_from_fake), + ) d_loss = (d_loss_real + d_loss_fake) / 2.0 - d_loss.backward() - # Optional: Gradient Clipping (can be helpful) - # nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping - optimizer.step() - return d_loss + def generator_train( low_quality, high_quality, @@ -91,52 +90,65 @@ def generator_train( generator, discriminator, adv_criterion, - g_optimizer, - device, - mel_transform: T.MelSpectrogram, - stft_transform: T.Spectrogram, - mfcc_transform: T.MFCC, lambda_adv: float = 1.0, - lambda_mel_l1: float = 10.0, - lambda_log_stft: float = 1.0, - lambda_mfcc: float = 1.0 + lambda_feat: float = 10.0, + lambda_stft: float = 2.5, ): - g_optimizer.zero_grad() - generator_output = generator(low_quality) discriminator_decision = discriminator(generator_output) - adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision)) + # adversarial_loss = adv_criterion( + # discriminator_decision, real_labels.expand_as(discriminator_decision) + # ) + adversarial_loss = adv_criterion(discriminator_decision, real_labels) - mel_l1 = 0.0 - log_stft_l1 = 0.0 - mfcc_l = 0.0 + combined_loss = lambda_adv * adversarial_loss - # Calculate Mel L1 Loss if weight is positive - if lambda_mel_l1 > 0: - mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality, generator_output) + stft_losses = stft_loss_fn(high_quality, generator_output) + stft_loss = stft_losses["total"] - # Calculate Log STFT L1 Loss if weight is positive - if lambda_log_stft > 0: - log_stft_l1 = log_stft_magnitude_loss(stft_transform, high_quality, generator_output) + combined_loss = (lambda_adv * adversarial_loss) + (lambda_stft * stft_loss) - # Calculate MFCC Loss if weight is positive - if lambda_mfcc > 0: - mfcc_l = gpu_mfcc_loss(mfcc_transform, high_quality, generator_output) + return generator_output, combined_loss, adversarial_loss - mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1 - log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1 - mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l - combined_loss = (lambda_adv * adversarial_loss) + \ - (lambda_mel_l1 * mel_l1_tensor) + \ - (lambda_log_stft * log_stft_l1_tensor) + \ - (lambda_mfcc * mfcc_l_tensor) +# def generator_train( +# low_quality, +# high_quality, +# real_labels, +# generator, +# discriminator, +# adv_criterion, +# lambda_adv: float = 1.0, +# lambda_mel_l1: float = 10.0, +# lambda_log_stft: float = 1.0, - combined_loss.backward() - # Optional: Gradient Clipping - # nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0) - g_optimizer.step() +# ): +# generator_output = generator(low_quality) - # 6. Return values for logging - return generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor +# discriminator_decision = discriminator(generator_output) +# adversarial_loss = adv_criterion( +# discriminator_decision, real_labels.expand_as(discriminator_decision) +# ) + +# combined_loss = lambda_adv * adversarial_loss + +# if lambda_mel_l1 > 0: +# mel_l1_loss = mel_spectrogram_loss(high_quality, generator_output, "l1") +# combined_loss += lambda_mel_l1 * mel_l1_loss +# else: +# mel_l1_loss = torch.tensor(0.0, device=low_quality.device) # For logging + +# if lambda_log_stft > 0: +# log_stft_loss = log_stft_magnitude_loss(high_quality, generator_output) +# combined_loss += lambda_log_stft * log_stft_loss +# else: +# log_stft_loss = torch.tensor(0.0, device=low_quality.device) + +# if lambda_mfcc > 0: +# mfcc_loss_val = mfcc_loss(high_quality, generator_output) +# combined_loss += lambda_mfcc * mfcc_loss_val +# else: +# mfcc_loss_val = torch.tensor(0.0, device=low_quality.device) + +# return generator_output, combined_loss, adversarial_loss diff --git a/utils/MultiResolutionSTFTLoss.py b/utils/MultiResolutionSTFTLoss.py new file mode 100644 index 0000000..5712fc3 --- /dev/null +++ b/utils/MultiResolutionSTFTLoss.py @@ -0,0 +1,62 @@ +from typing import Dict, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio.transforms as T + + +class MultiResolutionSTFTLoss(nn.Module): + """ + Computes a loss based on multiple STFT resolutions, including both + spectral convergence and log STFT magnitude components. + """ + + def __init__( + self, + fft_sizes: List[int] = [1024, 2048, 512], + hop_sizes: List[int] = [120, 240, 50], + win_lengths: List[int] = [600, 1200, 240], + eps: float = 1e-7, + ): + super().__init__() + self.stft_transforms = nn.ModuleList( + [ + T.Spectrogram( + n_fft=n_fft, win_length=win_len, hop_length=hop_len, power=None + ) + for n_fft, hop_len, win_len in zip(fft_sizes, hop_sizes, win_lengths) + ] + ) + self.eps = eps + + def forward( + self, y_true: torch.Tensor, y_pred: torch.Tensor + ) -> Dict[str, torch.Tensor]: + sc_loss = 0.0 # Spectral Convergence Loss + mag_loss = 0.0 # Log STFT Magnitude Loss + + for stft in self.stft_transforms: + stft.to(y_pred.device) # Ensure transform is on the correct device + + # Get complex STFTs + stft_true = stft(y_true) + stft_pred = stft(y_pred) + + # Get magnitudes + stft_mag_true = torch.abs(stft_true) + stft_mag_pred = torch.abs(stft_pred) + + # --- Spectral Convergence Loss --- + # || |S_true| - |S_pred| ||_F / || |S_true| ||_F + norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1)) + norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1)) + sc_loss += torch.mean(norm_diff / (norm_true + self.eps)) + + # --- Log STFT Magnitude Loss --- + mag_loss += F.l1_loss( + torch.log(stft_mag_pred + self.eps), torch.log(stft_mag_true + self.eps) + ) + + total_loss = sc_loss + mag_loss + return {"total": total_loss, "sc": sc_loss, "mag": mag_loss} diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29