22 Commits

Author SHA1 Message Date
5735557ec3 💩 | VERY CRUDE JAX implementation... 2025-04-30 23:45:05 +03:00
d70c86c257 | Implemented MFCC and STFT. 2025-04-26 17:03:28 +03:00
c04b072de6 | Added smarter ways that would've been needed from the begining. 2025-04-16 17:08:13 +03:00
b6d16e4f11 ♻️ | Restructured procject code. 2025-04-14 17:51:34 +03:00
3936b6c160 🐛 | Fixed NVIDIA training... again. 2025-04-07 14:49:07 +03:00
fbcd5803b8 🐛 | Fixed training on CPU and NVIDIA hardware. 2025-04-07 02:14:06 +03:00
9394bc6c5a :albemic: | Fat architecture. Hopefully better results. 2025-04-06 00:05:43 +03:00
f928d8c2cf :albemic: | More tests. 2025-03-25 21:51:29 +02:00
54338e55a9 :albemic: | Tests. 2025-03-25 19:50:51 +02:00
7e1c7e935a :albemic: | Experimenting with other model layouts. 2025-03-15 18:01:19 +02:00
416500f7fc | Removed/Updated dependencies. 2025-02-26 20:15:30 +02:00
8332b0df2d | Added ability to set epoch. 2025-02-26 19:36:43 +02:00
741dcce7b4 ⚗️ | Increase discriminator size and implement mfcc_loss for generator. 2025-02-23 13:52:01 +02:00
fb7b624c87 ⚗️ | Experimenting with very small model. 2025-02-10 12:44:42 +02:00
0790a0d3da ⚗️ | Experimenting with smaller architecture. 2025-01-25 16:48:10 +02:00
f615b39ded ⚗️ | Experimenting with larger model architecture. 2025-01-08 15:33:18 +02:00
89f8c68986 ⚗️ | Experimenting, again. 2024-12-26 04:00:24 +02:00
2ff45de22d 🔥 | Removed unnecessary test file. 2024-12-25 00:10:45 +02:00
eca71ff5ea ⚗️ | Experimenting still... 2024-12-25 00:09:57 +02:00
1000692f32 ⚗️ | Experimenting with other generator architectures. 2024-12-21 23:54:11 +02:00
de72ee31ea 🔥 | Removed unnecessary models. 2024-12-21 23:28:34 +02:00
70e20f53d4 ⚗️ | Experiment with other layer layouts. 2024-12-21 23:27:38 +02:00
12 changed files with 1230 additions and 185 deletions

1
.gitignore vendored
View File

@ -166,3 +166,4 @@ dataset/
old-output/ old-output/
output/ output/
*.wav *.wav
models/

18
AudioUtils.py Normal file
View File

@ -0,0 +1,18 @@
import torch
import torch.nn.functional as F
def stereo_tensor_to_mono(waveform):
if waveform.shape[0] > 1:
# Average across channels
mono_waveform = torch.mean(waveform, dim=0, keepdim=True)
else:
# Already mono
mono_waveform = waveform
return mono_waveform
def stretch_tensor(tensor, target_length):
scale_factor = target_length / tensor.size(1)
tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False)
return tensor

View File

@ -18,6 +18,7 @@ SISU (Super Ingenious Sound Upscaler) is a project that uses GANs (Generative Ad
1. **Set Up**: 1. **Set Up**:
- Make sure you have Python installed (version 3.8 or higher). - Make sure you have Python installed (version 3.8 or higher).
- Install needed packages: `pip install -r requirements.txt` - Install needed packages: `pip install -r requirements.txt`
- Install current version of PyTorch (CUDA/ROCm/What ever your device supports)
2. **Prepare Audio Data**: 2. **Prepare Audio Data**:
- Put your audio files in the `dataset/good` folder. - Put your audio files in the `dataset/good` folder.

116
data.py
View File

@ -1,50 +1,104 @@
# Keep necessary PyTorch imports for torchaudio and Dataset structure
from torch.utils.data import Dataset from torch.utils.data import Dataset
import torch.nn.functional as F import torch
import torchaudio import torchaudio
import torchaudio.transforms as T # Keep using torchaudio transforms
# Import NumPy
import numpy as np
import os import os
import random import random
# Assume AudioUtils is available and works on PyTorch Tensors as before
import AudioUtils
class AudioDatasetNumPy(Dataset): # Renamed slightly for clarity
audio_sample_rates = [11025]
MAX_LENGTH = 44100 # Define your desired maximum length here
class AudioDataset(Dataset): def __init__(self, input_dir):
audio_sample_rates = [8000, 11025, 16000, 22050] """
Initializes the dataset. Device argument is removed.
def __init__(self, input_dir, target_duration=None, padding_mode='constant', padding_value=0.0): """
self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')] self.input_files = [
self.target_duration = target_duration # Duration in seconds or None if not set os.path.join(root, f)
self.padding_mode = padding_mode for root, _, files in os.walk(input_dir)
self.padding_value = padding_value for f in files if f.endswith('.wav')
]
if not self.input_files:
print(f"Warning: No .wav files found in {input_dir}")
def __len__(self): def __len__(self):
return len(self.input_files) return len(self.input_files)
def __getitem__(self, idx): def __getitem__(self, idx):
high_quality_wav, sr_original = torchaudio.load(self.input_files[idx], normalize=True) """
Loads audio, processes it, and returns NumPy arrays.
"""
# --- Load and Resample using torchaudio (produces PyTorch tensors) ---
try:
high_quality_audio_pt, original_sample_rate = torchaudio.load(
self.input_files[idx], normalize=True
)
except Exception as e:
print(f"Error loading file {self.input_files[idx]}: {e}")
# Return None or raise error, or return dummy data if preferred
# Returning dummy data might hide issues
return None # Or handle appropriately
sample_rate = random.choice(self.audio_sample_rates) # Generate low-quality audio with random downsampling
resample_transform = torchaudio.transforms.Resample(sr_original, sample_rate) mangled_sample_rate = random.choice(self.audio_sample_rates)
low_quality_wav = resample_transform(high_quality_wav)
low_quality_wav = low_quality_wav
# Calculate target length based on desired duration and 16000 Hz # Ensure sample rates are different before resampling
if self.target_duration is not None: if original_sample_rate != mangled_sample_rate:
target_length = int(self.target_duration * 44100) resample_transform_low = T.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio_pt = resample_transform_low(high_quality_audio_pt)
resample_transform_high = T.Resample(mangled_sample_rate, original_sample_rate)
low_quality_audio_pt = resample_transform_high(low_quality_audio_pt)
else: else:
# Calculate duration of original high quality audio # If rates match, just copy the tensor
target_length = high_quality_wav.size(1) low_quality_audio_pt = high_quality_audio_pt.clone()
# Pad both to the calculated target length
high_quality_wav = self.stretch_tensor(high_quality_wav, target_length)
low_quality_wav = self.stretch_tensor(low_quality_wav, target_length)
return low_quality_wav, high_quality_wav # --- Process Stereo to Mono (still using PyTorch tensors) ---
# Assuming AudioUtils.stereo_tensor_to_mono expects PyTorch Tensor (C, L)
# and returns PyTorch Tensor (1, L)
try:
high_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(high_quality_audio_pt)
low_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(low_quality_audio_pt)
except Exception as e:
# Handle cases where mono conversion might fail (e.g., already mono)
# This depends on how AudioUtils is implemented. Let's assume it handles it.
print(f"Warning: Mono conversion issue with {self.input_files[idx]}: {e}. Using original.")
high_quality_audio_pt_mono = high_quality_audio_pt if high_quality_audio_pt.shape[0] == 1 else torch.mean(high_quality_audio_pt, dim=0, keepdim=True)
low_quality_audio_pt_mono = low_quality_audio_pt if low_quality_audio_pt.shape[0] == 1 else torch.mean(low_quality_audio_pt, dim=0, keepdim=True)
def stretch_tensor(self, tensor, target_length):
current_length = tensor.size(1)
scale_factor = target_length / current_length
# Resample the tensor using linear interpolation # --- Convert to NumPy Arrays ---
tensor = F.interpolate(tensor.unsqueeze(0), scale_factor=scale_factor, mode='linear', align_corners=False).squeeze(0) high_quality_audio_np = high_quality_audio_pt_mono.numpy() # Shape (1, L)
low_quality_audio_np = low_quality_audio_pt_mono.numpy() # Shape (1, L)
return tensor
# --- Pad or Truncate using NumPy ---
def process_numpy_audio(audio_np, max_len):
current_len = audio_np.shape[1] # Length is axis 1 for shape (1, L)
if current_len < max_len:
padding_needed = max_len - current_len
# np.pad format: ((pad_before_ax0, pad_after_ax0), (pad_before_ax1, pad_after_ax1), ...)
# We only pad axis 1 (length) at the end
audio_np = np.pad(audio_np, ((0, 0), (0, padding_needed)), mode='constant', constant_values=0)
elif current_len > max_len:
# Truncate axis 1 (length)
audio_np = audio_np[:, :max_len]
return audio_np
high_quality_audio_np = process_numpy_audio(high_quality_audio_np, self.MAX_LENGTH)
low_quality_audio_np = process_numpy_audio(low_quality_audio_np, self.MAX_LENGTH)
# --- Remove Device Handling ---
# .to(self.device) is removed.
# --- Return NumPy arrays and metadata ---
# Note: Arrays likely have shape (1, MAX_LENGTH) here
return (high_quality_audio_np, original_sample_rate), (low_quality_audio_np, mangled_sample_rate)

