7 Commits

9 changed files with 1178 additions and 267 deletions

View File

@ -18,6 +18,7 @@ SISU (Super Ingenious Sound Upscaler) is a project that uses GANs (Generative Ad
1. **Set Up**: 1. **Set Up**:
- Make sure you have Python installed (version 3.8 or higher). - Make sure you have Python installed (version 3.8 or higher).
- Install needed packages: `pip install -r requirements.txt` - Install needed packages: `pip install -r requirements.txt`
- Install current version of PyTorch (CUDA/ROCm/What ever your device supports)
2. **Prepare Audio Data**: 2. **Prepare Audio Data**:
- Put your audio files in the `dataset/good` folder. - Put your audio files in the `dataset/good` folder.

109
data.py
View File

@ -1,53 +1,104 @@
# Keep necessary PyTorch imports for torchaudio and Dataset structure
from torch.utils.data import Dataset from torch.utils.data import Dataset
import torch.nn.functional as F
import torch import torch
import torchaudio import torchaudio
import torchaudio.transforms as T # Keep using torchaudio transforms
# Import NumPy
import numpy as np
import os import os
import random import random
import torchaudio.transforms as T # Assume AudioUtils is available and works on PyTorch Tensors as before
import AudioUtils import AudioUtils
class AudioDataset(Dataset): class AudioDatasetNumPy(Dataset): # Renamed slightly for clarity
audio_sample_rates = [11025] audio_sample_rates = [11025]
MAX_LENGTH = 44100 # Define your desired maximum length here MAX_LENGTH = 44100 # Define your desired maximum length here
def __init__(self, input_dir, device): def __init__(self, input_dir):
self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')] """
self.device = device Initializes the dataset. Device argument is removed.
"""
self.input_files = [
os.path.join(root, f)
for root, _, files in os.walk(input_dir)
for f in files if f.endswith('.wav')
]
if not self.input_files:
print(f"Warning: No .wav files found in {input_dir}")
def __len__(self): def __len__(self):
return len(self.input_files) return len(self.input_files)
def __getitem__(self, idx): def __getitem__(self, idx):
# Load high-quality audio """
high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True) Loads audio, processes it, and returns NumPy arrays.
"""
# --- Load and Resample using torchaudio (produces PyTorch tensors) ---
try:
high_quality_audio_pt, original_sample_rate = torchaudio.load(
self.input_files[idx], normalize=True
)
except Exception as e:
print(f"Error loading file {self.input_files[idx]}: {e}")
# Return None or raise error, or return dummy data if preferred
# Returning dummy data might hide issues
return None # Or handle appropriately
# Generate low-quality audio with random downsampling # Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates) mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio = resample_transform_low(high_quality_audio)
resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate) # Ensure sample rates are different before resampling
low_quality_audio = resample_transform_high(low_quality_audio) if original_sample_rate != mangled_sample_rate:
resample_transform_low = T.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio_pt = resample_transform_low(high_quality_audio_pt)
high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio) resample_transform_high = T.Resample(mangled_sample_rate, original_sample_rate)
low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio) low_quality_audio_pt = resample_transform_high(low_quality_audio_pt)
else:
# If rates match, just copy the tensor
low_quality_audio_pt = high_quality_audio_pt.clone()
# Pad or truncate high-quality audio
if high_quality_audio.shape[1] < self.MAX_LENGTH:
padding = self.MAX_LENGTH - high_quality_audio.shape[1]
high_quality_audio = F.pad(high_quality_audio, (0, padding))
elif high_quality_audio.shape[1] > self.MAX_LENGTH:
high_quality_audio = high_quality_audio[:, :self.MAX_LENGTH]
# Pad or truncate low-quality audio # --- Process Stereo to Mono (still using PyTorch tensors) ---
if low_quality_audio.shape[1] < self.MAX_LENGTH: # Assuming AudioUtils.stereo_tensor_to_mono expects PyTorch Tensor (C, L)
padding = self.MAX_LENGTH - low_quality_audio.shape[1] # and returns PyTorch Tensor (1, L)
low_quality_audio = F.pad(low_quality_audio, (0, padding)) try:
elif low_quality_audio.shape[1] > self.MAX_LENGTH: high_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(high_quality_audio_pt)
low_quality_audio = low_quality_audio[:, :self.MAX_LENGTH] low_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(low_quality_audio_pt)
except Exception as e:
# Handle cases where mono conversion might fail (e.g., already mono)
# This depends on how AudioUtils is implemented. Let's assume it handles it.
print(f"Warning: Mono conversion issue with {self.input_files[idx]}: {e}. Using original.")
high_quality_audio_pt_mono = high_quality_audio_pt if high_quality_audio_pt.shape[0] == 1 else torch.mean(high_quality_audio_pt, dim=0, keepdim=True)
low_quality_audio_pt_mono = low_quality_audio_pt if low_quality_audio_pt.shape[0] == 1 else torch.mean(low_quality_audio_pt, dim=0, keepdim=True)
high_quality_audio = high_quality_audio.to(self.device)
low_quality_audio = low_quality_audio.to(self.device)
return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate) # --- Convert to NumPy Arrays ---
high_quality_audio_np = high_quality_audio_pt_mono.numpy() # Shape (1, L)
low_quality_audio_np = low_quality_audio_pt_mono.numpy() # Shape (1, L)
# --- Pad or Truncate using NumPy ---
def process_numpy_audio(audio_np, max_len):
current_len = audio_np.shape[1] # Length is axis 1 for shape (1, L)
if current_len < max_len:
padding_needed = max_len - current_len
# np.pad format: ((pad_before_ax0, pad_after_ax0), (pad_before_ax1, pad_after_ax1), ...)
# We only pad axis 1 (length) at the end
audio_np = np.pad(audio_np, ((0, 0), (0, padding_needed)), mode='constant', constant_values=0)
elif current_len > max_len:
# Truncate axis 1 (length)
audio_np = audio_np[:, :max_len]
return audio_np
high_quality_audio_np = process_numpy_audio(high_quality_audio_np, self.MAX_LENGTH)
low_quality_audio_np = process_numpy_audio(low_quality_audio_np, self.MAX_LENGTH)
# --- Remove Device Handling ---
# .to(self.device) is removed.
# --- Return NumPy arrays and metadata ---
# Note: Arrays likely have shape (1, MAX_LENGTH) here
return (high_quality_audio_np, original_sample_rate), (low_quality_audio_np, mangled_sample_rate)

View File

@ -1,58 +1,145 @@
import torch import jax
import torch.nn as nn import jax.numpy as jnp
import torch.nn.utils as utils from flax import linen as nn
from typing import Sequence, Tuple
def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, spectral_norm=True): # Assume InstanceNorm1d and AttentionBlock are defined as in the generator conversion
padding = (kernel_size // 2) * dilation # --- Custom InstanceNorm1d Implementation (from Generator) ---
conv_layer = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) class InstanceNorm1d(nn.Module):
if spectral_norm: features: int
conv_layer = utils.spectral_norm(conv_layer) epsilon: float = 1e-5
return nn.Sequential( use_scale: bool = True
conv_layer, use_bias: bool = True
nn.LeakyReLU(0.2, inplace=True), @nn.compact
nn.BatchNorm1d(out_channels) def __call__(self, x):
) if x.shape[-1] != self.features:
raise ValueError(f"Input features {x.shape[-1]} does not match InstanceNorm1d features {self.features}")
mean = jnp.mean(x, axis=1, keepdims=True)
var = jnp.var(x, axis=1, keepdims=True)
normalized = (x - mean) / jnp.sqrt(var + self.epsilon)
if self.use_scale:
scale = self.param('scale', nn.initializers.ones, (self.features,))
normalized *= scale
if self.use_bias:
bias = self.param('bias', nn.initializers.zeros, (self.features,))
normalized += bias
return normalized
# --- AttentionBlock Implementation (from Generator) ---
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
def __init__(self, channels): channels: int
super(AttentionBlock, self).__init__() @nn.compact
self.attention = nn.Sequential( def __call__(self, x):
nn.Conv1d(channels, channels // 4, kernel_size=1), ks1 = (1,)
nn.ReLU(), attention_weights = nn.Conv(features=self.channels // 4, kernel_size=ks1, padding='SAME')(x)
nn.Conv1d(channels // 4, channels, kernel_size=1), attention_weights = nn.relu(attention_weights)
nn.Sigmoid() attention_weights = nn.Conv(features=self.channels, kernel_size=ks1, padding='SAME')(attention_weights)
) attention_weights = nn.sigmoid(attention_weights)
def forward(self, x):
attention_weights = self.attention(x)
return x * attention_weights return x * attention_weights
# --- Converted Discriminator Modules ---
class DiscriminatorBlock(nn.Module):
"""Equivalent of the PyTorch discriminator_block function."""
in_channels: int # Needed for clarity, though not strictly used by layers if input shape is known
out_channels: int
kernel_size: int = 3
stride: int = 1
dilation: int = 1
# spectral_norm: bool = True # Flag for where SN would be applied
use_instance_norm: bool = True
negative_slope: float = 0.2
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C_in)
Returns:
Output tensor (N, L', C_out) - L' depends on stride/padding
"""
# Flax Conv expects kernel_size, stride, dilation as sequences (tuples)
ks = (self.kernel_size,)
st = (self.stride,)
di = (self.dilation,)
# Padding='SAME' works reasonably well for stride=1 and stride=2 downsampling
# NOTE: Spectral Norm is omitted here.
# If implementing, you'd wrap or replace nn.Conv with a spectral-normalized version.
# conv_layer = SpectralNormConv1D(...) or wrap(nn.Conv(...))
y = nn.Conv(
features=self.out_channels,
kernel_size=ks,
strides=st,
kernel_dilation=di,
padding='SAME' # Often used in GANs
)(x)
# Apply LeakyReLU first (as in the original code if IN is used)
y = nn.leaky_relu(y, negative_slope=self.negative_slope)
# Conditionally apply InstanceNorm
if self.use_instance_norm:
y = InstanceNorm1d(features=self.out_channels)(y)
return y
class SISUDiscriminator(nn.Module): class SISUDiscriminator(nn.Module):
def __init__(self, layers=4): #Increased base layer count """SISUDiscriminator model translated to Flax."""
super(SISUDiscriminator, self).__init__() base_channels: int = 16
self.model = nn.Sequential(
discriminator_block(1, layers, kernel_size=3, stride=1), #Aggressive downsampling
discriminator_block(layers, layers * 2, kernel_size=5, stride=2),
discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=4),
#AttentionBlock(layers * 4), #Added attention @nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, 1) - assumes single channel input
Returns:
Output tensor (N, 1) - logits
"""
if x.shape[-1] != 1:
raise ValueError(f"Input should have 1 channel (NLC format), got shape {x.shape}")
#discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=4), ch = self.base_channels
#AttentionBlock(layers * 8), #Added attention
#discriminator_block(layers * 8, layers * 16, kernel_size=5, dilation=8),
#discriminator_block(layers * 16, layers * 16, kernel_size=3, dilation=1),
#discriminator_block(layers * 16, layers * 8, kernel_size=3, dilation=2),
#discriminator_block(layers * 8, layers * 4, kernel_size=3, dilation=1),
discriminator_block(layers * 4, layers * 2, kernel_size=5, stride=2),
discriminator_block(layers * 2, layers, kernel_size=3, stride=1),
discriminator_block(layers, 1, kernel_size=3, stride=1, spectral_norm=False) #last layer no spectral norm.
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
self.sigmoid = nn.Sigmoid()
def forward(self, x): # Block 1: 1 -> ch, k=7, s=1, d=1, SN=T, IN=F
x = self.model(x) # NOTE: Spectral Norm omitted
x = self.global_avg_pool(x) y = DiscriminatorBlock(in_channels=1, out_channels=ch, kernel_size=7, stride=1, use_instance_norm=False)(x)
x = x.view(-1, 1)
x = self.sigmoid(x) # Block 2: ch -> ch*2, k=5, s=2, d=1, SN=T, IN=T
return x # NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch, out_channels=ch*2, kernel_size=5, stride=2, use_instance_norm=True)(y)
# Block 3: ch*2 -> ch*4, k=5, s=1, d=2, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch*4, kernel_size=5, stride=1, dilation=2, use_instance_norm=True)(y)
# Attention Block
y = AttentionBlock(channels=ch*4)(y)
# Block 4: ch*4 -> ch*8, k=5, s=1, d=4, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*8, kernel_size=5, stride=1, dilation=4, use_instance_norm=True)(y)
# Block 5: ch*8 -> ch*4, k=5, s=2, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*8, out_channels=ch*4, kernel_size=5, stride=2, use_instance_norm=True)(y)
# Block 6: ch*4 -> ch*2, k=3, s=1, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*2, kernel_size=3, stride=1, use_instance_norm=True)(y)
# Block 7: ch*2 -> ch, k=3, s=1, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch, kernel_size=3, stride=1, use_instance_norm=True)(y)
# Block 8: ch -> 1, k=3, s=1, d=1, SN=F, IN=F
# NOTE: Spectral Norm omitted (as per original config)
y = DiscriminatorBlock(in_channels=ch, out_channels=1, kernel_size=3, stride=1, use_instance_norm=False)(y)
# Global Average Pooling (across Length dimension)
pooled = jnp.mean(y, axis=1) # Shape becomes (N, C=1)
# Flatten (optional, as shape is likely already (N, 1))
output = jnp.reshape(pooled, (pooled.shape[0], -1)) # Shape (N, 1)
return output

28
file_utils.py Normal file
View File

@ -0,0 +1,28 @@
import json
filepath = "my_data.json"
def write_data(filepath, data):
try:
with open(filepath, 'w') as f:
json.dump(data, f, indent=4) # Use indent for pretty formatting
print(f"Data written to '{filepath}'")
except Exception as e:
print(f"Error writing to file: {e}")
def read_data(filepath):
try:
with open(filepath, 'r') as f:
data = json.load(f)
print(f"Data read from '{filepath}'")
return data
except FileNotFoundError:
print(f"File not found: {filepath}")
return None
except json.JSONDecodeError:
print(f"Error decoding JSON from file: {filepath}")
return None
except Exception as e:
print(f"Error reading from file: {e}")
return None

View File

@ -1,52 +1,173 @@
import torch.nn as nn import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Sequence, Tuple
def conv_block(in_channels, out_channels, kernel_size=3, dilation=1): # --- Custom InstanceNorm1d Implementation ---
return nn.Sequential( class InstanceNorm1d(nn.Module):
nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, dilation=dilation, padding=(kernel_size // 2) * dilation), """
nn.BatchNorm1d(out_channels), Flax implementation of Instance Normalization for 1D data (NLC format).
nn.PReLU() Normalizes across the 'L' dimension.
) """
features: int
epsilon: float = 1e-5
use_scale: bool = True
use_bias: bool = True
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor of shape (batch, length, features)
Returns:
Normalized tensor.
"""
if x.shape[-1] != self.features:
raise ValueError(f"Input features {x.shape[-1]} does not match InstanceNorm1d features {self.features}")
# Calculate mean and variance across the length dimension (axis=1)
# Keep dims for broadcasting
mean = jnp.mean(x, axis=1, keepdims=True)
# Variance calculation using mean needs care for numerical stability if needed,
# but jnp.var should handle it.
var = jnp.var(x, axis=1, keepdims=True)
# Normalize
normalized = (x - mean) / jnp.sqrt(var + self.epsilon)
# Apply learnable scale and bias if enabled
if self.use_scale:
# Parameter shape: (features,) to broadcast across N and L
scale = self.param('scale', nn.initializers.ones, (self.features,))
normalized *= scale
if self.use_bias:
# Parameter shape: (features,)
bias = self.param('bias', nn.initializers.zeros, (self.features,))
normalized += bias
return normalized
# --- Converted Modules ---
class ConvBlock(nn.Module):
"""Equivalent of the PyTorch conv_block function."""
out_channels: int
kernel_size: int = 3
dilation: int = 1
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C_in)
Returns:
Output tensor (N, L, C_out)
"""
# Flax Conv expects kernel_size and dilation as sequences (tuples)
ks = (self.kernel_size,)
di = (self.dilation,)
# Padding='SAME' attempts to preserve the length dimension for stride=1
x = nn.Conv(
features=self.out_channels,
kernel_size=ks,
kernel_dilation=di,
padding='SAME'
)(x)
x = InstanceNorm1d(features=self.out_channels)(x) # Use custom InstanceNorm
x = nn.PReLU()(x) # PReLU learns 'alpha' parameter per channel
return x
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
def __init__(self, channels): """Simple Channel Attention Block in Flax."""
super(AttentionBlock, self).__init__() channels: int
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1), @nn.compact
nn.ReLU(), def __call__(self, x):
nn.Conv1d(channels // 4, channels, kernel_size=1), """
nn.Sigmoid() Args:
) x: Input tensor (N, L, C)
Returns:
Attention-weighted output tensor (N, L, C)
"""
# Flax Conv expects kernel_size as a sequence (tuple)
ks1 = (1,)
attention_weights = nn.Conv(
features=self.channels // 4, kernel_size=ks1, padding='SAME'
)(x)
# NOTE: PyTorch used inplace=True, JAX/Flax don't modify inplace
attention_weights = nn.relu(attention_weights)
attention_weights = nn.Conv(
features=self.channels, kernel_size=ks1, padding='SAME'
)(attention_weights)
attention_weights = nn.sigmoid(attention_weights)
def forward(self, x):
attention_weights = self.attention(x)
return x * attention_weights return x * attention_weights
class ResidualInResidualBlock(nn.Module): class ResidualInResidualBlock(nn.Module):
def __init__(self, channels, num_convs=3): """ResidualInResidualBlock in Flax."""
super(ResidualInResidualBlock, self).__init__() channels: int
self.conv_layers = nn.Sequential(*[conv_block(channels, channels) for _ in range(num_convs)]) num_convs: int = 3
self.attention = AttentionBlock(channels)
def forward(self, x): @nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C)
Returns:
Output tensor (N, L, C)
"""
residual = x residual = x
x = self.conv_layers(x) y = x
x = self.attention(x) # Sequentially apply ConvBlocks
return x + residual for _ in range(self.num_convs):
y = ConvBlock(
out_channels=self.channels,
kernel_size=3, # Assuming kernel_size 3 as in original conv_block default
dilation=1 # Assuming dilation 1 as in original conv_block default
)(y)
y = AttentionBlock(channels=self.channels)(y)
return y + residual
class SISUGenerator(nn.Module): class SISUGenerator(nn.Module):
def __init__(self, layer=4, num_rirb=4): #increased base layer and rirb amounts """SISUGenerator model translated to Flax."""
super(SISUGenerator, self).__init__() channels: int = 16
self.conv1 = nn.Sequential( num_rirb: int = 4
nn.Conv1d(1, layer, kernel_size=7, padding=3), alpha: float = 1.0 # Non-learnable parameter, passed during init
nn.BatchNorm1d(layer),
nn.PReLU(),
)
self.rir_blocks = nn.Sequential(*[ResidualInResidualBlock(layer) for _ in range(num_rirb)])
self.final_layer = nn.Conv1d(layer, 1, kernel_size=3, padding=1)
def forward(self, x): @nn.compact
residual = x def __call__(self, x):
x = self.conv1(x) """
x = self.rir_blocks(x) Args:
x = self.final_layer(x) x: Input tensor (N, L, 1) - assumes single channel input
return x + residual Returns:
Output tensor (N, L, 1)
"""
if x.shape[-1] != 1:
raise ValueError(f"Input should have 1 channel (NLC format), got shape {x.shape}")
residual_input = x
# Initial convolution block
# Flax Conv expects kernel_size as sequence
ks7 = (7,)
ks3 = (3,)
y = nn.Conv(features=self.channels, kernel_size=ks7, padding='SAME')(x)
y = InstanceNorm1d(features=self.channels)(y)
y = nn.PReLU()(y)
# Residual-in-Residual Blocks
rirb_out = y
for _ in range(self.num_rirb):
rirb_out = ResidualInResidualBlock(channels=self.channels)(rirb_out)
# Final layer
learned_residual = nn.Conv(
features=1, kernel_size=ks3, padding='SAME'
)(rirb_out)
# Combine with input residual
output = residual_input + self.alpha * learned_residual
return output

