⚗️ | Added some stupid ways for training + some makeup
This commit is contained in:
@@ -1,71 +1,97 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
def stereo_tensor_to_mono(waveform):
|
||||
|
||||
def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert stereo (C, N) to mono (1, N). Ensures a channel dimension.
|
||||
"""
|
||||
if waveform.dim() == 1:
|
||||
waveform = waveform.unsqueeze(0) # (N,) -> (1, N)
|
||||
|
||||
if waveform.shape[0] > 1:
|
||||
# Average across channels
|
||||
mono_waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
mono_waveform = torch.mean(waveform, dim=0, keepdim=True) # (1, N)
|
||||
else:
|
||||
# Already mono
|
||||
mono_waveform = waveform
|
||||
return mono_waveform
|
||||
|
||||
def stretch_tensor(tensor, target_length):
|
||||
scale_factor = target_length / tensor.size(1)
|
||||
|
||||
tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False)
|
||||
def stretch_tensor(tensor: torch.Tensor, target_length: int) -> torch.Tensor:
|
||||
"""
|
||||
Stretch audio along time dimension to target_length.
|
||||
Input assumed (1, N). Returns (1, target_length).
|
||||
"""
|
||||
if tensor.dim() == 1:
|
||||
tensor = tensor.unsqueeze(0) # ensure (1, N)
|
||||
|
||||
return tensor
|
||||
tensor = tensor.unsqueeze(0) # (1, 1, N) for interpolate
|
||||
stretched = F.interpolate(
|
||||
tensor, size=target_length, mode="linear", align_corners=False
|
||||
)
|
||||
return stretched.squeeze(0) # back to (1, target_length)
|
||||
|
||||
|
||||
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128) -> torch.Tensor:
|
||||
"""
|
||||
Pad to fixed length. Input assumed (1, N). Returns (1, target_length).
|
||||
"""
|
||||
if audio_tensor.dim() == 1:
|
||||
audio_tensor = audio_tensor.unsqueeze(0)
|
||||
|
||||
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128):
|
||||
current_length = audio_tensor.shape[-1]
|
||||
|
||||
if current_length < target_length:
|
||||
padding_needed = target_length - current_length
|
||||
|
||||
padding_tuple = (0, padding_needed)
|
||||
padded_audio_tensor = F.pad(audio_tensor, padding_tuple, mode='constant', value=0)
|
||||
padded_audio_tensor = F.pad(
|
||||
audio_tensor, padding_tuple, mode="constant", value=0
|
||||
)
|
||||
else:
|
||||
padded_audio_tensor = audio_tensor
|
||||
padded_audio_tensor = audio_tensor[..., :target_length] # crop if too long
|
||||
|
||||
return padded_audio_tensor
|
||||
|
||||
def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 128) -> list[torch.Tensor]:
|
||||
|
||||
def split_audio(
|
||||
audio_tensor: torch.Tensor, chunk_size: int = 128
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
Split into chunks of (1, chunk_size).
|
||||
"""
|
||||
if not isinstance(chunk_size, int) or chunk_size <= 0:
|
||||
raise ValueError("chunk_size must be a positive integer.")
|
||||
|
||||
# Handle scalar tensor edge case if necessary
|
||||
if audio_tensor.dim() == 0:
|
||||
return [audio_tensor] if audio_tensor.numel() > 0 else []
|
||||
|
||||
# Identify the dimension to split (usually the last one, representing time/samples)
|
||||
split_dim = -1
|
||||
num_samples = audio_tensor.shape[split_dim]
|
||||
if audio_tensor.dim() == 1:
|
||||
audio_tensor = audio_tensor.unsqueeze(0)
|
||||
|
||||
num_samples = audio_tensor.shape[-1]
|
||||
if num_samples == 0:
|
||||
return [] # Return empty list if the dimension to split is empty
|
||||
|
||||
# Use torch.split to divide the tensor into chunks
|
||||
# It handles the last chunk being potentially smaller automatically.
|
||||
chunks = list(torch.split(audio_tensor, chunk_size, dim=split_dim))
|
||||
return []
|
||||
|
||||
chunks = list(torch.split(audio_tensor, chunk_size, dim=-1))
|
||||
return chunks
|
||||
|
||||
|
||||
def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Reconstruct audio from chunks. Returns (1, N).
|
||||
"""
|
||||
if not chunks:
|
||||
return torch.empty(0)
|
||||
|
||||
if len(chunks) == 1 and chunks[0].dim() == 0:
|
||||
return chunks[0]
|
||||
|
||||
concat_dim = -1
|
||||
return torch.empty(1, 0)
|
||||
|
||||
chunks = [c if c.dim() == 2 else c.unsqueeze(0) for c in chunks]
|
||||
try:
|
||||
reconstructed_tensor = torch.cat(chunks, dim=concat_dim)
|
||||
reconstructed_tensor = torch.cat(chunks, dim=-1)
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to concatenate audio chunks. Ensure chunks have compatible shapes "
|
||||
f"for concatenation along dimension {concat_dim}. Original error: {e}"
|
||||
f"for concatenation along dim -1. Original error: {e}"
|
||||
)
|
||||
|
||||
return reconstructed_tensor
|
||||
|
||||
|
||||
def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
||||
max_val = torch.max(torch.abs(audio_tensor))
|
||||
if max_val < eps:
|
||||
return audio_tensor # silence, skip normalization
|
||||
return audio_tensor / max_val
|
||||
|
0
__init__.py
Normal file
0
__init__.py
Normal file
1
app.py
1
app.py
@@ -68,6 +68,7 @@ def start():
|
||||
original_sample_rate = decoded_samples.sample_rate
|
||||
|
||||
audio = AudioUtils.stereo_tensor_to_mono(audio)
|
||||
audio = AudioUtils.normalize(audio)
|
||||
|
||||
resample_transform = torchaudio.transforms.Resample(
|
||||
original_sample_rate, args.sample_rate
|
||||
|
34
data.py
34
data.py
@@ -12,12 +12,15 @@ import AudioUtils
|
||||
class AudioDataset(Dataset):
|
||||
audio_sample_rates = [11025]
|
||||
|
||||
def __init__(self, input_dir, clip_length=16384):
|
||||
def __init__(self, input_dir, clip_length: int = 8000, normalize: bool = True):
|
||||
self.clip_length = clip_length
|
||||
self.normalize = normalize
|
||||
|
||||
input_files = [
|
||||
os.path.join(root, f)
|
||||
for root, _, files in os.walk(input_dir)
|
||||
for f in files
|
||||
if f.endswith(".wav") or f.endswith(".mp3") or f.endswith(".flac")
|
||||
os.path.join(input_dir, f)
|
||||
for f in os.listdir(input_dir)
|
||||
if os.path.isfile(os.path.join(input_dir, f))
|
||||
and f.lower().endswith((".wav", ".mp3", ".flac"))
|
||||
]
|
||||
|
||||
data = []
|
||||
@@ -25,14 +28,15 @@ class AudioDataset(Dataset):
|
||||
input_files, desc=f"Processing {len(input_files)} audio file(s)"
|
||||
):
|
||||
decoder = decoders.AudioDecoder(audio_clip)
|
||||
|
||||
decoded_samples = decoder.get_all_samples()
|
||||
audio = decoded_samples.data
|
||||
|
||||
audio = decoded_samples.data.float() # ensure float32
|
||||
original_sample_rate = decoded_samples.sample_rate
|
||||
|
||||
audio = AudioUtils.stereo_tensor_to_mono(audio)
|
||||
if normalize:
|
||||
audio = AudioUtils.normalize(audio)
|
||||
|
||||
# Generate low-quality audio with random downsampling
|
||||
mangled_sample_rate = random.choice(self.audio_sample_rates)
|
||||
resample_transform_low = torchaudio.transforms.Resample(
|
||||
original_sample_rate, mangled_sample_rate
|
||||
@@ -41,25 +45,27 @@ class AudioDataset(Dataset):
|
||||
mangled_sample_rate, original_sample_rate
|
||||
)
|
||||
|
||||
low_audio = resample_transform_low(audio)
|
||||
low_audio = resample_transform_high(low_audio)
|
||||
low_audio = resample_transform_high(resample_transform_low(audio))
|
||||
|
||||
splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length)
|
||||
splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length)
|
||||
|
||||
if not splitted_high_quality_audio or not splitted_low_quality_audio:
|
||||
continue # skip empty or invalid clips
|
||||
|
||||
splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(
|
||||
splitted_high_quality_audio[-1], clip_length
|
||||
)
|
||||
|
||||
splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length)
|
||||
splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(
|
||||
splitted_low_quality_audio[-1], clip_length
|
||||
)
|
||||
|
||||
for high_quality_sample, low_quality_sample in zip(
|
||||
for high_quality_data, low_quality_data in zip(
|
||||
splitted_high_quality_audio, splitted_low_quality_audio
|
||||
):
|
||||
data.append(
|
||||
(
|
||||
(high_quality_sample, low_quality_sample),
|
||||
(high_quality_data, low_quality_data),
|
||||
(original_sample_rate, mangled_sample_rate),
|
||||
)
|
||||
)
|
||||
|
@@ -49,74 +49,18 @@ class AttentionBlock(nn.Module):
|
||||
|
||||
|
||||
class SISUDiscriminator(nn.Module):
|
||||
def __init__(self, base_channels=16):
|
||||
def __init__(self, layers=32):
|
||||
super(SISUDiscriminator, self).__init__()
|
||||
layers = base_channels
|
||||
self.model = nn.Sequential(
|
||||
discriminator_block(
|
||||
1,
|
||||
layers,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
spectral_norm=True,
|
||||
use_instance_norm=False,
|
||||
),
|
||||
discriminator_block(
|
||||
layers,
|
||||
layers * 2,
|
||||
kernel_size=5,
|
||||
stride=2,
|
||||
spectral_norm=True,
|
||||
use_instance_norm=True,
|
||||
),
|
||||
discriminator_block(
|
||||
layers * 2,
|
||||
layers * 4,
|
||||
kernel_size=5,
|
||||
stride=1,
|
||||
dilation=2,
|
||||
spectral_norm=True,
|
||||
use_instance_norm=True,
|
||||
),
|
||||
discriminator_block(1, layers, kernel_size=7, stride=1),
|
||||
discriminator_block(layers, layers * 2, kernel_size=5, stride=2),
|
||||
discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=2),
|
||||
AttentionBlock(layers * 4),
|
||||
discriminator_block(
|
||||
layers * 4,
|
||||
layers * 8,
|
||||
kernel_size=5,
|
||||
stride=1,
|
||||
dilation=4,
|
||||
spectral_norm=True,
|
||||
use_instance_norm=True,
|
||||
),
|
||||
discriminator_block(
|
||||
layers * 8,
|
||||
layers * 4,
|
||||
kernel_size=5,
|
||||
stride=2,
|
||||
spectral_norm=True,
|
||||
use_instance_norm=True,
|
||||
),
|
||||
discriminator_block(
|
||||
layers * 4,
|
||||
layers * 2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
spectral_norm=True,
|
||||
use_instance_norm=True,
|
||||
),
|
||||
discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=4),
|
||||
discriminator_block(layers * 8, layers * 2, kernel_size=5, stride=2),
|
||||
discriminator_block(
|
||||
layers * 2,
|
||||
layers,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
spectral_norm=True,
|
||||
use_instance_norm=True,
|
||||
),
|
||||
discriminator_block(
|
||||
layers,
|
||||
1,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
spectral_norm=False,
|
||||
use_instance_norm=False,
|
||||
),
|
||||
|
@@ -1,30 +0,0 @@
|
||||
import json
|
||||
|
||||
filepath = "my_data.json"
|
||||
|
||||
def write_data(filepath, data, debug=False):
|
||||
try:
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(data, f, indent=4) # Use indent for pretty formatting
|
||||
if debug:
|
||||
print(f"Data written to '{filepath}'")
|
||||
except Exception as e:
|
||||
print(f"Error writing to file: {e}")
|
||||
|
||||
|
||||
def read_data(filepath, debug=False):
|
||||
try:
|
||||
with open(filepath, 'r') as f:
|
||||
data = json.load(f)
|
||||
if debug:
|
||||
print(f"Data read from '{filepath}'")
|
||||
return data
|
||||
except FileNotFoundError:
|
||||
print(f"File not found: {filepath}")
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error decoding JSON from file: {filepath}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error reading from file: {e}")
|
||||
return None
|
@@ -1,3 +1,4 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
@@ -52,7 +53,7 @@ class ResidualInResidualBlock(nn.Module):
|
||||
|
||||
|
||||
class SISUGenerator(nn.Module):
|
||||
def __init__(self, channels=16, num_rirb=4, alpha=1.0):
|
||||
def __init__(self, channels=16, num_rirb=4, alpha=1):
|
||||
super(SISUGenerator, self).__init__()
|
||||
self.alpha = alpha
|
||||
|
||||
@@ -66,7 +67,9 @@ class SISUGenerator(nn.Module):
|
||||
*[ResidualInResidualBlock(channels) for _ in range(num_rirb)]
|
||||
)
|
||||
|
||||
self.final_layer = nn.Conv1d(channels, 1, kernel_size=3, padding=1)
|
||||
self.final_layer = nn.Sequential(
|
||||
nn.Conv1d(channels, 1, kernel_size=3, padding=1), nn.Tanh()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
residual_input = x
|
||||
@@ -75,4 +78,4 @@ class SISUGenerator(nn.Module):
|
||||
learned_residual = self.final_layer(x_rirb_out)
|
||||
output = residual_input + self.alpha * learned_residual
|
||||
|
||||
return output
|
||||
return torch.tanh(output)
|
||||
|
@@ -1,12 +0,0 @@
|
||||
filelock==3.16.1
|
||||
fsspec==2024.10.0
|
||||
Jinja2==3.1.4
|
||||
MarkupSafe==2.1.5
|
||||
mpmath==1.3.0
|
||||
networkx==3.4.2
|
||||
numpy==2.2.3
|
||||
pillow==11.0.0
|
||||
setuptools==70.2.0
|
||||
sympy==1.13.3
|
||||
tqdm==4.67.1
|
||||
typing_extensions==4.12.2
|
246
training.py
246
training.py
@@ -4,25 +4,20 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchaudio.transforms as T
|
||||
import tqdm
|
||||
from torch.amp import GradScaler, autocast
|
||||
from torch.utils.data import DataLoader
|
||||
from accelerate import Accelerator
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
||||
import training_utils
|
||||
from data import AudioDataset
|
||||
from discriminator import SISUDiscriminator
|
||||
from generator import SISUGenerator
|
||||
from training_utils import discriminator_train, generator_train
|
||||
from utils.TrainingTools import discriminator_train, generator_train
|
||||
|
||||
# ---------------------------
|
||||
# Argument parsing
|
||||
# ---------------------------
|
||||
parser = argparse.ArgumentParser(description="Training script (safer defaults)")
|
||||
parser.add_argument("--resume", action="store_true", help="Resume training")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda", help="Device (cuda, cpu, mps)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs", type=int, default=5000, help="Number of training epochs"
|
||||
)
|
||||
@@ -35,86 +30,54 @@ parser.add_argument(
|
||||
args = parser.parse_args()
|
||||
|
||||
# ---------------------------
|
||||
# Device setup
|
||||
# Init accelerator
|
||||
# ---------------------------
|
||||
# Use requested device only if available
|
||||
device = torch.device(
|
||||
args.device if (args.device != "cuda" or torch.cuda.is_available()) else "cpu"
|
||||
)
|
||||
print(f"Using device: {device}")
|
||||
# sensible performance flags
|
||||
if device.type == "cuda":
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# optional: torch.set_float32_matmul_precision("high")
|
||||
debug = args.debug
|
||||
|
||||
# ---------------------------
|
||||
# Audio transforms
|
||||
# ---------------------------
|
||||
sample_rate = 44100
|
||||
n_fft = 1024
|
||||
win_length = n_fft
|
||||
hop_length = n_fft // 4
|
||||
n_mels = 96
|
||||
# n_mfcc = 13
|
||||
|
||||
# mfcc_transform = T.MFCC(
|
||||
# sample_rate=sample_rate,
|
||||
# n_mfcc=n_mfcc,
|
||||
# melkwargs=dict(
|
||||
# n_fft=n_fft,
|
||||
# hop_length=hop_length,
|
||||
# win_length=win_length,
|
||||
# n_mels=n_mels,
|
||||
# power=1.0,
|
||||
# ),
|
||||
# ).to(device)
|
||||
|
||||
mel_transform = T.MelSpectrogram(
|
||||
sample_rate=sample_rate,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
n_mels=n_mels,
|
||||
power=1.0,
|
||||
).to(device)
|
||||
|
||||
stft_transform = T.Spectrogram(
|
||||
n_fft=n_fft, win_length=win_length, hop_length=hop_length
|
||||
).to(device)
|
||||
|
||||
# training_utils.init(mel_transform, stft_transform, mfcc_transform)
|
||||
training_utils.init(mel_transform, stft_transform)
|
||||
|
||||
# ---------------------------
|
||||
# Dataset / DataLoader
|
||||
# ---------------------------
|
||||
dataset_dir = "./dataset/good"
|
||||
dataset = AudioDataset(dataset_dir)
|
||||
|
||||
train_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
)
|
||||
accelerator = Accelerator(mixed_precision="bf16")
|
||||
|
||||
# ---------------------------
|
||||
# Models
|
||||
# ---------------------------
|
||||
generator = SISUGenerator().to(device)
|
||||
discriminator = SISUDiscriminator().to(device)
|
||||
generator = SISUGenerator()
|
||||
discriminator = SISUDiscriminator()
|
||||
|
||||
accelerator.print("🔨 | Compiling models...")
|
||||
|
||||
generator = torch.compile(generator)
|
||||
discriminator = torch.compile(discriminator)
|
||||
|
||||
accelerator.print("✅ | Compiling done!")
|
||||
|
||||
# ---------------------------
|
||||
# Dataset / DataLoader
|
||||
# ---------------------------
|
||||
accelerator.print("📊 | Fetching dataset...")
|
||||
dataset = AudioDataset("./dataset")
|
||||
|
||||
sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None
|
||||
pin_memory = torch.cuda.is_available() and not args.no_pin_memory
|
||||
|
||||
train_loader = DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=(sampler is None),
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=pin_memory,
|
||||
persistent_workers=pin_memory,
|
||||
)
|
||||
|
||||
if not train_loader or not train_loader.batch_size or train_loader.batch_size == 0:
|
||||
accelerator.print("🪹 | There is no data to train with! Exiting...")
|
||||
exit()
|
||||
|
||||
loader_batch_size = train_loader.batch_size
|
||||
|
||||
accelerator.print("✅ | Dataset fetched!")
|
||||
|
||||
# ---------------------------
|
||||
# Losses / Optimizers / Scalers
|
||||
# ---------------------------
|
||||
criterion_g = nn.BCEWithLogitsLoss()
|
||||
criterion_d = nn.BCEWithLogitsLoss()
|
||||
|
||||
optimizer_g = optim.AdamW(
|
||||
generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001
|
||||
@@ -123,9 +86,6 @@ optimizer_d = optim.AdamW(
|
||||
discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001
|
||||
)
|
||||
|
||||
# Use modern GradScaler signature; choose device_type based on runtime device.
|
||||
scaler = GradScaler(device=device)
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer_g, mode="min", factor=0.5, patience=5
|
||||
)
|
||||
@@ -133,6 +93,17 @@ scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer_d, mode="min", factor=0.5, patience=5
|
||||
)
|
||||
|
||||
criterion_g = nn.BCEWithLogitsLoss()
|
||||
criterion_d = nn.MSELoss()
|
||||
|
||||
# ---------------------------
|
||||
# Prepare accelerator
|
||||
# ---------------------------
|
||||
|
||||
generator, discriminator, optimizer_g, optimizer_d, train_loader = accelerator.prepare(
|
||||
generator, discriminator, optimizer_g, optimizer_d, train_loader
|
||||
)
|
||||
|
||||
# ---------------------------
|
||||
# Checkpoint helpers
|
||||
# ---------------------------
|
||||
@@ -141,44 +112,45 @@ os.makedirs(models_dir, exist_ok=True)
|
||||
|
||||
|
||||
def save_ckpt(path, epoch):
|
||||
torch.save(
|
||||
{
|
||||
"epoch": epoch,
|
||||
"G": generator.state_dict(),
|
||||
"D": discriminator.state_dict(),
|
||||
"optG": optimizer_g.state_dict(),
|
||||
"optD": optimizer_d.state_dict(),
|
||||
"scaler": scaler.state_dict(),
|
||||
"schedG": scheduler_g.state_dict(),
|
||||
"schedD": scheduler_d.state_dict(),
|
||||
},
|
||||
path,
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
accelerator.save(
|
||||
{
|
||||
"epoch": epoch,
|
||||
"G": accelerator.unwrap_model(generator).state_dict(),
|
||||
"D": accelerator.unwrap_model(discriminator).state_dict(),
|
||||
"optG": optimizer_g.state_dict(),
|
||||
"optD": optimizer_d.state_dict(),
|
||||
"schedG": scheduler_g.state_dict(),
|
||||
"schedD": scheduler_d.state_dict(),
|
||||
},
|
||||
path,
|
||||
)
|
||||
|
||||
|
||||
start_epoch = 0
|
||||
if args.resume:
|
||||
ckpt = torch.load(os.path.join(models_dir, "last.pt"), map_location=device)
|
||||
generator.load_state_dict(ckpt["G"])
|
||||
discriminator.load_state_dict(ckpt["D"])
|
||||
ckpt_path = os.path.join(models_dir, "last.pt")
|
||||
ckpt = torch.load(ckpt_path)
|
||||
|
||||
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
|
||||
accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"])
|
||||
optimizer_g.load_state_dict(ckpt["optG"])
|
||||
optimizer_d.load_state_dict(ckpt["optD"])
|
||||
scaler.load_state_dict(ckpt["scaler"])
|
||||
scheduler_g.load_state_dict(ckpt["schedG"])
|
||||
scheduler_d.load_state_dict(ckpt["schedD"])
|
||||
|
||||
start_epoch = ckpt.get("epoch", 1)
|
||||
accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!")
|
||||
|
||||
# ---------------------------
|
||||
# Training loop (safer)
|
||||
# ---------------------------
|
||||
real_buf = torch.full(
|
||||
(loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32
|
||||
)
|
||||
fake_buf = torch.zeros(
|
||||
(loader_batch_size, 1), device=accelerator.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
if not train_loader or not train_loader.batch_size:
|
||||
print("There is no data to train with! Exiting...")
|
||||
exit()
|
||||
|
||||
max_batch = max(1, train_loader.batch_size)
|
||||
real_buf = torch.full((max_batch, 1), 0.9, device=device) # label smoothing
|
||||
fake_buf = torch.zeros(max_batch, 1, device=device)
|
||||
accelerator.print("🏋️ | Started training...")
|
||||
|
||||
try:
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
@@ -193,15 +165,12 @@ try:
|
||||
) in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")):
|
||||
batch_size = high_quality.size(0)
|
||||
|
||||
high_quality = high_quality.to(device, non_blocking=True)
|
||||
low_quality = low_quality.to(device, non_blocking=True)
|
||||
|
||||
real_labels = real_buf[:batch_size]
|
||||
fake_labels = fake_buf[:batch_size]
|
||||
real_labels = real_buf[:batch_size].to(accelerator.device)
|
||||
fake_labels = fake_buf[:batch_size].to(accelerator.device)
|
||||
|
||||
# --- Discriminator ---
|
||||
optimizer_d.zero_grad(set_to_none=True)
|
||||
with autocast(device_type=device.type):
|
||||
with accelerator.autocast():
|
||||
d_loss = discriminator_train(
|
||||
high_quality,
|
||||
low_quality,
|
||||
@@ -212,15 +181,14 @@ try:
|
||||
criterion_d,
|
||||
)
|
||||
|
||||
scaler.scale(d_loss).backward()
|
||||
scaler.unscale_(optimizer_d)
|
||||
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0)
|
||||
scaler.step(optimizer_d)
|
||||
accelerator.backward(d_loss)
|
||||
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1)
|
||||
optimizer_d.step()
|
||||
|
||||
# --- Generator ---
|
||||
optimizer_g.zero_grad(set_to_none=True)
|
||||
with autocast(device_type=device.type):
|
||||
g_out, g_total, g_adv = generator_train(
|
||||
with accelerator.autocast():
|
||||
g_total, g_adv = generator_train(
|
||||
low_quality,
|
||||
high_quality,
|
||||
real_labels,
|
||||
@@ -229,20 +197,32 @@ try:
|
||||
criterion_d,
|
||||
)
|
||||
|
||||
scaler.scale(g_total).backward()
|
||||
scaler.unscale_(optimizer_g)
|
||||
torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)
|
||||
scaler.step(optimizer_g)
|
||||
accelerator.backward(g_total)
|
||||
torch.nn.utils.clip_grad_norm_(generator.parameters(), 1)
|
||||
optimizer_g.step()
|
||||
|
||||
scaler.update()
|
||||
d_val = accelerator.gather(d_loss.detach()).mean()
|
||||
g_val = accelerator.gather(g_total.detach()).mean()
|
||||
|
||||
if torch.isfinite(d_val):
|
||||
running_d += d_val.item()
|
||||
else:
|
||||
accelerator.print(
|
||||
f"🫥 | NaN in discriminator loss at step {i}, skipping update."
|
||||
)
|
||||
|
||||
if torch.isfinite(g_val):
|
||||
running_g += g_val.item()
|
||||
else:
|
||||
accelerator.print(
|
||||
f"🫥 | NaN in generator loss at step {i}, skipping update."
|
||||
)
|
||||
|
||||
running_d += float(d_loss.detach().cpu().item())
|
||||
running_g += float(g_total.detach().cpu().item())
|
||||
steps += 1
|
||||
|
||||
# epoch averages & schedulers
|
||||
if steps == 0:
|
||||
print("No steps in epoch (empty dataloader?). Exiting.")
|
||||
accelerator.print("🪹 | No steps in epoch (empty dataloader?). Exiting.")
|
||||
break
|
||||
|
||||
mean_d = running_d / steps
|
||||
@@ -252,22 +232,14 @@ try:
|
||||
scheduler_g.step(mean_g)
|
||||
|
||||
save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
|
||||
print(f"Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
|
||||
accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
|
||||
|
||||
except Exception:
|
||||
try:
|
||||
save_ckpt(os.path.join(models_dir, "crash_last.pt"), epoch)
|
||||
print(f"Saved crash checkpoint for epoch {epoch}")
|
||||
accelerator.print(f"💾 | Saved crash checkpoint for epoch {epoch}")
|
||||
except Exception as e:
|
||||
print("Failed saving crash checkpoint:", e)
|
||||
accelerator.print("😬 | Failed saving crash checkpoint:", e)
|
||||
raise
|
||||
|
||||
try:
|
||||
torch.save(generator.state_dict(), os.path.join(models_dir, "final_generator.pt"))
|
||||
torch.save(
|
||||
discriminator.state_dict(), os.path.join(models_dir, "final_discriminator.pt")
|
||||
)
|
||||
except Exception as e:
|
||||
print("Failed to save final states:", e)
|
||||
|
||||
print("Training finished.")
|
||||
accelerator.print("🏁 | Training finished.")
|
||||
|
@@ -1,154 +0,0 @@
|
||||
import torch
|
||||
import torchaudio.transforms as T
|
||||
|
||||
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
|
||||
|
||||
mel_transform: T.MelSpectrogram
|
||||
stft_transform: T.Spectrogram
|
||||
# mfcc_transform: T.MFCC
|
||||
|
||||
|
||||
# def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram, mfcc_trans: T.MFCC):
|
||||
# """Initializes the global transform variables for the module."""
|
||||
# global mel_transform, stft_transform, mfcc_transform
|
||||
# mel_transform = mel_trans
|
||||
# stft_transform = stft_trans
|
||||
# mfcc_transform = mfcc_trans
|
||||
|
||||
|
||||
def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram):
|
||||
"""Initializes the global transform variables for the module."""
|
||||
global mel_transform, stft_transform
|
||||
mel_transform = mel_trans
|
||||
stft_transform = stft_trans
|
||||
|
||||
|
||||
# def mfcc_loss(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
|
||||
# """Computes the Mean Squared Error (MSE) loss on MFCCs."""
|
||||
# mfccs_true = mfcc_transform(y_true)
|
||||
# mfccs_pred = mfcc_transform(y_pred)
|
||||
# return F.mse_loss(mfccs_pred, mfccs_true)
|
||||
|
||||
|
||||
# def mel_spectrogram_loss(
|
||||
# y_true: torch.Tensor, y_pred: torch.Tensor, loss_type: str = "l1"
|
||||
# ) -> torch.Tensor:
|
||||
# """Calculates L1 or L2 loss on the Mel Spectrogram."""
|
||||
# mel_spec_true = mel_transform(y_true)
|
||||
# mel_spec_pred = mel_transform(y_pred)
|
||||
# if loss_type == "l1":
|
||||
# return F.l1_loss(mel_spec_pred, mel_spec_true)
|
||||
# elif loss_type == "l2":
|
||||
# return F.mse_loss(mel_spec_pred, mel_spec_true)
|
||||
# else:
|
||||
# raise ValueError("loss_type must be 'l1' or 'l2'")
|
||||
|
||||
|
||||
# def log_stft_magnitude_loss(
|
||||
# y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7
|
||||
# ) -> torch.Tensor:
|
||||
# """Calculates L1 loss on the log STFT magnitude."""
|
||||
# stft_mag_true = stft_transform(y_true)
|
||||
# stft_mag_pred = stft_transform(y_pred)
|
||||
# return F.l1_loss(torch.log(stft_mag_pred + eps), torch.log(stft_mag_true + eps))
|
||||
|
||||
|
||||
stft_loss_fn = MultiResolutionSTFTLoss(
|
||||
fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240]
|
||||
)
|
||||
|
||||
|
||||
def discriminator_train(
|
||||
high_quality,
|
||||
low_quality,
|
||||
real_labels,
|
||||
fake_labels,
|
||||
discriminator,
|
||||
generator,
|
||||
criterion,
|
||||
):
|
||||
discriminator_decision_from_real = discriminator(high_quality)
|
||||
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
|
||||
|
||||
with torch.no_grad():
|
||||
generator_output = generator(low_quality)
|
||||
discriminator_decision_from_fake = discriminator(generator_output)
|
||||
d_loss_fake = criterion(
|
||||
discriminator_decision_from_fake,
|
||||
fake_labels.expand_as(discriminator_decision_from_fake),
|
||||
)
|
||||
|
||||
d_loss = (d_loss_real + d_loss_fake) / 2.0
|
||||
|
||||
return d_loss
|
||||
|
||||
|
||||
def generator_train(
|
||||
low_quality,
|
||||
high_quality,
|
||||
real_labels,
|
||||
generator,
|
||||
discriminator,
|
||||
adv_criterion,
|
||||
lambda_adv: float = 1.0,
|
||||
lambda_feat: float = 10.0,
|
||||
lambda_stft: float = 2.5,
|
||||
):
|
||||
generator_output = generator(low_quality)
|
||||
|
||||
discriminator_decision = discriminator(generator_output)
|
||||
# adversarial_loss = adv_criterion(
|
||||
# discriminator_decision, real_labels.expand_as(discriminator_decision)
|
||||
# )
|
||||
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
|
||||
|
||||
combined_loss = lambda_adv * adversarial_loss
|
||||
|
||||
stft_losses = stft_loss_fn(high_quality, generator_output)
|
||||
stft_loss = stft_losses["total"]
|
||||
|
||||
combined_loss = (lambda_adv * adversarial_loss) + (lambda_stft * stft_loss)
|
||||
|
||||
return generator_output, combined_loss, adversarial_loss
|
||||
|
||||
|
||||
# def generator_train(
|
||||
# low_quality,
|
||||
# high_quality,
|
||||
# real_labels,
|
||||
# generator,
|
||||
# discriminator,
|
||||
# adv_criterion,
|
||||
# lambda_adv: float = 1.0,
|
||||
# lambda_mel_l1: float = 10.0,
|
||||
# lambda_log_stft: float = 1.0,
|
||||
|
||||
# ):
|
||||
# generator_output = generator(low_quality)
|
||||
|
||||
# discriminator_decision = discriminator(generator_output)
|
||||
# adversarial_loss = adv_criterion(
|
||||
# discriminator_decision, real_labels.expand_as(discriminator_decision)
|
||||
# )
|
||||
|
||||
# combined_loss = lambda_adv * adversarial_loss
|
||||
|
||||
# if lambda_mel_l1 > 0:
|
||||
# mel_l1_loss = mel_spectrogram_loss(high_quality, generator_output, "l1")
|
||||
# combined_loss += lambda_mel_l1 * mel_l1_loss
|
||||
# else:
|
||||
# mel_l1_loss = torch.tensor(0.0, device=low_quality.device) # For logging
|
||||
|
||||
# if lambda_log_stft > 0:
|
||||
# log_stft_loss = log_stft_magnitude_loss(high_quality, generator_output)
|
||||
# combined_loss += lambda_log_stft * log_stft_loss
|
||||
# else:
|
||||
# log_stft_loss = torch.tensor(0.0, device=low_quality.device)
|
||||
|
||||
# if lambda_mfcc > 0:
|
||||
# mfcc_loss_val = mfcc_loss(high_quality, generator_output)
|
||||
# combined_loss += lambda_mfcc * mfcc_loss_val
|
||||
# else:
|
||||
# mfcc_loss_val = torch.tensor(0.0, device=low_quality.device)
|
||||
|
||||
# return generator_output, combined_loss, adversarial_loss
|
@@ -8,8 +8,9 @@ import torchaudio.transforms as T
|
||||
|
||||
class MultiResolutionSTFTLoss(nn.Module):
|
||||
"""
|
||||
Computes a loss based on multiple STFT resolutions, including both
|
||||
spectral convergence and log STFT magnitude components.
|
||||
Multi-resolution STFT loss.
|
||||
Combines spectral convergence loss and log-magnitude loss
|
||||
across multiple STFT resolutions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -20,43 +21,67 @@ class MultiResolutionSTFTLoss(nn.Module):
|
||||
eps: float = 1e-7,
|
||||
):
|
||||
super().__init__()
|
||||
self.stft_transforms = nn.ModuleList(
|
||||
[
|
||||
T.Spectrogram(
|
||||
n_fft=n_fft, win_length=win_len, hop_length=hop_len, power=None
|
||||
)
|
||||
for n_fft, hop_len, win_len in zip(fft_sizes, hop_sizes, win_lengths)
|
||||
]
|
||||
)
|
||||
|
||||
self.eps = eps
|
||||
self.n_resolutions = len(fft_sizes)
|
||||
|
||||
self.stft_transforms = nn.ModuleList()
|
||||
for n_fft, hop_len, win_len in zip(fft_sizes, hop_sizes, win_lengths):
|
||||
window = torch.hann_window(win_len)
|
||||
stft = T.Spectrogram(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_len,
|
||||
win_length=win_len,
|
||||
window_fn=lambda _: window,
|
||||
power=None, # Keep complex output
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
)
|
||||
self.stft_transforms.append(stft)
|
||||
|
||||
def forward(
|
||||
self, y_true: torch.Tensor, y_pred: torch.Tensor
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
sc_loss = 0.0 # Spectral Convergence Loss
|
||||
mag_loss = 0.0 # Log STFT Magnitude Loss
|
||||
"""
|
||||
Args:
|
||||
y_true: (B, T) or (B, 1, T) waveform
|
||||
y_pred: (B, T) or (B, 1, T) waveform
|
||||
"""
|
||||
# Ensure correct shape (B, T)
|
||||
if y_true.dim() == 3 and y_true.size(1) == 1:
|
||||
y_true = y_true.squeeze(1)
|
||||
if y_pred.dim() == 3 and y_pred.size(1) == 1:
|
||||
y_pred = y_pred.squeeze(1)
|
||||
|
||||
sc_loss = 0.0
|
||||
mag_loss = 0.0
|
||||
|
||||
for stft in self.stft_transforms:
|
||||
stft.to(y_pred.device) # Ensure transform is on the correct device
|
||||
stft = stft.to(y_pred.device)
|
||||
|
||||
# Get complex STFTs
|
||||
# Complex STFTs: (B, F, T, 2)
|
||||
stft_true = stft(y_true)
|
||||
stft_pred = stft(y_pred)
|
||||
|
||||
# Get magnitudes
|
||||
# Magnitudes
|
||||
stft_mag_true = torch.abs(stft_true)
|
||||
stft_mag_pred = torch.abs(stft_pred)
|
||||
|
||||
# --- Spectral Convergence Loss ---
|
||||
# || |S_true| - |S_pred| ||_F / || |S_true| ||_F
|
||||
norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1))
|
||||
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1))
|
||||
sc_loss += torch.mean(norm_diff / (norm_true + self.eps))
|
||||
|
||||
# --- Log STFT Magnitude Loss ---
|
||||
mag_loss += F.l1_loss(
|
||||
torch.log(stft_mag_pred + self.eps), torch.log(stft_mag_true + self.eps)
|
||||
torch.log(stft_mag_pred + self.eps),
|
||||
torch.log(stft_mag_true + self.eps),
|
||||
)
|
||||
|
||||
# Average across resolutions
|
||||
sc_loss /= self.n_resolutions
|
||||
mag_loss /= self.n_resolutions
|
||||
total_loss = sc_loss + mag_loss
|
||||
|
||||
return {"total": total_loss, "sc": sc_loss, "mag": mag_loss}
|
||||
|
60
utils/TrainingTools.py
Normal file
60
utils/TrainingTools.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
|
||||
# In case if needed again...
|
||||
# from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
|
||||
#
|
||||
# stft_loss_fn = MultiResolutionSTFTLoss(
|
||||
# fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240]
|
||||
# )
|
||||
|
||||
|
||||
def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor:
|
||||
absolute_difference = torch.abs(input_one - input_two)
|
||||
return torch.mean(absolute_difference)
|
||||
|
||||
|
||||
def discriminator_train(
|
||||
high_quality,
|
||||
low_quality,
|
||||
high_labels,
|
||||
low_labels,
|
||||
discriminator,
|
||||
generator,
|
||||
criterion,
|
||||
):
|
||||
decision_high = discriminator(high_quality)
|
||||
d_loss_high = criterion(decision_high, high_labels)
|
||||
# print(f"Is this real?: {discriminator_decision_from_real} | {d_loss_real}")
|
||||
|
||||
decision_low = discriminator(low_quality)
|
||||
d_loss_low = criterion(decision_low, low_labels)
|
||||
# print(f"Is this real?: {discriminator_decision_from_fake} | {d_loss_fake}")
|
||||
|
||||
with torch.no_grad():
|
||||
generator_quality = generator(low_quality)
|
||||
decision_gen = discriminator(generator_quality)
|
||||
d_loss_gen = criterion(decision_gen, low_labels)
|
||||
|
||||
noise = torch.rand_like(high_quality) * 0.08
|
||||
decision_noise = discriminator(high_quality + noise)
|
||||
d_loss_noise = criterion(decision_noise, low_labels)
|
||||
|
||||
d_loss = (d_loss_high + d_loss_low + d_loss_gen + d_loss_noise) / 4.0
|
||||
|
||||
return d_loss
|
||||
|
||||
|
||||
def generator_train(
|
||||
low_quality, high_quality, real_labels, generator, discriminator, adv_criterion
|
||||
):
|
||||
generator_output = generator(low_quality)
|
||||
|
||||
discriminator_decision = discriminator(generator_output)
|
||||
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
|
||||
|
||||
# Signal similarity
|
||||
similarity_loss = signal_mae(generator_output, high_quality)
|
||||
|
||||
combined_loss = adversarial_loss + (similarity_loss * 100)
|
||||
|
||||
return combined_loss, adversarial_loss
|
Reference in New Issue
Block a user