1 Commits

Author SHA1 Message Date
1717e7a008 ⚗️ | Experimenting... 2025-02-10 19:35:50 +02:00
12 changed files with 254 additions and 740 deletions

View File

@@ -1,41 +1,18 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
def stereo_tensor_to_mono(waveform):
if waveform.shape[0] > 1:
# Average across channels
mono_waveform = torch.mean(waveform, dim=0, keepdim=True)
else:
# Already mono
mono_waveform = waveform
return mono_waveform
def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor: def stretch_tensor(tensor, target_length):
mono_tensor = torch.mean(waveform, dim=0, keepdim=True) scale_factor = target_length / tensor.size(1)
return mono_tensor
tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False)
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 512) -> torch.Tensor: return tensor
padding_amount = target_length - audio_tensor.size(-1)
if padding_amount <= 0:
return audio_tensor
padded_audio_tensor = F.pad(audio_tensor, (0, padding_amount))
return padded_audio_tensor
def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 512, pad_last_tensor: bool = False) -> list[torch.Tensor]:
chunks = list(torch.split(audio_tensor, chunk_size, dim=1))
if pad_last_tensor:
last_chunk = chunks[-1]
if last_chunk.size(-1) < chunk_size:
chunks[-1] = pad_tensor(last_chunk, chunk_size)
return chunks
def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor:
reconstructed_tensor = torch.cat(chunks, dim=-1)
return reconstructed_tensor
def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
max_val = torch.max(torch.abs(audio_tensor))
if max_val < eps:
return audio_tensor
return audio_tensor / max_val

View File

@@ -18,7 +18,6 @@ SISU (Super Ingenious Sound Upscaler) is a project that uses GANs (Generative Ad
1. **Set Up**: 1. **Set Up**:
- Make sure you have Python installed (version 3.8 or higher). - Make sure you have Python installed (version 3.8 or higher).
- Install needed packages: `pip install -r requirements.txt` - Install needed packages: `pip install -r requirements.txt`
- Install current version of PyTorch (CUDA/ROCm/What ever your device supports)
2. **Prepare Audio Data**: 2. **Prepare Audio Data**:
- Put your audio files in the `dataset/good` folder. - Put your audio files in the `dataset/good` folder.

View File

128
app.py
View File

@@ -1,128 +0,0 @@
import argparse
import torch
import torchaudio
import torchcodec
import tqdm
from accelerate import Accelerator
import AudioUtils
from generator import SISUGenerator
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--model", type=str, help="Model to use for upscaling")
parser.add_argument(
"--clip_length",
type=int,
default=8000,
help="Internal clip length, leave unspecified if unsure",
)
parser.add_argument(
"--sample_rate", type=int, default=44100, help="Output clip sample rate"
)
parser.add_argument(
"--bitrate",
type=int,
default=192000,
help="Output clip bitrate",
)
parser.add_argument("-i", "--input", type=str, help="Input audio file")
parser.add_argument("-o", "--output", type=str, help="Output audio file")
args = parser.parse_args()
if args.sample_rate < 8000:
print(
"Sample rate cannot be lower than 8000! (44100 is recommended for base models)"
)
exit()
# ---------------------------
# Init accelerator
# ---------------------------
accelerator = Accelerator(mixed_precision="bf16")
# ---------------------------
# Models
# ---------------------------
generator = SISUGenerator()
accelerator.print("🔨 | Compiling models...")
generator = torch.compile(generator)
accelerator.print("✅ | Compiling done!")
# ---------------------------
# Prepare accelerator
# ---------------------------
generator = accelerator.prepare(generator)
# ---------------------------
# Checkpoint helpers
# ---------------------------
models_dir = args.model
clip_length = args.clip_length
input_audio = args.input
output_audio = args.output
if models_dir:
ckpt = torch.load(models_dir)
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
accelerator.print("💾 | Loaded model!")
else:
print(
"Generator model (--model) isn't specified. Do you have the trained model? If not, you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)"
)
def start():
# To Mono!
decoder = torchcodec.decoders.AudioDecoder(input_audio)
decoded_samples = decoder.get_all_samples()
audio = decoded_samples.data
original_sample_rate = decoded_samples.sample_rate
# Support for multichannel audio
# audio = AudioUtils.stereo_tensor_to_mono(audio)
audio = AudioUtils.normalize(audio)
resample_transform = torchaudio.transforms.Resample(
original_sample_rate, args.sample_rate
)
audio = resample_transform(audio)
splitted_audio = AudioUtils.split_audio(audio, clip_length)
splitted_audio_on_device = [t.view(1, t.shape[0], t.shape[-1]).to(accelerator.device) for t in splitted_audio]
processed_audio = []
with torch.no_grad():
for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."):
channels = []
for audio_channel in torch.split(clip, 1, dim=1):
output_piece = generator(audio_channel)
channels.append(output_piece.detach().cpu())
output_clip = torch.cat(channels, dim=1)
processed_audio.append(output_clip)
reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio)
reconstructed_audio = reconstructed_audio.squeeze(0)
print(f"🔊 | Saving {output_audio}!")
torchaudio.save_with_torchcodec(
uri=output_audio,
src=reconstructed_audio,
sample_rate=args.sample_rate,
channels_first=True,
compression=args.bitrate,
)
start()

