⚗️ | Added some stupid ways for training + some makeup

This commit is contained in:
2025-10-04 22:38:11 +03:00
parent 0bc8fc2792
commit 3f23242d6f
12 changed files with 304 additions and 463 deletions

View File

@@ -1,71 +1,97 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
def stereo_tensor_to_mono(waveform):
def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor:
"""
Convert stereo (C, N) to mono (1, N). Ensures a channel dimension.
"""
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0) # (N,) -> (1, N)
if waveform.shape[0] > 1: if waveform.shape[0] > 1:
# Average across channels mono_waveform = torch.mean(waveform, dim=0, keepdim=True) # (1, N)
mono_waveform = torch.mean(waveform, dim=0, keepdim=True)
else: else:
# Already mono
mono_waveform = waveform mono_waveform = waveform
return mono_waveform return mono_waveform
def stretch_tensor(tensor, target_length):
scale_factor = target_length / tensor.size(1)
tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False) def stretch_tensor(tensor: torch.Tensor, target_length: int) -> torch.Tensor:
"""
Stretch audio along time dimension to target_length.
Input assumed (1, N). Returns (1, target_length).
"""
if tensor.dim() == 1:
tensor = tensor.unsqueeze(0) # ensure (1, N)
return tensor tensor = tensor.unsqueeze(0) # (1, 1, N) for interpolate
stretched = F.interpolate(
tensor, size=target_length, mode="linear", align_corners=False
)
return stretched.squeeze(0) # back to (1, target_length)
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128) -> torch.Tensor:
"""
Pad to fixed length. Input assumed (1, N). Returns (1, target_length).
"""
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128):
current_length = audio_tensor.shape[-1] current_length = audio_tensor.shape[-1]
if current_length < target_length: if current_length < target_length:
padding_needed = target_length - current_length padding_needed = target_length - current_length
padding_tuple = (0, padding_needed) padding_tuple = (0, padding_needed)
padded_audio_tensor = F.pad(audio_tensor, padding_tuple, mode='constant', value=0) padded_audio_tensor = F.pad(
audio_tensor, padding_tuple, mode="constant", value=0
)
else: else:
padded_audio_tensor = audio_tensor padded_audio_tensor = audio_tensor[..., :target_length] # crop if too long
return padded_audio_tensor return padded_audio_tensor
def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 128) -> list[torch.Tensor]:
def split_audio(
audio_tensor: torch.Tensor, chunk_size: int = 128
) -> list[torch.Tensor]:
"""
Split into chunks of (1, chunk_size).
"""
if not isinstance(chunk_size, int) or chunk_size <= 0: if not isinstance(chunk_size, int) or chunk_size <= 0:
raise ValueError("chunk_size must be a positive integer.") raise ValueError("chunk_size must be a positive integer.")
# Handle scalar tensor edge case if necessary if audio_tensor.dim() == 1:
if audio_tensor.dim() == 0: audio_tensor = audio_tensor.unsqueeze(0)
return [audio_tensor] if audio_tensor.numel() > 0 else []
# Identify the dimension to split (usually the last one, representing time/samples)
split_dim = -1
num_samples = audio_tensor.shape[split_dim]
num_samples = audio_tensor.shape[-1]
if num_samples == 0: if num_samples == 0:
return [] # Return empty list if the dimension to split is empty return []
# Use torch.split to divide the tensor into chunks
# It handles the last chunk being potentially smaller automatically.
chunks = list(torch.split(audio_tensor, chunk_size, dim=split_dim))
chunks = list(torch.split(audio_tensor, chunk_size, dim=-1))
return chunks return chunks
def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor: def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor:
"""
Reconstruct audio from chunks. Returns (1, N).
"""
if not chunks: if not chunks:
return torch.empty(0) return torch.empty(1, 0)
if len(chunks) == 1 and chunks[0].dim() == 0:
return chunks[0]
concat_dim = -1
chunks = [c if c.dim() == 2 else c.unsqueeze(0) for c in chunks]
try: try:
reconstructed_tensor = torch.cat(chunks, dim=concat_dim) reconstructed_tensor = torch.cat(chunks, dim=-1)
except RuntimeError as e: except RuntimeError as e:
raise RuntimeError( raise RuntimeError(
f"Failed to concatenate audio chunks. Ensure chunks have compatible shapes " f"Failed to concatenate audio chunks. Ensure chunks have compatible shapes "
f"for concatenation along dimension {concat_dim}. Original error: {e}" f"for concatenation along dim -1. Original error: {e}"
) )
return reconstructed_tensor return reconstructed_tensor
def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
max_val = torch.max(torch.abs(audio_tensor))
if max_val < eps:
return audio_tensor # silence, skip normalization
return audio_tensor / max_val

