8 Commits
jax ... rt-test

Author SHA1 Message Date
3f23242d6f ⚗️ | Added some stupid ways for training + some makeup 2025-10-04 22:38:11 +03:00
0bc8fc2792 | Made training bit... spicier. 2025-09-10 19:52:53 +03:00
ff38cefdd3 🐛 | Fix loading wrong model. 2025-06-08 18:14:31 +03:00
03fdc050cc | Made training bit faster. 2025-06-07 20:43:52 +03:00
2ded03713d | Added app.py script so the model can be used. 2025-06-06 22:10:06 +03:00
a135c765da 🐛 | Misc fixes... 2025-05-05 00:50:56 +03:00
b1e18443ba | Added support for .mp3 and .flac loading... 2025-05-04 23:56:14 +03:00
660b41aef8 :albemic: | Real-time testing... 2025-05-04 22:48:57 +03:00
14 changed files with 735 additions and 1211 deletions

View File

@@ -1,18 +1,97 @@
import torch
import torch.nn.functional as F
def stereo_tensor_to_mono(waveform):
def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor:
"""
Convert stereo (C, N) to mono (1, N). Ensures a channel dimension.
"""
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0) # (N,) -> (1, N)
if waveform.shape[0] > 1:
# Average across channels
mono_waveform = torch.mean(waveform, dim=0, keepdim=True)
mono_waveform = torch.mean(waveform, dim=0, keepdim=True) # (1, N)
else:
# Already mono
mono_waveform = waveform
return mono_waveform
def stretch_tensor(tensor, target_length):
scale_factor = target_length / tensor.size(1)
tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False)
def stretch_tensor(tensor: torch.Tensor, target_length: int) -> torch.Tensor:
"""
Stretch audio along time dimension to target_length.
Input assumed (1, N). Returns (1, target_length).
"""
if tensor.dim() == 1:
tensor = tensor.unsqueeze(0) # ensure (1, N)
return tensor
tensor = tensor.unsqueeze(0) # (1, 1, N) for interpolate
stretched = F.interpolate(
tensor, size=target_length, mode="linear", align_corners=False
)
return stretched.squeeze(0) # back to (1, target_length)
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128) -> torch.Tensor:
"""
Pad to fixed length. Input assumed (1, N). Returns (1, target_length).
"""
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
current_length = audio_tensor.shape[-1]
if current_length < target_length:
padding_needed = target_length - current_length
padding_tuple = (0, padding_needed)
padded_audio_tensor = F.pad(
audio_tensor, padding_tuple, mode="constant", value=0
)
else:
padded_audio_tensor = audio_tensor[..., :target_length] # crop if too long
return padded_audio_tensor
def split_audio(
audio_tensor: torch.Tensor, chunk_size: int = 128
) -> list[torch.Tensor]:
"""
Split into chunks of (1, chunk_size).
"""
if not isinstance(chunk_size, int) or chunk_size <= 0:
raise ValueError("chunk_size must be a positive integer.")
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
num_samples = audio_tensor.shape[-1]
if num_samples == 0:
return []
chunks = list(torch.split(audio_tensor, chunk_size, dim=-1))
return chunks
def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor:
"""
Reconstruct audio from chunks. Returns (1, N).
"""
if not chunks:
return torch.empty(1, 0)
chunks = [c if c.dim() == 2 else c.unsqueeze(0) for c in chunks]
try:
reconstructed_tensor = torch.cat(chunks, dim=-1)
except RuntimeError as e:
raise RuntimeError(
f"Failed to concatenate audio chunks. Ensure chunks have compatible shapes "
f"for concatenation along dim -1. Original error: {e}"
)
return reconstructed_tensor
def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
max_val = torch.max(torch.abs(audio_tensor))
if max_val < eps:
return audio_tensor # silence, skip normalization
return audio_tensor / max_val

0
__init__.py Normal file
View File

97
app.py Normal file
View File

@@ -0,0 +1,97 @@
import argparse
import torch
import torchaudio
import torchcodec
import tqdm
import AudioUtils
from generator import SISUGenerator
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--model", type=str, help="Model to use for upscaling")
parser.add_argument(
"--clip_length",
type=int,
default=16384,
help="Internal clip length, leave unspecified if unsure",
)
parser.add_argument(
"--sample_rate", type=int, default=44100, help="Output clip sample rate"
)
parser.add_argument(
"--bitrate",
type=int,
default=192000,
help="Output clip bitrate",
)
parser.add_argument("-i", "--input", type=str, help="Input audio file")
parser.add_argument("-o", "--output", type=str, help="Output audio file")
args = parser.parse_args()
if args.sample_rate < 8000:
print(
"Sample rate cannot be lower than 8000! (44100 is recommended for base models)"
)
exit()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
generator = SISUGenerator().to(device)
generator = torch.compile(generator)
models_dir = args.model
clip_length = args.clip_length
input_audio = args.input
output_audio = args.output
if models_dir:
ckpt = torch.load(models_dir, map_location=device)
generator.load_state_dict(ckpt["G"])
else:
print(
"Generator model (--model) isn't specified. Do you have the trained model? If not, you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)"
)
def start():
# To Mono!
decoder = torchcodec.decoders.AudioDecoder(input_audio)
decoded_samples = decoder.get_all_samples()
audio = decoded_samples.data
original_sample_rate = decoded_samples.sample_rate
audio = AudioUtils.stereo_tensor_to_mono(audio)
audio = AudioUtils.normalize(audio)
resample_transform = torchaudio.transforms.Resample(
original_sample_rate, args.sample_rate
)
audio = resample_transform(audio)
splitted_audio = AudioUtils.split_audio(audio, clip_length)
splitted_audio_on_device = [t.to(device) for t in splitted_audio]
processed_audio = []
for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."):
processed_audio.append(generator(clip))
reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio)
print(f"Saving {output_audio}!")
torchaudio.save_with_torchcodec(
uri=output_audio,
src=reconstructed_audio,
sample_rate=args.sample_rate,
channels_first=True,
compression=args.bitrate,
)
start()

163
data.py
View File

@@ -1,104 +1,79 @@
# Keep necessary PyTorch imports for torchaudio and Dataset structure
from torch.utils.data import Dataset
import torch
import torchaudio
import torchaudio.transforms as T # Keep using torchaudio transforms
# Import NumPy
import numpy as np
import os
import random
# Assume AudioUtils is available and works on PyTorch Tensors as before
import torchaudio
import torchcodec.decoders as decoders
import tqdm
from torch.utils.data import Dataset
import AudioUtils
class AudioDatasetNumPy(Dataset): # Renamed slightly for clarity
audio_sample_rates = [11025]
MAX_LENGTH = 44100 # Define your desired maximum length here
def __init__(self, input_dir):
"""
Initializes the dataset. Device argument is removed.
"""
self.input_files = [
os.path.join(root, f)
for root, _, files in os.walk(input_dir)
for f in files if f.endswith('.wav')
class AudioDataset(Dataset):
audio_sample_rates = [11025]
def __init__(self, input_dir, clip_length: int = 8000, normalize: bool = True):
self.clip_length = clip_length
self.normalize = normalize
input_files = [
os.path.join(input_dir, f)
for f in os.listdir(input_dir)
if os.path.isfile(os.path.join(input_dir, f))
and f.lower().endswith((".wav", ".mp3", ".flac"))
]
if not self.input_files:
print(f"Warning: No .wav files found in {input_dir}")
data = []
for audio_clip in tqdm.tqdm(
input_files, desc=f"Processing {len(input_files)} audio file(s)"
):
decoder = decoders.AudioDecoder(audio_clip)
decoded_samples = decoder.get_all_samples()
audio = decoded_samples.data.float() # ensure float32
original_sample_rate = decoded_samples.sample_rate
audio = AudioUtils.stereo_tensor_to_mono(audio)
if normalize:
audio = AudioUtils.normalize(audio)
mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform_low = torchaudio.transforms.Resample(
original_sample_rate, mangled_sample_rate
)
resample_transform_high = torchaudio.transforms.Resample(
mangled_sample_rate, original_sample_rate
)
low_audio = resample_transform_high(resample_transform_low(audio))
splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length)
splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length)
if not splitted_high_quality_audio or not splitted_low_quality_audio:
continue # skip empty or invalid clips
splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(
splitted_high_quality_audio[-1], clip_length
)
splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(
splitted_low_quality_audio[-1], clip_length
)
for high_quality_data, low_quality_data in zip(
splitted_high_quality_audio, splitted_low_quality_audio
):
data.append(
(
(high_quality_data, low_quality_data),
(original_sample_rate, mangled_sample_rate),
)
)
self.audio_data = data
def __len__(self):
return len(self.input_files)
return len(self.audio_data)
def __getitem__(self, idx):
"""
Loads audio, processes it, and returns NumPy arrays.
"""
# --- Load and Resample using torchaudio (produces PyTorch tensors) ---
try:
high_quality_audio_pt, original_sample_rate = torchaudio.load(
self.input_files[idx], normalize=True
)
except Exception as e:
print(f"Error loading file {self.input_files[idx]}: {e}")
# Return None or raise error, or return dummy data if preferred
# Returning dummy data might hide issues
return None # Or handle appropriately
# Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates)
# Ensure sample rates are different before resampling
if original_sample_rate != mangled_sample_rate:
resample_transform_low = T.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio_pt = resample_transform_low(high_quality_audio_pt)
resample_transform_high = T.Resample(mangled_sample_rate, original_sample_rate)
low_quality_audio_pt = resample_transform_high(low_quality_audio_pt)
else:
# If rates match, just copy the tensor
low_quality_audio_pt = high_quality_audio_pt.clone()
# --- Process Stereo to Mono (still using PyTorch tensors) ---
# Assuming AudioUtils.stereo_tensor_to_mono expects PyTorch Tensor (C, L)
# and returns PyTorch Tensor (1, L)
try:
high_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(high_quality_audio_pt)
low_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(low_quality_audio_pt)
except Exception as e:
# Handle cases where mono conversion might fail (e.g., already mono)
# This depends on how AudioUtils is implemented. Let's assume it handles it.
print(f"Warning: Mono conversion issue with {self.input_files[idx]}: {e}. Using original.")
high_quality_audio_pt_mono = high_quality_audio_pt if high_quality_audio_pt.shape[0] == 1 else torch.mean(high_quality_audio_pt, dim=0, keepdim=True)
low_quality_audio_pt_mono = low_quality_audio_pt if low_quality_audio_pt.shape[0] == 1 else torch.mean(low_quality_audio_pt, dim=0, keepdim=True)
# --- Convert to NumPy Arrays ---
high_quality_audio_np = high_quality_audio_pt_mono.numpy() # Shape (1, L)
low_quality_audio_np = low_quality_audio_pt_mono.numpy() # Shape (1, L)
# --- Pad or Truncate using NumPy ---
def process_numpy_audio(audio_np, max_len):
current_len = audio_np.shape[1] # Length is axis 1 for shape (1, L)
if current_len < max_len:
padding_needed = max_len - current_len
# np.pad format: ((pad_before_ax0, pad_after_ax0), (pad_before_ax1, pad_after_ax1), ...)
# We only pad axis 1 (length) at the end
audio_np = np.pad(audio_np, ((0, 0), (0, padding_needed)), mode='constant', constant_values=0)
elif current_len > max_len:
# Truncate axis 1 (length)
audio_np = audio_np[:, :max_len]
return audio_np
high_quality_audio_np = process_numpy_audio(high_quality_audio_np, self.MAX_LENGTH)
low_quality_audio_np = process_numpy_audio(low_quality_audio_np, self.MAX_LENGTH)
# --- Remove Device Handling ---
# .to(self.device) is removed.
# --- Return NumPy arrays and metadata ---
# Note: Arrays likely have shape (1, MAX_LENGTH) here
return (high_quality_audio_np, original_sample_rate), (low_quality_audio_np, mangled_sample_rate)
return self.audio_data[idx]

View File

@@ -1,145 +1,75 @@
import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Sequence, Tuple
import torch.nn as nn
import torch.nn.utils as utils
def discriminator_block(
in_channels,
out_channels,
kernel_size=3,
stride=1,
dilation=1,
spectral_norm=True,
use_instance_norm=True,
):
padding = (kernel_size // 2) * dilation
conv_layer = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
)
if spectral_norm:
conv_layer = utils.spectral_norm(conv_layer)
layers = [conv_layer]
layers.append(nn.LeakyReLU(0.2, inplace=True))
if use_instance_norm:
layers.append(nn.InstanceNorm1d(out_channels))
return nn.Sequential(*layers)
# Assume InstanceNorm1d and AttentionBlock are defined as in the generator conversion
# --- Custom InstanceNorm1d Implementation (from Generator) ---
class InstanceNorm1d(nn.Module):
features: int
epsilon: float = 1e-5
use_scale: bool = True
use_bias: bool = True
@nn.compact
def __call__(self, x):
if x.shape[-1] != self.features:
raise ValueError(f"Input features {x.shape[-1]} does not match InstanceNorm1d features {self.features}")
mean = jnp.mean(x, axis=1, keepdims=True)
var = jnp.var(x, axis=1, keepdims=True)
normalized = (x - mean) / jnp.sqrt(var + self.epsilon)
if self.use_scale:
scale = self.param('scale', nn.initializers.ones, (self.features,))
normalized *= scale
if self.use_bias:
bias = self.param('bias', nn.initializers.zeros, (self.features,))
normalized += bias
return normalized
# --- AttentionBlock Implementation (from Generator) ---
class AttentionBlock(nn.Module):
channels: int
@nn.compact
def __call__(self, x):
ks1 = (1,)
attention_weights = nn.Conv(features=self.channels // 4, kernel_size=ks1, padding='SAME')(x)
attention_weights = nn.relu(attention_weights)
attention_weights = nn.Conv(features=self.channels, kernel_size=ks1, padding='SAME')(attention_weights)
attention_weights = nn.sigmoid(attention_weights)
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid(),
)
def forward(self, x):
attention_weights = self.attention(x)
return x * attention_weights
# --- Converted Discriminator Modules ---
class DiscriminatorBlock(nn.Module):
"""Equivalent of the PyTorch discriminator_block function."""
in_channels: int # Needed for clarity, though not strictly used by layers if input shape is known
out_channels: int
kernel_size: int = 3
stride: int = 1
dilation: int = 1
# spectral_norm: bool = True # Flag for where SN would be applied
use_instance_norm: bool = True
negative_slope: float = 0.2
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C_in)
Returns:
Output tensor (N, L', C_out) - L' depends on stride/padding
"""
# Flax Conv expects kernel_size, stride, dilation as sequences (tuples)
ks = (self.kernel_size,)
st = (self.stride,)
di = (self.dilation,)
# Padding='SAME' works reasonably well for stride=1 and stride=2 downsampling
# NOTE: Spectral Norm is omitted here.
# If implementing, you'd wrap or replace nn.Conv with a spectral-normalized version.
# conv_layer = SpectralNormConv1D(...) or wrap(nn.Conv(...))
y = nn.Conv(
features=self.out_channels,
kernel_size=ks,
strides=st,
kernel_dilation=di,
padding='SAME' # Often used in GANs
)(x)
# Apply LeakyReLU first (as in the original code if IN is used)
y = nn.leaky_relu(y, negative_slope=self.negative_slope)
# Conditionally apply InstanceNorm
if self.use_instance_norm:
y = InstanceNorm1d(features=self.out_channels)(y)
return y
class SISUDiscriminator(nn.Module):
"""SISUDiscriminator model translated to Flax."""
base_channels: int = 16
def __init__(self, layers=32):
super(SISUDiscriminator, self).__init__()
self.model = nn.Sequential(
discriminator_block(1, layers, kernel_size=7, stride=1),
discriminator_block(layers, layers * 2, kernel_size=5, stride=2),
discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=2),
AttentionBlock(layers * 4),
discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=4),
discriminator_block(layers * 8, layers * 2, kernel_size=5, stride=2),
discriminator_block(
layers * 2,
1,
spectral_norm=False,
use_instance_norm=False,
),
)
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, 1) - assumes single channel input
Returns:
Output tensor (N, 1) - logits
"""
if x.shape[-1] != 1:
raise ValueError(f"Input should have 1 channel (NLC format), got shape {x.shape}")
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
ch = self.base_channels
# Block 1: 1 -> ch, k=7, s=1, d=1, SN=T, IN=F
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=1, out_channels=ch, kernel_size=7, stride=1, use_instance_norm=False)(x)
# Block 2: ch -> ch*2, k=5, s=2, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch, out_channels=ch*2, kernel_size=5, stride=2, use_instance_norm=True)(y)
# Block 3: ch*2 -> ch*4, k=5, s=1, d=2, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch*4, kernel_size=5, stride=1, dilation=2, use_instance_norm=True)(y)
# Attention Block
y = AttentionBlock(channels=ch*4)(y)
# Block 4: ch*4 -> ch*8, k=5, s=1, d=4, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*8, kernel_size=5, stride=1, dilation=4, use_instance_norm=True)(y)
# Block 5: ch*8 -> ch*4, k=5, s=2, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*8, out_channels=ch*4, kernel_size=5, stride=2, use_instance_norm=True)(y)
# Block 6: ch*4 -> ch*2, k=3, s=1, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*4, out_channels=ch*2, kernel_size=3, stride=1, use_instance_norm=True)(y)
# Block 7: ch*2 -> ch, k=3, s=1, d=1, SN=T, IN=T
# NOTE: Spectral Norm omitted
y = DiscriminatorBlock(in_channels=ch*2, out_channels=ch, kernel_size=3, stride=1, use_instance_norm=True)(y)
# Block 8: ch -> 1, k=3, s=1, d=1, SN=F, IN=F
# NOTE: Spectral Norm omitted (as per original config)
y = DiscriminatorBlock(in_channels=ch, out_channels=1, kernel_size=3, stride=1, use_instance_norm=False)(y)
# Global Average Pooling (across Length dimension)
pooled = jnp.mean(y, axis=1) # Shape becomes (N, C=1)
# Flatten (optional, as shape is likely already (N, 1))
output = jnp.reshape(pooled, (pooled.shape[0], -1)) # Shape (N, 1)
return output
def forward(self, x):
x = self.model(x)
x = self.global_avg_pool(x)
x = x.view(x.size(0), -1)
return x

View File

@@ -1,28 +0,0 @@
import json
filepath = "my_data.json"
def write_data(filepath, data):
try:
with open(filepath, 'w') as f:
json.dump(data, f, indent=4) # Use indent for pretty formatting
print(f"Data written to '{filepath}'")
except Exception as e:
print(f"Error writing to file: {e}")
def read_data(filepath):
try:
with open(filepath, 'r') as f:
data = json.load(f)
print(f"Data read from '{filepath}'")
return data
except FileNotFoundError:
print(f"File not found: {filepath}")
return None
except json.JSONDecodeError:
print(f"Error decoding JSON from file: {filepath}")
return None
except Exception as e:
print(f"Error reading from file: {e}")
return None

View File

@@ -1,173 +1,81 @@
import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Sequence, Tuple
import torch
import torch.nn as nn
# --- Custom InstanceNorm1d Implementation ---
class InstanceNorm1d(nn.Module):
"""
Flax implementation of Instance Normalization for 1D data (NLC format).
Normalizes across the 'L' dimension.
"""
features: int
epsilon: float = 1e-5
use_scale: bool = True
use_bias: bool = True
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor of shape (batch, length, features)
def conv_block(in_channels, out_channels, kernel_size=3, dilation=1):
return nn.Sequential(
nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
dilation=dilation,
padding=(kernel_size // 2) * dilation,
),
nn.InstanceNorm1d(out_channels),
nn.PReLU(),
)
Returns:
Normalized tensor.
"""
if x.shape[-1] != self.features:
raise ValueError(f"Input features {x.shape[-1]} does not match InstanceNorm1d features {self.features}")
# Calculate mean and variance across the length dimension (axis=1)
# Keep dims for broadcasting
mean = jnp.mean(x, axis=1, keepdims=True)
# Variance calculation using mean needs care for numerical stability if needed,
# but jnp.var should handle it.
var = jnp.var(x, axis=1, keepdims=True)
# Normalize
normalized = (x - mean) / jnp.sqrt(var + self.epsilon)
# Apply learnable scale and bias if enabled
if self.use_scale:
# Parameter shape: (features,) to broadcast across N and L
scale = self.param('scale', nn.initializers.ones, (self.features,))
normalized *= scale
if self.use_bias:
# Parameter shape: (features,)
bias = self.param('bias', nn.initializers.zeros, (self.features,))
normalized += bias
return normalized
# --- Converted Modules ---
class ConvBlock(nn.Module):
"""Equivalent of the PyTorch conv_block function."""
out_channels: int
kernel_size: int = 3
dilation: int = 1
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C_in)
Returns:
Output tensor (N, L, C_out)
"""
# Flax Conv expects kernel_size and dilation as sequences (tuples)
ks = (self.kernel_size,)
di = (self.dilation,)
# Padding='SAME' attempts to preserve the length dimension for stride=1
x = nn.Conv(
features=self.out_channels,
kernel_size=ks,
kernel_dilation=di,
padding='SAME'
)(x)
x = InstanceNorm1d(features=self.out_channels)(x) # Use custom InstanceNorm
x = nn.PReLU()(x) # PReLU learns 'alpha' parameter per channel
return x
class AttentionBlock(nn.Module):
"""Simple Channel Attention Block in Flax."""
channels: int
"""
Simple Channel Attention Block. Learns to weight channels based on their importance.
"""
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C)
Returns:
Attention-weighted output tensor (N, L, C)
"""
# Flax Conv expects kernel_size as a sequence (tuple)
ks1 = (1,)
attention_weights = nn.Conv(
features=self.channels // 4, kernel_size=ks1, padding='SAME'
)(x)
# NOTE: PyTorch used inplace=True, JAX/Flax don't modify inplace
attention_weights = nn.relu(attention_weights)
attention_weights = nn.Conv(
features=self.channels, kernel_size=ks1, padding='SAME'
)(attention_weights)
attention_weights = nn.sigmoid(attention_weights)
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid(),
)
def forward(self, x):
attention_weights = self.attention(x)
return x * attention_weights
class ResidualInResidualBlock(nn.Module):
"""ResidualInResidualBlock in Flax."""
channels: int
num_convs: int = 3
def __init__(self, channels, num_convs=3):
super(ResidualInResidualBlock, self).__init__()
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, C)
Returns:
Output tensor (N, L, C)
"""
self.conv_layers = nn.Sequential(
*[conv_block(channels, channels) for _ in range(num_convs)]
)
self.attention = AttentionBlock(channels)
def forward(self, x):
residual = x
y = x
# Sequentially apply ConvBlocks
for _ in range(self.num_convs):
y = ConvBlock(
out_channels=self.channels,
kernel_size=3, # Assuming kernel_size 3 as in original conv_block default
dilation=1 # Assuming dilation 1 as in original conv_block default
)(y)
x = self.conv_layers(x)
x = self.attention(x)
return x + residual
y = AttentionBlock(channels=self.channels)(y)
return y + residual
class SISUGenerator(nn.Module):
"""SISUGenerator model translated to Flax."""
channels: int = 16
num_rirb: int = 4
alpha: float = 1.0 # Non-learnable parameter, passed during init
def __init__(self, channels=16, num_rirb=4, alpha=1):
super(SISUGenerator, self).__init__()
self.alpha = alpha
@nn.compact
def __call__(self, x):
"""
Args:
x: Input tensor (N, L, 1) - assumes single channel input
Returns:
Output tensor (N, L, 1)
"""
if x.shape[-1] != 1:
raise ValueError(f"Input should have 1 channel (NLC format), got shape {x.shape}")
self.conv1 = nn.Sequential(
nn.Conv1d(1, channels, kernel_size=7, padding=3),
nn.InstanceNorm1d(channels),
nn.PReLU(),
)
self.rir_blocks = nn.Sequential(
*[ResidualInResidualBlock(channels) for _ in range(num_rirb)]
)
self.final_layer = nn.Sequential(
nn.Conv1d(channels, 1, kernel_size=3, padding=1), nn.Tanh()
)
def forward(self, x):
residual_input = x
# Initial convolution block
# Flax Conv expects kernel_size as sequence
ks7 = (7,)
ks3 = (3,)
y = nn.Conv(features=self.channels, kernel_size=ks7, padding='SAME')(x)
y = InstanceNorm1d(features=self.channels)(y)
y = nn.PReLU()(y)
# Residual-in-Residual Blocks
rirb_out = y
for _ in range(self.num_rirb):
rirb_out = ResidualInResidualBlock(channels=self.channels)(rirb_out)
# Final layer
learned_residual = nn.Conv(
features=1, kernel_size=ks3, padding='SAME'
)(rirb_out)
# Combine with input residual
x = self.conv1(x)
x_rirb_out = self.rir_blocks(x)
learned_residual = self.final_layer(x_rirb_out)
output = residual_input + self.alpha * learned_residual
return output
return torch.tanh(output)

