9 Commits
jax ... rt-test

Author SHA1 Message Date
782a3bab28 ⚗️ | More architectural changes 2025-11-18 21:34:59 +02:00
3f23242d6f ⚗️ | Added some stupid ways for training + some makeup 2025-10-04 22:38:11 +03:00
0bc8fc2792 | Made training bit... spicier. 2025-09-10 19:52:53 +03:00
ff38cefdd3 🐛 | Fix loading wrong model. 2025-06-08 18:14:31 +03:00
03fdc050cc | Made training bit faster. 2025-06-07 20:43:52 +03:00
2ded03713d | Added app.py script so the model can be used. 2025-06-06 22:10:06 +03:00
a135c765da 🐛 | Misc fixes... 2025-05-05 00:50:56 +03:00
b1e18443ba | Added support for .mp3 and .flac loading... 2025-05-04 23:56:14 +03:00
660b41aef8 :albemic: | Real-time testing... 2025-05-04 22:48:57 +03:00
13 changed files with 677 additions and 451 deletions

View File

@@ -1,18 +1,41 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
def stereo_tensor_to_mono(waveform):
if waveform.shape[0] > 1:
# Average across channels
mono_waveform = torch.mean(waveform, dim=0, keepdim=True)
else:
# Already mono
mono_waveform = waveform
return mono_waveform
def stretch_tensor(tensor, target_length): def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor:
scale_factor = target_length / tensor.size(1) mono_tensor = torch.mean(waveform, dim=0, keepdim=True)
return mono_tensor
tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False)
return tensor def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 512) -> torch.Tensor:
padding_amount = target_length - audio_tensor.size(-1)
if padding_amount <= 0:
return audio_tensor
padded_audio_tensor = F.pad(audio_tensor, (0, padding_amount))
return padded_audio_tensor
def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 512, pad_last_tensor: bool = False) -> list[torch.Tensor]:
chunks = list(torch.split(audio_tensor, chunk_size, dim=1))
if pad_last_tensor:
last_chunk = chunks[-1]
if last_chunk.size(-1) < chunk_size:
chunks[-1] = pad_tensor(last_chunk, chunk_size)
return chunks
def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor:
reconstructed_tensor = torch.cat(chunks, dim=-1)
return reconstructed_tensor
def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
max_val = torch.max(torch.abs(audio_tensor))
if max_val < eps:
return audio_tensor
return audio_tensor / max_val

0
__init__.py Normal file
View File

128
app.py Normal file
View File

@@ -0,0 +1,128 @@
import argparse
import torch
import torchaudio
import torchcodec
import tqdm
from accelerate import Accelerator
import AudioUtils
from generator import SISUGenerator
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--model", type=str, help="Model to use for upscaling")
parser.add_argument(
"--clip_length",
type=int,
default=8000,
help="Internal clip length, leave unspecified if unsure",
)
parser.add_argument(
"--sample_rate", type=int, default=44100, help="Output clip sample rate"
)
parser.add_argument(
"--bitrate",
type=int,
default=192000,
help="Output clip bitrate",
)
parser.add_argument("-i", "--input", type=str, help="Input audio file")
parser.add_argument("-o", "--output", type=str, help="Output audio file")
args = parser.parse_args()
if args.sample_rate < 8000:
print(
"Sample rate cannot be lower than 8000! (44100 is recommended for base models)"
)
exit()
# ---------------------------
# Init accelerator
# ---------------------------
accelerator = Accelerator(mixed_precision="bf16")
# ---------------------------
# Models
# ---------------------------
generator = SISUGenerator()
accelerator.print("🔨 | Compiling models...")
generator = torch.compile(generator)
accelerator.print("✅ | Compiling done!")
# ---------------------------
# Prepare accelerator
# ---------------------------
generator = accelerator.prepare(generator)
# ---------------------------
# Checkpoint helpers
# ---------------------------
models_dir = args.model
clip_length = args.clip_length
input_audio = args.input
output_audio = args.output
if models_dir:
ckpt = torch.load(models_dir)
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
accelerator.print("💾 | Loaded model!")
else:
print(
"Generator model (--model) isn't specified. Do you have the trained model? If not, you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)"
)
def start():
# To Mono!
decoder = torchcodec.decoders.AudioDecoder(input_audio)
decoded_samples = decoder.get_all_samples()
audio = decoded_samples.data
original_sample_rate = decoded_samples.sample_rate
# Support for multichannel audio
# audio = AudioUtils.stereo_tensor_to_mono(audio)
audio = AudioUtils.normalize(audio)
resample_transform = torchaudio.transforms.Resample(
original_sample_rate, args.sample_rate
)
audio = resample_transform(audio)
splitted_audio = AudioUtils.split_audio(audio, clip_length)
splitted_audio_on_device = [t.view(1, t.shape[0], t.shape[-1]).to(accelerator.device) for t in splitted_audio]
processed_audio = []
with torch.no_grad():
for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."):
channels = []
for audio_channel in torch.split(clip, 1, dim=1):
output_piece = generator(audio_channel)
channels.append(output_piece.detach().cpu())
output_clip = torch.cat(channels, dim=1)
processed_audio.append(output_clip)
reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio)
reconstructed_audio = reconstructed_audio.squeeze(0)
print(f"🔊 | Saving {output_audio}!")
torchaudio.save_with_torchcodec(
uri=output_audio,
src=reconstructed_audio,
sample_rate=args.sample_rate,
channels_first=True,
compression=args.bitrate,
)
start()