0
__init__.py Normal file
View File

1
app.py
View File

@@ -68,6 +68,7 @@ def start():
original_sample_rate = decoded_samples.sample_rate original_sample_rate = decoded_samples.sample_rate
audio = AudioUtils.stereo_tensor_to_mono(audio) audio = AudioUtils.stereo_tensor_to_mono(audio)
audio = AudioUtils.normalize(audio)
resample_transform = torchaudio.transforms.Resample( resample_transform = torchaudio.transforms.Resample(
original_sample_rate, args.sample_rate original_sample_rate, args.sample_rate

34
data.py
View File

@@ -12,12 +12,15 @@ import AudioUtils
class AudioDataset(Dataset): class AudioDataset(Dataset):
audio_sample_rates = [11025] audio_sample_rates = [11025]
def __init__(self, input_dir, clip_length=16384): def __init__(self, input_dir, clip_length: int = 8000, normalize: bool = True):
self.clip_length = clip_length
self.normalize = normalize
input_files = [ input_files = [
os.path.join(root, f) os.path.join(input_dir, f)
for root, _, files in os.walk(input_dir) for f in os.listdir(input_dir)
for f in files if os.path.isfile(os.path.join(input_dir, f))
if f.endswith(".wav") or f.endswith(".mp3") or f.endswith(".flac") and f.lower().endswith((".wav", ".mp3", ".flac"))
] ]
data = [] data = []
@@ -25,14 +28,15 @@ class AudioDataset(Dataset):
input_files, desc=f"Processing {len(input_files)} audio file(s)" input_files, desc=f"Processing {len(input_files)} audio file(s)"
): ):
decoder = decoders.AudioDecoder(audio_clip) decoder = decoders.AudioDecoder(audio_clip)
decoded_samples = decoder.get_all_samples() decoded_samples = decoder.get_all_samples()
audio = decoded_samples.data
audio = decoded_samples.data.float() # ensure float32
original_sample_rate = decoded_samples.sample_rate original_sample_rate = decoded_samples.sample_rate
audio = AudioUtils.stereo_tensor_to_mono(audio) audio = AudioUtils.stereo_tensor_to_mono(audio)
if normalize:
audio = AudioUtils.normalize(audio)
# Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates) mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform_low = torchaudio.transforms.Resample( resample_transform_low = torchaudio.transforms.Resample(
original_sample_rate, mangled_sample_rate original_sample_rate, mangled_sample_rate
@@ -41,25 +45,27 @@ class AudioDataset(Dataset):
mangled_sample_rate, original_sample_rate mangled_sample_rate, original_sample_rate
) )
low_audio = resample_transform_low(audio) low_audio = resample_transform_high(resample_transform_low(audio))
low_audio = resample_transform_high(low_audio)
splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length) splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length)
splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length)
if not splitted_high_quality_audio or not splitted_low_quality_audio:
continue # skip empty or invalid clips
splitted_high_quality_audio[-1] = AudioUtils.pad_tensor( splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(
splitted_high_quality_audio[-1], clip_length splitted_high_quality_audio[-1], clip_length
) )
splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length)
splitted_low_quality_audio[-1] = AudioUtils.pad_tensor( splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(
splitted_low_quality_audio[-1], clip_length splitted_low_quality_audio[-1], clip_length
) )
for high_quality_sample, low_quality_sample in zip( for high_quality_data, low_quality_data in zip(
splitted_high_quality_audio, splitted_low_quality_audio splitted_high_quality_audio, splitted_low_quality_audio
): ):
data.append( data.append(
( (
(high_quality_sample, low_quality_sample), (high_quality_data, low_quality_data),
(original_sample_rate, mangled_sample_rate), (original_sample_rate, mangled_sample_rate),
) )
) )

View File

