Compare commits

..

1 Commits

Author SHA1 Message Date
1717e7a008 ⚗️ | Experimenting... 2025-02-10 19:35:50 +02:00
12 changed files with 255 additions and 882 deletions

View File

@@ -1,41 +1,18 @@
import torch
import torch.nn.functional as F
def stereo_tensor_to_mono(waveform):
if waveform.shape[0] > 1:
# Average across channels
mono_waveform = torch.mean(waveform, dim=0, keepdim=True)
else:
# Already mono
mono_waveform = waveform
return mono_waveform
def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor:
mono_tensor = torch.mean(waveform, dim=0, keepdim=True)
return mono_tensor
def stretch_tensor(tensor, target_length):
scale_factor = target_length / tensor.size(1)
tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False)
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 512) -> torch.Tensor:
padding_amount = target_length - audio_tensor.size(-1)
if padding_amount <= 0:
return audio_tensor
padded_audio_tensor = F.pad(audio_tensor, (0, padding_amount))
return padded_audio_tensor
def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 512, pad_last_tensor: bool = False) -> list[torch.Tensor]:
chunks = list(torch.split(audio_tensor, chunk_size, dim=1))
if pad_last_tensor:
last_chunk = chunks[-1]
if last_chunk.size(-1) < chunk_size:
chunks[-1] = pad_tensor(last_chunk, chunk_size)
return chunks
def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor:
reconstructed_tensor = torch.cat(chunks, dim=-1)
return reconstructed_tensor
def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
max_val = torch.max(torch.abs(audio_tensor))
if max_val < eps:
return audio_tensor
return audio_tensor / max_val
return tensor

View File

@@ -2,7 +2,7 @@
### Overview
SISU (Super Ingenious Sound Upscaler) is an experimental project that is mostly vibe coded and uses GANs (Generative Adversarial Networks) to make low-quality audio better. The goal is to take not-so-good-sounding audio and turn it into high-quality, clear audio.
SISU (Super Ingenious Sound Upscaler) is a project that uses GANs (Generative Adversarial Networks) to make low-quality audio better. The goal is to take not-so-good-sounding audio and turn it into high-quality, clear audio.
### Structure of the Project
@@ -18,7 +18,6 @@ SISU (Super Ingenious Sound Upscaler) is an experimental project that is mostly
1. **Set Up**:
- Make sure you have Python installed (version 3.8 or higher).
- Install needed packages: `pip install -r requirements.txt`
- Install current version of PyTorch (CUDA/ROCm/What ever your device supports)
2. **Prepare Audio Data**:
- Put your audio files in the `dataset/good` folder.

View File

128
app.py
View File