98
data.py
View File

@@ -1,53 +1,71 @@
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch
import torchaudio
import os import os
import random import random
import torchaudio.transforms as T
import torch
import torchaudio
import torchcodec.decoders as decoders
import tqdm
from torch.utils.data import Dataset
import AudioUtils import AudioUtils
class AudioDataset(Dataset):
audio_sample_rates = [11025]
MAX_LENGTH = 44100 # Define your desired maximum length here
def __init__(self, input_dir, device): class AudioDataset(Dataset):
self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')] audio_sample_rates = [8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100]
self.device = device
def __init__(self, input_dir, clip_length: int = 512, normalize: bool = True):
self.clip_length = clip_length
self.normalize = normalize
input_files = [
os.path.join(input_dir, f)
for f in os.listdir(input_dir)
if os.path.isfile(os.path.join(input_dir, f))
and f.lower().endswith((".wav", ".mp3", ".flac"))
]
data = []
for audio_clip in tqdm.tqdm(
input_files, desc=f"Processing {len(input_files)} audio file(s)"
):
decoder = decoders.AudioDecoder(audio_clip)
decoded_samples = decoder.get_all_samples()
audio = decoded_samples.data.float()
original_sample_rate = decoded_samples.sample_rate
if normalize:
audio = AudioUtils.normalize(audio)
splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length, True)
if not splitted_high_quality_audio:
continue
for splitted_audio_clip in splitted_high_quality_audio:
for audio_clip in torch.split(splitted_audio_clip, 1):
data.append((audio_clip, original_sample_rate))
self.audio_data = data
def __len__(self): def __len__(self):
return len(self.input_files) return len(self.audio_data)
def __getitem__(self, idx): def __getitem__(self, idx):
# Load high-quality audio audio_clip = self.audio_data[idx]
high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True)
# Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates) mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio = resample_transform_low(high_quality_audio)
resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate) resample_transform_low = torchaudio.transforms.Resample(
low_quality_audio = resample_transform_high(low_quality_audio) audio_clip[1], mangled_sample_rate
)
high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio) resample_transform_high = torchaudio.transforms.Resample(
low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio) mangled_sample_rate, audio_clip[1]
)
# Pad or truncate high-quality audio low_audio_clip = resample_transform_high(resample_transform_low(audio_clip[0]))
if high_quality_audio.shape[1] < self.MAX_LENGTH: if audio_clip[0].shape[1] < low_audio_clip.shape[1]:
padding = self.MAX_LENGTH - high_quality_audio.shape[1] low_audio_clip = low_audio_clip[:, :audio_clip[0].shape[1]]
high_quality_audio = F.pad(high_quality_audio, (0, padding)) elif audio_clip[0].shape[1] > low_audio_clip.shape[1]:
elif high_quality_audio.shape[1] > self.MAX_LENGTH: low_audio_clip = AudioUtils.pad_tensor(low_audio_clip, self.clip_length)
high_quality_audio = high_quality_audio[:, :self.MAX_LENGTH] return ((audio_clip[0], low_audio_clip), (audio_clip[1], mangled_sample_rate))
# Pad or truncate low-quality audio
if low_quality_audio.shape[1] < self.MAX_LENGTH:
padding = self.MAX_LENGTH - low_quality_audio.shape[1]
low_quality_audio = F.pad(low_quality_audio, (0, padding))
elif low_quality_audio.shape[1] > self.MAX_LENGTH:
low_quality_audio = low_quality_audio[:, :self.MAX_LENGTH]
high_quality_audio = high_quality_audio.to(self.device)
low_quality_audio = low_quality_audio.to(self.device)
return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)

View File

