Compare commits

..

1 Commits
main ... jax

Author SHA1 Message Date
5735557ec3 💩 | VERY CRUDE JAX implementation... 2025-04-30 23:45:05 +03:00
6 changed files with 979 additions and 268 deletions

109
data.py
View File

@ -1,53 +1,104 @@
# Keep necessary PyTorch imports for torchaudio and Dataset structure
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch
import torchaudio
import torchaudio.transforms as T # Keep using torchaudio transforms
# Import NumPy
import numpy as np
import os
import random
import torchaudio.transforms as T
# Assume AudioUtils is available and works on PyTorch Tensors as before
import AudioUtils
class AudioDataset(Dataset):
class AudioDatasetNumPy(Dataset): # Renamed slightly for clarity
audio_sample_rates = [11025]
MAX_LENGTH = 44100 # Define your desired maximum length here
def __init__(self, input_dir, device):
self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')]
self.device = device
def __init__(self, input_dir):
"""
Initializes the dataset. Device argument is removed.
"""
self.input_files = [
os.path.join(root, f)
for root, _, files in os.walk(input_dir)
for f in files if f.endswith('.wav')
]
if not self.input_files:
print(f"Warning: No .wav files found in {input_dir}")
def __len__(self):
return len(self.input_files)
def __getitem__(self, idx):
# Load high-quality audio
high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True)
"""
Loads audio, processes it, and returns NumPy arrays.
"""
# --- Load and Resample using torchaudio (produces PyTorch tensors) ---
try:
high_quality_audio_pt, original_sample_rate = torchaudio.load(
self.input_files[idx], normalize=True
)
except Exception as e:
print(f"Error loading file {self.input_files[idx]}: {e}")
# Return None or raise error, or return dummy data if preferred
# Returning dummy data might hide issues
return None # Or handle appropriately
# Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio = resample_transform_low(high_quality_audio)
resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
low_quality_audio = resample_transform_high(low_quality_audio)
# Ensure sample rates are different before resampling
if original_sample_rate != mangled_sample_rate:
resample_transform_low = T.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio_pt = resample_transform_low(high_quality_audio_pt)
high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio)
low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio)
resample_transform_high = T.Resample(mangled_sample_rate, original_sample_rate)
low_quality_audio_pt = resample_transform_high(low_quality_audio_pt)
else:
# If rates match, just copy the tensor
low_quality_audio_pt = high_quality_audio_pt.clone()
# Pad or truncate high-quality audio
if high_quality_audio.shape[1] < self.MAX_LENGTH:
padding = self.MAX_LENGTH - high_quality_audio.shape[1]
high_quality_audio = F.pad(high_quality_audio, (0, padding))
elif high_quality_audio.shape[1] > self.MAX_LENGTH:
high_quality_audio = high_quality_audio[:, :self.MAX_LENGTH]
# Pad or truncate low-quality audio
if low_quality_audio.shape[1] < self.MAX_LENGTH:
padding = self.MAX_LENGTH - low_quality_audio.shape[1]
low_quality_audio = F.pad(low_quality_audio, (0, padding))
elif low_quality_audio.shape[1] > self.MAX_LENGTH:
low_quality_audio = low_quality_audio[:, :self.MAX_LENGTH]
# --- Process Stereo to Mono (still using PyTorch tensors) ---
# Assuming AudioUtils.stereo_tensor_to_mono expects PyTorch Tensor (C, L)
# and returns PyTorch Tensor (1, L)
try:
high_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(high_quality_audio_pt)
low_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(low_quality_audio_pt)
except Exception as e:
# Handle cases where mono conversion might fail (e.g., already mono)
# This depends on how AudioUtils is implemented. Let's assume it handles it.
print(f"Warning: Mono conversion issue with {self.input_files[idx]}: {e}. Using original.")
high_quality_audio_pt_mono = high_quality_audio_pt if high_quality_audio_pt.shape[0] == 1 else torch.mean(high_quality_audio_pt, dim=0, keepdim=True)
low_quality_audio_pt_mono = low_quality_audio_pt if low_quality_audio_pt.shape[0] == 1 else torch.mean(low_quality_audio_pt, dim=0, keepdim=True)
high_quality_audio = high_quality_audio.to(self.device)
low_quality_audio = low_quality_audio.to(self.device)
return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)
# --- Convert to NumPy Arrays ---
high_quality_audio_np = high_quality_audio_pt_mono.numpy() # Shape (1, L)
low_quality_audio_np = low_quality_audio_pt_mono.numpy() # Shape (1, L)
# --- Pad or Truncate using NumPy ---
def process_numpy_audio(audio_np, max_len):
current_len = audio_np.shape[1] # Length is axis 1 for shape (1, L)
if current_len < max_len:
padding_needed = max_len - current_len
# np.pad format: ((pad_before_ax0, pad_after_ax0), (pad_before_ax1, pad_after_ax1), ...)
# We only pad axis 1 (length) at the end
audio_np = np.pad(audio_np, ((0, 0), (0, padding_needed)), mode='constant', constant_values=0)
elif current_len > max_len:
# Truncate axis 1 (length)
audio_np = audio_np[:, :max_len]
return audio_np
high_quality_audio_np = process_numpy_audio(high_quality_audio_np, self.MAX_LENGTH)
low_quality_audio_np = process_numpy_audio(low_quality_audio_np, self.MAX_LENGTH)
# --- Remove Device Handling ---
# .to(self.device) is removed.
# --- Return NumPy arrays and metadata ---
# Note: Arrays likely have shape (1, MAX_LENGTH) here
return (high_quality_audio_np, original_sample_rate), (low_quality_audio_np, mangled_sample_rate)