@@ -1,128 +0,0 @@
import argparse
import torch
import torchaudio
import torchcodec
import tqdm
from accelerate import Accelerator
import AudioUtils
from generator import SISUGenerator
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--model", type=str, help="Model to use for upscaling")
parser.add_argument(
"--clip_length",
type=int,
default=8000,
help="Internal clip length, leave unspecified if unsure",
)
parser.add_argument(
"--sample_rate", type=int, default=44100, help="Output clip sample rate"
)
parser.add_argument(
"--bitrate",
type=int,
default=192000,
help="Output clip bitrate",
)
parser.add_argument("-i", "--input", type=str, help="Input audio file")
parser.add_argument("-o", "--output", type=str, help="Output audio file")
args = parser.parse_args()
if args.sample_rate < 8000:
print(
"Sample rate cannot be lower than 8000! (44100 is recommended for base models)"
)
exit()
# ---------------------------
# Init accelerator
# ---------------------------
accelerator = Accelerator(mixed_precision="bf16")
# ---------------------------
# Models
# ---------------------------
generator = SISUGenerator()
accelerator.print("🔨 | Compiling models...")
generator = torch.compile(generator)
accelerator.print("✅ | Compiling done!")
# ---------------------------
# Prepare accelerator
# ---------------------------
generator = accelerator.prepare(generator)
# ---------------------------
# Checkpoint helpers
# ---------------------------
models_dir = args.model
clip_length = args.clip_length
input_audio = args.input
output_audio = args.output
if models_dir:
ckpt = torch.load(models_dir)
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
accelerator.print("💾 | Loaded model!")
else:
print(
"Generator model (--model) isn't specified. Do you have the trained model? If not, you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)"
)
def start():
# To Mono!
decoder = torchcodec.decoders.AudioDecoder(input_audio)
decoded_samples = decoder.get_all_samples()
audio = decoded_samples.data
original_sample_rate = decoded_samples.sample_rate
# Support for multichannel audio
# audio = AudioUtils.stereo_tensor_to_mono(audio)
audio = AudioUtils.normalize(audio)
resample_transform = torchaudio.transforms.Resample(
original_sample_rate, args.sample_rate
)
audio = resample_transform(audio)
splitted_audio = AudioUtils.split_audio(audio, clip_length)
splitted_audio_on_device = [t.view(1, t.shape[0], t.shape[-1]).to(accelerator.device) for t in splitted_audio]
processed_audio = []
with torch.no_grad():
for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."):
channels = []
for audio_channel in torch.split(clip, 1, dim=1):
output_piece = generator(audio_channel)
channels.append(output_piece.detach().cpu())
output_clip = torch.cat(channels, dim=1)
processed_audio.append(output_clip)
reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio)
reconstructed_audio = reconstructed_audio.squeeze(0)
print(f"🔊 | Saving {output_audio}!")
torchaudio.save_with_torchcodec(
uri=output_audio,
src=reconstructed_audio,
sample_rate=args.sample_rate,
channels_first=True,
compression=args.bitrate,
)
start()

95
data.py
View File

@@ -1,71 +1,52 @@
import os
import random
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch
import torchaudio
import torchcodec.decoders as decoders
import tqdm
from torch.utils.data import Dataset
import AudioUtils
import os
import random
from AudioUtils import stereo_tensor_to_mono, stretch_tensor
class AudioDataset(Dataset):
audio_sample_rates = [8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100]
audio_sample_rates = [11025]
def __init__(self, input_dir, clip_length: int = 512, normalize: bool = True):
self.clip_length = clip_length
self.normalize = normalize
input_files = [
os.path.join(input_dir, f)
for f in os.listdir(input_dir)
if os.path.isfile(os.path.join(input_dir, f))
and f.lower().endswith((".wav", ".mp3", ".flac"))
def __init__(self, input_dir):
self.input_files = [
os.path.join(root, f)
for root, _, files in os.walk(input_dir)
for f in files if f.endswith('.wav')
]
data = []
for audio_clip in tqdm.tqdm(
input_files, desc=f"Processing {len(input_files)} audio file(s)"
):
decoder = decoders.AudioDecoder(audio_clip)
decoded_samples = decoder.get_all_samples()
audio = decoded_samples.data.float()
original_sample_rate = decoded_samples.sample_rate
if normalize:
audio = AudioUtils.normalize(audio)
splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length, True)
if not splitted_high_quality_audio:
continue
for splitted_audio_clip in splitted_high_quality_audio:
for audio_clip in torch.split(splitted_audio_clip, 1):
data.append((audio_clip, original_sample_rate))
self.audio_data = data
def __len__(self):
return len(self.audio_data)
return len(self.input_files)
def __getitem__(self, idx):
audio_clip = self.audio_data[idx]
# Load high-quality audio
high_quality_path = self.input_files[idx]
high_quality_audio, original_sample_rate = torchaudio.load(high_quality_path)
high_quality_audio = stereo_tensor_to_mono(high_quality_audio)
# Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio = resample_low(high_quality_audio)
resample_transform_low = torchaudio.transforms.Resample(
audio_clip[1], mangled_sample_rate
)
resample_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
low_quality_audio = resample_high(low_quality_audio)
resample_transform_high = torchaudio.transforms.Resample(
mangled_sample_rate, audio_clip[1]
)
# Pad or truncate to match a fixed length
target_length = 44100 # Adjust this based on your data
high_quality_audio = self.pad_or_truncate(high_quality_audio, target_length)
low_quality_audio = self.pad_or_truncate(low_quality_audio, target_length)
low_audio_clip = resample_transform_high(resample_transform_low(audio_clip[0]))
if audio_clip[0].shape[1] < low_audio_clip.shape[1]:
low_audio_clip = low_audio_clip[:, :audio_clip[0].shape[1]]
elif audio_clip[0].shape[1] > low_audio_clip.shape[1]:
low_audio_clip = AudioUtils.pad_tensor(low_audio_clip, self.clip_length)
return ((audio_clip[0], low_audio_clip), (audio_clip[1], mangled_sample_rate))
return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)
def pad_or_truncate(self, tensor, target_length):
current_length = tensor.size(1)
if current_length < target_length:
# Pad with zeros
padding = target_length - current_length
tensor = F.pad(tensor, (0, padding))
else:
# Truncate to target length
tensor = tensor[:, :target_length]
return tensor

View File

@@ -1,179 +1,38 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.parametrizations import weight_norm, spectral_norm
import torch.nn.utils as utils
# -------------------------------------------------------------------
# 1. Multi-Period Discriminator (MPD)
# Captures periodic structures (pitch/timbre) by folding audio.
# -------------------------------------------------------------------
class DiscriminatorP(nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
# Use spectral_norm for stability, or weight_norm for performance
norm_f = spectral_norm if use_spectral_norm else weight_norm
# We use 2D convs because we "fold" the 1D audio into 2D (Period x Time)
self.convs = nn.ModuleList([
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(2, 0))),
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(2, 0))),
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(2, 0))),
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(2, 0))),
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
])
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d conversion: [B, C, T] -> [B, C, T/P, P]
b, c, t = x.shape
if t % self.period != 0: # Pad if not divisible by period
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.1)
fmap.append(x) # Store feature map for Feature Matching Loss
x = self.conv_post(x)
fmap.append(x)
# Flatten back to 1D for score
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(nn.Module):
def __init__(self, periods=[2, 3, 5, 7, 11]):
super(MultiPeriodDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([
DiscriminatorP(p) for p in periods
])
def forward(self, y, y_hat):
y_d_rs = [] # Real scores
y_d_gs = [] # Generated (Fake) scores
fmap_rs = [] # Real feature maps
fmap_gs = [] # Generated (Fake) feature maps
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
# -------------------------------------------------------------------
# 2. Multi-Scale Discriminator (MSD)
# Captures structure at different audio resolutions (raw, x0.5, x0.25).
# -------------------------------------------------------------------
class DiscriminatorS(nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = spectral_norm if use_spectral_norm else weight_norm
# Standard 1D Convolutions with large receptive field
self.convs = nn.ModuleList([
norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)),
norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
])
self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.1)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiScaleDiscriminator(nn.Module):
def __init__(self):
super(MultiScaleDiscriminator, self).__init__()
# 3 Scales: Original, Downsampled x2, Downsampled x4
self.discriminators = nn.ModuleList([
DiscriminatorS(use_spectral_norm=True),
DiscriminatorS(),
DiscriminatorS(),
])
self.meanpools = nn.ModuleList([
nn.AvgPool1d(4, 2, padding=2),
nn.AvgPool1d(4, 2, padding=2)
])
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
if i != 0:
# Downsample input for subsequent discriminators
y = self.meanpools[i-1](y)
y_hat = self.meanpools[i-1](y_hat)
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
# -------------------------------------------------------------------
# 3. Master Wrapper
# Combines MPD and MSD into one class to fit your training script.
# -------------------------------------------------------------------
def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
padding = (kernel_size // 2) * dilation
return nn.Sequential(
utils.spectral_norm(
nn.Conv1d(in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding
)
),
nn.BatchNorm1d(out_channels),
nn.LeakyReLU(0.2, inplace=True)
)
class SISUDiscriminator(nn.Module):
def __init__(self):
super(SISUDiscriminator, self).__init__()
self.mpd = MultiPeriodDiscriminator()
self.msd = MultiScaleDiscriminator()
def forward(self, y, y_hat):
# Return format:
# scores_real, scores_fake, features_real, features_fake
# Run Multi-Period
mpd_y_d_rs, mpd_y_d_gs, mpd_fmap_rs, mpd_fmap_gs = self.mpd(y, y_hat)
# Run Multi-Scale
msd_y_d_rs, msd_y_d_gs, msd_fmap_rs, msd_fmap_gs = self.msd(y, y_hat)
# Combine all results
return (
mpd_y_d_rs + msd_y_d_rs, # All real scores
mpd_y_d_gs + msd_y_d_gs, # All fake scores
mpd_fmap_rs + msd_fmap_rs, # All real feature maps
mpd_fmap_gs + msd_fmap_gs # All fake feature maps
layers = 4
self.model = nn.Sequential(
discriminator_block(1, layers, kernel_size=7, stride=2, dilation=1),
discriminator_block(layers, layers * 2, kernel_size=5, stride=2, dilation=1),
discriminator_block(layers * 2, layers * 4, kernel_size=3, dilation=4),
discriminator_block(layers * 4, layers * 4, kernel_size=5, dilation=8),
discriminator_block(layers * 4, layers * 2, kernel_size=3, dilation=16),
discriminator_block(layers * 2, layers, kernel_size=5, dilation=2),
discriminator_block(layers, 1, kernel_size=3, stride=1)
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
def forward(self, x):
x = self.model(x)
x = self.global_avg_pool(x)
return x.view(-1, 1)

View File

@@ -1,126 +1,41 @@
import torch
import torch.nn as nn
from torch.nn.utils.parametrizations import weight_norm
def GeneratorBlock(in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
padding = (kernel_size - 1) // 2 * dilation
def conv_residual_block(in_channels, out_channels, kernel_size=3, dilation=1):
padding = (kernel_size // 2) * dilation
return nn.Sequential(
weight_norm(nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding
)),
nn.PReLU(num_parameters=1, init=0.1),
)
class AttentionBlock(nn.Module):
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
weight_norm(nn.Conv1d(channels, channels // 4, kernel_size=1)),
nn.ReLU(inplace=True),
weight_norm(nn.Conv1d(channels // 4, channels, kernel_size=1)),
nn.Sigmoid(),
)
def forward(self, x):
attention_weights = self.attention(x)
return x + (x * attention_weights)
class ResidualInResidualBlock(nn.Module):
def __init__(self, channels, num_convs=3):
super(ResidualInResidualBlock, self).__init__()
self.conv_layers = nn.Sequential(
*[GeneratorBlock(channels, channels) for _ in range(num_convs)]
)
self.attention = AttentionBlock(channels)
def forward(self, x):
residual = x
x = self.conv_layers(x)
x = self.attention(x)
return x + residual
def UpsampleBlock(in_channels, out_channels, scale_factor=2):
return nn.Sequential(
nn.Upsample(scale_factor=scale_factor, mode='nearest'),
weight_norm(nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1
)),
nn.PReLU(num_parameters=1, init=0.1)
nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=padding),
nn.BatchNorm1d(out_channels),
nn.PReLU(),
nn.Conv1d(out_channels, out_channels, kernel_size, dilation=dilation, padding=padding),
nn.BatchNorm1d(out_channels)
)
class SISUGenerator(nn.Module):
def __init__(self, channels=32, num_rirb=4):
def __init__(self):
super(SISUGenerator, self).__init__()
self.first_conv = GeneratorBlock(1, channels)
self.downsample = GeneratorBlock(channels, channels * 2, stride=2)
self.downsample_attn = AttentionBlock(channels * 2)
self.downsample_2 = GeneratorBlock(channels * 2, channels * 4, stride=2)
self.downsample_2_attn = AttentionBlock(channels * 4)
self.rirb = nn.Sequential(
*[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)]
layers = 4
self.conv1 = nn.Sequential(
nn.Conv1d(1, layers, kernel_size=7, padding=3),
nn.BatchNorm1d(layers),
nn.PReLU()
)
self.upsample = UpsampleBlock(channels * 4, channels * 2)
self.upsample_attn = AttentionBlock(channels * 2)
self.compress_1 = GeneratorBlock(channels * 4, channels * 2)
self.upsample_2 = UpsampleBlock(channels * 2, channels)
self.upsample_2_attn = AttentionBlock(channels)
self.compress_2 = GeneratorBlock(channels * 2, channels)
self.final_conv = nn.Sequential(
weight_norm(nn.Conv1d(channels, 1, kernel_size=7, padding=3)),
nn.Tanh()
self.conv_blocks = nn.Sequential(
conv_residual_block(layers, layers, kernel_size=3, dilation=1),
conv_residual_block(layers, layers * 2, kernel_size=5, dilation=2),
conv_residual_block(layers * 2, layers * 4, kernel_size=3, dilation=16),
conv_residual_block(layers * 4, layers * 2, kernel_size=5, dilation=8),
conv_residual_block(layers * 2, layers, kernel_size=5, dilation=2),
conv_residual_block(layers, layers, kernel_size=3, dilation=1)
)
self.final_layer = nn.Sequential(
nn.Conv1d(layers, 1, kernel_size=3, padding=1)
)
def forward(self, x):
residual_input = x
# Encoding
x1 = self.first_conv(x)
x2 = self.downsample(x1)
x2 = self.downsample_attn(x2)
x3 = self.downsample_2(x2)
x3 = self.downsample_2_attn(x3)
# Bottleneck (Deep Residual processing)
x_rirb = self.rirb(x3)
# Decoding with Skip Connections
up1 = self.upsample(x_rirb)
up1 = self.upsample_attn(up1)
cat1 = torch.cat((up1, x2), dim=1)
comp1 = self.compress_1(cat1)
up2 = self.upsample_2(comp1)
up2 = self.upsample_2_attn(up2)
cat2 = torch.cat((up2, x1), dim=1)
comp2 = self.compress_2(cat2)
learned_residual = self.final_conv(comp2)
output = residual_input + learned_residual
return output
residual = x
x = self.conv1(x)
x = self.conv_blocks(x) + x # Adding residual connection after blocks
x = self.final_layer(x)
return x + residual

14
requirements.txt Normal file
View File

@@ -0,0 +1,14 @@
filelock==3.16.1
fsspec==2024.10.0
Jinja2==3.1.4
MarkupSafe==2.1.5
mpmath==1.3.0
networkx==3.4.2
numpy==2.2.1
pytorch-triton-rocm==3.2.0+git0d4682f0
setuptools==70.2.0
sympy==1.13.1
torch==2.6.0.dev20241222+rocm6.2.4
torchaudio==2.6.0.dev20241222+rocm6.2.4
tqdm==4.67.1
typing_extensions==4.12.2

View File

@@ -1,247 +1,164 @@
import argparse
import datetime
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import tqdm
from accelerate import Accelerator
from torch.utils.data import DataLoader, DistributedSampler
import argparse
import math
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import AudioUtils
from data import AudioDataset
from discriminator import SISUDiscriminator
from generator import SISUGenerator
from utils.TrainingTools import discriminator_train, generator_train
from discriminator import SISUDiscriminator
def perceptual_loss(y_true, y_pred):
return torch.mean((y_true - y_pred) ** 2)
def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
optimizer_d.zero_grad()
# Forward pass for real samples
discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion_d(discriminator_decision_from_real, real_labels)
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
discriminator_decision_from_fake = discriminator(generator_output.detach())
d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels)
# Combine real and fake losses
d_loss = (d_loss_real + d_loss_fake) / 2.0
# Backward pass and optimization
d_loss.backward()
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
optimizer_d.step()
return d_loss
def generator_train(low_quality, real_labels):
optimizer_g.zero_grad()
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
discriminator_decision = discriminator(generator_output)
g_loss = criterion_g(discriminator_decision, real_labels)
g_loss.backward()
optimizer_g.step()
return generator_output
def first(objects):
if len(objects) >= 1:
return objects[0]
return objects
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
# ---------------------------
# Argument parsing
# ---------------------------
parser = argparse.ArgumentParser(description="Training script (safer defaults)")
parser.add_argument("--resume", action="store_true", help="Resume training")
parser.add_argument(
"--epochs", type=int, default=5000, help="Number of training epochs"
)
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
parser.add_argument("--num_workers", type=int, default=4, help="DataLoader num_workers") # Increased workers slightly
parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument(
"--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA"
)
args = parser.parse_args()
# ---------------------------
# Init accelerator
# ---------------------------
# Check for CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
accelerator = Accelerator(mixed_precision="bf16")
# Initialize dataset and dataloader
dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir)
# ---------------------------
# Models
# ---------------------------
# ========= SINGLE =========
train_data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
# Initialize models and move them to device
generator = SISUGenerator()
# Note: SISUDiscriminator is now an Ensemble (MPD + MSD)
discriminator = SISUDiscriminator()
accelerator.print("🔨 | Compiling models...")
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, weights_only=True))
if args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, weights_only=True))
# Torch compile is great, but if you hit errors with the new List/Tuple outputs
# of the discriminator, you might need to disable it for D.
generator = torch.compile(generator)
discriminator = torch.compile(discriminator)
generator = generator.to(device)
discriminator = discriminator.to(device)
accelerator.print("✅ | Compiling done!")
# Loss
criterion_g = nn.MSELoss()
criterion_d = nn.BCELoss()
# ---------------------------
# Dataset / DataLoader
# ---------------------------
accelerator.print("📊 | Fetching dataset...")
dataset = AudioDataset("./dataset", 8192)
# Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None
pin_memory = torch.cuda.is_available() and not args.no_pin_memory
# Scheduler
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
train_loader = DataLoader(
dataset,
sampler=sampler,
batch_size=args.batch_size,
shuffle=(sampler is None),
num_workers=args.num_workers,
pin_memory=pin_memory,
persistent_workers=pin_memory,
)
def start_training():
generator_epochs = 5000
for generator_epoch in range(generator_epochs):
low_quality_audio = (torch.empty((1)), 1)
high_quality_audio = (torch.empty((1)), 1)
ai_enhanced_audio = (torch.empty((1)), 1)
if not train_loader or not train_loader.batch_size or train_loader.batch_size == 0:
accelerator.print("🪹 | There is no data to train with! Exiting...")
exit()
times_correct = 0
loader_batch_size = train_loader.batch_size
# ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {generator_epoch+1}/{generator_epochs}"):
# for high_quality_clip, low_quality_clip in train_data_loader:
high_quality_sample = (high_quality_clip[0].to(device), high_quality_clip[1])
low_quality_sample = (low_quality_clip[0].to(device), low_quality_clip[1])
accelerator.print("✅ | Dataset fetched!")
# ========= LABELS =========
batch_size = high_quality_clip[0].size(0)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# ---------------------------
# Losses / Optimizers / Scalers
# ---------------------------
optimizer_g = optim.AdamW(
generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001
)
optimizer_d = optim.AdamW(
discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001
)
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_g, mode="min", factor=0.5, patience=5
)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_d, mode="min", factor=0.5, patience=5
)
# ---------------------------
# Prepare accelerator
# ---------------------------
generator, discriminator, optimizer_g, optimizer_d, train_loader = accelerator.prepare(
generator, discriminator, optimizer_g, optimizer_d, train_loader
)
# ---------------------------
# Checkpoint helpers
# ---------------------------
models_dir = "./models"
os.makedirs(models_dir, exist_ok=True)
def save_ckpt(path, epoch, loss=None, is_best=False):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
state = {
"epoch": epoch,
"G": accelerator.unwrap_model(generator).state_dict(),
"D": accelerator.unwrap_model(discriminator).state_dict(),
"optG": optimizer_g.state_dict(),
"optD": optimizer_d.state_dict(),
"schedG": scheduler_g.state_dict(),
"schedD": scheduler_d.state_dict()
}
accelerator.save(state, os.path.join(models_dir, "last.pt"))
if is_best:
accelerator.save(state, os.path.join(models_dir, "best.pt"))
accelerator.print(f"🌟 | New best model saved with G Loss: {loss:.4f}")
start_epoch = 0
if args.resume:
ckpt_path = os.path.join(models_dir, "last.pt")
if os.path.exists(ckpt_path):
ckpt = torch.load(ckpt_path)
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"])
optimizer_g.load_state_dict(ckpt["optG"])
optimizer_d.load_state_dict(ckpt["optD"])
scheduler_g.load_state_dict(ckpt["schedG"])
scheduler_d.load_state_dict(ckpt["schedD"])
start_epoch = ckpt.get("epoch", 1)
accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!")
else:
accelerator.print("⚠️ | Resume requested but no checkpoint found. Starting fresh.")
accelerator.print("🏋️ | Started training...")
try:
for epoch in range(start_epoch, args.epochs):
generator.train()
# ========= DISCRIMINATOR =========
discriminator.train()
discriminator_train(high_quality_sample, low_quality_sample, real_labels, fake_labels)
discriminator_time = 0
generator_time = 0
# ========= GENERATOR =========
generator.train()
generator_output = generator_train(low_quality_sample, real_labels)
running_d, running_g, steps = 0.0, 0.0, 0
# ========= SAVE LATEST AUDIO =========
high_quality_audio = (first(high_quality_clip[0]), high_quality_clip[1][0])
low_quality_audio = (first(low_quality_clip[0]), low_quality_clip[1][0])
ai_enhanced_audio = (first(generator_output[0]), high_quality_clip[1][0])
print(high_quality_audio)
progress_bar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs")
print(f"Saved epoch {generator_epoch}!")
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1])
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1])
for i, (
(high_quality, low_quality),
(high_sample_rate, low_sample_rate),
) in enumerate(progress_bar):
with accelerator.autocast():
generator_output = generator(low_quality)
#metric = snr(high_quality_audio[0].to(device), ai_enhanced_audio[0])
#print(f"Generator metric {metric}!")
#scheduler_g.step(metric)
# --- Discriminator ---
d_time = datetime.datetime.now()
optimizer_d.zero_grad(set_to_none=True)
with accelerator.autocast():
d_loss = discriminator_train(
high_quality,
discriminator,
generator_output.detach()
)
if generator_epoch % 10 == 0:
print(f"Saved epoch {generator_epoch}!")
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1])
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1])
accelerator.backward(d_loss)
optimizer_d.step()
discriminator_time = (datetime.datetime.now() - d_time).microseconds
torch.save(discriminator.state_dict(), f"models/current-epoch-discriminator.pt")
torch.save(generator.state_dict(), f"models/current-epoch-generator.pt")
# --- Generator ---
g_time = datetime.datetime.now()
optimizer_g.zero_grad(set_to_none=True)
with accelerator.autocast():
g_total, g_adv = generator_train(
low_quality,
high_quality,
generator,
discriminator,
generator_output
)
torch.save(discriminator.state_dict(), "models/epoch-5000-discriminator.pt")
torch.save(generator.state_dict(), "models/epoch-5000-generator.pt")
print("Training complete!")
accelerator.backward(g_total)
torch.nn.utils.clip_grad_norm_(generator.parameters(), 1)
optimizer_g.step()
generator_time = (datetime.datetime.now() - g_time).microseconds
d_val = accelerator.gather(d_loss.detach()).mean()
g_val = accelerator.gather(g_total.detach()).mean()
if torch.isfinite(d_val):
running_d += d_val.item()
else:
accelerator.print(
f"🫥 | NaN in discriminator loss at step {i}, skipping update."
)
if torch.isfinite(g_val):
running_g += g_val.item()
else:
accelerator.print(
f"🫥 | NaN in generator loss at step {i}, skipping update."
)
steps += 1
progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs")
if steps == 0:
accelerator.print("🪹 | No steps in epoch (empty dataloader?). Exiting.")
break
mean_d = running_d / steps
mean_g = running_g / steps
scheduler_d.step(mean_d)
scheduler_g.step(mean_g)
save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
except Exception:
try:
save_ckpt(os.path.join(models_dir, "crash_last.pt"), epoch)
accelerator.print(f"💾 | Saved crash checkpoint for epoch {epoch}")
except Exception as e:
accelerator.print("😬 | Failed saving crash checkpoint:", e)
raise
accelerator.print("🏁 | Training finished.")
start_training()

View File

@@ -1,68 +0,0 @@
from typing import Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.transforms as T
class MultiResolutionSTFTLoss(nn.Module):
def __init__(
self,
fft_sizes: List[int] = [512, 1024, 2048, 4096, 8192],
hop_sizes: List[int] = [64, 128, 256, 512, 1024],
win_lengths: List[int] = [256, 512, 1024, 2048, 4096],
eps: float = 1e-7,
center: bool = True
):
super().__init__()
self.eps = eps
self.n_resolutions = len(fft_sizes)
self.stft_transforms = nn.ModuleList()
for i, (n_fft, hop_len, win_len) in enumerate(zip(fft_sizes, hop_sizes, win_lengths)):
stft = T.Spectrogram(
n_fft=n_fft,
hop_length=hop_len,
win_length=win_len,
window_fn=torch.hann_window,
power=None,
center=center,
pad_mode="reflect",
normalized=False,
)
self.stft_transforms.append(stft)
def forward(
self, y_true: torch.Tensor, y_pred: torch.Tensor
) -> Dict[str, torch.Tensor]:
if y_true.dim() == 3 and y_true.size(1) == 1:
y_true = y_true.squeeze(1)
if y_pred.dim() == 3 and y_pred.size(1) == 1:
y_pred = y_pred.squeeze(1)
sc_loss = 0.0
mag_loss = 0.0
for stft in self.stft_transforms:
stft.window = stft.window.to(y_true.device)
stft_true = stft(y_true)
stft_pred = stft(y_pred)
stft_mag_true = torch.abs(stft_true)
stft_mag_pred = torch.abs(stft_pred)
norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1))
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1))
sc_loss += torch.mean(norm_diff / (norm_true + self.eps))
log_mag_pred = torch.log(stft_mag_pred + self.eps)
log_mag_true = torch.log(stft_mag_true + self.eps)
mag_loss += F.l1_loss(log_mag_pred, log_mag_true)
sc_loss /= self.n_resolutions
mag_loss /= self.n_resolutions
total_loss = sc_loss + mag_loss
return {"total": total_loss, "sc": sc_loss, "mag": mag_loss}

View File

@@ -1,93 +0,0 @@
import torch
import torch.nn.functional as F
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
# Keep STFT settings as is
stft_loss_fn = MultiResolutionSTFTLoss(
fft_sizes=[512, 1024, 2048],
hop_sizes=[64, 128, 256],
win_lengths=[256, 512, 1024]
)
def feature_matching_loss(fmap_r, fmap_g):
"""
Computes L1 distance between real and fake feature maps.
"""
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
rl = rl.detach()
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
"""
Least Squares GAN Loss (LSGAN) for the Discriminator.
Objective: Real -> 1, Fake -> 0
"""
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((dr - 1) ** 2)
g_loss = torch.mean(dg ** 2)
loss += (r_loss + g_loss)
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_adv_loss(disc_generated_outputs):
"""
Least Squares GAN Loss for the Generator.
Objective: Fake -> 1 (Fool the discriminator)
"""
loss = 0
for dg in zip(disc_generated_outputs):
dg = dg[0] # Unpack tuple
loss += torch.mean((dg - 1) ** 2)
return loss
def discriminator_train(
high_quality,
discriminator,
generator_output
):
y_d_rs, y_d_gs, _, _ = discriminator(high_quality, generator_output.detach())
d_loss, _, _ = discriminator_loss(y_d_rs, y_d_gs)
return d_loss
def generator_train(
low_quality,
high_quality,
generator,
discriminator,
generator_output
):
y_d_rs, y_d_gs, fmap_rs, fmap_gs = discriminator(high_quality, generator_output)
loss_gen_adv = generator_adv_loss(y_d_gs)
loss_fm = feature_matching_loss(fmap_rs, fmap_gs)
stft_loss = stft_loss_fn(high_quality, generator_output)["total"]
lambda_stft = 45.0
lambda_fm = 2.0
lambda_adv = 1.0
combined_loss = (lambda_stft * stft_loss) + \
(lambda_fm * loss_fm) + \
(lambda_adv * loss_gen_adv)
return combined_loss, loss_gen_adv

View File