@@ -1,9 +1,16 @@
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.utils as utils import torch.nn.utils as utils
def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, spectral_norm=True, use_instance_norm=True):
padding = (kernel_size // 2) * dilation def discriminator_block(
in_channels,
out_channels,
kernel_size=15,
stride=1,
dilation=1
):
padding = dilation * (kernel_size - 1) // 2
conv_layer = nn.Conv1d( conv_layer = nn.Conv1d(
in_channels, in_channels,
out_channels, out_channels,
@@ -13,51 +20,51 @@ def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dila
padding=padding padding=padding
) )
if spectral_norm: conv_layer = utils.spectral_norm(conv_layer)
conv_layer = utils.spectral_norm(conv_layer) leaky_relu = nn.LeakyReLU(0.2)
layers = [conv_layer] return nn.Sequential(conv_layer, leaky_relu)
layers.append(nn.LeakyReLU(0.2, inplace=True))
if use_instance_norm:
layers.append(nn.InstanceNorm1d(out_channels))
return nn.Sequential(*layers)
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
def __init__(self, channels): def __init__(self, channels):
super(AttentionBlock, self).__init__() super(AttentionBlock, self).__init__()
self.attention = nn.Sequential( self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1), nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True), nn.ReLU(),
nn.Conv1d(channels // 4, channels, kernel_size=1), nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid() nn.Sigmoid(),
) )
def forward(self, x): def forward(self, x):
attention_weights = self.attention(x) attention_weights = self.attention(x)
return x * attention_weights return x + (x * attention_weights)
class SISUDiscriminator(nn.Module): class SISUDiscriminator(nn.Module):
def __init__(self, base_channels=16): def __init__(self, layers=8):
super(SISUDiscriminator, self).__init__() super(SISUDiscriminator, self).__init__()
layers = base_channels self.discriminator_blocks = nn.Sequential(
self.model = nn.Sequential( # 1 -> 32
discriminator_block(1, layers, kernel_size=7, stride=1, spectral_norm=True, use_instance_norm=False), discriminator_block(2, layers),
discriminator_block(layers, layers * 2, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True), AttentionBlock(layers),
discriminator_block(layers * 2, layers * 4, kernel_size=5, stride=1, dilation=2, spectral_norm=True, use_instance_norm=True), # 32 -> 64
discriminator_block(layers, layers * 2, dilation=2),
# 64 -> 128
discriminator_block(layers * 2, layers * 4, dilation=4),
AttentionBlock(layers * 4), AttentionBlock(layers * 4),
discriminator_block(layers * 4, layers * 8, kernel_size=5, stride=1, dilation=4, spectral_norm=True, use_instance_norm=True), # 128 -> 256
discriminator_block(layers * 8, layers * 4, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True), discriminator_block(layers * 4, layers * 8, stride=4),
discriminator_block(layers * 4, layers * 2, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True), # 256 -> 512
discriminator_block(layers * 2, layers, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True), # discriminator_block(layers * 8, layers * 16, stride=4)
discriminator_block(layers, 1, kernel_size=3, stride=1, spectral_norm=False, use_instance_norm=False)
) )
self.global_avg_pool = nn.AdaptiveAvgPool1d(1) self.final_conv = nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1)
self.avg_pool = nn.AdaptiveAvgPool1d(1)
def forward(self, x): def forward(self, x):
x = self.model(x) x = self.discriminator_blocks(x)
x = self.global_avg_pool(x) x = self.final_conv(x)
x = x.view(x.size(0), -1) x = self.avg_pool(x)
return x return x.squeeze(2)

View File

@@ -1,28 +0,0 @@
import json
filepath = "my_data.json"
def write_data(filepath, data):
try:
with open(filepath, 'w') as f:
json.dump(data, f, indent=4) # Use indent for pretty formatting
print(f"Data written to '{filepath}'")
except Exception as e:
print(f"Error writing to file: {e}")
def read_data(filepath):
try:
with open(filepath, 'r') as f:
data = json.load(f)
print(f"Data read from '{filepath}'")
return data
except FileNotFoundError:
print(f"File not found: {filepath}")
return None
except json.JSONDecodeError:
print(f"Error decoding JSON from file: {filepath}")
return None
except Exception as e:
print(f"Error reading from file: {e}")
return None

View File