View File

@ -1,24 +1,145 @@
import torch.nn as nn import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Sequence, Tuple
# Assume InstanceNorm1d and AttentionBlock are defined as in the generator conversion
# --- Custom InstanceNorm1d Implementation (from Generator) ---
class InstanceNorm1d(nn.Module):
features: int
epsilon: float = 1e-5
use_scale: bool = True
use_bias: bool = True
@nn.compact
def __call__(self, x):
if x.shape[-1] != self.features:
raise ValueError(f"Input features {x.shape[-1]} does not match InstanceNorm1d features {self.features}")
mean = jnp.mean(x, axis=1, keepdims=True)
var = jnp.var(x, axis=1, keepdims=True)
normalized = (x - mean) / jnp.sqrt(var + self.epsilon)
if self.use_scale:
scale = self.param('scale', nn.initializers.ones, (self.features,))
normalized *= scale
if self.use_bias:
bias = self.param('bias', nn.initializers.zeros, (self.features,))
normalized += bias
return normalized
# --- AttentionBlock Implementation (from Generator) ---
class AttentionBlock(nn.Module):
channels: int
@nn.compact
def __call__(self, x):
ks1 = (1,)
attention_weights = nn.Conv(features=self.channels // 4, kernel_size=ks1, padding='SAME')(x)
attention_weights = nn.relu(attention_weights)
attention_weights = nn.Conv(features=self.channels, kernel_size=ks1, padding='SAME')(attention_weights)
attention_weights = nn.sigmoid(attention_weights)
return x * attention_weights
# --- Converted Discriminator Modules ---
class DiscriminatorBlock(nn.Module):
"""Equivalent of the PyTorch discriminator_block function."""
in_channels: int # Needed for clarity, though not strictly used by layers if input shape is known
out_channels: int
kernel_size: int = 3
stride: int = 1
dilation: int = 1
# spectral_norm: bool = True # Flag for where SN would be applied
use_instance_norm: bool = True
negative_slope: float = 0.2
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C_in)
Returns:
Output tensor (N, L', C_out) - L' depends on stride/padding
"""
# Flax Conv expects kernel_size, stride, dilation as sequences (tuples)
ks = (self.kernel_size,)
st = (self.stride,)
di = (self.dilation,)
# Padding='SAME' works reasonably well for stride=1 and stride=2 downsampling
# NOTE: Spectral Norm is omitted here.
# If implementing, you'd wrap or replace nn.Conv with a spectral-normalized version.
# conv_layer = SpectralNormConv1D(...) or wrap(nn.Conv(...))
y = nn.Conv(
features=self.out_channels,
kernel_size=ks,
strides=st,
kernel_dilation=di,
padding='SAME' # Often used in GANs
)(x)
# Apply LeakyReLU first (as in the original code if IN is used)
y = nn.leaky_relu(y, negative_slope=self.negative_slope)
# Conditionally apply InstanceNorm
if self.use_instance_norm:
y = InstanceNorm1d(features=self.out_channels)(y)
return y
class SISUDiscriminator(nn.Module): class SISUDiscriminator(nn.Module):
def __init__(self): """SISUDiscriminator model translated to Flax."""
super(SISUDiscriminator, self).__init__() base_channels: int = 16
self.model = nn.Sequential(
nn.Conv1d(2, 128, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(128, 256, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(256, 128, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(128, 64, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(64, 1, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1) # Output size (1,)
def forward(self, x): @nn.compact
x = self.model(x) def __call__(self, x):
x = self.global_avg_pool(x) """
x = x.view(-1, 1) # Flatten to (batch_size, 1) Args:
return x x: Input tensor (N, L, 1) - assumes single channel input
Returns:
Output tensor (N, 1) - logits
"""
if x.shape[-1] != 1:
raise ValueError(f"Input should have 1 channel (NLC format), got shape {x.shape}")
ch = self.base_channels
# Block 1: 1 -> ch, k=7, s=1, d=1, SN=T, IN=F
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=1, out_channels=ch, kernel_size=7, stride=1, use_instance_norm=False)(x)
# Block 2: ch -> ch*2, k=5, s=2, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch, out_channels=ch*2, kernel_size=5, stride=2, use_instance_norm=True)(y)
# Block 3: ch*2 -> ch*4, k=5, s=1, d=2, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch*4, kernel_size=5, stride=1, dilation=2, use_instance_norm=True)(y)
# Attention Block
y = AttentionBlock(channels=ch*4)(y)
# Block 4: ch*4 -> ch*8, k=5, s=1, d=4, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*8, kernel_size=5, stride=1, dilation=4, use_instance_norm=True)(y)
# Block 5: ch*8 -> ch*4, k=5, s=2, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*8, out_channels=ch*4, kernel_size=5, stride=2, use_instance_norm=True)(y)
# Block 6: ch*4 -> ch*2, k=3, s=1, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*2, kernel_size=3, stride=1, use_instance_norm=True)(y)
# Block 7: ch*2 -> ch, k=3, s=1, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch, kernel_size=3, stride=1, use_instance_norm=True)(y)
# Block 8: ch -> 1, k=3, s=1, d=1, SN=F, IN=F
# NOTE: Spectral Norm omitted (as per original config)
y = DiscriminatorBlock(in_channels=ch, out_channels=1, kernel_size=3, stride=1, use_instance_norm=False)(y)
# Global Average Pooling (across Length dimension)
pooled = jnp.mean(y, axis=1) # Shape becomes (N, C=1)
# Flatten (optional, as shape is likely already (N, 1))
output = jnp.reshape(pooled, (pooled.shape[0], -1)) # Shape (N, 1)
return output

28
file_utils.py Normal file
View File

@ -0,0 +1,28 @@
import json
filepath = "my_data.json"
def write_data(filepath, data):
try:
with open(filepath, 'w') as f:
json.dump(data, f, indent=4) # Use indent for pretty formatting
print(f"Data written to '{filepath}'")
except Exception as e:
print(f"Error writing to file: {e}")
def read_data(filepath):
try:
with open(filepath, 'r') as f:
data = json.load(f)
print(f"Data read from '{filepath}'")
return data
except FileNotFoundError:
print(f"File not found: {filepath}")
return None
except json.JSONDecodeError:
print(f"Error decoding JSON from file: {filepath}")
return None
except Exception as e:
print(f"Error reading from file: {e}")
return None

View File

@ -1,23 +1,173 @@
import torch.nn as nn import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Sequence, Tuple
# --- Custom InstanceNorm1d Implementation ---
class InstanceNorm1d(nn.Module):
"""
Flax implementation of Instance Normalization for 1D data (NLC format).
Normalizes across the 'L' dimension.
"""
features: int
epsilon: float = 1e-5
use_scale: bool = True
use_bias: bool = True
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor of shape (batch, length, features)
Returns:
Normalized tensor.
"""
if x.shape[-1] != self.features:
raise ValueError(f"Input features {x.shape[-1]} does not match InstanceNorm1d features {self.features}")
# Calculate mean and variance across the length dimension (axis=1)
# Keep dims for broadcasting
mean = jnp.mean(x, axis=1, keepdims=True)
# Variance calculation using mean needs care for numerical stability if needed,
# but jnp.var should handle it.
var = jnp.var(x, axis=1, keepdims=True)
# Normalize
normalized = (x - mean) / jnp.sqrt(var + self.epsilon)
# Apply learnable scale and bias if enabled
if self.use_scale:
# Parameter shape: (features,) to broadcast across N and L
scale = self.param('scale', nn.initializers.ones, (self.features,))
normalized *= scale
if self.use_bias:
# Parameter shape: (features,)
bias = self.param('bias', nn.initializers.zeros, (self.features,))
normalized += bias
return normalized
# --- Converted Modules ---
class ConvBlock(nn.Module):
"""Equivalent of the PyTorch conv_block function."""
out_channels: int
kernel_size: int = 3
dilation: int = 1
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C_in)
Returns:
Output tensor (N, L, C_out)
"""
# Flax Conv expects kernel_size and dilation as sequences (tuples)
ks = (self.kernel_size,)
di = (self.dilation,)
# Padding='SAME' attempts to preserve the length dimension for stride=1
x = nn.Conv(
features=self.out_channels,
kernel_size=ks,
kernel_dilation=di,
padding='SAME'
)(x)
x = InstanceNorm1d(features=self.out_channels)(x) # Use custom InstanceNorm
x = nn.PReLU()(x) # PReLU learns 'alpha' parameter per channel
return x
class AttentionBlock(nn.Module):
"""Simple Channel Attention Block in Flax."""
channels: int
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C)
Returns:
Attention-weighted output tensor (N, L, C)
"""
# Flax Conv expects kernel_size as a sequence (tuple)
ks1 = (1,)
attention_weights = nn.Conv(
features=self.channels // 4, kernel_size=ks1, padding='SAME'
)(x)
# NOTE: PyTorch used inplace=True, JAX/Flax don't modify inplace
attention_weights = nn.relu(attention_weights)
attention_weights = nn.Conv(
features=self.channels, kernel_size=ks1, padding='SAME'
)(attention_weights)
attention_weights = nn.sigmoid(attention_weights)
return x * attention_weights
class ResidualInResidualBlock(nn.Module):
"""ResidualInResidualBlock in Flax."""
channels: int
num_convs: int = 3
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C)
Returns:
Output tensor (N, L, C)
"""
residual = x
y = x
# Sequentially apply ConvBlocks
for _ in range(self.num_convs):
y = ConvBlock(
out_channels=self.channels,
kernel_size=3, # Assuming kernel_size 3 as in original conv_block default
dilation=1 # Assuming dilation 1 as in original conv_block default
)(y)
y = AttentionBlock(channels=self.channels)(y)
return y + residual
class SISUGenerator(nn.Module): class SISUGenerator(nn.Module):
def __init__(self, upscale_scale=1): # No noise_dim parameter """SISUGenerator model translated to Flax."""
super(SISUGenerator, self).__init__() channels: int = 16
self.model = nn.Sequential( num_rirb: int = 4
nn.Conv1d(2, 128, kernel_size=3, padding=1), alpha: float = 1.0 # Non-learnable parameter, passed during init
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(128, 256, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=upscale_scale, mode='nearest'), @nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, 1) - assumes single channel input
Returns:
Output tensor (N, L, 1)
"""
if x.shape[-1] != 1:
raise ValueError(f"Input should have 1 channel (NLC format), got shape {x.shape}")
nn.Conv1d(256, 128, kernel_size=3, padding=1), residual_input = x
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(128, 64, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(64, 2, kernel_size=3, padding=1),
nn.Tanh()
)
def forward(self, x): # Initial convolution block
return self.model(x) # Flax Conv expects kernel_size as sequence
ks7 = (7,)
ks3 = (3,)
y = nn.Conv(features=self.channels, kernel_size=ks7, padding='SAME')(x)
y = InstanceNorm1d(features=self.channels)(y)
y = nn.PReLU()(y)
# Residual-in-Residual Blocks
rirb_out = y
for _ in range(self.num_rirb):
rirb_out = ResidualInResidualBlock(channels=self.channels)(rirb_out)
# Final layer
learned_residual = nn.Conv(
features=1, kernel_size=ks3, padding='SAME'
)(rirb_out)
# Combine with input residual
output = residual_input + self.alpha * learned_residual
return output

