Compare commits

..

1 Commits
jax ... main

6 changed files with 270 additions and 981 deletions

109
data.py
View File

@ -1,104 +1,53 @@
# Keep necessary PyTorch imports for torchaudio and Dataset structure
from torch.utils.data import Dataset from torch.utils.data import Dataset
import torch.nn.functional as F
import torch import torch
import torchaudio import torchaudio
import torchaudio.transforms as T # Keep using torchaudio transforms
# Import NumPy
import numpy as np
import os import os
import random import random
# Assume AudioUtils is available and works on PyTorch Tensors as before import torchaudio.transforms as T
import AudioUtils import AudioUtils
class AudioDatasetNumPy(Dataset): # Renamed slightly for clarity class AudioDataset(Dataset):
audio_sample_rates = [11025] audio_sample_rates = [11025]
MAX_LENGTH = 44100 # Define your desired maximum length here MAX_LENGTH = 44100 # Define your desired maximum length here
def __init__(self, input_dir): def __init__(self, input_dir, device):
""" self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')]
Initializes the dataset. Device argument is removed. self.device = device
"""
self.input_files = [
os.path.join(root, f)
for root, _, files in os.walk(input_dir)
for f in files if f.endswith('.wav')
]
if not self.input_files:
print(f"Warning: No .wav files found in {input_dir}")
def __len__(self): def __len__(self):
return len(self.input_files) return len(self.input_files)
def __getitem__(self, idx): def __getitem__(self, idx):
""" # Load high-quality audio
Loads audio, processes it, and returns NumPy arrays. high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True)
"""
# --- Load and Resample using torchaudio (produces PyTorch tensors) ---
try:
high_quality_audio_pt, original_sample_rate = torchaudio.load(
self.input_files[idx], normalize=True
)
except Exception as e:
print(f"Error loading file {self.input_files[idx]}: {e}")
# Return None or raise error, or return dummy data if preferred
# Returning dummy data might hide issues
return None # Or handle appropriately
# Generate low-quality audio with random downsampling # Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates) mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio = resample_transform_low(high_quality_audio)
# Ensure sample rates are different before resampling resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
if original_sample_rate != mangled_sample_rate: low_quality_audio = resample_transform_high(low_quality_audio)
resample_transform_low = T.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio_pt = resample_transform_low(high_quality_audio_pt)
resample_transform_high = T.Resample(mangled_sample_rate, original_sample_rate) high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio)
low_quality_audio_pt = resample_transform_high(low_quality_audio_pt) low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio)
else:
# If rates match, just copy the tensor
low_quality_audio_pt = high_quality_audio_pt.clone()
# Pad or truncate high-quality audio
if high_quality_audio.shape[1] < self.MAX_LENGTH:
padding = self.MAX_LENGTH - high_quality_audio.shape[1]
high_quality_audio = F.pad(high_quality_audio, (0, padding))
elif high_quality_audio.shape[1] > self.MAX_LENGTH:
high_quality_audio = high_quality_audio[:, :self.MAX_LENGTH]
# --- Process Stereo to Mono (still using PyTorch tensors) --- # Pad or truncate low-quality audio
# Assuming AudioUtils.stereo_tensor_to_mono expects PyTorch Tensor (C, L) if low_quality_audio.shape[1] < self.MAX_LENGTH:
# and returns PyTorch Tensor (1, L) padding = self.MAX_LENGTH - low_quality_audio.shape[1]
try: low_quality_audio = F.pad(low_quality_audio, (0, padding))
high_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(high_quality_audio_pt) elif low_quality_audio.shape[1] > self.MAX_LENGTH:
low_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(low_quality_audio_pt) low_quality_audio = low_quality_audio[:, :self.MAX_LENGTH]
except Exception as e:
# Handle cases where mono conversion might fail (e.g., already mono)
# This depends on how AudioUtils is implemented. Let's assume it handles it.
print(f"Warning: Mono conversion issue with {self.input_files[idx]}: {e}. Using original.")
high_quality_audio_pt_mono = high_quality_audio_pt if high_quality_audio_pt.shape[0] == 1 else torch.mean(high_quality_audio_pt, dim=0, keepdim=True)
low_quality_audio_pt_mono = low_quality_audio_pt if low_quality_audio_pt.shape[0] == 1 else torch.mean(low_quality_audio_pt, dim=0, keepdim=True)
high_quality_audio = high_quality_audio.to(self.device)
low_quality_audio = low_quality_audio.to(self.device)
# --- Convert to NumPy Arrays --- return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)
high_quality_audio_np = high_quality_audio_pt_mono.numpy() # Shape (1, L)
low_quality_audio_np = low_quality_audio_pt_mono.numpy() # Shape (1, L)
# --- Pad or Truncate using NumPy ---
def process_numpy_audio(audio_np, max_len):
current_len = audio_np.shape[1] # Length is axis 1 for shape (1, L)
if current_len < max_len:
padding_needed = max_len - current_len
# np.pad format: ((pad_before_ax0, pad_after_ax0), (pad_before_ax1, pad_after_ax1), ...)
# We only pad axis 1 (length) at the end
audio_np = np.pad(audio_np, ((0, 0), (0, padding_needed)), mode='constant', constant_values=0)
elif current_len > max_len:
# Truncate axis 1 (length)
audio_np = audio_np[:, :max_len]
return audio_np
high_quality_audio_np = process_numpy_audio(high_quality_audio_np, self.MAX_LENGTH)
low_quality_audio_np = process_numpy_audio(low_quality_audio_np, self.MAX_LENGTH)
# --- Remove Device Handling ---
# .to(self.device) is removed.
# --- Return NumPy arrays and metadata ---
# Note: Arrays likely have shape (1, MAX_LENGTH) here
return (high_quality_audio_np, original_sample_rate), (low_quality_audio_np, mangled_sample_rate)

View File