View File

@@ -1,12 +0,0 @@
filelock==3.16.1
fsspec==2024.10.0
Jinja2==3.1.4
MarkupSafe==2.1.5
mpmath==1.3.0
networkx==3.4.2
numpy==2.2.3
pillow==11.0.0
setuptools==70.2.0
sympy==1.13.3
tqdm==4.67.1
typing_extensions==4.12.2

View File

@@ -1,481 +1,245 @@
import jax
import jax.numpy as jnp
import optax
import tqdm
import pickle # Using pickle for simplicity to save JAX states
import os
import argparse
# You might need a JAX-compatible library for audio loading/saving or convert to numpy
import scipy.io.wavfile as wavfile # Example for saving audio
import os
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import tqdm
from accelerate import Accelerator
from torch.utils.data import DataLoader, DistributedSampler
import file_utils as Data
from data import AudioDatasetNumPy
from generator import SISUGenerator
from data import AudioDataset
from discriminator import SISUDiscriminator
from generator import SISUGenerator
from utils.TrainingTools import discriminator_train, generator_train
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
parser.add_argument("--epoch", type=int, default=0, help="Starting epoch for model versioning")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models")
# ---------------------------
# Argument parsing
# ---------------------------
parser = argparse.ArgumentParser(description="Training script (safer defaults)")
parser.add_argument("--resume", action="store_true", help="Resume training")
parser.add_argument(
"--epochs", type=int, default=5000, help="Number of training epochs"
)
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
parser.add_argument("--num_workers", type=int, default=2, help="DataLoader num_workers")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument(
"--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA"
)
args = parser.parse_args()
# Parameters
sample_rate = 44100
n_fft = 2048
hop_length = 256
win_length = n_fft
n_mels = 128
n_mfcc = 20 # If using MFCC
# ---------------------------
# Init accelerator
# ---------------------------
debug = args.debug
accelerator = Accelerator(mixed_precision="bf16")
# Initialize JAX random key
key = jax.random.PRNGKey(0)
# ---------------------------
# Models
# ---------------------------
generator = SISUGenerator()
discriminator = SISUDiscriminator()
# Initialize dataset and dataloader
dataset_dir = './dataset/good'
dataset = AudioDatasetNumPy(dataset_dir) # Use your JAX dataset
train_data_loader = DataLoader(dataset, batch_size=4, shuffle=True) # Use your JAX DataLoader
accelerator.print("🔨 | Compiling models...")
models_dir = "models"
generator = torch.compile(generator)
discriminator = torch.compile(discriminator)
accelerator.print("✅ | Compiling done!")
# ---------------------------
# Dataset / DataLoader
# ---------------------------
accelerator.print("📊 | Fetching dataset...")
dataset = AudioDataset("./dataset")
sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None
pin_memory = torch.cuda.is_available() and not args.no_pin_memory
train_loader = DataLoader(
dataset,
sampler=sampler,
batch_size=args.batch_size,
shuffle=(sampler is None),
num_workers=args.num_workers,
pin_memory=pin_memory,
persistent_workers=pin_memory,
)
if not train_loader or not train_loader.batch_size or train_loader.batch_size == 0:
accelerator.print("🪹 | There is no data to train with! Exiting...")
exit()
loader_batch_size = train_loader.batch_size
accelerator.print("✅ | Dataset fetched!")
# ---------------------------
# Losses / Optimizers / Scalers
# ---------------------------
optimizer_g = optim.AdamW(
generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001
)
optimizer_d = optim.AdamW(
discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001
)
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_g, mode="min", factor=0.5, patience=5
)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_d, mode="min", factor=0.5, patience=5
)
criterion_g = nn.BCEWithLogitsLoss()
criterion_d = nn.MSELoss()
# ---------------------------
# Prepare accelerator
# ---------------------------
generator, discriminator, optimizer_g, optimizer_d, train_loader = accelerator.prepare(
generator, discriminator, optimizer_g, optimizer_d, train_loader
)
# ---------------------------
# Checkpoint helpers
# ---------------------------
models_dir = "./models"
os.makedirs(models_dir, exist_ok=True)
audio_output_dir = "output"
os.makedirs(audio_output_dir, exist_ok=True)
# ========= MODELS =========
def save_ckpt(path, epoch):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
accelerator.save(
{
"epoch": epoch,
"G": accelerator.unwrap_model(generator).state_dict(),
"D": accelerator.unwrap_model(discriminator).state_dict(),
"optG": optimizer_g.state_dict(),
"optD": optimizer_d.state_dict(),
"schedG": scheduler_g.state_dict(),
"schedD": scheduler_d.state_dict(),
},
path,
)
start_epoch = 0
if args.resume:
ckpt_path = os.path.join(models_dir, "last.pt")
ckpt = torch.load(ckpt_path)
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"])
optimizer_g.load_state_dict(ckpt["optG"])
optimizer_d.load_state_dict(ckpt["optD"])
scheduler_g.load_state_dict(ckpt["schedG"])
scheduler_d.load_state_dict(ckpt["schedD"])
start_epoch = ckpt.get("epoch", 1)
accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!")
real_buf = torch.full(
(loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32
)
fake_buf = torch.zeros(
(loader_batch_size, 1), device=accelerator.device, dtype=torch.float32
)
accelerator.print("🏋️ | Started training...")
try:
# Fetch the first batch
first_batch = next(iter(train_data_loader))
# The batch is a tuple: ((high_quality_audio_np, high_quality_sample_rate), (low_quality_audio_np, low_quality_sample_rate))
# We need the high-quality audio NumPy array batch for initialization
sample_input_np = first_batch[0][0] # Get the high-quality audio NumPy array batch
# Convert the NumPy array batch to a JAX array
sample_input_array = jnp.array(sample_input_np)
for epoch in range(start_epoch, args.epochs):
generator.train()
discriminator.train()
# === FIX ===
# Transpose the array from (batch, channels, length) to (batch, length, channels)
# The original shape from DataLoader is likely (batch, channels, length) like (4, 1, 44100)
# The generator expects NLC format (batch, length, channels) i.e., (4, 44100, 1)
if sample_input_array.ndim == 3 and sample_input_array.shape[1] == 1:
sample_input_array = jnp.transpose(sample_input_array, (0, 2, 1)) # Swap axes 1 and 2
running_d, running_g, steps = 0.0, 0.0, 0
print(sample_input_array.shape) # Should now print (4, 44100, 1)
# === END FIX ===
for i, (
(high_quality, low_quality),
(high_sample_rate, low_sample_rate),
) in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")):
batch_size = high_quality.size(0)
except StopIteration:
print("Error: Data loader is empty. Cannot initialize models.")
exit() # Exit if no data is available
real_labels = real_buf[:batch_size].to(accelerator.device)
fake_labels = fake_buf[:batch_size].to(accelerator.device)
# --- Discriminator ---
optimizer_d.zero_grad(set_to_none=True)
with accelerator.autocast():
d_loss = discriminator_train(
high_quality,
low_quality,
real_labels,
fake_labels,
discriminator,
generator,
criterion_d,
)
key, init_key_g, init_key_d = jax.random.split(key, 3)
generator_model = SISUGenerator()
discriminator_model = SISUDiscriminator()
accelerator.backward(d_loss)
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1)
optimizer_d.step()
# Initialize parameters
generator_params = generator_model.init(init_key_g, sample_input_array)['params']
discriminator_params = discriminator_model.init(init_key_d, sample_input_array)['params']
# --- Generator ---
optimizer_g.zero_grad(set_to_none=True)
with accelerator.autocast():
g_total, g_adv = generator_train(
low_quality,
high_quality,
real_labels,
generator,
discriminator,
criterion_d,
)
# Define apply functions
generator_apply_fn = generator_model.apply
discriminator_apply_fn = discriminator_model.apply
accelerator.backward(g_total)
torch.nn.utils.clip_grad_norm_(generator.parameters(), 1)
optimizer_g.step()
d_val = accelerator.gather(d_loss.detach()).mean()
g_val = accelerator.gather(g_total.detach()).mean()
# Loss functions (JAX equivalents)
# Optax provides common loss functions. BCEWithLogitsLoss is equivalent to
# sigmoid_binary_cross_entropy in Optax combined with a sigmoid activation
# in the model output or handling logits directly. Assuming your discriminator
# outputs logits, optax.sigmoid_binary_cross_entropy is appropriate.
criterion_d = optax.sigmoid_binary_cross_entropy
criterion_l1 = optax.sigmoid_binary_cross_entropy # For Mel, STFT, MFCC losses
# Optimizers (using Optax)
optimizer_g = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.999)
optimizer_d = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.999)
# Initialize optimizer states
generator_opt_state = optimizer_g.init(generator_params)
discriminator_opt_state = optimizer_d.init(discriminator_params)
# Schedulers - Optax has learning rate schedules. ReduceLROnPlateau
# is stateful and usually handled outside the jitted training step,
# or you can implement a custom learning rate schedule in Optax that
# takes a metric. For simplicity here, we won't directly replicate the
# PyTorch ReduceLROnPlateau but you could add logic based on losses
# in the main loop to adjust the learning rate if needed.
# Load saved state if continuing training
start_epoch = args.epoch
if args.continue_training:
try:
with open(f"{models_dir}/temp_generator.pkl", 'rb') as f:
loaded_state = pickle.load(f)
generator_params = loaded_state['params']
generator_opt_state = loaded_state['opt_state']
with open(f"{models_dir}/temp_discriminator.pkl", 'rb') as f:
loaded_state = pickle.load(f)
discriminator_params = loaded_state['params']
discriminator_opt_state = loaded_state['opt_state']
epoch_data = Data.read_data(f"{models_dir}/epoch_data.json")
start_epoch = epoch_data.get("epoch", 0) + 1
print(f"Continuing training from epoch {start_epoch}")
except FileNotFoundError:
print("Continue training requested but temp models not found. Starting from scratch.")
except Exception as e:
print(f"Error loading temp models: {e}. Starting from scratch.")
if args.generator is not None:
try:
with open(args.generator, 'rb') as f:
loaded_state = pickle.load(f)
generator_params = loaded_state['params']
print(f"Loaded generator from {args.generator}")
except FileNotFoundError:
print(f"Generator model not found at {args.generator}")
if args.discriminator is not None:
try:
with open(args.discriminator, 'rb') as f:
loaded_state = pickle.load(f)
discriminator_params = loaded_state['params']
print(f"Loaded discriminator from {args.discriminator}")
except FileNotFoundError:
print(f"Discriminator model not found at {args.discriminator}")
# Initialize JAX audio transforms
# mel_transform_fn = MelSpectrogramJAX(
# sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
# win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
# )
# stft_transform_fn = SpectrogramJAX(
# n_fft=n_fft, win_length=win_length, hop_length=hop_length
# )
# mfcc_transform_fn = MFCCJAX(
# sample_rate,
# n_mfcc,
# melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
# )
# ========= JAX TRAINING STEPS =========
@jax.jit
def discriminator_train_step(
discriminator_params,
generator_params,
discriminator_opt_state,
high_quality_audio, # JAX array (batch, length, channels)
low_quality_audio, # JAX array (batch, length, channels)
real_labels, # JAX array
fake_labels, # JAX array
discriminator_apply_fn,
generator_apply_fn,
discriminator_optimizer,
criterion_d,
key # JAX random key
):
# Split key for potential randomness in model application
key, disc_key, gen_key = jax.random.split(key, 3)
def loss_fn(d_params):
# Generate fake audio
# Note: Generator is not being trained in this step, so its parameters are static
# Ensure low_quality_audio is in the expected NLC format (batch, length, channels)
if low_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
low_quality_audio = jnp.expand_dims(low_quality_audio, axis=-1)
elif low_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
low_quality_audio = jnp.expand_dims(jnp.expand_dims(low_quality_audio, axis=0), axis=-1)
enhanced_audio, _ = generator_apply_fn({'params': generator_params}, gen_key, low_quality_audio)
# Pass data through the discriminator
# Ensure enhanced_audio has a leading dimension if not already present (e.g., batch size)
if enhanced_audio.ndim == 2: # Assuming (length, channel), add batch dim
enhanced_audio = jnp.expand_dims(enhanced_audio, axis=0)
elif enhanced_audio.ndim == 1: # Assuming (length), add batch and channel dims
enhanced_audio = jnp.expand_dims(jnp.expand_dims(enhanced_audio, axis=0), axis=-1)
# Ensure high_quality_audio is in the expected NLC format (batch, length, channels)
if high_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
high_quality_audio = jnp.expand_dims(high_quality_audio, axis=-1)
elif high_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
high_quality_audio = jnp.expand_dims(jnp.expand_dims(high_quality_audio, axis=0), axis=-1)
real_output = discriminator_apply_fn({'params': d_params}, disc_key, high_quality_audio)
fake_output = discriminator_apply_fn({'params': d_params}, disc_key, enhanced_audio)
# Calculate loss (criterion_d is assumed to be Optax's sigmoid_binary_cross_entropy or similar)
# Ensure the shapes match the labels (batch_size, 1)
real_output = real_output.reshape(-1, 1)
fake_output = fake_output.reshape(-1, 1)
real_loss = jnp.mean(criterion_d(real_output, real_labels))
fake_loss = jnp.mean(criterion_d(fake_output, fake_labels))
total_loss = real_loss + fake_loss
return total_loss, (real_loss, fake_loss)
# Compute gradients
# Use jax.value_and_grad to get both the loss value and the gradients
(loss, (real_loss, fake_loss)), grads = jax.value_and_grad(loss_fn, has_aux=True)(discriminator_params)
# Apply updates
updates, new_discriminator_opt_state = discriminator_optimizer.update(grads, discriminator_opt_state, discriminator_params)
new_discriminator_params = optax.apply_updates(discriminator_params, updates)
return new_discriminator_params, new_discriminator_opt_state, loss, key
@jax.jit
def generator_train_step(
generator_params,
discriminator_params,
generator_opt_state,
low_quality_audio, # JAX array (batch, length, channels)
high_quality_audio, # JAX array (batch, length, channels)
real_labels, # JAX array
generator_apply_fn,
discriminator_apply_fn,
generator_optimizer,
criterion_d, # Adversarial loss
criterion_l1, # Feature matching loss
key # JAX random key
):
# Split key for potential randomness
key, gen_key, disc_key = jax.random.split(key, 3)
def loss_fn(g_params):
# Ensure low_quality_audio is in the expected NLC format (batch, length, channels)
if low_quality_audio.ndim == 2: # Assuming (batch, length), add channel dim
low_quality_audio = jnp.expand_dims(low_quality_audio, axis=-1)
elif low_quality_audio.ndim == 1: # Assuming (length), add batch and channel dims
low_quality_audio = jnp.expand_dims(jnp.expand_dims(low_quality_audio, axis=0), axis=-1)
# Generate enhanced audio
enhanced_audio, _ = generator_apply_fn({'params': g_params}, gen_key, low_quality_audio)
# Ensure enhanced_audio has a leading dimension if not already present
if enhanced_audio.ndim == 2: # Assuming (length, channel), add batch dim
enhanced_audio = jnp.expand_dims(enhanced_audio, axis=0)
elif enhanced_audio.ndim == 1: # Assuming (length), add batch and channel dims
enhanced_audio = jnp.expand_dims(jnp.expand_dims(enhanced_audio, axis=0), axis=-1)
# Calculate adversarial loss (generator wants discriminator to think fake is real)
# Note: Discriminator is not being trained in this step, so its parameters are static
fake_output = discriminator_apply_fn({'params': discriminator_params}, disc_key, enhanced_audio)
# Ensure the shape matches the labels (batch_size, 1)
fake_output = fake_output.reshape(-1, 1)
adversarial_loss = jnp.mean(criterion_d(fake_output, real_labels)) # Generator wants fake_output to be close to real_labels (1s)
# Feature matching losses (assuming you add these back later)
# You would need to implement JAX versions of your audio transforms
# mel_loss = criterion_l1(mel_transform_fn(enhanced_audio), mel_transform_fn(high_quality_audio))
# stft_loss = criterion_l1(stft_transform_fn(enhanced_audio), stft_transform_fn(high_quality_audio))
# mfcc_loss = criterion_l1(mfcc_transform_fn(enhanced_audio), mfcc_transform_fn(high_quality_audio))
# combined_loss = adversarial_loss + mel_loss + stft_loss + mfcc_loss
combined_loss = adversarial_loss # For now, only adversarial loss
# Return combined_loss and any other metrics needed for logging/analysis
# For now, just adversarial loss and enhanced_audio
return combined_loss, (adversarial_loss, enhanced_audio) # Add other losses here when implemented
# Compute gradients
# Update: loss_fn now returns (loss, (aux1, aux2, ...))
(loss, (adversarial_loss_val, enhanced_audio)), grads = jax.value_and_grad(loss_fn, has_aux=True)(generator_params)
# Apply updates
updates, new_generator_opt_state = generator_optimizer.update(grads, generator_opt_state, generator_params)
new_generator_params = optax.apply_updates(generator_params, updates)
# Return the loss components separately along with the enhanced audio and key
return new_generator_params, new_generator_opt_state, loss, adversarial_loss_val, enhanced_audio, key
# ========= MAIN TRAINING LOOP =========
def start_training():
global generator_params, discriminator_params, generator_opt_state, discriminator_opt_state, key
generator_epochs = 5000
for generator_epoch in range(generator_epochs):
current_epoch = start_epoch + generator_epoch
# These will hold the last processed audio examples from a batch for saving
last_high_quality_audio = None
last_low_quality_audio = None
last_ai_enhanced_audio = None
last_sample_rate = None
# Use tqdm for progress bar
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {current_epoch}"):
# high_quality_clip and low_quality_clip are tuples: (audio_array, sample_rate_array)
# Extract audio arrays and sample rates (assuming batch dimension is first)
# The arrays are NumPy arrays at this point, likely in (batch, channels, length) format
high_quality_audio_batch_np = high_quality_clip[0]
low_quality_audio_batch_np = low_quality_clip[0]
sample_rate_batch_np = high_quality_clip[1] # Assuming sample rates are the same for paired clips
# Convert NumPy arrays to JAX arrays and transpose to NLC format (batch, length, channels)
# Only transpose if the shape is (batch, channels, length)
if high_quality_audio_batch_np.ndim == 3 and high_quality_audio_batch_np.shape[1] == 1:
high_quality_audio_batch = jnp.transpose(jnp.array(high_quality_audio_batch_np), (0, 2, 1))
if torch.isfinite(d_val):
running_d += d_val.item()
else:
high_quality_audio_batch = jnp.array(high_quality_audio_batch_np) # Assume already NLC or handle other cases
accelerator.print(
f"🫥 | NaN in discriminator loss at step {i}, skipping update."
)
if low_quality_audio_batch_np.ndim == 3 and low_quality_audio_batch_np.shape[1] == 1:
low_quality_audio_batch = jnp.transpose(jnp.array(low_quality_audio_batch_np), (0, 2, 1))
if torch.isfinite(g_val):
running_g += g_val.item()
else:
low_quality_audio_batch = jnp.array(low_quality_audio_batch_np) # Assume already NLC or handle other cases
accelerator.print(
f"🫥 | NaN in generator loss at step {i}, skipping update."
)
sample_rate_batch = jnp.array(sample_rate_batch_np)
steps += 1
# epoch averages & schedulers
if steps == 0:
accelerator.print("🪹 | No steps in epoch (empty dataloader?). Exiting.")
break
batch_size = high_quality_audio_batch.shape[0]
# Create labels - JAX arrays
real_labels = jnp.ones((batch_size, 1))
fake_labels = jnp.zeros((batch_size, 1))
mean_d = running_d / steps
mean_g = running_g / steps
# Split key for each batch
key, batch_key = jax.random.split(key)
scheduler_d.step(mean_d)
scheduler_g.step(mean_g)
# ========= DISCRIMINATOR =========
# Call the jitted discriminator training step
discriminator_params, discriminator_opt_state, d_loss, batch_key = discriminator_train_step(
discriminator_params,
generator_params,
discriminator_opt_state,
high_quality_audio_batch,
low_quality_audio_batch,
real_labels,
fake_labels,
discriminator_apply_fn,
generator_apply_fn,
optimizer_d,
criterion_d,
batch_key
)
save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
# ========= GENERATOR =========
# Call the jitted generator training step
generator_params, generator_opt_state, combined_loss, adversarial_loss, enhanced_audio_batch, batch_key = generator_train_step(
generator_params,
discriminator_params,
generator_opt_state,
low_quality_audio_batch,
high_quality_audio_batch,
real_labels, # Generator tries to make fake data look real
generator_apply_fn,
discriminator_apply_fn,
optimizer_g,
criterion_d,
criterion_l1,
batch_key
)
# Print debug logs (requires waiting for JIT compilation on first step)
if debug:
# Use .block_until_ready() to ensure computation is finished before printing
# In a real scenario, you might want to log metrics less frequently
d_loss_val = d_loss.block_until_ready().item()
combined_loss_val = combined_loss.block_until_ready().item()
adversarial_loss_val = adversarial_loss.block_until_ready().item()
# Assuming other losses are returned by generator_train_step and unpacked
# mel_loss_val = mel_l1_tensor.block_until_ready().item() if mel_l1_tensor is not None else 0
# stft_loss_val = log_stft_l1_tensor.block_until_ready().item() if log_stft_l1_tensor is not None else 0
# mfcc_loss_val = mfcc_l_tensor.block_until_ready().item() if mfcc_l_tensor is not None else 0
print(f"D_LOSS: {d_loss_val:.4f}, G_COMBINED_LOSS: {combined_loss_val:.4f}, G_ADVERSARIAL_LOSS: {adversarial_loss_val:.4f}")
# Print other losses here when implemented and returned
# Schedulers - Implement your learning rate scheduling logic here if needed
# based on the losses (e.g., reducing learning rate if loss plateaus).
# This logic would typically live outside the jitted step function.
# For Optax, you might use a schedule within the optimizer definition
# or update the learning rate of the optimizer manually.
# ========= SAVE LATEST AUDIO (from the last batch processed) =========
# Access the first sample of the batch for saving
# Ensure enhanced_audio_batch has a batch dimension and is in NLC format
if enhanced_audio_batch.ndim == 2: # Assuming (length, channel), add batch dim
enhanced_audio_batch = jnp.expand_dims(enhanced_audio_batch, axis=0)
elif enhanced_audio_batch.ndim == 1: # Assuming (length), add batch and channel dims
enhanced_audio_batch = jnp.expand_dims(jnp.expand_dims(enhanced_audio_batch, axis=0), axis=-1)
last_high_quality_audio = high_quality_audio_batch[0]
last_low_quality_audio = low_quality_audio_batch[0]
last_ai_enhanced_audio = enhanced_audio_batch[0]
last_sample_rate = sample_rate_batch[0].item() # Assuming sample rate is scalar per batch item
# Save audio files periodically (outside the batch loop)
if generator_epoch % 25 == 0 and last_high_quality_audio is not None:
print(f"Saving audio for epoch {current_epoch}!")
try:
# Convert JAX arrays to NumPy arrays for saving
# Transpose back to (length, channels) or (length) if needed by wavfile.write
# Assuming the models output (length, 1) or (length) after removing batch dim
low_quality_audio_np_save = jax.device_get(last_low_quality_audio)
ai_enhanced_audio_np_save = jax.device_get(last_ai_enhanced_audio)
high_quality_audio_np_save = jax.device_get(last_high_quality_audio)
# Remove the channel dimension if it's 1 for saving with wavfile
if low_quality_audio_np_save.shape[-1] == 1:
low_quality_audio_np_save = low_quality_audio_np_save.squeeze(axis=-1)
if ai_enhanced_audio_np_save.shape[-1] == 1:
ai_enhanced_audio_np_save = ai_enhanced_audio_np_save.squeeze(axis=-1)
if high_quality_audio_np_save.shape[-1] == 1:
high_quality_audio_np_save = high_quality_audio_np_save.squeeze(axis=-1)
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-crap.wav", last_sample_rate, low_quality_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-ai.wav", last_sample_rate, ai_enhanced_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
wavfile.write(f"{audio_output_dir}/epoch-{current_epoch}-audio-orig.wav", last_sample_rate, high_quality_audio_np_save.astype(jnp.int16)) # Assuming audio is int16
except Exception as e:
print(f"Error saving audio files: {e}")
# Save model states periodically (outside the batch loop)
# Use pickle to save parameters and optimizer states
try:
with open(f"{models_dir}/temp_discriminator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(discriminator_params), 'opt_state': jax.device_get(discriminator_opt_state)}, f)
with open(f"{models_dir}/temp_generator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(generator_params), 'opt_state': jax.device_get(generator_opt_state)}, f)
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": current_epoch})
except Exception as e:
print(f"Error saving temp model states: {e}")
# Save final model states after all epochs
print("Training complete! Saving final models.")
except Exception:
try:
with open(f"{models_dir}/epoch-{start_epoch + generator_epochs - 1}-discriminator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(discriminator_params)}, f)
with open(f"{models_dir}/epoch-{start_epoch + generator_epochs - 1}-generator.pkl", 'wb') as f:
pickle.dump({'params': jax.device_get(generator_params)}, f)
save_ckpt(os.path.join(models_dir, "crash_last.pt"), epoch)
accelerator.print(f"💾 | Saved crash checkpoint for epoch {epoch}")
except Exception as e:
print(f"Error saving final model states: {e}")
accelerator.print("😬 | Failed saving crash checkpoint:", e)
raise
start_training()
accelerator.print("🏁 | Training finished.")

View File

@@ -1,194 +0,0 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import tqdm
import argparse
import math
import os
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import AudioUtils
from data import AudioDataset
from generator import SISUGenerator
from discriminator import SISUDiscriminator
from training_utils import discriminator_train, generator_train
import file_utils as Data
import torchaudio.transforms as T
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models")
args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Parameters
sample_rate = 44100
n_fft = 2048
hop_length = 256
win_length = n_fft
n_mels = 128
n_mfcc = 20 # If using MFCC
mfcc_transform = T.MFCC(
sample_rate,
n_mfcc,
melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
).to(device)
mel_transform = T.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
).to(device)
stft_transform = T.Spectrogram(
n_fft=n_fft, win_length=win_length, hop_length=hop_length
).to(device)
debug = args.debug
# Initialize dataset and dataloader
dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir, device)
models_dir = "models"
os.makedirs(models_dir, exist_ok=True)
audio_output_dir = "output"
os.makedirs(audio_output_dir, exist_ok=True)
# ========= SINGLE =========
train_data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
# ========= MODELS =========
generator = SISUGenerator()
discriminator = SISUDiscriminator()
epoch: int = args.epoch
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
if args.continue_training:
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
epoch = epoch_from_file["epoch"] + 1
else:
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
if args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
generator = generator.to(device)
discriminator = discriminator.to(device)
# Loss
criterion_g = nn.BCEWithLogitsLoss()
criterion_d = nn.BCEWithLogitsLoss()
# Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
# Scheduler
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
def start_training():
generator_epochs = 5000
for generator_epoch in range(generator_epochs):
low_quality_audio = (torch.empty((1)), 1)
high_quality_audio = (torch.empty((1)), 1)
ai_enhanced_audio = (torch.empty((1)), 1)
times_correct = 0
# ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
# for high_quality_clip, low_quality_clip in train_data_loader:
high_quality_sample = (high_quality_clip[0], high_quality_clip[1])
low_quality_sample = (low_quality_clip[0], low_quality_clip[1])
# ========= LABELS =========
batch_size = high_quality_clip[0].size(0)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# ========= DISCRIMINATOR =========
discriminator.train()
d_loss = discriminator_train(
high_quality_sample,
low_quality_sample,
real_labels,
fake_labels,
discriminator,
generator,
criterion_d,
optimizer_d
)
# ========= GENERATOR =========
generator.train()
generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train(
low_quality_sample,
high_quality_sample,
real_labels,
generator,
discriminator,
criterion_d,
optimizer_g,
device,
mel_transform,
stft_transform,
mfcc_transform
)
if debug:
print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}")
scheduler_d.step(d_loss.detach())
scheduler_g.step(adversarial_loss.detach())
# ========= SAVE LATEST AUDIO =========
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])
low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0])
ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0])
new_epoch = generator_epoch+epoch
if generator_epoch % 25 == 0:
print(f"Saved epoch {new_epoch}!")
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu().detach(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1])
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1])
#if debug:
# print(generator.state_dict().keys())
# print(discriminator.state_dict().keys())
torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt")
torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt")
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch})
torch.save(discriminator, "models/epoch-5000-discriminator.pt")
torch.save(generator, "models/epoch-5000-generator.pt")
print("Training complete!")
start_training()

View File

@@ -1,142 +0,0 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import torchaudio.transforms as T
def gpu_mfcc_loss(mfcc_transform, y_true, y_pred):
mfccs_true = mfcc_transform(y_true)
mfccs_pred = mfcc_transform(y_pred)
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
mfccs_true = mfccs_true[:, :, :min_len]
mfccs_pred = mfccs_pred[:, :, :min_len]
loss = torch.mean((mfccs_true - mfccs_pred)**2)
return loss
def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
mel_spec_true = mel_transform(y_true)
mel_spec_pred = mel_transform(y_pred)
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
mel_spec_true = mel_spec_true[..., :min_len]
mel_spec_pred = mel_spec_pred[..., :min_len]
loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred))
return loss
def mel_spectrogram_l2_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
mel_spec_true = mel_transform(y_true)
mel_spec_pred = mel_transform(y_pred)
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
mel_spec_true = mel_spec_true[..., :min_len]
mel_spec_pred = mel_spec_pred[..., :min_len]
loss = torch.mean((mel_spec_true - mel_spec_pred)**2)
return loss
def log_stft_magnitude_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
stft_mag_true = stft_transform(y_true)
stft_mag_pred = stft_transform(y_pred)
min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1])
stft_mag_true = stft_mag_true[..., :min_len]
stft_mag_pred = stft_mag_pred[..., :min_len]
loss = torch.mean(torch.abs(torch.log(stft_mag_true + eps) - torch.log(stft_mag_pred + eps)))
return loss
def spectral_convergence_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
stft_mag_true = stft_transform(y_true)
stft_mag_pred = stft_transform(y_pred)
min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1])
stft_mag_true = stft_mag_true[..., :min_len]
stft_mag_pred = stft_mag_pred[..., :min_len]
norm_true = torch.linalg.norm(stft_mag_true, ord='fro', dim=(-2, -1))
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, ord='fro', dim=(-2, -1))
loss = torch.mean(norm_diff / (norm_true + eps))
return loss
def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer):
optimizer.zero_grad()
# Forward pass for real samples
discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
with torch.no_grad():
generator_output = generator(low_quality[0])
discriminator_decision_from_fake = discriminator(generator_output)
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake))
d_loss = (d_loss_real + d_loss_fake) / 2.0
d_loss.backward()
# Optional: Gradient Clipping (can be helpful)
# nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
optimizer.step()
return d_loss
def generator_train(
low_quality,
high_quality,
real_labels,
generator,
discriminator,
adv_criterion,
g_optimizer,
device,
mel_transform: T.MelSpectrogram,
stft_transform: T.Spectrogram,
mfcc_transform: T.MFCC,
lambda_adv: float = 1.0,
lambda_mel_l1: float = 10.0,
lambda_log_stft: float = 1.0,
lambda_mfcc: float = 1.0
):
g_optimizer.zero_grad()
generator_output = generator(low_quality[0])
discriminator_decision = discriminator(generator_output)
adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision))
mel_l1 = 0.0
log_stft_l1 = 0.0
mfcc_l = 0.0
# Calculate Mel L1 Loss if weight is positive
if lambda_mel_l1 > 0:
mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality[0], generator_output)
# Calculate Log STFT L1 Loss if weight is positive
if lambda_log_stft > 0:
log_stft_l1 = log_stft_magnitude_loss(stft_transform, high_quality[0], generator_output)
# Calculate MFCC Loss if weight is positive
if lambda_mfcc > 0:
mfcc_l = gpu_mfcc_loss(mfcc_transform, high_quality[0], generator_output)
mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1
log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1
mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l
combined_loss = (lambda_adv * adversarial_loss) + \
(lambda_mel_l1 * mel_l1_tensor) + \
(lambda_log_stft * log_stft_l1_tensor) + \
(lambda_mfcc * mfcc_l_tensor)
combined_loss.backward()
# Optional: Gradient Clipping
# nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
g_optimizer.step()
# 6. Return values for logging
return generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor

View File

@@ -0,0 +1,87 @@
from typing import Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.transforms as T
class MultiResolutionSTFTLoss(nn.Module):
"""
Multi-resolution STFT loss.
Combines spectral convergence loss and log-magnitude loss
across multiple STFT resolutions.
"""
def __init__(
self,
fft_sizes: List[int] = [1024, 2048, 512],
hop_sizes: List[int] = [120, 240, 50],
win_lengths: List[int] = [600, 1200, 240],
eps: float = 1e-7,
):
super().__init__()
self.eps = eps
self.n_resolutions = len(fft_sizes)
self.stft_transforms = nn.ModuleList()
for n_fft, hop_len, win_len in zip(fft_sizes, hop_sizes, win_lengths):
window = torch.hann_window(win_len)
stft = T.Spectrogram(
n_fft=n_fft,
hop_length=hop_len,
win_length=win_len,
window_fn=lambda _: window,
power=None, # Keep complex output
center=True,
pad_mode="reflect",
normalized=False,
)
self.stft_transforms.append(stft)
def forward(
self, y_true: torch.Tensor, y_pred: torch.Tensor
) -> Dict[str, torch.Tensor]:
"""
Args:
y_true: (B, T) or (B, 1, T) waveform
y_pred: (B, T) or (B, 1, T) waveform
"""
# Ensure correct shape (B, T)
if y_true.dim() == 3 and y_true.size(1) == 1:
y_true = y_true.squeeze(1)
if y_pred.dim() == 3 and y_pred.size(1) == 1:
y_pred = y_pred.squeeze(1)
sc_loss = 0.0
mag_loss = 0.0
for stft in self.stft_transforms:
stft = stft.to(y_pred.device)
# Complex STFTs: (B, F, T, 2)
stft_true = stft(y_true)
stft_pred = stft(y_pred)
# Magnitudes
stft_mag_true = torch.abs(stft_true)
stft_mag_pred = torch.abs(stft_pred)
# --- Spectral Convergence Loss ---
norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1))
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1))
sc_loss += torch.mean(norm_diff / (norm_true + self.eps))
# --- Log STFT Magnitude Loss ---
mag_loss += F.l1_loss(
torch.log(stft_mag_pred + self.eps),
torch.log(stft_mag_true + self.eps),
)
# Average across resolutions
sc_loss /= self.n_resolutions
mag_loss /= self.n_resolutions
total_loss = sc_loss + mag_loss
return {"total": total_loss, "sc": sc_loss, "mag": mag_loss}

60
utils/TrainingTools.py Normal file
View File

@@ -0,0 +1,60 @@
import torch
# In case if needed again...
# from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
#
# stft_loss_fn = MultiResolutionSTFTLoss(
# fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240]
# )
def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor:
absolute_difference = torch.abs(input_one - input_two)
return torch.mean(absolute_difference)
def discriminator_train(
high_quality,
low_quality,
high_labels,
low_labels,
discriminator,
generator,
criterion,
):
decision_high = discriminator(high_quality)
d_loss_high = criterion(decision_high, high_labels)
# print(f"Is this real?: {discriminator_decision_from_real} | {d_loss_real}")
decision_low = discriminator(low_quality)
d_loss_low = criterion(decision_low, low_labels)
# print(f"Is this real?: {discriminator_decision_from_fake} | {d_loss_fake}")
with torch.no_grad():
generator_quality = generator(low_quality)
decision_gen = discriminator(generator_quality)
d_loss_gen = criterion(decision_gen, low_labels)
noise = torch.rand_like(high_quality) * 0.08
decision_noise = discriminator(high_quality + noise)
d_loss_noise = criterion(decision_noise, low_labels)
d_loss = (d_loss_high + d_loss_low + d_loss_gen + d_loss_noise) / 4.0
return d_loss
def generator_train(
low_quality, high_quality, real_labels, generator, discriminator, adv_criterion
):
generator_output = generator(low_quality)
discriminator_decision = discriminator(generator_output)
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
# Signal similarity
similarity_loss = signal_mae(generator_output, high_quality)
combined_loss = adversarial_loss + (similarity_loss * 100)
return combined_loss, adversarial_loss

0
utils/__init__.py Normal file
View File