View File

@ -1,12 +1,12 @@
filelock>=3.16.1 filelock==3.16.1
fsspec>=2024.10.0 fsspec==2024.10.0
Jinja2>=3.1.4 Jinja2==3.1.4
MarkupSafe>=2.1.5 MarkupSafe==2.1.5
mpmath>=1.3.0 mpmath==1.3.0
networkx>=3.4.2 networkx==3.4.2
numpy>=2.1.2 numpy==2.2.3
pillow>=11.0.0 pillow==11.0.0
setuptools>=70.2.0 setuptools==70.2.0
sympy>=1.13.1 sympy==1.13.3
tqdm>=4.67.1 tqdm==4.67.1
typing_extensions>=4.12.2 typing_extensions==4.12.2

10
test.py
View File

@ -1,10 +0,0 @@
import torch.nn as nn
import torch
from discriminator import SISUDiscriminator
discriminator = SISUDiscriminator()
test_input = torch.randn(1, 2, 1000) # Example input (batch_size, channels, frames)
output = discriminator(test_input)
print(output)
print("Output shape:", output.shape)

View File

@ -1,135 +1,481 @@
import torch import jax
import torch.nn as nn import jax.numpy as jnp
import torch.optim as optim import optax
import torch.nn.functional as F
import torchaudio
import tqdm import tqdm
import pickle # Using pickle for simplicity to save JAX states
import os
import argparse
# You might need a JAX-compatible library for audio loading/saving or convert to numpy
import scipy.io.wavfile as wavfile # Example for saving audio
from torch.utils.data import random_split import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from data import AudioDataset import file_utils as Data
from data import AudioDatasetNumPy
from generator import SISUGenerator from generator import SISUGenerator
from discriminator import SISUDiscriminator from discriminator import SISUDiscriminator
# Check for CUDA availability # Init script argument parser
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") parser = argparse.ArgumentParser(description="Training script")
print(f"Using device: {device}") parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
parser.add_argument("--epoch", type=int, default=0, help="Starting epoch for model versioning")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models")
args = parser.parse_args()
# Parameters
sample_rate = 44100
n_fft = 2048
hop_length = 256
win_length = n_fft
n_mels = 128
n_mfcc = 20 # If using MFCC
debug = args.debug
# Initialize JAX random key
key = jax.random.PRNGKey(0)
# Initialize dataset and dataloader # Initialize dataset and dataloader
dataset_dir = './dataset/good' dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir, target_duration=2.0) dataset = AudioDatasetNumPy(dataset_dir) # Use your JAX dataset
train_data_loader = DataLoader(dataset, batch_size=4, shuffle=True) # Use your JAX DataLoader
dataset_size = len(dataset) models_dir = "models"
train_size = int(dataset_size * .9) os.makedirs(models_dir, exist_ok=True)
val_size = int(dataset_size-train_size) audio_output_dir = "output"
os.makedirs(audio_output_dir, exist_ok=True)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_data_loader = DataLoader(train_dataset, batch_size=8, shuffle=True) # ========= MODELS =========
val_data_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)
# Initialize models and move them to device try:
generator = SISUGenerator() # Fetch the first batch
discriminator = SISUDiscriminator() first_batch = next(iter(train_data_loader))
# The batch is a tuple: ((high_quality_audio_np, high_quality_sample_rate), (low_quality_audio_np, low_quality_sample_rate))
# We need the high-quality audio NumPy array batch for initialization
sample_input_np = first_batch[0][0] # Get the high-quality audio NumPy array batch
# Convert the NumPy array batch to a JAX array
sample_input_array = jnp.array(sample_input_np)
generator = generator.to(device) # === FIX ===
discriminator = discriminator.to(device) # Transpose the array from (batch, channels, length) to (batch, length, channels)
# The original shape from DataLoader is likely (batch, channels, length) like (4, 1, 44100)
# The generator expects NLC format (batch, length, channels) i.e., (4, 44100, 1)
if sample_input_array.ndim == 3 and sample_input_array.shape[1] == 1:
sample_input_array = jnp.transpose(sample_input_array, (0, 2, 1)) # Swap axes 1 and 2
# Loss print(sample_input_array.shape) # Should now print (4, 44100, 1)
criterion_g = nn.L1Loss() # === END FIX ===
criterion_d = nn.BCEWithLogitsLoss()
# Optimizers except StopIteration:
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999)) print("Error: Data loader is empty. Cannot initialize models.")
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) exit() # Exit if no data is available
# Scheduler
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
def snr(y_true, y_pred): key, init_key_g, init_key_d = jax.random.split(key, 3)
noise = y_true - y_pred generator_model = SISUGenerator()
signal_power = torch.mean(y_true ** 2) discriminator_model = SISUDiscriminator()
noise_power = torch.mean(noise ** 2)
snr_db = 10 * torch.log10(signal_power / noise_power)
return snr_db
def discriminator_train(discriminator, optimizer, criterion, generator, real_labels, fake_labels, high_quality, low_quality): # Initialize parameters
optimizer.zero_grad() generator_params = generator_model.init(init_key_g, sample_input_array)['params']
discriminator_params = discriminator_model.init(init_key_d, sample_input_array)['params']
discriminator_decision_from_real = discriminator(high_quality) # Define apply functions
d_loss_real = criterion(discriminator_decision_from_real, real_labels) generator_apply_fn = generator_model.apply
discriminator_apply_fn = discriminator_model.apply
generator_output = generator(low_quality)
discriminator_decision_from_fake = discriminator(generator_output.detach())
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels)
d_loss = (d_loss_real + d_loss_fake) / 2.0 # Loss functions (JAX equivalents)
# Optax provides common loss functions. BCEWithLogitsLoss is equivalent to
# sigmoid_binary_cross_entropy in Optax combined with a sigmoid activation
# in the model output or handling logits directly. Assuming your discriminator
# outputs logits, optax.sigmoid_binary_cross_entropy is appropriate.
criterion_d = optax.sigmoid_binary_cross_entropy
criterion_l1 = optax.sigmoid_binary_cross_entropy # For Mel, STFT, MFCC losses
d_loss.backward() # Optimizers (using Optax)
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) #Gradient Clipping optimizer_g = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.999)
optimizer.step() optimizer_d = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.999)
# print(f"Discriminator Loss: {d_loss.item():.4f}, Mean Real Logit: {discriminator_decision_from_real.mean().item():.2f}, Mean Fake Logit: {discriminator_decision_from_fake.mean().item():.2f}")
# Initialize optimizer states
generator_opt_state = optimizer_g.init(generator_params)
discriminator_opt_state = optimizer_d.init(discriminator_params)
# Schedulers - Optax has learning rate schedules. ReduceLROnPlateau
# is stateful and usually handled outside the jitted training step,
# or you can implement a custom learning rate schedule in Optax that
# takes a metric. For simplicity here, we won't directly replicate the
# PyTorch ReduceLROnPlateau but you could add logic based on losses
# in the main loop to adjust the learning rate if needed.
# Load saved state if continuing training
start_epoch = args.epoch
if args.continue_training:
try:
with open(f"{models_dir}/temp_generator.pkl", 'rb') as f:
loaded_state = pickle.load(f)
generator_params = loaded_state['params']
generator_opt_state = loaded_state['opt_state']
with open(f"{models_dir}/temp_discriminator.pkl", 'rb') as f:
loaded_state = pickle.load(f)
discriminator_params = loaded_state['params']
discriminator_opt_state = loaded_state['opt_state']
epoch_data = Data.read_data(f"{models_dir}/epoch_data.json")
start_epoch = epoch_data.get("epoch", 0) + 1
print(f"Continuing training from epoch {start_epoch}")
except FileNotFoundError:
print("Continue training requested but temp models not found. Starting from scratch.")
except Exception as e:
print(f"Error loading temp models: {e}. Starting from scratch.")
if args.generator is not None:
try:
with open(args.generator, 'rb') as f:
loaded_state = pickle.load(f)
generator_params = loaded_state['params']
print(f"Loaded generator from {args.generator}")
except FileNotFoundError:
print(f"Generator model not found at {args.generator}")
if args.discriminator is not None:
try:
with open(args.discriminator, 'rb') as f:
loaded_state = pickle.load(f)
discriminator_params = loaded_state['params']
print(f"Loaded discriminator from {args.discriminator}")
except FileNotFoundError:
print(f"Discriminator model not found at {args.discriminator}")
# Initialize JAX audio transforms
# mel_transform_fn = MelSpectrogramJAX(
# sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
# win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
# )
# stft_transform_fn = SpectrogramJAX(
# n_fft=n_fft, win_length=win_length, hop_length=hop_length
# )
# mfcc_transform_fn = MFCCJAX(
# sample_rate,
# n_mfcc,
# melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
# )
# ========= JAX TRAINING STEPS =========
@jax.jit
def discriminator_train_step(
discriminator_params,
generator_params,
discriminator_opt_state,
high_quality_audio, # JAX array (batch, length, channels)
low_quality_audio, # JAX array (batch, length, channels)
real_labels, # JAX array
fake_labels, # JAX array
discriminator_apply_fn,
generator_apply_fn,
discriminator_optimizer,
criterion_d,
key # JAX random key
):
# Split key for potential randomness in model application
key, disc_key, gen_key = jax.random.split(key, 3)
def loss_fn(d_params):
# Generate fake audio
# Note: Generator is not being trained in this step, so its parameters are static
# Ensure low_quality_audio is in the expected NLC format (batch, length, channels)
if low_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
low_quality_audio = jnp.expand_dims(low_quality_audio, axis=-1)
elif low_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
low_quality_audio = jnp.expand_dims(jnp.expand_dims(low_quality_audio, axis=0), axis=-1)
enhanced_audio, _ = generator_apply_fn({'params': generator_params}, gen_key, low_quality_audio)
# Pass data through the discriminator
# Ensure enhanced_audio has a leading dimension if not already present (e.g., batch size)
if enhanced_audio.ndim == 2: # Assuming (length, channel), add batch dim
enhanced_audio = jnp.expand_dims(enhanced_audio, axis=0)
elif enhanced_audio.ndim == 1: # Assuming (length), add batch and channel dims
enhanced_audio = jnp.expand_dims(jnp.expand_dims(enhanced_audio, axis=0), axis=-1)
# Ensure high_quality_audio is in the expected NLC format (batch, length, channels)
if high_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
high_quality_audio = jnp.expand_dims(high_quality_audio, axis=-1)
elif high_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
high_quality_audio = jnp.expand_dims(jnp.expand_dims(high_quality_audio, axis=0), axis=-1)
real_output = discriminator_apply_fn({'params': d_params}, disc_key, high_quality_audio)
fake_output = discriminator_apply_fn({'params': d_params}, disc_key, enhanced_audio)
# Calculate loss (criterion_d is assumed to be Optax's sigmoid_binary_cross_entropy or similar)
# Ensure the shapes match the labels (batch_size, 1)
real_output = real_output.reshape(-1, 1)
fake_output = fake_output.reshape(-1, 1)
real_loss = jnp.mean(criterion_d(real_output, real_labels))
fake_loss = jnp.mean(criterion_d(fake_output, fake_labels))
total_loss = real_loss + fake_loss
return total_loss, (real_loss, fake_loss)
# Compute gradients
# Use jax.value_and_grad to get both the loss value and the gradients
(loss, (real_loss, fake_loss)), grads = jax.value_and_grad(loss_fn, has_aux=True)(discriminator_params)
# Apply updates
updates, new_discriminator_opt_state = discriminator_optimizer.update(grads, discriminator_opt_state, discriminator_params)
new_discriminator_params = optax.apply_updates(discriminator_params, updates)
return new_discriminator_params, new_discriminator_opt_state, loss, key
@jax.jit
def generator_train_step(
generator_params,
discriminator_params,
generator_opt_state,
low_quality_audio, # JAX array (batch, length, channels)
high_quality_audio, # JAX array (batch, length, channels)
real_labels, # JAX array
generator_apply_fn,
discriminator_apply_fn,
generator_optimizer,
criterion_d, # Adversarial loss
criterion_l1, # Feature matching loss
key # JAX random key
):
# Split key for potential randomness
key, gen_key, disc_key = jax.random.split(key, 3)
def loss_fn(g_params):
# Ensure low_quality_audio is in the expected NLC format (batch, length, channels)
if low_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
low_quality_audio = jnp.expand_dims(low_quality_audio, axis=-1)
elif low_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
low_quality_audio = jnp.expand_dims(jnp.expand_dims(low_quality_audio, axis=0), axis=-1)
# Generate enhanced audio
enhanced_audio, _ = generator_apply_fn({'params': g_params}, gen_key, low_quality_audio)
# Ensure enhanced_audio has a leading dimension if not already present
if enhanced_audio.ndim == 2: # Assuming (length, channel), add batch dim
enhanced_audio = jnp.expand_dims(enhanced_audio, axis=0)
elif enhanced_audio.ndim == 1: # Assuming (length), add batch and channel dims
enhanced_audio = jnp.expand_dims(jnp.expand_dims(enhanced_audio, axis=0), axis=-1)
# Calculate adversarial loss (generator wants discriminator to think fake is real)
# Note: Discriminator is not being trained in this step, so its parameters are static
fake_output = discriminator_apply_fn({'params': discriminator_params}, disc_key, enhanced_audio)
# Ensure the shape matches the labels (batch_size, 1)
fake_output = fake_output.reshape(-1, 1)
adversarial_loss = jnp.mean(criterion_d(fake_output, real_labels)) # Generator wants fake_output to be close to real_labels (1s)
# Feature matching losses (assuming you add these back later)
# You would need to implement JAX versions of your audio transforms
# mel_loss = criterion_l1(mel_transform_fn(enhanced_audio), mel_transform_fn(high_quality_audio))
# stft_loss = criterion_l1(stft_transform_fn(enhanced_audio), stft_transform_fn(high_quality_audio))
# mfcc_loss = criterion_l1(mfcc_transform_fn(enhanced_audio), mfcc_transform_fn(high_quality_audio))
# combined_loss = adversarial_loss + mel_loss + stft_loss + mfcc_loss
combined_loss = adversarial_loss # For now, only adversarial loss
# Return combined_loss and any other metrics needed for logging/analysis
# For now, just adversarial loss and enhanced_audio
return combined_loss, (adversarial_loss, enhanced_audio) # Add other losses here when implemented
# Compute gradients
# Update: loss_fn now returns (loss, (aux1, aux2, ...))
(loss, (adversarial_loss_val, enhanced_audio)), grads = jax.value_and_grad(loss_fn, has_aux=True)(generator_params)
# Apply updates
updates, new_generator_opt_state = generator_optimizer.update(grads, generator_opt_state, generator_params)
new_generator_params = optax.apply_updates(generator_params, updates)
# Return the loss components separately along with the enhanced audio and key
return new_generator_params, new_generator_opt_state, loss, adversarial_loss_val, enhanced_audio, key
# ========= MAIN TRAINING LOOP =========
def start_training(): def start_training():
global generator_params, discriminator_params, generator_opt_state, discriminator_opt_state, key
generator_epochs = 5000
# Training loop
# discriminator_epochs = 1000
generator_epochs = 500
for generator_epoch in range(generator_epochs): for generator_epoch in range(generator_epochs):
low_quality_audio = torch.empty((1)) current_epoch = start_epoch + generator_epoch
high_quality_audio = torch.empty((1))
ai_enhanced_audio = torch.empty((1))
# Training # These will hold the last processed audio examples from a batch for saving
for low_quality, high_quality in tqdm.tqdm(train_data_loader, desc=f"Epoch {generator_epoch+1}/{generator_epochs}"): last_high_quality_audio = None
high_quality = high_quality.to(device) last_low_quality_audio = None
low_quality = low_quality.to(device) last_ai_enhanced_audio = None
last_sample_rate = None
batch_size = high_quality.size(0)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# Train Discriminator # Use tqdm for progress bar
discriminator.train() for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {current_epoch}"):
for _ in range(3): # high_quality_clip and low_quality_clip are tuples: (audio_array, sample_rate_array)
discriminator_train(discriminator, optimizer_d, criterion_d, generator, real_labels, fake_labels, high_quality, low_quality) # Extract audio arrays and sample rates (assuming batch dimension is first)
# The arrays are NumPy arrays at this point, likely in (batch, channels, length) format
high_quality_audio_batch_np = high_quality_clip[0]
low_quality_audio_batch_np = low_quality_clip[0]
sample_rate_batch_np = high_quality_clip[1] # Assuming sample rates are the same for paired clips
# Train Generator # Convert NumPy arrays to JAX arrays and transpose to NLC format (batch, length, channels)
generator.train() # Only transpose if the shape is (batch, channels, length)
optimizer_g.zero_grad() if high_quality_audio_batch_np.ndim == 3 and high_quality_audio_batch_np.shape[1] == 1:
high_quality_audio_batch = jnp.transpose(jnp.array(high_quality_audio_batch_np), (0, 2, 1))
else:
high_quality_audio_batch = jnp.array(high_quality_audio_batch_np) # Assume already NLC or handle other cases
# Generator loss: how well fake data fools the discriminator if low_quality_audio_batch_np.ndim == 3 and low_quality_audio_batch_np.shape[1] == 1:
generator_output = generator(low_quality) low_quality_audio_batch = jnp.transpose(jnp.array(low_quality_audio_batch_np), (0, 2, 1))
discriminator_decision = discriminator(generator_output) # No detach here else:
g_loss = criterion_g(discriminator_decision, real_labels) # Train generator to produce real-like outputs low_quality_audio_batch = jnp.array(low_quality_audio_batch_np) # Assume already NLC or handle other cases
g_loss.backward() sample_rate_batch = jnp.array(sample_rate_batch_np)
optimizer_g.step()
low_quality_audio = low_quality
high_quality_audio = high_quality
ai_enhanced_audio = generator_output
metric = snr(high_quality_audio, ai_enhanced_audio) batch_size = high_quality_audio_batch.shape[0]
print(f"Generator metric {metric}!") # Create labels - JAX arrays
scheduler_g.step(metric) real_labels = jnp.ones((batch_size, 1))
fake_labels = jnp.zeros((batch_size, 1))
if generator_epoch % 10 == 0: # Split key for each batch
print(f"Saved epoch {generator_epoch}!") key, batch_key = jax.random.split(key)
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0].cpu(), 44100)
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu(), 44100)
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0].cpu(), 44100)
if generator_epoch % 50 == 0: # ========= DISCRIMINATOR =========
torch.save(discriminator.state_dict(), "discriminator.pt") # Call the jitted discriminator training step
torch.save(generator.state_dict(), "generator.pt") discriminator_params, discriminator_opt_state, d_loss, batch_key = discriminator_train_step(
discriminator_params,
generator_params,
discriminator_opt_state,
high_quality_audio_batch,
low_quality_audio_batch,
real_labels,
fake_labels,
discriminator_apply_fn,
generator_apply_fn,
optimizer_d,
criterion_d,
batch_key
)
# ========= GENERATOR =========
# Call the jitted generator training step
generator_params, generator_opt_state, combined_loss, adversarial_loss, enhanced_audio_batch, batch_key = generator_train_step(
generator_params,
discriminator_params,
generator_opt_state,
low_quality_audio_batch,
high_quality_audio_batch,
real_labels, # Generator tries to make fake data look real
generator_apply_fn,
discriminator_apply_fn,
optimizer_g,
criterion_d,
criterion_l1,
batch_key
)
# Print debug logs (requires waiting for JIT compilation on first step)
if debug:
# Use .block_until_ready() to ensure computation is finished before printing
# In a real scenario, you might want to log metrics less frequently
d_loss_val = d_loss.block_until_ready().item()
combined_loss_val = combined_loss.block_until_ready().item()
adversarial_loss_val = adversarial_loss.block_until_ready().item()
# Assuming other losses are returned by generator_train_step and unpacked
# mel_loss_val = mel_l1_tensor.block_until_ready().item() if mel_l1_tensor is not None else 0
# stft_loss_val = log_stft_l1_tensor.block_until_ready().item() if log_stft_l1_tensor is not None else 0
# mfcc_loss_val = mfcc_l_tensor.block_until_ready().item() if mfcc_l_tensor is not None else 0
print(f"D_LOSS: {d_loss_val:.4f}, G_COMBINED_LOSS: {combined_loss_val:.4f}, G_ADVERSARIAL_LOSS: {adversarial_loss_val:.4f}")
# Print other losses here when implemented and returned
# Schedulers - Implement your learning rate scheduling logic here if needed
# based on the losses (e.g., reducing learning rate if loss plateaus).
# This logic would typically live outside the jitted step function.
# For Optax, you might use a schedule within the optimizer definition
# or update the learning rate of the optimizer manually.
# ========= SAVE LATEST AUDIO (from the last batch processed) =========
# Access the first sample of the batch for saving
# Ensure enhanced_audio_batch has a batch dimension and is in NLC format
if enhanced_audio_batch.ndim == 2: # Assuming (length, channel), add batch dim
enhanced_audio_batch = jnp.expand_dims(enhanced_audio_batch, axis=0)
elif enhanced_audio_batch.ndim == 1: # Assuming (length), add batch and channel dims
enhanced_audio_batch = jnp.expand_dims(jnp.expand_dims(enhanced_audio_batch, axis=0), axis=-1)
last_high_quality_audio = high_quality_audio_batch[0]
last_low_quality_audio = low_quality_audio_batch[0]
last_ai_enhanced_audio = enhanced_audio_batch[0]
last_sample_rate = sample_rate_batch[0].item() # Assuming sample rate is scalar per batch item
# Save audio files periodically (outside the batch loop)
if generator_epoch % 25 == 0 and last_high_quality_audio is not None:
print(f"Saving audio for epoch {current_epoch}!")
try:
# Convert JAX arrays to NumPy arrays for saving
# Transpose back to (length, channels) or (length) if needed by wavfile.write
# Assuming the models output (length, 1) or (length) after removing batch dim
low_quality_audio_np_save = jax.device_get(last_low_quality_audio)
ai_enhanced_audio_np_save = jax.device_get(last_ai_enhanced_audio)
high_quality_audio_np_save = jax.device_get(last_high_quality_audio)
# Remove the channel dimension if it's 1 for saving with wavfile
if low_quality_audio_np_save.shape[-1] == 1:
low_quality_audio_np_save = low_quality_audio_np_save.squeeze(axis=-1)
if ai_enhanced_audio_np_save.shape[-1] == 1:
ai_enhanced_audio_np_save = ai_enhanced_audio_np_save.squeeze(axis=-1)
if high_quality_audio_np_save.shape[-1] == 1:
high_quality_audio_np_save = high_quality_audio_np_save.squeeze(axis=-1)
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-crap.wav", last_sample_rate, low_quality_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-ai.wav", last_sample_rate, ai_enhanced_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-orig.wav", last_sample_rate, high_quality_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
except Exception as e:
print(f"Error saving audio files: {e}")
# Save model states periodically (outside the batch loop)
# Use pickle to save parameters and optimizer states
try:
with open(f"{models_dir}/temp_discriminator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(discriminator_params), 'opt_state': jax.device_get(discriminator_opt_state)}, f)
with open(f"{models_dir}/temp_generator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(generator_params), 'opt_state': jax.device_get(generator_opt_state)}, f)
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": current_epoch})
except Exception as e:
print(f"Error saving temp model states: {e}")
# Save final model states after all epochs
print("Training complete! Saving final models.")
try:
with open(f"{models_dir}/epoch-{start_epoch + generator_epochs - 1}-discriminator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(discriminator_params)}, f)
with open(f"{models_dir}/epoch-{start_epoch + generator_epochs - 1}-generator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(generator_params)}, f)
except Exception as e:
print(f"Error saving final model states: {e}")
torch.save(discriminator.state_dict(), "discriminator.pt")
torch.save(generator.state_dict(), "generator.pt")
print("Training complete!")
start_training() start_training()

194
training.txt Normal file
View File

@ -0,0 +1,194 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import tqdm
import argparse
import math
import os
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import AudioUtils
from data import AudioDataset
from generator import SISUGenerator
from discriminator import SISUDiscriminator
from training_utils import discriminator_train, generator_train
import file_utils as Data
import torchaudio.transforms as T
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models")
args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Parameters
sample_rate = 44100
n_fft = 2048
hop_length = 256
win_length = n_fft
n_mels = 128
n_mfcc = 20 # If using MFCC
mfcc_transform = T.MFCC(
sample_rate,
n_mfcc,
melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
).to(device)
mel_transform = T.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
).to(device)
stft_transform = T.Spectrogram(
n_fft=n_fft, win_length=win_length, hop_length=hop_length
).to(device)
debug = args.debug
# Initialize dataset and dataloader
dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir, device)
models_dir = "models"
os.makedirs(models_dir, exist_ok=True)
audio_output_dir = "output"
os.makedirs(audio_output_dir, exist_ok=True)
# ========= SINGLE =========
train_data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
# ========= MODELS =========
generator = SISUGenerator()
discriminator = SISUDiscriminator()
epoch: int = args.epoch
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
if args.continue_training:
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
epoch = epoch_from_file["epoch"] + 1
else:
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
if args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
generator = generator.to(device)
discriminator = discriminator.to(device)
# Loss
criterion_g = nn.BCEWithLogitsLoss()
criterion_d = nn.BCEWithLogitsLoss()
# Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
# Scheduler
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
def start_training():
generator_epochs = 5000
for generator_epoch in range(generator_epochs):
low_quality_audio = (torch.empty((1)), 1)
high_quality_audio = (torch.empty((1)), 1)
ai_enhanced_audio = (torch.empty((1)), 1)
times_correct = 0
# ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
# for high_quality_clip, low_quality_clip in train_data_loader:
high_quality_sample = (high_quality_clip[0], high_quality_clip[1])
low_quality_sample = (low_quality_clip[0], low_quality_clip[1])
# ========= LABELS =========
batch_size = high_quality_clip[0].size(0)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# ========= DISCRIMINATOR =========
discriminator.train()
d_loss = discriminator_train(
high_quality_sample,
low_quality_sample,
real_labels,
fake_labels,
discriminator,
generator,
criterion_d,
optimizer_d
)
# ========= GENERATOR =========
generator.train()
generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train(
low_quality_sample,
high_quality_sample,
real_labels,
generator,
discriminator,
criterion_d,
optimizer_g,
device,
mel_transform,
stft_transform,
mfcc_transform
)
if debug:
print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}")
scheduler_d.step(d_loss.detach())
scheduler_g.step(adversarial_loss.detach())
# ========= SAVE LATEST AUDIO =========
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])
low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0])
ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0])
new_epoch = generator_epoch+epoch
if generator_epoch % 25 == 0:
print(f"Saved epoch {new_epoch}!")
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu().detach(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1])
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1])
#if debug:
# print(generator.state_dict().keys())
# print(discriminator.state_dict().keys())
torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt")
torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt")
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch})
torch.save(discriminator, "models/epoch-5000-discriminator.pt")
torch.save(generator, "models/epoch-5000-generator.pt")
print("Training complete!")
start_training()

