diff --git a/AudioUtils.py b/AudioUtils.py index 04f75db..ff6a24f 100644 --- a/AudioUtils.py +++ b/AudioUtils.py @@ -1,18 +1,41 @@ import torch import torch.nn.functional as F -def stereo_tensor_to_mono(waveform): - if waveform.shape[0] > 1: - # Average across channels - mono_waveform = torch.mean(waveform, dim=0, keepdim=True) - else: - # Already mono - mono_waveform = waveform - return mono_waveform -def stretch_tensor(tensor, target_length): - scale_factor = target_length / tensor.size(1) +def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor: + mono_tensor = torch.mean(waveform, dim=0, keepdim=True) + return mono_tensor - tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False) - return tensor +def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 512) -> torch.Tensor: + padding_amount = target_length - audio_tensor.size(-1) + if padding_amount <= 0: + return audio_tensor + + padded_audio_tensor = F.pad(audio_tensor, (0, padding_amount)) + + return padded_audio_tensor + + +def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 512, pad_last_tensor: bool = False) -> list[torch.Tensor]: + chunks = list(torch.split(audio_tensor, chunk_size, dim=1)) + + if pad_last_tensor: + last_chunk = chunks[-1] + + if last_chunk.size(-1) < chunk_size: + chunks[-1] = pad_tensor(last_chunk, chunk_size) + + return chunks + + +def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor: + reconstructed_tensor = torch.cat(chunks, dim=-1) + return reconstructed_tensor + + +def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + max_val = torch.max(torch.abs(audio_tensor)) + if max_val < eps: + return audio_tensor + return audio_tensor / max_val diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app.py b/app.py new file mode 100644 index 0000000..0bacee2 --- /dev/null +++ b/app.py @@ -0,0 +1,128 @@ +import argparse + +import torch +import torchaudio +import torchcodec +import tqdm +from accelerate import Accelerator + +import AudioUtils +from generator import SISUGenerator + +# Init script argument parser +parser = argparse.ArgumentParser(description="Training script") +parser.add_argument("--device", type=str, default="cpu", help="Select device") +parser.add_argument("--model", type=str, help="Model to use for upscaling") +parser.add_argument( + "--clip_length", + type=int, + default=8000, + help="Internal clip length, leave unspecified if unsure", +) +parser.add_argument( + "--sample_rate", type=int, default=44100, help="Output clip sample rate" +) +parser.add_argument( + "--bitrate", + type=int, + default=192000, + help="Output clip bitrate", +) +parser.add_argument("-i", "--input", type=str, help="Input audio file") +parser.add_argument("-o", "--output", type=str, help="Output audio file") + +args = parser.parse_args() + +if args.sample_rate < 8000: + print( + "Sample rate cannot be lower than 8000! (44100 is recommended for base models)" + ) + exit() + +# --------------------------- +# Init accelerator +# --------------------------- + +accelerator = Accelerator(mixed_precision="bf16") + +# --------------------------- +# Models +# --------------------------- +generator = SISUGenerator() + +accelerator.print("๐Ÿ”จ | Compiling models...") + +generator = torch.compile(generator) + +accelerator.print("โœ… | Compiling done!") + +# --------------------------- +# Prepare accelerator +# --------------------------- + +generator = accelerator.prepare(generator) + +# --------------------------- +# Checkpoint helpers +# --------------------------- + +models_dir = args.model +clip_length = args.clip_length +input_audio = args.input +output_audio = args.output + +if models_dir: + ckpt = torch.load(models_dir) + + accelerator.unwrap_model(generator).load_state_dict(ckpt["G"]) + + accelerator.print("๐Ÿ’พ | Loaded model!") +else: + print( + "Generator model (--model) isn't specified. Do you have the trained model? If not, you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)" + ) + + +def start(): + # To Mono! + decoder = torchcodec.decoders.AudioDecoder(input_audio) + + decoded_samples = decoder.get_all_samples() + audio = decoded_samples.data + original_sample_rate = decoded_samples.sample_rate + + # Support for multichannel audio + # audio = AudioUtils.stereo_tensor_to_mono(audio) + audio = AudioUtils.normalize(audio) + + resample_transform = torchaudio.transforms.Resample( + original_sample_rate, args.sample_rate + ) + + audio = resample_transform(audio) + + splitted_audio = AudioUtils.split_audio(audio, clip_length) + splitted_audio_on_device = [t.view(1, t.shape[0], t.shape[-1]).to(accelerator.device) for t in splitted_audio] + processed_audio = [] + with torch.no_grad(): + for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."): + channels = [] + for audio_channel in torch.split(clip, 1, dim=1): + output_piece = generator(audio_channel) + channels.append(output_piece.detach().cpu()) + output_clip = torch.cat(channels, dim=1) + processed_audio.append(output_clip) + + reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio) + reconstructed_audio = reconstructed_audio.squeeze(0) + print(f"๐Ÿ”Š | Saving {output_audio}!") + torchaudio.save_with_torchcodec( + uri=output_audio, + src=reconstructed_audio, + sample_rate=args.sample_rate, + channels_first=True, + compression=args.bitrate, + ) + + +start() diff --git a/data.py b/data.py index bc7574f..13b0ca8 100644 --- a/data.py +++ b/data.py @@ -1,53 +1,71 @@ -from torch.utils.data import Dataset -import torch.nn.functional as F -import torch -import torchaudio import os import random -import torchaudio.transforms as T + +import torch +import torchaudio +import torchcodec.decoders as decoders +import tqdm +from torch.utils.data import Dataset + import AudioUtils -class AudioDataset(Dataset): - audio_sample_rates = [11025] - MAX_LENGTH = 44100 # Define your desired maximum length here - def __init__(self, input_dir, device): - self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')] - self.device = device +class AudioDataset(Dataset): + audio_sample_rates = [8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100] + + def __init__(self, input_dir, clip_length: int = 512, normalize: bool = True): + self.clip_length = clip_length + self.normalize = normalize + + input_files = [ + os.path.join(input_dir, f) + for f in os.listdir(input_dir) + if os.path.isfile(os.path.join(input_dir, f)) + and f.lower().endswith((".wav", ".mp3", ".flac")) + ] + + data = [] + for audio_clip in tqdm.tqdm( + input_files, desc=f"Processing {len(input_files)} audio file(s)" + ): + decoder = decoders.AudioDecoder(audio_clip) + decoded_samples = decoder.get_all_samples() + + audio = decoded_samples.data.float() + original_sample_rate = decoded_samples.sample_rate + + if normalize: + audio = AudioUtils.normalize(audio) + + splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length, True) + + if not splitted_high_quality_audio: + continue + + for splitted_audio_clip in splitted_high_quality_audio: + for audio_clip in torch.split(splitted_audio_clip, 1): + data.append((audio_clip, original_sample_rate)) + + self.audio_data = data def __len__(self): - return len(self.input_files) + return len(self.audio_data) def __getitem__(self, idx): - # Load high-quality audio - high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True) - - # Generate low-quality audio with random downsampling + audio_clip = self.audio_data[idx] mangled_sample_rate = random.choice(self.audio_sample_rates) - resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) - low_quality_audio = resample_transform_low(high_quality_audio) - resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate) - low_quality_audio = resample_transform_high(low_quality_audio) + resample_transform_low = torchaudio.transforms.Resample( + audio_clip[1], mangled_sample_rate + ) - high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio) - low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio) + resample_transform_high = torchaudio.transforms.Resample( + mangled_sample_rate, audio_clip[1] + ) - # Pad or truncate high-quality audio - if high_quality_audio.shape[1] < self.MAX_LENGTH: - padding = self.MAX_LENGTH - high_quality_audio.shape[1] - high_quality_audio = F.pad(high_quality_audio, (0, padding)) - elif high_quality_audio.shape[1] > self.MAX_LENGTH: - high_quality_audio = high_quality_audio[:, :self.MAX_LENGTH] - - # Pad or truncate low-quality audio - if low_quality_audio.shape[1] < self.MAX_LENGTH: - padding = self.MAX_LENGTH - low_quality_audio.shape[1] - low_quality_audio = F.pad(low_quality_audio, (0, padding)) - elif low_quality_audio.shape[1] > self.MAX_LENGTH: - low_quality_audio = low_quality_audio[:, :self.MAX_LENGTH] - - high_quality_audio = high_quality_audio.to(self.device) - low_quality_audio = low_quality_audio.to(self.device) - - return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate) + low_audio_clip = resample_transform_high(resample_transform_low(audio_clip[0])) + if audio_clip[0].shape[1] < low_audio_clip.shape[1]: + low_audio_clip = low_audio_clip[:, :audio_clip[0].shape[1]] + elif audio_clip[0].shape[1] > low_audio_clip.shape[1]: + low_audio_clip = AudioUtils.pad_tensor(low_audio_clip, self.clip_length) + return ((audio_clip[0], low_audio_clip), (audio_clip[1], mangled_sample_rate)) diff --git a/discriminator.py b/discriminator.py index dfd0126..0572d17 100644 --- a/discriminator.py +++ b/discriminator.py @@ -1,63 +1,179 @@ import torch import torch.nn as nn -import torch.nn.utils as utils +import torch.nn.functional as F +from torch.nn.utils.parametrizations import weight_norm, spectral_norm -def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, spectral_norm=True, use_instance_norm=True): - padding = (kernel_size // 2) * dilation - conv_layer = nn.Conv1d( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - padding=padding - ) +# ------------------------------------------------------------------- +# 1. Multi-Period Discriminator (MPD) +# Captures periodic structures (pitch/timbre) by folding audio. +# ------------------------------------------------------------------- - if spectral_norm: - conv_layer = utils.spectral_norm(conv_layer) +class DiscriminatorP(nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm - layers = [conv_layer] - layers.append(nn.LeakyReLU(0.2, inplace=True)) + # Use spectral_norm for stability, or weight_norm for performance + norm_f = spectral_norm if use_spectral_norm else weight_norm - if use_instance_norm: - layers.append(nn.InstanceNorm1d(out_channels)) + # We use 2D convs because we "fold" the 1D audio into 2D (Period x Time) + self.convs = nn.ModuleList([ + norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(2, 0))), + norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(2, 0))), + norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(2, 0))), + norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(2, 0))), + norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) - return nn.Sequential(*layers) - -class AttentionBlock(nn.Module): - def __init__(self, channels): - super(AttentionBlock, self).__init__() - self.attention = nn.Sequential( - nn.Conv1d(channels, channels // 4, kernel_size=1), - nn.ReLU(inplace=True), - nn.Conv1d(channels // 4, channels, kernel_size=1), - nn.Sigmoid() - ) + self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) def forward(self, x): - attention_weights = self.attention(x) - return x * attention_weights + fmap = [] + + # 1d to 2d conversion: [B, C, T] -> [B, C, T/P, P] + b, c, t = x.shape + if t % self.period != 0: # Pad if not divisible by period + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, 0.1) + fmap.append(x) # Store feature map for Feature Matching Loss + + x = self.conv_post(x) + fmap.append(x) + + # Flatten back to 1D for score + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(nn.Module): + def __init__(self, periods=[2, 3, 5, 7, 11]): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorP(p) for p in periods + ]) + + def forward(self, y, y_hat): + y_d_rs = [] # Real scores + y_d_gs = [] # Generated (Fake) scores + fmap_rs = [] # Real feature maps + fmap_gs = [] # Generated (Fake) feature maps + + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +# ------------------------------------------------------------------- +# 2. Multi-Scale Discriminator (MSD) +# Captures structure at different audio resolutions (raw, x0.5, x0.25). +# ------------------------------------------------------------------- + +class DiscriminatorS(nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = spectral_norm if use_spectral_norm else weight_norm + + # Standard 1D Convolutions with large receptive field + self.convs = nn.ModuleList([ + norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)), + norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, 0.1) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + return x, fmap + + +class MultiScaleDiscriminator(nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + # 3 Scales: Original, Downsampled x2, Downsampled x4 + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList([ + nn.AvgPool1d(4, 2, padding=2), + nn.AvgPool1d(4, 2, padding=2) + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for i, d in enumerate(self.discriminators): + if i != 0: + # Downsample input for subsequent discriminators + y = self.meanpools[i-1](y) + y_hat = self.meanpools[i-1](y_hat) + + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +# ------------------------------------------------------------------- +# 3. Master Wrapper +# Combines MPD and MSD into one class to fit your training script. +# ------------------------------------------------------------------- class SISUDiscriminator(nn.Module): - def __init__(self, base_channels=16): + def __init__(self): super(SISUDiscriminator, self).__init__() - layers = base_channels - self.model = nn.Sequential( - discriminator_block(1, layers, kernel_size=7, stride=1, spectral_norm=True, use_instance_norm=False), - discriminator_block(layers, layers * 2, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers * 2, layers * 4, kernel_size=5, stride=1, dilation=2, spectral_norm=True, use_instance_norm=True), - AttentionBlock(layers * 4), - discriminator_block(layers * 4, layers * 8, kernel_size=5, stride=1, dilation=4, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers * 8, layers * 4, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers * 4, layers * 2, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers * 2, layers, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers, 1, kernel_size=3, stride=1, spectral_norm=False, use_instance_norm=False) + self.mpd = MultiPeriodDiscriminator() + self.msd = MultiScaleDiscriminator() + + def forward(self, y, y_hat): + # Return format: + # scores_real, scores_fake, features_real, features_fake + + # Run Multi-Period + mpd_y_d_rs, mpd_y_d_gs, mpd_fmap_rs, mpd_fmap_gs = self.mpd(y, y_hat) + + # Run Multi-Scale + msd_y_d_rs, msd_y_d_gs, msd_fmap_rs, msd_fmap_gs = self.msd(y, y_hat) + + # Combine all results + return ( + mpd_y_d_rs + msd_y_d_rs, # All real scores + mpd_y_d_gs + msd_y_d_gs, # All fake scores + mpd_fmap_rs + msd_fmap_rs, # All real feature maps + mpd_fmap_gs + msd_fmap_gs # All fake feature maps ) - - self.global_avg_pool = nn.AdaptiveAvgPool1d(1) - - def forward(self, x): - x = self.model(x) - x = self.global_avg_pool(x) - x = x.view(x.size(0), -1) - return x diff --git a/file_utils.py b/file_utils.py deleted file mode 100644 index a723688..0000000 --- a/file_utils.py +++ /dev/null @@ -1,28 +0,0 @@ -import json - -filepath = "my_data.json" - -def write_data(filepath, data): - try: - with open(filepath, 'w') as f: - json.dump(data, f, indent=4) # Use indent for pretty formatting - print(f"Data written to '{filepath}'") - except Exception as e: - print(f"Error writing to file: {e}") - - -def read_data(filepath): - try: - with open(filepath, 'r') as f: - data = json.load(f) - print(f"Data read from '{filepath}'") - return data - except FileNotFoundError: - print(f"File not found: {filepath}") - return None - except json.JSONDecodeError: - print(f"Error decoding JSON from file: {filepath}") - return None - except Exception as e: - print(f"Error reading from file: {e}") - return None diff --git a/generator.py b/generator.py index a53feb7..15279b1 100644 --- a/generator.py +++ b/generator.py @@ -1,42 +1,45 @@ import torch import torch.nn as nn +from torch.nn.utils.parametrizations import weight_norm + +def GeneratorBlock(in_channels, out_channels, kernel_size=3, stride=1, dilation=1): + padding = (kernel_size - 1) // 2 * dilation -def conv_block(in_channels, out_channels, kernel_size=3, dilation=1): return nn.Sequential( - nn.Conv1d( + + weight_norm(nn.Conv1d( in_channels, out_channels, kernel_size=kernel_size, + stride=stride, dilation=dilation, - padding=(kernel_size // 2) * dilation - ), - nn.InstanceNorm1d(out_channels), - nn.PReLU() + padding=padding + )), + nn.PReLU(num_parameters=1, init=0.1), ) + class AttentionBlock(nn.Module): - """ - Simple Channel Attention Block. Learns to weight channels based on their importance. - """ def __init__(self, channels): super(AttentionBlock, self).__init__() self.attention = nn.Sequential( - nn.Conv1d(channels, channels // 4, kernel_size=1), + weight_norm(nn.Conv1d(channels, channels // 4, kernel_size=1)), nn.ReLU(inplace=True), - nn.Conv1d(channels // 4, channels, kernel_size=1), - nn.Sigmoid() + weight_norm(nn.Conv1d(channels // 4, channels, kernel_size=1)), + nn.Sigmoid(), ) def forward(self, x): attention_weights = self.attention(x) - return x * attention_weights + return x + (x * attention_weights) + class ResidualInResidualBlock(nn.Module): def __init__(self, channels, num_convs=3): super(ResidualInResidualBlock, self).__init__() self.conv_layers = nn.Sequential( - *[conv_block(channels, channels) for _ in range(num_convs)] + *[GeneratorBlock(channels, channels) for _ in range(num_convs)] ) self.attention = AttentionBlock(channels) @@ -47,28 +50,77 @@ class ResidualInResidualBlock(nn.Module): x = self.attention(x) return x + residual +def UpsampleBlock(in_channels, out_channels, scale_factor=2): + return nn.Sequential( + nn.Upsample(scale_factor=scale_factor, mode='nearest'), + weight_norm(nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding=1 + )), + nn.PReLU(num_parameters=1, init=0.1) + ) + class SISUGenerator(nn.Module): - def __init__(self, channels=16, num_rirb=4, alpha=1.0): + def __init__(self, channels=32, num_rirb=4): super(SISUGenerator, self).__init__() - self.alpha = alpha - self.conv1 = nn.Sequential( - nn.Conv1d(1, channels, kernel_size=7, padding=3), - nn.InstanceNorm1d(channels), - nn.PReLU(), + self.first_conv = GeneratorBlock(1, channels) + + self.downsample = GeneratorBlock(channels, channels * 2, stride=2) + self.downsample_attn = AttentionBlock(channels * 2) + self.downsample_2 = GeneratorBlock(channels * 2, channels * 4, stride=2) + self.downsample_2_attn = AttentionBlock(channels * 4) + + self.rirb = nn.Sequential( + *[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)] ) - self.rir_blocks = nn.Sequential( - *[ResidualInResidualBlock(channels) for _ in range(num_rirb)] + self.upsample = UpsampleBlock(channels * 4, channels * 2) + self.upsample_attn = AttentionBlock(channels * 2) + self.compress_1 = GeneratorBlock(channels * 4, channels * 2) + + self.upsample_2 = UpsampleBlock(channels * 2, channels) + self.upsample_2_attn = AttentionBlock(channels) + self.compress_2 = GeneratorBlock(channels * 2, channels) + + self.final_conv = nn.Sequential( + weight_norm(nn.Conv1d(channels, 1, kernel_size=7, padding=3)), + nn.Tanh() ) - self.final_layer = nn.Conv1d(channels, 1, kernel_size=3, padding=1) def forward(self, x): residual_input = x - x = self.conv1(x) - x_rirb_out = self.rir_blocks(x) - learned_residual = self.final_layer(x_rirb_out) - output = residual_input + self.alpha * learned_residual + # Encoding + x1 = self.first_conv(x) + + x2 = self.downsample(x1) + x2 = self.downsample_attn(x2) + + x3 = self.downsample_2(x2) + x3 = self.downsample_2_attn(x3) + + # Bottleneck (Deep Residual processing) + x_rirb = self.rirb(x3) + + # Decoding with Skip Connections + up1 = self.upsample(x_rirb) + up1 = self.upsample_attn(up1) + + cat1 = torch.cat((up1, x2), dim=1) + comp1 = self.compress_1(cat1) + + up2 = self.upsample_2(comp1) + up2 = self.upsample_2_attn(up2) + + cat2 = torch.cat((up2, x1), dim=1) + comp2 = self.compress_2(cat2) + + learned_residual = self.final_conv(comp2) + + output = residual_input + learned_residual return output diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 21f6bef..0000000 --- a/requirements.txt +++ /dev/null @@ -1,12 +0,0 @@ -filelock==3.16.1 -fsspec==2024.10.0 -Jinja2==3.1.4 -MarkupSafe==2.1.5 -mpmath==1.3.0 -networkx==3.4.2 -numpy==2.2.3 -pillow==11.0.0 -setuptools==70.2.0 -sympy==1.13.3 -tqdm==4.67.1 -typing_extensions==4.12.2 diff --git a/training.py b/training.py index 01ea749..8107b03 100644 --- a/training.py +++ b/training.py @@ -1,194 +1,247 @@ +import argparse +import datetime +import os + import torch import torch.nn as nn import torch.optim as optim - -import torch.nn.functional as F -import torchaudio import tqdm +from accelerate import Accelerator +from torch.utils.data import DataLoader, DistributedSampler -import argparse - -import math - -import os - -from torch.utils.data import random_split -from torch.utils.data import DataLoader - -import AudioUtils from data import AudioDataset -from generator import SISUGenerator from discriminator import SISUDiscriminator +from generator import SISUGenerator +from utils.TrainingTools import discriminator_train, generator_train -from training_utils import discriminator_train, generator_train -import file_utils as Data - -import torchaudio.transforms as T - -# Init script argument parser -parser = argparse.ArgumentParser(description="Training script") -parser.add_argument("--generator", type=str, default=None, - help="Path to the generator model file") -parser.add_argument("--discriminator", type=str, default=None, - help="Path to the discriminator model file") -parser.add_argument("--device", type=str, default="cpu", help="Select device") -parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning") -parser.add_argument("--debug", action="store_true", help="Print debug logs") -parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models") - +# --------------------------- +# Argument parsing +# --------------------------- +parser = argparse.ArgumentParser(description="Training script (safer defaults)") +parser.add_argument("--resume", action="store_true", help="Resume training") +parser.add_argument( + "--epochs", type=int, default=5000, help="Number of training epochs" +) +parser.add_argument("--batch_size", type=int, default=8, help="Batch size") +parser.add_argument("--num_workers", type=int, default=4, help="DataLoader num_workers") # Increased workers slightly +parser.add_argument("--debug", action="store_true", help="Print debug logs") +parser.add_argument( + "--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA" +) args = parser.parse_args() -device = torch.device(args.device if torch.cuda.is_available() else "cpu") -print(f"Using device: {device}") +# --------------------------- +# Init accelerator +# --------------------------- -# Parameters -sample_rate = 44100 -n_fft = 2048 -hop_length = 256 -win_length = n_fft -n_mels = 128 -n_mfcc = 20 # If using MFCC - -mfcc_transform = T.MFCC( - sample_rate, - n_mfcc, - melkwargs = {'n_fft': n_fft, 'hop_length': hop_length} -).to(device) - -mel_transform = T.MelSpectrogram( - sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, - win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel -).to(device) - -stft_transform = T.Spectrogram( - n_fft=n_fft, win_length=win_length, hop_length=hop_length -).to(device) - -debug = args.debug - -# Initialize dataset and dataloader -dataset_dir = './dataset/good' -dataset = AudioDataset(dataset_dir, device) -models_dir = "models" -os.makedirs(models_dir, exist_ok=True) -audio_output_dir = "output" -os.makedirs(audio_output_dir, exist_ok=True) - -# ========= SINGLE ========= - -train_data_loader = DataLoader(dataset, batch_size=64, shuffle=True) - - -# ========= MODELS ========= +accelerator = Accelerator(mixed_precision="bf16") +# --------------------------- +# Models +# --------------------------- generator = SISUGenerator() +# Note: SISUDiscriminator is now an Ensemble (MPD + MSD) discriminator = SISUDiscriminator() -epoch: int = args.epoch -epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json") +accelerator.print("๐Ÿ”จ | Compiling models...") -if args.continue_training: - generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) - discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) - epoch = epoch_from_file["epoch"] + 1 -else: - if args.generator is not None: - generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True)) - if args.discriminator is not None: - discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True)) +# Torch compile is great, but if you hit errors with the new List/Tuple outputs +# of the discriminator, you might need to disable it for D. +generator = torch.compile(generator) +discriminator = torch.compile(discriminator) -generator = generator.to(device) -discriminator = discriminator.to(device) +accelerator.print("โœ… | Compiling done!") -# Loss -criterion_g = nn.BCEWithLogitsLoss() -criterion_d = nn.BCEWithLogitsLoss() +# --------------------------- +# Dataset / DataLoader +# --------------------------- +accelerator.print("๐Ÿ“Š | Fetching dataset...") +dataset = AudioDataset("./dataset", 8192) -# Optimizers -optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999)) -optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) +sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None +pin_memory = torch.cuda.is_available() and not args.no_pin_memory -# Scheduler -scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5) -scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5) +train_loader = DataLoader( + dataset, + sampler=sampler, + batch_size=args.batch_size, + shuffle=(sampler is None), + num_workers=args.num_workers, + pin_memory=pin_memory, + persistent_workers=pin_memory, +) -def start_training(): - generator_epochs = 5000 - for generator_epoch in range(generator_epochs): - low_quality_audio = (torch.empty((1)), 1) - high_quality_audio = (torch.empty((1)), 1) - ai_enhanced_audio = (torch.empty((1)), 1) +if not train_loader or not train_loader.batch_size or train_loader.batch_size == 0: + accelerator.print("๐Ÿชน | There is no data to train with! Exiting...") + exit() - times_correct = 0 +loader_batch_size = train_loader.batch_size - # ========= TRAINING ========= - for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"): - # for high_quality_clip, low_quality_clip in train_data_loader: - high_quality_sample = (high_quality_clip[0], high_quality_clip[1]) - low_quality_sample = (low_quality_clip[0], low_quality_clip[1]) +accelerator.print("โœ… | Dataset fetched!") - # ========= LABELS ========= - batch_size = high_quality_clip[0].size(0) - real_labels = torch.ones(batch_size, 1).to(device) - fake_labels = torch.zeros(batch_size, 1).to(device) +# --------------------------- +# Losses / Optimizers / Scalers +# --------------------------- - # ========= DISCRIMINATOR ========= - discriminator.train() - d_loss = discriminator_train( - high_quality_sample, - low_quality_sample, - real_labels, - fake_labels, - discriminator, - generator, - criterion_d, - optimizer_d - ) +optimizer_g = optim.AdamW( + generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 +) +optimizer_d = optim.AdamW( + discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 +) - # ========= GENERATOR ========= - generator.train() - generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train( - low_quality_sample, - high_quality_sample, - real_labels, - generator, - discriminator, - criterion_d, - optimizer_g, - device, - mel_transform, - stft_transform, - mfcc_transform - ) +scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer_g, mode="min", factor=0.5, patience=5 +) +scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer_d, mode="min", factor=0.5, patience=5 +) - if debug: - print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}") - scheduler_d.step(d_loss.detach()) - scheduler_g.step(adversarial_loss.detach()) +# --------------------------- +# Prepare accelerator +# --------------------------- - # ========= SAVE LATEST AUDIO ========= - high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0]) - low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0]) - ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0]) +generator, discriminator, optimizer_g, optimizer_d, train_loader = accelerator.prepare( + generator, discriminator, optimizer_g, optimizer_d, train_loader +) - new_epoch = generator_epoch+epoch - - if generator_epoch % 25 == 0: - print(f"Saved epoch {new_epoch}!") - torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu().detach(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. - torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1]) - torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1]) - - #if debug: - # print(generator.state_dict().keys()) - # print(discriminator.state_dict().keys()) - torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt") - torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt") - Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch}) +# --------------------------- +# Checkpoint helpers +# --------------------------- +models_dir = "./models" +os.makedirs(models_dir, exist_ok=True) - torch.save(discriminator, "models/epoch-5000-discriminator.pt") - torch.save(generator, "models/epoch-5000-generator.pt") - print("Training complete!") +def save_ckpt(path, epoch, loss=None, is_best=False): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + state = { + "epoch": epoch, + "G": accelerator.unwrap_model(generator).state_dict(), + "D": accelerator.unwrap_model(discriminator).state_dict(), + "optG": optimizer_g.state_dict(), + "optD": optimizer_d.state_dict(), + "schedG": scheduler_g.state_dict(), + "schedD": scheduler_d.state_dict() + } -start_training() + accelerator.save(state, os.path.join(models_dir, "last.pt")) + + if is_best: + accelerator.save(state, os.path.join(models_dir, "best.pt")) + accelerator.print(f"๐ŸŒŸ | New best model saved with G Loss: {loss:.4f}") + + +start_epoch = 0 +if args.resume: + ckpt_path = os.path.join(models_dir, "last.pt") + if os.path.exists(ckpt_path): + ckpt = torch.load(ckpt_path) + + accelerator.unwrap_model(generator).load_state_dict(ckpt["G"]) + accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"]) + optimizer_g.load_state_dict(ckpt["optG"]) + optimizer_d.load_state_dict(ckpt["optD"]) + scheduler_g.load_state_dict(ckpt["schedG"]) + scheduler_d.load_state_dict(ckpt["schedD"]) + + start_epoch = ckpt.get("epoch", 1) + accelerator.print(f"๐Ÿ” | Resumed from epoch {start_epoch}!") + else: + accelerator.print("โš ๏ธ | Resume requested but no checkpoint found. Starting fresh.") + + +accelerator.print("๐Ÿ‹๏ธ | Started training...") + +try: + for epoch in range(start_epoch, args.epochs): + generator.train() + discriminator.train() + + discriminator_time = 0 + generator_time = 0 + + running_d, running_g, steps = 0.0, 0.0, 0 + + progress_bar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch} | D {discriminator_time}ฮผs | G {generator_time}ฮผs") + + for i, ( + (high_quality, low_quality), + (high_sample_rate, low_sample_rate), + ) in enumerate(progress_bar): + with accelerator.autocast(): + generator_output = generator(low_quality) + + # --- Discriminator --- + d_time = datetime.datetime.now() + optimizer_d.zero_grad(set_to_none=True) + with accelerator.autocast(): + d_loss = discriminator_train( + high_quality, + discriminator, + generator_output.detach() + ) + + accelerator.backward(d_loss) + optimizer_d.step() + discriminator_time = (datetime.datetime.now() - d_time).microseconds + + # --- Generator --- + g_time = datetime.datetime.now() + optimizer_g.zero_grad(set_to_none=True) + with accelerator.autocast(): + g_total, g_adv = generator_train( + low_quality, + high_quality, + generator, + discriminator, + generator_output + ) + + accelerator.backward(g_total) + torch.nn.utils.clip_grad_norm_(generator.parameters(), 1) + optimizer_g.step() + generator_time = (datetime.datetime.now() - g_time).microseconds + + d_val = accelerator.gather(d_loss.detach()).mean() + g_val = accelerator.gather(g_total.detach()).mean() + + if torch.isfinite(d_val): + running_d += d_val.item() + else: + accelerator.print( + f"๐Ÿซฅ | NaN in discriminator loss at step {i}, skipping update." + ) + + if torch.isfinite(g_val): + running_g += g_val.item() + else: + accelerator.print( + f"๐Ÿซฅ | NaN in generator loss at step {i}, skipping update." + ) + + steps += 1 + progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}ฮผs | G {generator_time}ฮผs") + + if steps == 0: + accelerator.print("๐Ÿชน | No steps in epoch (empty dataloader?). Exiting.") + break + + mean_d = running_d / steps + mean_g = running_g / steps + + scheduler_d.step(mean_d) + scheduler_g.step(mean_g) + + save_ckpt(os.path.join(models_dir, "last.pt"), epoch) + accelerator.print(f"๐Ÿค | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}") + +except Exception: + try: + save_ckpt(os.path.join(models_dir, "crash_last.pt"), epoch) + accelerator.print(f"๐Ÿ’พ | Saved crash checkpoint for epoch {epoch}") + except Exception as e: + accelerator.print("๐Ÿ˜ฌ | Failed saving crash checkpoint:", e) + raise + +accelerator.print("๐Ÿ | Training finished.") diff --git a/training_utils.py b/training_utils.py deleted file mode 100644 index 6f26f58..0000000 --- a/training_utils.py +++ /dev/null @@ -1,144 +0,0 @@ -import torch -import torch.nn as nn -import torch.optim as optim - -import torchaudio -import torchaudio.transforms as T - -def gpu_mfcc_loss(mfcc_transform, y_true, y_pred): - mfccs_true = mfcc_transform(y_true) - mfccs_pred = mfcc_transform(y_pred) - - min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2]) - mfccs_true = mfccs_true[:, :, :min_len] - mfccs_pred = mfccs_pred[:, :, :min_len] - - loss = torch.mean((mfccs_true - mfccs_pred)**2) - return loss - -def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: - mel_spec_true = mel_transform(y_true) - mel_spec_pred = mel_transform(y_pred) - - # Ensure same time dimension length (due to potential framing differences) - min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1]) - mel_spec_true = mel_spec_true[..., :min_len] - mel_spec_pred = mel_spec_pred[..., :min_len] - - # L1 Loss (Mean Absolute Error) - loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred)) - return loss - -def mel_spectrogram_l2_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: - mel_spec_true = mel_transform(y_true) - mel_spec_pred = mel_transform(y_pred) - - min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1]) - mel_spec_true = mel_spec_true[..., :min_len] - mel_spec_pred = mel_spec_pred[..., :min_len] - - loss = torch.mean((mel_spec_true - mel_spec_pred)**2) - return loss - -def log_stft_magnitude_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: - stft_mag_true = stft_transform(y_true) - stft_mag_pred = stft_transform(y_pred) - - min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1]) - stft_mag_true = stft_mag_true[..., :min_len] - stft_mag_pred = stft_mag_pred[..., :min_len] - - loss = torch.mean(torch.abs(torch.log(stft_mag_true + eps) - torch.log(stft_mag_pred + eps))) - return loss - -def spectral_convergence_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: - stft_mag_true = stft_transform(y_true) - stft_mag_pred = stft_transform(y_pred) - - min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1]) - stft_mag_true = stft_mag_true[..., :min_len] - stft_mag_pred = stft_mag_pred[..., :min_len] - - norm_true = torch.linalg.norm(stft_mag_true, ord='fro', dim=(-2, -1)) - norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, ord='fro', dim=(-2, -1)) - - loss = torch.mean(norm_diff / (norm_true + eps)) - return loss - -def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer): - optimizer.zero_grad() - - # Forward pass for real samples - discriminator_decision_from_real = discriminator(high_quality[0]) - d_loss_real = criterion(discriminator_decision_from_real, real_labels) - - with torch.no_grad(): - generator_output = generator(low_quality[0]) - discriminator_decision_from_fake = discriminator(generator_output) - d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake)) - - d_loss = (d_loss_real + d_loss_fake) / 2.0 - - d_loss.backward() - # Optional: Gradient Clipping (can be helpful) - # nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping - optimizer.step() - - return d_loss - -def generator_train( - low_quality, - high_quality, - real_labels, - generator, - discriminator, - adv_criterion, - g_optimizer, - device, - mel_transform: T.MelSpectrogram, - stft_transform: T.Spectrogram, - mfcc_transform: T.MFCC, - lambda_adv: float = 1.0, - lambda_mel_l1: float = 10.0, - lambda_log_stft: float = 1.0, - lambda_mfcc: float = 1.0 -): - g_optimizer.zero_grad() - - generator_output = generator(low_quality[0]) - - discriminator_decision = discriminator(generator_output) - adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision)) - - mel_l1 = 0.0 - log_stft_l1 = 0.0 - mfcc_l = 0.0 - - # Calculate Mel L1 Loss if weight is positive - if lambda_mel_l1 > 0: - mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality[0], generator_output) - - # Calculate Log STFT L1 Loss if weight is positive - if lambda_log_stft > 0: - log_stft_l1 = log_stft_magnitude_loss(stft_transform, high_quality[0], generator_output) - - # Calculate MFCC Loss if weight is positive - if lambda_mfcc > 0: - mfcc_l = gpu_mfcc_loss(mfcc_transform, high_quality[0], generator_output) - - mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1 - log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1 - mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l - - combined_loss = (lambda_adv * adversarial_loss) + \ - (lambda_mel_l1 * mel_l1_tensor) + \ - (lambda_log_stft * log_stft_l1_tensor) + \ - (lambda_mfcc * mfcc_l_tensor) - - combined_loss.backward() - # Optional: Gradient Clipping - # nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0) - g_optimizer.step() - - # 6. Return values for logging - return generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor diff --git a/utils/MultiResolutionSTFTLoss.py b/utils/MultiResolutionSTFTLoss.py new file mode 100644 index 0000000..560191a --- /dev/null +++ b/utils/MultiResolutionSTFTLoss.py @@ -0,0 +1,68 @@ +from typing import Dict, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio.transforms as T + + +class MultiResolutionSTFTLoss(nn.Module): + def __init__( + self, + fft_sizes: List[int] = [512, 1024, 2048, 4096, 8192], + hop_sizes: List[int] = [64, 128, 256, 512, 1024], + win_lengths: List[int] = [256, 512, 1024, 2048, 4096], + eps: float = 1e-7, + center: bool = True + ): + super().__init__() + + self.eps = eps + self.n_resolutions = len(fft_sizes) + + self.stft_transforms = nn.ModuleList() + for i, (n_fft, hop_len, win_len) in enumerate(zip(fft_sizes, hop_sizes, win_lengths)): + stft = T.Spectrogram( + n_fft=n_fft, + hop_length=hop_len, + win_length=win_len, + window_fn=torch.hann_window, + power=None, + center=center, + pad_mode="reflect", + normalized=False, + ) + self.stft_transforms.append(stft) + + def forward( + self, y_true: torch.Tensor, y_pred: torch.Tensor + ) -> Dict[str, torch.Tensor]: + if y_true.dim() == 3 and y_true.size(1) == 1: + y_true = y_true.squeeze(1) + if y_pred.dim() == 3 and y_pred.size(1) == 1: + y_pred = y_pred.squeeze(1) + + sc_loss = 0.0 + mag_loss = 0.0 + + for stft in self.stft_transforms: + stft.window = stft.window.to(y_true.device) + stft_true = stft(y_true) + stft_pred = stft(y_pred) + + stft_mag_true = torch.abs(stft_true) + stft_mag_pred = torch.abs(stft_pred) + + norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1)) + norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1)) + sc_loss += torch.mean(norm_diff / (norm_true + self.eps)) + + log_mag_pred = torch.log(stft_mag_pred + self.eps) + log_mag_true = torch.log(stft_mag_true + self.eps) + mag_loss += F.l1_loss(log_mag_pred, log_mag_true) + + sc_loss /= self.n_resolutions + mag_loss /= self.n_resolutions + total_loss = sc_loss + mag_loss + + return {"total": total_loss, "sc": sc_loss, "mag": mag_loss} diff --git a/utils/TrainingTools.py b/utils/TrainingTools.py new file mode 100644 index 0000000..cd0350a --- /dev/null +++ b/utils/TrainingTools.py @@ -0,0 +1,93 @@ +import torch +import torch.nn.functional as F +from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss + +# Keep STFT settings as is +stft_loss_fn = MultiResolutionSTFTLoss( + fft_sizes=[512, 1024, 2048], + hop_sizes=[64, 128, 256], + win_lengths=[256, 512, 1024] +) + + +def feature_matching_loss(fmap_r, fmap_g): + """ + Computes L1 distance between real and fake feature maps. + """ + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + rl = rl.detach() + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + """ + Least Squares GAN Loss (LSGAN) for the Discriminator. + Objective: Real -> 1, Fake -> 0 + """ + loss = 0 + r_losses = [] + g_losses = [] + + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((dr - 1) ** 2) + g_loss = torch.mean(dg ** 2) + + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_adv_loss(disc_generated_outputs): + """ + Least Squares GAN Loss for the Generator. + Objective: Fake -> 1 (Fool the discriminator) + """ + loss = 0 + for dg in zip(disc_generated_outputs): + dg = dg[0] # Unpack tuple + loss += torch.mean((dg - 1) ** 2) + return loss + + +def discriminator_train( + high_quality, + discriminator, + generator_output +): + y_d_rs, y_d_gs, _, _ = discriminator(high_quality, generator_output.detach()) + + d_loss, _, _ = discriminator_loss(y_d_rs, y_d_gs) + + return d_loss + + +def generator_train( + low_quality, + high_quality, + generator, + discriminator, + generator_output +): + y_d_rs, y_d_gs, fmap_rs, fmap_gs = discriminator(high_quality, generator_output) + + loss_gen_adv = generator_adv_loss(y_d_gs) + + loss_fm = feature_matching_loss(fmap_rs, fmap_gs) + + stft_loss = stft_loss_fn(high_quality, generator_output)["total"] + + lambda_stft = 45.0 + lambda_fm = 2.0 + lambda_adv = 1.0 + + combined_loss = (lambda_stft * stft_loss) + \ + (lambda_fm * loss_fm) + \ + (lambda_adv * loss_gen_adv) + + return combined_loss, loss_gen_adv diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29