1 Commits

Author SHA1 Message Date
1717e7a008 ⚗️ | Experimenting... 2025-02-10 19:35:50 +02:00
9 changed files with 218 additions and 1189 deletions

View File

@ -18,7 +18,6 @@ SISU (Super Ingenious Sound Upscaler) is a project that uses GANs (Generative Ad
1. **Set Up**: 1. **Set Up**:
- Make sure you have Python installed (version 3.8 or higher). - Make sure you have Python installed (version 3.8 or higher).
- Install needed packages: `pip install -r requirements.txt` - Install needed packages: `pip install -r requirements.txt`
- Install current version of PyTorch (CUDA/ROCm/What ever your device supports)
2. **Prepare Audio Data**: 2. **Prepare Audio Data**:
- Put your audio files in the `dataset/good` folder. - Put your audio files in the `dataset/good` folder.

106
data.py
View File

@ -1,104 +1,52 @@
# Keep necessary PyTorch imports for torchaudio and Dataset structure
from torch.utils.data import Dataset from torch.utils.data import Dataset
import torch.nn.functional as F
import torch import torch
import torchaudio import torchaudio
import torchaudio.transforms as T # Keep using torchaudio transforms
# Import NumPy
import numpy as np
import os import os
import random import random
# Assume AudioUtils is available and works on PyTorch Tensors as before from AudioUtils import stereo_tensor_to_mono, stretch_tensor
import AudioUtils
class AudioDatasetNumPy(Dataset): # Renamed slightly for clarity class AudioDataset(Dataset):
audio_sample_rates = [11025] audio_sample_rates = [11025]
MAX_LENGTH = 44100 # Define your desired maximum length here
def __init__(self, input_dir): def __init__(self, input_dir):
"""
Initializes the dataset. Device argument is removed.
"""
self.input_files = [ self.input_files = [
os.path.join(root, f) os.path.join(root, f)
for root, _, files in os.walk(input_dir) for root, _, files in os.walk(input_dir)
for f in files if f.endswith('.wav') for f in files if f.endswith('.wav')
] ]
if not self.input_files:
print(f"Warning: No .wav files found in {input_dir}")
def __len__(self): def __len__(self):
return len(self.input_files) return len(self.input_files)
def __getitem__(self, idx): def __getitem__(self, idx):
""" # Load high-quality audio
Loads audio, processes it, and returns NumPy arrays. high_quality_path = self.input_files[idx]
""" high_quality_audio, original_sample_rate = torchaudio.load(high_quality_path)
# --- Load and Resample using torchaudio (produces PyTorch tensors) --- high_quality_audio = stereo_tensor_to_mono(high_quality_audio)
try:
high_quality_audio_pt, original_sample_rate = torchaudio.load(
self.input_files[idx], normalize=True
)
except Exception as e:
print(f"Error loading file {self.input_files[idx]}: {e}")
# Return None or raise error, or return dummy data if preferred
# Returning dummy data might hide issues
return None # Or handle appropriately
# Generate low-quality audio with random downsampling # Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates) mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio = resample_low(high_quality_audio)
# Ensure sample rates are different before resampling resample_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
if original_sample_rate != mangled_sample_rate: low_quality_audio = resample_high(low_quality_audio)
resample_transform_low = T.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio_pt = resample_transform_low(high_quality_audio_pt)
resample_transform_high = T.Resample(mangled_sample_rate, original_sample_rate) # Pad or truncate to match a fixed length
low_quality_audio_pt = resample_transform_high(low_quality_audio_pt) target_length = 44100 # Adjust this based on your data
high_quality_audio = self.pad_or_truncate(high_quality_audio, target_length)
low_quality_audio = self.pad_or_truncate(low_quality_audio, target_length)
return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)
def pad_or_truncate(self, tensor, target_length):
current_length = tensor.size(1)
if current_length < target_length:
# Pad with zeros
padding = target_length - current_length
tensor = F.pad(tensor, (0, padding))
else: else:
# If rates match, just copy the tensor # Truncate to target length
low_quality_audio_pt = high_quality_audio_pt.clone() tensor = tensor[:, :target_length]
return tensor
# --- Process Stereo to Mono (still using PyTorch tensors) ---
# Assuming AudioUtils.stereo_tensor_to_mono expects PyTorch Tensor (C, L)
# and returns PyTorch Tensor (1, L)
try:
high_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(high_quality_audio_pt)
low_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(low_quality_audio_pt)
except Exception as e:
# Handle cases where mono conversion might fail (e.g., already mono)
# This depends on how AudioUtils is implemented. Let's assume it handles it.
print(f"Warning: Mono conversion issue with {self.input_files[idx]}: {e}. Using original.")
high_quality_audio_pt_mono = high_quality_audio_pt if high_quality_audio_pt.shape[0] == 1 else torch.mean(high_quality_audio_pt, dim=0, keepdim=True)
low_quality_audio_pt_mono = low_quality_audio_pt if low_quality_audio_pt.shape[0] == 1 else torch.mean(low_quality_audio_pt, dim=0, keepdim=True)
# --- Convert to NumPy Arrays ---
high_quality_audio_np = high_quality_audio_pt_mono.numpy() # Shape (1, L)
low_quality_audio_np = low_quality_audio_pt_mono.numpy() # Shape (1, L)
# --- Pad or Truncate using NumPy ---
def process_numpy_audio(audio_np, max_len):
current_len = audio_np.shape[1] # Length is axis 1 for shape (1, L)
if current_len < max_len:
padding_needed = max_len - current_len
# np.pad format: ((pad_before_ax0, pad_after_ax0), (pad_before_ax1, pad_after_ax1), ...)
# We only pad axis 1 (length) at the end
audio_np = np.pad(audio_np, ((0, 0), (0, padding_needed)), mode='constant', constant_values=0)
elif current_len > max_len:
# Truncate axis 1 (length)
audio_np = audio_np[:, :max_len]
return audio_np
high_quality_audio_np = process_numpy_audio(high_quality_audio_np, self.MAX_LENGTH)
low_quality_audio_np = process_numpy_audio(low_quality_audio_np, self.MAX_LENGTH)
# --- Remove Device Handling ---
# .to(self.device) is removed.
# --- Return NumPy arrays and metadata ---
# Note: Arrays likely have shape (1, MAX_LENGTH) here
return (high_quality_audio_np, original_sample_rate), (low_quality_audio_np, mangled_sample_rate)

View File

