💩 | VERY CRUDE JAX implementation...
This commit is contained in:
parent
d70c86c257
commit
5735557ec3
109
data.py
109
data.py
@ -1,53 +1,104 @@
|
||||
# Keep necessary PyTorch imports for torchaudio and Dataset structure
|
||||
from torch.utils.data import Dataset
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchaudio.transforms as T # Keep using torchaudio transforms
|
||||
|
||||
# Import NumPy
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
import random
|
||||
import torchaudio.transforms as T
|
||||
# Assume AudioUtils is available and works on PyTorch Tensors as before
|
||||
import AudioUtils
|
||||
|
||||
class AudioDataset(Dataset):
|
||||
class AudioDatasetNumPy(Dataset): # Renamed slightly for clarity
|
||||
audio_sample_rates = [11025]
|
||||
MAX_LENGTH = 44100 # Define your desired maximum length here
|
||||
|
||||
def __init__(self, input_dir, device):
|
||||
self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')]
|
||||
self.device = device
|
||||
def __init__(self, input_dir):
|
||||
"""
|
||||
Initializes the dataset. Device argument is removed.
|
||||
"""
|
||||
self.input_files = [
|
||||
os.path.join(root, f)
|
||||
for root, _, files in os.walk(input_dir)
|
||||
for f in files if f.endswith('.wav')
|
||||
]
|
||||
if not self.input_files:
|
||||
print(f"Warning: No .wav files found in {input_dir}")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_files)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Load high-quality audio
|
||||
high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True)
|
||||
"""
|
||||
Loads audio, processes it, and returns NumPy arrays.
|
||||
"""
|
||||
# --- Load and Resample using torchaudio (produces PyTorch tensors) ---
|
||||
try:
|
||||
high_quality_audio_pt, original_sample_rate = torchaudio.load(
|
||||
self.input_files[idx], normalize=True
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error loading file {self.input_files[idx]}: {e}")
|
||||
# Return None or raise error, or return dummy data if preferred
|
||||
# Returning dummy data might hide issues
|
||||
return None # Or handle appropriately
|
||||
|
||||
# Generate low-quality audio with random downsampling
|
||||
mangled_sample_rate = random.choice(self.audio_sample_rates)
|
||||
resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
|
||||
low_quality_audio = resample_transform_low(high_quality_audio)
|
||||
|
||||
resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
|
||||
low_quality_audio = resample_transform_high(low_quality_audio)
|
||||
# Ensure sample rates are different before resampling
|
||||
if original_sample_rate != mangled_sample_rate:
|
||||
resample_transform_low = T.Resample(original_sample_rate, mangled_sample_rate)
|
||||
low_quality_audio_pt = resample_transform_low(high_quality_audio_pt)
|
||||
|
||||
high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio)
|
||||
low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio)
|
||||
resample_transform_high = T.Resample(mangled_sample_rate, original_sample_rate)
|
||||
low_quality_audio_pt = resample_transform_high(low_quality_audio_pt)
|
||||
else:
|
||||
# If rates match, just copy the tensor
|
||||
low_quality_audio_pt = high_quality_audio_pt.clone()
|
||||
|
||||
# Pad or truncate high-quality audio
|
||||
if high_quality_audio.shape[1] < self.MAX_LENGTH:
|
||||
padding = self.MAX_LENGTH - high_quality_audio.shape[1]
|
||||
high_quality_audio = F.pad(high_quality_audio, (0, padding))
|
||||
elif high_quality_audio.shape[1] > self.MAX_LENGTH:
|
||||
high_quality_audio = high_quality_audio[:, :self.MAX_LENGTH]
|
||||
|
||||
# Pad or truncate low-quality audio
|
||||
if low_quality_audio.shape[1] < self.MAX_LENGTH:
|
||||
padding = self.MAX_LENGTH - low_quality_audio.shape[1]
|
||||
low_quality_audio = F.pad(low_quality_audio, (0, padding))
|
||||
elif low_quality_audio.shape[1] > self.MAX_LENGTH:
|
||||
low_quality_audio = low_quality_audio[:, :self.MAX_LENGTH]
|
||||
# --- Process Stereo to Mono (still using PyTorch tensors) ---
|
||||
# Assuming AudioUtils.stereo_tensor_to_mono expects PyTorch Tensor (C, L)
|
||||
# and returns PyTorch Tensor (1, L)
|
||||
try:
|
||||
high_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(high_quality_audio_pt)
|
||||
low_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(low_quality_audio_pt)
|
||||
except Exception as e:
|
||||
# Handle cases where mono conversion might fail (e.g., already mono)
|
||||
# This depends on how AudioUtils is implemented. Let's assume it handles it.
|
||||
print(f"Warning: Mono conversion issue with {self.input_files[idx]}: {e}. Using original.")
|
||||
high_quality_audio_pt_mono = high_quality_audio_pt if high_quality_audio_pt.shape[0] == 1 else torch.mean(high_quality_audio_pt, dim=0, keepdim=True)
|
||||
low_quality_audio_pt_mono = low_quality_audio_pt if low_quality_audio_pt.shape[0] == 1 else torch.mean(low_quality_audio_pt, dim=0, keepdim=True)
|
||||
|
||||
high_quality_audio = high_quality_audio.to(self.device)
|
||||
low_quality_audio = low_quality_audio.to(self.device)
|
||||
|
||||
return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)
|
||||
# --- Convert to NumPy Arrays ---
|
||||
high_quality_audio_np = high_quality_audio_pt_mono.numpy() # Shape (1, L)
|
||||
low_quality_audio_np = low_quality_audio_pt_mono.numpy() # Shape (1, L)
|
||||
|
||||
|
||||
# --- Pad or Truncate using NumPy ---
|
||||
def process_numpy_audio(audio_np, max_len):
|
||||
current_len = audio_np.shape[1] # Length is axis 1 for shape (1, L)
|
||||
if current_len < max_len:
|
||||
padding_needed = max_len - current_len
|
||||
# np.pad format: ((pad_before_ax0, pad_after_ax0), (pad_before_ax1, pad_after_ax1), ...)
|
||||
# We only pad axis 1 (length) at the end
|
||||
audio_np = np.pad(audio_np, ((0, 0), (0, padding_needed)), mode='constant', constant_values=0)
|
||||
elif current_len > max_len:
|
||||
# Truncate axis 1 (length)
|
||||
audio_np = audio_np[:, :max_len]
|
||||
return audio_np
|
||||
|
||||
high_quality_audio_np = process_numpy_audio(high_quality_audio_np, self.MAX_LENGTH)
|
||||
low_quality_audio_np = process_numpy_audio(low_quality_audio_np, self.MAX_LENGTH)
|
||||
|
||||
# --- Remove Device Handling ---
|
||||
# .to(self.device) is removed.
|
||||
|
||||
# --- Return NumPy arrays and metadata ---
|
||||
# Note: Arrays likely have shape (1, MAX_LENGTH) here
|
||||
return (high_quality_audio_np, original_sample_rate), (low_quality_audio_np, mangled_sample_rate)
|
||||
|
192
discriminator.py
192
discriminator.py
@ -1,63 +1,145 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.utils as utils
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax import linen as nn
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, spectral_norm=True, use_instance_norm=True):
|
||||
padding = (kernel_size // 2) * dilation
|
||||
conv_layer = nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding
|
||||
)
|
||||
|
||||
if spectral_norm:
|
||||
conv_layer = utils.spectral_norm(conv_layer)
|
||||
|
||||
layers = [conv_layer]
|
||||
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
||||
|
||||
if use_instance_norm:
|
||||
layers.append(nn.InstanceNorm1d(out_channels))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
# Assume InstanceNorm1d and AttentionBlock are defined as in the generator conversion
|
||||
# --- Custom InstanceNorm1d Implementation (from Generator) ---
|
||||
class InstanceNorm1d(nn.Module):
|
||||
features: int
|
||||
epsilon: float = 1e-5
|
||||
use_scale: bool = True
|
||||
use_bias: bool = True
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
if x.shape[-1] != self.features:
|
||||
raise ValueError(f"Input features {x.shape[-1]} does not match InstanceNorm1d features {self.features}")
|
||||
mean = jnp.mean(x, axis=1, keepdims=True)
|
||||
var = jnp.var(x, axis=1, keepdims=True)
|
||||
normalized = (x - mean) / jnp.sqrt(var + self.epsilon)
|
||||
if self.use_scale:
|
||||
scale = self.param('scale', nn.initializers.ones, (self.features,))
|
||||
normalized *= scale
|
||||
if self.use_bias:
|
||||
bias = self.param('bias', nn.initializers.zeros, (self.features,))
|
||||
normalized += bias
|
||||
return normalized
|
||||
|
||||
# --- AttentionBlock Implementation (from Generator) ---
|
||||
class AttentionBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super(AttentionBlock, self).__init__()
|
||||
self.attention = nn.Sequential(
|
||||
nn.Conv1d(channels, channels // 4, kernel_size=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv1d(channels // 4, channels, kernel_size=1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
attention_weights = self.attention(x)
|
||||
channels: int
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
ks1 = (1,)
|
||||
attention_weights = nn.Conv(features=self.channels // 4, kernel_size=ks1, padding='SAME')(x)
|
||||
attention_weights = nn.relu(attention_weights)
|
||||
attention_weights = nn.Conv(features=self.channels, kernel_size=ks1, padding='SAME')(attention_weights)
|
||||
attention_weights = nn.sigmoid(attention_weights)
|
||||
return x * attention_weights
|
||||
|
||||
# --- Converted Discriminator Modules ---
|
||||
|
||||
class DiscriminatorBlock(nn.Module):
|
||||
"""Equivalent of the PyTorch discriminator_block function."""
|
||||
in_channels: int # Needed for clarity, though not strictly used by layers if input shape is known
|
||||
out_channels: int
|
||||
kernel_size: int = 3
|
||||
stride: int = 1
|
||||
dilation: int = 1
|
||||
# spectral_norm: bool = True # Flag for where SN would be applied
|
||||
use_instance_norm: bool = True
|
||||
negative_slope: float = 0.2
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: Input tensor (N, L, C_in)
|
||||
Returns:
|
||||
Output tensor (N, L', C_out) - L' depends on stride/padding
|
||||
"""
|
||||
# Flax Conv expects kernel_size, stride, dilation as sequences (tuples)
|
||||
ks = (self.kernel_size,)
|
||||
st = (self.stride,)
|
||||
di = (self.dilation,)
|
||||
|
||||
# Padding='SAME' works reasonably well for stride=1 and stride=2 downsampling
|
||||
# NOTE: Spectral Norm is omitted here.
|
||||
# If implementing, you'd wrap or replace nn.Conv with a spectral-normalized version.
|
||||
# conv_layer = SpectralNormConv1D(...) or wrap(nn.Conv(...))
|
||||
y = nn.Conv(
|
||||
features=self.out_channels,
|
||||
kernel_size=ks,
|
||||
strides=st,
|
||||
kernel_dilation=di,
|
||||
padding='SAME' # Often used in GANs
|
||||
)(x)
|
||||
|
||||
# Apply LeakyReLU first (as in the original code if IN is used)
|
||||
y = nn.leaky_relu(y, negative_slope=self.negative_slope)
|
||||
|
||||
# Conditionally apply InstanceNorm
|
||||
if self.use_instance_norm:
|
||||
y = InstanceNorm1d(features=self.out_channels)(y)
|
||||
|
||||
return y
|
||||
|
||||
class SISUDiscriminator(nn.Module):
|
||||
def __init__(self, base_channels=16):
|
||||
super(SISUDiscriminator, self).__init__()
|
||||
layers = base_channels
|
||||
self.model = nn.Sequential(
|
||||
discriminator_block(1, layers, kernel_size=7, stride=1, spectral_norm=True, use_instance_norm=False),
|
||||
discriminator_block(layers, layers * 2, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True),
|
||||
discriminator_block(layers * 2, layers * 4, kernel_size=5, stride=1, dilation=2, spectral_norm=True, use_instance_norm=True),
|
||||
AttentionBlock(layers * 4),
|
||||
discriminator_block(layers * 4, layers * 8, kernel_size=5, stride=1, dilation=4, spectral_norm=True, use_instance_norm=True),
|
||||
discriminator_block(layers * 8, layers * 4, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True),
|
||||
discriminator_block(layers * 4, layers * 2, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True),
|
||||
discriminator_block(layers * 2, layers, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True),
|
||||
discriminator_block(layers, 1, kernel_size=3, stride=1, spectral_norm=False, use_instance_norm=False)
|
||||
)
|
||||
"""SISUDiscriminator model translated to Flax."""
|
||||
base_channels: int = 16
|
||||
|
||||
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: Input tensor (N, L, 1) - assumes single channel input
|
||||
Returns:
|
||||
Output tensor (N, 1) - logits
|
||||
"""
|
||||
if x.shape[-1] != 1:
|
||||
raise ValueError(f"Input should have 1 channel (NLC format), got shape {x.shape}")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.model(x)
|
||||
x = self.global_avg_pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return x
|
||||
ch = self.base_channels
|
||||
|
||||
# Block 1: 1 -> ch, k=7, s=1, d=1, SN=T, IN=F
|
||||
# NOTE: Spectral Norm omitted
|
||||
y = DiscriminatorBlock(in_channels=1, out_channels=ch, kernel_size=7, stride=1, use_instance_norm=False)(x)
|
||||
|
||||
# Block 2: ch -> ch*2, k=5, s=2, d=1, SN=T, IN=T
|
||||
# NOTE: Spectral Norm omitted
|
||||
y = DiscriminatorBlock(in_channels=ch, out_channels=ch*2, kernel_size=5, stride=2, use_instance_norm=True)(y)
|
||||
|
||||
# Block 3: ch*2 -> ch*4, k=5, s=1, d=2, SN=T, IN=T
|
||||
# NOTE: Spectral Norm omitted
|
||||
y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch*4, kernel_size=5, stride=1, dilation=2, use_instance_norm=True)(y)
|
||||
|
||||
# Attention Block
|
||||
y = AttentionBlock(channels=ch*4)(y)
|
||||
|
||||
# Block 4: ch*4 -> ch*8, k=5, s=1, d=4, SN=T, IN=T
|
||||
# NOTE: Spectral Norm omitted
|
||||
y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*8, kernel_size=5, stride=1, dilation=4, use_instance_norm=True)(y)
|
||||
|
||||
# Block 5: ch*8 -> ch*4, k=5, s=2, d=1, SN=T, IN=T
|
||||
# NOTE: Spectral Norm omitted
|
||||
y = DiscriminatorBlock(in_channels=ch*8, out_channels=ch*4, kernel_size=5, stride=2, use_instance_norm=True)(y)
|
||||
|
||||
# Block 6: ch*4 -> ch*2, k=3, s=1, d=1, SN=T, IN=T
|
||||
# NOTE: Spectral Norm omitted
|
||||
y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*2, kernel_size=3, stride=1, use_instance_norm=True)(y)
|
||||
|
||||
# Block 7: ch*2 -> ch, k=3, s=1, d=1, SN=T, IN=T
|
||||
# NOTE: Spectral Norm omitted
|
||||
y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch, kernel_size=3, stride=1, use_instance_norm=True)(y)
|
||||
|
||||
# Block 8: ch -> 1, k=3, s=1, d=1, SN=F, IN=F
|
||||
# NOTE: Spectral Norm omitted (as per original config)
|
||||
y = DiscriminatorBlock(in_channels=ch, out_channels=1, kernel_size=3, stride=1, use_instance_norm=False)(y)
|
||||
|
||||
# Global Average Pooling (across Length dimension)
|
||||
pooled = jnp.mean(y, axis=1) # Shape becomes (N, C=1)
|
||||
|
||||
# Flatten (optional, as shape is likely already (N, 1))
|
||||
output = jnp.reshape(pooled, (pooled.shape[0], -1)) # Shape (N, 1)
|
||||
|
||||
return output
|
||||
|
215
generator.py
215
generator.py
@ -1,74 +1,173 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax import linen as nn
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
def conv_block(in_channels, out_channels, kernel_size=3, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation,
|
||||
padding=(kernel_size // 2) * dilation
|
||||
),
|
||||
nn.InstanceNorm1d(out_channels),
|
||||
nn.PReLU()
|
||||
)
|
||||
# --- Custom InstanceNorm1d Implementation ---
|
||||
class InstanceNorm1d(nn.Module):
|
||||
"""
|
||||
Flax implementation of Instance Normalization for 1D data (NLC format).
|
||||
Normalizes across the 'L' dimension.
|
||||
"""
|
||||
features: int
|
||||
epsilon: float = 1e-5
|
||||
use_scale: bool = True
|
||||
use_bias: bool = True
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: Input tensor of shape (batch, length, features)
|
||||
|
||||
Returns:
|
||||
Normalized tensor.
|
||||
"""
|
||||
if x.shape[-1] != self.features:
|
||||
raise ValueError(f"Input features {x.shape[-1]} does not match InstanceNorm1d features {self.features}")
|
||||
|
||||
# Calculate mean and variance across the length dimension (axis=1)
|
||||
# Keep dims for broadcasting
|
||||
mean = jnp.mean(x, axis=1, keepdims=True)
|
||||
# Variance calculation using mean needs care for numerical stability if needed,
|
||||
# but jnp.var should handle it.
|
||||
var = jnp.var(x, axis=1, keepdims=True)
|
||||
|
||||
# Normalize
|
||||
normalized = (x - mean) / jnp.sqrt(var + self.epsilon)
|
||||
|
||||
# Apply learnable scale and bias if enabled
|
||||
if self.use_scale:
|
||||
# Parameter shape: (features,) to broadcast across N and L
|
||||
scale = self.param('scale', nn.initializers.ones, (self.features,))
|
||||
normalized *= scale
|
||||
if self.use_bias:
|
||||
# Parameter shape: (features,)
|
||||
bias = self.param('bias', nn.initializers.zeros, (self.features,))
|
||||
normalized += bias
|
||||
|
||||
return normalized
|
||||
|
||||
# --- Converted Modules ---
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""Equivalent of the PyTorch conv_block function."""
|
||||
out_channels: int
|
||||
kernel_size: int = 3
|
||||
dilation: int = 1
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: Input tensor (N, L, C_in)
|
||||
Returns:
|
||||
Output tensor (N, L, C_out)
|
||||
"""
|
||||
# Flax Conv expects kernel_size and dilation as sequences (tuples)
|
||||
ks = (self.kernel_size,)
|
||||
di = (self.dilation,)
|
||||
|
||||
# Padding='SAME' attempts to preserve the length dimension for stride=1
|
||||
x = nn.Conv(
|
||||
features=self.out_channels,
|
||||
kernel_size=ks,
|
||||
kernel_dilation=di,
|
||||
padding='SAME'
|
||||
)(x)
|
||||
x = InstanceNorm1d(features=self.out_channels)(x) # Use custom InstanceNorm
|
||||
x = nn.PReLU()(x) # PReLU learns 'alpha' parameter per channel
|
||||
return x
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
Simple Channel Attention Block. Learns to weight channels based on their importance.
|
||||
"""
|
||||
def __init__(self, channels):
|
||||
super(AttentionBlock, self).__init__()
|
||||
self.attention = nn.Sequential(
|
||||
nn.Conv1d(channels, channels // 4, kernel_size=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv1d(channels // 4, channels, kernel_size=1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
"""Simple Channel Attention Block in Flax."""
|
||||
channels: int
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: Input tensor (N, L, C)
|
||||
Returns:
|
||||
Attention-weighted output tensor (N, L, C)
|
||||
"""
|
||||
# Flax Conv expects kernel_size as a sequence (tuple)
|
||||
ks1 = (1,)
|
||||
attention_weights = nn.Conv(
|
||||
features=self.channels // 4, kernel_size=ks1, padding='SAME'
|
||||
)(x)
|
||||
# NOTE: PyTorch used inplace=True, JAX/Flax don't modify inplace
|
||||
attention_weights = nn.relu(attention_weights)
|
||||
attention_weights = nn.Conv(
|
||||
features=self.channels, kernel_size=ks1, padding='SAME'
|
||||
)(attention_weights)
|
||||
attention_weights = nn.sigmoid(attention_weights)
|
||||
|
||||
def forward(self, x):
|
||||
attention_weights = self.attention(x)
|
||||
return x * attention_weights
|
||||
|
||||
class ResidualInResidualBlock(nn.Module):
|
||||
def __init__(self, channels, num_convs=3):
|
||||
super(ResidualInResidualBlock, self).__init__()
|
||||
"""ResidualInResidualBlock in Flax."""
|
||||
channels: int
|
||||
num_convs: int = 3
|
||||
|
||||
self.conv_layers = nn.Sequential(
|
||||
*[conv_block(channels, channels) for _ in range(num_convs)]
|
||||
)
|
||||
|
||||
self.attention = AttentionBlock(channels)
|
||||
|
||||
def forward(self, x):
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: Input tensor (N, L, C)
|
||||
Returns:
|
||||
Output tensor (N, L, C)
|
||||
"""
|
||||
residual = x
|
||||
x = self.conv_layers(x)
|
||||
x = self.attention(x)
|
||||
return x + residual
|
||||
y = x
|
||||
# Sequentially apply ConvBlocks
|
||||
for _ in range(self.num_convs):
|
||||
y = ConvBlock(
|
||||
out_channels=self.channels,
|
||||
kernel_size=3, # Assuming kernel_size 3 as in original conv_block default
|
||||
dilation=1 # Assuming dilation 1 as in original conv_block default
|
||||
)(y)
|
||||
|
||||
y = AttentionBlock(channels=self.channels)(y)
|
||||
return y + residual
|
||||
|
||||
class SISUGenerator(nn.Module):
|
||||
def __init__(self, channels=16, num_rirb=4, alpha=1.0):
|
||||
super(SISUGenerator, self).__init__()
|
||||
self.alpha = alpha
|
||||
"""SISUGenerator model translated to Flax."""
|
||||
channels: int = 16
|
||||
num_rirb: int = 4
|
||||
alpha: float = 1.0 # Non-learnable parameter, passed during init
|
||||
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv1d(1, channels, kernel_size=7, padding=3),
|
||||
nn.InstanceNorm1d(channels),
|
||||
nn.PReLU(),
|
||||
)
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: Input tensor (N, L, 1) - assumes single channel input
|
||||
Returns:
|
||||
Output tensor (N, L, 1)
|
||||
"""
|
||||
if x.shape[-1] != 1:
|
||||
raise ValueError(f"Input should have 1 channel (NLC format), got shape {x.shape}")
|
||||
|
||||
self.rir_blocks = nn.Sequential(
|
||||
*[ResidualInResidualBlock(channels) for _ in range(num_rirb)]
|
||||
)
|
||||
|
||||
self.final_layer = nn.Conv1d(channels, 1, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
residual_input = x
|
||||
x = self.conv1(x)
|
||||
x_rirb_out = self.rir_blocks(x)
|
||||
learned_residual = self.final_layer(x_rirb_out)
|
||||
output = residual_input + self.alpha * learned_residual
|
||||
|
||||
# Initial convolution block
|
||||
# Flax Conv expects kernel_size as sequence
|
||||
ks7 = (7,)
|
||||
ks3 = (3,)
|
||||
y = nn.Conv(features=self.channels, kernel_size=ks7, padding='SAME')(x)
|
||||
y = InstanceNorm1d(features=self.channels)(y)
|
||||
y = nn.PReLU()(y)
|
||||
|
||||
# Residual-in-Residual Blocks
|
||||
rirb_out = y
|
||||
for _ in range(self.num_rirb):
|
||||
rirb_out = ResidualInResidualBlock(channels=self.channels)(rirb_out)
|
||||
|
||||
# Final layer
|
||||
learned_residual = nn.Conv(
|
||||
features=1, kernel_size=ks3, padding='SAME'
|
||||
)(rirb_out)
|
||||
|
||||
# Combine with input residual
|
||||
output = residual_input + self.alpha * learned_residual
|
||||
return output
|
||||
|
535
training.py
535
training.py
@ -1,46 +1,33 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import tqdm
|
||||
|
||||
import argparse
|
||||
|
||||
import math
|
||||
|
||||
import pickle # Using pickle for simplicity to save JAX states
|
||||
import os
|
||||
import argparse
|
||||
# You might need a JAX-compatible library for audio loading/saving or convert to numpy
|
||||
import scipy.io.wavfile as wavfile # Example for saving audio
|
||||
|
||||
from torch.utils.data import random_split
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import AudioUtils
|
||||
from data import AudioDataset
|
||||
import file_utils as Data
|
||||
from data import AudioDatasetNumPy
|
||||
from generator import SISUGenerator
|
||||
from discriminator import SISUDiscriminator
|
||||
|
||||
from training_utils import discriminator_train, generator_train
|
||||
import file_utils as Data
|
||||
|
||||
import torchaudio.transforms as T
|
||||
|
||||
# Init script argument parser
|
||||
parser = argparse.ArgumentParser(description="Training script")
|
||||
parser.add_argument("--generator", type=str, default=None,
|
||||
help="Path to the generator model file")
|
||||
parser.add_argument("--discriminator", type=str, default=None,
|
||||
help="Path to the discriminator model file")
|
||||
parser.add_argument("--device", type=str, default="cpu", help="Select device")
|
||||
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
|
||||
parser.add_argument("--epoch", type=int, default=0, help="Starting epoch for model versioning")
|
||||
parser.add_argument("--debug", action="store_true", help="Print debug logs")
|
||||
parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Parameters
|
||||
sample_rate = 44100
|
||||
n_fft = 2048
|
||||
@ -49,146 +36,446 @@ win_length = n_fft
|
||||
n_mels = 128
|
||||
n_mfcc = 20 # If using MFCC
|
||||
|
||||
mfcc_transform = T.MFCC(
|
||||
sample_rate,
|
||||
n_mfcc,
|
||||
melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
|
||||
).to(device)
|
||||
|
||||
mel_transform = T.MelSpectrogram(
|
||||
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
|
||||
win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
|
||||
).to(device)
|
||||
|
||||
stft_transform = T.Spectrogram(
|
||||
n_fft=n_fft, win_length=win_length, hop_length=hop_length
|
||||
).to(device)
|
||||
|
||||
debug = args.debug
|
||||
|
||||
# Initialize JAX random key
|
||||
key = jax.random.PRNGKey(0)
|
||||
|
||||
# Initialize dataset and dataloader
|
||||
dataset_dir = './dataset/good'
|
||||
dataset = AudioDataset(dataset_dir, device)
|
||||
dataset = AudioDatasetNumPy(dataset_dir) # Use your JAX dataset
|
||||
train_data_loader = DataLoader(dataset, batch_size=4, shuffle=True) # Use your JAX DataLoader
|
||||
|
||||
models_dir = "models"
|
||||
os.makedirs(models_dir, exist_ok=True)
|
||||
audio_output_dir = "output"
|
||||
os.makedirs(audio_output_dir, exist_ok=True)
|
||||
|
||||
# ========= SINGLE =========
|
||||
|
||||
train_data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
|
||||
|
||||
|
||||
# ========= MODELS =========
|
||||
|
||||
generator = SISUGenerator()
|
||||
discriminator = SISUDiscriminator()
|
||||
try:
|
||||
# Fetch the first batch
|
||||
first_batch = next(iter(train_data_loader))
|
||||
# The batch is a tuple: ((high_quality_audio_np, high_quality_sample_rate), (low_quality_audio_np, low_quality_sample_rate))
|
||||
# We need the high-quality audio NumPy array batch for initialization
|
||||
sample_input_np = first_batch[0][0] # Get the high-quality audio NumPy array batch
|
||||
# Convert the NumPy array batch to a JAX array
|
||||
sample_input_array = jnp.array(sample_input_np)
|
||||
|
||||
epoch: int = args.epoch
|
||||
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
|
||||
# === FIX ===
|
||||
# Transpose the array from (batch, channels, length) to (batch, length, channels)
|
||||
# The original shape from DataLoader is likely (batch, channels, length) like (4, 1, 44100)
|
||||
# The generator expects NLC format (batch, length, channels) i.e., (4, 44100, 1)
|
||||
if sample_input_array.ndim == 3 and sample_input_array.shape[1] == 1:
|
||||
sample_input_array = jnp.transpose(sample_input_array, (0, 2, 1)) # Swap axes 1 and 2
|
||||
|
||||
print(sample_input_array.shape) # Should now print (4, 44100, 1)
|
||||
# === END FIX ===
|
||||
|
||||
except StopIteration:
|
||||
print("Error: Data loader is empty. Cannot initialize models.")
|
||||
exit() # Exit if no data is available
|
||||
|
||||
|
||||
key, init_key_g, init_key_d = jax.random.split(key, 3)
|
||||
generator_model = SISUGenerator()
|
||||
discriminator_model = SISUDiscriminator()
|
||||
|
||||
# Initialize parameters
|
||||
generator_params = generator_model.init(init_key_g, sample_input_array)['params']
|
||||
discriminator_params = discriminator_model.init(init_key_d, sample_input_array)['params']
|
||||
|
||||
# Define apply functions
|
||||
generator_apply_fn = generator_model.apply
|
||||
discriminator_apply_fn = discriminator_model.apply
|
||||
|
||||
|
||||
# Loss functions (JAX equivalents)
|
||||
# Optax provides common loss functions. BCEWithLogitsLoss is equivalent to
|
||||
# sigmoid_binary_cross_entropy in Optax combined with a sigmoid activation
|
||||
# in the model output or handling logits directly. Assuming your discriminator
|
||||
# outputs logits, optax.sigmoid_binary_cross_entropy is appropriate.
|
||||
criterion_d = optax.sigmoid_binary_cross_entropy
|
||||
criterion_l1 = optax.sigmoid_binary_cross_entropy # For Mel, STFT, MFCC losses
|
||||
|
||||
# Optimizers (using Optax)
|
||||
optimizer_g = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.999)
|
||||
optimizer_d = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.999)
|
||||
|
||||
# Initialize optimizer states
|
||||
generator_opt_state = optimizer_g.init(generator_params)
|
||||
discriminator_opt_state = optimizer_d.init(discriminator_params)
|
||||
|
||||
# Schedulers - Optax has learning rate schedules. ReduceLROnPlateau
|
||||
# is stateful and usually handled outside the jitted training step,
|
||||
# or you can implement a custom learning rate schedule in Optax that
|
||||
# takes a metric. For simplicity here, we won't directly replicate the
|
||||
# PyTorch ReduceLROnPlateau but you could add logic based on losses
|
||||
# in the main loop to adjust the learning rate if needed.
|
||||
|
||||
|
||||
# Load saved state if continuing training
|
||||
start_epoch = args.epoch
|
||||
if args.continue_training:
|
||||
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
|
||||
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
|
||||
epoch = epoch_from_file["epoch"] + 1
|
||||
else:
|
||||
if args.generator is not None:
|
||||
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
|
||||
if args.discriminator is not None:
|
||||
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
|
||||
try:
|
||||
with open(f"{models_dir}/temp_generator.pkl", 'rb') as f:
|
||||
loaded_state = pickle.load(f)
|
||||
generator_params = loaded_state['params']
|
||||
generator_opt_state = loaded_state['opt_state']
|
||||
with open(f"{models_dir}/temp_discriminator.pkl", 'rb') as f:
|
||||
loaded_state = pickle.load(f)
|
||||
discriminator_params = loaded_state['params']
|
||||
discriminator_opt_state = loaded_state['opt_state']
|
||||
epoch_data = Data.read_data(f"{models_dir}/epoch_data.json")
|
||||
start_epoch = epoch_data.get("epoch", 0) + 1
|
||||
print(f"Continuing training from epoch {start_epoch}")
|
||||
except FileNotFoundError:
|
||||
print("Continue training requested but temp models not found. Starting from scratch.")
|
||||
except Exception as e:
|
||||
print(f"Error loading temp models: {e}. Starting from scratch.")
|
||||
|
||||
generator = generator.to(device)
|
||||
discriminator = discriminator.to(device)
|
||||
if args.generator is not None:
|
||||
try:
|
||||
with open(args.generator, 'rb') as f:
|
||||
loaded_state = pickle.load(f)
|
||||
generator_params = loaded_state['params']
|
||||
print(f"Loaded generator from {args.generator}")
|
||||
except FileNotFoundError:
|
||||
print(f"Generator model not found at {args.generator}")
|
||||
|
||||
# Loss
|
||||
criterion_g = nn.BCEWithLogitsLoss()
|
||||
criterion_d = nn.BCEWithLogitsLoss()
|
||||
if args.discriminator is not None:
|
||||
try:
|
||||
with open(args.discriminator, 'rb') as f:
|
||||
loaded_state = pickle.load(f)
|
||||
discriminator_params = loaded_state['params']
|
||||
print(f"Loaded discriminator from {args.discriminator}")
|
||||
except FileNotFoundError:
|
||||
print(f"Discriminator model not found at {args.discriminator}")
|
||||
|
||||
# Optimizers
|
||||
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
|
||||
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
|
||||
|
||||
# Scheduler
|
||||
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
|
||||
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
|
||||
# Initialize JAX audio transforms
|
||||
# mel_transform_fn = MelSpectrogramJAX(
|
||||
# sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
|
||||
# win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
|
||||
# )
|
||||
|
||||
# stft_transform_fn = SpectrogramJAX(
|
||||
# n_fft=n_fft, win_length=win_length, hop_length=hop_length
|
||||
# )
|
||||
|
||||
# mfcc_transform_fn = MFCCJAX(
|
||||
# sample_rate,
|
||||
# n_mfcc,
|
||||
# melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
|
||||
# )
|
||||
|
||||
|
||||
# ========= JAX TRAINING STEPS =========
|
||||
|
||||
@jax.jit
|
||||
def discriminator_train_step(
|
||||
discriminator_params,
|
||||
generator_params,
|
||||
discriminator_opt_state,
|
||||
high_quality_audio, # JAX array (batch, length, channels)
|
||||
low_quality_audio, # JAX array (batch, length, channels)
|
||||
real_labels, # JAX array
|
||||
fake_labels, # JAX array
|
||||
discriminator_apply_fn,
|
||||
generator_apply_fn,
|
||||
discriminator_optimizer,
|
||||
criterion_d,
|
||||
key # JAX random key
|
||||
):
|
||||
# Split key for potential randomness in model application
|
||||
key, disc_key, gen_key = jax.random.split(key, 3)
|
||||
|
||||
def loss_fn(d_params):
|
||||
# Generate fake audio
|
||||
# Note: Generator is not being trained in this step, so its parameters are static
|
||||
# Ensure low_quality_audio is in the expected NLC format (batch, length, channels)
|
||||
if low_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
|
||||
low_quality_audio = jnp.expand_dims(low_quality_audio, axis=-1)
|
||||
elif low_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
|
||||
low_quality_audio = jnp.expand_dims(jnp.expand_dims(low_quality_audio, axis=0), axis=-1)
|
||||
|
||||
|
||||
enhanced_audio, _ = generator_apply_fn({'params': generator_params}, gen_key, low_quality_audio)
|
||||
|
||||
# Pass data through the discriminator
|
||||
# Ensure enhanced_audio has a leading dimension if not already present (e.g., batch size)
|
||||
if enhanced_audio.ndim == 2: # Assuming (length, channel), add batch dim
|
||||
enhanced_audio = jnp.expand_dims(enhanced_audio, axis=0)
|
||||
elif enhanced_audio.ndim == 1: # Assuming (length), add batch and channel dims
|
||||
enhanced_audio = jnp.expand_dims(jnp.expand_dims(enhanced_audio, axis=0), axis=-1)
|
||||
|
||||
|
||||
# Ensure high_quality_audio is in the expected NLC format (batch, length, channels)
|
||||
if high_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
|
||||
high_quality_audio = jnp.expand_dims(high_quality_audio, axis=-1)
|
||||
elif high_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
|
||||
high_quality_audio = jnp.expand_dims(jnp.expand_dims(high_quality_audio, axis=0), axis=-1)
|
||||
|
||||
|
||||
real_output = discriminator_apply_fn({'params': d_params}, disc_key, high_quality_audio)
|
||||
fake_output = discriminator_apply_fn({'params': d_params}, disc_key, enhanced_audio)
|
||||
|
||||
# Calculate loss (criterion_d is assumed to be Optax's sigmoid_binary_cross_entropy or similar)
|
||||
# Ensure the shapes match the labels (batch_size, 1)
|
||||
real_output = real_output.reshape(-1, 1)
|
||||
fake_output = fake_output.reshape(-1, 1)
|
||||
|
||||
real_loss = jnp.mean(criterion_d(real_output, real_labels))
|
||||
fake_loss = jnp.mean(criterion_d(fake_output, fake_labels))
|
||||
total_loss = real_loss + fake_loss
|
||||
return total_loss, (real_loss, fake_loss)
|
||||
|
||||
# Compute gradients
|
||||
# Use jax.value_and_grad to get both the loss value and the gradients
|
||||
(loss, (real_loss, fake_loss)), grads = jax.value_and_grad(loss_fn, has_aux=True)(discriminator_params)
|
||||
|
||||
# Apply updates
|
||||
updates, new_discriminator_opt_state = discriminator_optimizer.update(grads, discriminator_opt_state, discriminator_params)
|
||||
new_discriminator_params = optax.apply_updates(discriminator_params, updates)
|
||||
|
||||
return new_discriminator_params, new_discriminator_opt_state, loss, key
|
||||
|
||||
|
||||
@jax.jit
|
||||
def generator_train_step(
|
||||
generator_params,
|
||||
discriminator_params,
|
||||
generator_opt_state,
|
||||
low_quality_audio, # JAX array (batch, length, channels)
|
||||
high_quality_audio, # JAX array (batch, length, channels)
|
||||
real_labels, # JAX array
|
||||
generator_apply_fn,
|
||||
discriminator_apply_fn,
|
||||
generator_optimizer,
|
||||
criterion_d, # Adversarial loss
|
||||
criterion_l1, # Feature matching loss
|
||||
key # JAX random key
|
||||
):
|
||||
# Split key for potential randomness
|
||||
key, gen_key, disc_key = jax.random.split(key, 3)
|
||||
|
||||
def loss_fn(g_params):
|
||||
# Ensure low_quality_audio is in the expected NLC format (batch, length, channels)
|
||||
if low_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
|
||||
low_quality_audio = jnp.expand_dims(low_quality_audio, axis=-1)
|
||||
elif low_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
|
||||
low_quality_audio = jnp.expand_dims(jnp.expand_dims(low_quality_audio, axis=0), axis=-1)
|
||||
|
||||
# Generate enhanced audio
|
||||
enhanced_audio, _ = generator_apply_fn({'params': g_params}, gen_key, low_quality_audio)
|
||||
|
||||
# Ensure enhanced_audio has a leading dimension if not already present
|
||||
if enhanced_audio.ndim == 2: # Assuming (length, channel), add batch dim
|
||||
enhanced_audio = jnp.expand_dims(enhanced_audio, axis=0)
|
||||
elif enhanced_audio.ndim == 1: # Assuming (length), add batch and channel dims
|
||||
enhanced_audio = jnp.expand_dims(jnp.expand_dims(enhanced_audio, axis=0), axis=-1)
|
||||
|
||||
|
||||
# Calculate adversarial loss (generator wants discriminator to think fake is real)
|
||||
# Note: Discriminator is not being trained in this step, so its parameters are static
|
||||
fake_output = discriminator_apply_fn({'params': discriminator_params}, disc_key, enhanced_audio)
|
||||
# Ensure the shape matches the labels (batch_size, 1)
|
||||
fake_output = fake_output.reshape(-1, 1)
|
||||
adversarial_loss = jnp.mean(criterion_d(fake_output, real_labels)) # Generator wants fake_output to be close to real_labels (1s)
|
||||
|
||||
# Feature matching losses (assuming you add these back later)
|
||||
# You would need to implement JAX versions of your audio transforms
|
||||
# mel_loss = criterion_l1(mel_transform_fn(enhanced_audio), mel_transform_fn(high_quality_audio))
|
||||
# stft_loss = criterion_l1(stft_transform_fn(enhanced_audio), stft_transform_fn(high_quality_audio))
|
||||
# mfcc_loss = criterion_l1(mfcc_transform_fn(enhanced_audio), mfcc_transform_fn(high_quality_audio))
|
||||
|
||||
# combined_loss = adversarial_loss + mel_loss + stft_loss + mfcc_loss
|
||||
combined_loss = adversarial_loss # For now, only adversarial loss
|
||||
|
||||
# Return combined_loss and any other metrics needed for logging/analysis
|
||||
# For now, just adversarial loss and enhanced_audio
|
||||
return combined_loss, (adversarial_loss, enhanced_audio) # Add other losses here when implemented
|
||||
|
||||
# Compute gradients
|
||||
# Update: loss_fn now returns (loss, (aux1, aux2, ...))
|
||||
(loss, (adversarial_loss_val, enhanced_audio)), grads = jax.value_and_grad(loss_fn, has_aux=True)(generator_params)
|
||||
|
||||
# Apply updates
|
||||
updates, new_generator_opt_state = generator_optimizer.update(grads, generator_opt_state, generator_params)
|
||||
new_generator_params = optax.apply_updates(generator_params, updates)
|
||||
|
||||
# Return the loss components separately along with the enhanced audio and key
|
||||
return new_generator_params, new_generator_opt_state, loss, adversarial_loss_val, enhanced_audio, key
|
||||
|
||||
|
||||
# ========= MAIN TRAINING LOOP =========
|
||||
|
||||
def start_training():
|
||||
global generator_params, discriminator_params, generator_opt_state, discriminator_opt_state, key
|
||||
generator_epochs = 5000
|
||||
|
||||
for generator_epoch in range(generator_epochs):
|
||||
low_quality_audio = (torch.empty((1)), 1)
|
||||
high_quality_audio = (torch.empty((1)), 1)
|
||||
ai_enhanced_audio = (torch.empty((1)), 1)
|
||||
current_epoch = start_epoch + generator_epoch
|
||||
|
||||
times_correct = 0
|
||||
# These will hold the last processed audio examples from a batch for saving
|
||||
last_high_quality_audio = None
|
||||
last_low_quality_audio = None
|
||||
last_ai_enhanced_audio = None
|
||||
last_sample_rate = None
|
||||
|
||||
# ========= TRAINING =========
|
||||
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
|
||||
# for high_quality_clip, low_quality_clip in train_data_loader:
|
||||
high_quality_sample = (high_quality_clip[0], high_quality_clip[1])
|
||||
low_quality_sample = (low_quality_clip[0], low_quality_clip[1])
|
||||
|
||||
# ========= LABELS =========
|
||||
batch_size = high_quality_clip[0].size(0)
|
||||
real_labels = torch.ones(batch_size, 1).to(device)
|
||||
fake_labels = torch.zeros(batch_size, 1).to(device)
|
||||
# Use tqdm for progress bar
|
||||
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {current_epoch}"):
|
||||
|
||||
# high_quality_clip and low_quality_clip are tuples: (audio_array, sample_rate_array)
|
||||
# Extract audio arrays and sample rates (assuming batch dimension is first)
|
||||
# The arrays are NumPy arrays at this point, likely in (batch, channels, length) format
|
||||
high_quality_audio_batch_np = high_quality_clip[0]
|
||||
low_quality_audio_batch_np = low_quality_clip[0]
|
||||
sample_rate_batch_np = high_quality_clip[1] # Assuming sample rates are the same for paired clips
|
||||
|
||||
# Convert NumPy arrays to JAX arrays and transpose to NLC format (batch, length, channels)
|
||||
# Only transpose if the shape is (batch, channels, length)
|
||||
if high_quality_audio_batch_np.ndim == 3 and high_quality_audio_batch_np.shape[1] == 1:
|
||||
high_quality_audio_batch = jnp.transpose(jnp.array(high_quality_audio_batch_np), (0, 2, 1))
|
||||
else:
|
||||
high_quality_audio_batch = jnp.array(high_quality_audio_batch_np) # Assume already NLC or handle other cases
|
||||
|
||||
if low_quality_audio_batch_np.ndim == 3 and low_quality_audio_batch_np.shape[1] == 1:
|
||||
low_quality_audio_batch = jnp.transpose(jnp.array(low_quality_audio_batch_np), (0, 2, 1))
|
||||
else:
|
||||
low_quality_audio_batch = jnp.array(low_quality_audio_batch_np) # Assume already NLC or handle other cases
|
||||
|
||||
sample_rate_batch = jnp.array(sample_rate_batch_np)
|
||||
|
||||
|
||||
batch_size = high_quality_audio_batch.shape[0]
|
||||
# Create labels - JAX arrays
|
||||
real_labels = jnp.ones((batch_size, 1))
|
||||
fake_labels = jnp.zeros((batch_size, 1))
|
||||
|
||||
# Split key for each batch
|
||||
key, batch_key = jax.random.split(key)
|
||||
|
||||
# ========= DISCRIMINATOR =========
|
||||
discriminator.train()
|
||||
d_loss = discriminator_train(
|
||||
high_quality_sample,
|
||||
low_quality_sample,
|
||||
# Call the jitted discriminator training step
|
||||
discriminator_params, discriminator_opt_state, d_loss, batch_key = discriminator_train_step(
|
||||
discriminator_params,
|
||||
generator_params,
|
||||
discriminator_opt_state,
|
||||
high_quality_audio_batch,
|
||||
low_quality_audio_batch,
|
||||
real_labels,
|
||||
fake_labels,
|
||||
discriminator,
|
||||
generator,
|
||||
discriminator_apply_fn,
|
||||
generator_apply_fn,
|
||||
optimizer_d,
|
||||
criterion_d,
|
||||
optimizer_d
|
||||
batch_key
|
||||
)
|
||||
|
||||
# ========= GENERATOR =========
|
||||
generator.train()
|
||||
generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train(
|
||||
low_quality_sample,
|
||||
high_quality_sample,
|
||||
real_labels,
|
||||
generator,
|
||||
discriminator,
|
||||
criterion_d,
|
||||
# Call the jitted generator training step
|
||||
generator_params, generator_opt_state, combined_loss, adversarial_loss, enhanced_audio_batch, batch_key = generator_train_step(
|
||||
generator_params,
|
||||
discriminator_params,
|
||||
generator_opt_state,
|
||||
low_quality_audio_batch,
|
||||
high_quality_audio_batch,
|
||||
real_labels, # Generator tries to make fake data look real
|
||||
generator_apply_fn,
|
||||
discriminator_apply_fn,
|
||||
optimizer_g,
|
||||
device,
|
||||
mel_transform,
|
||||
stft_transform,
|
||||
mfcc_transform
|
||||
criterion_d,
|
||||
criterion_l1,
|
||||
batch_key
|
||||
)
|
||||
|
||||
# Print debug logs (requires waiting for JIT compilation on first step)
|
||||
if debug:
|
||||
print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}")
|
||||
scheduler_d.step(d_loss.detach())
|
||||
scheduler_g.step(adversarial_loss.detach())
|
||||
# Use .block_until_ready() to ensure computation is finished before printing
|
||||
# In a real scenario, you might want to log metrics less frequently
|
||||
d_loss_val = d_loss.block_until_ready().item()
|
||||
combined_loss_val = combined_loss.block_until_ready().item()
|
||||
adversarial_loss_val = adversarial_loss.block_until_ready().item()
|
||||
# Assuming other losses are returned by generator_train_step and unpacked
|
||||
# mel_loss_val = mel_l1_tensor.block_until_ready().item() if mel_l1_tensor is not None else 0
|
||||
# stft_loss_val = log_stft_l1_tensor.block_until_ready().item() if log_stft_l1_tensor is not None else 0
|
||||
# mfcc_loss_val = mfcc_l_tensor.block_until_ready().item() if mfcc_l_tensor is not None else 0
|
||||
print(f"D_LOSS: {d_loss_val:.4f}, G_COMBINED_LOSS: {combined_loss_val:.4f}, G_ADVERSARIAL_LOSS: {adversarial_loss_val:.4f}")
|
||||
# Print other losses here when implemented and returned
|
||||
|
||||
# ========= SAVE LATEST AUDIO =========
|
||||
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])
|
||||
low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0])
|
||||
ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0])
|
||||
|
||||
new_epoch = generator_epoch+epoch
|
||||
|
||||
if generator_epoch % 25 == 0:
|
||||
print(f"Saved epoch {new_epoch}!")
|
||||
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu().detach(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
|
||||
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1])
|
||||
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1])
|
||||
|
||||
#if debug:
|
||||
# print(generator.state_dict().keys())
|
||||
# print(discriminator.state_dict().keys())
|
||||
torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt")
|
||||
torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt")
|
||||
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch})
|
||||
# Schedulers - Implement your learning rate scheduling logic here if needed
|
||||
# based on the losses (e.g., reducing learning rate if loss plateaus).
|
||||
# This logic would typically live outside the jitted step function.
|
||||
# For Optax, you might use a schedule within the optimizer definition
|
||||
# or update the learning rate of the optimizer manually.
|
||||
|
||||
|
||||
torch.save(discriminator, "models/epoch-5000-discriminator.pt")
|
||||
torch.save(generator, "models/epoch-5000-generator.pt")
|
||||
print("Training complete!")
|
||||
# ========= SAVE LATEST AUDIO (from the last batch processed) =========
|
||||
# Access the first sample of the batch for saving
|
||||
# Ensure enhanced_audio_batch has a batch dimension and is in NLC format
|
||||
if enhanced_audio_batch.ndim == 2: # Assuming (length, channel), add batch dim
|
||||
enhanced_audio_batch = jnp.expand_dims(enhanced_audio_batch, axis=0)
|
||||
elif enhanced_audio_batch.ndim == 1: # Assuming (length), add batch and channel dims
|
||||
enhanced_audio_batch = jnp.expand_dims(jnp.expand_dims(enhanced_audio_batch, axis=0), axis=-1)
|
||||
|
||||
|
||||
last_high_quality_audio = high_quality_audio_batch[0]
|
||||
last_low_quality_audio = low_quality_audio_batch[0]
|
||||
last_ai_enhanced_audio = enhanced_audio_batch[0]
|
||||
last_sample_rate = sample_rate_batch[0].item() # Assuming sample rate is scalar per batch item
|
||||
|
||||
|
||||
# Save audio files periodically (outside the batch loop)
|
||||
if generator_epoch % 25 == 0 and last_high_quality_audio is not None:
|
||||
print(f"Saving audio for epoch {current_epoch}!")
|
||||
try:
|
||||
# Convert JAX arrays to NumPy arrays for saving
|
||||
# Transpose back to (length, channels) or (length) if needed by wavfile.write
|
||||
# Assuming the models output (length, 1) or (length) after removing batch dim
|
||||
low_quality_audio_np_save = jax.device_get(last_low_quality_audio)
|
||||
ai_enhanced_audio_np_save = jax.device_get(last_ai_enhanced_audio)
|
||||
high_quality_audio_np_save = jax.device_get(last_high_quality_audio)
|
||||
|
||||
# Remove the channel dimension if it's 1 for saving with wavfile
|
||||
if low_quality_audio_np_save.shape[-1] == 1:
|
||||
low_quality_audio_np_save = low_quality_audio_np_save.squeeze(axis=-1)
|
||||
if ai_enhanced_audio_np_save.shape[-1] == 1:
|
||||
ai_enhanced_audio_np_save = ai_enhanced_audio_np_save.squeeze(axis=-1)
|
||||
if high_quality_audio_np_save.shape[-1] == 1:
|
||||
high_quality_audio_np_save = high_quality_audio_np_save.squeeze(axis=-1)
|
||||
|
||||
|
||||
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-crap.wav", last_sample_rate, low_quality_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
|
||||
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-ai.wav", last_sample_rate, ai_enhanced_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
|
||||
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-orig.wav", last_sample_rate, high_quality_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
|
||||
except Exception as e:
|
||||
print(f"Error saving audio files: {e}")
|
||||
|
||||
|
||||
# Save model states periodically (outside the batch loop)
|
||||
# Use pickle to save parameters and optimizer states
|
||||
try:
|
||||
with open(f"{models_dir}/temp_discriminator.pkl", 'wb') as f:
|
||||
pickle.dump({'params': jax.device_get(discriminator_params), 'opt_state': jax.device_get(discriminator_opt_state)}, f)
|
||||
with open(f"{models_dir}/temp_generator.pkl", 'wb') as f:
|
||||
pickle.dump({'params': jax.device_get(generator_params), 'opt_state': jax.device_get(generator_opt_state)}, f)
|
||||
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": current_epoch})
|
||||
except Exception as e:
|
||||
print(f"Error saving temp model states: {e}")
|
||||
|
||||
|
||||
# Save final model states after all epochs
|
||||
print("Training complete! Saving final models.")
|
||||
try:
|
||||
with open(f"{models_dir}/epoch-{start_epoch + generator_epochs - 1}-discriminator.pkl", 'wb') as f:
|
||||
pickle.dump({'params': jax.device_get(discriminator_params)}, f)
|
||||
with open(f"{models_dir}/epoch-{start_epoch + generator_epochs - 1}-generator.pkl", 'wb') as f:
|
||||
pickle.dump({'params': jax.device_get(generator_params)}, f)
|
||||
except Exception as e:
|
||||
print(f"Error saving final model states: {e}")
|
||||
|
||||
|
||||
start_training()
|
||||
|
194
training.txt
Normal file
194
training.txt
Normal file
@ -0,0 +1,194 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import tqdm
|
||||
|
||||
import argparse
|
||||
|
||||
import math
|
||||
|
||||
import os
|
||||
|
||||
from torch.utils.data import random_split
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import AudioUtils
|
||||
from data import AudioDataset
|
||||
from generator import SISUGenerator
|
||||
from discriminator import SISUDiscriminator
|
||||
|
||||
from training_utils import discriminator_train, generator_train
|
||||
import file_utils as Data
|
||||
|
||||
import torchaudio.transforms as T
|
||||
|
||||
# Init script argument parser
|
||||
parser = argparse.ArgumentParser(description="Training script")
|
||||
parser.add_argument("--generator", type=str, default=None,
|
||||
help="Path to the generator model file")
|
||||
parser.add_argument("--discriminator", type=str, default=None,
|
||||
help="Path to the discriminator model file")
|
||||
parser.add_argument("--device", type=str, default="cpu", help="Select device")
|
||||
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
|
||||
parser.add_argument("--debug", action="store_true", help="Print debug logs")
|
||||
parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Parameters
|
||||
sample_rate = 44100
|
||||
n_fft = 2048
|
||||
hop_length = 256
|
||||
win_length = n_fft
|
||||
n_mels = 128
|
||||
n_mfcc = 20 # If using MFCC
|
||||
|
||||
mfcc_transform = T.MFCC(
|
||||
sample_rate,
|
||||
n_mfcc,
|
||||
melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
|
||||
).to(device)
|
||||
|
||||
mel_transform = T.MelSpectrogram(
|
||||
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
|
||||
win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
|
||||
).to(device)
|
||||
|
||||
stft_transform = T.Spectrogram(
|
||||
n_fft=n_fft, win_length=win_length, hop_length=hop_length
|
||||
).to(device)
|
||||
|
||||
debug = args.debug
|
||||
|
||||
# Initialize dataset and dataloader
|
||||
dataset_dir = './dataset/good'
|
||||
dataset = AudioDataset(dataset_dir, device)
|
||||
models_dir = "models"
|
||||
os.makedirs(models_dir, exist_ok=True)
|
||||
audio_output_dir = "output"
|
||||
os.makedirs(audio_output_dir, exist_ok=True)
|
||||
|
||||
# ========= SINGLE =========
|
||||
|
||||
train_data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
|
||||
|
||||
|
||||
# ========= MODELS =========
|
||||
|
||||
generator = SISUGenerator()
|
||||
discriminator = SISUDiscriminator()
|
||||
|
||||
epoch: int = args.epoch
|
||||
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
|
||||
|
||||
if args.continue_training:
|
||||
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
|
||||
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
|
||||
epoch = epoch_from_file["epoch"] + 1
|
||||
else:
|
||||
if args.generator is not None:
|
||||
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
|
||||
if args.discriminator is not None:
|
||||
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
|
||||
|
||||
generator = generator.to(device)
|
||||
discriminator = discriminator.to(device)
|
||||
|
||||
# Loss
|
||||
criterion_g = nn.BCEWithLogitsLoss()
|
||||
criterion_d = nn.BCEWithLogitsLoss()
|
||||
|
||||
# Optimizers
|
||||
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
|
||||
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
|
||||
|
||||
# Scheduler
|
||||
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
|
||||
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
|
||||
|
||||
def start_training():
|
||||
generator_epochs = 5000
|
||||
for generator_epoch in range(generator_epochs):
|
||||
low_quality_audio = (torch.empty((1)), 1)
|
||||
high_quality_audio = (torch.empty((1)), 1)
|
||||
ai_enhanced_audio = (torch.empty((1)), 1)
|
||||
|
||||
times_correct = 0
|
||||
|
||||
# ========= TRAINING =========
|
||||
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
|
||||
# for high_quality_clip, low_quality_clip in train_data_loader:
|
||||
high_quality_sample = (high_quality_clip[0], high_quality_clip[1])
|
||||
low_quality_sample = (low_quality_clip[0], low_quality_clip[1])
|
||||
|
||||
# ========= LABELS =========
|
||||
batch_size = high_quality_clip[0].size(0)
|
||||
real_labels = torch.ones(batch_size, 1).to(device)
|
||||
fake_labels = torch.zeros(batch_size, 1).to(device)
|
||||
|
||||
# ========= DISCRIMINATOR =========
|
||||
discriminator.train()
|
||||
d_loss = discriminator_train(
|
||||
high_quality_sample,
|
||||
low_quality_sample,
|
||||
real_labels,
|
||||
fake_labels,
|
||||
discriminator,
|
||||
generator,
|
||||
criterion_d,
|
||||
optimizer_d
|
||||
)
|
||||
|
||||
# ========= GENERATOR =========
|
||||
generator.train()
|
||||
generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train(
|
||||
low_quality_sample,
|
||||
high_quality_sample,
|
||||
real_labels,
|
||||
generator,
|
||||
discriminator,
|
||||
criterion_d,
|
||||
optimizer_g,
|
||||
device,
|
||||
mel_transform,
|
||||
stft_transform,
|
||||
mfcc_transform
|
||||
)
|
||||
|
||||
if debug:
|
||||
print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}")
|
||||
scheduler_d.step(d_loss.detach())
|
||||
scheduler_g.step(adversarial_loss.detach())
|
||||
|
||||
# ========= SAVE LATEST AUDIO =========
|
||||
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])
|
||||
low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0])
|
||||
ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0])
|
||||
|
||||
new_epoch = generator_epoch+epoch
|
||||
|
||||
if generator_epoch % 25 == 0:
|
||||
print(f"Saved epoch {new_epoch}!")
|
||||
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu().detach(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
|
||||
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1])
|
||||
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1])
|
||||
|
||||
#if debug:
|
||||
# print(generator.state_dict().keys())
|
||||
# print(discriminator.state_dict().keys())
|
||||
torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt")
|
||||
torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt")
|
||||
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch})
|
||||
|
||||
|
||||
torch.save(discriminator, "models/epoch-5000-discriminator.pt")
|
||||
torch.save(generator, "models/epoch-5000-generator.pt")
|
||||
print("Training complete!")
|
||||
|
||||
start_training()
|
@ -20,12 +20,10 @@ def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tenso
|
||||
mel_spec_true = mel_transform(y_true)
|
||||
mel_spec_pred = mel_transform(y_pred)
|
||||
|
||||
# Ensure same time dimension length (due to potential framing differences)
|
||||
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
|
||||
mel_spec_true = mel_spec_true[..., :min_len]
|
||||
mel_spec_pred = mel_spec_pred[..., :min_len]
|
||||
|
||||
# L1 Loss (Mean Absolute Error)
|
||||
loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred))
|
||||
return loss
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user