95
data.py
View File

@@ -1,71 +1,52 @@
import os from torch.utils.data import Dataset
import random import torch.nn.functional as F
import torch import torch
import torchaudio import torchaudio
import torchcodec.decoders as decoders import os
import tqdm import random
from torch.utils.data import Dataset from AudioUtils import stereo_tensor_to_mono, stretch_tensor
import AudioUtils
class AudioDataset(Dataset): class AudioDataset(Dataset):
audio_sample_rates = [8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100] audio_sample_rates = [11025]
def __init__(self, input_dir, clip_length: int = 512, normalize: bool = True): def __init__(self, input_dir):
self.clip_length = clip_length self.input_files = [
self.normalize = normalize os.path.join(root, f)
for root, _, files in os.walk(input_dir)
input_files = [ for f in files if f.endswith('.wav')
os.path.join(input_dir, f)
for f in os.listdir(input_dir)
if os.path.isfile(os.path.join(input_dir, f))
and f.lower().endswith((".wav", ".mp3", ".flac"))
] ]
data = []
for audio_clip in tqdm.tqdm(
input_files, desc=f"Processing {len(input_files)} audio file(s)"
):
decoder = decoders.AudioDecoder(audio_clip)
decoded_samples = decoder.get_all_samples()
audio = decoded_samples.data.float()
original_sample_rate = decoded_samples.sample_rate
if normalize:
audio = AudioUtils.normalize(audio)
splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length, True)
if not splitted_high_quality_audio:
continue
for splitted_audio_clip in splitted_high_quality_audio:
for audio_clip in torch.split(splitted_audio_clip, 1):
data.append((audio_clip, original_sample_rate))
self.audio_data = data
def __len__(self): def __len__(self):
return len(self.audio_data) return len(self.input_files)
def __getitem__(self, idx): def __getitem__(self, idx):
audio_clip = self.audio_data[idx] # Load high-quality audio
high_quality_path = self.input_files[idx]
high_quality_audio, original_sample_rate = torchaudio.load(high_quality_path)
high_quality_audio = stereo_tensor_to_mono(high_quality_audio)
# Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates) mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio = resample_low(high_quality_audio)
resample_transform_low = torchaudio.transforms.Resample( resample_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
audio_clip[1], mangled_sample_rate low_quality_audio = resample_high(low_quality_audio)
)
resample_transform_high = torchaudio.transforms.Resample( # Pad or truncate to match a fixed length
mangled_sample_rate, audio_clip[1] target_length = 44100 # Adjust this based on your data
) high_quality_audio = self.pad_or_truncate(high_quality_audio, target_length)
low_quality_audio = self.pad_or_truncate(low_quality_audio, target_length)
low_audio_clip = resample_transform_high(resample_transform_low(audio_clip[0])) return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)
if audio_clip[0].shape[1] < low_audio_clip.shape[1]:
low_audio_clip = low_audio_clip[:, :audio_clip[0].shape[1]] def pad_or_truncate(self, tensor, target_length):
elif audio_clip[0].shape[1] > low_audio_clip.shape[1]: current_length = tensor.size(1)
low_audio_clip = AudioUtils.pad_tensor(low_audio_clip, self.clip_length) if current_length < target_length:
return ((audio_clip[0], low_audio_clip), (audio_clip[1], mangled_sample_rate)) # Pad with zeros
padding = target_length - current_length
tensor = F.pad(tensor, (0, padding))
else:
# Truncate to target length
tensor = tensor[:, :target_length]
return tensor

View File

@@ -1,70 +1,38 @@
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.utils as utils import torch.nn.utils as utils
def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
def discriminator_block( padding = (kernel_size // 2) * dilation
in_channels, return nn.Sequential(
out_channels, utils.spectral_norm(
kernel_size=15, nn.Conv1d(in_channels, out_channels,
stride=1,
dilation=1
):
padding = dilation * (kernel_size - 1) // 2
conv_layer = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride, stride=stride,
dilation=dilation, dilation=dilation,
padding=padding padding=padding
) )
),
conv_layer = utils.spectral_norm(conv_layer) nn.BatchNorm1d(out_channels),
leaky_relu = nn.LeakyReLU(0.2) nn.LeakyReLU(0.2, inplace=True)
return nn.Sequential(conv_layer, leaky_relu)
class AttentionBlock(nn.Module):
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid(),
) )
def forward(self, x):
attention_weights = self.attention(x)
return x + (x * attention_weights)
class SISUDiscriminator(nn.Module): class SISUDiscriminator(nn.Module):
def __init__(self, layers=8): def __init__(self):
super(SISUDiscriminator, self).__init__() super(SISUDiscriminator, self).__init__()
self.discriminator_blocks = nn.Sequential( layers = 4
# 1 -> 32 self.model = nn.Sequential(
discriminator_block(2, layers), discriminator_block(1, layers, kernel_size=7, stride=2, dilation=1),
AttentionBlock(layers), discriminator_block(layers, layers * 2, kernel_size=5, stride=2, dilation=1),
# 32 -> 64 discriminator_block(layers * 2, layers * 4, kernel_size=3, dilation=4),
discriminator_block(layers, layers * 2, dilation=2), discriminator_block(layers * 4, layers * 4, kernel_size=5, dilation=8),
# 64 -> 128 discriminator_block(layers * 4, layers * 2, kernel_size=3, dilation=16),
discriminator_block(layers * 2, layers * 4, dilation=4), discriminator_block(layers * 2, layers, kernel_size=5, dilation=2),
AttentionBlock(layers * 4), discriminator_block(layers, 1, kernel_size=3, stride=1)
# 128 -> 256
discriminator_block(layers * 4, layers * 8, stride=4),
# 256 -> 512
# discriminator_block(layers * 8, layers * 16, stride=4)
) )
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
self.final_conv = nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1)
self.avg_pool = nn.AdaptiveAvgPool1d(1)
def forward(self, x): def forward(self, x):
x = self.discriminator_blocks(x) x = self.model(x)
x = self.final_conv(x) x = self.global_avg_pool(x)
x = self.avg_pool(x) return x.view(-1, 1)
return x.squeeze(2)

View File

@@ -1,122 +1,41 @@
import torch
import torch.nn as nn import torch.nn as nn
def conv_residual_block(in_channels, out_channels, kernel_size=3, dilation=1):
def GeneratorBlock(in_channels, out_channels, kernel_size=3, stride=1, dilation=1): padding = (kernel_size // 2) * dilation
padding = (kernel_size - 1) // 2 * dilation
return nn.Sequential( return nn.Sequential(
nn.Conv1d( nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=padding),
in_channels, nn.BatchNorm1d(out_channels),
out_channels, nn.PReLU(),
kernel_size=kernel_size, nn.Conv1d(out_channels, out_channels, kernel_size, dilation=dilation, padding=padding),
stride=stride, nn.BatchNorm1d(out_channels)
dilation=dilation,
padding=padding
),
nn.InstanceNorm1d(out_channels),
nn.PReLU(num_parameters=1, init=0.1),
)
class AttentionBlock(nn.Module):
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid(),
)
def forward(self, x):
attention_weights = self.attention(x)
return x + (x * attention_weights)
class ResidualInResidualBlock(nn.Module):
def __init__(self, channels, num_convs=3):
super(ResidualInResidualBlock, self).__init__()
self.conv_layers = nn.Sequential(
*[GeneratorBlock(channels, channels) for _ in range(num_convs)]
)
self.attention = AttentionBlock(channels)
def forward(self, x):
residual = x
x = self.conv_layers(x)
x = self.attention(x)
return x + residual
def UpsampleBlock(in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=4,
stride=2,
padding=1
),
nn.InstanceNorm1d(out_channels),
nn.PReLU(num_parameters=1, init=0.1)
) )
class SISUGenerator(nn.Module): class SISUGenerator(nn.Module):
def __init__(self, channels=32, num_rirb=1): def __init__(self):
super(SISUGenerator, self).__init__() super(SISUGenerator, self).__init__()
layers = 4
self.first_conv = GeneratorBlock(1, channels) self.conv1 = nn.Sequential(
nn.Conv1d(1, layers, kernel_size=7, padding=3),
self.downsample = GeneratorBlock(channels, channels * 2, stride=2) nn.BatchNorm1d(layers),
self.downsample_attn = AttentionBlock(channels * 2) nn.PReLU()
self.downsample_2 = GeneratorBlock(channels * 2, channels * 4, stride=2)
self.downsample_2_attn = AttentionBlock(channels * 4)
self.rirb = ResidualInResidualBlock(channels * 4)
# self.rirb = nn.Sequential(
# *[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)]
# )
self.upsample = UpsampleBlock(channels * 4, channels * 2)
self.upsample_attn = AttentionBlock(channels * 2)
self.compress_1 = GeneratorBlock(channels * 4, channels * 2)
self.upsample_2 = UpsampleBlock(channels * 2, channels)
self.upsample_2_attn = AttentionBlock(channels)
self.compress_2 = GeneratorBlock(channels * 2, channels)
self.final_conv = nn.Sequential(
nn.Conv1d(channels, 1, kernel_size=7, padding=3),
nn.Tanh()
) )
self.conv_blocks = nn.Sequential(
conv_residual_block(layers, layers, kernel_size=3, dilation=1),
conv_residual_block(layers, layers * 2, kernel_size=5, dilation=2),
conv_residual_block(layers * 2, layers * 4, kernel_size=3, dilation=16),
conv_residual_block(layers * 4, layers * 2, kernel_size=5, dilation=8),
conv_residual_block(layers * 2, layers, kernel_size=5, dilation=2),
conv_residual_block(layers, layers, kernel_size=3, dilation=1)
)
self.final_layer = nn.Sequential(
nn.Conv1d(layers, 1, kernel_size=3, padding=1)
)
def forward(self, x): def forward(self, x):
residual_input = x residual = x
x1 = self.first_conv(x) x = self.conv1(x)
x = self.conv_blocks(x) + x # Adding residual connection after blocks
x2 = self.downsample(x1) x = self.final_layer(x)
x2 = self.downsample_attn(x2) return x + residual
x3 = self.downsample_2(x2)
x3 = self.downsample_2_attn(x3)
x_rirb = self.rirb(x3)
up1 = self.upsample(x_rirb)
up1 = self.upsample_attn(up1)
cat1 = torch.cat((up1, x2), dim=1)
comp1 = self.compress_1(cat1)
up2 = self.upsample_2(comp1)
up2 = self.upsample_2_attn(up2)
cat2 = torch.cat((up2, x1), dim=1)
comp2 = self.compress_2(cat2)
learned_residual = self.final_conv(comp2)
output = residual_input + learned_residual
return output

14
requirements.txt Normal file
View File

@@ -0,0 +1,14 @@
filelock==3.16.1
fsspec==2024.10.0
Jinja2==3.1.4
MarkupSafe==2.1.5
mpmath==1.3.0
networkx==3.4.2
numpy==2.2.1
pytorch-triton-rocm==3.2.0+git0d4682f0
setuptools==70.2.0
sympy==1.13.1
torch==2.6.0.dev20241222+rocm6.2.4
torchaudio==2.6.0.dev20241222+rocm6.2.4
tqdm==4.67.1
typing_extensions==4.12.2

View File