@@ -49,74 +49,18 @@ class AttentionBlock(nn.Module):
class SISUDiscriminator(nn.Module): class SISUDiscriminator(nn.Module):
def __init__(self, base_channels=16): def __init__(self, layers=32):
super(SISUDiscriminator, self).__init__() super(SISUDiscriminator, self).__init__()
layers = base_channels
self.model = nn.Sequential( self.model = nn.Sequential(
discriminator_block( discriminator_block(1, layers, kernel_size=7, stride=1),
1, discriminator_block(layers, layers * 2, kernel_size=5, stride=2),
layers, discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=2),
kernel_size=7,
stride=1,
spectral_norm=True,
use_instance_norm=False,
),
discriminator_block(
layers,
layers * 2,
kernel_size=5,
stride=2,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block(
layers * 2,
layers * 4,
kernel_size=5,
stride=1,
dilation=2,
spectral_norm=True,
use_instance_norm=True,
),
AttentionBlock(layers * 4), AttentionBlock(layers * 4),
discriminator_block( discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=4),
layers * 4, discriminator_block(layers * 8, layers * 2, kernel_size=5, stride=2),
layers * 8,
kernel_size=5,
stride=1,
dilation=4,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block(
layers * 8,
layers * 4,
kernel_size=5,
stride=2,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block(
layers * 4,
layers * 2,
kernel_size=3,
stride=1,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block( discriminator_block(
layers * 2, layers * 2,
layers,
kernel_size=3,
stride=1,
spectral_norm=True,
use_instance_norm=True,
),
discriminator_block(
layers,
1, 1,
kernel_size=3,
stride=1,
spectral_norm=False, spectral_norm=False,
use_instance_norm=False, use_instance_norm=False,
), ),

View File

@@ -1,30 +0,0 @@
import json
filepath = "my_data.json"
def write_data(filepath, data, debug=False):
try:
with open(filepath, 'w') as f:
json.dump(data, f, indent=4) # Use indent for pretty formatting
if debug:
print(f"Data written to '{filepath}'")
except Exception as e:
print(f"Error writing to file: {e}")
def read_data(filepath, debug=False):
try:
with open(filepath, 'r') as f:
data = json.load(f)
if debug:
print(f"Data read from '{filepath}'")
return data
except FileNotFoundError:
print(f"File not found: {filepath}")
return None
except json.JSONDecodeError:
print(f"Error decoding JSON from file: {filepath}")
return None
except Exception as e:
print(f"Error reading from file: {e}")
return None

View File

@@ -1,3 +1,4 @@
import torch
import torch.nn as nn import torch.nn as nn
@@ -52,7 +53,7 @@ class ResidualInResidualBlock(nn.Module):
class SISUGenerator(nn.Module): class SISUGenerator(nn.Module):
def __init__(self, channels=16, num_rirb=4, alpha=1.0): def __init__(self, channels=16, num_rirb=4, alpha=1):
super(SISUGenerator, self).__init__() super(SISUGenerator, self).__init__()
self.alpha = alpha self.alpha = alpha
@@ -66,7 +67,9 @@ class SISUGenerator(nn.Module):
*[ResidualInResidualBlock(channels) for _ in range(num_rirb)] *[ResidualInResidualBlock(channels) for _ in range(num_rirb)]
) )
self.final_layer = nn.Conv1d(channels, 1, kernel_size=3, padding=1) self.final_layer = nn.Sequential(
nn.Conv1d(channels, 1, kernel_size=3, padding=1), nn.Tanh()
)
def forward(self, x): def forward(self, x):
residual_input = x residual_input = x
@@ -75,4 +78,4 @@ class SISUGenerator(nn.Module):
learned_residual = self.final_layer(x_rirb_out) learned_residual = self.final_layer(x_rirb_out)
output = residual_input + self.alpha * learned_residual output = residual_input + self.alpha * learned_residual
return output return torch.tanh(output)

View File

@@ -1,12 +0,0 @@
filelock==3.16.1
fsspec==2024.10.0
Jinja2==3.1.4
MarkupSafe==2.1.5
mpmath==1.3.0
networkx==3.4.2
numpy==2.2.3
pillow==11.0.0
setuptools==70.2.0
sympy==1.13.3
tqdm==4.67.1
typing_extensions==4.12.2

View File

@@ -4,25 +4,20 @@ import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torchaudio.transforms as T
import tqdm import tqdm
from torch.amp import GradScaler, autocast from accelerate import Accelerator
from torch.utils.data import DataLoader from torch.utils.data import DataLoader, DistributedSampler
import training_utils
from data import AudioDataset from data import AudioDataset
from discriminator import SISUDiscriminator from discriminator import SISUDiscriminator
from generator import SISUGenerator from generator import SISUGenerator
from training_utils import discriminator_train, generator_train from utils.TrainingTools import discriminator_train, generator_train
# --------------------------- # ---------------------------
# Argument parsing # Argument parsing
# --------------------------- # ---------------------------
parser = argparse.ArgumentParser(description="Training script (safer defaults)") parser = argparse.ArgumentParser(description="Training script (safer defaults)")
parser.add_argument("--resume", action="store_true", help="Resume training") parser.add_argument("--resume", action="store_true", help="Resume training")
parser.add_argument(
"--device", type=str, default="cuda", help="Device (cuda, cpu, mps)"
)
parser.add_argument( parser.add_argument(
"--epochs", type=int, default=5000, help="Number of training epochs" "--epochs", type=int, default=5000, help="Number of training epochs"
) )
@@ -35,86 +30,54 @@ parser.add_argument(
args = parser.parse_args() args = parser.parse_args()
# --------------------------- # ---------------------------
# Device setup # Init accelerator
# --------------------------- # ---------------------------
# Use requested device only if available
device = torch.device(
args.device if (args.device != "cuda" or torch.cuda.is_available()) else "cpu"
)
print(f"Using device: {device}")
# sensible performance flags
if device.type == "cuda":
torch.backends.cudnn.benchmark = True
# optional: torch.set_float32_matmul_precision("high")
debug = args.debug
# --------------------------- accelerator = Accelerator(mixed_precision="bf16")
# Audio transforms
# ---------------------------
sample_rate = 44100
n_fft = 1024
win_length = n_fft
hop_length = n_fft // 4
n_mels = 96
# n_mfcc = 13
# mfcc_transform = T.MFCC(
# sample_rate=sample_rate,
# n_mfcc=n_mfcc,
# melkwargs=dict(
# n_fft=n_fft,
# hop_length=hop_length,
# win_length=win_length,
# n_mels=n_mels,
# power=1.0,
# ),
# ).to(device)
mel_transform = T.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mels=n_mels,
power=1.0,
).to(device)
stft_transform = T.Spectrogram(
n_fft=n_fft, win_length=win_length, hop_length=hop_length
).to(device)
# training_utils.init(mel_transform, stft_transform, mfcc_transform)
training_utils.init(mel_transform, stft_transform)
# ---------------------------
# Dataset / DataLoader
# ---------------------------
dataset_dir = "./dataset/good"
dataset = AudioDataset(dataset_dir)
train_loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
persistent_workers=True,
)
# --------------------------- # ---------------------------
# Models # Models
# --------------------------- # ---------------------------
generator = SISUGenerator().to(device) generator = SISUGenerator()
discriminator = SISUDiscriminator().to(device) discriminator = SISUDiscriminator()
accelerator.print("🔨 | Compiling models...")
generator = torch.compile(generator) generator = torch.compile(generator)
discriminator = torch.compile(discriminator) discriminator = torch.compile(discriminator)
accelerator.print("✅ | Compiling done!")
# ---------------------------
# Dataset / DataLoader
# ---------------------------
accelerator.print("📊 | Fetching dataset...")
dataset = AudioDataset("./dataset")
sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None
pin_memory = torch.cuda.is_available() and not args.no_pin_memory
train_loader = DataLoader(
dataset,
sampler=sampler,
batch_size=args.batch_size,
shuffle=(sampler is None),
num_workers=args.num_workers,
pin_memory=pin_memory,
persistent_workers=pin_memory,
)
if not train_loader or not train_loader.batch_size or train_loader.batch_size == 0:
accelerator.print("🪹 | There is no data to train with! Exiting...")
exit()
loader_batch_size = train_loader.batch_size
accelerator.print("✅ | Dataset fetched!")
# --------------------------- # ---------------------------
# Losses / Optimizers / Scalers # Losses / Optimizers / Scalers
# --------------------------- # ---------------------------
criterion_g = nn.BCEWithLogitsLoss()
criterion_d = nn.BCEWithLogitsLoss()
optimizer_g = optim.AdamW( optimizer_g = optim.AdamW(
generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001
@@ -123,9 +86,6 @@ optimizer_d = optim.AdamW(
discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001
) )
# Use modern GradScaler signature; choose device_type based on runtime device.
scaler = GradScaler(device=device)
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau( scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_g, mode="min", factor=0.5, patience=5 optimizer_g, mode="min", factor=0.5, patience=5
) )
@@ -133,6 +93,17 @@ scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_d, mode="min", factor=0.5, patience=5 optimizer_d, mode="min", factor=0.5, patience=5
) )
criterion_g = nn.BCEWithLogitsLoss()
criterion_d = nn.MSELoss()
# ---------------------------
# Prepare accelerator
# ---------------------------
generator, discriminator, optimizer_g, optimizer_d, train_loader = accelerator.prepare(
generator, discriminator, optimizer_g, optimizer_d, train_loader
)
# --------------------------- # ---------------------------
# Checkpoint helpers # Checkpoint helpers
# --------------------------- # ---------------------------
@@ -141,14 +112,15 @@ os.makedirs(models_dir, exist_ok=True)
def save_ckpt(path, epoch): def save_ckpt(path, epoch):
torch.save( accelerator.wait_for_everyone()
if accelerator.is_main_process:
accelerator.save(
{ {
"epoch": epoch, "epoch": epoch,
"G": generator.state_dict(), "G": accelerator.unwrap_model(generator).state_dict(),
"D": discriminator.state_dict(), "D": accelerator.unwrap_model(discriminator).state_dict(),
"optG": optimizer_g.state_dict(), "optG": optimizer_g.state_dict(),
"optD": optimizer_d.state_dict(), "optD": optimizer_d.state_dict(),
"scaler": scaler.state_dict(),
"schedG": scheduler_g.state_dict(), "schedG": scheduler_g.state_dict(),
"schedD": scheduler_d.state_dict(), "schedD": scheduler_d.state_dict(),
}, },
@@ -158,27 +130,27 @@ def save_ckpt(path, epoch):
start_epoch = 0 start_epoch = 0
if args.resume: if args.resume:
ckpt = torch.load(os.path.join(models_dir, "last.pt"), map_location=device) ckpt_path = os.path.join(models_dir, "last.pt")
generator.load_state_dict(ckpt["G"]) ckpt = torch.load(ckpt_path)
discriminator.load_state_dict(ckpt["D"])
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"])
optimizer_g.load_state_dict(ckpt["optG"]) optimizer_g.load_state_dict(ckpt["optG"])
optimizer_d.load_state_dict(ckpt["optD"]) optimizer_d.load_state_dict(ckpt["optD"])
scaler.load_state_dict(ckpt["scaler"])
scheduler_g.load_state_dict(ckpt["schedG"]) scheduler_g.load_state_dict(ckpt["schedG"])
scheduler_d.load_state_dict(ckpt["schedD"]) scheduler_d.load_state_dict(ckpt["schedD"])
start_epoch = ckpt.get("epoch", 1) start_epoch = ckpt.get("epoch", 1)
accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!")
# --------------------------- real_buf = torch.full(
# Training loop (safer) (loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32
# --------------------------- )
fake_buf = torch.zeros(
(loader_batch_size, 1), device=accelerator.device, dtype=torch.float32
)
if not train_loader or not train_loader.batch_size: accelerator.print("🏋️ | Started training...")
print("There is no data to train with! Exiting...")
exit()
max_batch = max(1, train_loader.batch_size)
real_buf = torch.full((max_batch, 1), 0.9, device=device) # label smoothing
fake_buf = torch.zeros(max_batch, 1, device=device)
try: try:
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
@@ -193,15 +165,12 @@ try:
) in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")): ) in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")):
batch_size = high_quality.size(0) batch_size = high_quality.size(0)
high_quality = high_quality.to(device, non_blocking=True) real_labels = real_buf[:batch_size].to(accelerator.device)
low_quality = low_quality.to(device, non_blocking=True) fake_labels = fake_buf[:batch_size].to(accelerator.device)
real_labels = real_buf[:batch_size]
fake_labels = fake_buf[:batch_size]
# --- Discriminator --- # --- Discriminator ---
optimizer_d.zero_grad(set_to_none=True) optimizer_d.zero_grad(set_to_none=True)
with autocast(device_type=device.type): with accelerator.autocast():
d_loss = discriminator_train( d_loss = discriminator_train(
high_quality, high_quality,
low_quality, low_quality,
@@ -212,15 +181,14 @@ try:
criterion_d, criterion_d,
) )
scaler.scale(d_loss).backward() accelerator.backward(d_loss)
scaler.unscale_(optimizer_d) torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1)
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0) optimizer_d.step()
scaler.step(optimizer_d)
# --- Generator --- # --- Generator ---
optimizer_g.zero_grad(set_to_none=True) optimizer_g.zero_grad(set_to_none=True)
with autocast(device_type=device.type): with accelerator.autocast():
g_out, g_total, g_adv = generator_train( g_total, g_adv = generator_train(
low_quality, low_quality,
high_quality, high_quality,
real_labels, real_labels,
@@ -229,20 +197,32 @@ try:
criterion_d, criterion_d,
) )
scaler.scale(g_total).backward() accelerator.backward(g_total)
scaler.unscale_(optimizer_g) torch.nn.utils.clip_grad_norm_(generator.parameters(), 1)
torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0) optimizer_g.step()
scaler.step(optimizer_g)
scaler.update() d_val = accelerator.gather(d_loss.detach()).mean()
g_val = accelerator.gather(g_total.detach()).mean()
if torch.isfinite(d_val):
running_d += d_val.item()
else:
accelerator.print(
f"🫥 | NaN in discriminator loss at step {i}, skipping update."
)
if torch.isfinite(g_val):
running_g += g_val.item()
else:
accelerator.print(
f"🫥 | NaN in generator loss at step {i}, skipping update."
)
running_d += float(d_loss.detach().cpu().item())
running_g += float(g_total.detach().cpu().item())
steps += 1 steps += 1
# epoch averages & schedulers # epoch averages & schedulers
if steps == 0: if steps == 0:
print("No steps in epoch (empty dataloader?). Exiting.") accelerator.print("🪹 | No steps in epoch (empty dataloader?). Exiting.")
break break
mean_d = running_d / steps mean_d = running_d / steps
@@ -252,22 +232,14 @@ try:
scheduler_g.step(mean_g) scheduler_g.step(mean_g)
save_ckpt(os.path.join(models_dir, "last.pt"), epoch) save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
print(f"Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}") accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
except Exception: except Exception:
try: try:
save_ckpt(os.path.join(models_dir, "crash_last.pt"), epoch) save_ckpt(os.path.join(models_dir, "crash_last.pt"), epoch)
print(f"Saved crash checkpoint for epoch {epoch}") accelerator.print(f"💾 | Saved crash checkpoint for epoch {epoch}")
except Exception as e: except Exception as e:
print("Failed saving crash checkpoint:", e) accelerator.print("😬 | Failed saving crash checkpoint:", e)
raise raise
try: accelerator.print("🏁 | Training finished.")
torch.save(generator.state_dict(), os.path.join(models_dir, "final_generator.pt"))
torch.save(
discriminator.state_dict(), os.path.join(models_dir, "final_discriminator.pt")
)
except Exception as e:
print("Failed to save final states:", e)
print("Training finished.")

