diff --git a/AudioUtils.py b/AudioUtils.py index f45efb5..183dc36 100644 --- a/AudioUtils.py +++ b/AudioUtils.py @@ -1,71 +1,97 @@ import torch import torch.nn.functional as F -def stereo_tensor_to_mono(waveform): + +def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor: + """ + Convert stereo (C, N) to mono (1, N). Ensures a channel dimension. + """ + if waveform.dim() == 1: + waveform = waveform.unsqueeze(0) # (N,) -> (1, N) + if waveform.shape[0] > 1: - # Average across channels - mono_waveform = torch.mean(waveform, dim=0, keepdim=True) + mono_waveform = torch.mean(waveform, dim=0, keepdim=True) # (1, N) else: - # Already mono mono_waveform = waveform return mono_waveform -def stretch_tensor(tensor, target_length): - scale_factor = target_length / tensor.size(1) - tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False) +def stretch_tensor(tensor: torch.Tensor, target_length: int) -> torch.Tensor: + """ + Stretch audio along time dimension to target_length. + Input assumed (1, N). Returns (1, target_length). + """ + if tensor.dim() == 1: + tensor = tensor.unsqueeze(0) # ensure (1, N) - return tensor + tensor = tensor.unsqueeze(0) # (1, 1, N) for interpolate + stretched = F.interpolate( + tensor, size=target_length, mode="linear", align_corners=False + ) + return stretched.squeeze(0) # back to (1, target_length) + + +def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128) -> torch.Tensor: + """ + Pad to fixed length. Input assumed (1, N). Returns (1, target_length). + """ + if audio_tensor.dim() == 1: + audio_tensor = audio_tensor.unsqueeze(0) -def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128): current_length = audio_tensor.shape[-1] - if current_length < target_length: padding_needed = target_length - current_length - padding_tuple = (0, padding_needed) - padded_audio_tensor = F.pad(audio_tensor, padding_tuple, mode='constant', value=0) + padded_audio_tensor = F.pad( + audio_tensor, padding_tuple, mode="constant", value=0 + ) else: - padded_audio_tensor = audio_tensor + padded_audio_tensor = audio_tensor[..., :target_length] # crop if too long return padded_audio_tensor -def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 128) -> list[torch.Tensor]: + +def split_audio( + audio_tensor: torch.Tensor, chunk_size: int = 128 +) -> list[torch.Tensor]: + """ + Split into chunks of (1, chunk_size). + """ if not isinstance(chunk_size, int) or chunk_size <= 0: raise ValueError("chunk_size must be a positive integer.") - # Handle scalar tensor edge case if necessary - if audio_tensor.dim() == 0: - return [audio_tensor] if audio_tensor.numel() > 0 else [] - - # Identify the dimension to split (usually the last one, representing time/samples) - split_dim = -1 - num_samples = audio_tensor.shape[split_dim] + if audio_tensor.dim() == 1: + audio_tensor = audio_tensor.unsqueeze(0) + num_samples = audio_tensor.shape[-1] if num_samples == 0: - return [] # Return empty list if the dimension to split is empty - - # Use torch.split to divide the tensor into chunks - # It handles the last chunk being potentially smaller automatically. - chunks = list(torch.split(audio_tensor, chunk_size, dim=split_dim)) + return [] + chunks = list(torch.split(audio_tensor, chunk_size, dim=-1)) return chunks + def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor: + """ + Reconstruct audio from chunks. Returns (1, N). + """ if not chunks: - return torch.empty(0) - - if len(chunks) == 1 and chunks[0].dim() == 0: - return chunks[0] - - concat_dim = -1 + return torch.empty(1, 0) + chunks = [c if c.dim() == 2 else c.unsqueeze(0) for c in chunks] try: - reconstructed_tensor = torch.cat(chunks, dim=concat_dim) + reconstructed_tensor = torch.cat(chunks, dim=-1) except RuntimeError as e: raise RuntimeError( f"Failed to concatenate audio chunks. Ensure chunks have compatible shapes " - f"for concatenation along dimension {concat_dim}. Original error: {e}" + f"for concatenation along dim -1. Original error: {e}" ) return reconstructed_tensor + + +def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + max_val = torch.max(torch.abs(audio_tensor)) + if max_val < eps: + return audio_tensor # silence, skip normalization + return audio_tensor / max_val diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app.py b/app.py index 8c669a9..006cba5 100644 --- a/app.py +++ b/app.py @@ -68,6 +68,7 @@ def start(): original_sample_rate = decoded_samples.sample_rate audio = AudioUtils.stereo_tensor_to_mono(audio) + audio = AudioUtils.normalize(audio) resample_transform = torchaudio.transforms.Resample( original_sample_rate, args.sample_rate diff --git a/data.py b/data.py index c9f4e76..a2ddb71 100644 --- a/data.py +++ b/data.py @@ -12,12 +12,15 @@ import AudioUtils class AudioDataset(Dataset): audio_sample_rates = [11025] - def __init__(self, input_dir, clip_length=16384): + def __init__(self, input_dir, clip_length: int = 8000, normalize: bool = True): + self.clip_length = clip_length + self.normalize = normalize + input_files = [ - os.path.join(root, f) - for root, _, files in os.walk(input_dir) - for f in files - if f.endswith(".wav") or f.endswith(".mp3") or f.endswith(".flac") + os.path.join(input_dir, f) + for f in os.listdir(input_dir) + if os.path.isfile(os.path.join(input_dir, f)) + and f.lower().endswith((".wav", ".mp3", ".flac")) ] data = [] @@ -25,14 +28,15 @@ class AudioDataset(Dataset): input_files, desc=f"Processing {len(input_files)} audio file(s)" ): decoder = decoders.AudioDecoder(audio_clip) - decoded_samples = decoder.get_all_samples() - audio = decoded_samples.data + + audio = decoded_samples.data.float() # ensure float32 original_sample_rate = decoded_samples.sample_rate audio = AudioUtils.stereo_tensor_to_mono(audio) + if normalize: + audio = AudioUtils.normalize(audio) - # Generate low-quality audio with random downsampling mangled_sample_rate = random.choice(self.audio_sample_rates) resample_transform_low = torchaudio.transforms.Resample( original_sample_rate, mangled_sample_rate @@ -41,25 +45,27 @@ class AudioDataset(Dataset): mangled_sample_rate, original_sample_rate ) - low_audio = resample_transform_low(audio) - low_audio = resample_transform_high(low_audio) + low_audio = resample_transform_high(resample_transform_low(audio)) splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length) + splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length) + + if not splitted_high_quality_audio or not splitted_low_quality_audio: + continue # skip empty or invalid clips + splitted_high_quality_audio[-1] = AudioUtils.pad_tensor( splitted_high_quality_audio[-1], clip_length ) - - splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length) splitted_low_quality_audio[-1] = AudioUtils.pad_tensor( splitted_low_quality_audio[-1], clip_length ) - for high_quality_sample, low_quality_sample in zip( + for high_quality_data, low_quality_data in zip( splitted_high_quality_audio, splitted_low_quality_audio ): data.append( ( - (high_quality_sample, low_quality_sample), + (high_quality_data, low_quality_data), (original_sample_rate, mangled_sample_rate), ) ) diff --git a/discriminator.py b/discriminator.py index ce2e84c..5e8442b 100644 --- a/discriminator.py +++ b/discriminator.py @@ -49,74 +49,18 @@ class AttentionBlock(nn.Module): class SISUDiscriminator(nn.Module): - def __init__(self, base_channels=16): + def __init__(self, layers=32): super(SISUDiscriminator, self).__init__() - layers = base_channels self.model = nn.Sequential( - discriminator_block( - 1, - layers, - kernel_size=7, - stride=1, - spectral_norm=True, - use_instance_norm=False, - ), - discriminator_block( - layers, - layers * 2, - kernel_size=5, - stride=2, - spectral_norm=True, - use_instance_norm=True, - ), - discriminator_block( - layers * 2, - layers * 4, - kernel_size=5, - stride=1, - dilation=2, - spectral_norm=True, - use_instance_norm=True, - ), + discriminator_block(1, layers, kernel_size=7, stride=1), + discriminator_block(layers, layers * 2, kernel_size=5, stride=2), + discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=2), AttentionBlock(layers * 4), - discriminator_block( - layers * 4, - layers * 8, - kernel_size=5, - stride=1, - dilation=4, - spectral_norm=True, - use_instance_norm=True, - ), - discriminator_block( - layers * 8, - layers * 4, - kernel_size=5, - stride=2, - spectral_norm=True, - use_instance_norm=True, - ), - discriminator_block( - layers * 4, - layers * 2, - kernel_size=3, - stride=1, - spectral_norm=True, - use_instance_norm=True, - ), + discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=4), + discriminator_block(layers * 8, layers * 2, kernel_size=5, stride=2), discriminator_block( layers * 2, - layers, - kernel_size=3, - stride=1, - spectral_norm=True, - use_instance_norm=True, - ), - discriminator_block( - layers, 1, - kernel_size=3, - stride=1, spectral_norm=False, use_instance_norm=False, ), diff --git a/file_utils.py b/file_utils.py deleted file mode 100644 index 98f70bc..0000000 --- a/file_utils.py +++ /dev/null @@ -1,30 +0,0 @@ -import json - -filepath = "my_data.json" - -def write_data(filepath, data, debug=False): - try: - with open(filepath, 'w') as f: - json.dump(data, f, indent=4) # Use indent for pretty formatting - if debug: - print(f"Data written to '{filepath}'") - except Exception as e: - print(f"Error writing to file: {e}") - - -def read_data(filepath, debug=False): - try: - with open(filepath, 'r') as f: - data = json.load(f) - if debug: - print(f"Data read from '{filepath}'") - return data - except FileNotFoundError: - print(f"File not found: {filepath}") - return None - except json.JSONDecodeError: - print(f"Error decoding JSON from file: {filepath}") - return None - except Exception as e: - print(f"Error reading from file: {e}") - return None diff --git a/generator.py b/generator.py index 0240860..b6d2204 100644 --- a/generator.py +++ b/generator.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn @@ -52,7 +53,7 @@ class ResidualInResidualBlock(nn.Module): class SISUGenerator(nn.Module): - def __init__(self, channels=16, num_rirb=4, alpha=1.0): + def __init__(self, channels=16, num_rirb=4, alpha=1): super(SISUGenerator, self).__init__() self.alpha = alpha @@ -66,7 +67,9 @@ class SISUGenerator(nn.Module): *[ResidualInResidualBlock(channels) for _ in range(num_rirb)] ) - self.final_layer = nn.Conv1d(channels, 1, kernel_size=3, padding=1) + self.final_layer = nn.Sequential( + nn.Conv1d(channels, 1, kernel_size=3, padding=1), nn.Tanh() + ) def forward(self, x): residual_input = x @@ -75,4 +78,4 @@ class SISUGenerator(nn.Module): learned_residual = self.final_layer(x_rirb_out) output = residual_input + self.alpha * learned_residual - return output + return torch.tanh(output) diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 21f6bef..0000000 --- a/requirements.txt +++ /dev/null @@ -1,12 +0,0 @@ -filelock==3.16.1 -fsspec==2024.10.0 -Jinja2==3.1.4 -MarkupSafe==2.1.5 -mpmath==1.3.0 -networkx==3.4.2 -numpy==2.2.3 -pillow==11.0.0 -setuptools==70.2.0 -sympy==1.13.3 -tqdm==4.67.1 -typing_extensions==4.12.2 diff --git a/training.py b/training.py index 8876740..0e0caf8 100644 --- a/training.py +++ b/training.py @@ -4,25 +4,20 @@ import os import torch import torch.nn as nn import torch.optim as optim -import torchaudio.transforms as T import tqdm -from torch.amp import GradScaler, autocast -from torch.utils.data import DataLoader +from accelerate import Accelerator +from torch.utils.data import DataLoader, DistributedSampler -import training_utils from data import AudioDataset from discriminator import SISUDiscriminator from generator import SISUGenerator -from training_utils import discriminator_train, generator_train +from utils.TrainingTools import discriminator_train, generator_train # --------------------------- # Argument parsing # --------------------------- parser = argparse.ArgumentParser(description="Training script (safer defaults)") parser.add_argument("--resume", action="store_true", help="Resume training") -parser.add_argument( - "--device", type=str, default="cuda", help="Device (cuda, cpu, mps)" -) parser.add_argument( "--epochs", type=int, default=5000, help="Number of training epochs" ) @@ -35,86 +30,54 @@ parser.add_argument( args = parser.parse_args() # --------------------------- -# Device setup +# Init accelerator # --------------------------- -# Use requested device only if available -device = torch.device( - args.device if (args.device != "cuda" or torch.cuda.is_available()) else "cpu" -) -print(f"Using device: {device}") -# sensible performance flags -if device.type == "cuda": - torch.backends.cudnn.benchmark = True - # optional: torch.set_float32_matmul_precision("high") -debug = args.debug -# --------------------------- -# Audio transforms -# --------------------------- -sample_rate = 44100 -n_fft = 1024 -win_length = n_fft -hop_length = n_fft // 4 -n_mels = 96 -# n_mfcc = 13 - -# mfcc_transform = T.MFCC( -# sample_rate=sample_rate, -# n_mfcc=n_mfcc, -# melkwargs=dict( -# n_fft=n_fft, -# hop_length=hop_length, -# win_length=win_length, -# n_mels=n_mels, -# power=1.0, -# ), -# ).to(device) - -mel_transform = T.MelSpectrogram( - sample_rate=sample_rate, - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - n_mels=n_mels, - power=1.0, -).to(device) - -stft_transform = T.Spectrogram( - n_fft=n_fft, win_length=win_length, hop_length=hop_length -).to(device) - -# training_utils.init(mel_transform, stft_transform, mfcc_transform) -training_utils.init(mel_transform, stft_transform) - -# --------------------------- -# Dataset / DataLoader -# --------------------------- -dataset_dir = "./dataset/good" -dataset = AudioDataset(dataset_dir) - -train_loader = DataLoader( - dataset, - batch_size=args.batch_size, - shuffle=True, - num_workers=args.num_workers, - pin_memory=True, - persistent_workers=True, -) +accelerator = Accelerator(mixed_precision="bf16") # --------------------------- # Models # --------------------------- -generator = SISUGenerator().to(device) -discriminator = SISUDiscriminator().to(device) +generator = SISUGenerator() +discriminator = SISUDiscriminator() + +accelerator.print("๐Ÿ”จ | Compiling models...") generator = torch.compile(generator) discriminator = torch.compile(discriminator) +accelerator.print("โœ… | Compiling done!") + +# --------------------------- +# Dataset / DataLoader +# --------------------------- +accelerator.print("๐Ÿ“Š | Fetching dataset...") +dataset = AudioDataset("./dataset") + +sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None +pin_memory = torch.cuda.is_available() and not args.no_pin_memory + +train_loader = DataLoader( + dataset, + sampler=sampler, + batch_size=args.batch_size, + shuffle=(sampler is None), + num_workers=args.num_workers, + pin_memory=pin_memory, + persistent_workers=pin_memory, +) + +if not train_loader or not train_loader.batch_size or train_loader.batch_size == 0: + accelerator.print("๐Ÿชน | There is no data to train with! Exiting...") + exit() + +loader_batch_size = train_loader.batch_size + +accelerator.print("โœ… | Dataset fetched!") + # --------------------------- # Losses / Optimizers / Scalers # --------------------------- -criterion_g = nn.BCEWithLogitsLoss() -criterion_d = nn.BCEWithLogitsLoss() optimizer_g = optim.AdamW( generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 @@ -123,9 +86,6 @@ optimizer_d = optim.AdamW( discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 ) -# Use modern GradScaler signature; choose device_type based on runtime device. -scaler = GradScaler(device=device) - scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_g, mode="min", factor=0.5, patience=5 ) @@ -133,6 +93,17 @@ scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_d, mode="min", factor=0.5, patience=5 ) +criterion_g = nn.BCEWithLogitsLoss() +criterion_d = nn.MSELoss() + +# --------------------------- +# Prepare accelerator +# --------------------------- + +generator, discriminator, optimizer_g, optimizer_d, train_loader = accelerator.prepare( + generator, discriminator, optimizer_g, optimizer_d, train_loader +) + # --------------------------- # Checkpoint helpers # --------------------------- @@ -141,44 +112,45 @@ os.makedirs(models_dir, exist_ok=True) def save_ckpt(path, epoch): - torch.save( - { - "epoch": epoch, - "G": generator.state_dict(), - "D": discriminator.state_dict(), - "optG": optimizer_g.state_dict(), - "optD": optimizer_d.state_dict(), - "scaler": scaler.state_dict(), - "schedG": scheduler_g.state_dict(), - "schedD": scheduler_d.state_dict(), - }, - path, - ) + accelerator.wait_for_everyone() + if accelerator.is_main_process: + accelerator.save( + { + "epoch": epoch, + "G": accelerator.unwrap_model(generator).state_dict(), + "D": accelerator.unwrap_model(discriminator).state_dict(), + "optG": optimizer_g.state_dict(), + "optD": optimizer_d.state_dict(), + "schedG": scheduler_g.state_dict(), + "schedD": scheduler_d.state_dict(), + }, + path, + ) start_epoch = 0 if args.resume: - ckpt = torch.load(os.path.join(models_dir, "last.pt"), map_location=device) - generator.load_state_dict(ckpt["G"]) - discriminator.load_state_dict(ckpt["D"]) + ckpt_path = os.path.join(models_dir, "last.pt") + ckpt = torch.load(ckpt_path) + + accelerator.unwrap_model(generator).load_state_dict(ckpt["G"]) + accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"]) optimizer_g.load_state_dict(ckpt["optG"]) optimizer_d.load_state_dict(ckpt["optD"]) - scaler.load_state_dict(ckpt["scaler"]) scheduler_g.load_state_dict(ckpt["schedG"]) scheduler_d.load_state_dict(ckpt["schedD"]) + start_epoch = ckpt.get("epoch", 1) + accelerator.print(f"๐Ÿ” | Resumed from epoch {start_epoch}!") -# --------------------------- -# Training loop (safer) -# --------------------------- +real_buf = torch.full( + (loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32 +) +fake_buf = torch.zeros( + (loader_batch_size, 1), device=accelerator.device, dtype=torch.float32 +) -if not train_loader or not train_loader.batch_size: - print("There is no data to train with! Exiting...") - exit() - -max_batch = max(1, train_loader.batch_size) -real_buf = torch.full((max_batch, 1), 0.9, device=device) # label smoothing -fake_buf = torch.zeros(max_batch, 1, device=device) +accelerator.print("๐Ÿ‹๏ธ | Started training...") try: for epoch in range(start_epoch, args.epochs): @@ -193,15 +165,12 @@ try: ) in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")): batch_size = high_quality.size(0) - high_quality = high_quality.to(device, non_blocking=True) - low_quality = low_quality.to(device, non_blocking=True) - - real_labels = real_buf[:batch_size] - fake_labels = fake_buf[:batch_size] + real_labels = real_buf[:batch_size].to(accelerator.device) + fake_labels = fake_buf[:batch_size].to(accelerator.device) # --- Discriminator --- optimizer_d.zero_grad(set_to_none=True) - with autocast(device_type=device.type): + with accelerator.autocast(): d_loss = discriminator_train( high_quality, low_quality, @@ -212,15 +181,14 @@ try: criterion_d, ) - scaler.scale(d_loss).backward() - scaler.unscale_(optimizer_d) - torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0) - scaler.step(optimizer_d) + accelerator.backward(d_loss) + torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1) + optimizer_d.step() # --- Generator --- optimizer_g.zero_grad(set_to_none=True) - with autocast(device_type=device.type): - g_out, g_total, g_adv = generator_train( + with accelerator.autocast(): + g_total, g_adv = generator_train( low_quality, high_quality, real_labels, @@ -229,20 +197,32 @@ try: criterion_d, ) - scaler.scale(g_total).backward() - scaler.unscale_(optimizer_g) - torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0) - scaler.step(optimizer_g) + accelerator.backward(g_total) + torch.nn.utils.clip_grad_norm_(generator.parameters(), 1) + optimizer_g.step() - scaler.update() + d_val = accelerator.gather(d_loss.detach()).mean() + g_val = accelerator.gather(g_total.detach()).mean() + + if torch.isfinite(d_val): + running_d += d_val.item() + else: + accelerator.print( + f"๐Ÿซฅ | NaN in discriminator loss at step {i}, skipping update." + ) + + if torch.isfinite(g_val): + running_g += g_val.item() + else: + accelerator.print( + f"๐Ÿซฅ | NaN in generator loss at step {i}, skipping update." + ) - running_d += float(d_loss.detach().cpu().item()) - running_g += float(g_total.detach().cpu().item()) steps += 1 # epoch averages & schedulers if steps == 0: - print("No steps in epoch (empty dataloader?). Exiting.") + accelerator.print("๐Ÿชน | No steps in epoch (empty dataloader?). Exiting.") break mean_d = running_d / steps @@ -252,22 +232,14 @@ try: scheduler_g.step(mean_g) save_ckpt(os.path.join(models_dir, "last.pt"), epoch) - print(f"Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}") + accelerator.print(f"๐Ÿค | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}") except Exception: try: save_ckpt(os.path.join(models_dir, "crash_last.pt"), epoch) - print(f"Saved crash checkpoint for epoch {epoch}") + accelerator.print(f"๐Ÿ’พ | Saved crash checkpoint for epoch {epoch}") except Exception as e: - print("Failed saving crash checkpoint:", e) + accelerator.print("๐Ÿ˜ฌ | Failed saving crash checkpoint:", e) raise -try: - torch.save(generator.state_dict(), os.path.join(models_dir, "final_generator.pt")) - torch.save( - discriminator.state_dict(), os.path.join(models_dir, "final_discriminator.pt") - ) -except Exception as e: - print("Failed to save final states:", e) - -print("Training finished.") +accelerator.print("๐Ÿ | Training finished.") diff --git a/training_utils.py b/training_utils.py deleted file mode 100644 index 02403af..0000000 --- a/training_utils.py +++ /dev/null @@ -1,154 +0,0 @@ -import torch -import torchaudio.transforms as T - -from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss - -mel_transform: T.MelSpectrogram -stft_transform: T.Spectrogram -# mfcc_transform: T.MFCC - - -# def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram, mfcc_trans: T.MFCC): -# """Initializes the global transform variables for the module.""" -# global mel_transform, stft_transform, mfcc_transform -# mel_transform = mel_trans -# stft_transform = stft_trans -# mfcc_transform = mfcc_trans - - -def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram): - """Initializes the global transform variables for the module.""" - global mel_transform, stft_transform - mel_transform = mel_trans - stft_transform = stft_trans - - -# def mfcc_loss(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: -# """Computes the Mean Squared Error (MSE) loss on MFCCs.""" -# mfccs_true = mfcc_transform(y_true) -# mfccs_pred = mfcc_transform(y_pred) -# return F.mse_loss(mfccs_pred, mfccs_true) - - -# def mel_spectrogram_loss( -# y_true: torch.Tensor, y_pred: torch.Tensor, loss_type: str = "l1" -# ) -> torch.Tensor: -# """Calculates L1 or L2 loss on the Mel Spectrogram.""" -# mel_spec_true = mel_transform(y_true) -# mel_spec_pred = mel_transform(y_pred) -# if loss_type == "l1": -# return F.l1_loss(mel_spec_pred, mel_spec_true) -# elif loss_type == "l2": -# return F.mse_loss(mel_spec_pred, mel_spec_true) -# else: -# raise ValueError("loss_type must be 'l1' or 'l2'") - - -# def log_stft_magnitude_loss( -# y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7 -# ) -> torch.Tensor: -# """Calculates L1 loss on the log STFT magnitude.""" -# stft_mag_true = stft_transform(y_true) -# stft_mag_pred = stft_transform(y_pred) -# return F.l1_loss(torch.log(stft_mag_pred + eps), torch.log(stft_mag_true + eps)) - - -stft_loss_fn = MultiResolutionSTFTLoss( - fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240] -) - - -def discriminator_train( - high_quality, - low_quality, - real_labels, - fake_labels, - discriminator, - generator, - criterion, -): - discriminator_decision_from_real = discriminator(high_quality) - d_loss_real = criterion(discriminator_decision_from_real, real_labels) - - with torch.no_grad(): - generator_output = generator(low_quality) - discriminator_decision_from_fake = discriminator(generator_output) - d_loss_fake = criterion( - discriminator_decision_from_fake, - fake_labels.expand_as(discriminator_decision_from_fake), - ) - - d_loss = (d_loss_real + d_loss_fake) / 2.0 - - return d_loss - - -def generator_train( - low_quality, - high_quality, - real_labels, - generator, - discriminator, - adv_criterion, - lambda_adv: float = 1.0, - lambda_feat: float = 10.0, - lambda_stft: float = 2.5, -): - generator_output = generator(low_quality) - - discriminator_decision = discriminator(generator_output) - # adversarial_loss = adv_criterion( - # discriminator_decision, real_labels.expand_as(discriminator_decision) - # ) - adversarial_loss = adv_criterion(discriminator_decision, real_labels) - - combined_loss = lambda_adv * adversarial_loss - - stft_losses = stft_loss_fn(high_quality, generator_output) - stft_loss = stft_losses["total"] - - combined_loss = (lambda_adv * adversarial_loss) + (lambda_stft * stft_loss) - - return generator_output, combined_loss, adversarial_loss - - -# def generator_train( -# low_quality, -# high_quality, -# real_labels, -# generator, -# discriminator, -# adv_criterion, -# lambda_adv: float = 1.0, -# lambda_mel_l1: float = 10.0, -# lambda_log_stft: float = 1.0, - -# ): -# generator_output = generator(low_quality) - -# discriminator_decision = discriminator(generator_output) -# adversarial_loss = adv_criterion( -# discriminator_decision, real_labels.expand_as(discriminator_decision) -# ) - -# combined_loss = lambda_adv * adversarial_loss - -# if lambda_mel_l1 > 0: -# mel_l1_loss = mel_spectrogram_loss(high_quality, generator_output, "l1") -# combined_loss += lambda_mel_l1 * mel_l1_loss -# else: -# mel_l1_loss = torch.tensor(0.0, device=low_quality.device) # For logging - -# if lambda_log_stft > 0: -# log_stft_loss = log_stft_magnitude_loss(high_quality, generator_output) -# combined_loss += lambda_log_stft * log_stft_loss -# else: -# log_stft_loss = torch.tensor(0.0, device=low_quality.device) - -# if lambda_mfcc > 0: -# mfcc_loss_val = mfcc_loss(high_quality, generator_output) -# combined_loss += lambda_mfcc * mfcc_loss_val -# else: -# mfcc_loss_val = torch.tensor(0.0, device=low_quality.device) - -# return generator_output, combined_loss, adversarial_loss diff --git a/utils/MultiResolutionSTFTLoss.py b/utils/MultiResolutionSTFTLoss.py index 5712fc3..eab6355 100644 --- a/utils/MultiResolutionSTFTLoss.py +++ b/utils/MultiResolutionSTFTLoss.py @@ -8,8 +8,9 @@ import torchaudio.transforms as T class MultiResolutionSTFTLoss(nn.Module): """ - Computes a loss based on multiple STFT resolutions, including both - spectral convergence and log STFT magnitude components. + Multi-resolution STFT loss. + Combines spectral convergence loss and log-magnitude loss + across multiple STFT resolutions. """ def __init__( @@ -20,43 +21,67 @@ class MultiResolutionSTFTLoss(nn.Module): eps: float = 1e-7, ): super().__init__() - self.stft_transforms = nn.ModuleList( - [ - T.Spectrogram( - n_fft=n_fft, win_length=win_len, hop_length=hop_len, power=None - ) - for n_fft, hop_len, win_len in zip(fft_sizes, hop_sizes, win_lengths) - ] - ) + self.eps = eps + self.n_resolutions = len(fft_sizes) + + self.stft_transforms = nn.ModuleList() + for n_fft, hop_len, win_len in zip(fft_sizes, hop_sizes, win_lengths): + window = torch.hann_window(win_len) + stft = T.Spectrogram( + n_fft=n_fft, + hop_length=hop_len, + win_length=win_len, + window_fn=lambda _: window, + power=None, # Keep complex output + center=True, + pad_mode="reflect", + normalized=False, + ) + self.stft_transforms.append(stft) def forward( self, y_true: torch.Tensor, y_pred: torch.Tensor ) -> Dict[str, torch.Tensor]: - sc_loss = 0.0 # Spectral Convergence Loss - mag_loss = 0.0 # Log STFT Magnitude Loss + """ + Args: + y_true: (B, T) or (B, 1, T) waveform + y_pred: (B, T) or (B, 1, T) waveform + """ + # Ensure correct shape (B, T) + if y_true.dim() == 3 and y_true.size(1) == 1: + y_true = y_true.squeeze(1) + if y_pred.dim() == 3 and y_pred.size(1) == 1: + y_pred = y_pred.squeeze(1) + + sc_loss = 0.0 + mag_loss = 0.0 for stft in self.stft_transforms: - stft.to(y_pred.device) # Ensure transform is on the correct device + stft = stft.to(y_pred.device) - # Get complex STFTs + # Complex STFTs: (B, F, T, 2) stft_true = stft(y_true) stft_pred = stft(y_pred) - # Get magnitudes + # Magnitudes stft_mag_true = torch.abs(stft_true) stft_mag_pred = torch.abs(stft_pred) # --- Spectral Convergence Loss --- - # || |S_true| - |S_pred| ||_F / || |S_true| ||_F norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1)) norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1)) sc_loss += torch.mean(norm_diff / (norm_true + self.eps)) # --- Log STFT Magnitude Loss --- mag_loss += F.l1_loss( - torch.log(stft_mag_pred + self.eps), torch.log(stft_mag_true + self.eps) + torch.log(stft_mag_pred + self.eps), + torch.log(stft_mag_true + self.eps), ) + # Average across resolutions + sc_loss /= self.n_resolutions + mag_loss /= self.n_resolutions total_loss = sc_loss + mag_loss + return {"total": total_loss, "sc": sc_loss, "mag": mag_loss} diff --git a/utils/TrainingTools.py b/utils/TrainingTools.py new file mode 100644 index 0000000..d1959a0 --- /dev/null +++ b/utils/TrainingTools.py @@ -0,0 +1,60 @@ +import torch + +# In case if needed again... +# from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss +# +# stft_loss_fn = MultiResolutionSTFTLoss( +# fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240] +# ) + + +def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor: + absolute_difference = torch.abs(input_one - input_two) + return torch.mean(absolute_difference) + + +def discriminator_train( + high_quality, + low_quality, + high_labels, + low_labels, + discriminator, + generator, + criterion, +): + decision_high = discriminator(high_quality) + d_loss_high = criterion(decision_high, high_labels) + # print(f"Is this real?: {discriminator_decision_from_real} | {d_loss_real}") + + decision_low = discriminator(low_quality) + d_loss_low = criterion(decision_low, low_labels) + # print(f"Is this real?: {discriminator_decision_from_fake} | {d_loss_fake}") + + with torch.no_grad(): + generator_quality = generator(low_quality) + decision_gen = discriminator(generator_quality) + d_loss_gen = criterion(decision_gen, low_labels) + + noise = torch.rand_like(high_quality) * 0.08 + decision_noise = discriminator(high_quality + noise) + d_loss_noise = criterion(decision_noise, low_labels) + + d_loss = (d_loss_high + d_loss_low + d_loss_gen + d_loss_noise) / 4.0 + + return d_loss + + +def generator_train( + low_quality, high_quality, real_labels, generator, discriminator, adv_criterion +): + generator_output = generator(low_quality) + + discriminator_decision = discriminator(generator_output) + adversarial_loss = adv_criterion(discriminator_decision, real_labels) + + # Signal similarity + similarity_loss = signal_mae(generator_output, high_quality) + + combined_loss = adversarial_loss + (similarity_loss * 100) + + return combined_loss, adversarial_loss