@@ -1,254 +1,164 @@
import argparse
import datetime
import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import tqdm import tqdm
from accelerate import Accelerator
from torch.utils.data import DataLoader, DistributedSampler
import argparse
import math
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import AudioUtils
from data import AudioDataset from data import AudioDataset
from discriminator import SISUDiscriminator
from generator import SISUGenerator from generator import SISUGenerator
from utils.TrainingTools import discriminator_train, generator_train from discriminator import SISUDiscriminator
def perceptual_loss(y_true, y_pred):
return torch.mean((y_true - y_pred) ** 2)
def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
optimizer_d.zero_grad()
# Forward pass for real samples
discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion_d(discriminator_decision_from_real, real_labels)
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
discriminator_decision_from_fake = discriminator(generator_output.detach())
d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels)
# Combine real and fake losses
d_loss = (d_loss_real + d_loss_fake) / 2.0
# Backward pass and optimization
d_loss.backward()
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
optimizer_d.step()
return d_loss
def generator_train(low_quality, real_labels):
optimizer_g.zero_grad()
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
discriminator_decision = discriminator(generator_output)
g_loss = criterion_g(discriminator_decision, real_labels)
g_loss.backward()
optimizer_g.step()
return generator_output
def first(objects):
if len(objects) >= 1:
return objects[0]
return objects
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
# ---------------------------
# Argument parsing
# ---------------------------
parser = argparse.ArgumentParser(description="Training script (safer defaults)")
parser.add_argument("--resume", action="store_true", help="Resume training")
parser.add_argument(
"--epochs", type=int, default=5000, help="Number of training epochs"
)
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
parser.add_argument("--num_workers", type=int, default=2, help="DataLoader num_workers")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument(
"--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA"
)
args = parser.parse_args() args = parser.parse_args()
# --------------------------- # Check for CUDA availability
# Init accelerator device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --------------------------- print(f"Using device: {device}")
accelerator = Accelerator(mixed_precision="bf16") # Initialize dataset and dataloader
dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir)
# --------------------------- # ========= SINGLE =========
# Models
# --------------------------- train_data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
# Initialize models and move them to device
generator = SISUGenerator() generator = SISUGenerator()
discriminator = SISUDiscriminator() discriminator = SISUDiscriminator()
accelerator.print("🔨 | Compiling models...") if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, weights_only=True))
if args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, weights_only=True))
generator = torch.compile(generator) generator = generator.to(device)
discriminator = torch.compile(discriminator) discriminator = discriminator.to(device)
accelerator.print("✅ | Compiling done!") # Loss
criterion_g = nn.MSELoss()
criterion_d = nn.BCELoss()
# --------------------------- # Optimizers
# Dataset / DataLoader optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
# --------------------------- optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
accelerator.print("📊 | Fetching dataset...")
dataset = AudioDataset("./dataset", 8192)
sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None # Scheduler
pin_memory = torch.cuda.is_available() and not args.no_pin_memory scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
train_loader = DataLoader( def start_training():
dataset, generator_epochs = 5000
sampler=sampler, for generator_epoch in range(generator_epochs):
batch_size=args.batch_size, low_quality_audio = (torch.empty((1)), 1)
shuffle=(sampler is None), high_quality_audio = (torch.empty((1)), 1)
num_workers=args.num_workers, ai_enhanced_audio = (torch.empty((1)), 1)
pin_memory=pin_memory,
persistent_workers=pin_memory,
)
if not train_loader or not train_loader.batch_size or train_loader.batch_size == 0: times_correct = 0
accelerator.print("🪹 | There is no data to train with! Exiting...")
exit()
loader_batch_size = train_loader.batch_size # ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {generator_epoch+1}/{generator_epochs}"):
# for high_quality_clip, low_quality_clip in train_data_loader:
high_quality_sample = (high_quality_clip[0].to(device), high_quality_clip[1])
low_quality_sample = (low_quality_clip[0].to(device), low_quality_clip[1])
accelerator.print("✅ | Dataset fetched!") # ========= LABELS =========
batch_size = high_quality_clip[0].size(0)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# --------------------------- # ========= DISCRIMINATOR =========
# Losses / Optimizers / Scalers
# ---------------------------
optimizer_g = optim.AdamW(
generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001
)
optimizer_d = optim.AdamW(
discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001
)
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_g, mode="min", factor=0.5, patience=5
)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_d, mode="min", factor=0.5, patience=5
)
criterion_d = nn.MSELoss()
# ---------------------------
# Prepare accelerator
# ---------------------------
generator, discriminator, optimizer_g, optimizer_d, train_loader = accelerator.prepare(
generator, discriminator, optimizer_g, optimizer_d, train_loader
)
# ---------------------------
# Checkpoint helpers
# ---------------------------
models_dir = "./models"
os.makedirs(models_dir, exist_ok=True)
def save_ckpt(path, epoch):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
accelerator.save(
{
"epoch": epoch,
"G": accelerator.unwrap_model(generator).state_dict(),
"D": accelerator.unwrap_model(discriminator).state_dict(),
"optG": optimizer_g.state_dict(),
"optD": optimizer_d.state_dict(),
"schedG": scheduler_g.state_dict(),
"schedD": scheduler_d.state_dict(),
},
path,
)
start_epoch = 0
if args.resume:
ckpt_path = os.path.join(models_dir, "last.pt")
ckpt = torch.load(ckpt_path)
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"])
optimizer_g.load_state_dict(ckpt["optG"])
optimizer_d.load_state_dict(ckpt["optD"])
scheduler_g.load_state_dict(ckpt["schedG"])
scheduler_d.load_state_dict(ckpt["schedD"])
start_epoch = ckpt.get("epoch", 1)
accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!")
real_buf = torch.full((loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32)
fake_buf = torch.zeros((loader_batch_size, 1), device=accelerator.device, dtype=torch.float32)
accelerator.print("🏋️ | Started training...")
try:
for epoch in range(start_epoch, args.epochs):
generator.train()
discriminator.train() discriminator.train()
discriminator_train(high_quality_sample, low_quality_sample, real_labels, fake_labels)
discriminator_time = 0 # ========= GENERATOR =========
generator_time = 0 generator.train()
generator_output = generator_train(low_quality_sample, real_labels)
running_d, running_g, steps = 0.0, 0.0, 0 # ========= SAVE LATEST AUDIO =========
high_quality_audio = (first(high_quality_clip[0]), high_quality_clip[1][0])
low_quality_audio = (first(low_quality_clip[0]), low_quality_clip[1][0])
ai_enhanced_audio = (first(generator_output[0]), high_quality_clip[1][0])
print(high_quality_audio)
progress_bar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs") print(f"Saved epoch {generator_epoch}!")
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1])
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1])
for i, ( #metric = snr(high_quality_audio[0].to(device), ai_enhanced_audio[0])
(high_quality, low_quality), #print(f"Generator metric {metric}!")
(high_sample_rate, low_sample_rate), #scheduler_g.step(metric)
) in enumerate(progress_bar):
batch_size = high_quality.size(0)
real_labels = real_buf[:batch_size].to(accelerator.device) if generator_epoch % 10 == 0:
fake_labels = fake_buf[:batch_size].to(accelerator.device) print(f"Saved epoch {generator_epoch}!")
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1])
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1])
with accelerator.autocast(): torch.save(discriminator.state_dict(), f"models/current-epoch-discriminator.pt")
generator_output = generator(low_quality) torch.save(generator.state_dict(), f"models/current-epoch-generator.pt")
# --- Discriminator --- torch.save(discriminator.state_dict(), "models/epoch-5000-discriminator.pt")
d_time = datetime.datetime.now() torch.save(generator.state_dict(), "models/epoch-5000-generator.pt")
optimizer_d.zero_grad(set_to_none=True) print("Training complete!")
with accelerator.autocast():
d_loss = discriminator_train(
high_quality,
low_quality.detach(),
real_labels,
fake_labels,
discriminator,
criterion_d,
generator_output.detach()
)
accelerator.backward(d_loss) start_training()
optimizer_d.step()
discriminator_time = (datetime.datetime.now() - d_time).microseconds
# --- Generator ---
g_time = datetime.datetime.now()
optimizer_g.zero_grad(set_to_none=True)
with accelerator.autocast():
g_total, g_adv = generator_train(
low_quality,
high_quality,
real_labels,
generator,
discriminator,
criterion_d,
generator_output
)
accelerator.backward(g_total)
torch.nn.utils.clip_grad_norm_(generator.parameters(), 1)
optimizer_g.step()
generator_time = (datetime.datetime.now() - g_time).microseconds
d_val = accelerator.gather(d_loss.detach()).mean()
g_val = accelerator.gather(g_total.detach()).mean()
if torch.isfinite(d_val):
running_d += d_val.item()
else:
accelerator.print(
f"🫥 | NaN in discriminator loss at step {i}, skipping update."
)
if torch.isfinite(g_val):
running_g += g_val.item()
else:
accelerator.print(
f"🫥 | NaN in generator loss at step {i}, skipping update."
)
steps += 1
progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs")
# epoch averages & schedulers
if steps == 0:
accelerator.print("🪹 | No steps in epoch (empty dataloader?). Exiting.")
break
mean_d = running_d / steps
mean_g = running_g / steps
scheduler_d.step(mean_d)
scheduler_g.step(mean_g)
save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
except Exception:
try:
save_ckpt(os.path.join(models_dir, "crash_last.pt"), epoch)
accelerator.print(f"💾 | Saved crash checkpoint for epoch {epoch}")
except Exception as e:
accelerator.print("😬 | Failed saving crash checkpoint:", e)
raise
accelerator.print("🏁 | Training finished.")

View File

@@ -1,68 +0,0 @@
from typing import Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.transforms as T
class MultiResolutionSTFTLoss(nn.Module):
def __init__(
self,
fft_sizes: List[int] = [512, 1024, 2048, 4096, 8192],
hop_sizes: List[int] = [64, 128, 256, 512, 1024],
win_lengths: List[int] = [256, 512, 1024, 2048, 4096],
eps: float = 1e-7,
center: bool = True
):
super().__init__()
self.eps = eps
self.n_resolutions = len(fft_sizes)
self.stft_transforms = nn.ModuleList()
for i, (n_fft, hop_len, win_len) in enumerate(zip(fft_sizes, hop_sizes, win_lengths)):
stft = T.Spectrogram(
n_fft=n_fft,
hop_length=hop_len,
win_length=win_len,
window_fn=torch.hann_window,
power=None,
center=center,
pad_mode="reflect",
normalized=False,
)
self.stft_transforms.append(stft)
def forward(
self, y_true: torch.Tensor, y_pred: torch.Tensor
) -> Dict[str, torch.Tensor]:
if y_true.dim() == 3 and y_true.size(1) == 1:
y_true = y_true.squeeze(1)
if y_pred.dim() == 3 and y_pred.size(1) == 1:
y_pred = y_pred.squeeze(1)
sc_loss = 0.0
mag_loss = 0.0
for stft in self.stft_transforms:
stft.window = stft.window.to(y_true.device)
stft_true = stft(y_true)
stft_pred = stft(y_pred)
stft_mag_true = torch.abs(stft_true)
stft_mag_pred = torch.abs(stft_pred)
norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1))
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1))
sc_loss += torch.mean(norm_diff / (norm_true + self.eps))
log_mag_pred = torch.log(stft_mag_pred + self.eps)
log_mag_true = torch.log(stft_mag_true + self.eps)
mag_loss += F.l1_loss(log_mag_pred, log_mag_true)
sc_loss /= self.n_resolutions
mag_loss /= self.n_resolutions
total_loss = sc_loss + mag_loss
return {"total": total_loss, "sc": sc_loss, "mag": mag_loss}