View File

@@ -1,154 +0,0 @@
import torch
import torchaudio.transforms as T
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
mel_transform: T.MelSpectrogram
stft_transform: T.Spectrogram
# mfcc_transform: T.MFCC
# def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram, mfcc_trans: T.MFCC):
# """Initializes the global transform variables for the module."""
# global mel_transform, stft_transform, mfcc_transform
# mel_transform = mel_trans
# stft_transform = stft_trans
# mfcc_transform = mfcc_trans
def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram):
"""Initializes the global transform variables for the module."""
global mel_transform, stft_transform
mel_transform = mel_trans
stft_transform = stft_trans
# def mfcc_loss(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
# """Computes the Mean Squared Error (MSE) loss on MFCCs."""
# mfccs_true = mfcc_transform(y_true)
# mfccs_pred = mfcc_transform(y_pred)
# return F.mse_loss(mfccs_pred, mfccs_true)
# def mel_spectrogram_loss(
# y_true: torch.Tensor, y_pred: torch.Tensor, loss_type: str = "l1"
# ) -> torch.Tensor:
# """Calculates L1 or L2 loss on the Mel Spectrogram."""
# mel_spec_true = mel_transform(y_true)
# mel_spec_pred = mel_transform(y_pred)
# if loss_type == "l1":
# return F.l1_loss(mel_spec_pred, mel_spec_true)
# elif loss_type == "l2":
# return F.mse_loss(mel_spec_pred, mel_spec_true)
# else:
# raise ValueError("loss_type must be 'l1' or 'l2'")
# def log_stft_magnitude_loss(
# y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7
# ) -> torch.Tensor:
# """Calculates L1 loss on the log STFT magnitude."""
# stft_mag_true = stft_transform(y_true)
# stft_mag_pred = stft_transform(y_pred)
# return F.l1_loss(torch.log(stft_mag_pred + eps), torch.log(stft_mag_true + eps))
stft_loss_fn = MultiResolutionSTFTLoss(
fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240]
)
def discriminator_train(
high_quality,
low_quality,
real_labels,
fake_labels,
discriminator,
generator,
criterion,
):
discriminator_decision_from_real = discriminator(high_quality)
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
with torch.no_grad():
generator_output = generator(low_quality)
discriminator_decision_from_fake = discriminator(generator_output)
d_loss_fake = criterion(
discriminator_decision_from_fake,
fake_labels.expand_as(discriminator_decision_from_fake),
)
d_loss = (d_loss_real + d_loss_fake) / 2.0
return d_loss
def generator_train(
low_quality,
high_quality,
real_labels,
generator,
discriminator,
adv_criterion,
lambda_adv: float = 1.0,
lambda_feat: float = 10.0,
lambda_stft: float = 2.5,
):
generator_output = generator(low_quality)
discriminator_decision = discriminator(generator_output)
# adversarial_loss = adv_criterion(
# discriminator_decision, real_labels.expand_as(discriminator_decision)
# )
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
combined_loss = lambda_adv * adversarial_loss
stft_losses = stft_loss_fn(high_quality, generator_output)
stft_loss = stft_losses["total"]
combined_loss = (lambda_adv * adversarial_loss) + (lambda_stft * stft_loss)
return generator_output, combined_loss, adversarial_loss
# def generator_train(
# low_quality,
# high_quality,
# real_labels,
# generator,
# discriminator,
# adv_criterion,
# lambda_adv: float = 1.0,
# lambda_mel_l1: float = 10.0,
# lambda_log_stft: float = 1.0,
# ):
# generator_output = generator(low_quality)
# discriminator_decision = discriminator(generator_output)
# adversarial_loss = adv_criterion(
# discriminator_decision, real_labels.expand_as(discriminator_decision)
# )
# combined_loss = lambda_adv * adversarial_loss
# if lambda_mel_l1 > 0:
# mel_l1_loss = mel_spectrogram_loss(high_quality, generator_output, "l1")
# combined_loss += lambda_mel_l1 * mel_l1_loss
# else:
# mel_l1_loss = torch.tensor(0.0, device=low_quality.device) # For logging
# if lambda_log_stft > 0:
# log_stft_loss = log_stft_magnitude_loss(high_quality, generator_output)
# combined_loss += lambda_log_stft * log_stft_loss
# else:
# log_stft_loss = torch.tensor(0.0, device=low_quality.device)
# if lambda_mfcc > 0:
# mfcc_loss_val = mfcc_loss(high_quality, generator_output)
# combined_loss += lambda_mfcc * mfcc_loss_val
# else:
# mfcc_loss_val = torch.tensor(0.0, device=low_quality.device)
# return generator_output, combined_loss, adversarial_loss

