27 Commits

Author SHA1 Message Date
3f23242d6f ⚗️ | Added some stupid ways for training + some makeup 2025-10-04 22:38:11 +03:00
0bc8fc2792 | Made training bit... spicier. 2025-09-10 19:52:53 +03:00
ff38cefdd3 🐛 | Fix loading wrong model. 2025-06-08 18:14:31 +03:00
03fdc050cc | Made training bit faster. 2025-06-07 20:43:52 +03:00
2ded03713d | Added app.py script so the model can be used. 2025-06-06 22:10:06 +03:00
a135c765da 🐛 | Misc fixes... 2025-05-05 00:50:56 +03:00
b1e18443ba | Added support for .mp3 and .flac loading... 2025-05-04 23:56:14 +03:00
660b41aef8 :albemic: | Real-time testing... 2025-05-04 22:48:57 +03:00
d70c86c257 | Implemented MFCC and STFT. 2025-04-26 17:03:28 +03:00
c04b072de6 | Added smarter ways that would've been needed from the begining. 2025-04-16 17:08:13 +03:00
b6d16e4f11 ♻️ | Restructured procject code. 2025-04-14 17:51:34 +03:00
nsiltala
3936b6c160 🐛 | Fixed NVIDIA training... again. 2025-04-07 14:49:07 +03:00
fbcd5803b8 🐛 | Fixed training on CPU and NVIDIA hardware. 2025-04-07 02:14:06 +03:00
9394bc6c5a :albemic: | Fat architecture. Hopefully better results. 2025-04-06 00:05:43 +03:00
f928d8c2cf :albemic: | More tests. 2025-03-25 21:51:29 +02:00
54338e55a9 :albemic: | Tests. 2025-03-25 19:50:51 +02:00
7e1c7e935a :albemic: | Experimenting with other model layouts. 2025-03-15 18:01:19 +02:00
416500f7fc | Removed/Updated dependencies. 2025-02-26 20:15:30 +02:00
8332b0df2d | Added ability to set epoch. 2025-02-26 19:36:43 +02:00
741dcce7b4 ⚗️ | Increase discriminator size and implement mfcc_loss for generator. 2025-02-23 13:52:01 +02:00
fb7b624c87 ⚗️ | Experimenting with very small model. 2025-02-10 12:44:42 +02:00
0790a0d3da ⚗️ | Experimenting with smaller architecture. 2025-01-25 16:48:10 +02:00
f615b39ded ⚗️ | Experimenting with larger model architecture. 2025-01-08 15:33:18 +02:00
89f8c68986 ⚗️ | Experimenting, again. 2024-12-26 04:00:24 +02:00
2ff45de22d 🔥 | Removed unnecessary test file. 2024-12-25 00:10:45 +02:00
eca71ff5ea ⚗️ | Experimenting still... 2024-12-25 00:09:57 +02:00
1000692f32 ⚗️ | Experimenting with other generator architectures. 2024-12-21 23:54:11 +02:00
13 changed files with 756 additions and 239 deletions

97
AudioUtils.py Normal file
View File