View File

@ -1,63 +1,145 @@
import torch
import torch.nn as nn
import torch.nn.utils as utils
import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Sequence, Tuple
def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, spectral_norm=True, use_instance_norm=True):
padding = (kernel_size // 2) * dilation
conv_layer = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding
)
if spectral_norm:
conv_layer = utils.spectral_norm(conv_layer)
layers = [conv_layer]
layers.append(nn.LeakyReLU(0.2, inplace=True))
if use_instance_norm:
layers.append(nn.InstanceNorm1d(out_channels))
return nn.Sequential(*layers)
# Assume InstanceNorm1d and AttentionBlock are defined as in the generator conversion
# --- Custom InstanceNorm1d Implementation (from Generator) ---
class InstanceNorm1d(nn.Module):
features: int
epsilon: float = 1e-5
use_scale: bool = True
use_bias: bool = True
@nn.compact
def __call__(self, x):
if x.shape[-1] != self.features:
raise ValueError(f"Input features {x.shape[-1]} does not match InstanceNorm1d features {self.features}")
mean = jnp.mean(x, axis=1, keepdims=True)
var = jnp.var(x, axis=1, keepdims=True)
normalized = (x - mean) / jnp.sqrt(var + self.epsilon)
if self.use_scale:
scale = self.param('scale', nn.initializers.ones, (self.features,))
normalized *= scale
if self.use_bias:
bias = self.param('bias', nn.initializers.zeros, (self.features,))
normalized += bias
return normalized
# --- AttentionBlock Implementation (from Generator) ---
class AttentionBlock(nn.Module):
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
attention_weights = self.attention(x)
channels: int
@nn.compact
def __call__(self, x):
ks1 = (1,)
attention_weights = nn.Conv(features=self.channels // 4, kernel_size=ks1, padding='SAME')(x)
attention_weights = nn.relu(attention_weights)
attention_weights = nn.Conv(features=self.channels, kernel_size=ks1, padding='SAME')(attention_weights)
attention_weights = nn.sigmoid(attention_weights)
return x * attention_weights
# --- Converted Discriminator Modules ---
class DiscriminatorBlock(nn.Module):
"""Equivalent of the PyTorch discriminator_block function."""
in_channels: int # Needed for clarity, though not strictly used by layers if input shape is known
out_channels: int
kernel_size: int = 3
stride: int = 1
dilation: int = 1
# spectral_norm: bool = True # Flag for where SN would be applied
use_instance_norm: bool = True
negative_slope: float = 0.2
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C_in)
Returns:
Output tensor (N, L', C_out) - L' depends on stride/padding
"""
# Flax Conv expects kernel_size, stride, dilation as sequences (tuples)
ks = (self.kernel_size,)
st = (self.stride,)
di = (self.dilation,)
# Padding='SAME' works reasonably well for stride=1 and stride=2 downsampling
# NOTE: Spectral Norm is omitted here.
# If implementing, you'd wrap or replace nn.Conv with a spectral-normalized version.
# conv_layer = SpectralNormConv1D(...) or wrap(nn.Conv(...))
y = nn.Conv(
features=self.out_channels,
kernel_size=ks,
strides=st,
kernel_dilation=di,
padding='SAME' # Often used in GANs
)(x)
# Apply LeakyReLU first (as in the original code if IN is used)
y = nn.leaky_relu(y, negative_slope=self.negative_slope)
# Conditionally apply InstanceNorm
if self.use_instance_norm:
y = InstanceNorm1d(features=self.out_channels)(y)
return y
class SISUDiscriminator(nn.Module):
def __init__(self, base_channels=16):
super(SISUDiscriminator, self).__init__()
layers = base_channels
self.model = nn.Sequential(
discriminator_block(1, layers, kernel_size=7, stride=1, spectral_norm=True, use_instance_norm=False),
discriminator_block(layers, layers * 2, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers * 2, layers * 4, kernel_size=5, stride=1, dilation=2, spectral_norm=True, use_instance_norm=True),
AttentionBlock(layers * 4),
discriminator_block(layers * 4, layers * 8, kernel_size=5, stride=1, dilation=4, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers * 8, layers * 4, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers * 4, layers * 2, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers * 2, layers, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True),
discriminator_block(layers, 1, kernel_size=3, stride=1, spectral_norm=False, use_instance_norm=False)
)
"""SISUDiscriminator model translated to Flax."""
base_channels: int = 16
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, 1) - assumes single channel input
Returns:
Output tensor (N, 1) - logits
"""
if x.shape[-1] != 1:
raise ValueError(f"Input should have 1 channel (NLC format), got shape {x.shape}")
def forward(self, x):
x = self.model(x)
x = self.global_avg_pool(x)
x = x.view(x.size(0), -1)
return x
ch = self.base_channels
# Block 1: 1 -> ch, k=7, s=1, d=1, SN=T, IN=F
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=1, out_channels=ch, kernel_size=7, stride=1, use_instance_norm=False)(x)
# Block 2: ch -> ch*2, k=5, s=2, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch, out_channels=ch*2, kernel_size=5, stride=2, use_instance_norm=True)(y)
# Block 3: ch*2 -> ch*4, k=5, s=1, d=2, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch*4, kernel_size=5, stride=1, dilation=2, use_instance_norm=True)(y)
# Attention Block
y = AttentionBlock(channels=ch*4)(y)
# Block 4: ch*4 -> ch*8, k=5, s=1, d=4, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*8, kernel_size=5, stride=1, dilation=4, use_instance_norm=True)(y)
# Block 5: ch*8 -> ch*4, k=5, s=2, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*8, out_channels=ch*4, kernel_size=5, stride=2, use_instance_norm=True)(y)
# Block 6: ch*4 -> ch*2, k=3, s=1, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*2, kernel_size=3, stride=1, use_instance_norm=True)(y)
# Block 7: ch*2 -> ch, k=3, s=1, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch, kernel_size=3, stride=1, use_instance_norm=True)(y)
# Block 8: ch -> 1, k=3, s=1, d=1, SN=F, IN=F
# NOTE: Spectral Norm omitted (as per original config)
y = DiscriminatorBlock(in_channels=ch, out_channels=1, kernel_size=3, stride=1, use_instance_norm=False)(y)
# Global Average Pooling (across Length dimension)
pooled = jnp.mean(y, axis=1) # Shape becomes (N, C=1)
# Flatten (optional, as shape is likely already (N, 1))
output = jnp.reshape(pooled, (pooled.shape[0], -1)) # Shape (N, 1)
return output