@@ -1,42 +1,44 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
def conv_block(in_channels, out_channels, kernel_size=3, dilation=1):
def GeneratorBlock(in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
padding = (kernel_size - 1) // 2 * dilation
return nn.Sequential( return nn.Sequential(
nn.Conv1d( nn.Conv1d(
in_channels, in_channels,
out_channels, out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride,
dilation=dilation, dilation=dilation,
padding=(kernel_size // 2) * dilation padding=padding
), ),
nn.InstanceNorm1d(out_channels), nn.InstanceNorm1d(out_channels),
nn.PReLU() nn.PReLU(num_parameters=1, init=0.1),
) )
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
"""
Simple Channel Attention Block. Learns to weight channels based on their importance.
"""
def __init__(self, channels): def __init__(self, channels):
super(AttentionBlock, self).__init__() super(AttentionBlock, self).__init__()
self.attention = nn.Sequential( self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1), nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv1d(channels // 4, channels, kernel_size=1), nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid() nn.Sigmoid(),
) )
def forward(self, x): def forward(self, x):
attention_weights = self.attention(x) attention_weights = self.attention(x)
return x * attention_weights return x + (x * attention_weights)
class ResidualInResidualBlock(nn.Module): class ResidualInResidualBlock(nn.Module):
def __init__(self, channels, num_convs=3): def __init__(self, channels, num_convs=3):
super(ResidualInResidualBlock, self).__init__() super(ResidualInResidualBlock, self).__init__()
self.conv_layers = nn.Sequential( self.conv_layers = nn.Sequential(
*[conv_block(channels, channels) for _ in range(num_convs)] *[GeneratorBlock(channels, channels) for _ in range(num_convs)]
) )
self.attention = AttentionBlock(channels) self.attention = AttentionBlock(channels)
@@ -47,28 +49,74 @@ class ResidualInResidualBlock(nn.Module):
x = self.attention(x) x = self.attention(x)
return x + residual return x + residual
def UpsampleBlock(in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=4,
stride=2,
padding=1
),
nn.InstanceNorm1d(out_channels),
nn.PReLU(num_parameters=1, init=0.1)
)
class SISUGenerator(nn.Module): class SISUGenerator(nn.Module):
def __init__(self, channels=16, num_rirb=4, alpha=1.0): def __init__(self, channels=32, num_rirb=1):
super(SISUGenerator, self).__init__() super(SISUGenerator, self).__init__()
self.alpha = alpha
self.conv1 = nn.Sequential( self.first_conv = GeneratorBlock(1, channels)
nn.Conv1d(1, channels, kernel_size=7, padding=3),
nn.InstanceNorm1d(channels), self.downsample = GeneratorBlock(channels, channels * 2, stride=2)
nn.PReLU(), self.downsample_attn = AttentionBlock(channels * 2)
self.downsample_2 = GeneratorBlock(channels * 2, channels * 4, stride=2)
self.downsample_2_attn = AttentionBlock(channels * 4)
self.rirb = ResidualInResidualBlock(channels * 4)
# self.rirb = nn.Sequential(
# *[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)]
# )
self.upsample = UpsampleBlock(channels * 4, channels * 2)
self.upsample_attn = AttentionBlock(channels * 2)
self.compress_1 = GeneratorBlock(channels * 4, channels * 2)
self.upsample_2 = UpsampleBlock(channels * 2, channels)
self.upsample_2_attn = AttentionBlock(channels)
self.compress_2 = GeneratorBlock(channels * 2, channels)
self.final_conv = nn.Sequential(
nn.Conv1d(channels, 1, kernel_size=7, padding=3),
nn.Tanh()
) )
self.rir_blocks = nn.Sequential(
*[ResidualInResidualBlock(channels) for _ in range(num_rirb)]
)
self.final_layer = nn.Conv1d(channels, 1, kernel_size=3, padding=1)
def forward(self, x): def forward(self, x):
residual_input = x residual_input = x
x = self.conv1(x) x1 = self.first_conv(x)
x_rirb_out = self.rir_blocks(x)
learned_residual = self.final_layer(x_rirb_out)
output = residual_input + self.alpha * learned_residual
x2 = self.downsample(x1)
x2 = self.downsample_attn(x2)
x3 = self.downsample_2(x2)
x3 = self.downsample_2_attn(x3)
x_rirb = self.rirb(x3)
up1 = self.upsample(x_rirb)
up1 = self.upsample_attn(up1)
cat1 = torch.cat((up1, x2), dim=1)
comp1 = self.compress_1(cat1)
up2 = self.upsample_2(comp1)
up2 = self.upsample_2_attn(up2)
cat2 = torch.cat((up2, x1), dim=1)
comp2 = self.compress_2(cat2)
learned_residual = self.final_conv(comp2)
output = residual_input + learned_residual
return output return output

View File

@@ -1,12 +0,0 @@
filelock==3.16.1
fsspec==2024.10.0
Jinja2==3.1.4
MarkupSafe==2.1.5
mpmath==1.3.0
networkx==3.4.2
numpy==2.2.3
pillow==11.0.0
setuptools==70.2.0
sympy==1.13.3
tqdm==4.67.1
typing_extensions==4.12.2