@ -1,145 +1,63 @@
import jax import torch
import jax.numpy as jnp import torch.nn as nn
from flax import linen as nn import torch.nn.utils as utils
from typing import Sequence, Tuple
# Assume InstanceNorm1d and AttentionBlock are defined as in the generator conversion def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, spectral_norm=True, use_instance_norm=True):
# --- Custom InstanceNorm1d Implementation (from Generator) --- padding = (kernel_size // 2) * dilation
class InstanceNorm1d(nn.Module): conv_layer = nn.Conv1d(
features: int in_channels,
epsilon: float = 1e-5 out_channels,
use_scale: bool = True kernel_size=kernel_size,
use_bias: bool = True stride=stride,
@nn.compact dilation=dilation,
def __call__(self, x): padding=padding
if x.shape[-1] != self.features: )
raise ValueError(f"Input features {x.shape[-1]} does not match InstanceNorm1d features {self.features}")
mean = jnp.mean(x, axis=1, keepdims=True) if spectral_norm:
var = jnp.var(x, axis=1, keepdims=True) conv_layer = utils.spectral_norm(conv_layer)
normalized = (x - mean) / jnp.sqrt(var + self.epsilon)
if self.use_scale: layers = [conv_layer]
scale = self.param('scale', nn.initializers.ones, (self.features,)) layers.append(nn.LeakyReLU(0.2, inplace=True))
normalized *= scale
if self.use_bias: if use_instance_norm:
bias = self.param('bias', nn.initializers.zeros, (self.features,)) layers.append(nn.InstanceNorm1d(out_channels))
normalized += bias
return normalized return nn.Sequential(*layers)
# --- AttentionBlock Implementation (from Generator) ---
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
channels: int def __init__(self, channels):
@nn.compact super(AttentionBlock, self).__init__()
def __call__(self, x): self.attention = nn.Sequential(
ks1 = (1,) nn.Conv1d(channels, channels // 4, kernel_size=1),
attention_weights = nn.Conv(features=self.channels // 4, kernel_size=ks1, padding='SAME')(x) nn.ReLU(inplace=True),
attention_weights = nn.relu(attention_weights) nn.Conv1d(channels // 4, channels, kernel_size=1),
attention_weights = nn.Conv(features=self.channels, kernel_size=ks1, padding='SAME')(attention_weights) nn.Sigmoid()
attention_weights = nn.sigmoid(attention_weights) )
def forward(self, x):
attention_weights = self.attention(x)
return x * attention_weights return x * attention_weights
# --- Converted Discriminator Modules ---
class DiscriminatorBlock(nn.Module):
"""Equivalent of the PyTorch discriminator_block function."""
in_channels: int # Needed for clarity, though not strictly used by layers if input shape is known
out_channels: int
kernel_size: int = 3
stride: int = 1
dilation: int = 1
# spectral_norm: bool = True # Flag for where SN would be applied
use_instance_norm: bool = True
negative_slope: float = 0.2
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C_in)
Returns:
Output tensor (N, L', C_out) - L' depends on stride/padding
"""
# Flax Conv expects kernel_size, stride, dilation as sequences (tuples)
ks = (self.kernel_size,)
st = (self.stride,)
di = (self.dilation,)
# Padding='SAME' works reasonably well for stride=1 and stride=2 downsampling
# NOTE: Spectral Norm is omitted here.
# If implementing, you'd wrap or replace nn.Conv with a spectral-normalized version.
# conv_layer = SpectralNormConv1D(...) or wrap(nn.Conv(...))
y = nn.Conv(
features=self.out_channels,
kernel_size=ks,
strides=st,
kernel_dilation=di,
padding='SAME' # Often used in GANs
)(x)
# Apply LeakyReLU first (as in the original code if IN is used)
y = nn.leaky_relu(y, negative_slope=self.negative_slope)
# Conditionally apply InstanceNorm
if self.use_instance_norm:
y = InstanceNorm1d(features=self.out_channels)(y)
return y
class SISUDiscriminator(nn.Module): class SISUDiscriminator(nn.Module):
"""SISUDiscriminator model translated to Flax.""" def __init__(self, base_channels=16):
base_channels: int = 16 super(SISUDiscriminator, self).__init__()
layers = base_channels
self.model = nn.Sequential(
discriminator_block(1, layers, kernel_size=7, stride=1, spectral_norm=True, use_instance_norm=False),
discriminator_block(layers, layers * 2, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers * 2, layers * 4, kernel_size=5, stride=1, dilation=2, spectral_norm=True, use_instance_norm=True),
AttentionBlock(layers * 4),
discriminator_block(layers * 4, layers * 8, kernel_size=5, stride=1, dilation=4, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers * 8, layers * 4, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers * 4, layers * 2, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers * 2, layers, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers, 1, kernel_size=3, stride=1, spectral_norm=False, use_instance_norm=False)
)
@nn.compact self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, 1) - assumes single channel input
Returns:
Output tensor (N, 1) - logits
"""
if x.shape[-1] != 1:
raise ValueError(f"Input should have 1 channel (NLC format), got shape {x.shape}")
ch = self.base_channels def forward(self, x):
x = self.model(x)
# Block 1: 1 -> ch, k=7, s=1, d=1, SN=T, IN=F x = self.global_avg_pool(x)
# NOTE: Spectral Norm omitted x = x.view(x.size(0), -1)
y = DiscriminatorBlock(in_channels=1, out_channels=ch, kernel_size=7, stride=1, use_instance_norm=False)(x) return x
# Block 2: ch -> ch*2, k=5, s=2, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch, out_channels=ch*2, kernel_size=5, stride=2, use_instance_norm=True)(y)
# Block 3: ch*2 -> ch*4, k=5, s=1, d=2, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch*4, kernel_size=5, stride=1, dilation=2, use_instance_norm=True)(y)
# Attention Block
y = AttentionBlock(channels=ch*4)(y)
# Block 4: ch*4 -> ch*8, k=5, s=1, d=4, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*8, kernel_size=5, stride=1, dilation=4, use_instance_norm=True)(y)
# Block 5: ch*8 -> ch*4, k=5, s=2, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*8, out_channels=ch*4, kernel_size=5, stride=2, use_instance_norm=True)(y)
# Block 6: ch*4 -> ch*2, k=3, s=1, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*2, kernel_size=3, stride=1, use_instance_norm=True)(y)
# Block 7: ch*2 -> ch, k=3, s=1, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch, kernel_size=3, stride=1, use_instance_norm=True)(y)
# Block 8: ch -> 1, k=3, s=1, d=1, SN=F, IN=F
# NOTE: Spectral Norm omitted (as per original config)
y = DiscriminatorBlock(in_channels=ch, out_channels=1, kernel_size=3, stride=1, use_instance_norm=False)(y)
# Global Average Pooling (across Length dimension)
pooled = jnp.mean(y, axis=1) # Shape becomes (N, C=1)
# Flatten (optional, as shape is likely already (N, 1))
output = jnp.reshape(pooled, (pooled.shape[0], -1)) # Shape (N, 1)
return output

View File

@ -1,173 +1,74 @@
import jax import torch
import jax.numpy as jnp import torch.nn as nn
from flax import linen as nn
from typing import Sequence, Tuple
# --- Custom InstanceNorm1d Implementation --- def conv_block(in_channels, out_channels, kernel_size=3, dilation=1):
class InstanceNorm1d(nn.Module): return nn.Sequential(
""" nn.Conv1d(
Flax implementation of Instance Normalization for 1D data (NLC format). in_channels,
Normalizes across the 'L' dimension. out_channels,
""" kernel_size=kernel_size,
features: int dilation=dilation,
epsilon: float = 1e-5 padding=(kernel_size // 2) * dilation
use_scale: bool = True ),
use_bias: bool = True nn.InstanceNorm1d(out_channels),
nn.PReLU()
@nn.compact )
def __call__(self, x):
"""
Args:
x: Input tensor of shape (batch, length, features)
Returns:
Normalized tensor.
"""
if x.shape[-1] != self.features:
raise ValueError(f"Input features {x.shape[-1]} does not match InstanceNorm1d features {self.features}")
# Calculate mean and variance across the length dimension (axis=1)
# Keep dims for broadcasting
mean = jnp.mean(x, axis=1, keepdims=True)
# Variance calculation using mean needs care for numerical stability if needed,
# but jnp.var should handle it.
var = jnp.var(x, axis=1, keepdims=True)
# Normalize
normalized = (x - mean) / jnp.sqrt(var + self.epsilon)
# Apply learnable scale and bias if enabled
if self.use_scale:
# Parameter shape: (features,) to broadcast across N and L
scale = self.param('scale', nn.initializers.ones, (self.features,))
normalized *= scale
if self.use_bias:
# Parameter shape: (features,)
bias = self.param('bias', nn.initializers.zeros, (self.features,))
normalized += bias
return normalized
# --- Converted Modules ---
class ConvBlock(nn.Module):
"""Equivalent of the PyTorch conv_block function."""
out_channels: int
kernel_size: int = 3
dilation: int = 1
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C_in)
Returns:
Output tensor (N, L, C_out)
"""
# Flax Conv expects kernel_size and dilation as sequences (tuples)
ks = (self.kernel_size,)
di = (self.dilation,)
# Padding='SAME' attempts to preserve the length dimension for stride=1
x = nn.Conv(
features=self.out_channels,
kernel_size=ks,
kernel_dilation=di,
padding='SAME'
)(x)
x = InstanceNorm1d(features=self.out_channels)(x) # Use custom InstanceNorm
x = nn.PReLU()(x) # PReLU learns 'alpha' parameter per channel
return x
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
"""Simple Channel Attention Block in Flax.""" """
channels: int Simple Channel Attention Block. Learns to weight channels based on their importance.
"""
@nn.compact def __init__(self, channels):
def __call__(self, x): super(AttentionBlock, self).__init__()
""" self.attention = nn.Sequential(
Args: nn.Conv1d(channels, channels // 4, kernel_size=1),
x: Input tensor (N, L, C) nn.ReLU(inplace=True),
Returns: nn.Conv1d(channels // 4, channels, kernel_size=1),
Attention-weighted output tensor (N, L, C) nn.Sigmoid()
""" )
# Flax Conv expects kernel_size as a sequence (tuple)
ks1 = (1,)
attention_weights = nn.Conv(
features=self.channels // 4, kernel_size=ks1, padding='SAME'
)(x)
# NOTE: PyTorch used inplace=True, JAX/Flax don't modify inplace
attention_weights = nn.relu(attention_weights)
attention_weights = nn.Conv(
features=self.channels, kernel_size=ks1, padding='SAME'
)(attention_weights)
attention_weights = nn.sigmoid(attention_weights)
def forward(self, x):
attention_weights = self.attention(x)
return x * attention_weights return x * attention_weights
class ResidualInResidualBlock(nn.Module): class ResidualInResidualBlock(nn.Module):
"""ResidualInResidualBlock in Flax.""" def __init__(self, channels, num_convs=3):
channels: int super(ResidualInResidualBlock, self).__init__()
num_convs: int = 3
@nn.compact self.conv_layers = nn.Sequential(
def __call__(self, x): *[conv_block(channels, channels) for _ in range(num_convs)]
""" )
Args:
x: Input tensor (N, L, C) self.attention = AttentionBlock(channels)
Returns:
Output tensor (N, L, C) def forward(self, x):
"""
residual = x residual = x
y = x x = self.conv_layers(x)
# Sequentially apply ConvBlocks x = self.attention(x)
for _ in range(self.num_convs): return x + residual
y = ConvBlock(
out_channels=self.channels,
kernel_size=3, # Assuming kernel_size 3 as in original conv_block default
dilation=1 # Assuming dilation 1 as in original conv_block default
)(y)
y = AttentionBlock(channels=self.channels)(y)
return y + residual
class SISUGenerator(nn.Module): class SISUGenerator(nn.Module):
"""SISUGenerator model translated to Flax.""" def __init__(self, channels=16, num_rirb=4, alpha=1.0):
channels: int = 16 super(SISUGenerator, self).__init__()
num_rirb: int = 4 self.alpha = alpha
alpha: float = 1.0 # Non-learnable parameter, passed during init
@nn.compact self.conv1 = nn.Sequential(
def __call__(self, x): nn.Conv1d(1, channels, kernel_size=7, padding=3),
""" nn.InstanceNorm1d(channels),
Args: nn.PReLU(),
x: Input tensor (N, L, 1) - assumes single channel input )
Returns:
Output tensor (N, L, 1)
"""
if x.shape[-1] != 1:
raise ValueError(f"Input should have 1 channel (NLC format), got shape {x.shape}")
self.rir_blocks = nn.Sequential(
*[ResidualInResidualBlock(channels) for _ in range(num_rirb)]
)
self.final_layer = nn.Conv1d(channels, 1, kernel_size=3, padding=1)
def forward(self, x):
residual_input = x residual_input = x
x = self.conv1(x)
# Initial convolution block x_rirb_out = self.rir_blocks(x)
# Flax Conv expects kernel_size as sequence learned_residual = self.final_layer(x_rirb_out)
ks7 = (7,)
ks3 = (3,)
y = nn.Conv(features=self.channels, kernel_size=ks7, padding='SAME')(x)
y = InstanceNorm1d(features=self.channels)(y)
y = nn.PReLU()(y)
# Residual-in-Residual Blocks
rirb_out = y
for _ in range(self.num_rirb):
rirb_out = ResidualInResidualBlock(channels=self.channels)(rirb_out)
# Final layer
learned_residual = nn.Conv(
features=1, kernel_size=ks3, padding='SAME'
)(rirb_out)
# Combine with input residual
output = residual_input + self.alpha * learned_residual output = residual_input + self.alpha * learned_residual
return output return output

View File

@ -1,33 +1,46 @@
import jax
import jax.numpy as jnp
import optax
import tqdm
import pickle # Using pickle for simplicity to save JAX states
import os
import argparse
# You might need a JAX-compatible library for audio loading/saving or convert to numpy
import scipy.io.wavfile as wavfile # Example for saving audio
import torch import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import tqdm
import argparse
import math
import os
from torch.utils.data import random_split
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import file_utils as Data import AudioUtils
from data import AudioDatasetNumPy from data import AudioDataset
from generator import SISUGenerator from generator import SISUGenerator
from discriminator import SISUDiscriminator from discriminator import SISUDiscriminator
from training_utils import discriminator_train, generator_train
import file_utils as Data
import torchaudio.transforms as T
# Init script argument parser # Init script argument parser
parser = argparse.ArgumentParser(description="Training script") parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None, parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file") help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None, parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file") help="Path to the discriminator model file")
parser.add_argument("--epoch", type=int, default=0, help="Starting epoch for model versioning") parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
parser.add_argument("--debug", action="store_true", help="Print debug logs") parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models") parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models")
args = parser.parse_args() args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Parameters # Parameters
sample_rate = 44100 sample_rate = 44100
n_fft = 2048 n_fft = 2048
@ -36,446 +49,146 @@ win_length = n_fft
n_mels = 128 n_mels = 128
n_mfcc = 20 # If using MFCC n_mfcc = 20 # If using MFCC
debug = args.debug mfcc_transform = T.MFCC(
sample_rate,
n_mfcc,
melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
).to(device)
# Initialize JAX random key mel_transform = T.MelSpectrogram(
key = jax.random.PRNGKey(0) sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
).to(device)
stft_transform = T.Spectrogram(
n_fft=n_fft, win_length=win_length, hop_length=hop_length
).to(device)
debug = args.debug
# Initialize dataset and dataloader # Initialize dataset and dataloader
dataset_dir = './dataset/good' dataset_dir = './dataset/good'
dataset = AudioDatasetNumPy(dataset_dir) # Use your JAX dataset dataset = AudioDataset(dataset_dir, device)
train_data_loader = DataLoader(dataset, batch_size=4, shuffle=True) # Use your JAX DataLoader
models_dir = "models" models_dir = "models"
os.makedirs(models_dir, exist_ok=True) os.makedirs(models_dir, exist_ok=True)
audio_output_dir = "output" audio_output_dir = "output"
os.makedirs(audio_output_dir, exist_ok=True) os.makedirs(audio_output_dir, exist_ok=True)
# ========= SINGLE =========
train_data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
# ========= MODELS ========= # ========= MODELS =========
try: generator = SISUGenerator()
# Fetch the first batch discriminator = SISUDiscriminator()
first_batch = next(iter(train_data_loader))
# The batch is a tuple: ((high_quality_audio_np, high_quality_sample_rate), (low_quality_audio_np, low_quality_sample_rate))
# We need the high-quality audio NumPy array batch for initialization
sample_input_np = first_batch[0][0] # Get the high-quality audio NumPy array batch
# Convert the NumPy array batch to a JAX array
sample_input_array = jnp.array(sample_input_np)
# === FIX === epoch: int = args.epoch
# Transpose the array from (batch, channels, length) to (batch, length, channels) epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
# The original shape from DataLoader is likely (batch, channels, length) like (4, 1, 44100)
# The generator expects NLC format (batch, length, channels) i.e., (4, 44100, 1)
if sample_input_array.ndim == 3 and sample_input_array.shape[1] == 1:
sample_input_array = jnp.transpose(sample_input_array, (0, 2, 1)) # Swap axes 1 and 2
print(sample_input_array.shape) # Should now print (4, 44100, 1)
# === END FIX ===
except StopIteration:
print("Error: Data loader is empty. Cannot initialize models.")
exit() # Exit if no data is available
key, init_key_g, init_key_d = jax.random.split(key, 3)
generator_model = SISUGenerator()
discriminator_model = SISUDiscriminator()
# Initialize parameters
generator_params = generator_model.init(init_key_g, sample_input_array)['params']
discriminator_params = discriminator_model.init(init_key_d, sample_input_array)['params']
# Define apply functions
generator_apply_fn = generator_model.apply
discriminator_apply_fn = discriminator_model.apply
# Loss functions (JAX equivalents)
# Optax provides common loss functions. BCEWithLogitsLoss is equivalent to
# sigmoid_binary_cross_entropy in Optax combined with a sigmoid activation
# in the model output or handling logits directly. Assuming your discriminator
# outputs logits, optax.sigmoid_binary_cross_entropy is appropriate.
criterion_d = optax.sigmoid_binary_cross_entropy
criterion_l1 = optax.sigmoid_binary_cross_entropy # For Mel, STFT, MFCC losses
# Optimizers (using Optax)
optimizer_g = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.999)
optimizer_d = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.999)
# Initialize optimizer states
generator_opt_state = optimizer_g.init(generator_params)
discriminator_opt_state = optimizer_d.init(discriminator_params)
# Schedulers - Optax has learning rate schedules. ReduceLROnPlateau
# is stateful and usually handled outside the jitted training step,
# or you can implement a custom learning rate schedule in Optax that
# takes a metric. For simplicity here, we won't directly replicate the
# PyTorch ReduceLROnPlateau but you could add logic based on losses
# in the main loop to adjust the learning rate if needed.
# Load saved state if continuing training
start_epoch = args.epoch
if args.continue_training: if args.continue_training:
try: generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
with open(f"{models_dir}/temp_generator.pkl", 'rb') as f: discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
loaded_state = pickle.load(f) epoch = epoch_from_file["epoch"] + 1
generator_params = loaded_state['params'] else:
generator_opt_state = loaded_state['opt_state'] if args.generator is not None:
with open(f"{models_dir}/temp_discriminator.pkl", 'rb') as f: generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
loaded_state = pickle.load(f) if args.discriminator is not None:
discriminator_params = loaded_state['params'] discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
discriminator_opt_state = loaded_state['opt_state']
epoch_data = Data.read_data(f"{models_dir}/epoch_data.json")
start_epoch = epoch_data.get("epoch", 0) + 1
print(f"Continuing training from epoch {start_epoch}")
except FileNotFoundError:
print("Continue training requested but temp models not found. Starting from scratch.")
except Exception as e:
print(f"Error loading temp models: {e}. Starting from scratch.")
if args.generator is not None: generator = generator.to(device)
try: discriminator = discriminator.to(device)
with open(args.generator, 'rb') as f:
loaded_state = pickle.load(f)
generator_params = loaded_state['params']
print(f"Loaded generator from {args.generator}")
except FileNotFoundError:
print(f"Generator model not found at {args.generator}")
if args.discriminator is not None: # Loss
try: criterion_g = nn.BCEWithLogitsLoss()
with open(args.discriminator, 'rb') as f: criterion_d = nn.BCEWithLogitsLoss()
loaded_state = pickle.load(f)
discriminator_params = loaded_state['params']
print(f"Loaded discriminator from {args.discriminator}")
except FileNotFoundError:
print(f"Discriminator model not found at {args.discriminator}")
# Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
# Initialize JAX audio transforms # Scheduler
# mel_transform_fn = MelSpectrogramJAX( scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
# sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
# win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
# )
# stft_transform_fn = SpectrogramJAX(
# n_fft=n_fft, win_length=win_length, hop_length=hop_length
# )
# mfcc_transform_fn = MFCCJAX(
# sample_rate,
# n_mfcc,
# melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
# )
# ========= JAX TRAINING STEPS =========
@jax.jit
def discriminator_train_step(
discriminator_params,
generator_params,
discriminator_opt_state,
high_quality_audio, # JAX array (batch, length, channels)
low_quality_audio, # JAX array (batch, length, channels)
real_labels, # JAX array
fake_labels, # JAX array
discriminator_apply_fn,
generator_apply_fn,
discriminator_optimizer,
criterion_d,
key # JAX random key
):
# Split key for potential randomness in model application
key, disc_key, gen_key = jax.random.split(key, 3)
def loss_fn(d_params):
# Generate fake audio
# Note: Generator is not being trained in this step, so its parameters are static
# Ensure low_quality_audio is in the expected NLC format (batch, length, channels)
if low_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
low_quality_audio = jnp.expand_dims(low_quality_audio, axis=-1)
elif low_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
low_quality_audio = jnp.expand_dims(jnp.expand_dims(low_quality_audio, axis=0), axis=-1)
enhanced_audio, _ = generator_apply_fn({'params': generator_params}, gen_key, low_quality_audio)
# Pass data through the discriminator
# Ensure enhanced_audio has a leading dimension if not already present (e.g., batch size)
if enhanced_audio.ndim == 2: # Assuming (length, channel), add batch dim
enhanced_audio = jnp.expand_dims(enhanced_audio, axis=0)
elif enhanced_audio.ndim == 1: # Assuming (length), add batch and channel dims
enhanced_audio = jnp.expand_dims(jnp.expand_dims(enhanced_audio, axis=0), axis=-1)
# Ensure high_quality_audio is in the expected NLC format (batch, length, channels)
if high_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
high_quality_audio = jnp.expand_dims(high_quality_audio, axis=-1)
elif high_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
high_quality_audio = jnp.expand_dims(jnp.expand_dims(high_quality_audio, axis=0), axis=-1)
real_output = discriminator_apply_fn({'params': d_params}, disc_key, high_quality_audio)
fake_output = discriminator_apply_fn({'params': d_params}, disc_key, enhanced_audio)
# Calculate loss (criterion_d is assumed to be Optax's sigmoid_binary_cross_entropy or similar)
# Ensure the shapes match the labels (batch_size, 1)
real_output = real_output.reshape(-1, 1)
fake_output = fake_output.reshape(-1, 1)
real_loss = jnp.mean(criterion_d(real_output, real_labels))
fake_loss = jnp.mean(criterion_d(fake_output, fake_labels))
total_loss = real_loss + fake_loss
return total_loss, (real_loss, fake_loss)
# Compute gradients
# Use jax.value_and_grad to get both the loss value and the gradients
(loss, (real_loss, fake_loss)), grads = jax.value_and_grad(loss_fn, has_aux=True)(discriminator_params)
# Apply updates
updates, new_discriminator_opt_state = discriminator_optimizer.update(grads, discriminator_opt_state, discriminator_params)
new_discriminator_params = optax.apply_updates(discriminator_params, updates)
return new_discriminator_params, new_discriminator_opt_state, loss, key
@jax.jit
def generator_train_step(
generator_params,
discriminator_params,
generator_opt_state,
low_quality_audio, # JAX array (batch, length, channels)
high_quality_audio, # JAX array (batch, length, channels)
real_labels, # JAX array
generator_apply_fn,
discriminator_apply_fn,
generator_optimizer,
criterion_d, # Adversarial loss
criterion_l1, # Feature matching loss
key # JAX random key
):
# Split key for potential randomness
key, gen_key, disc_key = jax.random.split(key, 3)
def loss_fn(g_params):
# Ensure low_quality_audio is in the expected NLC format (batch, length, channels)
if low_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
low_quality_audio = jnp.expand_dims(low_quality_audio, axis=-1)
elif low_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
low_quality_audio = jnp.expand_dims(jnp.expand_dims(low_quality_audio, axis=0), axis=-1)
# Generate enhanced audio
enhanced_audio, _ = generator_apply_fn({'params': g_params}, gen_key, low_quality_audio)
# Ensure enhanced_audio has a leading dimension if not already present
if enhanced_audio.ndim == 2: # Assuming (length, channel), add batch dim
enhanced_audio = jnp.expand_dims(enhanced_audio, axis=0)
elif enhanced_audio.ndim == 1: # Assuming (length), add batch and channel dims
enhanced_audio = jnp.expand_dims(jnp.expand_dims(enhanced_audio, axis=0), axis=-1)
# Calculate adversarial loss (generator wants discriminator to think fake is real)
# Note: Discriminator is not being trained in this step, so its parameters are static
fake_output = discriminator_apply_fn({'params': discriminator_params}, disc_key, enhanced_audio)
# Ensure the shape matches the labels (batch_size, 1)
fake_output = fake_output.reshape(-1, 1)
adversarial_loss = jnp.mean(criterion_d(fake_output, real_labels)) # Generator wants fake_output to be close to real_labels (1s)
# Feature matching losses (assuming you add these back later)
# You would need to implement JAX versions of your audio transforms
# mel_loss = criterion_l1(mel_transform_fn(enhanced_audio), mel_transform_fn(high_quality_audio))
# stft_loss = criterion_l1(stft_transform_fn(enhanced_audio), stft_transform_fn(high_quality_audio))
# mfcc_loss = criterion_l1(mfcc_transform_fn(enhanced_audio), mfcc_transform_fn(high_quality_audio))
# combined_loss = adversarial_loss + mel_loss + stft_loss + mfcc_loss
combined_loss = adversarial_loss # For now, only adversarial loss
# Return combined_loss and any other metrics needed for logging/analysis
# For now, just adversarial loss and enhanced_audio
return combined_loss, (adversarial_loss, enhanced_audio) # Add other losses here when implemented
# Compute gradients
# Update: loss_fn now returns (loss, (aux1, aux2, ...))
(loss, (adversarial_loss_val, enhanced_audio)), grads = jax.value_and_grad(loss_fn, has_aux=True)(generator_params)
# Apply updates
updates, new_generator_opt_state = generator_optimizer.update(grads, generator_opt_state, generator_params)
new_generator_params = optax.apply_updates(generator_params, updates)
# Return the loss components separately along with the enhanced audio and key
return new_generator_params, new_generator_opt_state, loss, adversarial_loss_val, enhanced_audio, key
# ========= MAIN TRAINING LOOP =========
def start_training(): def start_training():
global generator_params, discriminator_params, generator_opt_state, discriminator_opt_state, key
generator_epochs = 5000 generator_epochs = 5000
for generator_epoch in range(generator_epochs): for generator_epoch in range(generator_epochs):
current_epoch = start_epoch + generator_epoch low_quality_audio = (torch.empty((1)), 1)
high_quality_audio = (torch.empty((1)), 1)
ai_enhanced_audio = (torch.empty((1)), 1)
# These will hold the last processed audio examples from a batch for saving times_correct = 0
last_high_quality_audio = None
last_low_quality_audio = None
last_ai_enhanced_audio = None
last_sample_rate = None
# ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
# for high_quality_clip, low_quality_clip in train_data_loader:
high_quality_sample = (high_quality_clip[0], high_quality_clip[1])
low_quality_sample = (low_quality_clip[0], low_quality_clip[1])
# Use tqdm for progress bar # ========= LABELS =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {current_epoch}"): batch_size = high_quality_clip[0].size(0)
real_labels = torch.ones(batch_size, 1).to(device)
# high_quality_clip and low_quality_clip are tuples: (audio_array, sample_rate_array) fake_labels = torch.zeros(batch_size, 1).to(device)
# Extract audio arrays and sample rates (assuming batch dimension is first)
# The arrays are NumPy arrays at this point, likely in (batch, channels, length) format
high_quality_audio_batch_np = high_quality_clip[0]
low_quality_audio_batch_np = low_quality_clip[0]
sample_rate_batch_np = high_quality_clip[1] # Assuming sample rates are the same for paired clips
# Convert NumPy arrays to JAX arrays and transpose to NLC format (batch, length, channels)
# Only transpose if the shape is (batch, channels, length)
if high_quality_audio_batch_np.ndim == 3 and high_quality_audio_batch_np.shape[1] == 1:
high_quality_audio_batch = jnp.transpose(jnp.array(high_quality_audio_batch_np), (0, 2, 1))
else:
high_quality_audio_batch = jnp.array(high_quality_audio_batch_np) # Assume already NLC or handle other cases
if low_quality_audio_batch_np.ndim == 3 and low_quality_audio_batch_np.shape[1] == 1:
low_quality_audio_batch = jnp.transpose(jnp.array(low_quality_audio_batch_np), (0, 2, 1))
else:
low_quality_audio_batch = jnp.array(low_quality_audio_batch_np) # Assume already NLC or handle other cases
sample_rate_batch = jnp.array(sample_rate_batch_np)
batch_size = high_quality_audio_batch.shape[0]
# Create labels - JAX arrays
real_labels = jnp.ones((batch_size, 1))
fake_labels = jnp.zeros((batch_size, 1))
# Split key for each batch
key, batch_key = jax.random.split(key)
# ========= DISCRIMINATOR ========= # ========= DISCRIMINATOR =========
# Call the jitted discriminator training step discriminator.train()
discriminator_params, discriminator_opt_state, d_loss, batch_key = discriminator_train_step( d_loss = discriminator_train(
discriminator_params, high_quality_sample,
generator_params, low_quality_sample,
discriminator_opt_state,
high_quality_audio_batch,
low_quality_audio_batch,
real_labels, real_labels,
fake_labels, fake_labels,
discriminator_apply_fn, discriminator,
generator_apply_fn, generator,
optimizer_d,
criterion_d, criterion_d,
batch_key optimizer_d
) )
# ========= GENERATOR ========= # ========= GENERATOR =========
# Call the jitted generator training step generator.train()
generator_params, generator_opt_state, combined_loss, adversarial_loss, enhanced_audio_batch, batch_key = generator_train_step( generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train(
generator_params, low_quality_sample,
discriminator_params, high_quality_sample,
generator_opt_state, real_labels,
low_quality_audio_batch, generator,
high_quality_audio_batch, discriminator,
real_labels, # Generator tries to make fake data look real
generator_apply_fn,
discriminator_apply_fn,
optimizer_g,
criterion_d, criterion_d,
criterion_l1, optimizer_g,
batch_key device,
mel_transform,
stft_transform,
mfcc_transform
) )
# Print debug logs (requires waiting for JIT compilation on first step)
if debug: if debug:
# Use .block_until_ready() to ensure computation is finished before printing print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}")
# In a real scenario, you might want to log metrics less frequently scheduler_d.step(d_loss.detach())
d_loss_val = d_loss.block_until_ready().item() scheduler_g.step(adversarial_loss.detach())
combined_loss_val = combined_loss.block_until_ready().item()
adversarial_loss_val = adversarial_loss.block_until_ready().item()
# Assuming other losses are returned by generator_train_step and unpacked
# mel_loss_val = mel_l1_tensor.block_until_ready().item() if mel_l1_tensor is not None else 0
# stft_loss_val = log_stft_l1_tensor.block_until_ready().item() if log_stft_l1_tensor is not None else 0
# mfcc_loss_val = mfcc_l_tensor.block_until_ready().item() if mfcc_l_tensor is not None else 0
print(f"D_LOSS: {d_loss_val:.4f}, G_COMBINED_LOSS: {combined_loss_val:.4f}, G_ADVERSARIAL_LOSS: {adversarial_loss_val:.4f}")
# Print other losses here when implemented and returned
# Schedulers - Implement your learning rate scheduling logic here if needed # ========= SAVE LATEST AUDIO =========
# based on the losses (e.g., reducing learning rate if loss plateaus). high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])
# This logic would typically live outside the jitted step function. low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0])
# For Optax, you might use a schedule within the optimizer definition ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0])
# or update the learning rate of the optimizer manually.
new_epoch = generator_epoch+epoch
if generator_epoch % 25 == 0:
print(f"Saved epoch {new_epoch}!")
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu().detach(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1])
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1])
#if debug:
# print(generator.state_dict().keys())
# print(discriminator.state_dict().keys())
torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt")
torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt")
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch})
# ========= SAVE LATEST AUDIO (from the last batch processed) ========= torch.save(discriminator, "models/epoch-5000-discriminator.pt")
# Access the first sample of the batch for saving torch.save(generator, "models/epoch-5000-generator.pt")
# Ensure enhanced_audio_batch has a batch dimension and is in NLC format print("Training complete!")
if enhanced_audio_batch.ndim == 2: # Assuming (length, channel), add batch dim
enhanced_audio_batch = jnp.expand_dims(enhanced_audio_batch, axis=0)
elif enhanced_audio_batch.ndim == 1: # Assuming (length), add batch and channel dims
enhanced_audio_batch = jnp.expand_dims(jnp.expand_dims(enhanced_audio_batch, axis=0), axis=-1)
last_high_quality_audio = high_quality_audio_batch[0]
last_low_quality_audio = low_quality_audio_batch[0]
last_ai_enhanced_audio = enhanced_audio_batch[0]
last_sample_rate = sample_rate_batch[0].item() # Assuming sample rate is scalar per batch item
# Save audio files periodically (outside the batch loop)
if generator_epoch % 25 == 0 and last_high_quality_audio is not None:
print(f"Saving audio for epoch {current_epoch}!")
try:
# Convert JAX arrays to NumPy arrays for saving
# Transpose back to (length, channels) or (length) if needed by wavfile.write
# Assuming the models output (length, 1) or (length) after removing batch dim
low_quality_audio_np_save = jax.device_get(last_low_quality_audio)
ai_enhanced_audio_np_save = jax.device_get(last_ai_enhanced_audio)
high_quality_audio_np_save = jax.device_get(last_high_quality_audio)
# Remove the channel dimension if it's 1 for saving with wavfile
if low_quality_audio_np_save.shape[-1] == 1:
low_quality_audio_np_save = low_quality_audio_np_save.squeeze(axis=-1)
if ai_enhanced_audio_np_save.shape[-1] == 1:
ai_enhanced_audio_np_save = ai_enhanced_audio_np_save.squeeze(axis=-1)
if high_quality_audio_np_save.shape[-1] == 1:
high_quality_audio_np_save = high_quality_audio_np_save.squeeze(axis=-1)
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-crap.wav", last_sample_rate, low_quality_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-ai.wav", last_sample_rate, ai_enhanced_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-orig.wav", last_sample_rate, high_quality_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
except Exception as e:
print(f"Error saving audio files: {e}")
# Save model states periodically (outside the batch loop)
# Use pickle to save parameters and optimizer states
try:
with open(f"{models_dir}/temp_discriminator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(discriminator_params), 'opt_state': jax.device_get(discriminator_opt_state)}, f)
with open(f"{models_dir}/temp_generator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(generator_params), 'opt_state': jax.device_get(generator_opt_state)}, f)
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": current_epoch})
except Exception as e:
print(f"Error saving temp model states: {e}")
# Save final model states after all epochs
print("Training complete! Saving final models.")
try:
with open(f"{models_dir}/epoch-{start_epoch + generator_epochs - 1}-discriminator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(discriminator_params)}, f)
with open(f"{models_dir}/epoch-{start_epoch + generator_epochs - 1}-generator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(generator_params)}, f)
except Exception as e:
print(f"Error saving final model states: {e}")
start_training() start_training()

View File

@ -1,194 +0,0 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import tqdm
import argparse
import math
import os
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import AudioUtils
from data import AudioDataset
from generator import SISUGenerator
from discriminator import SISUDiscriminator
from training_utils import discriminator_train, generator_train
import file_utils as Data
import torchaudio.transforms as T
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models")
args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Parameters
sample_rate = 44100
n_fft = 2048
hop_length = 256
win_length = n_fft
n_mels = 128
n_mfcc = 20 # If using MFCC
mfcc_transform = T.MFCC(
sample_rate,
n_mfcc,
melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
).to(device)
mel_transform = T.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
).to(device)
stft_transform = T.Spectrogram(
n_fft=n_fft, win_length=win_length, hop_length=hop_length
).to(device)
debug = args.debug
# Initialize dataset and dataloader
dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir, device)
models_dir = "models"
os.makedirs(models_dir, exist_ok=True)
audio_output_dir = "output"
os.makedirs(audio_output_dir, exist_ok=True)
# ========= SINGLE =========
train_data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
# ========= MODELS =========
generator = SISUGenerator()
discriminator = SISUDiscriminator()
epoch: int = args.epoch
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
if args.continue_training:
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
epoch = epoch_from_file["epoch"] + 1
else:
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
if args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
generator = generator.to(device)
discriminator = discriminator.to(device)
# Loss
criterion_g = nn.BCEWithLogitsLoss()
criterion_d = nn.BCEWithLogitsLoss()
# Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
# Scheduler
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
def start_training():
generator_epochs = 5000
for generator_epoch in range(generator_epochs):
low_quality_audio = (torch.empty((1)), 1)
high_quality_audio = (torch.empty((1)), 1)
ai_enhanced_audio = (torch.empty((1)), 1)
times_correct = 0
# ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
# for high_quality_clip, low_quality_clip in train_data_loader:
high_quality_sample = (high_quality_clip[0], high_quality_clip[1])
low_quality_sample = (low_quality_clip[0], low_quality_clip[1])
# ========= LABELS =========
batch_size = high_quality_clip[0].size(0)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# ========= DISCRIMINATOR =========
discriminator.train()
d_loss = discriminator_train(
high_quality_sample,
low_quality_sample,
real_labels,
fake_labels,
discriminator,
generator,
criterion_d,
optimizer_d
)
# ========= GENERATOR =========
generator.train()
generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train(
low_quality_sample,
high_quality_sample,
real_labels,
generator,
discriminator,
criterion_d,
optimizer_g,
device,
mel_transform,
stft_transform,
mfcc_transform
)
if debug:
print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}")
scheduler_d.step(d_loss.detach())
scheduler_g.step(adversarial_loss.detach())
# ========= SAVE LATEST AUDIO =========
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])
low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0])
ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0])
new_epoch = generator_epoch+epoch
if generator_epoch % 25 == 0:
print(f"Saved epoch {new_epoch}!")
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu().detach(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1])
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1])
#if debug:
# print(generator.state_dict().keys())
# print(discriminator.state_dict().keys())
torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt")
torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt")
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch})
torch.save(discriminator, "models/epoch-5000-discriminator.pt")
torch.save(generator, "models/epoch-5000-generator.pt")
print("Training complete!")
start_training()

View File

@ -20,10 +20,12 @@ def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tenso
mel_spec_true = mel_transform(y_true) mel_spec_true = mel_transform(y_true)
mel_spec_pred = mel_transform(y_pred) mel_spec_pred = mel_transform(y_pred)
# Ensure same time dimension length (due to potential framing differences)
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1]) min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
mel_spec_true = mel_spec_true[..., :min_len] mel_spec_true = mel_spec_true[..., :min_len]
mel_spec_pred = mel_spec_pred[..., :min_len] mel_spec_pred = mel_spec_pred[..., :min_len]
# L1 Loss (Mean Absolute Error)
loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred)) loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred))
return loss return loss