View File

@ -1,74 +1,173 @@
import torch
import torch.nn as nn
import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Sequence, Tuple
def conv_block(in_channels, out_channels, kernel_size=3, dilation=1):
return nn.Sequential(
nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
dilation=dilation,
padding=(kernel_size // 2) * dilation
),
nn.InstanceNorm1d(out_channels),
nn.PReLU()
)
# --- Custom InstanceNorm1d Implementation ---
class InstanceNorm1d(nn.Module):
"""
Flax implementation of Instance Normalization for 1D data (NLC format).
Normalizes across the 'L' dimension.
"""
features: int
epsilon: float = 1e-5
use_scale: bool = True
use_bias: bool = True
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor of shape (batch, length, features)
Returns:
Normalized tensor.
"""
if x.shape[-1] != self.features:
raise ValueError(f"Input features {x.shape[-1]} does not match InstanceNorm1d features {self.features}")
# Calculate mean and variance across the length dimension (axis=1)
# Keep dims for broadcasting
mean = jnp.mean(x, axis=1, keepdims=True)
# Variance calculation using mean needs care for numerical stability if needed,
# but jnp.var should handle it.
var = jnp.var(x, axis=1, keepdims=True)
# Normalize
normalized = (x - mean) / jnp.sqrt(var + self.epsilon)
# Apply learnable scale and bias if enabled
if self.use_scale:
# Parameter shape: (features,) to broadcast across N and L
scale = self.param('scale', nn.initializers.ones, (self.features,))
normalized *= scale
if self.use_bias:
# Parameter shape: (features,)
bias = self.param('bias', nn.initializers.zeros, (self.features,))
normalized += bias
return normalized
# --- Converted Modules ---
class ConvBlock(nn.Module):
"""Equivalent of the PyTorch conv_block function."""
out_channels: int
kernel_size: int = 3
dilation: int = 1
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C_in)
Returns:
Output tensor (N, L, C_out)
"""
# Flax Conv expects kernel_size and dilation as sequences (tuples)
ks = (self.kernel_size,)
di = (self.dilation,)
# Padding='SAME' attempts to preserve the length dimension for stride=1
x = nn.Conv(
features=self.out_channels,
kernel_size=ks,
kernel_dilation=di,
padding='SAME'
)(x)
x = InstanceNorm1d(features=self.out_channels)(x) # Use custom InstanceNorm
x = nn.PReLU()(x) # PReLU learns 'alpha' parameter per channel
return x
class AttentionBlock(nn.Module):
"""
Simple Channel Attention Block. Learns to weight channels based on their importance.
"""
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid()
)
"""Simple Channel Attention Block in Flax."""
channels: int
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C)
Returns:
Attention-weighted output tensor (N, L, C)
"""
# Flax Conv expects kernel_size as a sequence (tuple)
ks1 = (1,)
attention_weights = nn.Conv(
features=self.channels // 4, kernel_size=ks1, padding='SAME'
)(x)
# NOTE: PyTorch used inplace=True, JAX/Flax don't modify inplace
attention_weights = nn.relu(attention_weights)
attention_weights = nn.Conv(
features=self.channels, kernel_size=ks1, padding='SAME'
)(attention_weights)
attention_weights = nn.sigmoid(attention_weights)
def forward(self, x):
attention_weights = self.attention(x)
return x * attention_weights
class ResidualInResidualBlock(nn.Module):
def __init__(self, channels, num_convs=3):
super(ResidualInResidualBlock, self).__init__()
"""ResidualInResidualBlock in Flax."""
channels: int
num_convs: int = 3
self.conv_layers = nn.Sequential(
*[conv_block(channels, channels) for _ in range(num_convs)]
)
self.attention = AttentionBlock(channels)
def forward(self, x):
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C)
Returns:
Output tensor (N, L, C)
"""
residual = x
x = self.conv_layers(x)
x = self.attention(x)
return x + residual
y = x
# Sequentially apply ConvBlocks
for _ in range(self.num_convs):
y = ConvBlock(
out_channels=self.channels,
kernel_size=3, # Assuming kernel_size 3 as in original conv_block default
dilation=1 # Assuming dilation 1 as in original conv_block default
)(y)
y = AttentionBlock(channels=self.channels)(y)
return y + residual
class SISUGenerator(nn.Module):
def __init__(self, channels=16, num_rirb=4, alpha=1.0):
super(SISUGenerator, self).__init__()
self.alpha = alpha
"""SISUGenerator model translated to Flax."""
channels: int = 16
num_rirb: int = 4
alpha: float = 1.0 # Non-learnable parameter, passed during init
self.conv1 = nn.Sequential(
nn.Conv1d(1, channels, kernel_size=7, padding=3),
nn.InstanceNorm1d(channels),
nn.PReLU(),
)
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, 1) - assumes single channel input
Returns:
Output tensor (N, L, 1)
"""
if x.shape[-1] != 1:
raise ValueError(f"Input should have 1 channel (NLC format), got shape {x.shape}")
self.rir_blocks = nn.Sequential(
*[ResidualInResidualBlock(channels) for _ in range(num_rirb)]
)
self.final_layer = nn.Conv1d(channels, 1, kernel_size=3, padding=1)
def forward(self, x):
residual_input = x
x = self.conv1(x)
x_rirb_out = self.rir_blocks(x)
learned_residual = self.final_layer(x_rirb_out)
output = residual_input + self.alpha * learned_residual
# Initial convolution block
# Flax Conv expects kernel_size as sequence
ks7 = (7,)
ks3 = (3,)
y = nn.Conv(features=self.channels, kernel_size=ks7, padding='SAME')(x)
y = InstanceNorm1d(features=self.channels)(y)
y = nn.PReLU()(y)
# Residual-in-Residual Blocks
rirb_out = y
for _ in range(self.num_rirb):
rirb_out = ResidualInResidualBlock(channels=self.channels)(rirb_out)
# Final layer
learned_residual = nn.Conv(
features=1, kernel_size=ks3, padding='SAME'
)(rirb_out)
# Combine with input residual
output = residual_input + self.alpha * learned_residual
return output

View File

@ -1,46 +1,33 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import jax
import jax.numpy as jnp
import optax
import tqdm
import argparse
import math
import pickle # Using pickle for simplicity to save JAX states
import os
import argparse
# You might need a JAX-compatible library for audio loading/saving or convert to numpy
import scipy.io.wavfile as wavfile # Example for saving audio
from torch.utils.data import random_split
import torch
from torch.utils.data import DataLoader
import AudioUtils
from data import AudioDataset
import file_utils as Data
from data import AudioDatasetNumPy
from generator import SISUGenerator
from discriminator import SISUDiscriminator
from training_utils import discriminator_train, generator_train
import file_utils as Data
import torchaudio.transforms as T
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
parser.add_argument("--epoch", type=int, default=0, help="Starting epoch for model versioning")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models")
args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Parameters
sample_rate = 44100
n_fft = 2048
@ -49,146 +36,446 @@ win_length = n_fft
n_mels = 128
n_mfcc = 20 # If using MFCC
mfcc_transform = T.MFCC(
sample_rate,
n_mfcc,
melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
).to(device)
mel_transform = T.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
).to(device)
stft_transform = T.Spectrogram(
n_fft=n_fft, win_length=win_length, hop_length=hop_length
).to(device)
debug = args.debug
# Initialize JAX random key
key = jax.random.PRNGKey(0)
# Initialize dataset and dataloader
dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir, device)
dataset = AudioDatasetNumPy(dataset_dir) # Use your JAX dataset
train_data_loader = DataLoader(dataset, batch_size=4, shuffle=True) # Use your JAX DataLoader
models_dir = "models"
os.makedirs(models_dir, exist_ok=True)
audio_output_dir = "output"
os.makedirs(audio_output_dir, exist_ok=True)
# ========= SINGLE =========
train_data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
# ========= MODELS =========
generator = SISUGenerator()
discriminator = SISUDiscriminator()
try:
# Fetch the first batch
first_batch = next(iter(train_data_loader))
# The batch is a tuple: ((high_quality_audio_np, high_quality_sample_rate), (low_quality_audio_np, low_quality_sample_rate))
# We need the high-quality audio NumPy array batch for initialization
sample_input_np = first_batch[0][0] # Get the high-quality audio NumPy array batch
# Convert the NumPy array batch to a JAX array
sample_input_array = jnp.array(sample_input_np)
epoch: int = args.epoch
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
# === FIX ===
# Transpose the array from (batch, channels, length) to (batch, length, channels)
# The original shape from DataLoader is likely (batch, channels, length) like (4, 1, 44100)
# The generator expects NLC format (batch, length, channels) i.e., (4, 44100, 1)
if sample_input_array.ndim == 3 and sample_input_array.shape[1] == 1:
sample_input_array = jnp.transpose(sample_input_array, (0, 2, 1)) # Swap axes 1 and 2
print(sample_input_array.shape) # Should now print (4, 44100, 1)
# === END FIX ===
except StopIteration:
print("Error: Data loader is empty. Cannot initialize models.")
exit() # Exit if no data is available
key, init_key_g, init_key_d = jax.random.split(key, 3)
generator_model = SISUGenerator()
discriminator_model = SISUDiscriminator()
# Initialize parameters
generator_params = generator_model.init(init_key_g, sample_input_array)['params']
discriminator_params = discriminator_model.init(init_key_d, sample_input_array)['params']
# Define apply functions
generator_apply_fn = generator_model.apply
discriminator_apply_fn = discriminator_model.apply
# Loss functions (JAX equivalents)
# Optax provides common loss functions. BCEWithLogitsLoss is equivalent to
# sigmoid_binary_cross_entropy in Optax combined with a sigmoid activation
# in the model output or handling logits directly. Assuming your discriminator
# outputs logits, optax.sigmoid_binary_cross_entropy is appropriate.
criterion_d = optax.sigmoid_binary_cross_entropy
criterion_l1 = optax.sigmoid_binary_cross_entropy # For Mel, STFT, MFCC losses
# Optimizers (using Optax)
optimizer_g = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.999)
optimizer_d = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.999)
# Initialize optimizer states
generator_opt_state = optimizer_g.init(generator_params)
discriminator_opt_state = optimizer_d.init(discriminator_params)
# Schedulers - Optax has learning rate schedules. ReduceLROnPlateau
# is stateful and usually handled outside the jitted training step,
# or you can implement a custom learning rate schedule in Optax that
# takes a metric. For simplicity here, we won't directly replicate the
# PyTorch ReduceLROnPlateau but you could add logic based on losses
# in the main loop to adjust the learning rate if needed.
# Load saved state if continuing training
start_epoch = args.epoch
if args.continue_training:
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
epoch = epoch_from_file["epoch"] + 1
else:
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
if args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
try:
with open(f"{models_dir}/temp_generator.pkl", 'rb') as f:
loaded_state = pickle.load(f)
generator_params = loaded_state['params']
generator_opt_state = loaded_state['opt_state']
with open(f"{models_dir}/temp_discriminator.pkl", 'rb') as f:
loaded_state = pickle.load(f)
discriminator_params = loaded_state['params']
discriminator_opt_state = loaded_state['opt_state']
epoch_data = Data.read_data(f"{models_dir}/epoch_data.json")
start_epoch = epoch_data.get("epoch", 0) + 1
print(f"Continuing training from epoch {start_epoch}")
except FileNotFoundError:
print("Continue training requested but temp models not found. Starting from scratch.")
except Exception as e:
print(f"Error loading temp models: {e}. Starting from scratch.")
generator = generator.to(device)
discriminator = discriminator.to(device)
if args.generator is not None:
try:
with open(args.generator, 'rb') as f:
loaded_state = pickle.load(f)
generator_params = loaded_state['params']
print(f"Loaded generator from {args.generator}")
except FileNotFoundError:
print(f"Generator model not found at {args.generator}")
# Loss
criterion_g = nn.BCEWithLogitsLoss()
criterion_d = nn.BCEWithLogitsLoss()
if args.discriminator is not None:
try:
with open(args.discriminator, 'rb') as f:
loaded_state = pickle.load(f)
discriminator_params = loaded_state['params']
print(f"Loaded discriminator from {args.discriminator}")
except FileNotFoundError:
print(f"Discriminator model not found at {args.discriminator}")
# Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
# Scheduler
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
# Initialize JAX audio transforms
# mel_transform_fn = MelSpectrogramJAX(
# sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
# win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
# )
# stft_transform_fn = SpectrogramJAX(
# n_fft=n_fft, win_length=win_length, hop_length=hop_length
# )
# mfcc_transform_fn = MFCCJAX(
# sample_rate,
# n_mfcc,
# melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
# )
# ========= JAX TRAINING STEPS =========
@jax.jit
def discriminator_train_step(
discriminator_params,
generator_params,
discriminator_opt_state,
high_quality_audio, # JAX array (batch, length, channels)
low_quality_audio, # JAX array (batch, length, channels)
real_labels, # JAX array
fake_labels, # JAX array
discriminator_apply_fn,
generator_apply_fn,
discriminator_optimizer,
criterion_d,
key # JAX random key
):
# Split key for potential randomness in model application
key, disc_key, gen_key = jax.random.split(key, 3)
def loss_fn(d_params):
# Generate fake audio
# Note: Generator is not being trained in this step, so its parameters are static
# Ensure low_quality_audio is in the expected NLC format (batch, length, channels)
if low_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
low_quality_audio = jnp.expand_dims(low_quality_audio, axis=-1)
elif low_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
low_quality_audio = jnp.expand_dims(jnp.expand_dims(low_quality_audio, axis=0), axis=-1)
enhanced_audio, _ = generator_apply_fn({'params': generator_params}, gen_key, low_quality_audio)
# Pass data through the discriminator
# Ensure enhanced_audio has a leading dimension if not already present (e.g., batch size)
if enhanced_audio.ndim == 2: # Assuming (length, channel), add batch dim
enhanced_audio = jnp.expand_dims(enhanced_audio, axis=0)
elif enhanced_audio.ndim == 1: # Assuming (length), add batch and channel dims
enhanced_audio = jnp.expand_dims(jnp.expand_dims(enhanced_audio, axis=0), axis=-1)
# Ensure high_quality_audio is in the expected NLC format (batch, length, channels)
if high_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
high_quality_audio = jnp.expand_dims(high_quality_audio, axis=-1)
elif high_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
high_quality_audio = jnp.expand_dims(jnp.expand_dims(high_quality_audio, axis=0), axis=-1)
real_output = discriminator_apply_fn({'params': d_params}, disc_key, high_quality_audio)
fake_output = discriminator_apply_fn({'params': d_params}, disc_key, enhanced_audio)
# Calculate loss (criterion_d is assumed to be Optax's sigmoid_binary_cross_entropy or similar)
# Ensure the shapes match the labels (batch_size, 1)
real_output = real_output.reshape(-1, 1)
fake_output = fake_output.reshape(-1, 1)
real_loss = jnp.mean(criterion_d(real_output, real_labels))
fake_loss = jnp.mean(criterion_d(fake_output, fake_labels))
total_loss = real_loss + fake_loss
return total_loss, (real_loss, fake_loss)
# Compute gradients
# Use jax.value_and_grad to get both the loss value and the gradients
(loss, (real_loss, fake_loss)), grads = jax.value_and_grad(loss_fn, has_aux=True)(discriminator_params)
# Apply updates
updates, new_discriminator_opt_state = discriminator_optimizer.update(grads, discriminator_opt_state, discriminator_params)
new_discriminator_params = optax.apply_updates(discriminator_params, updates)
return new_discriminator_params, new_discriminator_opt_state, loss, key
@jax.jit
def generator_train_step(
generator_params,
discriminator_params,
generator_opt_state,
low_quality_audio, # JAX array (batch, length, channels)
high_quality_audio, # JAX array (batch, length, channels)
real_labels, # JAX array
generator_apply_fn,
discriminator_apply_fn,
generator_optimizer,
criterion_d, # Adversarial loss
criterion_l1, # Feature matching loss
key # JAX random key
):
# Split key for potential randomness
key, gen_key, disc_key = jax.random.split(key, 3)
def loss_fn(g_params):
# Ensure low_quality_audio is in the expected NLC format (batch, length, channels)
if low_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
low_quality_audio = jnp.expand_dims(low_quality_audio, axis=-1)
elif low_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
low_quality_audio = jnp.expand_dims(jnp.expand_dims(low_quality_audio, axis=0), axis=-1)
# Generate enhanced audio
enhanced_audio, _ = generator_apply_fn({'params': g_params}, gen_key, low_quality_audio)
# Ensure enhanced_audio has a leading dimension if not already present
if enhanced_audio.ndim == 2: # Assuming (length, channel), add batch dim
enhanced_audio = jnp.expand_dims(enhanced_audio, axis=0)
elif enhanced_audio.ndim == 1: # Assuming (length), add batch and channel dims
enhanced_audio = jnp.expand_dims(jnp.expand_dims(enhanced_audio, axis=0), axis=-1)
# Calculate adversarial loss (generator wants discriminator to think fake is real)
# Note: Discriminator is not being trained in this step, so its parameters are static
fake_output = discriminator_apply_fn({'params': discriminator_params}, disc_key, enhanced_audio)
# Ensure the shape matches the labels (batch_size, 1)
fake_output = fake_output.reshape(-1, 1)
adversarial_loss = jnp.mean(criterion_d(fake_output, real_labels)) # Generator wants fake_output to be close to real_labels (1s)
# Feature matching losses (assuming you add these back later)
# You would need to implement JAX versions of your audio transforms
# mel_loss = criterion_l1(mel_transform_fn(enhanced_audio), mel_transform_fn(high_quality_audio))
# stft_loss = criterion_l1(stft_transform_fn(enhanced_audio), stft_transform_fn(high_quality_audio))
# mfcc_loss = criterion_l1(mfcc_transform_fn(enhanced_audio), mfcc_transform_fn(high_quality_audio))
# combined_loss = adversarial_loss + mel_loss + stft_loss + mfcc_loss
combined_loss = adversarial_loss # For now, only adversarial loss
# Return combined_loss and any other metrics needed for logging/analysis
# For now, just adversarial loss and enhanced_audio
return combined_loss, (adversarial_loss, enhanced_audio) # Add other losses here when implemented
# Compute gradients
# Update: loss_fn now returns (loss, (aux1, aux2, ...))
(loss, (adversarial_loss_val, enhanced_audio)), grads = jax.value_and_grad(loss_fn, has_aux=True)(generator_params)
# Apply updates
updates, new_generator_opt_state = generator_optimizer.update(grads, generator_opt_state, generator_params)
new_generator_params = optax.apply_updates(generator_params, updates)
# Return the loss components separately along with the enhanced audio and key
return new_generator_params, new_generator_opt_state, loss, adversarial_loss_val, enhanced_audio, key
# ========= MAIN TRAINING LOOP =========
def start_training():
global generator_params, discriminator_params, generator_opt_state, discriminator_opt_state, key
generator_epochs = 5000
for generator_epoch in range(generator_epochs):
low_quality_audio = (torch.empty((1)), 1)
high_quality_audio = (torch.empty((1)), 1)
ai_enhanced_audio = (torch.empty((1)), 1)
current_epoch = start_epoch + generator_epoch
times_correct = 0
# These will hold the last processed audio examples from a batch for saving
last_high_quality_audio = None
last_low_quality_audio = None
last_ai_enhanced_audio = None
last_sample_rate = None
# ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
# for high_quality_clip, low_quality_clip in train_data_loader:
high_quality_sample = (high_quality_clip[0], high_quality_clip[1])
low_quality_sample = (low_quality_clip[0], low_quality_clip[1])
# ========= LABELS =========
batch_size = high_quality_clip[0].size(0)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# Use tqdm for progress bar
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {current_epoch}"):
# high_quality_clip and low_quality_clip are tuples: (audio_array, sample_rate_array)
# Extract audio arrays and sample rates (assuming batch dimension is first)
# The arrays are NumPy arrays at this point, likely in (batch, channels, length) format
high_quality_audio_batch_np = high_quality_clip[0]
low_quality_audio_batch_np = low_quality_clip[0]
sample_rate_batch_np = high_quality_clip[1] # Assuming sample rates are the same for paired clips
# Convert NumPy arrays to JAX arrays and transpose to NLC format (batch, length, channels)
# Only transpose if the shape is (batch, channels, length)
if high_quality_audio_batch_np.ndim == 3 and high_quality_audio_batch_np.shape[1] == 1:
high_quality_audio_batch = jnp.transpose(jnp.array(high_quality_audio_batch_np), (0, 2, 1))
else:
high_quality_audio_batch = jnp.array(high_quality_audio_batch_np) # Assume already NLC or handle other cases
if low_quality_audio_batch_np.ndim == 3 and low_quality_audio_batch_np.shape[1] == 1:
low_quality_audio_batch = jnp.transpose(jnp.array(low_quality_audio_batch_np), (0, 2, 1))
else:
low_quality_audio_batch = jnp.array(low_quality_audio_batch_np) # Assume already NLC or handle other cases
sample_rate_batch = jnp.array(sample_rate_batch_np)
batch_size = high_quality_audio_batch.shape[0]
# Create labels - JAX arrays
real_labels = jnp.ones((batch_size, 1))
fake_labels = jnp.zeros((batch_size, 1))
# Split key for each batch
key, batch_key = jax.random.split(key)
# ========= DISCRIMINATOR =========
discriminator.train()
d_loss = discriminator_train(
high_quality_sample,
low_quality_sample,
# Call the jitted discriminator training step
discriminator_params, discriminator_opt_state, d_loss, batch_key = discriminator_train_step(
discriminator_params,
generator_params,
discriminator_opt_state,
high_quality_audio_batch,
low_quality_audio_batch,
real_labels,
fake_labels,
discriminator,
generator,
discriminator_apply_fn,
generator_apply_fn,
optimizer_d,
criterion_d,
optimizer_d
batch_key
)
# ========= GENERATOR =========
generator.train()
generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train(
low_quality_sample,
high_quality_sample,
real_labels,
generator,
discriminator,
criterion_d,
# Call the jitted generator training step
generator_params, generator_opt_state, combined_loss, adversarial_loss, enhanced_audio_batch, batch_key = generator_train_step(
generator_params,
discriminator_params,
generator_opt_state,
low_quality_audio_batch,
high_quality_audio_batch,
real_labels, # Generator tries to make fake data look real
generator_apply_fn,
discriminator_apply_fn,
optimizer_g,
device,
mel_transform,
stft_transform,
mfcc_transform
criterion_d,
criterion_l1,
batch_key
)
# Print debug logs (requires waiting for JIT compilation on first step)
if debug:
print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}")
scheduler_d.step(d_loss.detach())
scheduler_g.step(adversarial_loss.detach())
# Use .block_until_ready() to ensure computation is finished before printing
# In a real scenario, you might want to log metrics less frequently
d_loss_val = d_loss.block_until_ready().item()
combined_loss_val = combined_loss.block_until_ready().item()
adversarial_loss_val = adversarial_loss.block_until_ready().item()
# Assuming other losses are returned by generator_train_step and unpacked
# mel_loss_val = mel_l1_tensor.block_until_ready().item() if mel_l1_tensor is not None else 0
# stft_loss_val = log_stft_l1_tensor.block_until_ready().item() if log_stft_l1_tensor is not None else 0
# mfcc_loss_val = mfcc_l_tensor.block_until_ready().item() if mfcc_l_tensor is not None else 0
print(f"D_LOSS: {d_loss_val:.4f}, G_COMBINED_LOSS: {combined_loss_val:.4f}, G_ADVERSARIAL_LOSS: {adversarial_loss_val:.4f}")
# Print other losses here when implemented and returned
# ========= SAVE LATEST AUDIO =========
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])
low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0])
ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0])
new_epoch = generator_epoch+epoch
if generator_epoch % 25 == 0:
print(f"Saved epoch {new_epoch}!")
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu().detach(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1])
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1])
#if debug:
# print(generator.state_dict().keys())
# print(discriminator.state_dict().keys())
torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt")
torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt")
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch})
# Schedulers - Implement your learning rate scheduling logic here if needed
# based on the losses (e.g., reducing learning rate if loss plateaus).
# This logic would typically live outside the jitted step function.
# For Optax, you might use a schedule within the optimizer definition
# or update the learning rate of the optimizer manually.
torch.save(discriminator, "models/epoch-5000-discriminator.pt")
torch.save(generator, "models/epoch-5000-generator.pt")
print("Training complete!")
# ========= SAVE LATEST AUDIO (from the last batch processed) =========
# Access the first sample of the batch for saving
# Ensure enhanced_audio_batch has a batch dimension and is in NLC format
if enhanced_audio_batch.ndim == 2: # Assuming (length, channel), add batch dim
enhanced_audio_batch = jnp.expand_dims(enhanced_audio_batch, axis=0)
elif enhanced_audio_batch.ndim == 1: # Assuming (length), add batch and channel dims
enhanced_audio_batch = jnp.expand_dims(jnp.expand_dims(enhanced_audio_batch, axis=0), axis=-1)
last_high_quality_audio = high_quality_audio_batch[0]
last_low_quality_audio = low_quality_audio_batch[0]
last_ai_enhanced_audio = enhanced_audio_batch[0]
last_sample_rate = sample_rate_batch[0].item() # Assuming sample rate is scalar per batch item
# Save audio files periodically (outside the batch loop)
if generator_epoch % 25 == 0 and last_high_quality_audio is not None:
print(f"Saving audio for epoch {current_epoch}!")
try:
# Convert JAX arrays to NumPy arrays for saving
# Transpose back to (length, channels) or (length) if needed by wavfile.write
# Assuming the models output (length, 1) or (length) after removing batch dim
low_quality_audio_np_save = jax.device_get(last_low_quality_audio)
ai_enhanced_audio_np_save = jax.device_get(last_ai_enhanced_audio)
high_quality_audio_np_save = jax.device_get(last_high_quality_audio)
# Remove the channel dimension if it's 1 for saving with wavfile
if low_quality_audio_np_save.shape[-1] == 1:
low_quality_audio_np_save = low_quality_audio_np_save.squeeze(axis=-1)
if ai_enhanced_audio_np_save.shape[-1] == 1:
ai_enhanced_audio_np_save = ai_enhanced_audio_np_save.squeeze(axis=-1)
if high_quality_audio_np_save.shape[-1] == 1:
high_quality_audio_np_save = high_quality_audio_np_save.squeeze(axis=-1)
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-crap.wav", last_sample_rate, low_quality_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-ai.wav", last_sample_rate, ai_enhanced_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-orig.wav", last_sample_rate, high_quality_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
except Exception as e:
print(f"Error saving audio files: {e}")
# Save model states periodically (outside the batch loop)
# Use pickle to save parameters and optimizer states
try:
with open(f"{models_dir}/temp_discriminator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(discriminator_params), 'opt_state': jax.device_get(discriminator_opt_state)}, f)
with open(f"{models_dir}/temp_generator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(generator_params), 'opt_state': jax.device_get(generator_opt_state)}, f)
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": current_epoch})
except Exception as e:
print(f"Error saving temp model states: {e}")
# Save final model states after all epochs
print("Training complete! Saving final models.")
try:
with open(f"{models_dir}/epoch-{start_epoch + generator_epochs - 1}-discriminator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(discriminator_params)}, f)
with open(f"{models_dir}/epoch-{start_epoch + generator_epochs - 1}-generator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(generator_params)}, f)
except Exception as e:
print(f"Error saving final model states: {e}")
start_training()

194
training.txt Normal file
View File

@ -0,0 +1,194 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import tqdm
import argparse
import math
import os
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import AudioUtils
from data import AudioDataset
from generator import SISUGenerator
from discriminator import SISUDiscriminator
from training_utils import discriminator_train, generator_train
import file_utils as Data
import torchaudio.transforms as T
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models")
args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Parameters
sample_rate = 44100
n_fft = 2048
hop_length = 256
win_length = n_fft
n_mels = 128
n_mfcc = 20 # If using MFCC
mfcc_transform = T.MFCC(
sample_rate,
n_mfcc,
melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
).to(device)
mel_transform = T.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
).to(device)
stft_transform = T.Spectrogram(
n_fft=n_fft, win_length=win_length, hop_length=hop_length
).to(device)
debug = args.debug
# Initialize dataset and dataloader
dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir, device)
models_dir = "models"
os.makedirs(models_dir, exist_ok=True)
audio_output_dir = "output"
os.makedirs(audio_output_dir, exist_ok=True)
# ========= SINGLE =========
train_data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
# ========= MODELS =========
generator = SISUGenerator()
discriminator = SISUDiscriminator()
epoch: int = args.epoch
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
if args.continue_training:
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
epoch = epoch_from_file["epoch"] + 1
else:
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
if args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
generator = generator.to(device)
discriminator = discriminator.to(device)
# Loss
criterion_g = nn.BCEWithLogitsLoss()
criterion_d = nn.BCEWithLogitsLoss()
# Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
# Scheduler
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
def start_training():
generator_epochs = 5000
for generator_epoch in range(generator_epochs):
low_quality_audio = (torch.empty((1)), 1)
high_quality_audio = (torch.empty((1)), 1)
ai_enhanced_audio = (torch.empty((1)), 1)
times_correct = 0
# ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
# for high_quality_clip, low_quality_clip in train_data_loader:
high_quality_sample = (high_quality_clip[0], high_quality_clip[1])
low_quality_sample = (low_quality_clip[0], low_quality_clip[1])
# ========= LABELS =========
batch_size = high_quality_clip[0].size(0)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# ========= DISCRIMINATOR =========
discriminator.train()
d_loss = discriminator_train(
high_quality_sample,
low_quality_sample,
real_labels,
fake_labels,
discriminator,
generator,
criterion_d,
optimizer_d
)
# ========= GENERATOR =========
generator.train()
generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train(
low_quality_sample,
high_quality_sample,
real_labels,
generator,
discriminator,
criterion_d,
optimizer_g,
device,
mel_transform,
stft_transform,
mfcc_transform
)
if debug:
print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}")
scheduler_d.step(d_loss.detach())
scheduler_g.step(adversarial_loss.detach())
# ========= SAVE LATEST AUDIO =========
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])
low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0])
ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0])
new_epoch = generator_epoch+epoch
if generator_epoch % 25 == 0:
print(f"Saved epoch {new_epoch}!")
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu().detach(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1])
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1])
#if debug:
# print(generator.state_dict().keys())
# print(discriminator.state_dict().keys())
torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt")
torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt")
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch})
torch.save(discriminator, "models/epoch-5000-discriminator.pt")
torch.save(generator, "models/epoch-5000-generator.pt")
print("Training complete!")
start_training()

View File

@ -20,12 +20,10 @@ def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tenso
mel_spec_true = mel_transform(y_true)
mel_spec_pred = mel_transform(y_pred)
# Ensure same time dimension length (due to potential framing differences)
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
mel_spec_true = mel_spec_true[..., :min_len]
mel_spec_pred = mel_spec_pred[..., :min_len]
# L1 Loss (Mean Absolute Error)
loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred))
return loss