@ -1,145 +1,38 @@
import jax import torch
import jax.numpy as jnp import torch.nn as nn
from flax import linen as nn import torch.nn.utils as utils
from typing import Sequence, Tuple
# Assume InstanceNorm1d and AttentionBlock are defined as in the generator conversion def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
# --- Custom InstanceNorm1d Implementation (from Generator) --- padding = (kernel_size // 2) * dilation
class InstanceNorm1d(nn.Module): return nn.Sequential(
features: int utils.spectral_norm(
epsilon: float = 1e-5 nn.Conv1d(in_channels, out_channels,
use_scale: bool = True kernel_size=kernel_size,
use_bias: bool = True stride=stride,
@nn.compact dilation=dilation,
def __call__(self, x): padding=padding
if x.shape[-1] != self.features: )
raise ValueError(f"Input features {x.shape[-1]} does not match InstanceNorm1d features {self.features}") ),
mean = jnp.mean(x, axis=1, keepdims=True) nn.BatchNorm1d(out_channels),
var = jnp.var(x, axis=1, keepdims=True) nn.LeakyReLU(0.2, inplace=True)
normalized = (x - mean) / jnp.sqrt(var + self.epsilon) )
if self.use_scale:
scale = self.param('scale', nn.initializers.ones, (self.features,))
normalized *= scale
if self.use_bias:
bias = self.param('bias', nn.initializers.zeros, (self.features,))
normalized += bias
return normalized
# --- AttentionBlock Implementation (from Generator) ---
class AttentionBlock(nn.Module):
channels: int
@nn.compact
def __call__(self, x):
ks1 = (1,)
attention_weights = nn.Conv(features=self.channels // 4, kernel_size=ks1, padding='SAME')(x)
attention_weights = nn.relu(attention_weights)
attention_weights = nn.Conv(features=self.channels, kernel_size=ks1, padding='SAME')(attention_weights)
attention_weights = nn.sigmoid(attention_weights)
return x * attention_weights
# --- Converted Discriminator Modules ---
class DiscriminatorBlock(nn.Module):
"""Equivalent of the PyTorch discriminator_block function."""
in_channels: int # Needed for clarity, though not strictly used by layers if input shape is known
out_channels: int
kernel_size: int = 3
stride: int = 1
dilation: int = 1
# spectral_norm: bool = True # Flag for where SN would be applied
use_instance_norm: bool = True
negative_slope: float = 0.2
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C_in)
Returns:
Output tensor (N, L', C_out) - L' depends on stride/padding
"""
# Flax Conv expects kernel_size, stride, dilation as sequences (tuples)
ks = (self.kernel_size,)
st = (self.stride,)
di = (self.dilation,)
# Padding='SAME' works reasonably well for stride=1 and stride=2 downsampling
# NOTE: Spectral Norm is omitted here.
# If implementing, you'd wrap or replace nn.Conv with a spectral-normalized version.
# conv_layer = SpectralNormConv1D(...) or wrap(nn.Conv(...))
y = nn.Conv(
features=self.out_channels,
kernel_size=ks,
strides=st,
kernel_dilation=di,
padding='SAME' # Often used in GANs
)(x)
# Apply LeakyReLU first (as in the original code if IN is used)
y = nn.leaky_relu(y, negative_slope=self.negative_slope)
# Conditionally apply InstanceNorm
if self.use_instance_norm:
y = InstanceNorm1d(features=self.out_channels)(y)
return y
class SISUDiscriminator(nn.Module): class SISUDiscriminator(nn.Module):
"""SISUDiscriminator model translated to Flax.""" def __init__(self):
base_channels: int = 16 super(SISUDiscriminator, self).__init__()
layers = 4
self.model = nn.Sequential(
discriminator_block(1, layers, kernel_size=7, stride=2, dilation=1),
discriminator_block(layers, layers * 2, kernel_size=5, stride=2, dilation=1),
discriminator_block(layers * 2, layers * 4, kernel_size=3, dilation=4),
discriminator_block(layers * 4, layers * 4, kernel_size=5, dilation=8),
discriminator_block(layers * 4, layers * 2, kernel_size=3, dilation=16),
discriminator_block(layers * 2, layers, kernel_size=5, dilation=2),
discriminator_block(layers, 1, kernel_size=3, stride=1)
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
@nn.compact def forward(self, x):
def __call__(self, x): x = self.model(x)
""" x = self.global_avg_pool(x)
Args: return x.view(-1, 1)
x: Input tensor (N, L, 1) - assumes single channel input
Returns:
Output tensor (N, 1) - logits
"""
if x.shape[-1] != 1:
raise ValueError(f"Input should have 1 channel (NLC format), got shape {x.shape}")
ch = self.base_channels
# Block 1: 1 -> ch, k=7, s=1, d=1, SN=T, IN=F
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=1, out_channels=ch, kernel_size=7, stride=1, use_instance_norm=False)(x)
# Block 2: ch -> ch*2, k=5, s=2, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch, out_channels=ch*2, kernel_size=5, stride=2, use_instance_norm=True)(y)
# Block 3: ch*2 -> ch*4, k=5, s=1, d=2, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch*4, kernel_size=5, stride=1, dilation=2, use_instance_norm=True)(y)
# Attention Block
y = AttentionBlock(channels=ch*4)(y)
# Block 4: ch*4 -> ch*8, k=5, s=1, d=4, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*8, kernel_size=5, stride=1, dilation=4, use_instance_norm=True)(y)
# Block 5: ch*8 -> ch*4, k=5, s=2, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*8, out_channels=ch*4, kernel_size=5, stride=2, use_instance_norm=True)(y)
# Block 6: ch*4 -> ch*2, k=3, s=1, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*2, kernel_size=3, stride=1, use_instance_norm=True)(y)
# Block 7: ch*2 -> ch, k=3, s=1, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch, kernel_size=3, stride=1, use_instance_norm=True)(y)
# Block 8: ch -> 1, k=3, s=1, d=1, SN=F, IN=F
# NOTE: Spectral Norm omitted (as per original config)
y = DiscriminatorBlock(in_channels=ch, out_channels=1, kernel_size=3, stride=1, use_instance_norm=False)(y)
# Global Average Pooling (across Length dimension)
pooled = jnp.mean(y, axis=1) # Shape becomes (N, C=1)
# Flatten (optional, as shape is likely already (N, 1))
output = jnp.reshape(pooled, (pooled.shape[0], -1)) # Shape (N, 1)
return output

View File

@ -1,28 +0,0 @@
import json
filepath = "my_data.json"
def write_data(filepath, data):
try:
with open(filepath, 'w') as f:
json.dump(data, f, indent=4) # Use indent for pretty formatting
print(f"Data written to '{filepath}'")
except Exception as e:
print(f"Error writing to file: {e}")
def read_data(filepath):
try:
with open(filepath, 'r') as f:
data = json.load(f)
print(f"Data read from '{filepath}'")
return data
except FileNotFoundError:
print(f"File not found: {filepath}")
return None
except json.JSONDecodeError:
print(f"Error decoding JSON from file: {filepath}")
return None
except Exception as e:
print(f"Error reading from file: {e}")
return None

View File

@ -1,173 +1,41 @@
import jax import torch.nn as nn
import jax.numpy as jnp
from flax import linen as nn
from typing import Sequence, Tuple
# --- Custom InstanceNorm1d Implementation --- def conv_residual_block(in_channels, out_channels, kernel_size=3, dilation=1):
class InstanceNorm1d(nn.Module): padding = (kernel_size // 2) * dilation
""" return nn.Sequential(
Flax implementation of Instance Normalization for 1D data (NLC format). nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=padding),
Normalizes across the 'L' dimension. nn.BatchNorm1d(out_channels),
""" nn.PReLU(),
features: int nn.Conv1d(out_channels, out_channels, kernel_size, dilation=dilation, padding=padding),
epsilon: float = 1e-5 nn.BatchNorm1d(out_channels)
use_scale: bool = True )
use_bias: bool = True
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor of shape (batch, length, features)
Returns:
Normalized tensor.
"""
if x.shape[-1] != self.features:
raise ValueError(f"Input features {x.shape[-1]} does not match InstanceNorm1d features {self.features}")
# Calculate mean and variance across the length dimension (axis=1)
# Keep dims for broadcasting
mean = jnp.mean(x, axis=1, keepdims=True)
# Variance calculation using mean needs care for numerical stability if needed,
# but jnp.var should handle it.
var = jnp.var(x, axis=1, keepdims=True)
# Normalize
normalized = (x - mean) / jnp.sqrt(var + self.epsilon)
# Apply learnable scale and bias if enabled
if self.use_scale:
# Parameter shape: (features,) to broadcast across N and L
scale = self.param('scale', nn.initializers.ones, (self.features,))
normalized *= scale
if self.use_bias:
# Parameter shape: (features,)
bias = self.param('bias', nn.initializers.zeros, (self.features,))
normalized += bias
return normalized
# --- Converted Modules ---
class ConvBlock(nn.Module):
"""Equivalent of the PyTorch conv_block function."""
out_channels: int
kernel_size: int = 3
dilation: int = 1
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C_in)
Returns:
Output tensor (N, L, C_out)
"""
# Flax Conv expects kernel_size and dilation as sequences (tuples)
ks = (self.kernel_size,)
di = (self.dilation,)
# Padding='SAME' attempts to preserve the length dimension for stride=1
x = nn.Conv(
features=self.out_channels,
kernel_size=ks,
kernel_dilation=di,
padding='SAME'
)(x)
x = InstanceNorm1d(features=self.out_channels)(x) # Use custom InstanceNorm
x = nn.PReLU()(x) # PReLU learns 'alpha' parameter per channel
return x
class AttentionBlock(nn.Module):
"""Simple Channel Attention Block in Flax."""
channels: int
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C)
Returns:
Attention-weighted output tensor (N, L, C)
"""
# Flax Conv expects kernel_size as a sequence (tuple)
ks1 = (1,)
attention_weights = nn.Conv(
features=self.channels // 4, kernel_size=ks1, padding='SAME'
)(x)
# NOTE: PyTorch used inplace=True, JAX/Flax don't modify inplace
attention_weights = nn.relu(attention_weights)
attention_weights = nn.Conv(
features=self.channels, kernel_size=ks1, padding='SAME'
)(attention_weights)
attention_weights = nn.sigmoid(attention_weights)
return x * attention_weights
class ResidualInResidualBlock(nn.Module):
"""ResidualInResidualBlock in Flax."""
channels: int
num_convs: int = 3
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C)
Returns:
Output tensor (N, L, C)
"""
residual = x
y = x
# Sequentially apply ConvBlocks
for _ in range(self.num_convs):
y = ConvBlock(
out_channels=self.channels,
kernel_size=3, # Assuming kernel_size 3 as in original conv_block default
dilation=1 # Assuming dilation 1 as in original conv_block default
)(y)
y = AttentionBlock(channels=self.channels)(y)
return y + residual
class SISUGenerator(nn.Module): class SISUGenerator(nn.Module):
"""SISUGenerator model translated to Flax.""" def __init__(self):
channels: int = 16 super(SISUGenerator, self).__init__()
num_rirb: int = 4 layers = 4
alpha: float = 1.0 # Non-learnable parameter, passed during init self.conv1 = nn.Sequential(
nn.Conv1d(1, layers, kernel_size=7, padding=3),
nn.BatchNorm1d(layers),
nn.PReLU()
)
@nn.compact self.conv_blocks = nn.Sequential(
def __call__(self, x): conv_residual_block(layers, layers, kernel_size=3, dilation=1),
""" conv_residual_block(layers, layers * 2, kernel_size=5, dilation=2),
Args: conv_residual_block(layers * 2, layers * 4, kernel_size=3, dilation=16),
x: Input tensor (N, L, 1) - assumes single channel input conv_residual_block(layers * 4, layers * 2, kernel_size=5, dilation=8),
Returns: conv_residual_block(layers * 2, layers, kernel_size=5, dilation=2),
Output tensor (N, L, 1) conv_residual_block(layers, layers, kernel_size=3, dilation=1)
""" )
if x.shape[-1] != 1:
raise ValueError(f"Input should have 1 channel (NLC format), got shape {x.shape}")
residual_input = x self.final_layer = nn.Sequential(
nn.Conv1d(layers, 1, kernel_size=3, padding=1)
)
# Initial convolution block def forward(self, x):
# Flax Conv expects kernel_size as sequence residual = x
ks7 = (7,) x = self.conv1(x)
ks3 = (3,) x = self.conv_blocks(x) + x # Adding residual connection after blocks
y = nn.Conv(features=self.channels, kernel_size=ks7, padding='SAME')(x) x = self.final_layer(x)
y = InstanceNorm1d(features=self.channels)(y) return x + residual
y = nn.PReLU()(y)
# Residual-in-Residual Blocks
rirb_out = y
for _ in range(self.num_rirb):
rirb_out = ResidualInResidualBlock(channels=self.channels)(rirb_out)
# Final layer
learned_residual = nn.Conv(
features=1, kernel_size=ks3, padding='SAME'
)(rirb_out)
# Combine with input residual
output = residual_input + self.alpha * learned_residual
return output

View File

@ -4,9 +4,11 @@ Jinja2==3.1.4
MarkupSafe==2.1.5 MarkupSafe==2.1.5
mpmath==1.3.0 mpmath==1.3.0
networkx==3.4.2 networkx==3.4.2
numpy==2.2.3 numpy==2.2.1
pillow==11.0.0 pytorch-triton-rocm==3.2.0+git0d4682f0
setuptools==70.2.0 setuptools==70.2.0
sympy==1.13.3 sympy==1.13.1
torch==2.6.0.dev20241222+rocm6.2.4
torchaudio==2.6.0.dev20241222+rocm6.2.4
tqdm==4.67.1 tqdm==4.67.1
typing_extensions==4.12.2 typing_extensions==4.12.2

View File

@ -1,481 +1,164 @@
import jax
import jax.numpy as jnp
import optax
import tqdm
import pickle # Using pickle for simplicity to save JAX states
import os
import argparse
# You might need a JAX-compatible library for audio loading/saving or convert to numpy
import scipy.io.wavfile as wavfile # Example for saving audio
import torch import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import tqdm
import argparse
import math
from torch.utils.data import random_split
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import file_utils as Data import AudioUtils
from data import AudioDatasetNumPy from data import AudioDataset
from generator import SISUGenerator from generator import SISUGenerator
from discriminator import SISUDiscriminator from discriminator import SISUDiscriminator
def perceptual_loss(y_true, y_pred):
return torch.mean((y_true - y_pred) ** 2)
def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
optimizer_d.zero_grad()
# Forward pass for real samples
discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion_d(discriminator_decision_from_real, real_labels)
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
discriminator_decision_from_fake = discriminator(generator_output.detach())
d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels)
# Combine real and fake losses
d_loss = (d_loss_real + d_loss_fake) / 2.0
# Backward pass and optimization
d_loss.backward()
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
optimizer_d.step()
return d_loss
def generator_train(low_quality, real_labels):
optimizer_g.zero_grad()
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
discriminator_decision = discriminator(generator_output)
g_loss = criterion_g(discriminator_decision, real_labels)
g_loss.backward()
optimizer_g.step()
return generator_output
def first(objects):
if len(objects) >= 1:
return objects[0]
return objects
# Init script argument parser # Init script argument parser
parser = argparse.ArgumentParser(description="Training script") parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None, parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file") help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None, parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file") help="Path to the discriminator model file")
parser.add_argument("--epoch", type=int, default=0, help="Starting epoch for model versioning")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models")
args = parser.parse_args() args = parser.parse_args()
# Parameters # Check for CUDA availability
sample_rate = 44100 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_fft = 2048 print(f"Using device: {device}")
hop_length = 256
win_length = n_fft
n_mels = 128
n_mfcc = 20 # If using MFCC
debug = args.debug
# Initialize JAX random key
key = jax.random.PRNGKey(0)
# Initialize dataset and dataloader # Initialize dataset and dataloader
dataset_dir = './dataset/good' dataset_dir = './dataset/good'
dataset = AudioDatasetNumPy(dataset_dir) # Use your JAX dataset dataset = AudioDataset(dataset_dir)
train_data_loader = DataLoader(dataset, batch_size=4, shuffle=True) # Use your JAX DataLoader
models_dir = "models" # ========= SINGLE =========
os.makedirs(models_dir, exist_ok=True)
audio_output_dir = "output"
os.makedirs(audio_output_dir, exist_ok=True)
train_data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
# ========= MODELS ========= # Initialize models and move them to device
generator = SISUGenerator()
try: discriminator = SISUDiscriminator()
# Fetch the first batch
first_batch = next(iter(train_data_loader))
# The batch is a tuple: ((high_quality_audio_np, high_quality_sample_rate), (low_quality_audio_np, low_quality_sample_rate))
# We need the high-quality audio NumPy array batch for initialization
sample_input_np = first_batch[0][0] # Get the high-quality audio NumPy array batch
# Convert the NumPy array batch to a JAX array
sample_input_array = jnp.array(sample_input_np)
# === FIX ===
# Transpose the array from (batch, channels, length) to (batch, length, channels)
# The original shape from DataLoader is likely (batch, channels, length) like (4, 1, 44100)
# The generator expects NLC format (batch, length, channels) i.e., (4, 44100, 1)
if sample_input_array.ndim == 3 and sample_input_array.shape[1] == 1:
sample_input_array = jnp.transpose(sample_input_array, (0, 2, 1)) # Swap axes 1 and 2
print(sample_input_array.shape) # Should now print (4, 44100, 1)
# === END FIX ===
except StopIteration:
print("Error: Data loader is empty. Cannot initialize models.")
exit() # Exit if no data is available
key, init_key_g, init_key_d = jax.random.split(key, 3)
generator_model = SISUGenerator()
discriminator_model = SISUDiscriminator()
# Initialize parameters
generator_params = generator_model.init(init_key_g, sample_input_array)['params']
discriminator_params = discriminator_model.init(init_key_d, sample_input_array)['params']
# Define apply functions
generator_apply_fn = generator_model.apply
discriminator_apply_fn = discriminator_model.apply
# Loss functions (JAX equivalents)
# Optax provides common loss functions. BCEWithLogitsLoss is equivalent to
# sigmoid_binary_cross_entropy in Optax combined with a sigmoid activation
# in the model output or handling logits directly. Assuming your discriminator
# outputs logits, optax.sigmoid_binary_cross_entropy is appropriate.
criterion_d = optax.sigmoid_binary_cross_entropy
criterion_l1 = optax.sigmoid_binary_cross_entropy # For Mel, STFT, MFCC losses
# Optimizers (using Optax)
optimizer_g = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.999)
optimizer_d = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.999)
# Initialize optimizer states
generator_opt_state = optimizer_g.init(generator_params)
discriminator_opt_state = optimizer_d.init(discriminator_params)
# Schedulers - Optax has learning rate schedules. ReduceLROnPlateau
# is stateful and usually handled outside the jitted training step,
# or you can implement a custom learning rate schedule in Optax that
# takes a metric. For simplicity here, we won't directly replicate the
# PyTorch ReduceLROnPlateau but you could add logic based on losses
# in the main loop to adjust the learning rate if needed.
# Load saved state if continuing training
start_epoch = args.epoch
if args.continue_training:
try:
with open(f"{models_dir}/temp_generator.pkl", 'rb') as f:
loaded_state = pickle.load(f)
generator_params = loaded_state['params']
generator_opt_state = loaded_state['opt_state']
with open(f"{models_dir}/temp_discriminator.pkl", 'rb') as f:
loaded_state = pickle.load(f)
discriminator_params = loaded_state['params']
discriminator_opt_state = loaded_state['opt_state']
epoch_data = Data.read_data(f"{models_dir}/epoch_data.json")
start_epoch = epoch_data.get("epoch", 0) + 1
print(f"Continuing training from epoch {start_epoch}")
except FileNotFoundError:
print("Continue training requested but temp models not found. Starting from scratch.")
except Exception as e:
print(f"Error loading temp models: {e}. Starting from scratch.")
if args.generator is not None: if args.generator is not None:
try: generator.load_state_dict(torch.load(args.generator, weights_only=True))
with open(args.generator, 'rb') as f:
loaded_state = pickle.load(f)
generator_params = loaded_state['params']
print(f"Loaded generator from {args.generator}")
except FileNotFoundError:
print(f"Generator model not found at {args.generator}")
if args.discriminator is not None: if args.discriminator is not None:
try: discriminator.load_state_dict(torch.load(args.discriminator, weights_only=True))
with open(args.discriminator, 'rb') as f:
loaded_state = pickle.load(f)
discriminator_params = loaded_state['params']
print(f"Loaded discriminator from {args.discriminator}")
except FileNotFoundError:
print(f"Discriminator model not found at {args.discriminator}")
generator = generator.to(device)
discriminator = discriminator.to(device)
# Initialize JAX audio transforms # Loss
# mel_transform_fn = MelSpectrogramJAX( criterion_g = nn.MSELoss()
# sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, criterion_d = nn.BCELoss()
# win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
# )
# stft_transform_fn = SpectrogramJAX( # Optimizers
# n_fft=n_fft, win_length=win_length, hop_length=hop_length optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
# ) optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
# mfcc_transform_fn = MFCCJAX( # Scheduler
# sample_rate, scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
# n_mfcc, scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
# melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
# )
# ========= JAX TRAINING STEPS =========
@jax.jit
def discriminator_train_step(
discriminator_params,
generator_params,
discriminator_opt_state,
high_quality_audio, # JAX array (batch, length, channels)
low_quality_audio, # JAX array (batch, length, channels)
real_labels, # JAX array
fake_labels, # JAX array
discriminator_apply_fn,
generator_apply_fn,
discriminator_optimizer,
criterion_d,
key # JAX random key
):
# Split key for potential randomness in model application
key, disc_key, gen_key = jax.random.split(key, 3)
def loss_fn(d_params):
# Generate fake audio
# Note: Generator is not being trained in this step, so its parameters are static
# Ensure low_quality_audio is in the expected NLC format (batch, length, channels)
if low_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
low_quality_audio = jnp.expand_dims(low_quality_audio, axis=-1)
elif low_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
low_quality_audio = jnp.expand_dims(jnp.expand_dims(low_quality_audio, axis=0), axis=-1)
enhanced_audio, _ = generator_apply_fn({'params': generator_params}, gen_key, low_quality_audio)
# Pass data through the discriminator
# Ensure enhanced_audio has a leading dimension if not already present (e.g., batch size)
if enhanced_audio.ndim == 2: # Assuming (length, channel), add batch dim
enhanced_audio = jnp.expand_dims(enhanced_audio, axis=0)
elif enhanced_audio.ndim == 1: # Assuming (length), add batch and channel dims
enhanced_audio = jnp.expand_dims(jnp.expand_dims(enhanced_audio, axis=0), axis=-1)
# Ensure high_quality_audio is in the expected NLC format (batch, length, channels)
if high_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
high_quality_audio = jnp.expand_dims(high_quality_audio, axis=-1)
elif high_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
high_quality_audio = jnp.expand_dims(jnp.expand_dims(high_quality_audio, axis=0), axis=-1)
real_output = discriminator_apply_fn({'params': d_params}, disc_key, high_quality_audio)
fake_output = discriminator_apply_fn({'params': d_params}, disc_key, enhanced_audio)
# Calculate loss (criterion_d is assumed to be Optax's sigmoid_binary_cross_entropy or similar)
# Ensure the shapes match the labels (batch_size, 1)
real_output = real_output.reshape(-1, 1)
fake_output = fake_output.reshape(-1, 1)
real_loss = jnp.mean(criterion_d(real_output, real_labels))
fake_loss = jnp.mean(criterion_d(fake_output, fake_labels))
total_loss = real_loss + fake_loss
return total_loss, (real_loss, fake_loss)
# Compute gradients
# Use jax.value_and_grad to get both the loss value and the gradients
(loss, (real_loss, fake_loss)), grads = jax.value_and_grad(loss_fn, has_aux=True)(discriminator_params)
# Apply updates
updates, new_discriminator_opt_state = discriminator_optimizer.update(grads, discriminator_opt_state, discriminator_params)
new_discriminator_params = optax.apply_updates(discriminator_params, updates)
return new_discriminator_params, new_discriminator_opt_state, loss, key
@jax.jit
def generator_train_step(
generator_params,
discriminator_params,
generator_opt_state,
low_quality_audio, # JAX array (batch, length, channels)
high_quality_audio, # JAX array (batch, length, channels)
real_labels, # JAX array
generator_apply_fn,
discriminator_apply_fn,
generator_optimizer,
criterion_d, # Adversarial loss
criterion_l1, # Feature matching loss
key # JAX random key
):
# Split key for potential randomness
key, gen_key, disc_key = jax.random.split(key, 3)
def loss_fn(g_params):
# Ensure low_quality_audio is in the expected NLC format (batch, length, channels)
if low_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
low_quality_audio = jnp.expand_dims(low_quality_audio, axis=-1)
elif low_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
low_quality_audio = jnp.expand_dims(jnp.expand_dims(low_quality_audio, axis=0), axis=-1)
# Generate enhanced audio
enhanced_audio, _ = generator_apply_fn({'params': g_params}, gen_key, low_quality_audio)
# Ensure enhanced_audio has a leading dimension if not already present
if enhanced_audio.ndim == 2: # Assuming (length, channel), add batch dim
enhanced_audio = jnp.expand_dims(enhanced_audio, axis=0)
elif enhanced_audio.ndim == 1: # Assuming (length), add batch and channel dims
enhanced_audio = jnp.expand_dims(jnp.expand_dims(enhanced_audio, axis=0), axis=-1)
# Calculate adversarial loss (generator wants discriminator to think fake is real)
# Note: Discriminator is not being trained in this step, so its parameters are static
fake_output = discriminator_apply_fn({'params': discriminator_params}, disc_key, enhanced_audio)
# Ensure the shape matches the labels (batch_size, 1)
fake_output = fake_output.reshape(-1, 1)
adversarial_loss = jnp.mean(criterion_d(fake_output, real_labels)) # Generator wants fake_output to be close to real_labels (1s)
# Feature matching losses (assuming you add these back later)
# You would need to implement JAX versions of your audio transforms
# mel_loss = criterion_l1(mel_transform_fn(enhanced_audio), mel_transform_fn(high_quality_audio))
# stft_loss = criterion_l1(stft_transform_fn(enhanced_audio), stft_transform_fn(high_quality_audio))
# mfcc_loss = criterion_l1(mfcc_transform_fn(enhanced_audio), mfcc_transform_fn(high_quality_audio))
# combined_loss = adversarial_loss + mel_loss + stft_loss + mfcc_loss
combined_loss = adversarial_loss # For now, only adversarial loss
# Return combined_loss and any other metrics needed for logging/analysis
# For now, just adversarial loss and enhanced_audio
return combined_loss, (adversarial_loss, enhanced_audio) # Add other losses here when implemented
# Compute gradients
# Update: loss_fn now returns (loss, (aux1, aux2, ...))
(loss, (adversarial_loss_val, enhanced_audio)), grads = jax.value_and_grad(loss_fn, has_aux=True)(generator_params)
# Apply updates
updates, new_generator_opt_state = generator_optimizer.update(grads, generator_opt_state, generator_params)
new_generator_params = optax.apply_updates(generator_params, updates)
# Return the loss components separately along with the enhanced audio and key
return new_generator_params, new_generator_opt_state, loss, adversarial_loss_val, enhanced_audio, key
# ========= MAIN TRAINING LOOP =========
def start_training(): def start_training():
global generator_params, discriminator_params, generator_opt_state, discriminator_opt_state, key
generator_epochs = 5000 generator_epochs = 5000
for generator_epoch in range(generator_epochs): for generator_epoch in range(generator_epochs):
current_epoch = start_epoch + generator_epoch low_quality_audio = (torch.empty((1)), 1)
high_quality_audio = (torch.empty((1)), 1)
ai_enhanced_audio = (torch.empty((1)), 1)
# These will hold the last processed audio examples from a batch for saving times_correct = 0
last_high_quality_audio = None
last_low_quality_audio = None
last_ai_enhanced_audio = None
last_sample_rate = None
# ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {generator_epoch+1}/{generator_epochs}"):
# for high_quality_clip, low_quality_clip in train_data_loader:
high_quality_sample = (high_quality_clip[0].to(device), high_quality_clip[1])
low_quality_sample = (low_quality_clip[0].to(device), low_quality_clip[1])
# Use tqdm for progress bar # ========= LABELS =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {current_epoch}"): batch_size = high_quality_clip[0].size(0)
real_labels = torch.ones(batch_size, 1).to(device)
# high_quality_clip and low_quality_clip are tuples: (audio_array, sample_rate_array) fake_labels = torch.zeros(batch_size, 1).to(device)
# Extract audio arrays and sample rates (assuming batch dimension is first)
# The arrays are NumPy arrays at this point, likely in (batch, channels, length) format
high_quality_audio_batch_np = high_quality_clip[0]
low_quality_audio_batch_np = low_quality_clip[0]
sample_rate_batch_np = high_quality_clip[1] # Assuming sample rates are the same for paired clips
# Convert NumPy arrays to JAX arrays and transpose to NLC format (batch, length, channels)
# Only transpose if the shape is (batch, channels, length)
if high_quality_audio_batch_np.ndim == 3 and high_quality_audio_batch_np.shape[1] == 1:
high_quality_audio_batch = jnp.transpose(jnp.array(high_quality_audio_batch_np), (0, 2, 1))
else:
high_quality_audio_batch = jnp.array(high_quality_audio_batch_np) # Assume already NLC or handle other cases
if low_quality_audio_batch_np.ndim == 3 and low_quality_audio_batch_np.shape[1] == 1:
low_quality_audio_batch = jnp.transpose(jnp.array(low_quality_audio_batch_np), (0, 2, 1))
else:
low_quality_audio_batch = jnp.array(low_quality_audio_batch_np) # Assume already NLC or handle other cases
sample_rate_batch = jnp.array(sample_rate_batch_np)
batch_size = high_quality_audio_batch.shape[0]
# Create labels - JAX arrays
real_labels = jnp.ones((batch_size, 1))
fake_labels = jnp.zeros((batch_size, 1))
# Split key for each batch
key, batch_key = jax.random.split(key)
# ========= DISCRIMINATOR ========= # ========= DISCRIMINATOR =========
# Call the jitted discriminator training step discriminator.train()
discriminator_params, discriminator_opt_state, d_loss, batch_key = discriminator_train_step( discriminator_train(high_quality_sample, low_quality_sample, real_labels, fake_labels)
discriminator_params,
generator_params,
discriminator_opt_state,
high_quality_audio_batch,
low_quality_audio_batch,
real_labels,
fake_labels,
discriminator_apply_fn,
generator_apply_fn,
optimizer_d,
criterion_d,
batch_key
)
# ========= GENERATOR ========= # ========= GENERATOR =========
# Call the jitted generator training step generator.train()
generator_params, generator_opt_state, combined_loss, adversarial_loss, enhanced_audio_batch, batch_key = generator_train_step( generator_output = generator_train(low_quality_sample, real_labels)
generator_params,
discriminator_params,
generator_opt_state,
low_quality_audio_batch,
high_quality_audio_batch,
real_labels, # Generator tries to make fake data look real
generator_apply_fn,
discriminator_apply_fn,
optimizer_g,
criterion_d,
criterion_l1,
batch_key
)
# Print debug logs (requires waiting for JIT compilation on first step) # ========= SAVE LATEST AUDIO =========
if debug: high_quality_audio = (first(high_quality_clip[0]), high_quality_clip[1][0])
# Use .block_until_ready() to ensure computation is finished before printing low_quality_audio = (first(low_quality_clip[0]), low_quality_clip[1][0])
# In a real scenario, you might want to log metrics less frequently ai_enhanced_audio = (first(generator_output[0]), high_quality_clip[1][0])
d_loss_val = d_loss.block_until_ready().item() print(high_quality_audio)
combined_loss_val = combined_loss.block_until_ready().item()
adversarial_loss_val = adversarial_loss.block_until_ready().item()
# Assuming other losses are returned by generator_train_step and unpacked
# mel_loss_val = mel_l1_tensor.block_until_ready().item() if mel_l1_tensor is not None else 0
# stft_loss_val = log_stft_l1_tensor.block_until_ready().item() if log_stft_l1_tensor is not None else 0
# mfcc_loss_val = mfcc_l_tensor.block_until_ready().item() if mfcc_l_tensor is not None else 0
print(f"D_LOSS: {d_loss_val:.4f}, G_COMBINED_LOSS: {combined_loss_val:.4f}, G_ADVERSARIAL_LOSS: {adversarial_loss_val:.4f}")
# Print other losses here when implemented and returned
# Schedulers - Implement your learning rate scheduling logic here if needed print(f"Saved epoch {generator_epoch}!")
# based on the losses (e.g., reducing learning rate if loss plateaus). torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
# This logic would typically live outside the jitted step function. torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1])
# For Optax, you might use a schedule within the optimizer definition torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1])
# or update the learning rate of the optimizer manually.
#metric = snr(high_quality_audio[0].to(device), ai_enhanced_audio[0])
#print(f"Generator metric {metric}!")
#scheduler_g.step(metric)
# ========= SAVE LATEST AUDIO (from the last batch processed) ========= if generator_epoch % 10 == 0:
# Access the first sample of the batch for saving print(f"Saved epoch {generator_epoch}!")
# Ensure enhanced_audio_batch has a batch dimension and is in NLC format torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
if enhanced_audio_batch.ndim == 2: # Assuming (length, channel), add batch dim torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1])
enhanced_audio_batch = jnp.expand_dims(enhanced_audio_batch, axis=0) torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1])
elif enhanced_audio_batch.ndim == 1: # Assuming (length), add batch and channel dims
enhanced_audio_batch = jnp.expand_dims(jnp.expand_dims(enhanced_audio_batch, axis=0), axis=-1)
torch.save(discriminator.state_dict(), f"models/current-epoch-discriminator.pt")
torch.save(generator.state_dict(), f"models/current-epoch-generator.pt")
last_high_quality_audio = high_quality_audio_batch[0] torch.save(discriminator.state_dict(), "models/epoch-5000-discriminator.pt")
last_low_quality_audio = low_quality_audio_batch[0] torch.save(generator.state_dict(), "models/epoch-5000-generator.pt")
last_ai_enhanced_audio = enhanced_audio_batch[0] print("Training complete!")
last_sample_rate = sample_rate_batch[0].item() # Assuming sample rate is scalar per batch item
# Save audio files periodically (outside the batch loop)
if generator_epoch % 25 == 0 and last_high_quality_audio is not None:
print(f"Saving audio for epoch {current_epoch}!")
try:
# Convert JAX arrays to NumPy arrays for saving
# Transpose back to (length, channels) or (length) if needed by wavfile.write
# Assuming the models output (length, 1) or (length) after removing batch dim
low_quality_audio_np_save = jax.device_get(last_low_quality_audio)
ai_enhanced_audio_np_save = jax.device_get(last_ai_enhanced_audio)
high_quality_audio_np_save = jax.device_get(last_high_quality_audio)
# Remove the channel dimension if it's 1 for saving with wavfile
if low_quality_audio_np_save.shape[-1] == 1:
low_quality_audio_np_save = low_quality_audio_np_save.squeeze(axis=-1)
if ai_enhanced_audio_np_save.shape[-1] == 1:
ai_enhanced_audio_np_save = ai_enhanced_audio_np_save.squeeze(axis=-1)
if high_quality_audio_np_save.shape[-1] == 1:
high_quality_audio_np_save = high_quality_audio_np_save.squeeze(axis=-1)
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-crap.wav", last_sample_rate, low_quality_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-ai.wav", last_sample_rate, ai_enhanced_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-orig.wav", last_sample_rate, high_quality_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
except Exception as e:
print(f"Error saving audio files: {e}")
# Save model states periodically (outside the batch loop)
# Use pickle to save parameters and optimizer states
try:
with open(f"{models_dir}/temp_discriminator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(discriminator_params), 'opt_state': jax.device_get(discriminator_opt_state)}, f)
with open(f"{models_dir}/temp_generator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(generator_params), 'opt_state': jax.device_get(generator_opt_state)}, f)
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": current_epoch})
except Exception as e:
print(f"Error saving temp model states: {e}")
# Save final model states after all epochs
print("Training complete! Saving final models.")
try:
with open(f"{models_dir}/epoch-{start_epoch + generator_epochs - 1}-discriminator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(discriminator_params)}, f)
with open(f"{models_dir}/epoch-{start_epoch + generator_epochs - 1}-generator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(generator_params)}, f)
except Exception as e:
print(f"Error saving final model states: {e}")
start_training() start_training()

View File

@ -1,194 +0,0 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import tqdm
import argparse
import math
import os
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import AudioUtils
from data import AudioDataset
from generator import SISUGenerator
from discriminator import SISUDiscriminator
from training_utils import discriminator_train, generator_train
import file_utils as Data
import torchaudio.transforms as T
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models")
args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Parameters
sample_rate = 44100
n_fft = 2048
hop_length = 256
win_length = n_fft
n_mels = 128
n_mfcc = 20 # If using MFCC
mfcc_transform = T.MFCC(
sample_rate,
n_mfcc,
melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
).to(device)
mel_transform = T.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
).to(device)
stft_transform = T.Spectrogram(
n_fft=n_fft, win_length=win_length, hop_length=hop_length
).to(device)
debug = args.debug
# Initialize dataset and dataloader
dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir, device)
models_dir = "models"
os.makedirs(models_dir, exist_ok=True)
audio_output_dir = "output"
os.makedirs(audio_output_dir, exist_ok=True)
# ========= SINGLE =========
train_data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
# ========= MODELS =========
generator = SISUGenerator()
discriminator = SISUDiscriminator()
epoch: int = args.epoch
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
if args.continue_training:
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
epoch = epoch_from_file["epoch"] + 1
else:
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
if args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
generator = generator.to(device)
discriminator = discriminator.to(device)
# Loss
criterion_g = nn.BCEWithLogitsLoss()
criterion_d = nn.BCEWithLogitsLoss()
# Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
# Scheduler
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
def start_training():
generator_epochs = 5000
for generator_epoch in range(generator_epochs):
low_quality_audio = (torch.empty((1)), 1)
high_quality_audio = (torch.empty((1)), 1)
ai_enhanced_audio = (torch.empty((1)), 1)
times_correct = 0
# ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
# for high_quality_clip, low_quality_clip in train_data_loader:
high_quality_sample = (high_quality_clip[0], high_quality_clip[1])
low_quality_sample = (low_quality_clip[0], low_quality_clip[1])
# ========= LABELS =========
batch_size = high_quality_clip[0].size(0)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# ========= DISCRIMINATOR =========
discriminator.train()
d_loss = discriminator_train(
high_quality_sample,
low_quality_sample,
real_labels,
fake_labels,
discriminator,
generator,
criterion_d,
optimizer_d
)
# ========= GENERATOR =========
generator.train()
generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train(
low_quality_sample,
high_quality_sample,
real_labels,
generator,
discriminator,
criterion_d,
optimizer_g,
device,
mel_transform,
stft_transform,
mfcc_transform
)
if debug:
print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}")
scheduler_d.step(d_loss.detach())
scheduler_g.step(adversarial_loss.detach())
# ========= SAVE LATEST AUDIO =========
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])
low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0])
ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0])
new_epoch = generator_epoch+epoch
if generator_epoch % 25 == 0:
print(f"Saved epoch {new_epoch}!")
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu().detach(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1])
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1])
#if debug:
# print(generator.state_dict().keys())
# print(discriminator.state_dict().keys())
torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt")
torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt")
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch})
torch.save(discriminator, "models/epoch-5000-discriminator.pt")
torch.save(generator, "models/epoch-5000-generator.pt")
print("Training complete!")
start_training()

View File

@ -1,142 +0,0 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import torchaudio.transforms as T
def gpu_mfcc_loss(mfcc_transform, y_true, y_pred):
mfccs_true = mfcc_transform(y_true)
mfccs_pred = mfcc_transform(y_pred)
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
mfccs_true = mfccs_true[:, :, :min_len]
mfccs_pred = mfccs_pred[:, :, :min_len]
loss = torch.mean((mfccs_true - mfccs_pred)**2)
return loss
def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
mel_spec_true = mel_transform(y_true)
mel_spec_pred = mel_transform(y_pred)
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
mel_spec_true = mel_spec_true[..., :min_len]
mel_spec_pred = mel_spec_pred[..., :min_len]
loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred))
return loss
def mel_spectrogram_l2_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
mel_spec_true = mel_transform(y_true)
mel_spec_pred = mel_transform(y_pred)
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
mel_spec_true = mel_spec_true[..., :min_len]
mel_spec_pred = mel_spec_pred[..., :min_len]
loss = torch.mean((mel_spec_true - mel_spec_pred)**2)
return loss
def log_stft_magnitude_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
stft_mag_true = stft_transform(y_true)
stft_mag_pred = stft_transform(y_pred)
min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1])
stft_mag_true = stft_mag_true[..., :min_len]
stft_mag_pred = stft_mag_pred[..., :min_len]
loss = torch.mean(torch.abs(torch.log(stft_mag_true + eps) - torch.log(stft_mag_pred + eps)))
return loss
def spectral_convergence_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
stft_mag_true = stft_transform(y_true)
stft_mag_pred = stft_transform(y_pred)
min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1])
stft_mag_true = stft_mag_true[..., :min_len]
stft_mag_pred = stft_mag_pred[..., :min_len]
norm_true = torch.linalg.norm(stft_mag_true, ord='fro', dim=(-2, -1))
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, ord='fro', dim=(-2, -1))
loss = torch.mean(norm_diff / (norm_true + eps))
return loss
def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer):
optimizer.zero_grad()
# Forward pass for real samples
discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
with torch.no_grad():
generator_output = generator(low_quality[0])
discriminator_decision_from_fake = discriminator(generator_output)
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake))
d_loss = (d_loss_real + d_loss_fake) / 2.0
d_loss.backward()
# Optional: Gradient Clipping (can be helpful)
# nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
optimizer.step()
return d_loss
def generator_train(
low_quality,
high_quality,
real_labels,
generator,
discriminator,
adv_criterion,
g_optimizer,
device,
mel_transform: T.MelSpectrogram,
stft_transform: T.Spectrogram,
mfcc_transform: T.MFCC,
lambda_adv: float = 1.0,
lambda_mel_l1: float = 10.0,
lambda_log_stft: float = 1.0,
lambda_mfcc: float = 1.0
):
g_optimizer.zero_grad()
generator_output = generator(low_quality[0])
discriminator_decision = discriminator(generator_output)
adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision))
mel_l1 = 0.0
log_stft_l1 = 0.0
mfcc_l = 0.0
# Calculate Mel L1 Loss if weight is positive
if lambda_mel_l1 > 0:
mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality[0], generator_output)
# Calculate Log STFT L1 Loss if weight is positive
if lambda_log_stft > 0:
log_stft_l1 = log_stft_magnitude_loss(stft_transform, high_quality[0], generator_output)
# Calculate MFCC Loss if weight is positive
if lambda_mfcc > 0:
mfcc_l = gpu_mfcc_loss(mfcc_transform, high_quality[0], generator_output)
mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1
log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1
mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l
combined_loss = (lambda_adv * adversarial_loss) + \
(lambda_mel_l1 * mel_l1_tensor) + \
(lambda_log_stft * log_stft_l1_tensor) + \
(lambda_mfcc * mfcc_l_tensor)
combined_loss.backward()
# Optional: Gradient Clipping
# nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
g_optimizer.step()
# 6. Return values for logging
return generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor