From 782a3bab28e2484893eee1cb6b5dae54b979030a Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Tue, 18 Nov 2025 21:34:59 +0200 Subject: [PATCH] :alembic: | More architectural changes --- AudioUtils.py | 86 +++++------------------------ app.py | 55 ++++++++++++++---- data.py | 62 +++++++++------------ discriminator.py | 65 ++++++++++------------ generator.py | 95 +++++++++++++++++++++++--------- training.py | 33 +++++++---- utils/MultiResolutionSTFTLoss.py | 43 ++++----------- utils/TrainingTools.py | 60 ++++++++++---------- 8 files changed, 245 insertions(+), 254 deletions(-) diff --git a/AudioUtils.py b/AudioUtils.py index 183dc36..ff6a24f 100644 --- a/AudioUtils.py +++ b/AudioUtils.py @@ -3,95 +3,39 @@ import torch.nn.functional as F def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor: - """ - Convert stereo (C, N) to mono (1, N). Ensures a channel dimension. - """ - if waveform.dim() == 1: - waveform = waveform.unsqueeze(0) # (N,) -> (1, N) - - if waveform.shape[0] > 1: - mono_waveform = torch.mean(waveform, dim=0, keepdim=True) # (1, N) - else: - mono_waveform = waveform - return mono_waveform + mono_tensor = torch.mean(waveform, dim=0, keepdim=True) + return mono_tensor -def stretch_tensor(tensor: torch.Tensor, target_length: int) -> torch.Tensor: - """ - Stretch audio along time dimension to target_length. - Input assumed (1, N). Returns (1, target_length). - """ - if tensor.dim() == 1: - tensor = tensor.unsqueeze(0) # ensure (1, N) +def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 512) -> torch.Tensor: + padding_amount = target_length - audio_tensor.size(-1) + if padding_amount <= 0: + return audio_tensor - tensor = tensor.unsqueeze(0) # (1, 1, N) for interpolate - stretched = F.interpolate( - tensor, size=target_length, mode="linear", align_corners=False - ) - return stretched.squeeze(0) # back to (1, target_length) - - -def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128) -> torch.Tensor: - """ - Pad to fixed length. Input assumed (1, N). Returns (1, target_length). - """ - if audio_tensor.dim() == 1: - audio_tensor = audio_tensor.unsqueeze(0) - - current_length = audio_tensor.shape[-1] - if current_length < target_length: - padding_needed = target_length - current_length - padding_tuple = (0, padding_needed) - padded_audio_tensor = F.pad( - audio_tensor, padding_tuple, mode="constant", value=0 - ) - else: - padded_audio_tensor = audio_tensor[..., :target_length] # crop if too long + padded_audio_tensor = F.pad(audio_tensor, (0, padding_amount)) return padded_audio_tensor -def split_audio( - audio_tensor: torch.Tensor, chunk_size: int = 128 -) -> list[torch.Tensor]: - """ - Split into chunks of (1, chunk_size). - """ - if not isinstance(chunk_size, int) or chunk_size <= 0: - raise ValueError("chunk_size must be a positive integer.") +def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 512, pad_last_tensor: bool = False) -> list[torch.Tensor]: + chunks = list(torch.split(audio_tensor, chunk_size, dim=1)) - if audio_tensor.dim() == 1: - audio_tensor = audio_tensor.unsqueeze(0) + if pad_last_tensor: + last_chunk = chunks[-1] - num_samples = audio_tensor.shape[-1] - if num_samples == 0: - return [] + if last_chunk.size(-1) < chunk_size: + chunks[-1] = pad_tensor(last_chunk, chunk_size) - chunks = list(torch.split(audio_tensor, chunk_size, dim=-1)) return chunks def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor: - """ - Reconstruct audio from chunks. Returns (1, N). - """ - if not chunks: - return torch.empty(1, 0) - - chunks = [c if c.dim() == 2 else c.unsqueeze(0) for c in chunks] - try: - reconstructed_tensor = torch.cat(chunks, dim=-1) - except RuntimeError as e: - raise RuntimeError( - f"Failed to concatenate audio chunks. Ensure chunks have compatible shapes " - f"for concatenation along dim -1. Original error: {e}" - ) - + reconstructed_tensor = torch.cat(chunks, dim=-1) return reconstructed_tensor def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: max_val = torch.max(torch.abs(audio_tensor)) if max_val < eps: - return audio_tensor # silence, skip normalization + return audio_tensor return audio_tensor / max_val diff --git a/app.py b/app.py index 006cba5..0bacee2 100644 --- a/app.py +++ b/app.py @@ -4,6 +4,7 @@ import torch import torchaudio import torchcodec import tqdm +from accelerate import Accelerator import AudioUtils from generator import SISUGenerator @@ -15,7 +16,7 @@ parser.add_argument("--model", type=str, help="Model to use for upscaling") parser.add_argument( "--clip_length", type=int, - default=16384, + default=8000, help="Internal clip length, leave unspecified if unsure", ) parser.add_argument( @@ -38,21 +39,44 @@ if args.sample_rate < 8000: ) exit() -device = torch.device(args.device if torch.cuda.is_available() else "cpu") -print(f"Using device: {device}") +# --------------------------- +# Init accelerator +# --------------------------- -generator = SISUGenerator().to(device) +accelerator = Accelerator(mixed_precision="bf16") + +# --------------------------- +# Models +# --------------------------- +generator = SISUGenerator() + +accelerator.print("🔨 | Compiling models...") generator = torch.compile(generator) +accelerator.print("✅ | Compiling done!") + +# --------------------------- +# Prepare accelerator +# --------------------------- + +generator = accelerator.prepare(generator) + +# --------------------------- +# Checkpoint helpers +# --------------------------- + models_dir = args.model clip_length = args.clip_length input_audio = args.input output_audio = args.output if models_dir: - ckpt = torch.load(models_dir, map_location=device) - generator.load_state_dict(ckpt["G"]) + ckpt = torch.load(models_dir) + + accelerator.unwrap_model(generator).load_state_dict(ckpt["G"]) + + accelerator.print("💾 | Loaded model!") else: print( "Generator model (--model) isn't specified. Do you have the trained model? If not, you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)" @@ -67,7 +91,8 @@ def start(): audio = decoded_samples.data original_sample_rate = decoded_samples.sample_rate - audio = AudioUtils.stereo_tensor_to_mono(audio) + # Support for multichannel audio + # audio = AudioUtils.stereo_tensor_to_mono(audio) audio = AudioUtils.normalize(audio) resample_transform = torchaudio.transforms.Resample( @@ -77,14 +102,20 @@ def start(): audio = resample_transform(audio) splitted_audio = AudioUtils.split_audio(audio, clip_length) - splitted_audio_on_device = [t.to(device) for t in splitted_audio] + splitted_audio_on_device = [t.view(1, t.shape[0], t.shape[-1]).to(accelerator.device) for t in splitted_audio] processed_audio = [] - - for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."): - processed_audio.append(generator(clip)) + with torch.no_grad(): + for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."): + channels = [] + for audio_channel in torch.split(clip, 1, dim=1): + output_piece = generator(audio_channel) + channels.append(output_piece.detach().cpu()) + output_clip = torch.cat(channels, dim=1) + processed_audio.append(output_clip) reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio) - print(f"Saving {output_audio}!") + reconstructed_audio = reconstructed_audio.squeeze(0) + print(f"🔊 | Saving {output_audio}!") torchaudio.save_with_torchcodec( uri=output_audio, src=reconstructed_audio, diff --git a/data.py b/data.py index a2ddb71..13b0ca8 100644 --- a/data.py +++ b/data.py @@ -1,6 +1,7 @@ import os import random +import torch import torchaudio import torchcodec.decoders as decoders import tqdm @@ -10,9 +11,9 @@ import AudioUtils class AudioDataset(Dataset): - audio_sample_rates = [11025] + audio_sample_rates = [8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100] - def __init__(self, input_dir, clip_length: int = 8000, normalize: bool = True): + def __init__(self, input_dir, clip_length: int = 512, normalize: bool = True): self.clip_length = clip_length self.normalize = normalize @@ -30,45 +31,20 @@ class AudioDataset(Dataset): decoder = decoders.AudioDecoder(audio_clip) decoded_samples = decoder.get_all_samples() - audio = decoded_samples.data.float() # ensure float32 + audio = decoded_samples.data.float() original_sample_rate = decoded_samples.sample_rate - audio = AudioUtils.stereo_tensor_to_mono(audio) if normalize: audio = AudioUtils.normalize(audio) - mangled_sample_rate = random.choice(self.audio_sample_rates) - resample_transform_low = torchaudio.transforms.Resample( - original_sample_rate, mangled_sample_rate - ) - resample_transform_high = torchaudio.transforms.Resample( - mangled_sample_rate, original_sample_rate - ) + splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length, True) - low_audio = resample_transform_high(resample_transform_low(audio)) + if not splitted_high_quality_audio: + continue - splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length) - splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length) - - if not splitted_high_quality_audio or not splitted_low_quality_audio: - continue # skip empty or invalid clips - - splitted_high_quality_audio[-1] = AudioUtils.pad_tensor( - splitted_high_quality_audio[-1], clip_length - ) - splitted_low_quality_audio[-1] = AudioUtils.pad_tensor( - splitted_low_quality_audio[-1], clip_length - ) - - for high_quality_data, low_quality_data in zip( - splitted_high_quality_audio, splitted_low_quality_audio - ): - data.append( - ( - (high_quality_data, low_quality_data), - (original_sample_rate, mangled_sample_rate), - ) - ) + for splitted_audio_clip in splitted_high_quality_audio: + for audio_clip in torch.split(splitted_audio_clip, 1): + data.append((audio_clip, original_sample_rate)) self.audio_data = data @@ -76,4 +52,20 @@ class AudioDataset(Dataset): return len(self.audio_data) def __getitem__(self, idx): - return self.audio_data[idx] + audio_clip = self.audio_data[idx] + mangled_sample_rate = random.choice(self.audio_sample_rates) + + resample_transform_low = torchaudio.transforms.Resample( + audio_clip[1], mangled_sample_rate + ) + + resample_transform_high = torchaudio.transforms.Resample( + mangled_sample_rate, audio_clip[1] + ) + + low_audio_clip = resample_transform_high(resample_transform_low(audio_clip[0])) + if audio_clip[0].shape[1] < low_audio_clip.shape[1]: + low_audio_clip = low_audio_clip[:, :audio_clip[0].shape[1]] + elif audio_clip[0].shape[1] > low_audio_clip.shape[1]: + low_audio_clip = AudioUtils.pad_tensor(low_audio_clip, self.clip_length) + return ((audio_clip[0], low_audio_clip), (audio_clip[1], mangled_sample_rate)) diff --git a/discriminator.py b/discriminator.py index 5e8442b..69de5ce 100644 --- a/discriminator.py +++ b/discriminator.py @@ -5,32 +5,25 @@ import torch.nn.utils as utils def discriminator_block( in_channels, out_channels, - kernel_size=3, + kernel_size=15, stride=1, - dilation=1, - spectral_norm=True, - use_instance_norm=True, + dilation=1 ): - padding = (kernel_size // 2) * dilation + padding = dilation * (kernel_size - 1) // 2 + conv_layer = nn.Conv1d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, - padding=padding, + padding=padding ) - if spectral_norm: - conv_layer = utils.spectral_norm(conv_layer) + conv_layer = utils.spectral_norm(conv_layer) + leaky_relu = nn.LeakyReLU(0.2) - layers = [conv_layer] - layers.append(nn.LeakyReLU(0.2, inplace=True)) - - if use_instance_norm: - layers.append(nn.InstanceNorm1d(out_channels)) - - return nn.Sequential(*layers) + return nn.Sequential(conv_layer, leaky_relu) class AttentionBlock(nn.Module): @@ -38,38 +31,40 @@ class AttentionBlock(nn.Module): super(AttentionBlock, self).__init__() self.attention = nn.Sequential( nn.Conv1d(channels, channels // 4, kernel_size=1), - nn.ReLU(inplace=True), + nn.ReLU(), nn.Conv1d(channels // 4, channels, kernel_size=1), nn.Sigmoid(), ) def forward(self, x): attention_weights = self.attention(x) - return x * attention_weights + return x + (x * attention_weights) class SISUDiscriminator(nn.Module): - def __init__(self, layers=32): + def __init__(self, layers=8): super(SISUDiscriminator, self).__init__() - self.model = nn.Sequential( - discriminator_block(1, layers, kernel_size=7, stride=1), - discriminator_block(layers, layers * 2, kernel_size=5, stride=2), - discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=2), + self.discriminator_blocks = nn.Sequential( + # 1 -> 32 + discriminator_block(2, layers), + AttentionBlock(layers), + # 32 -> 64 + discriminator_block(layers, layers * 2, dilation=2), + # 64 -> 128 + discriminator_block(layers * 2, layers * 4, dilation=4), AttentionBlock(layers * 4), - discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=4), - discriminator_block(layers * 8, layers * 2, kernel_size=5, stride=2), - discriminator_block( - layers * 2, - 1, - spectral_norm=False, - use_instance_norm=False, - ), + # 128 -> 256 + discriminator_block(layers * 4, layers * 8, stride=4), + # 256 -> 512 + # discriminator_block(layers * 8, layers * 16, stride=4) ) - self.global_avg_pool = nn.AdaptiveAvgPool1d(1) + self.final_conv = nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1) + + self.avg_pool = nn.AdaptiveAvgPool1d(1) def forward(self, x): - x = self.model(x) - x = self.global_avg_pool(x) - x = x.view(x.size(0), -1) - return x + x = self.discriminator_blocks(x) + x = self.final_conv(x) + x = self.avg_pool(x) + return x.squeeze(2) diff --git a/generator.py b/generator.py index b6d2204..bc994ac 100644 --- a/generator.py +++ b/generator.py @@ -2,25 +2,23 @@ import torch import torch.nn as nn -def conv_block(in_channels, out_channels, kernel_size=3, dilation=1): +def GeneratorBlock(in_channels, out_channels, kernel_size=3, stride=1, dilation=1): + padding = (kernel_size - 1) // 2 * dilation return nn.Sequential( nn.Conv1d( in_channels, out_channels, kernel_size=kernel_size, + stride=stride, dilation=dilation, - padding=(kernel_size // 2) * dilation, + padding=padding ), nn.InstanceNorm1d(out_channels), - nn.PReLU(), + nn.PReLU(num_parameters=1, init=0.1), ) class AttentionBlock(nn.Module): - """ - Simple Channel Attention Block. Learns to weight channels based on their importance. - """ - def __init__(self, channels): super(AttentionBlock, self).__init__() self.attention = nn.Sequential( @@ -32,7 +30,7 @@ class AttentionBlock(nn.Module): def forward(self, x): attention_weights = self.attention(x) - return x * attention_weights + return x + (x * attention_weights) class ResidualInResidualBlock(nn.Module): @@ -40,7 +38,7 @@ class ResidualInResidualBlock(nn.Module): super(ResidualInResidualBlock, self).__init__() self.conv_layers = nn.Sequential( - *[conv_block(channels, channels) for _ in range(num_convs)] + *[GeneratorBlock(channels, channels) for _ in range(num_convs)] ) self.attention = AttentionBlock(channels) @@ -51,31 +49,74 @@ class ResidualInResidualBlock(nn.Module): x = self.attention(x) return x + residual +def UpsampleBlock(in_channels, out_channels): + return nn.Sequential( + nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=4, + stride=2, + padding=1 + ), + nn.InstanceNorm1d(out_channels), + nn.PReLU(num_parameters=1, init=0.1) + ) class SISUGenerator(nn.Module): - def __init__(self, channels=16, num_rirb=4, alpha=1): + def __init__(self, channels=32, num_rirb=1): super(SISUGenerator, self).__init__() - self.alpha = alpha - self.conv1 = nn.Sequential( - nn.Conv1d(1, channels, kernel_size=7, padding=3), - nn.InstanceNorm1d(channels), - nn.PReLU(), + self.first_conv = GeneratorBlock(1, channels) + + self.downsample = GeneratorBlock(channels, channels * 2, stride=2) + self.downsample_attn = AttentionBlock(channels * 2) + self.downsample_2 = GeneratorBlock(channels * 2, channels * 4, stride=2) + self.downsample_2_attn = AttentionBlock(channels * 4) + + self.rirb = ResidualInResidualBlock(channels * 4) + # self.rirb = nn.Sequential( + # *[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)] + # ) + + self.upsample = UpsampleBlock(channels * 4, channels * 2) + self.upsample_attn = AttentionBlock(channels * 2) + self.compress_1 = GeneratorBlock(channels * 4, channels * 2) + + self.upsample_2 = UpsampleBlock(channels * 2, channels) + self.upsample_2_attn = AttentionBlock(channels) + self.compress_2 = GeneratorBlock(channels * 2, channels) + + self.final_conv = nn.Sequential( + nn.Conv1d(channels, 1, kernel_size=7, padding=3), + nn.Tanh() ) - self.rir_blocks = nn.Sequential( - *[ResidualInResidualBlock(channels) for _ in range(num_rirb)] - ) - - self.final_layer = nn.Sequential( - nn.Conv1d(channels, 1, kernel_size=3, padding=1), nn.Tanh() - ) def forward(self, x): residual_input = x - x = self.conv1(x) - x_rirb_out = self.rir_blocks(x) - learned_residual = self.final_layer(x_rirb_out) - output = residual_input + self.alpha * learned_residual + x1 = self.first_conv(x) - return torch.tanh(output) + x2 = self.downsample(x1) + x2 = self.downsample_attn(x2) + + x3 = self.downsample_2(x2) + x3 = self.downsample_2_attn(x3) + + x_rirb = self.rirb(x3) + + up1 = self.upsample(x_rirb) + up1 = self.upsample_attn(up1) + + cat1 = torch.cat((up1, x2), dim=1) + comp1 = self.compress_1(cat1) + + up2 = self.upsample_2(comp1) + up2 = self.upsample_2_attn(up2) + + cat2 = torch.cat((up2, x1), dim=1) + comp2 = self.compress_2(cat2) + + learned_residual = self.final_conv(comp2) + + output = residual_input + learned_residual + return output diff --git a/training.py b/training.py index 0e0caf8..6f962f3 100644 --- a/training.py +++ b/training.py @@ -1,4 +1,5 @@ import argparse +import datetime import os import torch @@ -52,7 +53,7 @@ accelerator.print("✅ | Compiling done!") # Dataset / DataLoader # --------------------------- accelerator.print("📊 | Fetching dataset...") -dataset = AudioDataset("./dataset") +dataset = AudioDataset("./dataset", 8192) sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None pin_memory = torch.cuda.is_available() and not args.no_pin_memory @@ -93,7 +94,6 @@ scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_d, mode="min", factor=0.5, patience=5 ) -criterion_g = nn.BCEWithLogitsLoss() criterion_d = nn.MSELoss() # --------------------------- @@ -143,12 +143,8 @@ if args.resume: start_epoch = ckpt.get("epoch", 1) accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!") -real_buf = torch.full( - (loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32 -) -fake_buf = torch.zeros( - (loader_batch_size, 1), device=accelerator.device, dtype=torch.float32 -) +real_buf = torch.full((loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32) +fake_buf = torch.zeros((loader_batch_size, 1), device=accelerator.device, dtype=torch.float32) accelerator.print("🏋️ | Started training...") @@ -157,35 +153,45 @@ try: generator.train() discriminator.train() + discriminator_time = 0 + generator_time = 0 + running_d, running_g, steps = 0.0, 0.0, 0 + progress_bar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs") + for i, ( (high_quality, low_quality), (high_sample_rate, low_sample_rate), - ) in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")): + ) in enumerate(progress_bar): batch_size = high_quality.size(0) real_labels = real_buf[:batch_size].to(accelerator.device) fake_labels = fake_buf[:batch_size].to(accelerator.device) + with accelerator.autocast(): + generator_output = generator(low_quality) + # --- Discriminator --- + d_time = datetime.datetime.now() optimizer_d.zero_grad(set_to_none=True) with accelerator.autocast(): d_loss = discriminator_train( high_quality, - low_quality, + low_quality.detach(), real_labels, fake_labels, discriminator, - generator, criterion_d, + generator_output.detach() ) accelerator.backward(d_loss) - torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1) optimizer_d.step() + discriminator_time = (datetime.datetime.now() - d_time).microseconds # --- Generator --- + g_time = datetime.datetime.now() optimizer_g.zero_grad(set_to_none=True) with accelerator.autocast(): g_total, g_adv = generator_train( @@ -195,11 +201,13 @@ try: generator, discriminator, criterion_d, + generator_output ) accelerator.backward(g_total) torch.nn.utils.clip_grad_norm_(generator.parameters(), 1) optimizer_g.step() + generator_time = (datetime.datetime.now() - g_time).microseconds d_val = accelerator.gather(d_loss.detach()).mean() g_val = accelerator.gather(g_total.detach()).mean() @@ -219,6 +227,7 @@ try: ) steps += 1 + progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs") # epoch averages & schedulers if steps == 0: diff --git a/utils/MultiResolutionSTFTLoss.py b/utils/MultiResolutionSTFTLoss.py index eab6355..560191a 100644 --- a/utils/MultiResolutionSTFTLoss.py +++ b/utils/MultiResolutionSTFTLoss.py @@ -7,18 +7,13 @@ import torchaudio.transforms as T class MultiResolutionSTFTLoss(nn.Module): - """ - Multi-resolution STFT loss. - Combines spectral convergence loss and log-magnitude loss - across multiple STFT resolutions. - """ - def __init__( self, - fft_sizes: List[int] = [1024, 2048, 512], - hop_sizes: List[int] = [120, 240, 50], - win_lengths: List[int] = [600, 1200, 240], + fft_sizes: List[int] = [512, 1024, 2048, 4096, 8192], + hop_sizes: List[int] = [64, 128, 256, 512, 1024], + win_lengths: List[int] = [256, 512, 1024, 2048, 4096], eps: float = 1e-7, + center: bool = True ): super().__init__() @@ -26,15 +21,14 @@ class MultiResolutionSTFTLoss(nn.Module): self.n_resolutions = len(fft_sizes) self.stft_transforms = nn.ModuleList() - for n_fft, hop_len, win_len in zip(fft_sizes, hop_sizes, win_lengths): - window = torch.hann_window(win_len) + for i, (n_fft, hop_len, win_len) in enumerate(zip(fft_sizes, hop_sizes, win_lengths)): stft = T.Spectrogram( n_fft=n_fft, hop_length=hop_len, win_length=win_len, - window_fn=lambda _: window, - power=None, # Keep complex output - center=True, + window_fn=torch.hann_window, + power=None, + center=center, pad_mode="reflect", normalized=False, ) @@ -43,12 +37,6 @@ class MultiResolutionSTFTLoss(nn.Module): def forward( self, y_true: torch.Tensor, y_pred: torch.Tensor ) -> Dict[str, torch.Tensor]: - """ - Args: - y_true: (B, T) or (B, 1, T) waveform - y_pred: (B, T) or (B, 1, T) waveform - """ - # Ensure correct shape (B, T) if y_true.dim() == 3 and y_true.size(1) == 1: y_true = y_true.squeeze(1) if y_pred.dim() == 3 and y_pred.size(1) == 1: @@ -58,28 +46,21 @@ class MultiResolutionSTFTLoss(nn.Module): mag_loss = 0.0 for stft in self.stft_transforms: - stft = stft.to(y_pred.device) - - # Complex STFTs: (B, F, T, 2) + stft.window = stft.window.to(y_true.device) stft_true = stft(y_true) stft_pred = stft(y_pred) - # Magnitudes stft_mag_true = torch.abs(stft_true) stft_mag_pred = torch.abs(stft_pred) - # --- Spectral Convergence Loss --- norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1)) norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1)) sc_loss += torch.mean(norm_diff / (norm_true + self.eps)) - # --- Log STFT Magnitude Loss --- - mag_loss += F.l1_loss( - torch.log(stft_mag_pred + self.eps), - torch.log(stft_mag_true + self.eps), - ) + log_mag_pred = torch.log(stft_mag_pred + self.eps) + log_mag_true = torch.log(stft_mag_true + self.eps) + mag_loss += F.l1_loss(log_mag_pred, log_mag_true) - # Average across resolutions sc_loss /= self.n_resolutions mag_loss /= self.n_resolutions total_loss = sc_loss + mag_loss diff --git a/utils/TrainingTools.py b/utils/TrainingTools.py index d1959a0..581a890 100644 --- a/utils/TrainingTools.py +++ b/utils/TrainingTools.py @@ -1,12 +1,17 @@ import torch -# In case if needed again... -# from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss -# -# stft_loss_fn = MultiResolutionSTFTLoss( -# fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240] -# ) +from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss +# stft_loss_fn = MultiResolutionSTFTLoss( +# fft_sizes=[512, 1024, 2048, 4096], +# hop_sizes=[128, 256, 512, 1024], +# win_lengths=[512, 1024, 2048, 4096] +# ) +stft_loss_fn = MultiResolutionSTFTLoss( + fft_sizes=[512, 1024, 2048], + hop_sizes=[64, 128, 256], + win_lengths=[256, 512, 1024] +) def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor: absolute_difference = torch.abs(input_one - input_two) @@ -19,42 +24,35 @@ def discriminator_train( high_labels, low_labels, discriminator, - generator, criterion, + generator_output ): - decision_high = discriminator(high_quality) - d_loss_high = criterion(decision_high, high_labels) - # print(f"Is this real?: {discriminator_decision_from_real} | {d_loss_real}") - decision_low = discriminator(low_quality) - d_loss_low = criterion(decision_low, low_labels) - # print(f"Is this real?: {discriminator_decision_from_fake} | {d_loss_fake}") + real_pair = torch.cat((low_quality, high_quality), dim=1) + decision_real = discriminator(real_pair) + d_loss_real = criterion(decision_real, high_labels) - with torch.no_grad(): - generator_quality = generator(low_quality) - decision_gen = discriminator(generator_quality) - d_loss_gen = criterion(decision_gen, low_labels) - - noise = torch.rand_like(high_quality) * 0.08 - decision_noise = discriminator(high_quality + noise) - d_loss_noise = criterion(decision_noise, low_labels) - - d_loss = (d_loss_high + d_loss_low + d_loss_gen + d_loss_noise) / 4.0 + fake_pair = torch.cat((low_quality, generator_output), dim=1) + decision_fake = discriminator(fake_pair) + d_loss_fake = criterion(decision_fake, low_labels) + d_loss = (d_loss_real + d_loss_fake) / 2.0 return d_loss def generator_train( - low_quality, high_quality, real_labels, generator, discriminator, adv_criterion -): - generator_output = generator(low_quality) + low_quality, high_quality, real_labels, generator, discriminator, adv_criterion, generator_output): - discriminator_decision = discriminator(generator_output) + fake_pair = torch.cat((low_quality, generator_output), dim=1) + + discriminator_decision = discriminator(fake_pair) adversarial_loss = adv_criterion(discriminator_decision, real_labels) - # Signal similarity - similarity_loss = signal_mae(generator_output, high_quality) - - combined_loss = adversarial_loss + (similarity_loss * 100) + mae_loss = signal_mae(generator_output, high_quality) + stft_loss = stft_loss_fn(high_quality, generator_output)["total"] + lambda_mae = 10.0 + lambda_stft = 2.5 + lambda_adv = 2.5 + combined_loss = (lambda_mae * mae_loss) + (lambda_stft * stft_loss) + (lambda_adv * adversarial_loss) return combined_loss, adversarial_loss