View File

@@ -1,58 +0,0 @@
import torch
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
# stft_loss_fn = MultiResolutionSTFTLoss(
# fft_sizes=[512, 1024, 2048, 4096],
# hop_sizes=[128, 256, 512, 1024],
# win_lengths=[512, 1024, 2048, 4096]
# )
stft_loss_fn = MultiResolutionSTFTLoss(
fft_sizes=[512, 1024, 2048],
hop_sizes=[64, 128, 256],
win_lengths=[256, 512, 1024]
)
def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor:
absolute_difference = torch.abs(input_one - input_two)
return torch.mean(absolute_difference)
def discriminator_train(
high_quality,
low_quality,
high_labels,
low_labels,
discriminator,
criterion,
generator_output
):
real_pair = torch.cat((low_quality, high_quality), dim=1)
decision_real = discriminator(real_pair)
d_loss_real = criterion(decision_real, high_labels)
fake_pair = torch.cat((low_quality, generator_output), dim=1)
decision_fake = discriminator(fake_pair)
d_loss_fake = criterion(decision_fake, low_labels)
d_loss = (d_loss_real + d_loss_fake) / 2.0
return d_loss
def generator_train(
low_quality, high_quality, real_labels, generator, discriminator, adv_criterion, generator_output):
fake_pair = torch.cat((low_quality, generator_output), dim=1)
discriminator_decision = discriminator(fake_pair)
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
mae_loss = signal_mae(generator_output, high_quality)
stft_loss = stft_loss_fn(high_quality, generator_output)["total"]
lambda_mae = 10.0
lambda_stft = 2.5
lambda_adv = 2.5
combined_loss = (lambda_mae * mae_loss) + (lambda_stft * stft_loss) + (lambda_adv * adversarial_loss)
return combined_loss, adversarial_loss

View File