View File

@@ -1,194 +1,254 @@
import argparse
import datetime
import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import tqdm import tqdm
from accelerate import Accelerator
from torch.utils.data import DataLoader, DistributedSampler
import argparse
import math
import os
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import AudioUtils
from data import AudioDataset from data import AudioDataset
from generator import SISUGenerator
from discriminator import SISUDiscriminator from discriminator import SISUDiscriminator
from generator import SISUGenerator
from utils.TrainingTools import discriminator_train, generator_train
from training_utils import discriminator_train, generator_train # ---------------------------
import file_utils as Data # Argument parsing
# ---------------------------
import torchaudio.transforms as T parser = argparse.ArgumentParser(description="Training script (safer defaults)")
parser.add_argument("--resume", action="store_true", help="Resume training")
# Init script argument parser parser.add_argument(
parser = argparse.ArgumentParser(description="Training script") "--epochs", type=int, default=5000, help="Number of training epochs"
parser.add_argument("--generator", type=str, default=None, )
help="Path to the generator model file") parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
parser.add_argument("--discriminator", type=str, default=None, parser.add_argument("--num_workers", type=int, default=2, help="DataLoader num_workers")
help="Path to the discriminator model file") parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument("--device", type=str, default="cpu", help="Select device") parser.add_argument(
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning") "--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA"
parser.add_argument("--debug", action="store_true", help="Print debug logs") )
parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models")
args = parser.parse_args() args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu") # ---------------------------
print(f"Using device: {device}") # Init accelerator
# ---------------------------
# Parameters accelerator = Accelerator(mixed_precision="bf16")
sample_rate = 44100
n_fft = 2048
hop_length = 256
win_length = n_fft
n_mels = 128
n_mfcc = 20 # If using MFCC
mfcc_transform = T.MFCC(
sample_rate,
n_mfcc,
melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
).to(device)
mel_transform = T.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
).to(device)
stft_transform = T.Spectrogram(
n_fft=n_fft, win_length=win_length, hop_length=hop_length
).to(device)
debug = args.debug
# Initialize dataset and dataloader
dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir, device)
models_dir = "models"
os.makedirs(models_dir, exist_ok=True)
audio_output_dir = "output"
os.makedirs(audio_output_dir, exist_ok=True)
# ========= SINGLE =========
train_data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
# ========= MODELS =========
# ---------------------------
# Models
# ---------------------------
generator = SISUGenerator() generator = SISUGenerator()
discriminator = SISUDiscriminator() discriminator = SISUDiscriminator()
epoch: int = args.epoch accelerator.print("🔨 | Compiling models...")
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
if args.continue_training: generator = torch.compile(generator)
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) discriminator = torch.compile(discriminator)
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
epoch = epoch_from_file["epoch"] + 1
else:
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
if args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
generator = generator.to(device) accelerator.print("✅ | Compiling done!")
discriminator = discriminator.to(device)
# Loss # ---------------------------
criterion_g = nn.BCEWithLogitsLoss() # Dataset / DataLoader
criterion_d = nn.BCEWithLogitsLoss() # ---------------------------
accelerator.print("📊 | Fetching dataset...")
dataset = AudioDataset("./dataset", 8192)
# Optimizers sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999)) pin_memory = torch.cuda.is_available() and not args.no_pin_memory
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
# Scheduler train_loader = DataLoader(
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5) dataset,
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5) sampler=sampler,
batch_size=args.batch_size,
shuffle=(sampler is None),
num_workers=args.num_workers,
pin_memory=pin_memory,
persistent_workers=pin_memory,
)
def start_training(): if not train_loader or not train_loader.batch_size or train_loader.batch_size == 0:
generator_epochs = 5000 accelerator.print("🪹 | There is no data to train with! Exiting...")
for generator_epoch in range(generator_epochs): exit()
low_quality_audio = (torch.empty((1)), 1)
high_quality_audio = (torch.empty((1)), 1)
ai_enhanced_audio = (torch.empty((1)), 1)
times_correct = 0 loader_batch_size = train_loader.batch_size
# ========= TRAINING ========= accelerator.print("✅ | Dataset fetched!")
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
# for high_quality_clip, low_quality_clip in train_data_loader:
high_quality_sample = (high_quality_clip[0], high_quality_clip[1])
low_quality_sample = (low_quality_clip[0], low_quality_clip[1])
# ========= LABELS ========= # ---------------------------
batch_size = high_quality_clip[0].size(0) # Losses / Optimizers / Scalers
real_labels = torch.ones(batch_size, 1).to(device) # ---------------------------
fake_labels = torch.zeros(batch_size, 1).to(device)
# ========= DISCRIMINATOR ========= optimizer_g = optim.AdamW(
discriminator.train() generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001
d_loss = discriminator_train( )
high_quality_sample, optimizer_d = optim.AdamW(
low_quality_sample, discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001
real_labels, )
fake_labels,
discriminator,
generator,
criterion_d,
optimizer_d
)
# ========= GENERATOR ========= scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(
generator.train() optimizer_g, mode="min", factor=0.5, patience=5
generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train( )
low_quality_sample, scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(
high_quality_sample, optimizer_d, mode="min", factor=0.5, patience=5
real_labels, )
generator,
discriminator,
criterion_d,
optimizer_g,
device,
mel_transform,
stft_transform,
mfcc_transform
)
if debug: criterion_d = nn.MSELoss()
print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}")
scheduler_d.step(d_loss.detach())
scheduler_g.step(adversarial_loss.detach())
# ========= SAVE LATEST AUDIO ========= # ---------------------------
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0]) # Prepare accelerator
low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0]) # ---------------------------
ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0])
new_epoch = generator_epoch+epoch generator, discriminator, optimizer_g, optimizer_d, train_loader = accelerator.prepare(
generator, discriminator, optimizer_g, optimizer_d, train_loader
)
if generator_epoch % 25 == 0: # ---------------------------
print(f"Saved epoch {new_epoch}!") # Checkpoint helpers
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu().detach(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. # ---------------------------
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1]) models_dir = "./models"
torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1]) os.makedirs(models_dir, exist_ok=True)
#if debug:
# print(generator.state_dict().keys())
# print(discriminator.state_dict().keys())
torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt")
torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt")
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch})
torch.save(discriminator, "models/epoch-5000-discriminator.pt") def save_ckpt(path, epoch):
torch.save(generator, "models/epoch-5000-generator.pt") accelerator.wait_for_everyone()
print("Training complete!") if accelerator.is_main_process:
accelerator.save(
{
"epoch": epoch,
"G": accelerator.unwrap_model(generator).state_dict(),
"D": accelerator.unwrap_model(discriminator).state_dict(),
"optG": optimizer_g.state_dict(),
"optD": optimizer_d.state_dict(),
"schedG": scheduler_g.state_dict(),
"schedD": scheduler_d.state_dict(),
},
path,
)
start_training()
start_epoch = 0
if args.resume:
ckpt_path = os.path.join(models_dir, "last.pt")
ckpt = torch.load(ckpt_path)
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"])
optimizer_g.load_state_dict(ckpt["optG"])
optimizer_d.load_state_dict(ckpt["optD"])
scheduler_g.load_state_dict(ckpt["schedG"])
scheduler_d.load_state_dict(ckpt["schedD"])
start_epoch = ckpt.get("epoch", 1)
accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!")
real_buf = torch.full((loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32)
fake_buf = torch.zeros((loader_batch_size, 1), device=accelerator.device, dtype=torch.float32)
accelerator.print("🏋️ | Started training...")
try:
for epoch in range(start_epoch, args.epochs):
generator.train()
discriminator.train()
discriminator_time = 0
generator_time = 0
running_d, running_g, steps = 0.0, 0.0, 0
progress_bar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs")
for i, (
(high_quality, low_quality),
(high_sample_rate, low_sample_rate),
) in enumerate(progress_bar):
batch_size = high_quality.size(0)
real_labels = real_buf[:batch_size].to(accelerator.device)
fake_labels = fake_buf[:batch_size].to(accelerator.device)
with accelerator.autocast():
generator_output = generator(low_quality)
# --- Discriminator ---
d_time = datetime.datetime.now()
optimizer_d.zero_grad(set_to_none=True)
with accelerator.autocast():
d_loss = discriminator_train(
high_quality,
low_quality.detach(),
real_labels,
fake_labels,
discriminator,
criterion_d,
generator_output.detach()
)
accelerator.backward(d_loss)
optimizer_d.step()
discriminator_time = (datetime.datetime.now() - d_time).microseconds
# --- Generator ---
g_time = datetime.datetime.now()
optimizer_g.zero_grad(set_to_none=True)
with accelerator.autocast():
g_total, g_adv = generator_train(
low_quality,
high_quality,
real_labels,
generator,
discriminator,
criterion_d,
generator_output
)
accelerator.backward(g_total)
torch.nn.utils.clip_grad_norm_(generator.parameters(), 1)
optimizer_g.step()
generator_time = (datetime.datetime.now() - g_time).microseconds
d_val = accelerator.gather(d_loss.detach()).mean()
g_val = accelerator.gather(g_total.detach()).mean()
if torch.isfinite(d_val):
running_d += d_val.item()
else:
accelerator.print(
f"🫥 | NaN in discriminator loss at step {i}, skipping update."
)
if torch.isfinite(g_val):
running_g += g_val.item()
else:
accelerator.print(
f"🫥 | NaN in generator loss at step {i}, skipping update."
)
steps += 1
progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs")
# epoch averages & schedulers
if steps == 0:
accelerator.print("🪹 | No steps in epoch (empty dataloader?). Exiting.")
break
mean_d = running_d / steps
mean_g = running_g / steps
scheduler_d.step(mean_d)
scheduler_g.step(mean_g)
save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
except Exception:
try:
save_ckpt(os.path.join(models_dir, "crash_last.pt"), epoch)
accelerator.print(f"💾 | Saved crash checkpoint for epoch {epoch}")
except Exception as e:
accelerator.print("😬 | Failed saving crash checkpoint:", e)
raise
accelerator.print("🏁 | Training finished.")

View File

@@ -1,144 +0,0 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import torchaudio.transforms as T
def gpu_mfcc_loss(mfcc_transform, y_true, y_pred):
mfccs_true = mfcc_transform(y_true)
mfccs_pred = mfcc_transform(y_pred)
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
mfccs_true = mfccs_true[:, :, :min_len]
mfccs_pred = mfccs_pred[:, :, :min_len]
loss = torch.mean((mfccs_true - mfccs_pred)**2)
return loss
def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
mel_spec_true = mel_transform(y_true)
mel_spec_pred = mel_transform(y_pred)
# Ensure same time dimension length (due to potential framing differences)
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
mel_spec_true = mel_spec_true[..., :min_len]
mel_spec_pred = mel_spec_pred[..., :min_len]
# L1 Loss (Mean Absolute Error)
loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred))
return loss
def mel_spectrogram_l2_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
mel_spec_true = mel_transform(y_true)
mel_spec_pred = mel_transform(y_pred)
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
mel_spec_true = mel_spec_true[..., :min_len]
mel_spec_pred = mel_spec_pred[..., :min_len]
loss = torch.mean((mel_spec_true - mel_spec_pred)**2)
return loss
def log_stft_magnitude_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
stft_mag_true = stft_transform(y_true)
stft_mag_pred = stft_transform(y_pred)
min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1])
stft_mag_true = stft_mag_true[..., :min_len]
stft_mag_pred = stft_mag_pred[..., :min_len]
loss = torch.mean(torch.abs(torch.log(stft_mag_true + eps) - torch.log(stft_mag_pred + eps)))
return loss
def spectral_convergence_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
stft_mag_true = stft_transform(y_true)
stft_mag_pred = stft_transform(y_pred)
min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1])
stft_mag_true = stft_mag_true[..., :min_len]
stft_mag_pred = stft_mag_pred[..., :min_len]
norm_true = torch.linalg.norm(stft_mag_true, ord='fro', dim=(-2, -1))
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, ord='fro', dim=(-2, -1))
loss = torch.mean(norm_diff / (norm_true + eps))
return loss
def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer):
optimizer.zero_grad()
# Forward pass for real samples
discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
with torch.no_grad():
generator_output = generator(low_quality[0])
discriminator_decision_from_fake = discriminator(generator_output)
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake))
d_loss = (d_loss_real + d_loss_fake) / 2.0
d_loss.backward()
# Optional: Gradient Clipping (can be helpful)
# nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
optimizer.step()
return d_loss
def generator_train(
low_quality,
high_quality,
real_labels,
generator,
discriminator,
adv_criterion,
g_optimizer,
device,
mel_transform: T.MelSpectrogram,
stft_transform: T.Spectrogram,
mfcc_transform: T.MFCC,
lambda_adv: float = 1.0,
lambda_mel_l1: float = 10.0,
lambda_log_stft: float = 1.0,
lambda_mfcc: float = 1.0
):
g_optimizer.zero_grad()
generator_output = generator(low_quality[0])
discriminator_decision = discriminator(generator_output)
adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision))
mel_l1 = 0.0
log_stft_l1 = 0.0
mfcc_l = 0.0
# Calculate Mel L1 Loss if weight is positive
if lambda_mel_l1 > 0:
mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality[0], generator_output)
# Calculate Log STFT L1 Loss if weight is positive
if lambda_log_stft > 0:
log_stft_l1 = log_stft_magnitude_loss(stft_transform, high_quality[0], generator_output)
# Calculate MFCC Loss if weight is positive
if lambda_mfcc > 0:
mfcc_l = gpu_mfcc_loss(mfcc_transform, high_quality[0], generator_output)
mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1
log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1
mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l
combined_loss = (lambda_adv * adversarial_loss) + \
(lambda_mel_l1 * mel_l1_tensor) + \
(lambda_log_stft * log_stft_l1_tensor) + \
(lambda_mfcc * mfcc_l_tensor)
combined_loss.backward()
# Optional: Gradient Clipping
# nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
g_optimizer.step()
# 6. Return values for logging
return generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor

View File

@@ -0,0 +1,68 @@
from typing import Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.transforms as T
class MultiResolutionSTFTLoss(nn.Module):
def __init__(
self,
fft_sizes: List[int] = [512, 1024, 2048, 4096, 8192],
hop_sizes: List[int] = [64, 128, 256, 512, 1024],
win_lengths: List[int] = [256, 512, 1024, 2048, 4096],
eps: float = 1e-7,
center: bool = True
):
super().__init__()
self.eps = eps
self.n_resolutions = len(fft_sizes)
self.stft_transforms = nn.ModuleList()
for i, (n_fft, hop_len, win_len) in enumerate(zip(fft_sizes, hop_sizes, win_lengths)):
stft = T.Spectrogram(
n_fft=n_fft,
hop_length=hop_len,
win_length=win_len,
window_fn=torch.hann_window,
power=None,
center=center,
pad_mode="reflect",
normalized=False,
)
self.stft_transforms.append(stft)
def forward(
self, y_true: torch.Tensor, y_pred: torch.Tensor
) -> Dict[str, torch.Tensor]:
if y_true.dim() == 3 and y_true.size(1) == 1:
y_true = y_true.squeeze(1)
if y_pred.dim() == 3 and y_pred.size(1) == 1:
y_pred = y_pred.squeeze(1)
sc_loss = 0.0
mag_loss = 0.0
for stft in self.stft_transforms:
stft.window = stft.window.to(y_true.device)
stft_true = stft(y_true)
stft_pred = stft(y_pred)
stft_mag_true = torch.abs(stft_true)
stft_mag_pred = torch.abs(stft_pred)
norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1))
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1))
sc_loss += torch.mean(norm_diff / (norm_true + self.eps))
log_mag_pred = torch.log(stft_mag_pred + self.eps)
log_mag_true = torch.log(stft_mag_true + self.eps)
mag_loss += F.l1_loss(log_mag_pred, log_mag_true)
sc_loss /= self.n_resolutions
mag_loss /= self.n_resolutions
total_loss = sc_loss + mag_loss
return {"total": total_loss, "sc": sc_loss, "mag": mag_loss}