View File

@ -5,10 +5,8 @@ MarkupSafe==2.1.5
mpmath==1.3.0 mpmath==1.3.0
networkx==3.4.2 networkx==3.4.2
numpy==2.2.3 numpy==2.2.3
pytorch-triton-rocm==3.2.0+git4b3bb1f8 pillow==11.0.0
setuptools==70.2.0 setuptools==70.2.0
sympy==1.13.3 sympy==1.13.3
torch==2.7.0.dev20250226+rocm6.3
torchaudio==2.6.0.dev20250226+rocm6.3
tqdm==4.67.1 tqdm==4.67.1
typing_extensions==4.12.2 typing_extensions==4.12.2

View File

@ -1,192 +1,481 @@
import torch import jax
import torch.nn as nn import jax.numpy as jnp
import torch.optim as optim import optax
import torch.nn.functional as F
import torchaudio
import tqdm import tqdm
import pickle # Using pickle for simplicity to save JAX states
import argparse
import math
import os import os
import argparse
# You might need a JAX-compatible library for audio loading/saving or convert to numpy
import scipy.io.wavfile as wavfile # Example for saving audio
from torch.utils.data import random_split import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import AudioUtils import file_utils as Data
from data import AudioDataset from data import AudioDatasetNumPy
from generator import SISUGenerator from generator import SISUGenerator
from discriminator import SISUDiscriminator from discriminator import SISUDiscriminator
import torchaudio.transforms as T
# Init script argument parser # Init script argument parser
parser = argparse.ArgumentParser(description="Training script") parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None, parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file") help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None, parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file") help="Path to the discriminator model file")
parser.add_argument("--device", type=str, default="cpu", help="Select device") parser.add_argument("--epoch", type=int, default=0, help="Starting epoch for model versioning")
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
parser.add_argument("--debug", action="store_true", help="Print debug logs") parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models")
args = parser.parse_args() args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu") # Parameters
print(f"Using device: {device}") sample_rate = 44100
n_fft = 2048
mfcc_transform = T.MFCC( hop_length = 256
sample_rate=44100, win_length = n_fft
n_mfcc=20, n_mels = 128
melkwargs={'n_fft': 2048, 'hop_length': 256} n_mfcc = 20 # If using MFCC
).to(device)
def gpu_mfcc_loss(y_true, y_pred):
mfccs_true = mfcc_transform(y_true)
mfccs_pred = mfcc_transform(y_pred)
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
mfccs_true = mfccs_true[:, :, :min_len]
mfccs_pred = mfccs_pred[:, :, :min_len]
loss = torch.mean((mfccs_true - mfccs_pred)**2)
return loss
def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
optimizer_d.zero_grad()
# Forward pass for real samples
discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion_d(discriminator_decision_from_real, real_labels)
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
discriminator_decision_from_fake = discriminator(generator_output.detach())
d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels)
# Combine real and fake losses
d_loss = (d_loss_real + d_loss_fake) / 2.0
# Backward pass and optimization
d_loss.backward()
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
optimizer_d.step()
return d_loss
def generator_train(low_quality, high_quality, real_labels):
optimizer_g.zero_grad()
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
#mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output)
discriminator_decision = discriminator(generator_output)
adversarial_loss = criterion_g(discriminator_decision, real_labels)
#combined_loss = adversarial_loss + 0.5 * mfcc_l
adversarial_loss.backward()
optimizer_g.step()
#return (generator_output, combined_loss, adversarial_loss, mfcc_l)
return (generator_output, adversarial_loss)
debug = args.debug debug = args.debug
# Initialize JAX random key
key = jax.random.PRNGKey(0)
# Initialize dataset and dataloader # Initialize dataset and dataloader
dataset_dir = './dataset/good' dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir, device) dataset = AudioDatasetNumPy(dataset_dir) # Use your JAX dataset
train_data_loader = DataLoader(dataset, batch_size=4, shuffle=True) # Use your JAX DataLoader
# ========= SINGLE =========
train_data_loader = DataLoader(dataset, batch_size=256, shuffle=True)
# Initialize models and move them to device
generator = SISUGenerator()
discriminator = SISUDiscriminator()
epoch: int = args.epoch
generator = generator.to(device)
discriminator = discriminator.to(device)
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
if args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
# Loss
criterion_g = nn.MSELoss()
criterion_d = nn.BCELoss()
# Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
# Scheduler
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
models_dir = "models" models_dir = "models"
os.makedirs(models_dir, exist_ok=True) os.makedirs(models_dir, exist_ok=True)
audio_output_dir = "output"
os.makedirs(audio_output_dir, exist_ok=True)
# ========= MODELS =========
try:
# Fetch the first batch
first_batch = next(iter(train_data_loader))
# The batch is a tuple: ((high_quality_audio_np, high_quality_sample_rate), (low_quality_audio_np, low_quality_sample_rate))
# We need the high-quality audio NumPy array batch for initialization
sample_input_np = first_batch[0][0] # Get the high-quality audio NumPy array batch
# Convert the NumPy array batch to a JAX array
sample_input_array = jnp.array(sample_input_np)
# === FIX ===
# Transpose the array from (batch, channels, length) to (batch, length, channels)
# The original shape from DataLoader is likely (batch, channels, length) like (4, 1, 44100)
# The generator expects NLC format (batch, length, channels) i.e., (4, 44100, 1)
if sample_input_array.ndim == 3 and sample_input_array.shape[1] == 1:
sample_input_array = jnp.transpose(sample_input_array, (0, 2, 1)) # Swap axes 1 and 2
print(sample_input_array.shape) # Should now print (4, 44100, 1)
# === END FIX ===
except StopIteration:
print("Error: Data loader is empty. Cannot initialize models.")
exit() # Exit if no data is available
key, init_key_g, init_key_d = jax.random.split(key, 3)
generator_model = SISUGenerator()
discriminator_model = SISUDiscriminator()
# Initialize parameters
generator_params = generator_model.init(init_key_g, sample_input_array)['params']
discriminator_params = discriminator_model.init(init_key_d, sample_input_array)['params']
# Define apply functions
generator_apply_fn = generator_model.apply
discriminator_apply_fn = discriminator_model.apply
# Loss functions (JAX equivalents)
# Optax provides common loss functions. BCEWithLogitsLoss is equivalent to
# sigmoid_binary_cross_entropy in Optax combined with a sigmoid activation
# in the model output or handling logits directly. Assuming your discriminator
# outputs logits, optax.sigmoid_binary_cross_entropy is appropriate.
criterion_d = optax.sigmoid_binary_cross_entropy
criterion_l1 = optax.sigmoid_binary_cross_entropy # For Mel, STFT, MFCC losses
# Optimizers (using Optax)
optimizer_g = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.999)
optimizer_d = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.999)
# Initialize optimizer states
generator_opt_state = optimizer_g.init(generator_params)
discriminator_opt_state = optimizer_d.init(discriminator_params)
# Schedulers - Optax has learning rate schedules. ReduceLROnPlateau
# is stateful and usually handled outside the jitted training step,
# or you can implement a custom learning rate schedule in Optax that
# takes a metric. For simplicity here, we won't directly replicate the
# PyTorch ReduceLROnPlateau but you could add logic based on losses
# in the main loop to adjust the learning rate if needed.
# Load saved state if continuing training
start_epoch = args.epoch
if args.continue_training:
try:
with open(f"{models_dir}/temp_generator.pkl", 'rb') as f:
loaded_state = pickle.load(f)
generator_params = loaded_state['params']
generator_opt_state = loaded_state['opt_state']
with open(f"{models_dir}/temp_discriminator.pkl", 'rb') as f:
loaded_state = pickle.load(f)
discriminator_params = loaded_state['params']
discriminator_opt_state = loaded_state['opt_state']
epoch_data = Data.read_data(f"{models_dir}/epoch_data.json")
start_epoch = epoch_data.get("epoch", 0) + 1
print(f"Continuing training from epoch {start_epoch}")
except FileNotFoundError:
print("Continue training requested but temp models not found. Starting from scratch.")
except Exception as e:
print(f"Error loading temp models: {e}. Starting from scratch.")
if args.generator is not None:
try:
with open(args.generator, 'rb') as f:
loaded_state = pickle.load(f)
generator_params = loaded_state['params']
print(f"Loaded generator from {args.generator}")
except FileNotFoundError:
print(f"Generator model not found at {args.generator}")
if args.discriminator is not None:
try:
with open(args.discriminator, 'rb') as f:
loaded_state = pickle.load(f)
discriminator_params = loaded_state['params']
print(f"Loaded discriminator from {args.discriminator}")
except FileNotFoundError:
print(f"Discriminator model not found at {args.discriminator}")
# Initialize JAX audio transforms
# mel_transform_fn = MelSpectrogramJAX(
# sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
# win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
# )
# stft_transform_fn = SpectrogramJAX(
# n_fft=n_fft, win_length=win_length, hop_length=hop_length
# )
# mfcc_transform_fn = MFCCJAX(
# sample_rate,
# n_mfcc,
# melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
# )
# ========= JAX TRAINING STEPS =========
@jax.jit
def discriminator_train_step(
discriminator_params,
generator_params,
discriminator_opt_state,
high_quality_audio, # JAX array (batch, length, channels)
low_quality_audio, # JAX array (batch, length, channels)
real_labels, # JAX array
fake_labels, # JAX array
discriminator_apply_fn,
generator_apply_fn,
discriminator_optimizer,
criterion_d,
key # JAX random key
):
# Split key for potential randomness in model application
key, disc_key, gen_key = jax.random.split(key, 3)
def loss_fn(d_params):
# Generate fake audio
# Note: Generator is not being trained in this step, so its parameters are static
# Ensure low_quality_audio is in the expected NLC format (batch, length, channels)
if low_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
low_quality_audio = jnp.expand_dims(low_quality_audio, axis=-1)
elif low_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
low_quality_audio = jnp.expand_dims(jnp.expand_dims(low_quality_audio, axis=0), axis=-1)
enhanced_audio, _ = generator_apply_fn({'params': generator_params}, gen_key, low_quality_audio)
# Pass data through the discriminator
# Ensure enhanced_audio has a leading dimension if not already present (e.g., batch size)
if enhanced_audio.ndim == 2: # Assuming (length, channel), add batch dim
enhanced_audio = jnp.expand_dims(enhanced_audio, axis=0)
elif enhanced_audio.ndim == 1: # Assuming (length), add batch and channel dims
enhanced_audio = jnp.expand_dims(jnp.expand_dims(enhanced_audio, axis=0), axis=-1)
# Ensure high_quality_audio is in the expected NLC format (batch, length, channels)
if high_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
high_quality_audio = jnp.expand_dims(high_quality_audio, axis=-1)
elif high_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
high_quality_audio = jnp.expand_dims(jnp.expand_dims(high_quality_audio, axis=0), axis=-1)
real_output = discriminator_apply_fn({'params': d_params}, disc_key, high_quality_audio)
fake_output = discriminator_apply_fn({'params': d_params}, disc_key, enhanced_audio)
# Calculate loss (criterion_d is assumed to be Optax's sigmoid_binary_cross_entropy or similar)
# Ensure the shapes match the labels (batch_size, 1)
real_output = real_output.reshape(-1, 1)
fake_output = fake_output.reshape(-1, 1)
real_loss = jnp.mean(criterion_d(real_output, real_labels))
fake_loss = jnp.mean(criterion_d(fake_output, fake_labels))
total_loss = real_loss + fake_loss
return total_loss, (real_loss, fake_loss)
# Compute gradients
# Use jax.value_and_grad to get both the loss value and the gradients
(loss, (real_loss, fake_loss)), grads = jax.value_and_grad(loss_fn, has_aux=True)(discriminator_params)
# Apply updates
updates, new_discriminator_opt_state = discriminator_optimizer.update(grads, discriminator_opt_state, discriminator_params)
new_discriminator_params = optax.apply_updates(discriminator_params, updates)
return new_discriminator_params, new_discriminator_opt_state, loss, key
@jax.jit
def generator_train_step(
generator_params,
discriminator_params,
generator_opt_state,
low_quality_audio, # JAX array (batch, length, channels)
high_quality_audio, # JAX array (batch, length, channels)
real_labels, # JAX array
generator_apply_fn,
discriminator_apply_fn,
generator_optimizer,
criterion_d, # Adversarial loss
criterion_l1, # Feature matching loss
key # JAX random key
):
# Split key for potential randomness
key, gen_key, disc_key = jax.random.split(key, 3)
def loss_fn(g_params):
# Ensure low_quality_audio is in the expected NLC format (batch, length, channels)
if low_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
low_quality_audio = jnp.expand_dims(low_quality_audio, axis=-1)
elif low_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
low_quality_audio = jnp.expand_dims(jnp.expand_dims(low_quality_audio, axis=0), axis=-1)
# Generate enhanced audio
enhanced_audio, _ = generator_apply_fn({'params': g_params}, gen_key, low_quality_audio)
# Ensure enhanced_audio has a leading dimension if not already present
if enhanced_audio.ndim == 2: # Assuming (length, channel), add batch dim
enhanced_audio = jnp.expand_dims(enhanced_audio, axis=0)
elif enhanced_audio.ndim == 1: # Assuming (length), add batch and channel dims
enhanced_audio = jnp.expand_dims(jnp.expand_dims(enhanced_audio, axis=0), axis=-1)
# Calculate adversarial loss (generator wants discriminator to think fake is real)
# Note: Discriminator is not being trained in this step, so its parameters are static
fake_output = discriminator_apply_fn({'params': discriminator_params}, disc_key, enhanced_audio)
# Ensure the shape matches the labels (batch_size, 1)
fake_output = fake_output.reshape(-1, 1)
adversarial_loss = jnp.mean(criterion_d(fake_output, real_labels)) # Generator wants fake_output to be close to real_labels (1s)
# Feature matching losses (assuming you add these back later)
# You would need to implement JAX versions of your audio transforms
# mel_loss = criterion_l1(mel_transform_fn(enhanced_audio), mel_transform_fn(high_quality_audio))
# stft_loss = criterion_l1(stft_transform_fn(enhanced_audio), stft_transform_fn(high_quality_audio))
# mfcc_loss = criterion_l1(mfcc_transform_fn(enhanced_audio), mfcc_transform_fn(high_quality_audio))
# combined_loss = adversarial_loss + mel_loss + stft_loss + mfcc_loss
combined_loss = adversarial_loss # For now, only adversarial loss
# Return combined_loss and any other metrics needed for logging/analysis
# For now, just adversarial loss and enhanced_audio
return combined_loss, (adversarial_loss, enhanced_audio) # Add other losses here when implemented
# Compute gradients
# Update: loss_fn now returns (loss, (aux1, aux2, ...))
(loss, (adversarial_loss_val, enhanced_audio)), grads = jax.value_and_grad(loss_fn, has_aux=True)(generator_params)
# Apply updates
updates, new_generator_opt_state = generator_optimizer.update(grads, generator_opt_state, generator_params)
new_generator_params = optax.apply_updates(generator_params, updates)
# Return the loss components separately along with the enhanced audio and key
return new_generator_params, new_generator_opt_state, loss, adversarial_loss_val, enhanced_audio, key
# ========= MAIN TRAINING LOOP =========
def start_training(): def start_training():
global generator_params, discriminator_params, generator_opt_state, discriminator_opt_state, key
generator_epochs = 5000 generator_epochs = 5000
for generator_epoch in range(generator_epochs): for generator_epoch in range(generator_epochs):
low_quality_audio = (torch.empty((1)), 1) current_epoch = start_epoch + generator_epoch
high_quality_audio = (torch.empty((1)), 1)
ai_enhanced_audio = (torch.empty((1)), 1)
times_correct = 0 # These will hold the last processed audio examples from a batch for saving
last_high_quality_audio = None
last_low_quality_audio = None
last_ai_enhanced_audio = None
last_sample_rate = None
# ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
# for high_quality_clip, low_quality_clip in train_data_loader:
high_quality_sample = (high_quality_clip[0], high_quality_clip[1])
low_quality_sample = (low_quality_clip[0], low_quality_clip[1])
# ========= LABELS ========= # Use tqdm for progress bar
batch_size = high_quality_clip[0].size(0) for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {current_epoch}"):
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device) # high_quality_clip and low_quality_clip are tuples: (audio_array, sample_rate_array)
# Extract audio arrays and sample rates (assuming batch dimension is first)
# The arrays are NumPy arrays at this point, likely in (batch, channels, length) format
high_quality_audio_batch_np = high_quality_clip[0]
low_quality_audio_batch_np = low_quality_clip[0]
sample_rate_batch_np = high_quality_clip[1] # Assuming sample rates are the same for paired clips
# Convert NumPy arrays to JAX arrays and transpose to NLC format (batch, length, channels)
# Only transpose if the shape is (batch, channels, length)
if high_quality_audio_batch_np.ndim == 3 and high_quality_audio_batch_np.shape[1] == 1:
high_quality_audio_batch = jnp.transpose(jnp.array(high_quality_audio_batch_np), (0, 2, 1))
else:
high_quality_audio_batch = jnp.array(high_quality_audio_batch_np) # Assume already NLC or handle other cases
if low_quality_audio_batch_np.ndim == 3 and low_quality_audio_batch_np.shape[1] == 1:
low_quality_audio_batch = jnp.transpose(jnp.array(low_quality_audio_batch_np), (0, 2, 1))
else:
low_quality_audio_batch = jnp.array(low_quality_audio_batch_np) # Assume already NLC or handle other cases
sample_rate_batch = jnp.array(sample_rate_batch_np)
batch_size = high_quality_audio_batch.shape[0]
# Create labels - JAX arrays
real_labels = jnp.ones((batch_size, 1))
fake_labels = jnp.zeros((batch_size, 1))
# Split key for each batch
key, batch_key = jax.random.split(key)
# ========= DISCRIMINATOR ========= # ========= DISCRIMINATOR =========
discriminator.train() # Call the jitted discriminator training step
d_loss = discriminator_train(high_quality_sample, low_quality_sample, real_labels, fake_labels) discriminator_params, discriminator_opt_state, d_loss, batch_key = discriminator_train_step(
discriminator_params,
generator_params,
discriminator_opt_state,
high_quality_audio_batch,
low_quality_audio_batch,
real_labels,
fake_labels,
discriminator_apply_fn,
generator_apply_fn,
optimizer_d,
criterion_d,
batch_key
)
# ========= GENERATOR ========= # ========= GENERATOR =========
generator.train() # Call the jitted generator training step
#generator_output, combined_loss, adversarial_loss, mfcc_l = generator_train(low_quality_sample, high_quality_sample, real_labels) generator_params, generator_opt_state, combined_loss, adversarial_loss, enhanced_audio_batch, batch_key = generator_train_step(
generator_output, adversarial_loss = generator_train(low_quality_sample, high_quality_sample, real_labels) generator_params,
discriminator_params,
generator_opt_state,
low_quality_audio_batch,
high_quality_audio_batch,
real_labels, # Generator tries to make fake data look real
generator_apply_fn,
discriminator_apply_fn,
optimizer_g,
criterion_d,
criterion_l1,
batch_key
)
# Print debug logs (requires waiting for JIT compilation on first step)
if debug: if debug:
print(d_loss, adversarial_loss) # Use .block_until_ready() to ensure computation is finished before printing
scheduler_d.step(d_loss) # In a real scenario, you might want to log metrics less frequently
scheduler_g.step(adversarial_loss) d_loss_val = d_loss.block_until_ready().item()
combined_loss_val = combined_loss.block_until_ready().item()
adversarial_loss_val = adversarial_loss.block_until_ready().item()
# Assuming other losses are returned by generator_train_step and unpacked
# mel_loss_val = mel_l1_tensor.block_until_ready().item() if mel_l1_tensor is not None else 0
# stft_loss_val = log_stft_l1_tensor.block_until_ready().item() if log_stft_l1_tensor is not None else 0
# mfcc_loss_val = mfcc_l_tensor.block_until_ready().item() if mfcc_l_tensor is not None else 0
print(f"D_LOSS: {d_loss_val:.4f}, G_COMBINED_LOSS: {combined_loss_val:.4f}, G_ADVERSARIAL_LOSS: {adversarial_loss_val:.4f}")
# Print other losses here when implemented and returned
# ========= SAVE LATEST AUDIO ========= # Schedulers - Implement your learning rate scheduling logic here if needed
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0]) # based on the losses (e.g., reducing learning rate if loss plateaus).
low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0]) # This logic would typically live outside the jitted step function.
ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0]) # For Optax, you might use a schedule within the optimizer definition
# or update the learning rate of the optimizer manually.
new_epoch = generator_epoch+epoch
if generator_epoch % 10 == 0: # ========= SAVE LATEST AUDIO (from the last batch processed) =========
print(f"Saved epoch {new_epoch}!") # Access the first sample of the batch for saving
torchaudio.save(f"./output/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. # Ensure enhanced_audio_batch has a batch dimension and is in NLC format
torchaudio.save(f"./output/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu(), ai_enhanced_audio[1]) if enhanced_audio_batch.ndim == 2: # Assuming (length, channel), add batch dim
torchaudio.save(f"./output/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu(), high_quality_audio[1]) enhanced_audio_batch = jnp.expand_dims(enhanced_audio_batch, axis=0)
elif enhanced_audio_batch.ndim == 1: # Assuming (length), add batch and channel dims
enhanced_audio_batch = jnp.expand_dims(jnp.expand_dims(enhanced_audio_batch, axis=0), axis=-1)
if debug:
print(generator.state_dict().keys())
print(discriminator.state_dict().keys())
torch.save(discriminator.state_dict(), f"{models_dir}/discriminator_epoch_{new_epoch}.pt")
torch.save(generator.state_dict(), f"{models_dir}/generator_epoch_{new_epoch}.pt")
torch.save(discriminator, "models/epoch-5000-discriminator.pt") last_high_quality_audio = high_quality_audio_batch[0]
torch.save(generator, "models/epoch-5000-generator.pt") last_low_quality_audio = low_quality_audio_batch[0]
print("Training complete!") last_ai_enhanced_audio = enhanced_audio_batch[0]
last_sample_rate = sample_rate_batch[0].item() # Assuming sample rate is scalar per batch item
# Save audio files periodically (outside the batch loop)
if generator_epoch % 25 == 0 and last_high_quality_audio is not None:
print(f"Saving audio for epoch {current_epoch}!")
try:
# Convert JAX arrays to NumPy arrays for saving
# Transpose back to (length, channels) or (length) if needed by wavfile.write
# Assuming the models output (length, 1) or (length) after removing batch dim
low_quality_audio_np_save = jax.device_get(last_low_quality_audio)
ai_enhanced_audio_np_save = jax.device_get(last_ai_enhanced_audio)
high_quality_audio_np_save = jax.device_get(last_high_quality_audio)
# Remove the channel dimension if it's 1 for saving with wavfile
if low_quality_audio_np_save.shape[-1] == 1:
low_quality_audio_np_save = low_quality_audio_np_save.squeeze(axis=-1)
if ai_enhanced_audio_np_save.shape[-1] == 1:
ai_enhanced_audio_np_save = ai_enhanced_audio_np_save.squeeze(axis=-1)
if high_quality_audio_np_save.shape[-1] == 1:
high_quality_audio_np_save = high_quality_audio_np_save.squeeze(axis=-1)
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-crap.wav", last_sample_rate, low_quality_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-ai.wav", last_sample_rate, ai_enhanced_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-orig.wav", last_sample_rate, high_quality_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
except Exception as e:
print(f"Error saving audio files: {e}")
# Save model states periodically (outside the batch loop)
# Use pickle to save parameters and optimizer states
try:
with open(f"{models_dir}/temp_discriminator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(discriminator_params), 'opt_state': jax.device_get(discriminator_opt_state)}, f)
with open(f"{models_dir}/temp_generator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(generator_params), 'opt_state': jax.device_get(generator_opt_state)}, f)
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": current_epoch})
except Exception as e:
print(f"Error saving temp model states: {e}")
# Save final model states after all epochs
print("Training complete! Saving final models.")
try:
with open(f"{models_dir}/epoch-{start_epoch + generator_epochs - 1}-discriminator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(discriminator_params)}, f)
with open(f"{models_dir}/epoch-{start_epoch + generator_epochs - 1}-generator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(generator_params)}, f)
except Exception as e:
print(f"Error saving final model states: {e}")
start_training() start_training()

194
training.txt Normal file
View File

@ -0,0 +1,194 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import tqdm
import argparse
import math
import os
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import AudioUtils
from data import AudioDataset
from generator import SISUGenerator
from discriminator import SISUDiscriminator
from training_utils import discriminator_train, generator_train
import file_utils as Data
import torchaudio.transforms as T
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models")
args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Parameters
sample_rate = 44100
n_fft = 2048
hop_length = 256
win_length = n_fft
n_mels = 128
n_mfcc = 20 # If using MFCC
mfcc_transform = T.MFCC(
sample_rate,
n_mfcc,
melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
).to(device)
mel_transform = T.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
).to(device)
stft_transform = T.Spectrogram(
n_fft=n_fft, win_length=win_length, hop_length=hop_length
).to(device)
debug = args.debug
# Initialize dataset and dataloader
dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir, device)
models_dir = "models"
os.makedirs(models_dir, exist_ok=True)
audio_output_dir = "output"
os.makedirs(audio_output_dir, exist_ok=True)
# ========= SINGLE =========
train_data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
# ========= MODELS =========
generator = SISUGenerator()
discriminator = SISUDiscriminator()
epoch: int = args.epoch
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
if args.continue_training:
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
epoch = epoch_from_file["epoch"] + 1
else:
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
if args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
generator = generator.to(device)
discriminator = discriminator.to(device)
# Loss
criterion_g = nn.BCEWithLogitsLoss()
criterion_d = nn.BCEWithLogitsLoss()
# Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
# Scheduler
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
def start_training():
generator_epochs = 5000
for generator_epoch in range(generator_epochs):
low_quality_audio = (torch.empty((1)), 1)
high_quality_audio = (torch.empty((1)), 1)
ai_enhanced_audio = (torch.empty((1)), 1)
times_correct = 0
# ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
# for high_quality_clip, low_quality_clip in train_data_loader:
high_quality_sample = (high_quality_clip[0], high_quality_clip[1])
low_quality_sample = (low_quality_clip[0], low_quality_clip[1])
# ========= LABELS =========
batch_size = high_quality_clip[0].size(0)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# ========= DISCRIMINATOR =========
discriminator.train()
d_loss = discriminator_train(
high_quality_sample,
low_quality_sample,
real_labels,
fake_labels,
discriminator,
generator,
criterion_d,
optimizer_d
)
# ========= GENERATOR =========
generator.train()
generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train(
low_quality_sample,
high_quality_sample,
real_labels,
generator,
discriminator,
criterion_d,
optimizer_g,
device,
mel_transform,
stft_transform,
mfcc_transform
)
if debug:
print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}")
scheduler_d.step(d_loss.detach())
scheduler_g.step(adversarial_loss.detach())
# ========= SAVE LATEST AUDIO =========
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])
low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0])
ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0])
new_epoch = generator_epoch+epoch
if generator_epoch % 25 == 0:
print(f"Saved epoch {new_epoch}!")
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu().detach(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1])
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1])
#if debug:
# print(generator.state_dict().keys())
# print(discriminator.state_dict().keys())
torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt")
torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt")
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch})
torch.save(discriminator, "models/epoch-5000-discriminator.pt")
torch.save(generator, "models/epoch-5000-generator.pt")
print("Training complete!")
start_training()