@@ -0,0 +1,97 @@
import torch
import torch.nn.functional as F
def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor:
"""
Convert stereo (C, N) to mono (1, N). Ensures a channel dimension.
"""
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0) # (N,) -> (1, N)
if waveform.shape[0] > 1:
mono_waveform = torch.mean(waveform, dim=0, keepdim=True) # (1, N)
else:
mono_waveform = waveform
return mono_waveform
def stretch_tensor(tensor: torch.Tensor, target_length: int) -> torch.Tensor:
"""
Stretch audio along time dimension to target_length.
Input assumed (1, N). Returns (1, target_length).
"""
if tensor.dim() == 1:
tensor = tensor.unsqueeze(0) # ensure (1, N)
tensor = tensor.unsqueeze(0) # (1, 1, N) for interpolate
stretched = F.interpolate(
tensor, size=target_length, mode="linear", align_corners=False
)
return stretched.squeeze(0) # back to (1, target_length)
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128) -> torch.Tensor:
"""
Pad to fixed length. Input assumed (1, N). Returns (1, target_length).
"""
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
current_length = audio_tensor.shape[-1]
if current_length < target_length:
padding_needed = target_length - current_length
padding_tuple = (0, padding_needed)
padded_audio_tensor = F.pad(
audio_tensor, padding_tuple, mode="constant", value=0
)
else:
padded_audio_tensor = audio_tensor[..., :target_length] # crop if too long
return padded_audio_tensor
def split_audio(
audio_tensor: torch.Tensor, chunk_size: int = 128
) -> list[torch.Tensor]:
"""
Split into chunks of (1, chunk_size).
"""
if not isinstance(chunk_size, int) or chunk_size <= 0:
raise ValueError("chunk_size must be a positive integer.")
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
num_samples = audio_tensor.shape[-1]
if num_samples == 0:
return []
chunks = list(torch.split(audio_tensor, chunk_size, dim=-1))
return chunks
def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor:
"""
Reconstruct audio from chunks. Returns (1, N).
"""
if not chunks:
return torch.empty(1, 0)
chunks = [c if c.dim() == 2 else c.unsqueeze(0) for c in chunks]
try:
reconstructed_tensor = torch.cat(chunks, dim=-1)
except RuntimeError as e:
raise RuntimeError(
f"Failed to concatenate audio chunks. Ensure chunks have compatible shapes "
f"for concatenation along dim -1. Original error: {e}"
)
return reconstructed_tensor
def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
max_val = torch.max(torch.abs(audio_tensor))
if max_val < eps:
return audio_tensor # silence, skip normalization
return audio_tensor / max_val

View File

@@ -18,6 +18,7 @@ SISU (Super Ingenious Sound Upscaler) is a project that uses GANs (Generative Ad
1. **Set Up**:
- Make sure you have Python installed (version 3.8 or higher).
- Install needed packages: `pip install -r requirements.txt`
- Install current version of PyTorch (CUDA/ROCm/What ever your device supports)
2. **Prepare Audio Data**:
- Put your audio files in the `dataset/good` folder.

0
__init__.py Normal file
View File

97
app.py Normal file
View File

@@ -0,0 +1,97 @@
import argparse
import torch
import torchaudio
import torchcodec
import tqdm
import AudioUtils
from generator import SISUGenerator
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--model", type=str, help="Model to use for upscaling")
parser.add_argument(
"--clip_length",
type=int,
default=16384,
help="Internal clip length, leave unspecified if unsure",
)
parser.add_argument(
"--sample_rate", type=int, default=44100, help="Output clip sample rate"
)
parser.add_argument(
"--bitrate",
type=int,
default=192000,
help="Output clip bitrate",
)
parser.add_argument("-i", "--input", type=str, help="Input audio file")
parser.add_argument("-o", "--output", type=str, help="Output audio file")
args = parser.parse_args()
if args.sample_rate < 8000:
print(
"Sample rate cannot be lower than 8000! (44100 is recommended for base models)"
)
exit()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
generator = SISUGenerator().to(device)
generator = torch.compile(generator)
models_dir = args.model
clip_length = args.clip_length
input_audio = args.input
output_audio = args.output
if models_dir:
ckpt = torch.load(models_dir, map_location=device)
generator.load_state_dict(ckpt["G"])
else:
print(
"Generator model (--model) isn't specified. Do you have the trained model? If not, you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)"
)
def start():
# To Mono!
decoder = torchcodec.decoders.AudioDecoder(input_audio)
decoded_samples = decoder.get_all_samples()
audio = decoded_samples.data
original_sample_rate = decoded_samples.sample_rate
audio = AudioUtils.stereo_tensor_to_mono(audio)
audio = AudioUtils.normalize(audio)
resample_transform = torchaudio.transforms.Resample(
original_sample_rate, args.sample_rate
)
audio = resample_transform(audio)
splitted_audio = AudioUtils.split_audio(audio, clip_length)
splitted_audio_on_device = [t.to(device) for t in splitted_audio]
processed_audio = []
for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."):
processed_audio.append(generator(clip))
reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio)
print(f"Saving {output_audio}!")
torchaudio.save_with_torchcodec(
uri=output_audio,
src=reconstructed_audio,
sample_rate=args.sample_rate,
channels_first=True,
compression=args.bitrate,
)
start()

108
data.py
View File

@@ -1,49 +1,79 @@
from torch.utils.data import Dataset
import torch.nn.functional as F
import torchaudio
import os
import random
import torchaudio
import torchcodec.decoders as decoders
import tqdm
from torch.utils.data import Dataset
import AudioUtils
class AudioDataset(Dataset):
audio_sample_rates = [8000, 11025, 16000, 22050]
audio_sample_rates = [11025]
def __init__(self, input_dir, target_duration=None, padding_mode='constant', padding_value=0.0):
self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')]
self.target_duration = target_duration # Duration in seconds or None if not set
self.padding_mode = padding_mode
self.padding_value = padding_value
def __init__(self, input_dir, clip_length: int = 8000, normalize: bool = True):
self.clip_length = clip_length
self.normalize = normalize
input_files = [
os.path.join(input_dir, f)
for f in os.listdir(input_dir)
if os.path.isfile(os.path.join(input_dir, f))
and f.lower().endswith((".wav", ".mp3", ".flac"))
]
data = []
for audio_clip in tqdm.tqdm(
input_files, desc=f"Processing {len(input_files)} audio file(s)"
):
decoder = decoders.AudioDecoder(audio_clip)
decoded_samples = decoder.get_all_samples()
audio = decoded_samples.data.float() # ensure float32
original_sample_rate = decoded_samples.sample_rate
audio = AudioUtils.stereo_tensor_to_mono(audio)
if normalize:
audio = AudioUtils.normalize(audio)
mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform_low = torchaudio.transforms.Resample(
original_sample_rate, mangled_sample_rate
)
resample_transform_high = torchaudio.transforms.Resample(
mangled_sample_rate, original_sample_rate
)
low_audio = resample_transform_high(resample_transform_low(audio))
splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length)
splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length)
if not splitted_high_quality_audio or not splitted_low_quality_audio:
continue # skip empty or invalid clips
splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(
splitted_high_quality_audio[-1], clip_length
)
splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(
splitted_low_quality_audio[-1], clip_length
)
for high_quality_data, low_quality_data in zip(
splitted_high_quality_audio, splitted_low_quality_audio
):
data.append(
(
(high_quality_data, low_quality_data),
(original_sample_rate, mangled_sample_rate),
)
)
self.audio_data = data
def __len__(self):
return len(self.input_files)
return len(self.audio_data)
def __getitem__(self, idx):
high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True)
mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio = resample_transform(high_quality_audio)
# Calculate target length based on desired duration and 16000 Hz
# if self.target_duration is not None:
# target_length = int(self.target_duration * 44100)
# else:
# # Calculate duration of original high quality audio
# target_length = high_quality_wav.size(1)
# Pad both to the calculated target length
# high_quality_wav = self.stretch_tensor(high_quality_wav, target_length)
# low_quality_wav = self.stretch_tensor(low_quality_wav, target_length)
return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)
def stretch_tensor(self, tensor, target_length):
current_length = tensor.size(1)
scale_factor = target_length / current_length
# Resample the tensor using linear interpolation
tensor = F.interpolate(tensor.unsqueeze(0), scale_factor=scale_factor, mode='linear', align_corners=False).squeeze(0)
return tensor
return self.audio_data[idx]

View File

@@ -1,24 +1,75 @@
import torch.nn as nn
import torch.nn.utils as utils
def discriminator_block(
in_channels,
out_channels,
kernel_size=3,
stride=1,
dilation=1,
spectral_norm=True,
use_instance_norm=True,
):
padding = (kernel_size // 2) * dilation
conv_layer = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
)
if spectral_norm:
conv_layer = utils.spectral_norm(conv_layer)
layers = [conv_layer]
layers.append(nn.LeakyReLU(0.2, inplace=True))
if use_instance_norm:
layers.append(nn.InstanceNorm1d(out_channels))
return nn.Sequential(*layers)
class AttentionBlock(nn.Module):
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid(),
)
def forward(self, x):
attention_weights = self.attention(x)
return x * attention_weights
class SISUDiscriminator(nn.Module):
def __init__(self):
def __init__(self, layers=32):
super(SISUDiscriminator, self).__init__()
self.model = nn.Sequential(
nn.Conv1d(2, 128, kernel_size=3, padding=1),
#nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(128, 256, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(256, 128, kernel_size=3, padding=1),
#nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(128, 64, kernel_size=3, padding=1),
#nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(64, 1, kernel_size=3, padding=1),
#nn.LeakyReLU(0.2, inplace=True),
discriminator_block(1, layers, kernel_size=7, stride=1),
discriminator_block(layers, layers * 2, kernel_size=5, stride=2),
discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=2),
AttentionBlock(layers * 4),
discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=4),
discriminator_block(layers * 8, layers * 2, kernel_size=5, stride=2),
discriminator_block(
layers * 2,
1,
spectral_norm=False,
use_instance_norm=False,
),
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1) # Output size (1,)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
def forward(self, x):
x = self.model(x)
x = self.global_avg_pool(x)
x = x.view(-1, 1) # Flatten to (batch_size, 1)
x = x.view(x.size(0), -1)
return x

View File

@@ -1,27 +1,81 @@
import torch
import torch.nn as nn
def conv_block(in_channels, out_channels, kernel_size=3, dilation=1):
return nn.Sequential(
nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
dilation=dilation,
padding=(kernel_size // 2) * dilation,
),
nn.InstanceNorm1d(out_channels),
nn.PReLU(),
)
class AttentionBlock(nn.Module):
"""
Simple Channel Attention Block. Learns to weight channels based on their importance.
"""
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid(),
)
def forward(self, x):
attention_weights = self.attention(x)
return x * attention_weights
class ResidualInResidualBlock(nn.Module):
def __init__(self, channels, num_convs=3):
super(ResidualInResidualBlock, self).__init__()
self.conv_layers = nn.Sequential(
*[conv_block(channels, channels) for _ in range(num_convs)]
)
self.attention = AttentionBlock(channels)
def forward(self, x):
residual = x
x = self.conv_layers(x)
x = self.attention(x)
return x + residual
class SISUGenerator(nn.Module):
def __init__(self, upscale_scale=1): # No noise_dim parameter
def __init__(self, channels=16, num_rirb=4, alpha=1):
super(SISUGenerator, self).__init__()
self.layers1 = nn.Sequential(
nn.Conv1d(2, 128, kernel_size=3, padding=1),
# nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(128, 256, kernel_size=3, padding=1),
# nn.LeakyReLU(0.2, inplace=True),
self.alpha = alpha
self.conv1 = nn.Sequential(
nn.Conv1d(1, channels, kernel_size=7, padding=3),
nn.InstanceNorm1d(channels),
nn.PReLU(),
)
self.layers2 = nn.Sequential(
nn.Conv1d(256, 128, kernel_size=3, padding=1),
# nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(128, 64, kernel_size=3, padding=1),
# nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(64, 2, kernel_size=3, padding=1),
# nn.Tanh()
self.rir_blocks = nn.Sequential(
*[ResidualInResidualBlock(channels) for _ in range(num_rirb)]
)
def forward(self, x, scale):
x = self.layers1(x)
upsample = nn.Upsample(scale_factor=scale, mode='nearest')
x = upsample(x)
x = self.layers2(x)
return x
self.final_layer = nn.Sequential(
nn.Conv1d(channels, 1, kernel_size=3, padding=1), nn.Tanh()
)
def forward(self, x):
residual_input = x
x = self.conv1(x)
x_rirb_out = self.rir_blocks(x)
learned_residual = self.final_layer(x_rirb_out)
output = residual_input + self.alpha * learned_residual
return torch.tanh(output)

View File

@@ -1,12 +0,0 @@
filelock>=3.16.1
fsspec>=2024.10.0
Jinja2>=3.1.4
MarkupSafe>=2.1.5
mpmath>=1.3.0
networkx>=3.4.2
numpy>=2.1.2
pillow>=11.0.0
setuptools>=70.2.0
sympy>=1.13.1
tqdm>=4.67.1
typing_extensions>=4.12.2

10
test.py
View File

@@ -1,10 +0,0 @@
import torch.nn as nn
import torch
from discriminator import SISUDiscriminator
discriminator = SISUDiscriminator()
test_input = torch.randn(1, 2, 1000) # Example input (batch_size, channels, frames)
output = discriminator(test_input)
print(output)
print("Output shape:", output.shape)

View File

@@ -1,183 +1,245 @@
import argparse
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import tqdm
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from accelerate import Accelerator
from torch.utils.data import DataLoader, DistributedSampler
from data import AudioDataset
from generator import SISUGenerator
from discriminator import SISUDiscriminator
from generator import SISUGenerator
from utils.TrainingTools import discriminator_train, generator_train
# Mel Spectrogram Loss
class MelSpectrogramLoss(nn.Module):
def __init__(self, sample_rate=44100, n_fft=2048, hop_length=512, n_mels=128):
super(MelSpectrogramLoss, self).__init__()
self.mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels
).to(device) # Move to device
# ---------------------------
# Argument parsing
# ---------------------------
parser = argparse.ArgumentParser(description="Training script (safer defaults)")
parser.add_argument("--resume", action="store_true", help="Resume training")
parser.add_argument(
"--epochs", type=int, default=5000, help="Number of training epochs"
)
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
parser.add_argument("--num_workers", type=int, default=2, help="DataLoader num_workers")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument(
"--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA"
)
args = parser.parse_args()
def forward(self, y_pred, y_true):
mel_pred = self.mel_transform(y_pred)
mel_true = self.mel_transform(y_true)
return F.l1_loss(mel_pred, mel_true)
# ---------------------------
# Init accelerator
# ---------------------------
def snr(y_true, y_pred):
noise = y_true - y_pred
signal_power = torch.mean(y_true ** 2)
noise_power = torch.mean(noise ** 2)
snr_db = 10 * torch.log10(signal_power / noise_power)
return snr_db
accelerator = Accelerator(mixed_precision="bf16")
def discriminator_train(high_quality, low_quality, scale, real_labels, fake_labels):
optimizer_d.zero_grad()
discriminator_decision_from_real = discriminator(high_quality)
# TODO: Experiment with criterions HERE!
d_loss_real = criterion_d(discriminator_decision_from_real, real_labels)
generator_output = generator(low_quality, scale)
discriminator_decision_from_fake = discriminator(generator_output.detach())
# TODO: Experiment with criterions HERE!
d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels)
d_loss = (d_loss_real + d_loss_fake) / 2.0
d_loss.backward()
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) #Gradient Clipping
optimizer_d.step()
return d_loss
def generator_train(low_quality, scale, real_labels):
optimizer_g.zero_grad()
generator_output = generator(low_quality, scale)
discriminator_decision = discriminator(generator_output)
# TODO: Fix this shit
g_loss = criterion_g(discriminator_decision, real_labels)
g_loss.backward()
optimizer_g.step()
return generator_output
# Check for CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Initialize dataset and dataloader
dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir, target_duration=2.0)
dataset_size = len(dataset)
train_size = int(dataset_size * .9)
val_size = int(dataset_size-train_size)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)
# Initialize models and move them to device
# ---------------------------
# Models
# ---------------------------
generator = SISUGenerator()
discriminator = SISUDiscriminator()
generator = generator.to(device)
discriminator = discriminator.to(device)
accelerator.print("🔨 | Compiling models...")
# Loss
criterion_g = nn.L1Loss()
criterion_g_mel = MelSpectrogramLoss().to(device)
criterion_d = nn.BCEWithLogitsLoss()
generator = torch.compile(generator)
discriminator = torch.compile(discriminator)
# Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
accelerator.print("✅ | Compiling done!")
# Scheduler
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
# ---------------------------
# Dataset / DataLoader
# ---------------------------
accelerator.print("📊 | Fetching dataset...")
dataset = AudioDataset("./dataset")
def start_training():
sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None
pin_memory = torch.cuda.is_available() and not args.no_pin_memory
# Training loop
train_loader = DataLoader(
dataset,
sampler=sampler,
batch_size=args.batch_size,
shuffle=(sampler is None),
num_workers=args.num_workers,
pin_memory=pin_memory,
persistent_workers=pin_memory,
)
# ========= DISCRIMINATOR PRE-TRAINING =========
discriminator_epochs = 1
for discriminator_epoch in range(discriminator_epochs):
if not train_loader or not train_loader.batch_size or train_loader.batch_size == 0:
accelerator.print("🪹 | There is no data to train with! Exiting...")
exit()
# ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {discriminator_epoch+1}/{discriminator_epochs}"):
high_quality_sample = high_quality_clip[0].to(device)
low_quality_sample = low_quality_clip[0].to(device)
loader_batch_size = train_loader.batch_size
scale = high_quality_clip[0].shape[2]/low_quality_clip[0].shape[2]
accelerator.print("✅ | Dataset fetched!")
# ========= LABELS =========
batch_size = high_quality_sample.size(0)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# ---------------------------
# Losses / Optimizers / Scalers
# ---------------------------
# ========= DISCRIMINATOR =========
discriminator.train()
discriminator_train(high_quality_sample, low_quality_sample, scale, real_labels, fake_labels)
optimizer_g = optim.AdamW(
generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001
)
optimizer_d = optim.AdamW(
discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001
)
torch.save(discriminator.state_dict(), "models/discriminator-single-shot-pre-train.pt")
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_g, mode="min", factor=0.5, patience=5
)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_d, mode="min", factor=0.5, patience=5
)
generator_epochs = 500
for generator_epoch in range(generator_epochs):
low_quality_audio = (torch.empty((1)), 1)
high_quality_audio = (torch.empty((1)), 1)
ai_enhanced_audio = (torch.empty((1)), 1)
criterion_g = nn.BCEWithLogitsLoss()
criterion_d = nn.MSELoss()
# ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {generator_epoch+1}/{generator_epochs}"):
high_quality_sample = high_quality_clip[0].to(device)
low_quality_sample = low_quality_clip[0].to(device)
# ---------------------------
# Prepare accelerator
# ---------------------------
scale = high_quality_clip[0].shape[2]/low_quality_clip[0].shape[2]
generator, discriminator, optimizer_g, optimizer_d, train_loader = accelerator.prepare(
generator, discriminator, optimizer_g, optimizer_d, train_loader
)
# ========= LABELS =========
batch_size = high_quality_clip[0].size(0)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# ---------------------------
# Checkpoint helpers
# ---------------------------
models_dir = "./models"
os.makedirs(models_dir, exist_ok=True)
# ========= DISCRIMINATOR =========
discriminator.train()
for _ in range(3):
discriminator_train(high_quality_sample, low_quality_sample, scale, real_labels, fake_labels)
# ========= GENERATOR =========
generator.train()
generator_output = generator_train(low_quality_sample, scale, real_labels)
def save_ckpt(path, epoch):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
accelerator.save(
{
"epoch": epoch,
"G": accelerator.unwrap_model(generator).state_dict(),
"D": accelerator.unwrap_model(discriminator).state_dict(),
"optG": optimizer_g.state_dict(),
"optD": optimizer_d.state_dict(),
"schedG": scheduler_g.state_dict(),
"schedD": scheduler_d.state_dict(),
},
path,
)
# ========= SAVE LATEST AUDIO =========
high_quality_audio = high_quality_clip
low_quality_audio = low_quality_clip
ai_enhanced_audio = (generator_output, high_quality_clip[1])
metric = snr(high_quality_audio[0].to(device), ai_enhanced_audio[0])
print(f"Generator metric {metric}!")
scheduler_g.step(metric)
start_epoch = 0
if args.resume:
ckpt_path = os.path.join(models_dir, "last.pt")
ckpt = torch.load(ckpt_path)
if generator_epoch % 10 == 0:
print(f"Saved epoch {generator_epoch}!")
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), low_quality_audio[1])
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1])
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1])
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"])
optimizer_g.load_state_dict(ckpt["optG"])
optimizer_d.load_state_dict(ckpt["optD"])
scheduler_g.load_state_dict(ckpt["schedG"])
scheduler_d.load_state_dict(ckpt["schedD"])
if generator_epoch % 50 == 0:
torch.save(discriminator.state_dict(), f"models/epoch-{generator_epoch}-discriminator.pt")
torch.save(generator.state_dict(), f"models/epoch-{generator_epoch}-generator.pt")
start_epoch = ckpt.get("epoch", 1)
accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!")
torch.save(discriminator.state_dict(), "models/epoch-500-discriminator.pt")
torch.save(generator.state_dict(), "models/epoch-500-generator.pt")
print("Training complete!")
real_buf = torch.full(
(loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32
)
fake_buf = torch.zeros(
(loader_batch_size, 1), device=accelerator.device, dtype=torch.float32
)
start_training()
accelerator.print("🏋️ | Started training...")
try:
for epoch in range(start_epoch, args.epochs):
generator.train()
discriminator.train()
running_d, running_g, steps = 0.0, 0.0, 0
for i, (
(high_quality, low_quality),
(high_sample_rate, low_sample_rate),
) in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")):
batch_size = high_quality.size(0)
real_labels = real_buf[:batch_size].to(accelerator.device)
fake_labels = fake_buf[:batch_size].to(accelerator.device)
# --- Discriminator ---
optimizer_d.zero_grad(set_to_none=True)
with accelerator.autocast():
d_loss = discriminator_train(
high_quality,
low_quality,
real_labels,
fake_labels,
discriminator,
generator,
criterion_d,
)
accelerator.backward(d_loss)
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1)
optimizer_d.step()
# --- Generator ---
optimizer_g.zero_grad(set_to_none=True)
with accelerator.autocast():
g_total, g_adv = generator_train(
low_quality,
high_quality,
real_labels,
generator,
discriminator,
criterion_d,
)
accelerator.backward(g_total)
torch.nn.utils.clip_grad_norm_(generator.parameters(), 1)
optimizer_g.step()
d_val = accelerator.gather(d_loss.detach()).mean()
g_val = accelerator.gather(g_total.detach()).mean()
if torch.isfinite(d_val):
running_d += d_val.item()
else:
accelerator.print(
f"🫥 | NaN in discriminator loss at step {i}, skipping update."
)
if torch.isfinite(g_val):
running_g += g_val.item()
else:
accelerator.print(
f"🫥 | NaN in generator loss at step {i}, skipping update."
)
steps += 1
# epoch averages & schedulers
if steps == 0:
accelerator.print("🪹 | No steps in epoch (empty dataloader?). Exiting.")
break
mean_d = running_d / steps
mean_g = running_g / steps
scheduler_d.step(mean_d)
scheduler_g.step(mean_g)
save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
except Exception:
try:
save_ckpt(os.path.join(models_dir, "crash_last.pt"), epoch)
accelerator.print(f"💾 | Saved crash checkpoint for epoch {epoch}")
except Exception as e:
accelerator.print("😬 | Failed saving crash checkpoint:", e)
raise
accelerator.print("🏁 | Training finished.")

View File

@@ -0,0 +1,87 @@
from typing import Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.transforms as T
class MultiResolutionSTFTLoss(nn.Module):
"""
Multi-resolution STFT loss.
Combines spectral convergence loss and log-magnitude loss
across multiple STFT resolutions.
"""
def __init__(
self,
fft_sizes: List[int] = [1024, 2048, 512],
hop_sizes: List[int] = [120, 240, 50],
win_lengths: List[int] = [600, 1200, 240],
eps: float = 1e-7,
):
super().__init__()
self.eps = eps
self.n_resolutions = len(fft_sizes)
self.stft_transforms = nn.ModuleList()
for n_fft, hop_len, win_len in zip(fft_sizes, hop_sizes, win_lengths):
window = torch.hann_window(win_len)
stft = T.Spectrogram(
n_fft=n_fft,
hop_length=hop_len,
win_length=win_len,
window_fn=lambda _: window,
power=None, # Keep complex output
center=True,
pad_mode="reflect",
normalized=False,
)
self.stft_transforms.append(stft)
def forward(
self, y_true: torch.Tensor, y_pred: torch.Tensor
) -> Dict[str, torch.Tensor]:
"""
Args:
y_true: (B, T) or (B, 1, T) waveform
y_pred: (B, T) or (B, 1, T) waveform
"""
# Ensure correct shape (B, T)
if y_true.dim() == 3 and y_true.size(1) == 1:
y_true = y_true.squeeze(1)
if y_pred.dim() == 3 and y_pred.size(1) == 1:
y_pred = y_pred.squeeze(1)
sc_loss = 0.0
mag_loss = 0.0
for stft in self.stft_transforms:
stft = stft.to(y_pred.device)
# Complex STFTs: (B, F, T, 2)
stft_true = stft(y_true)
stft_pred = stft(y_pred)
# Magnitudes
stft_mag_true = torch.abs(stft_true)
stft_mag_pred = torch.abs(stft_pred)
# --- Spectral Convergence Loss ---
norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1))
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1))
sc_loss += torch.mean(norm_diff / (norm_true + self.eps))
# --- Log STFT Magnitude Loss ---
mag_loss += F.l1_loss(
torch.log(stft_mag_pred + self.eps),
torch.log(stft_mag_true + self.eps),
)
# Average across resolutions
sc_loss /= self.n_resolutions
mag_loss /= self.n_resolutions
total_loss = sc_loss + mag_loss
return {"total": total_loss, "sc": sc_loss, "mag": mag_loss}

60
utils/TrainingTools.py Normal file
View File

@@ -0,0 +1,60 @@
import torch
# In case if needed again...
# from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
#
# stft_loss_fn = MultiResolutionSTFTLoss(
# fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240]
# )
def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor:
absolute_difference = torch.abs(input_one - input_two)
return torch.mean(absolute_difference)
def discriminator_train(
high_quality,
low_quality,
high_labels,
low_labels,
discriminator,
generator,
criterion,
):
decision_high = discriminator(high_quality)
d_loss_high = criterion(decision_high, high_labels)
# print(f"Is this real?: {discriminator_decision_from_real} | {d_loss_real}")
decision_low = discriminator(low_quality)
d_loss_low = criterion(decision_low, low_labels)
# print(f"Is this real?: {discriminator_decision_from_fake} | {d_loss_fake}")
with torch.no_grad():
generator_quality = generator(low_quality)
decision_gen = discriminator(generator_quality)
d_loss_gen = criterion(decision_gen, low_labels)
noise = torch.rand_like(high_quality) * 0.08
decision_noise = discriminator(high_quality + noise)
d_loss_noise = criterion(decision_noise, low_labels)
d_loss = (d_loss_high + d_loss_low + d_loss_gen + d_loss_noise) / 4.0
return d_loss
def generator_train(
low_quality, high_quality, real_labels, generator, discriminator, adv_criterion
):
generator_output = generator(low_quality)
discriminator_decision = discriminator(generator_output)
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
# Signal similarity
similarity_loss = signal_mae(generator_output, high_quality)
combined_loss = adversarial_loss + (similarity_loss * 100)
return combined_loss, adversarial_loss

0
utils/__init__.py Normal file
View File