58
utils/TrainingTools.py Normal file
View File

@@ -0,0 +1,58 @@
import torch
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
# stft_loss_fn = MultiResolutionSTFTLoss(
# fft_sizes=[512, 1024, 2048, 4096],
# hop_sizes=[128, 256, 512, 1024],
# win_lengths=[512, 1024, 2048, 4096]
# )
stft_loss_fn = MultiResolutionSTFTLoss(
fft_sizes=[512, 1024, 2048],
hop_sizes=[64, 128, 256],
win_lengths=[256, 512, 1024]
)
def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor:
absolute_difference = torch.abs(input_one - input_two)
return torch.mean(absolute_difference)
def discriminator_train(
high_quality,
low_quality,
high_labels,
low_labels,
discriminator,
criterion,
generator_output
):
real_pair = torch.cat((low_quality, high_quality), dim=1)
decision_real = discriminator(real_pair)
d_loss_real = criterion(decision_real, high_labels)
fake_pair = torch.cat((low_quality, generator_output), dim=1)
decision_fake = discriminator(fake_pair)
d_loss_fake = criterion(decision_fake, low_labels)
d_loss = (d_loss_real + d_loss_fake) / 2.0
return d_loss
def generator_train(
low_quality, high_quality, real_labels, generator, discriminator, adv_criterion, generator_output):
fake_pair = torch.cat((low_quality, generator_output), dim=1)
discriminator_decision = discriminator(fake_pair)
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
mae_loss = signal_mae(generator_output, high_quality)
stft_loss = stft_loss_fn(high_quality, generator_output)["total"]
lambda_mae = 10.0
lambda_stft = 2.5
lambda_adv = 2.5
combined_loss = (lambda_mae * mae_loss) + (lambda_stft * stft_loss) + (lambda_adv * adversarial_loss)
return combined_loss, adversarial_loss

0
utils/__init__.py Normal file
View File