142
training_utils.py Normal file
View File

@ -0,0 +1,142 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import torchaudio.transforms as T
def gpu_mfcc_loss(mfcc_transform, y_true, y_pred):
mfccs_true = mfcc_transform(y_true)
mfccs_pred = mfcc_transform(y_pred)
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
mfccs_true = mfccs_true[:, :, :min_len]
mfccs_pred = mfccs_pred[:, :, :min_len]
loss = torch.mean((mfccs_true - mfccs_pred)**2)
return loss
def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
mel_spec_true = mel_transform(y_true)
mel_spec_pred = mel_transform(y_pred)
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
mel_spec_true = mel_spec_true[..., :min_len]
mel_spec_pred = mel_spec_pred[..., :min_len]
loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred))
return loss
def mel_spectrogram_l2_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
mel_spec_true = mel_transform(y_true)
mel_spec_pred = mel_transform(y_pred)
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
mel_spec_true = mel_spec_true[..., :min_len]
mel_spec_pred = mel_spec_pred[..., :min_len]
loss = torch.mean((mel_spec_true - mel_spec_pred)**2)
return loss
def log_stft_magnitude_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
stft_mag_true = stft_transform(y_true)
stft_mag_pred = stft_transform(y_pred)
min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1])
stft_mag_true = stft_mag_true[..., :min_len]
stft_mag_pred = stft_mag_pred[..., :min_len]
loss = torch.mean(torch.abs(torch.log(stft_mag_true + eps) - torch.log(stft_mag_pred + eps)))
return loss
def spectral_convergence_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
stft_mag_true = stft_transform(y_true)
stft_mag_pred = stft_transform(y_pred)
min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1])
stft_mag_true = stft_mag_true[..., :min_len]
stft_mag_pred = stft_mag_pred[..., :min_len]
norm_true = torch.linalg.norm(stft_mag_true, ord='fro', dim=(-2, -1))
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, ord='fro', dim=(-2, -1))
loss = torch.mean(norm_diff / (norm_true + eps))
return loss
def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer):
optimizer.zero_grad()
# Forward pass for real samples
discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
with torch.no_grad():
generator_output = generator(low_quality[0])
discriminator_decision_from_fake = discriminator(generator_output)
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake))
d_loss = (d_loss_real + d_loss_fake) / 2.0
d_loss.backward()
# Optional: Gradient Clipping (can be helpful)
# nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
optimizer.step()
return d_loss
def generator_train(
low_quality,
high_quality,
real_labels,
generator,
discriminator,
adv_criterion,
g_optimizer,
device,
mel_transform: T.MelSpectrogram,
stft_transform: T.Spectrogram,
mfcc_transform: T.MFCC,
lambda_adv: float = 1.0,
lambda_mel_l1: float = 10.0,
lambda_log_stft: float = 1.0,
lambda_mfcc: float = 1.0
):
g_optimizer.zero_grad()
generator_output = generator(low_quality[0])
discriminator_decision = discriminator(generator_output)
adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision))
mel_l1 = 0.0
log_stft_l1 = 0.0
mfcc_l = 0.0
# Calculate Mel L1 Loss if weight is positive
if lambda_mel_l1 > 0:
mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality[0], generator_output)
# Calculate Log STFT L1 Loss if weight is positive
if lambda_log_stft > 0:
log_stft_l1 = log_stft_magnitude_loss(stft_transform, high_quality[0], generator_output)
# Calculate MFCC Loss if weight is positive
if lambda_mfcc > 0:
mfcc_l = gpu_mfcc_loss(mfcc_transform, high_quality[0], generator_output)
mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1
log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1
mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l
combined_loss = (lambda_adv * adversarial_loss) + \
(lambda_mel_l1 * mel_l1_tensor) + \
(lambda_log_stft * log_stft_l1_tensor) + \
(lambda_mfcc * mfcc_l_tensor)
combined_loss.backward()
# Optional: Gradient Clipping
# nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
g_optimizer.step()
# 6. Return values for logging
return generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor