From 5735557ec3493db72f29c9c8f96d7d697d113507 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Wed, 30 Apr 2025 23:45:05 +0300 Subject: [PATCH] :poop: | VERY CRUDE JAX implementation... --- data.py | 109 +++++++--- discriminator.py | 192 ++++++++++++----- generator.py | 215 ++++++++++++++----- training.py | 535 +++++++++++++++++++++++++++++++++++----------- training.txt | 194 +++++++++++++++++ training_utils.py | 2 - 6 files changed, 979 insertions(+), 268 deletions(-) create mode 100644 training.txt diff --git a/data.py b/data.py index bc7574f..c095970 100644 --- a/data.py +++ b/data.py @@ -1,53 +1,104 @@ +# Keep necessary PyTorch imports for torchaudio and Dataset structure from torch.utils.data import Dataset -import torch.nn.functional as F import torch import torchaudio +import torchaudio.transforms as T # Keep using torchaudio transforms + +# Import NumPy +import numpy as np + import os import random -import torchaudio.transforms as T +# Assume AudioUtils is available and works on PyTorch Tensors as before import AudioUtils -class AudioDataset(Dataset): +class AudioDatasetNumPy(Dataset): # Renamed slightly for clarity audio_sample_rates = [11025] MAX_LENGTH = 44100 # Define your desired maximum length here - def __init__(self, input_dir, device): - self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')] - self.device = device + def __init__(self, input_dir): + """ + Initializes the dataset. Device argument is removed. + """ + self.input_files = [ + os.path.join(root, f) + for root, _, files in os.walk(input_dir) + for f in files if f.endswith('.wav') + ] + if not self.input_files: + print(f"Warning: No .wav files found in {input_dir}") def __len__(self): return len(self.input_files) def __getitem__(self, idx): - # Load high-quality audio - high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True) + """ + Loads audio, processes it, and returns NumPy arrays. + """ + # --- Load and Resample using torchaudio (produces PyTorch tensors) --- + try: + high_quality_audio_pt, original_sample_rate = torchaudio.load( + self.input_files[idx], normalize=True + ) + except Exception as e: + print(f"Error loading file {self.input_files[idx]}: {e}") + # Return None or raise error, or return dummy data if preferred + # Returning dummy data might hide issues + return None # Or handle appropriately # Generate low-quality audio with random downsampling mangled_sample_rate = random.choice(self.audio_sample_rates) - resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) - low_quality_audio = resample_transform_low(high_quality_audio) - resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate) - low_quality_audio = resample_transform_high(low_quality_audio) + # Ensure sample rates are different before resampling + if original_sample_rate != mangled_sample_rate: + resample_transform_low = T.Resample(original_sample_rate, mangled_sample_rate) + low_quality_audio_pt = resample_transform_low(high_quality_audio_pt) - high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio) - low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio) + resample_transform_high = T.Resample(mangled_sample_rate, original_sample_rate) + low_quality_audio_pt = resample_transform_high(low_quality_audio_pt) + else: + # If rates match, just copy the tensor + low_quality_audio_pt = high_quality_audio_pt.clone() - # Pad or truncate high-quality audio - if high_quality_audio.shape[1] < self.MAX_LENGTH: - padding = self.MAX_LENGTH - high_quality_audio.shape[1] - high_quality_audio = F.pad(high_quality_audio, (0, padding)) - elif high_quality_audio.shape[1] > self.MAX_LENGTH: - high_quality_audio = high_quality_audio[:, :self.MAX_LENGTH] - # Pad or truncate low-quality audio - if low_quality_audio.shape[1] < self.MAX_LENGTH: - padding = self.MAX_LENGTH - low_quality_audio.shape[1] - low_quality_audio = F.pad(low_quality_audio, (0, padding)) - elif low_quality_audio.shape[1] > self.MAX_LENGTH: - low_quality_audio = low_quality_audio[:, :self.MAX_LENGTH] + # --- Process Stereo to Mono (still using PyTorch tensors) --- + # Assuming AudioUtils.stereo_tensor_to_mono expects PyTorch Tensor (C, L) + # and returns PyTorch Tensor (1, L) + try: + high_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(high_quality_audio_pt) + low_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(low_quality_audio_pt) + except Exception as e: + # Handle cases where mono conversion might fail (e.g., already mono) + # This depends on how AudioUtils is implemented. Let's assume it handles it. + print(f"Warning: Mono conversion issue with {self.input_files[idx]}: {e}. Using original.") + high_quality_audio_pt_mono = high_quality_audio_pt if high_quality_audio_pt.shape[0] == 1 else torch.mean(high_quality_audio_pt, dim=0, keepdim=True) + low_quality_audio_pt_mono = low_quality_audio_pt if low_quality_audio_pt.shape[0] == 1 else torch.mean(low_quality_audio_pt, dim=0, keepdim=True) - high_quality_audio = high_quality_audio.to(self.device) - low_quality_audio = low_quality_audio.to(self.device) - return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate) + # --- Convert to NumPy Arrays --- + high_quality_audio_np = high_quality_audio_pt_mono.numpy() # Shape (1, L) + low_quality_audio_np = low_quality_audio_pt_mono.numpy() # Shape (1, L) + + + # --- Pad or Truncate using NumPy --- + def process_numpy_audio(audio_np, max_len): + current_len = audio_np.shape[1] # Length is axis 1 for shape (1, L) + if current_len < max_len: + padding_needed = max_len - current_len + # np.pad format: ((pad_before_ax0, pad_after_ax0), (pad_before_ax1, pad_after_ax1), ...) + # We only pad axis 1 (length) at the end + audio_np = np.pad(audio_np, ((0, 0), (0, padding_needed)), mode='constant', constant_values=0) + elif current_len > max_len: + # Truncate axis 1 (length) + audio_np = audio_np[:, :max_len] + return audio_np + + high_quality_audio_np = process_numpy_audio(high_quality_audio_np, self.MAX_LENGTH) + low_quality_audio_np = process_numpy_audio(low_quality_audio_np, self.MAX_LENGTH) + + # --- Remove Device Handling --- + # .to(self.device) is removed. + + # --- Return NumPy arrays and metadata --- + # Note: Arrays likely have shape (1, MAX_LENGTH) here + return (high_quality_audio_np, original_sample_rate), (low_quality_audio_np, mangled_sample_rate) diff --git a/discriminator.py b/discriminator.py index dfd0126..0e3d494 100644 --- a/discriminator.py +++ b/discriminator.py @@ -1,63 +1,145 @@ -import torch -import torch.nn as nn -import torch.nn.utils as utils +import jax +import jax.numpy as jnp +from flax import linen as nn +from typing import Sequence, Tuple -def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, spectral_norm=True, use_instance_norm=True): - padding = (kernel_size // 2) * dilation - conv_layer = nn.Conv1d( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - padding=padding - ) - - if spectral_norm: - conv_layer = utils.spectral_norm(conv_layer) - - layers = [conv_layer] - layers.append(nn.LeakyReLU(0.2, inplace=True)) - - if use_instance_norm: - layers.append(nn.InstanceNorm1d(out_channels)) - - return nn.Sequential(*layers) +# Assume InstanceNorm1d and AttentionBlock are defined as in the generator conversion +# --- Custom InstanceNorm1d Implementation (from Generator) --- +class InstanceNorm1d(nn.Module): + features: int + epsilon: float = 1e-5 + use_scale: bool = True + use_bias: bool = True + @nn.compact + def __call__(self, x): + if x.shape[-1] != self.features: + raise ValueError(f"Input features {x.shape[-1]} does not match InstanceNorm1d features {self.features}") + mean = jnp.mean(x, axis=1, keepdims=True) + var = jnp.var(x, axis=1, keepdims=True) + normalized = (x - mean) / jnp.sqrt(var + self.epsilon) + if self.use_scale: + scale = self.param('scale', nn.initializers.ones, (self.features,)) + normalized *= scale + if self.use_bias: + bias = self.param('bias', nn.initializers.zeros, (self.features,)) + normalized += bias + return normalized +# --- AttentionBlock Implementation (from Generator) --- class AttentionBlock(nn.Module): - def __init__(self, channels): - super(AttentionBlock, self).__init__() - self.attention = nn.Sequential( - nn.Conv1d(channels, channels // 4, kernel_size=1), - nn.ReLU(inplace=True), - nn.Conv1d(channels // 4, channels, kernel_size=1), - nn.Sigmoid() - ) - - def forward(self, x): - attention_weights = self.attention(x) + channels: int + @nn.compact + def __call__(self, x): + ks1 = (1,) + attention_weights = nn.Conv(features=self.channels // 4, kernel_size=ks1, padding='SAME')(x) + attention_weights = nn.relu(attention_weights) + attention_weights = nn.Conv(features=self.channels, kernel_size=ks1, padding='SAME')(attention_weights) + attention_weights = nn.sigmoid(attention_weights) return x * attention_weights +# --- Converted Discriminator Modules --- + +class DiscriminatorBlock(nn.Module): + """Equivalent of the PyTorch discriminator_block function.""" + in_channels: int # Needed for clarity, though not strictly used by layers if input shape is known + out_channels: int + kernel_size: int = 3 + stride: int = 1 + dilation: int = 1 + # spectral_norm: bool = True # Flag for where SN would be applied + use_instance_norm: bool = True + negative_slope: float = 0.2 + + @nn.compact + def __call__(self, x): + """ + Args: + x: Input tensor (N, L, C_in) + Returns: + Output tensor (N, L', C_out) - L' depends on stride/padding + """ + # Flax Conv expects kernel_size, stride, dilation as sequences (tuples) + ks = (self.kernel_size,) + st = (self.stride,) + di = (self.dilation,) + + # Padding='SAME' works reasonably well for stride=1 and stride=2 downsampling + # NOTE: Spectral Norm is omitted here. + # If implementing, you'd wrap or replace nn.Conv with a spectral-normalized version. + # conv_layer = SpectralNormConv1D(...) or wrap(nn.Conv(...)) + y = nn.Conv( + features=self.out_channels, + kernel_size=ks, + strides=st, + kernel_dilation=di, + padding='SAME' # Often used in GANs + )(x) + + # Apply LeakyReLU first (as in the original code if IN is used) + y = nn.leaky_relu(y, negative_slope=self.negative_slope) + + # Conditionally apply InstanceNorm + if self.use_instance_norm: + y = InstanceNorm1d(features=self.out_channels)(y) + + return y + class SISUDiscriminator(nn.Module): - def __init__(self, base_channels=16): - super(SISUDiscriminator, self).__init__() - layers = base_channels - self.model = nn.Sequential( - discriminator_block(1, layers, kernel_size=7, stride=1, spectral_norm=True, use_instance_norm=False), - discriminator_block(layers, layers * 2, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers * 2, layers * 4, kernel_size=5, stride=1, dilation=2, spectral_norm=True, use_instance_norm=True), - AttentionBlock(layers * 4), - discriminator_block(layers * 4, layers * 8, kernel_size=5, stride=1, dilation=4, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers * 8, layers * 4, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers * 4, layers * 2, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers * 2, layers, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers, 1, kernel_size=3, stride=1, spectral_norm=False, use_instance_norm=False) - ) + """SISUDiscriminator model translated to Flax.""" + base_channels: int = 16 - self.global_avg_pool = nn.AdaptiveAvgPool1d(1) + @nn.compact + def __call__(self, x): + """ + Args: + x: Input tensor (N, L, 1) - assumes single channel input + Returns: + Output tensor (N, 1) - logits + """ + if x.shape[-1] != 1: + raise ValueError(f"Input should have 1 channel (NLC format), got shape {x.shape}") - def forward(self, x): - x = self.model(x) - x = self.global_avg_pool(x) - x = x.view(x.size(0), -1) - return x + ch = self.base_channels + + # Block 1: 1 -> ch, k=7, s=1, d=1, SN=T, IN=F + # NOTE: Spectral Norm omitted + y = DiscriminatorBlock(in_channels=1, out_channels=ch, kernel_size=7, stride=1, use_instance_norm=False)(x) + + # Block 2: ch -> ch*2, k=5, s=2, d=1, SN=T, IN=T + # NOTE: Spectral Norm omitted + y = DiscriminatorBlock(in_channels=ch, out_channels=ch*2, kernel_size=5, stride=2, use_instance_norm=True)(y) + + # Block 3: ch*2 -> ch*4, k=5, s=1, d=2, SN=T, IN=T + # NOTE: Spectral Norm omitted + y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch*4, kernel_size=5, stride=1, dilation=2, use_instance_norm=True)(y) + + # Attention Block + y = AttentionBlock(channels=ch*4)(y) + + # Block 4: ch*4 -> ch*8, k=5, s=1, d=4, SN=T, IN=T + # NOTE: Spectral Norm omitted + y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*8, kernel_size=5, stride=1, dilation=4, use_instance_norm=True)(y) + + # Block 5: ch*8 -> ch*4, k=5, s=2, d=1, SN=T, IN=T + # NOTE: Spectral Norm omitted + y = DiscriminatorBlock(in_channels=ch*8, out_channels=ch*4, kernel_size=5, stride=2, use_instance_norm=True)(y) + + # Block 6: ch*4 -> ch*2, k=3, s=1, d=1, SN=T, IN=T + # NOTE: Spectral Norm omitted + y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*2, kernel_size=3, stride=1, use_instance_norm=True)(y) + + # Block 7: ch*2 -> ch, k=3, s=1, d=1, SN=T, IN=T + # NOTE: Spectral Norm omitted + y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch, kernel_size=3, stride=1, use_instance_norm=True)(y) + + # Block 8: ch -> 1, k=3, s=1, d=1, SN=F, IN=F + # NOTE: Spectral Norm omitted (as per original config) + y = DiscriminatorBlock(in_channels=ch, out_channels=1, kernel_size=3, stride=1, use_instance_norm=False)(y) + + # Global Average Pooling (across Length dimension) + pooled = jnp.mean(y, axis=1) # Shape becomes (N, C=1) + + # Flatten (optional, as shape is likely already (N, 1)) + output = jnp.reshape(pooled, (pooled.shape[0], -1)) # Shape (N, 1) + + return output diff --git a/generator.py b/generator.py index a53feb7..b384eed 100644 --- a/generator.py +++ b/generator.py @@ -1,74 +1,173 @@ -import torch -import torch.nn as nn +import jax +import jax.numpy as jnp +from flax import linen as nn +from typing import Sequence, Tuple -def conv_block(in_channels, out_channels, kernel_size=3, dilation=1): - return nn.Sequential( - nn.Conv1d( - in_channels, - out_channels, - kernel_size=kernel_size, - dilation=dilation, - padding=(kernel_size // 2) * dilation - ), - nn.InstanceNorm1d(out_channels), - nn.PReLU() - ) +# --- Custom InstanceNorm1d Implementation --- +class InstanceNorm1d(nn.Module): + """ + Flax implementation of Instance Normalization for 1D data (NLC format). + Normalizes across the 'L' dimension. + """ + features: int + epsilon: float = 1e-5 + use_scale: bool = True + use_bias: bool = True + + @nn.compact + def __call__(self, x): + """ + Args: + x: Input tensor of shape (batch, length, features) + + Returns: + Normalized tensor. + """ + if x.shape[-1] != self.features: + raise ValueError(f"Input features {x.shape[-1]} does not match InstanceNorm1d features {self.features}") + + # Calculate mean and variance across the length dimension (axis=1) + # Keep dims for broadcasting + mean = jnp.mean(x, axis=1, keepdims=True) + # Variance calculation using mean needs care for numerical stability if needed, + # but jnp.var should handle it. + var = jnp.var(x, axis=1, keepdims=True) + + # Normalize + normalized = (x - mean) / jnp.sqrt(var + self.epsilon) + + # Apply learnable scale and bias if enabled + if self.use_scale: + # Parameter shape: (features,) to broadcast across N and L + scale = self.param('scale', nn.initializers.ones, (self.features,)) + normalized *= scale + if self.use_bias: + # Parameter shape: (features,) + bias = self.param('bias', nn.initializers.zeros, (self.features,)) + normalized += bias + + return normalized + +# --- Converted Modules --- + +class ConvBlock(nn.Module): + """Equivalent of the PyTorch conv_block function.""" + out_channels: int + kernel_size: int = 3 + dilation: int = 1 + + @nn.compact + def __call__(self, x): + """ + Args: + x: Input tensor (N, L, C_in) + Returns: + Output tensor (N, L, C_out) + """ + # Flax Conv expects kernel_size and dilation as sequences (tuples) + ks = (self.kernel_size,) + di = (self.dilation,) + + # Padding='SAME' attempts to preserve the length dimension for stride=1 + x = nn.Conv( + features=self.out_channels, + kernel_size=ks, + kernel_dilation=di, + padding='SAME' + )(x) + x = InstanceNorm1d(features=self.out_channels)(x) # Use custom InstanceNorm + x = nn.PReLU()(x) # PReLU learns 'alpha' parameter per channel + return x class AttentionBlock(nn.Module): - """ - Simple Channel Attention Block. Learns to weight channels based on their importance. - """ - def __init__(self, channels): - super(AttentionBlock, self).__init__() - self.attention = nn.Sequential( - nn.Conv1d(channels, channels // 4, kernel_size=1), - nn.ReLU(inplace=True), - nn.Conv1d(channels // 4, channels, kernel_size=1), - nn.Sigmoid() - ) + """Simple Channel Attention Block in Flax.""" + channels: int + + @nn.compact + def __call__(self, x): + """ + Args: + x: Input tensor (N, L, C) + Returns: + Attention-weighted output tensor (N, L, C) + """ + # Flax Conv expects kernel_size as a sequence (tuple) + ks1 = (1,) + attention_weights = nn.Conv( + features=self.channels // 4, kernel_size=ks1, padding='SAME' + )(x) + # NOTE: PyTorch used inplace=True, JAX/Flax don't modify inplace + attention_weights = nn.relu(attention_weights) + attention_weights = nn.Conv( + features=self.channels, kernel_size=ks1, padding='SAME' + )(attention_weights) + attention_weights = nn.sigmoid(attention_weights) - def forward(self, x): - attention_weights = self.attention(x) return x * attention_weights class ResidualInResidualBlock(nn.Module): - def __init__(self, channels, num_convs=3): - super(ResidualInResidualBlock, self).__init__() + """ResidualInResidualBlock in Flax.""" + channels: int + num_convs: int = 3 - self.conv_layers = nn.Sequential( - *[conv_block(channels, channels) for _ in range(num_convs)] - ) - - self.attention = AttentionBlock(channels) - - def forward(self, x): + @nn.compact + def __call__(self, x): + """ + Args: + x: Input tensor (N, L, C) + Returns: + Output tensor (N, L, C) + """ residual = x - x = self.conv_layers(x) - x = self.attention(x) - return x + residual + y = x + # Sequentially apply ConvBlocks + for _ in range(self.num_convs): + y = ConvBlock( + out_channels=self.channels, + kernel_size=3, # Assuming kernel_size 3 as in original conv_block default + dilation=1 # Assuming dilation 1 as in original conv_block default + )(y) + + y = AttentionBlock(channels=self.channels)(y) + return y + residual class SISUGenerator(nn.Module): - def __init__(self, channels=16, num_rirb=4, alpha=1.0): - super(SISUGenerator, self).__init__() - self.alpha = alpha + """SISUGenerator model translated to Flax.""" + channels: int = 16 + num_rirb: int = 4 + alpha: float = 1.0 # Non-learnable parameter, passed during init - self.conv1 = nn.Sequential( - nn.Conv1d(1, channels, kernel_size=7, padding=3), - nn.InstanceNorm1d(channels), - nn.PReLU(), - ) + @nn.compact + def __call__(self, x): + """ + Args: + x: Input tensor (N, L, 1) - assumes single channel input + Returns: + Output tensor (N, L, 1) + """ + if x.shape[-1] != 1: + raise ValueError(f"Input should have 1 channel (NLC format), got shape {x.shape}") - self.rir_blocks = nn.Sequential( - *[ResidualInResidualBlock(channels) for _ in range(num_rirb)] - ) - - self.final_layer = nn.Conv1d(channels, 1, kernel_size=3, padding=1) - - def forward(self, x): residual_input = x - x = self.conv1(x) - x_rirb_out = self.rir_blocks(x) - learned_residual = self.final_layer(x_rirb_out) - output = residual_input + self.alpha * learned_residual + # Initial convolution block + # Flax Conv expects kernel_size as sequence + ks7 = (7,) + ks3 = (3,) + y = nn.Conv(features=self.channels, kernel_size=ks7, padding='SAME')(x) + y = InstanceNorm1d(features=self.channels)(y) + y = nn.PReLU()(y) + + # Residual-in-Residual Blocks + rirb_out = y + for _ in range(self.num_rirb): + rirb_out = ResidualInResidualBlock(channels=self.channels)(rirb_out) + + # Final layer + learned_residual = nn.Conv( + features=1, kernel_size=ks3, padding='SAME' + )(rirb_out) + + # Combine with input residual + output = residual_input + self.alpha * learned_residual return output diff --git a/training.py b/training.py index 01ea749..45dc7cb 100644 --- a/training.py +++ b/training.py @@ -1,46 +1,33 @@ -import torch -import torch.nn as nn -import torch.optim as optim - -import torch.nn.functional as F -import torchaudio +import jax +import jax.numpy as jnp +import optax import tqdm - -import argparse - -import math - +import pickle # Using pickle for simplicity to save JAX states import os +import argparse +# You might need a JAX-compatible library for audio loading/saving or convert to numpy +import scipy.io.wavfile as wavfile # Example for saving audio -from torch.utils.data import random_split +import torch from torch.utils.data import DataLoader -import AudioUtils -from data import AudioDataset +import file_utils as Data +from data import AudioDatasetNumPy from generator import SISUGenerator from discriminator import SISUDiscriminator -from training_utils import discriminator_train, generator_train -import file_utils as Data - -import torchaudio.transforms as T - # Init script argument parser parser = argparse.ArgumentParser(description="Training script") parser.add_argument("--generator", type=str, default=None, help="Path to the generator model file") parser.add_argument("--discriminator", type=str, default=None, help="Path to the discriminator model file") -parser.add_argument("--device", type=str, default="cpu", help="Select device") -parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning") +parser.add_argument("--epoch", type=int, default=0, help="Starting epoch for model versioning") parser.add_argument("--debug", action="store_true", help="Print debug logs") parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models") args = parser.parse_args() -device = torch.device(args.device if torch.cuda.is_available() else "cpu") -print(f"Using device: {device}") - # Parameters sample_rate = 44100 n_fft = 2048 @@ -49,146 +36,446 @@ win_length = n_fft n_mels = 128 n_mfcc = 20 # If using MFCC -mfcc_transform = T.MFCC( - sample_rate, - n_mfcc, - melkwargs = {'n_fft': n_fft, 'hop_length': hop_length} -).to(device) - -mel_transform = T.MelSpectrogram( - sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, - win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel -).to(device) - -stft_transform = T.Spectrogram( - n_fft=n_fft, win_length=win_length, hop_length=hop_length -).to(device) - debug = args.debug +# Initialize JAX random key +key = jax.random.PRNGKey(0) + # Initialize dataset and dataloader dataset_dir = './dataset/good' -dataset = AudioDataset(dataset_dir, device) +dataset = AudioDatasetNumPy(dataset_dir) # Use your JAX dataset +train_data_loader = DataLoader(dataset, batch_size=4, shuffle=True) # Use your JAX DataLoader + models_dir = "models" os.makedirs(models_dir, exist_ok=True) audio_output_dir = "output" os.makedirs(audio_output_dir, exist_ok=True) -# ========= SINGLE ========= - -train_data_loader = DataLoader(dataset, batch_size=64, shuffle=True) - # ========= MODELS ========= -generator = SISUGenerator() -discriminator = SISUDiscriminator() +try: + # Fetch the first batch + first_batch = next(iter(train_data_loader)) + # The batch is a tuple: ((high_quality_audio_np, high_quality_sample_rate), (low_quality_audio_np, low_quality_sample_rate)) + # We need the high-quality audio NumPy array batch for initialization + sample_input_np = first_batch[0][0] # Get the high-quality audio NumPy array batch + # Convert the NumPy array batch to a JAX array + sample_input_array = jnp.array(sample_input_np) -epoch: int = args.epoch -epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json") + # === FIX === + # Transpose the array from (batch, channels, length) to (batch, length, channels) + # The original shape from DataLoader is likely (batch, channels, length) like (4, 1, 44100) + # The generator expects NLC format (batch, length, channels) i.e., (4, 44100, 1) + if sample_input_array.ndim == 3 and sample_input_array.shape[1] == 1: + sample_input_array = jnp.transpose(sample_input_array, (0, 2, 1)) # Swap axes 1 and 2 + print(sample_input_array.shape) # Should now print (4, 44100, 1) + # === END FIX === + +except StopIteration: + print("Error: Data loader is empty. Cannot initialize models.") + exit() # Exit if no data is available + + +key, init_key_g, init_key_d = jax.random.split(key, 3) +generator_model = SISUGenerator() +discriminator_model = SISUDiscriminator() + +# Initialize parameters +generator_params = generator_model.init(init_key_g, sample_input_array)['params'] +discriminator_params = discriminator_model.init(init_key_d, sample_input_array)['params'] + +# Define apply functions +generator_apply_fn = generator_model.apply +discriminator_apply_fn = discriminator_model.apply + + +# Loss functions (JAX equivalents) +# Optax provides common loss functions. BCEWithLogitsLoss is equivalent to +# sigmoid_binary_cross_entropy in Optax combined with a sigmoid activation +# in the model output or handling logits directly. Assuming your discriminator +# outputs logits, optax.sigmoid_binary_cross_entropy is appropriate. +criterion_d = optax.sigmoid_binary_cross_entropy +criterion_l1 = optax.sigmoid_binary_cross_entropy # For Mel, STFT, MFCC losses + +# Optimizers (using Optax) +optimizer_g = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.999) +optimizer_d = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.999) + +# Initialize optimizer states +generator_opt_state = optimizer_g.init(generator_params) +discriminator_opt_state = optimizer_d.init(discriminator_params) + +# Schedulers - Optax has learning rate schedules. ReduceLROnPlateau +# is stateful and usually handled outside the jitted training step, +# or you can implement a custom learning rate schedule in Optax that +# takes a metric. For simplicity here, we won't directly replicate the +# PyTorch ReduceLROnPlateau but you could add logic based on losses +# in the main loop to adjust the learning rate if needed. + + +# Load saved state if continuing training +start_epoch = args.epoch if args.continue_training: - generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) - discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) - epoch = epoch_from_file["epoch"] + 1 -else: - if args.generator is not None: - generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True)) - if args.discriminator is not None: - discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True)) + try: + with open(f"{models_dir}/temp_generator.pkl", 'rb') as f: + loaded_state = pickle.load(f) + generator_params = loaded_state['params'] + generator_opt_state = loaded_state['opt_state'] + with open(f"{models_dir}/temp_discriminator.pkl", 'rb') as f: + loaded_state = pickle.load(f) + discriminator_params = loaded_state['params'] + discriminator_opt_state = loaded_state['opt_state'] + epoch_data = Data.read_data(f"{models_dir}/epoch_data.json") + start_epoch = epoch_data.get("epoch", 0) + 1 + print(f"Continuing training from epoch {start_epoch}") + except FileNotFoundError: + print("Continue training requested but temp models not found. Starting from scratch.") + except Exception as e: + print(f"Error loading temp models: {e}. Starting from scratch.") -generator = generator.to(device) -discriminator = discriminator.to(device) +if args.generator is not None: + try: + with open(args.generator, 'rb') as f: + loaded_state = pickle.load(f) + generator_params = loaded_state['params'] + print(f"Loaded generator from {args.generator}") + except FileNotFoundError: + print(f"Generator model not found at {args.generator}") -# Loss -criterion_g = nn.BCEWithLogitsLoss() -criterion_d = nn.BCEWithLogitsLoss() +if args.discriminator is not None: + try: + with open(args.discriminator, 'rb') as f: + loaded_state = pickle.load(f) + discriminator_params = loaded_state['params'] + print(f"Loaded discriminator from {args.discriminator}") + except FileNotFoundError: + print(f"Discriminator model not found at {args.discriminator}") -# Optimizers -optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999)) -optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) -# Scheduler -scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5) -scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5) +# Initialize JAX audio transforms +# mel_transform_fn = MelSpectrogramJAX( +# sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, +# win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel +# ) + +# stft_transform_fn = SpectrogramJAX( +# n_fft=n_fft, win_length=win_length, hop_length=hop_length +# ) + +# mfcc_transform_fn = MFCCJAX( +# sample_rate, +# n_mfcc, +# melkwargs = {'n_fft': n_fft, 'hop_length': hop_length} +# ) + + +# ========= JAX TRAINING STEPS ========= + +@jax.jit +def discriminator_train_step( + discriminator_params, + generator_params, + discriminator_opt_state, + high_quality_audio, # JAX array (batch, length, channels) + low_quality_audio, # JAX array (batch, length, channels) + real_labels, # JAX array + fake_labels, # JAX array + discriminator_apply_fn, + generator_apply_fn, + discriminator_optimizer, + criterion_d, + key # JAX random key +): + # Split key for potential randomness in model application + key, disc_key, gen_key = jax.random.split(key, 3) + + def loss_fn(d_params): + # Generate fake audio + # Note: Generator is not being trained in this step, so its parameters are static + # Ensure low_quality_audio is in the expected NLC format (batch, length, channels) + if low_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim + low_quality_audio = jnp.expand_dims(low_quality_audio, axis=-1) + elif low_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims + low_quality_audio = jnp.expand_dims(jnp.expand_dims(low_quality_audio, axis=0), axis=-1) + + + enhanced_audio, _ = generator_apply_fn({'params': generator_params}, gen_key, low_quality_audio) + + # Pass data through the discriminator + # Ensure enhanced_audio has a leading dimension if not already present (e.g., batch size) + if enhanced_audio.ndim == 2: # Assuming (length, channel), add batch dim + enhanced_audio = jnp.expand_dims(enhanced_audio, axis=0) + elif enhanced_audio.ndim == 1: # Assuming (length), add batch and channel dims + enhanced_audio = jnp.expand_dims(jnp.expand_dims(enhanced_audio, axis=0), axis=-1) + + + # Ensure high_quality_audio is in the expected NLC format (batch, length, channels) + if high_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim + high_quality_audio = jnp.expand_dims(high_quality_audio, axis=-1) + elif high_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims + high_quality_audio = jnp.expand_dims(jnp.expand_dims(high_quality_audio, axis=0), axis=-1) + + + real_output = discriminator_apply_fn({'params': d_params}, disc_key, high_quality_audio) + fake_output = discriminator_apply_fn({'params': d_params}, disc_key, enhanced_audio) + + # Calculate loss (criterion_d is assumed to be Optax's sigmoid_binary_cross_entropy or similar) + # Ensure the shapes match the labels (batch_size, 1) + real_output = real_output.reshape(-1, 1) + fake_output = fake_output.reshape(-1, 1) + + real_loss = jnp.mean(criterion_d(real_output, real_labels)) + fake_loss = jnp.mean(criterion_d(fake_output, fake_labels)) + total_loss = real_loss + fake_loss + return total_loss, (real_loss, fake_loss) + + # Compute gradients + # Use jax.value_and_grad to get both the loss value and the gradients + (loss, (real_loss, fake_loss)), grads = jax.value_and_grad(loss_fn, has_aux=True)(discriminator_params) + + # Apply updates + updates, new_discriminator_opt_state = discriminator_optimizer.update(grads, discriminator_opt_state, discriminator_params) + new_discriminator_params = optax.apply_updates(discriminator_params, updates) + + return new_discriminator_params, new_discriminator_opt_state, loss, key + + +@jax.jit +def generator_train_step( + generator_params, + discriminator_params, + generator_opt_state, + low_quality_audio, # JAX array (batch, length, channels) + high_quality_audio, # JAX array (batch, length, channels) + real_labels, # JAX array + generator_apply_fn, + discriminator_apply_fn, + generator_optimizer, + criterion_d, # Adversarial loss + criterion_l1, # Feature matching loss + key # JAX random key +): + # Split key for potential randomness + key, gen_key, disc_key = jax.random.split(key, 3) + + def loss_fn(g_params): + # Ensure low_quality_audio is in the expected NLC format (batch, length, channels) + if low_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim + low_quality_audio = jnp.expand_dims(low_quality_audio, axis=-1) + elif low_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims + low_quality_audio = jnp.expand_dims(jnp.expand_dims(low_quality_audio, axis=0), axis=-1) + + # Generate enhanced audio + enhanced_audio, _ = generator_apply_fn({'params': g_params}, gen_key, low_quality_audio) + + # Ensure enhanced_audio has a leading dimension if not already present + if enhanced_audio.ndim == 2: # Assuming (length, channel), add batch dim + enhanced_audio = jnp.expand_dims(enhanced_audio, axis=0) + elif enhanced_audio.ndim == 1: # Assuming (length), add batch and channel dims + enhanced_audio = jnp.expand_dims(jnp.expand_dims(enhanced_audio, axis=0), axis=-1) + + + # Calculate adversarial loss (generator wants discriminator to think fake is real) + # Note: Discriminator is not being trained in this step, so its parameters are static + fake_output = discriminator_apply_fn({'params': discriminator_params}, disc_key, enhanced_audio) + # Ensure the shape matches the labels (batch_size, 1) + fake_output = fake_output.reshape(-1, 1) + adversarial_loss = jnp.mean(criterion_d(fake_output, real_labels)) # Generator wants fake_output to be close to real_labels (1s) + + # Feature matching losses (assuming you add these back later) + # You would need to implement JAX versions of your audio transforms + # mel_loss = criterion_l1(mel_transform_fn(enhanced_audio), mel_transform_fn(high_quality_audio)) + # stft_loss = criterion_l1(stft_transform_fn(enhanced_audio), stft_transform_fn(high_quality_audio)) + # mfcc_loss = criterion_l1(mfcc_transform_fn(enhanced_audio), mfcc_transform_fn(high_quality_audio)) + + # combined_loss = adversarial_loss + mel_loss + stft_loss + mfcc_loss + combined_loss = adversarial_loss # For now, only adversarial loss + + # Return combined_loss and any other metrics needed for logging/analysis + # For now, just adversarial loss and enhanced_audio + return combined_loss, (adversarial_loss, enhanced_audio) # Add other losses here when implemented + + # Compute gradients + # Update: loss_fn now returns (loss, (aux1, aux2, ...)) + (loss, (adversarial_loss_val, enhanced_audio)), grads = jax.value_and_grad(loss_fn, has_aux=True)(generator_params) + + # Apply updates + updates, new_generator_opt_state = generator_optimizer.update(grads, generator_opt_state, generator_params) + new_generator_params = optax.apply_updates(generator_params, updates) + + # Return the loss components separately along with the enhanced audio and key + return new_generator_params, new_generator_opt_state, loss, adversarial_loss_val, enhanced_audio, key + + +# ========= MAIN TRAINING LOOP ========= def start_training(): + global generator_params, discriminator_params, generator_opt_state, discriminator_opt_state, key generator_epochs = 5000 + for generator_epoch in range(generator_epochs): - low_quality_audio = (torch.empty((1)), 1) - high_quality_audio = (torch.empty((1)), 1) - ai_enhanced_audio = (torch.empty((1)), 1) + current_epoch = start_epoch + generator_epoch - times_correct = 0 + # These will hold the last processed audio examples from a batch for saving + last_high_quality_audio = None + last_low_quality_audio = None + last_ai_enhanced_audio = None + last_sample_rate = None - # ========= TRAINING ========= - for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"): - # for high_quality_clip, low_quality_clip in train_data_loader: - high_quality_sample = (high_quality_clip[0], high_quality_clip[1]) - low_quality_sample = (low_quality_clip[0], low_quality_clip[1]) - # ========= LABELS ========= - batch_size = high_quality_clip[0].size(0) - real_labels = torch.ones(batch_size, 1).to(device) - fake_labels = torch.zeros(batch_size, 1).to(device) + # Use tqdm for progress bar + for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {current_epoch}"): + + # high_quality_clip and low_quality_clip are tuples: (audio_array, sample_rate_array) + # Extract audio arrays and sample rates (assuming batch dimension is first) + # The arrays are NumPy arrays at this point, likely in (batch, channels, length) format + high_quality_audio_batch_np = high_quality_clip[0] + low_quality_audio_batch_np = low_quality_clip[0] + sample_rate_batch_np = high_quality_clip[1] # Assuming sample rates are the same for paired clips + + # Convert NumPy arrays to JAX arrays and transpose to NLC format (batch, length, channels) + # Only transpose if the shape is (batch, channels, length) + if high_quality_audio_batch_np.ndim == 3 and high_quality_audio_batch_np.shape[1] == 1: + high_quality_audio_batch = jnp.transpose(jnp.array(high_quality_audio_batch_np), (0, 2, 1)) + else: + high_quality_audio_batch = jnp.array(high_quality_audio_batch_np) # Assume already NLC or handle other cases + + if low_quality_audio_batch_np.ndim == 3 and low_quality_audio_batch_np.shape[1] == 1: + low_quality_audio_batch = jnp.transpose(jnp.array(low_quality_audio_batch_np), (0, 2, 1)) + else: + low_quality_audio_batch = jnp.array(low_quality_audio_batch_np) # Assume already NLC or handle other cases + + sample_rate_batch = jnp.array(sample_rate_batch_np) + + + batch_size = high_quality_audio_batch.shape[0] + # Create labels - JAX arrays + real_labels = jnp.ones((batch_size, 1)) + fake_labels = jnp.zeros((batch_size, 1)) + + # Split key for each batch + key, batch_key = jax.random.split(key) # ========= DISCRIMINATOR ========= - discriminator.train() - d_loss = discriminator_train( - high_quality_sample, - low_quality_sample, + # Call the jitted discriminator training step + discriminator_params, discriminator_opt_state, d_loss, batch_key = discriminator_train_step( + discriminator_params, + generator_params, + discriminator_opt_state, + high_quality_audio_batch, + low_quality_audio_batch, real_labels, fake_labels, - discriminator, - generator, + discriminator_apply_fn, + generator_apply_fn, + optimizer_d, criterion_d, - optimizer_d + batch_key ) # ========= GENERATOR ========= - generator.train() - generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train( - low_quality_sample, - high_quality_sample, - real_labels, - generator, - discriminator, - criterion_d, + # Call the jitted generator training step + generator_params, generator_opt_state, combined_loss, adversarial_loss, enhanced_audio_batch, batch_key = generator_train_step( + generator_params, + discriminator_params, + generator_opt_state, + low_quality_audio_batch, + high_quality_audio_batch, + real_labels, # Generator tries to make fake data look real + generator_apply_fn, + discriminator_apply_fn, optimizer_g, - device, - mel_transform, - stft_transform, - mfcc_transform + criterion_d, + criterion_l1, + batch_key ) + # Print debug logs (requires waiting for JIT compilation on first step) if debug: - print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}") - scheduler_d.step(d_loss.detach()) - scheduler_g.step(adversarial_loss.detach()) + # Use .block_until_ready() to ensure computation is finished before printing + # In a real scenario, you might want to log metrics less frequently + d_loss_val = d_loss.block_until_ready().item() + combined_loss_val = combined_loss.block_until_ready().item() + adversarial_loss_val = adversarial_loss.block_until_ready().item() + # Assuming other losses are returned by generator_train_step and unpacked + # mel_loss_val = mel_l1_tensor.block_until_ready().item() if mel_l1_tensor is not None else 0 + # stft_loss_val = log_stft_l1_tensor.block_until_ready().item() if log_stft_l1_tensor is not None else 0 + # mfcc_loss_val = mfcc_l_tensor.block_until_ready().item() if mfcc_l_tensor is not None else 0 + print(f"D_LOSS: {d_loss_val:.4f}, G_COMBINED_LOSS: {combined_loss_val:.4f}, G_ADVERSARIAL_LOSS: {adversarial_loss_val:.4f}") + # Print other losses here when implemented and returned - # ========= SAVE LATEST AUDIO ========= - high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0]) - low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0]) - ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0]) - - new_epoch = generator_epoch+epoch - - if generator_epoch % 25 == 0: - print(f"Saved epoch {new_epoch}!") - torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu().detach(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. - torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1]) - torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1]) - - #if debug: - # print(generator.state_dict().keys()) - # print(discriminator.state_dict().keys()) - torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt") - torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt") - Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch}) + # Schedulers - Implement your learning rate scheduling logic here if needed + # based on the losses (e.g., reducing learning rate if loss plateaus). + # This logic would typically live outside the jitted step function. + # For Optax, you might use a schedule within the optimizer definition + # or update the learning rate of the optimizer manually. - torch.save(discriminator, "models/epoch-5000-discriminator.pt") - torch.save(generator, "models/epoch-5000-generator.pt") - print("Training complete!") + # ========= SAVE LATEST AUDIO (from the last batch processed) ========= + # Access the first sample of the batch for saving + # Ensure enhanced_audio_batch has a batch dimension and is in NLC format + if enhanced_audio_batch.ndim == 2: # Assuming (length, channel), add batch dim + enhanced_audio_batch = jnp.expand_dims(enhanced_audio_batch, axis=0) + elif enhanced_audio_batch.ndim == 1: # Assuming (length), add batch and channel dims + enhanced_audio_batch = jnp.expand_dims(jnp.expand_dims(enhanced_audio_batch, axis=0), axis=-1) + + + last_high_quality_audio = high_quality_audio_batch[0] + last_low_quality_audio = low_quality_audio_batch[0] + last_ai_enhanced_audio = enhanced_audio_batch[0] + last_sample_rate = sample_rate_batch[0].item() # Assuming sample rate is scalar per batch item + + + # Save audio files periodically (outside the batch loop) + if generator_epoch % 25 == 0 and last_high_quality_audio is not None: + print(f"Saving audio for epoch {current_epoch}!") + try: + # Convert JAX arrays to NumPy arrays for saving + # Transpose back to (length, channels) or (length) if needed by wavfile.write + # Assuming the models output (length, 1) or (length) after removing batch dim + low_quality_audio_np_save = jax.device_get(last_low_quality_audio) + ai_enhanced_audio_np_save = jax.device_get(last_ai_enhanced_audio) + high_quality_audio_np_save = jax.device_get(last_high_quality_audio) + + # Remove the channel dimension if it's 1 for saving with wavfile + if low_quality_audio_np_save.shape[-1] == 1: + low_quality_audio_np_save = low_quality_audio_np_save.squeeze(axis=-1) + if ai_enhanced_audio_np_save.shape[-1] == 1: + ai_enhanced_audio_np_save = ai_enhanced_audio_np_save.squeeze(axis=-1) + if high_quality_audio_np_save.shape[-1] == 1: + high_quality_audio_np_save = high_quality_audio_np_save.squeeze(axis=-1) + + + wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-crap.wav", last_sample_rate, low_quality_audio_np_save.astype(jnp.int16)) # Assuming audio is int16 + wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-ai.wav", last_sample_rate, ai_enhanced_audio_np_save.astype(jnp.int16)) # Assuming audio is int16 + wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-orig.wav", last_sample_rate, high_quality_audio_np_save.astype(jnp.int16)) # Assuming audio is int16 + except Exception as e: + print(f"Error saving audio files: {e}") + + + # Save model states periodically (outside the batch loop) + # Use pickle to save parameters and optimizer states + try: + with open(f"{models_dir}/temp_discriminator.pkl", 'wb') as f: + pickle.dump({'params': jax.device_get(discriminator_params), 'opt_state': jax.device_get(discriminator_opt_state)}, f) + with open(f"{models_dir}/temp_generator.pkl", 'wb') as f: + pickle.dump({'params': jax.device_get(generator_params), 'opt_state': jax.device_get(generator_opt_state)}, f) + Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": current_epoch}) + except Exception as e: + print(f"Error saving temp model states: {e}") + + + # Save final model states after all epochs + print("Training complete! Saving final models.") + try: + with open(f"{models_dir}/epoch-{start_epoch + generator_epochs - 1}-discriminator.pkl", 'wb') as f: + pickle.dump({'params': jax.device_get(discriminator_params)}, f) + with open(f"{models_dir}/epoch-{start_epoch + generator_epochs - 1}-generator.pkl", 'wb') as f: + pickle.dump({'params': jax.device_get(generator_params)}, f) + except Exception as e: + print(f"Error saving final model states: {e}") + start_training() diff --git a/training.txt b/training.txt new file mode 100644 index 0000000..01ea749 --- /dev/null +++ b/training.txt @@ -0,0 +1,194 @@ +import torch +import torch.nn as nn +import torch.optim as optim + +import torch.nn.functional as F +import torchaudio +import tqdm + +import argparse + +import math + +import os + +from torch.utils.data import random_split +from torch.utils.data import DataLoader + +import AudioUtils +from data import AudioDataset +from generator import SISUGenerator +from discriminator import SISUDiscriminator + +from training_utils import discriminator_train, generator_train +import file_utils as Data + +import torchaudio.transforms as T + +# Init script argument parser +parser = argparse.ArgumentParser(description="Training script") +parser.add_argument("--generator", type=str, default=None, + help="Path to the generator model file") +parser.add_argument("--discriminator", type=str, default=None, + help="Path to the discriminator model file") +parser.add_argument("--device", type=str, default="cpu", help="Select device") +parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning") +parser.add_argument("--debug", action="store_true", help="Print debug logs") +parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models") + +args = parser.parse_args() + +device = torch.device(args.device if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + +# Parameters +sample_rate = 44100 +n_fft = 2048 +hop_length = 256 +win_length = n_fft +n_mels = 128 +n_mfcc = 20 # If using MFCC + +mfcc_transform = T.MFCC( + sample_rate, + n_mfcc, + melkwargs = {'n_fft': n_fft, 'hop_length': hop_length} +).to(device) + +mel_transform = T.MelSpectrogram( + sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, + win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel +).to(device) + +stft_transform = T.Spectrogram( + n_fft=n_fft, win_length=win_length, hop_length=hop_length +).to(device) + +debug = args.debug + +# Initialize dataset and dataloader +dataset_dir = './dataset/good' +dataset = AudioDataset(dataset_dir, device) +models_dir = "models" +os.makedirs(models_dir, exist_ok=True) +audio_output_dir = "output" +os.makedirs(audio_output_dir, exist_ok=True) + +# ========= SINGLE ========= + +train_data_loader = DataLoader(dataset, batch_size=64, shuffle=True) + + +# ========= MODELS ========= + +generator = SISUGenerator() +discriminator = SISUDiscriminator() + +epoch: int = args.epoch +epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json") + +if args.continue_training: + generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) + discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) + epoch = epoch_from_file["epoch"] + 1 +else: + if args.generator is not None: + generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True)) + if args.discriminator is not None: + discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True)) + +generator = generator.to(device) +discriminator = discriminator.to(device) + +# Loss +criterion_g = nn.BCEWithLogitsLoss() +criterion_d = nn.BCEWithLogitsLoss() + +# Optimizers +optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999)) +optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) + +# Scheduler +scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5) +scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5) + +def start_training(): + generator_epochs = 5000 + for generator_epoch in range(generator_epochs): + low_quality_audio = (torch.empty((1)), 1) + high_quality_audio = (torch.empty((1)), 1) + ai_enhanced_audio = (torch.empty((1)), 1) + + times_correct = 0 + + # ========= TRAINING ========= + for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"): + # for high_quality_clip, low_quality_clip in train_data_loader: + high_quality_sample = (high_quality_clip[0], high_quality_clip[1]) + low_quality_sample = (low_quality_clip[0], low_quality_clip[1]) + + # ========= LABELS ========= + batch_size = high_quality_clip[0].size(0) + real_labels = torch.ones(batch_size, 1).to(device) + fake_labels = torch.zeros(batch_size, 1).to(device) + + # ========= DISCRIMINATOR ========= + discriminator.train() + d_loss = discriminator_train( + high_quality_sample, + low_quality_sample, + real_labels, + fake_labels, + discriminator, + generator, + criterion_d, + optimizer_d + ) + + # ========= GENERATOR ========= + generator.train() + generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train( + low_quality_sample, + high_quality_sample, + real_labels, + generator, + discriminator, + criterion_d, + optimizer_g, + device, + mel_transform, + stft_transform, + mfcc_transform + ) + + if debug: + print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}") + scheduler_d.step(d_loss.detach()) + scheduler_g.step(adversarial_loss.detach()) + + # ========= SAVE LATEST AUDIO ========= + high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0]) + low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0]) + ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0]) + + new_epoch = generator_epoch+epoch + + if generator_epoch % 25 == 0: + print(f"Saved epoch {new_epoch}!") + torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu().detach(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. + torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1]) + torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1]) + + #if debug: + # print(generator.state_dict().keys()) + # print(discriminator.state_dict().keys()) + torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt") + torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt") + Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch}) + + + torch.save(discriminator, "models/epoch-5000-discriminator.pt") + torch.save(generator, "models/epoch-5000-generator.pt") + print("Training complete!") + +start_training() diff --git a/training_utils.py b/training_utils.py index 6f26f58..989b5ca 100644 --- a/training_utils.py +++ b/training_utils.py @@ -20,12 +20,10 @@ def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tenso mel_spec_true = mel_transform(y_true) mel_spec_pred = mel_transform(y_pred) - # Ensure same time dimension length (due to potential framing differences) min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1]) mel_spec_true = mel_spec_true[..., :min_len] mel_spec_pred = mel_spec_pred[..., :min_len] - # L1 Loss (Mean Absolute Error) loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred)) return loss