142
training_utils.py Normal file
View File

@ -0,0 +1,142 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import torchaudio.transforms as T
def gpu_mfcc_loss(mfcc_transform, y_true, y_pred):
mfccs_true = mfcc_transform(y_true)
mfccs_pred = mfcc_transform(y_pred)
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
mfccs_true = mfccs_true[:, :, :min_len]
mfccs_pred = mfccs_pred[:, :, :min_len]
loss = torch.mean((mfccs_true - mfccs_pred)**2)
return loss
def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
mel_spec_true = mel_transform(y_true)
mel_spec_pred = mel_transform(y_pred)
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
mel_spec_true = mel_spec_true[..., :min_len]
mel_spec_pred = mel_spec_pred[..., :min_len]
loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred))
return loss
def mel_spectrogram_l2_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
mel_spec_true = mel_transform(y_true)
mel_spec_pred = mel_transform(y_pred)
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
mel_spec_true = mel_spec_true[..., :min_len]
mel_spec_pred = mel_spec_pred[..., :min_len]
loss = torch.mean((mel_spec_true - mel_spec_pred)**2)
return loss
def log_stft_magnitude_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
stft_mag_true = stft_transform(y_true)
stft_mag_pred = stft_transform(y_pred)
min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1])
stft_mag_true = stft_mag_true[..., :min_len]
stft_mag_pred = stft_mag_pred[..., :min_len]
loss = torch.mean(torch.abs(torch.log(stft_mag_true + eps) - torch.log(stft_mag_pred + eps)))
return loss
def spectral_convergence_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
stft_mag_true = stft_transform(y_true)
stft_mag_pred = stft_transform(y_pred)
min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1])
stft_mag_true = stft_mag_true[..., :min_len]
stft_mag_pred = stft_mag_pred[..., :min_len]
norm_true = torch.linalg.norm(stft_mag_true, ord='fro', dim=(-2, -1))
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, ord='fro', dim=(-2, -1))
loss = torch.mean(norm_diff / (norm_true + eps))
return loss
def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer):
optimizer.zero_grad()
# Forward pass for real samples
discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
with torch.no_grad():
generator_output = generator(low_quality[0])
discriminator_decision_from_fake = discriminator(generator_output)
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake))
d_loss = (d_loss_real + d_loss_fake) / 2.0
d_loss.backward()
# Optional: Gradient Clipping (can be helpful)
# nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
optimizer.step()
return d_loss
def generator_train(
low_quality,
high_quality,
real_labels,
generator,
discriminator,
adv_criterion,
g_optimizer,
device,
mel_transform: T.MelSpectrogram,
stft_transform: T.Spectrogram,
mfcc_transform: T.MFCC,
lambda_adv: float = 1.0,
lambda_mel_l1: float = 10.0,
lambda_log_stft: float = 1.0,
lambda_mfcc: float = 1.0
):
g_optimizer.zero_grad()
generator_output = generator(low_quality[0])
discriminator_decision = discriminator(generator_output)
adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision))
mel_l1 = 0.0
log_stft_l1 = 0.0
mfcc_l = 0.0
# Calculate Mel L1 Loss if weight is positive
if lambda_mel_l1 > 0:
mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality[0], generator_output)
# Calculate Log STFT L1 Loss if weight is positive
if lambda_log_stft > 0:
log_stft_l1 = log_stft_magnitude_loss(stft_transform, high_quality[0], generator_output)
# Calculate MFCC Loss if weight is positive
if lambda_mfcc > 0:
mfcc_l = gpu_mfcc_loss(mfcc_transform, high_quality[0], generator_output)
mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1
log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1
mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l
combined_loss = (lambda_adv * adversarial_loss) + \
(lambda_mel_l1 * mel_l1_tensor) + \
(lambda_log_stft * log_stft_l1_tensor) + \
(lambda_mfcc * mfcc_l_tensor)
combined_loss.backward()
# Optional: Gradient Clipping
# nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
g_optimizer.step()
# 6. Return values for logging
return generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor