| Made training bit... spicier.

This commit is contained in:
2025-09-10 19:52:53 +03:00
parent ff38cefdd3
commit 0bc8fc2792
8 changed files with 581 additions and 303 deletions

74
app.py
View File

@@ -1,33 +1,49 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import tqdm
import argparse import argparse
import math
import os import torch
import torchaudio
import torchcodec
import tqdm
import AudioUtils import AudioUtils
from generator import SISUGenerator from generator import SISUGenerator
# Init script argument parser # Init script argument parser
parser = argparse.ArgumentParser(description="Training script") parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--device", type=str, default="cpu", help="Select device") parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--model", type=str, help="Model to use for upscaling") parser.add_argument("--model", type=str, help="Model to use for upscaling")
parser.add_argument("--clip_length", type=int, default=1024, help="Internal clip length, leave unspecified if unsure") parser.add_argument(
"--clip_length",
type=int,
default=16384,
help="Internal clip length, leave unspecified if unsure",
)
parser.add_argument(
"--sample_rate", type=int, default=44100, help="Output clip sample rate"
)
parser.add_argument(
"--bitrate",
type=int,
default=192000,
help="Output clip bitrate",
)
parser.add_argument("-i", "--input", type=str, help="Input audio file") parser.add_argument("-i", "--input", type=str, help="Input audio file")
parser.add_argument("-o", "--output", type=str, help="Output audio file") parser.add_argument("-o", "--output", type=str, help="Output audio file")
args = parser.parse_args() args = parser.parse_args()
if args.sample_rate < 8000:
print(
"Sample rate cannot be lower than 8000! (44100 is recommended for base models)"
)
exit()
device = torch.device(args.device if torch.cuda.is_available() else "cpu") device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}") print(f"Using device: {device}")
generator = SISUGenerator() generator = SISUGenerator().to(device)
generator = torch.compile(generator)
models_dir = args.model models_dir = args.model
clip_length = args.clip_length clip_length = args.clip_length
@@ -35,17 +51,30 @@ input_audio = args.input
output_audio = args.output output_audio = args.output
if models_dir: if models_dir:
generator.load_state_dict(torch.load(f"{models_dir}", map_location=device, weights_only=True)) ckpt = torch.load(models_dir, map_location=device)
generator.load_state_dict(ckpt["G"])
else: else:
print(f"Generator model (--model) isn't specified. Do you have the trained model? If not you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)") print(
"Generator model (--model) isn't specified. Do you have the trained model? If not, you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)"
)
generator = generator.to(device)
def start(): def start():
# To Mono! # To Mono!
audio, original_sample_rate = torchaudio.load(input_audio, normalize=True) decoder = torchcodec.decoders.AudioDecoder(input_audio)
decoded_samples = decoder.get_all_samples()
audio = decoded_samples.data
original_sample_rate = decoded_samples.sample_rate
audio = AudioUtils.stereo_tensor_to_mono(audio) audio = AudioUtils.stereo_tensor_to_mono(audio)
resample_transform = torchaudio.transforms.Resample(
original_sample_rate, args.sample_rate
)
audio = resample_transform(audio)
splitted_audio = AudioUtils.split_audio(audio, clip_length) splitted_audio = AudioUtils.split_audio(audio, clip_length)
splitted_audio_on_device = [t.to(device) for t in splitted_audio] splitted_audio_on_device = [t.to(device) for t in splitted_audio]
processed_audio = [] processed_audio = []
@@ -55,6 +84,13 @@ def start():
reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio) reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio)
print(f"Saving {output_audio}!") print(f"Saving {output_audio}!")
torchaudio.save(output_audio, reconstructed_audio.cpu().detach(), original_sample_rate) torchaudio.save_with_torchcodec(
uri=output_audio,
src=reconstructed_audio,
sample_rate=args.sample_rate,
channels_first=True,
compression=args.bitrate,
)
start() start()

59
data.py
View File

@@ -1,41 +1,68 @@
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch
import torchaudio
import os import os
import random import random
import torchaudio.transforms as T
import torchaudio
import torchcodec.decoders as decoders
import tqdm import tqdm
from torch.utils.data import Dataset
import AudioUtils import AudioUtils
class AudioDataset(Dataset): class AudioDataset(Dataset):
audio_sample_rates = [11025] audio_sample_rates = [11025]
def __init__(self, input_dir, device, clip_length = 1024): def __init__(self, input_dir, clip_length=16384):
self.device = device input_files = [
input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav') or f.endswith('.mp3') or f.endswith('.flac')] os.path.join(root, f)
for root, _, files in os.walk(input_dir)
for f in files
if f.endswith(".wav") or f.endswith(".mp3") or f.endswith(".flac")
]
data = [] data = []
for audio_clip in tqdm.tqdm(input_files, desc=f"Processing {len(input_files)} audio file(s)"): for audio_clip in tqdm.tqdm(
audio, original_sample_rate = torchaudio.load(audio_clip, normalize=True) input_files, desc=f"Processing {len(input_files)} audio file(s)"
):
decoder = decoders.AudioDecoder(audio_clip)
decoded_samples = decoder.get_all_samples()
audio = decoded_samples.data
original_sample_rate = decoded_samples.sample_rate
audio = AudioUtils.stereo_tensor_to_mono(audio) audio = AudioUtils.stereo_tensor_to_mono(audio)
# Generate low-quality audio with random downsampling # Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates) mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) resample_transform_low = torchaudio.transforms.Resample(
resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate) original_sample_rate, mangled_sample_rate
)
resample_transform_high = torchaudio.transforms.Resample(
mangled_sample_rate, original_sample_rate
)
low_audio = resample_transform_low(audio) low_audio = resample_transform_low(audio)
low_audio = resample_transform_high(low_audio) low_audio = resample_transform_high(low_audio)
splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length) splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length)
splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(splitted_high_quality_audio[-1], clip_length) splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(
splitted_high_quality_audio[-1], clip_length
)
splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length) splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length)
splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(splitted_low_quality_audio[-1], clip_length) splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(
splitted_low_quality_audio[-1], clip_length
)
for high_quality_sample, low_quality_sample in zip(splitted_high_quality_audio, splitted_low_quality_audio): for high_quality_sample, low_quality_sample in zip(
data.append(((high_quality_sample, low_quality_sample), (original_sample_rate, mangled_sample_rate))) splitted_high_quality_audio, splitted_low_quality_audio
):
data.append(
(
(high_quality_sample, low_quality_sample),
(original_sample_rate, mangled_sample_rate),
)
)
self.audio_data = data self.audio_data = data

View File

@@ -1,8 +1,16 @@
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.utils as utils import torch.nn.utils as utils
def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, spectral_norm=True, use_instance_norm=True):
def discriminator_block(
in_channels,
out_channels,
kernel_size=3,
stride=1,
dilation=1,
spectral_norm=True,
use_instance_norm=True,
):
padding = (kernel_size // 2) * dilation padding = (kernel_size // 2) * dilation
conv_layer = nn.Conv1d( conv_layer = nn.Conv1d(
in_channels, in_channels,
@@ -10,7 +18,7 @@ def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dila
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride, stride=stride,
dilation=dilation, dilation=dilation,
padding=padding padding=padding,
) )
if spectral_norm: if spectral_norm:
@@ -24,6 +32,7 @@ def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dila
return nn.Sequential(*layers) return nn.Sequential(*layers)
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
def __init__(self, channels): def __init__(self, channels):
super(AttentionBlock, self).__init__() super(AttentionBlock, self).__init__()
@@ -31,27 +40,86 @@ class AttentionBlock(nn.Module):
nn.Conv1d(channels, channels // 4, kernel_size=1), nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv1d(channels // 4, channels, kernel_size=1), nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid() nn.Sigmoid(),
) )
def forward(self, x): def forward(self, x):
attention_weights = self.attention(x) attention_weights = self.attention(x)
return x * attention_weights return x * attention_weights
class SISUDiscriminator(nn.Module): class SISUDiscriminator(nn.Module):
def __init__(self, base_channels=16): def __init__(self, base_channels=16):
super(SISUDiscriminator, self).__init__() super(SISUDiscriminator, self).__init__()
layers = base_channels layers = base_channels
self.model = nn.Sequential( self.model = nn.Sequential(
discriminator_block(1, layers, kernel_size=7, stride=1, spectral_norm=True, use_instance_norm=False), discriminator_block(
discriminator_block(layers, layers * 2, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True), 1,
discriminator_block(layers * 2, layers * 4, kernel_size=5, stride=1, dilation=2, spectral_norm=True, use_instance_norm=True), layers,
kernel_size=7,
stride=1,
spectral_norm=True,
use_instance_norm=False,
),
discriminator_block(
layers,
layers * 2,
kernel_size=5,
stride=2,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block(
layers * 2,
layers * 4,
kernel_size=5,
stride=1,
dilation=2,
spectral_norm=True,
use_instance_norm=True,
),
AttentionBlock(layers * 4), AttentionBlock(layers * 4),
discriminator_block(layers * 4, layers * 8, kernel_size=5, stride=1, dilation=4, spectral_norm=True, use_instance_norm=True), discriminator_block(
discriminator_block(layers * 8, layers * 4, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True), layers * 4,
discriminator_block(layers * 4, layers * 2, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True), layers * 8,
discriminator_block(layers * 2, layers, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True), kernel_size=5,
discriminator_block(layers, 1, kernel_size=3, stride=1, spectral_norm=False, use_instance_norm=False) stride=1,
dilation=4,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block(
layers * 8,
layers * 4,
kernel_size=5,
stride=2,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block(
layers * 4,
layers * 2,
kernel_size=3,
stride=1,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block(
layers * 2,
layers,
kernel_size=3,
stride=1,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block(
layers,
1,
kernel_size=3,
stride=1,
spectral_norm=False,
use_instance_norm=False,
),
) )
self.global_avg_pool = nn.AdaptiveAvgPool1d(1) self.global_avg_pool = nn.AdaptiveAvgPool1d(1)

View File

@@ -1,6 +1,6 @@
import torch
import torch.nn as nn import torch.nn as nn
def conv_block(in_channels, out_channels, kernel_size=3, dilation=1): def conv_block(in_channels, out_channels, kernel_size=3, dilation=1):
return nn.Sequential( return nn.Sequential(
nn.Conv1d( nn.Conv1d(
@@ -8,29 +8,32 @@ def conv_block(in_channels, out_channels, kernel_size=3, dilation=1):
out_channels, out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
dilation=dilation, dilation=dilation,
padding=(kernel_size // 2) * dilation padding=(kernel_size // 2) * dilation,
), ),
nn.InstanceNorm1d(out_channels), nn.InstanceNorm1d(out_channels),
nn.PReLU() nn.PReLU(),
) )
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
""" """
Simple Channel Attention Block. Learns to weight channels based on their importance. Simple Channel Attention Block. Learns to weight channels based on their importance.
""" """
def __init__(self, channels): def __init__(self, channels):
super(AttentionBlock, self).__init__() super(AttentionBlock, self).__init__()
self.attention = nn.Sequential( self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1), nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv1d(channels // 4, channels, kernel_size=1), nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid() nn.Sigmoid(),
) )
def forward(self, x): def forward(self, x):
attention_weights = self.attention(x) attention_weights = self.attention(x)
return x * attention_weights return x * attention_weights
class ResidualInResidualBlock(nn.Module): class ResidualInResidualBlock(nn.Module):
def __init__(self, channels, num_convs=3): def __init__(self, channels, num_convs=3):
super(ResidualInResidualBlock, self).__init__() super(ResidualInResidualBlock, self).__init__()
@@ -47,6 +50,7 @@ class ResidualInResidualBlock(nn.Module):
x = self.attention(x) x = self.attention(x)
return x + residual return x + residual
class SISUGenerator(nn.Module): class SISUGenerator(nn.Module):
def __init__(self, channels=16, num_rirb=4, alpha=1.0): def __init__(self, channels=16, num_rirb=4, alpha=1.0):
super(SISUGenerator, self).__init__() super(SISUGenerator, self).__init__()

View File

@@ -1,65 +1,74 @@
import argparse
import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torchaudio.transforms as T
import torch.nn.functional as F
import torchaudio
import tqdm import tqdm
from torch.amp import GradScaler, autocast
import argparse
import math
import os
from torch.utils.data import random_split
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import AudioUtils import training_utils
from data import AudioDataset from data import AudioDataset
from generator import SISUGenerator
from discriminator import SISUDiscriminator from discriminator import SISUDiscriminator
from generator import SISUGenerator
from training_utils import discriminator_train, generator_train from training_utils import discriminator_train, generator_train
import file_utils as Data
import torchaudio.transforms as T
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models")
# ---------------------------
# Argument parsing
# ---------------------------
parser = argparse.ArgumentParser(description="Training script (safer defaults)")
parser.add_argument("--resume", action="store_true", help="Resume training")
parser.add_argument(
"--device", type=str, default="cuda", help="Device (cuda, cpu, mps)"
)
parser.add_argument(
"--epochs", type=int, default=5000, help="Number of training epochs"
)
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
parser.add_argument("--num_workers", type=int, default=2, help="DataLoader num_workers")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument(
"--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA"
)
args = parser.parse_args() args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu") # ---------------------------
# Device setup
# ---------------------------
# Use requested device only if available
device = torch.device(
args.device if (args.device != "cuda" or torch.cuda.is_available()) else "cpu"
)
print(f"Using device: {device}") print(f"Using device: {device}")
# sensible performance flags
if device.type == "cuda":
torch.backends.cudnn.benchmark = True
# optional: torch.set_float32_matmul_precision("high")
debug = args.debug
# Parameters # ---------------------------
# Audio transforms
# ---------------------------
sample_rate = 44100 sample_rate = 44100
n_fft = 1024 n_fft = 1024
win_length = n_fft win_length = n_fft
hop_length = n_fft // 4 hop_length = n_fft // 4
n_mels = 40 n_mels = 96
n_mfcc = 13 # n_mfcc = 13
mfcc_transform = T.MFCC( # mfcc_transform = T.MFCC(
sample_rate=sample_rate, # sample_rate=sample_rate,
n_mfcc=n_mfcc, # n_mfcc=n_mfcc,
melkwargs={ # melkwargs=dict(
'n_fft': n_fft, # n_fft=n_fft,
'hop_length': hop_length, # hop_length=hop_length,
'win_length': win_length, # win_length=win_length,
'n_mels': n_mels, # n_mels=n_mels,
'power': 1.0, # power=1.0,
} # ),
).to(device) # ).to(device)
mel_transform = T.MelSpectrogram( mel_transform = T.MelSpectrogram(
sample_rate=sample_rate, sample_rate=sample_rate,
@@ -67,138 +76,198 @@ mel_transform = T.MelSpectrogram(
hop_length=hop_length, hop_length=hop_length,
win_length=win_length, win_length=win_length,
n_mels=n_mels, n_mels=n_mels,
power=1.0 # Magnitude Mel power=1.0,
).to(device) ).to(device)
stft_transform = T.Spectrogram( stft_transform = T.Spectrogram(
n_fft=n_fft, n_fft=n_fft, win_length=win_length, hop_length=hop_length
win_length=win_length,
hop_length=hop_length
).to(device) ).to(device)
debug = args.debug
# Initialize dataset and dataloader # training_utils.init(mel_transform, stft_transform, mfcc_transform)
dataset_dir = './dataset/good' training_utils.init(mel_transform, stft_transform)
dataset = AudioDataset(dataset_dir, device)
models_dir = "./models"
os.makedirs(models_dir, exist_ok=True)
audio_output_dir = "./output"
os.makedirs(audio_output_dir, exist_ok=True)
# ========= SINGLE ========= # ---------------------------
# Dataset / DataLoader
# ---------------------------
dataset_dir = "./dataset/good"
dataset = AudioDataset(dataset_dir)
train_data_loader = DataLoader(dataset, batch_size=2048, shuffle=True, num_workers=24) train_loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
persistent_workers=True,
)
# ---------------------------
# Models
# ---------------------------
generator = SISUGenerator().to(device)
discriminator = SISUDiscriminator().to(device)
# ========= MODELS ========= generator = torch.compile(generator)
discriminator = torch.compile(discriminator)
generator = SISUGenerator() # ---------------------------
discriminator = SISUDiscriminator() # Losses / Optimizers / Scalers
# ---------------------------
epoch: int = args.epoch
if args.continue_training:
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
elif args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
else:
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_discriminator.pt", map_location=device, weights_only=True))
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
epoch = epoch_from_file["epoch"] + 1
generator = generator.to(device)
discriminator = discriminator.to(device)
# Loss
criterion_g = nn.BCEWithLogitsLoss() criterion_g = nn.BCEWithLogitsLoss()
criterion_d = nn.BCEWithLogitsLoss() criterion_d = nn.BCEWithLogitsLoss()
# Optimizers optimizer_g = optim.AdamW(
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999)) generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) )
optimizer_d = optim.AdamW(
discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001
)
# Scheduler # Use modern GradScaler signature; choose device_type based on runtime device.
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5) scaler = GradScaler(device=device)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
def start_training(): scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(
generator_epochs = 5000 optimizer_g, mode="min", factor=0.5, patience=5
for generator_epoch in range(generator_epochs): )
high_quality_audio = ([torch.empty((1))], 1) scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(
low_quality_audio = ([torch.empty((1))], 1) optimizer_d, mode="min", factor=0.5, patience=5
ai_enhanced_audio = ([torch.empty((1))], 1) )
times_correct = 0 # ---------------------------
# Checkpoint helpers
# ========= TRAINING ========= # ---------------------------
for training_data in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"): models_dir = "./models"
## Data structure: os.makedirs(models_dir, exist_ok=True)
# [[[float..., float..., float...], [float..., float..., float...]], [original_sample_rate, mangled_sample_rate]]
# ========= LABELS =========
good_quality_data = training_data[0][0].to(device)
bad_quality_data = training_data[0][1].to(device)
original_sample_rate = training_data[1][0]
mangled_sample_rate = training_data[1][1]
batch_size = good_quality_data.size(0)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
high_quality_audio = (good_quality_data, original_sample_rate)
low_quality_audio = (bad_quality_data, mangled_sample_rate)
# ========= DISCRIMINATOR =========
discriminator.train()
d_loss = discriminator_train(
good_quality_data,
bad_quality_data,
real_labels,
fake_labels,
discriminator,
generator,
criterion_d,
optimizer_d
)
# ========= GENERATOR =========
generator.train()
generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train(
bad_quality_data,
good_quality_data,
real_labels,
generator,
discriminator,
criterion_d,
optimizer_g,
device,
mel_transform,
stft_transform,
mfcc_transform
)
if debug:
print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}")
scheduler_d.step(d_loss.detach())
scheduler_g.step(adversarial_loss.detach())
# ========= SAVE LATEST AUDIO =========
high_quality_audio = (good_quality_data, original_sample_rate)
low_quality_audio = (bad_quality_data, original_sample_rate)
ai_enhanced_audio = (generator_output, original_sample_rate)
torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt")
torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt")
new_epoch = generator_epoch+epoch
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch})
torch.save(discriminator, "models/epoch-5000-discriminator.pt") def save_ckpt(path, epoch):
torch.save(generator, "models/epoch-5000-generator.pt") torch.save(
print("Training complete!") {
"epoch": epoch,
"G": generator.state_dict(),
"D": discriminator.state_dict(),
"optG": optimizer_g.state_dict(),
"optD": optimizer_d.state_dict(),
"scaler": scaler.state_dict(),
"schedG": scheduler_g.state_dict(),
"schedD": scheduler_d.state_dict(),
},
path,
)
start_training()
start_epoch = 0
if args.resume:
ckpt = torch.load(os.path.join(models_dir, "last.pt"), map_location=device)
generator.load_state_dict(ckpt["G"])
discriminator.load_state_dict(ckpt["D"])
optimizer_g.load_state_dict(ckpt["optG"])
optimizer_d.load_state_dict(ckpt["optD"])
scaler.load_state_dict(ckpt["scaler"])
scheduler_g.load_state_dict(ckpt["schedG"])
scheduler_d.load_state_dict(ckpt["schedD"])
start_epoch = ckpt.get("epoch", 1)
# ---------------------------
# Training loop (safer)
# ---------------------------
if not train_loader or not train_loader.batch_size:
print("There is no data to train with! Exiting...")
exit()
max_batch = max(1, train_loader.batch_size)
real_buf = torch.full((max_batch, 1), 0.9, device=device) # label smoothing
fake_buf = torch.zeros(max_batch, 1, device=device)
try:
for epoch in range(start_epoch, args.epochs):
generator.train()
discriminator.train()
running_d, running_g, steps = 0.0, 0.0, 0
for i, (
(high_quality, low_quality),
(high_sample_rate, low_sample_rate),
) in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")):
batch_size = high_quality.size(0)
high_quality = high_quality.to(device, non_blocking=True)
low_quality = low_quality.to(device, non_blocking=True)
real_labels = real_buf[:batch_size]
fake_labels = fake_buf[:batch_size]
# --- Discriminator ---
optimizer_d.zero_grad(set_to_none=True)
with autocast(device_type=device.type):
d_loss = discriminator_train(
high_quality,
low_quality,
real_labels,
fake_labels,
discriminator,
generator,
criterion_d,
)
scaler.scale(d_loss).backward()
scaler.unscale_(optimizer_d)
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0)
scaler.step(optimizer_d)
# --- Generator ---
optimizer_g.zero_grad(set_to_none=True)
with autocast(device_type=device.type):
g_out, g_total, g_adv = generator_train(
low_quality,
high_quality,
real_labels,
generator,
discriminator,
criterion_d,
)
scaler.scale(g_total).backward()
scaler.unscale_(optimizer_g)
torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)
scaler.step(optimizer_g)
scaler.update()
running_d += float(d_loss.detach().cpu().item())
running_g += float(g_total.detach().cpu().item())
steps += 1
# epoch averages & schedulers
if steps == 0:
print("No steps in epoch (empty dataloader?). Exiting.")
break
mean_d = running_d / steps
mean_g = running_g / steps
scheduler_d.step(mean_d)
scheduler_g.step(mean_g)
save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
print(f"Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
except Exception:
try:
save_ckpt(os.path.join(models_dir, "crash_last.pt"), epoch)
print(f"Saved crash checkpoint for epoch {epoch}")
except Exception as e:
print("Failed saving crash checkpoint:", e)
raise
try:
torch.save(generator.state_dict(), os.path.join(models_dir, "final_generator.pt"))
torch.save(
discriminator.state_dict(), os.path.join(models_dir, "final_discriminator.pt")
)
except Exception as e:
print("Failed to save final states:", e)
print("Training finished.")

View File

@@ -1,89 +1,88 @@
import torch import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import torchaudio.transforms as T import torchaudio.transforms as T
def gpu_mfcc_loss(mfcc_transform, y_true, y_pred): from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
mfccs_true = mfcc_transform(y_true)
mfccs_pred = mfcc_transform(y_pred)
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2]) mel_transform: T.MelSpectrogram
mfccs_true = mfccs_true[:, :, :min_len] stft_transform: T.Spectrogram
mfccs_pred = mfccs_pred[:, :, :min_len] # mfcc_transform: T.MFCC
loss = torch.mean((mfccs_true - mfccs_pred)**2)
return loss
def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: # def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram, mfcc_trans: T.MFCC):
mel_spec_true = mel_transform(y_true) # """Initializes the global transform variables for the module."""
mel_spec_pred = mel_transform(y_pred) # global mel_transform, stft_transform, mfcc_transform
# mel_transform = mel_trans
# stft_transform = stft_trans
# mfcc_transform = mfcc_trans
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
mel_spec_true = mel_spec_true[..., :min_len]
mel_spec_pred = mel_spec_pred[..., :min_len]
loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred)) def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram):
return loss """Initializes the global transform variables for the module."""
global mel_transform, stft_transform
mel_transform = mel_trans
stft_transform = stft_trans
def mel_spectrogram_l2_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
mel_spec_true = mel_transform(y_true)
mel_spec_pred = mel_transform(y_pred)
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1]) # def mfcc_loss(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
mel_spec_true = mel_spec_true[..., :min_len] # """Computes the Mean Squared Error (MSE) loss on MFCCs."""
mel_spec_pred = mel_spec_pred[..., :min_len] # mfccs_true = mfcc_transform(y_true)
# mfccs_pred = mfcc_transform(y_pred)
# return F.mse_loss(mfccs_pred, mfccs_true)
loss = torch.mean((mel_spec_true - mel_spec_pred)**2)
return loss
def log_stft_magnitude_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: # def mel_spectrogram_loss(
stft_mag_true = stft_transform(y_true) # y_true: torch.Tensor, y_pred: torch.Tensor, loss_type: str = "l1"
stft_mag_pred = stft_transform(y_pred) # ) -> torch.Tensor:
# """Calculates L1 or L2 loss on the Mel Spectrogram."""
# mel_spec_true = mel_transform(y_true)
# mel_spec_pred = mel_transform(y_pred)
# if loss_type == "l1":
# return F.l1_loss(mel_spec_pred, mel_spec_true)
# elif loss_type == "l2":
# return F.mse_loss(mel_spec_pred, mel_spec_true)
# else:
# raise ValueError("loss_type must be 'l1' or 'l2'")
min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1])
stft_mag_true = stft_mag_true[..., :min_len]
stft_mag_pred = stft_mag_pred[..., :min_len]
loss = torch.mean(torch.abs(torch.log(stft_mag_true + eps) - torch.log(stft_mag_pred + eps))) # def log_stft_magnitude_loss(
return loss # y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7
# ) -> torch.Tensor:
# """Calculates L1 loss on the log STFT magnitude."""
# stft_mag_true = stft_transform(y_true)
# stft_mag_pred = stft_transform(y_pred)
# return F.l1_loss(torch.log(stft_mag_pred + eps), torch.log(stft_mag_true + eps))
def spectral_convergence_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
stft_mag_true = stft_transform(y_true)
stft_mag_pred = stft_transform(y_pred)
min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1]) stft_loss_fn = MultiResolutionSTFTLoss(
stft_mag_true = stft_mag_true[..., :min_len] fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240]
stft_mag_pred = stft_mag_pred[..., :min_len] )
norm_true = torch.linalg.norm(stft_mag_true, ord='fro', dim=(-2, -1))
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, ord='fro', dim=(-2, -1))
loss = torch.mean(norm_diff / (norm_true + eps)) def discriminator_train(
return loss high_quality,
low_quality,
def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer): real_labels,
optimizer.zero_grad() fake_labels,
discriminator,
# Forward pass for real samples generator,
criterion,
):
discriminator_decision_from_real = discriminator(high_quality) discriminator_decision_from_real = discriminator(high_quality)
d_loss_real = criterion(discriminator_decision_from_real, real_labels) d_loss_real = criterion(discriminator_decision_from_real, real_labels)
with torch.no_grad(): with torch.no_grad():
generator_output = generator(low_quality) generator_output = generator(low_quality)
discriminator_decision_from_fake = discriminator(generator_output) discriminator_decision_from_fake = discriminator(generator_output)
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake)) d_loss_fake = criterion(
discriminator_decision_from_fake,
fake_labels.expand_as(discriminator_decision_from_fake),
)
d_loss = (d_loss_real + d_loss_fake) / 2.0 d_loss = (d_loss_real + d_loss_fake) / 2.0
d_loss.backward()
# Optional: Gradient Clipping (can be helpful)
# nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
optimizer.step()
return d_loss return d_loss
def generator_train( def generator_train(
low_quality, low_quality,
high_quality, high_quality,
@@ -91,52 +90,65 @@ def generator_train(
generator, generator,
discriminator, discriminator,
adv_criterion, adv_criterion,
g_optimizer,
device,
mel_transform: T.MelSpectrogram,
stft_transform: T.Spectrogram,
mfcc_transform: T.MFCC,
lambda_adv: float = 1.0, lambda_adv: float = 1.0,
lambda_mel_l1: float = 10.0, lambda_feat: float = 10.0,
lambda_log_stft: float = 1.0, lambda_stft: float = 2.5,
lambda_mfcc: float = 1.0
): ):
g_optimizer.zero_grad()
generator_output = generator(low_quality) generator_output = generator(low_quality)
discriminator_decision = discriminator(generator_output) discriminator_decision = discriminator(generator_output)
adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision)) # adversarial_loss = adv_criterion(
# discriminator_decision, real_labels.expand_as(discriminator_decision)
# )
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
mel_l1 = 0.0 combined_loss = lambda_adv * adversarial_loss
log_stft_l1 = 0.0
mfcc_l = 0.0
# Calculate Mel L1 Loss if weight is positive stft_losses = stft_loss_fn(high_quality, generator_output)
if lambda_mel_l1 > 0: stft_loss = stft_losses["total"]
mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality, generator_output)
# Calculate Log STFT L1 Loss if weight is positive combined_loss = (lambda_adv * adversarial_loss) + (lambda_stft * stft_loss)
if lambda_log_stft > 0:
log_stft_l1 = log_stft_magnitude_loss(stft_transform, high_quality, generator_output)
# Calculate MFCC Loss if weight is positive return generator_output, combined_loss, adversarial_loss
if lambda_mfcc > 0:
mfcc_l = gpu_mfcc_loss(mfcc_transform, high_quality, generator_output)
mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1
log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1
mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l
combined_loss = (lambda_adv * adversarial_loss) + \ # def generator_train(
(lambda_mel_l1 * mel_l1_tensor) + \ # low_quality,
(lambda_log_stft * log_stft_l1_tensor) + \ # high_quality,
(lambda_mfcc * mfcc_l_tensor) # real_labels,
# generator,
# discriminator,
# adv_criterion,
# lambda_adv: float = 1.0,
# lambda_mel_l1: float = 10.0,
# lambda_log_stft: float = 1.0,
combined_loss.backward() # ):
# Optional: Gradient Clipping # generator_output = generator(low_quality)
# nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
g_optimizer.step()
# 6. Return values for logging # discriminator_decision = discriminator(generator_output)
return generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor # adversarial_loss = adv_criterion(
# discriminator_decision, real_labels.expand_as(discriminator_decision)
# )
# combined_loss = lambda_adv * adversarial_loss
# if lambda_mel_l1 > 0:
# mel_l1_loss = mel_spectrogram_loss(high_quality, generator_output, "l1")
# combined_loss += lambda_mel_l1 * mel_l1_loss
# else:
# mel_l1_loss = torch.tensor(0.0, device=low_quality.device) # For logging
# if lambda_log_stft > 0:
# log_stft_loss = log_stft_magnitude_loss(high_quality, generator_output)
# combined_loss += lambda_log_stft * log_stft_loss
# else:
# log_stft_loss = torch.tensor(0.0, device=low_quality.device)
# if lambda_mfcc > 0:
# mfcc_loss_val = mfcc_loss(high_quality, generator_output)
# combined_loss += lambda_mfcc * mfcc_loss_val
# else:
# mfcc_loss_val = torch.tensor(0.0, device=low_quality.device)
# return generator_output, combined_loss, adversarial_loss

View File

@@ -0,0 +1,62 @@
from typing import Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.transforms as T
class MultiResolutionSTFTLoss(nn.Module):
"""
Computes a loss based on multiple STFT resolutions, including both
spectral convergence and log STFT magnitude components.
"""
def __init__(
self,
fft_sizes: List[int] = [1024, 2048, 512],
hop_sizes: List[int] = [120, 240, 50],
win_lengths: List[int] = [600, 1200, 240],
eps: float = 1e-7,
):
super().__init__()
self.stft_transforms = nn.ModuleList(
[
T.Spectrogram(
n_fft=n_fft, win_length=win_len, hop_length=hop_len, power=None
)
for n_fft, hop_len, win_len in zip(fft_sizes, hop_sizes, win_lengths)
]
)
self.eps = eps
def forward(
self, y_true: torch.Tensor, y_pred: torch.Tensor
) -> Dict[str, torch.Tensor]:
sc_loss = 0.0 # Spectral Convergence Loss
mag_loss = 0.0 # Log STFT Magnitude Loss
for stft in self.stft_transforms:
stft.to(y_pred.device) # Ensure transform is on the correct device
# Get complex STFTs
stft_true = stft(y_true)
stft_pred = stft(y_pred)
# Get magnitudes
stft_mag_true = torch.abs(stft_true)
stft_mag_pred = torch.abs(stft_pred)
# --- Spectral Convergence Loss ---
# || |S_true| - |S_pred| ||_F / || |S_true| ||_F
norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1))
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1))
sc_loss += torch.mean(norm_diff / (norm_true + self.eps))
# --- Log STFT Magnitude Loss ---
mag_loss += F.l1_loss(
torch.log(stft_mag_pred + self.eps), torch.log(stft_mag_true + self.eps)
)
total_loss = sc_loss + mag_loss
return {"total": total_loss, "sc": sc_loss, "mag": mag_loss}

0
utils/__init__.py Normal file
View File