View File

@@ -8,8 +8,9 @@ import torchaudio.transforms as T
class MultiResolutionSTFTLoss(nn.Module): class MultiResolutionSTFTLoss(nn.Module):
""" """
Computes a loss based on multiple STFT resolutions, including both Multi-resolution STFT loss.
spectral convergence and log STFT magnitude components. Combines spectral convergence loss and log-magnitude loss
across multiple STFT resolutions.
""" """
def __init__( def __init__(
@@ -20,43 +21,67 @@ class MultiResolutionSTFTLoss(nn.Module):
eps: float = 1e-7, eps: float = 1e-7,
): ):
super().__init__() super().__init__()
self.stft_transforms = nn.ModuleList(
[
T.Spectrogram(
n_fft=n_fft, win_length=win_len, hop_length=hop_len, power=None
)
for n_fft, hop_len, win_len in zip(fft_sizes, hop_sizes, win_lengths)
]
)
self.eps = eps self.eps = eps
self.n_resolutions = len(fft_sizes)
self.stft_transforms = nn.ModuleList()
for n_fft, hop_len, win_len in zip(fft_sizes, hop_sizes, win_lengths):
window = torch.hann_window(win_len)
stft = T.Spectrogram(
n_fft=n_fft,
hop_length=hop_len,
win_length=win_len,
window_fn=lambda _: window,
power=None, # Keep complex output
center=True,
pad_mode="reflect",
normalized=False,
)
self.stft_transforms.append(stft)
def forward( def forward(
self, y_true: torch.Tensor, y_pred: torch.Tensor self, y_true: torch.Tensor, y_pred: torch.Tensor
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
sc_loss = 0.0 # Spectral Convergence Loss """
mag_loss = 0.0 # Log STFT Magnitude Loss Args:
y_true: (B, T) or (B, 1, T) waveform
y_pred: (B, T) or (B, 1, T) waveform
"""
# Ensure correct shape (B, T)
if y_true.dim() == 3 and y_true.size(1) == 1:
y_true = y_true.squeeze(1)
if y_pred.dim() == 3 and y_pred.size(1) == 1:
y_pred = y_pred.squeeze(1)
sc_loss = 0.0
mag_loss = 0.0
for stft in self.stft_transforms: for stft in self.stft_transforms:
stft.to(y_pred.device) # Ensure transform is on the correct device stft = stft.to(y_pred.device)
# Get complex STFTs # Complex STFTs: (B, F, T, 2)
stft_true = stft(y_true) stft_true = stft(y_true)
stft_pred = stft(y_pred) stft_pred = stft(y_pred)
# Get magnitudes # Magnitudes
stft_mag_true = torch.abs(stft_true) stft_mag_true = torch.abs(stft_true)
stft_mag_pred = torch.abs(stft_pred) stft_mag_pred = torch.abs(stft_pred)
# --- Spectral Convergence Loss --- # --- Spectral Convergence Loss ---
# || |S_true| - |S_pred| ||_F / || |S_true| ||_F
norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1)) norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1))
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1)) norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1))
sc_loss += torch.mean(norm_diff / (norm_true + self.eps)) sc_loss += torch.mean(norm_diff / (norm_true + self.eps))
# --- Log STFT Magnitude Loss --- # --- Log STFT Magnitude Loss ---
mag_loss += F.l1_loss( mag_loss += F.l1_loss(
torch.log(stft_mag_pred + self.eps), torch.log(stft_mag_true + self.eps) torch.log(stft_mag_pred + self.eps),
torch.log(stft_mag_true + self.eps),
) )
# Average across resolutions
sc_loss /= self.n_resolutions
mag_loss /= self.n_resolutions
total_loss = sc_loss + mag_loss total_loss = sc_loss + mag_loss
return {"total": total_loss, "sc": sc_loss, "mag": mag_loss} return {"total": total_loss, "sc": sc_loss, "mag": mag_loss}

60
utils/TrainingTools.py Normal file
View File

@@ -0,0 +1,60 @@
import torch
# In case if needed again...
# from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
#
# stft_loss_fn = MultiResolutionSTFTLoss(
# fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240]
# )
def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor:
absolute_difference = torch.abs(input_one - input_two)
return torch.mean(absolute_difference)
def discriminator_train(
high_quality,
low_quality,
high_labels,
low_labels,
discriminator,
generator,
criterion,
):
decision_high = discriminator(high_quality)
d_loss_high = criterion(decision_high, high_labels)
# print(f"Is this real?: {discriminator_decision_from_real} | {d_loss_real}")
decision_low = discriminator(low_quality)
d_loss_low = criterion(decision_low, low_labels)
# print(f"Is this real?: {discriminator_decision_from_fake} | {d_loss_fake}")
with torch.no_grad():
generator_quality = generator(low_quality)
decision_gen = discriminator(generator_quality)
d_loss_gen = criterion(decision_gen, low_labels)
noise = torch.rand_like(high_quality) * 0.08
decision_noise = discriminator(high_quality + noise)
d_loss_noise = criterion(decision_noise, low_labels)
d_loss = (d_loss_high + d_loss_low + d_loss_gen + d_loss_noise) / 4.0
return d_loss
def generator_train(
low_quality, high_quality, real_labels, generator, discriminator, adv_criterion
):
generator_output = generator(low_quality)
discriminator_decision = discriminator(generator_output)
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
# Signal similarity
similarity_loss = signal_mae(generator_output, high_quality)
combined_loss = adversarial_loss + (similarity_loss * 100)
return combined_loss, adversarial_loss