⚗️ | More architectural changes

This commit is contained in:
2025-11-18 21:34:59 +02:00
parent 3f23242d6f
commit 782a3bab28
8 changed files with 245 additions and 254 deletions

View File

@@ -3,95 +3,39 @@ import torch.nn.functional as F
def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor: def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor:
""" mono_tensor = torch.mean(waveform, dim=0, keepdim=True)
Convert stereo (C, N) to mono (1, N). Ensures a channel dimension. return mono_tensor
"""
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0) # (N,) -> (1, N)
if waveform.shape[0] > 1:
mono_waveform = torch.mean(waveform, dim=0, keepdim=True) # (1, N)
else:
mono_waveform = waveform
return mono_waveform
def stretch_tensor(tensor: torch.Tensor, target_length: int) -> torch.Tensor: def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 512) -> torch.Tensor:
""" padding_amount = target_length - audio_tensor.size(-1)
Stretch audio along time dimension to target_length. if padding_amount <= 0:
Input assumed (1, N). Returns (1, target_length). return audio_tensor
"""
if tensor.dim() == 1:
tensor = tensor.unsqueeze(0) # ensure (1, N)
tensor = tensor.unsqueeze(0) # (1, 1, N) for interpolate padded_audio_tensor = F.pad(audio_tensor, (0, padding_amount))
stretched = F.interpolate(
tensor, size=target_length, mode="linear", align_corners=False
)
return stretched.squeeze(0) # back to (1, target_length)
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128) -> torch.Tensor:
"""
Pad to fixed length. Input assumed (1, N). Returns (1, target_length).
"""
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
current_length = audio_tensor.shape[-1]
if current_length < target_length:
padding_needed = target_length - current_length
padding_tuple = (0, padding_needed)
padded_audio_tensor = F.pad(
audio_tensor, padding_tuple, mode="constant", value=0
)
else:
padded_audio_tensor = audio_tensor[..., :target_length] # crop if too long
return padded_audio_tensor return padded_audio_tensor
def split_audio( def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 512, pad_last_tensor: bool = False) -> list[torch.Tensor]:
audio_tensor: torch.Tensor, chunk_size: int = 128 chunks = list(torch.split(audio_tensor, chunk_size, dim=1))
) -> list[torch.Tensor]:
"""
Split into chunks of (1, chunk_size).
"""
if not isinstance(chunk_size, int) or chunk_size <= 0:
raise ValueError("chunk_size must be a positive integer.")
if audio_tensor.dim() == 1: if pad_last_tensor:
audio_tensor = audio_tensor.unsqueeze(0) last_chunk = chunks[-1]
num_samples = audio_tensor.shape[-1] if last_chunk.size(-1) < chunk_size:
if num_samples == 0: chunks[-1] = pad_tensor(last_chunk, chunk_size)
return []
chunks = list(torch.split(audio_tensor, chunk_size, dim=-1))
return chunks return chunks
def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor: def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor:
""" reconstructed_tensor = torch.cat(chunks, dim=-1)
Reconstruct audio from chunks. Returns (1, N).
"""
if not chunks:
return torch.empty(1, 0)
chunks = [c if c.dim() == 2 else c.unsqueeze(0) for c in chunks]
try:
reconstructed_tensor = torch.cat(chunks, dim=-1)
except RuntimeError as e:
raise RuntimeError(
f"Failed to concatenate audio chunks. Ensure chunks have compatible shapes "
f"for concatenation along dim -1. Original error: {e}"
)
return reconstructed_tensor return reconstructed_tensor
def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
max_val = torch.max(torch.abs(audio_tensor)) max_val = torch.max(torch.abs(audio_tensor))
if max_val < eps: if max_val < eps:
return audio_tensor # silence, skip normalization return audio_tensor
return audio_tensor / max_val return audio_tensor / max_val

55
app.py
View File

@@ -4,6 +4,7 @@ import torch
import torchaudio import torchaudio
import torchcodec import torchcodec
import tqdm import tqdm
from accelerate import Accelerator
import AudioUtils import AudioUtils
from generator import SISUGenerator from generator import SISUGenerator
@@ -15,7 +16,7 @@ parser.add_argument("--model", type=str, help="Model to use for upscaling")
parser.add_argument( parser.add_argument(
"--clip_length", "--clip_length",
type=int, type=int,
default=16384, default=8000,
help="Internal clip length, leave unspecified if unsure", help="Internal clip length, leave unspecified if unsure",
) )
parser.add_argument( parser.add_argument(
@@ -38,21 +39,44 @@ if args.sample_rate < 8000:
) )
exit() exit()
device = torch.device(args.device if torch.cuda.is_available() else "cpu") # ---------------------------
print(f"Using device: {device}") # Init accelerator
# ---------------------------
generator = SISUGenerator().to(device) accelerator = Accelerator(mixed_precision="bf16")
# ---------------------------
# Models
# ---------------------------
generator = SISUGenerator()
accelerator.print("🔨 | Compiling models...")
generator = torch.compile(generator) generator = torch.compile(generator)
accelerator.print("✅ | Compiling done!")
# ---------------------------
# Prepare accelerator
# ---------------------------
generator = accelerator.prepare(generator)
# ---------------------------
# Checkpoint helpers
# ---------------------------
models_dir = args.model models_dir = args.model
clip_length = args.clip_length clip_length = args.clip_length
input_audio = args.input input_audio = args.input
output_audio = args.output output_audio = args.output
if models_dir: if models_dir:
ckpt = torch.load(models_dir, map_location=device) ckpt = torch.load(models_dir)
generator.load_state_dict(ckpt["G"])
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
accelerator.print("💾 | Loaded model!")
else: else:
print( print(
"Generator model (--model) isn't specified. Do you have the trained model? If not, you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)" "Generator model (--model) isn't specified. Do you have the trained model? If not, you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)"
@@ -67,7 +91,8 @@ def start():
audio = decoded_samples.data audio = decoded_samples.data
original_sample_rate = decoded_samples.sample_rate original_sample_rate = decoded_samples.sample_rate
audio = AudioUtils.stereo_tensor_to_mono(audio) # Support for multichannel audio
# audio = AudioUtils.stereo_tensor_to_mono(audio)
audio = AudioUtils.normalize(audio) audio = AudioUtils.normalize(audio)
resample_transform = torchaudio.transforms.Resample( resample_transform = torchaudio.transforms.Resample(
@@ -77,14 +102,20 @@ def start():
audio = resample_transform(audio) audio = resample_transform(audio)
splitted_audio = AudioUtils.split_audio(audio, clip_length) splitted_audio = AudioUtils.split_audio(audio, clip_length)
splitted_audio_on_device = [t.to(device) for t in splitted_audio] splitted_audio_on_device = [t.view(1, t.shape[0], t.shape[-1]).to(accelerator.device) for t in splitted_audio]
processed_audio = [] processed_audio = []
with torch.no_grad():
for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."): for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."):
processed_audio.append(generator(clip)) channels = []
for audio_channel in torch.split(clip, 1, dim=1):
output_piece = generator(audio_channel)
channels.append(output_piece.detach().cpu())
output_clip = torch.cat(channels, dim=1)
processed_audio.append(output_clip)
reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio) reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio)
print(f"Saving {output_audio}!") reconstructed_audio = reconstructed_audio.squeeze(0)
print(f"🔊 | Saving {output_audio}!")
torchaudio.save_with_torchcodec( torchaudio.save_with_torchcodec(
uri=output_audio, uri=output_audio,
src=reconstructed_audio, src=reconstructed_audio,

62
data.py
View File

@@ -1,6 +1,7 @@
import os import os
import random import random
import torch
import torchaudio import torchaudio
import torchcodec.decoders as decoders import torchcodec.decoders as decoders
import tqdm import tqdm
@@ -10,9 +11,9 @@ import AudioUtils
class AudioDataset(Dataset): class AudioDataset(Dataset):
audio_sample_rates = [11025] audio_sample_rates = [8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100]
def __init__(self, input_dir, clip_length: int = 8000, normalize: bool = True): def __init__(self, input_dir, clip_length: int = 512, normalize: bool = True):
self.clip_length = clip_length self.clip_length = clip_length
self.normalize = normalize self.normalize = normalize
@@ -30,45 +31,20 @@ class AudioDataset(Dataset):
decoder = decoders.AudioDecoder(audio_clip) decoder = decoders.AudioDecoder(audio_clip)
decoded_samples = decoder.get_all_samples() decoded_samples = decoder.get_all_samples()
audio = decoded_samples.data.float() # ensure float32 audio = decoded_samples.data.float()
original_sample_rate = decoded_samples.sample_rate original_sample_rate = decoded_samples.sample_rate
audio = AudioUtils.stereo_tensor_to_mono(audio)
if normalize: if normalize:
audio = AudioUtils.normalize(audio) audio = AudioUtils.normalize(audio)
mangled_sample_rate = random.choice(self.audio_sample_rates) splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length, True)
resample_transform_low = torchaudio.transforms.Resample(
original_sample_rate, mangled_sample_rate
)
resample_transform_high = torchaudio.transforms.Resample(
mangled_sample_rate, original_sample_rate
)
low_audio = resample_transform_high(resample_transform_low(audio)) if not splitted_high_quality_audio:
continue
splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length) for splitted_audio_clip in splitted_high_quality_audio:
splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length) for audio_clip in torch.split(splitted_audio_clip, 1):
data.append((audio_clip, original_sample_rate))
if not splitted_high_quality_audio or not splitted_low_quality_audio:
continue # skip empty or invalid clips
splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(
splitted_high_quality_audio[-1], clip_length
)
splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(
splitted_low_quality_audio[-1], clip_length
)
for high_quality_data, low_quality_data in zip(
splitted_high_quality_audio, splitted_low_quality_audio
):
data.append(
(
(high_quality_data, low_quality_data),
(original_sample_rate, mangled_sample_rate),
)
)
self.audio_data = data self.audio_data = data
@@ -76,4 +52,20 @@ class AudioDataset(Dataset):
return len(self.audio_data) return len(self.audio_data)
def __getitem__(self, idx): def __getitem__(self, idx):
return self.audio_data[idx] audio_clip = self.audio_data[idx]
mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform_low = torchaudio.transforms.Resample(
audio_clip[1], mangled_sample_rate
)
resample_transform_high = torchaudio.transforms.Resample(
mangled_sample_rate, audio_clip[1]
)
low_audio_clip = resample_transform_high(resample_transform_low(audio_clip[0]))
if audio_clip[0].shape[1] < low_audio_clip.shape[1]:
low_audio_clip = low_audio_clip[:, :audio_clip[0].shape[1]]
elif audio_clip[0].shape[1] > low_audio_clip.shape[1]:
low_audio_clip = AudioUtils.pad_tensor(low_audio_clip, self.clip_length)
return ((audio_clip[0], low_audio_clip), (audio_clip[1], mangled_sample_rate))

View File

@@ -5,32 +5,25 @@ import torch.nn.utils as utils
def discriminator_block( def discriminator_block(
in_channels, in_channels,
out_channels, out_channels,
kernel_size=3, kernel_size=15,
stride=1, stride=1,
dilation=1, dilation=1
spectral_norm=True,
use_instance_norm=True,
): ):
padding = (kernel_size // 2) * dilation padding = dilation * (kernel_size - 1) // 2
conv_layer = nn.Conv1d( conv_layer = nn.Conv1d(
in_channels, in_channels,
out_channels, out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride, stride=stride,
dilation=dilation, dilation=dilation,
padding=padding, padding=padding
) )
if spectral_norm: conv_layer = utils.spectral_norm(conv_layer)
conv_layer = utils.spectral_norm(conv_layer) leaky_relu = nn.LeakyReLU(0.2)
layers = [conv_layer] return nn.Sequential(conv_layer, leaky_relu)
layers.append(nn.LeakyReLU(0.2, inplace=True))
if use_instance_norm:
layers.append(nn.InstanceNorm1d(out_channels))
return nn.Sequential(*layers)
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
@@ -38,38 +31,40 @@ class AttentionBlock(nn.Module):
super(AttentionBlock, self).__init__() super(AttentionBlock, self).__init__()
self.attention = nn.Sequential( self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1), nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True), nn.ReLU(),
nn.Conv1d(channels // 4, channels, kernel_size=1), nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid(), nn.Sigmoid(),
) )
def forward(self, x): def forward(self, x):
attention_weights = self.attention(x) attention_weights = self.attention(x)
return x * attention_weights return x + (x * attention_weights)
class SISUDiscriminator(nn.Module): class SISUDiscriminator(nn.Module):
def __init__(self, layers=32): def __init__(self, layers=8):
super(SISUDiscriminator, self).__init__() super(SISUDiscriminator, self).__init__()
self.model = nn.Sequential( self.discriminator_blocks = nn.Sequential(
discriminator_block(1, layers, kernel_size=7, stride=1), # 1 -> 32
discriminator_block(layers, layers * 2, kernel_size=5, stride=2), discriminator_block(2, layers),
discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=2), AttentionBlock(layers),
# 32 -> 64
discriminator_block(layers, layers * 2, dilation=2),
# 64 -> 128
discriminator_block(layers * 2, layers * 4, dilation=4),
AttentionBlock(layers * 4), AttentionBlock(layers * 4),
discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=4), # 128 -> 256
discriminator_block(layers * 8, layers * 2, kernel_size=5, stride=2), discriminator_block(layers * 4, layers * 8, stride=4),
discriminator_block( # 256 -> 512
layers * 2, # discriminator_block(layers * 8, layers * 16, stride=4)
1,
spectral_norm=False,
use_instance_norm=False,
),
) )
self.global_avg_pool = nn.AdaptiveAvgPool1d(1) self.final_conv = nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1)
self.avg_pool = nn.AdaptiveAvgPool1d(1)
def forward(self, x): def forward(self, x):
x = self.model(x) x = self.discriminator_blocks(x)
x = self.global_avg_pool(x) x = self.final_conv(x)
x = x.view(x.size(0), -1) x = self.avg_pool(x)
return x return x.squeeze(2)

View File

@@ -2,25 +2,23 @@ import torch
import torch.nn as nn import torch.nn as nn
def conv_block(in_channels, out_channels, kernel_size=3, dilation=1): def GeneratorBlock(in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
padding = (kernel_size - 1) // 2 * dilation
return nn.Sequential( return nn.Sequential(
nn.Conv1d( nn.Conv1d(
in_channels, in_channels,
out_channels, out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride,
dilation=dilation, dilation=dilation,
padding=(kernel_size // 2) * dilation, padding=padding
), ),
nn.InstanceNorm1d(out_channels), nn.InstanceNorm1d(out_channels),
nn.PReLU(), nn.PReLU(num_parameters=1, init=0.1),
) )
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
"""
Simple Channel Attention Block. Learns to weight channels based on their importance.
"""
def __init__(self, channels): def __init__(self, channels):
super(AttentionBlock, self).__init__() super(AttentionBlock, self).__init__()
self.attention = nn.Sequential( self.attention = nn.Sequential(
@@ -32,7 +30,7 @@ class AttentionBlock(nn.Module):
def forward(self, x): def forward(self, x):
attention_weights = self.attention(x) attention_weights = self.attention(x)
return x * attention_weights return x + (x * attention_weights)
class ResidualInResidualBlock(nn.Module): class ResidualInResidualBlock(nn.Module):
@@ -40,7 +38,7 @@ class ResidualInResidualBlock(nn.Module):
super(ResidualInResidualBlock, self).__init__() super(ResidualInResidualBlock, self).__init__()
self.conv_layers = nn.Sequential( self.conv_layers = nn.Sequential(
*[conv_block(channels, channels) for _ in range(num_convs)] *[GeneratorBlock(channels, channels) for _ in range(num_convs)]
) )
self.attention = AttentionBlock(channels) self.attention = AttentionBlock(channels)
@@ -51,31 +49,74 @@ class ResidualInResidualBlock(nn.Module):
x = self.attention(x) x = self.attention(x)
return x + residual return x + residual
def UpsampleBlock(in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=4,
stride=2,
padding=1
),
nn.InstanceNorm1d(out_channels),
nn.PReLU(num_parameters=1, init=0.1)
)
class SISUGenerator(nn.Module): class SISUGenerator(nn.Module):
def __init__(self, channels=16, num_rirb=4, alpha=1): def __init__(self, channels=32, num_rirb=1):
super(SISUGenerator, self).__init__() super(SISUGenerator, self).__init__()
self.alpha = alpha
self.conv1 = nn.Sequential( self.first_conv = GeneratorBlock(1, channels)
nn.Conv1d(1, channels, kernel_size=7, padding=3),
nn.InstanceNorm1d(channels), self.downsample = GeneratorBlock(channels, channels * 2, stride=2)
nn.PReLU(), self.downsample_attn = AttentionBlock(channels * 2)
self.downsample_2 = GeneratorBlock(channels * 2, channels * 4, stride=2)
self.downsample_2_attn = AttentionBlock(channels * 4)
self.rirb = ResidualInResidualBlock(channels * 4)
# self.rirb = nn.Sequential(
# *[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)]
# )
self.upsample = UpsampleBlock(channels * 4, channels * 2)
self.upsample_attn = AttentionBlock(channels * 2)
self.compress_1 = GeneratorBlock(channels * 4, channels * 2)
self.upsample_2 = UpsampleBlock(channels * 2, channels)
self.upsample_2_attn = AttentionBlock(channels)
self.compress_2 = GeneratorBlock(channels * 2, channels)
self.final_conv = nn.Sequential(
nn.Conv1d(channels, 1, kernel_size=7, padding=3),
nn.Tanh()
) )
self.rir_blocks = nn.Sequential(
*[ResidualInResidualBlock(channels) for _ in range(num_rirb)]
)
self.final_layer = nn.Sequential(
nn.Conv1d(channels, 1, kernel_size=3, padding=1), nn.Tanh()
)
def forward(self, x): def forward(self, x):
residual_input = x residual_input = x
x = self.conv1(x) x1 = self.first_conv(x)
x_rirb_out = self.rir_blocks(x)
learned_residual = self.final_layer(x_rirb_out)
output = residual_input + self.alpha * learned_residual
return torch.tanh(output) x2 = self.downsample(x1)
x2 = self.downsample_attn(x2)
x3 = self.downsample_2(x2)
x3 = self.downsample_2_attn(x3)
x_rirb = self.rirb(x3)
up1 = self.upsample(x_rirb)
up1 = self.upsample_attn(up1)
cat1 = torch.cat((up1, x2), dim=1)
comp1 = self.compress_1(cat1)
up2 = self.upsample_2(comp1)
up2 = self.upsample_2_attn(up2)
cat2 = torch.cat((up2, x1), dim=1)
comp2 = self.compress_2(cat2)
learned_residual = self.final_conv(comp2)
output = residual_input + learned_residual
return output

View File

@@ -1,4 +1,5 @@
import argparse import argparse
import datetime
import os import os
import torch import torch
@@ -52,7 +53,7 @@ accelerator.print("✅ | Compiling done!")
# Dataset / DataLoader # Dataset / DataLoader
# --------------------------- # ---------------------------
accelerator.print("📊 | Fetching dataset...") accelerator.print("📊 | Fetching dataset...")
dataset = AudioDataset("./dataset") dataset = AudioDataset("./dataset", 8192)
sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None
pin_memory = torch.cuda.is_available() and not args.no_pin_memory pin_memory = torch.cuda.is_available() and not args.no_pin_memory
@@ -93,7 +94,6 @@ scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_d, mode="min", factor=0.5, patience=5 optimizer_d, mode="min", factor=0.5, patience=5
) )
criterion_g = nn.BCEWithLogitsLoss()
criterion_d = nn.MSELoss() criterion_d = nn.MSELoss()
# --------------------------- # ---------------------------
@@ -143,12 +143,8 @@ if args.resume:
start_epoch = ckpt.get("epoch", 1) start_epoch = ckpt.get("epoch", 1)
accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!") accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!")
real_buf = torch.full( real_buf = torch.full((loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32)
(loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32 fake_buf = torch.zeros((loader_batch_size, 1), device=accelerator.device, dtype=torch.float32)
)
fake_buf = torch.zeros(
(loader_batch_size, 1), device=accelerator.device, dtype=torch.float32
)
accelerator.print("🏋️ | Started training...") accelerator.print("🏋️ | Started training...")
@@ -157,35 +153,45 @@ try:
generator.train() generator.train()
discriminator.train() discriminator.train()
discriminator_time = 0
generator_time = 0
running_d, running_g, steps = 0.0, 0.0, 0 running_d, running_g, steps = 0.0, 0.0, 0
progress_bar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs")
for i, ( for i, (
(high_quality, low_quality), (high_quality, low_quality),
(high_sample_rate, low_sample_rate), (high_sample_rate, low_sample_rate),
) in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")): ) in enumerate(progress_bar):
batch_size = high_quality.size(0) batch_size = high_quality.size(0)
real_labels = real_buf[:batch_size].to(accelerator.device) real_labels = real_buf[:batch_size].to(accelerator.device)
fake_labels = fake_buf[:batch_size].to(accelerator.device) fake_labels = fake_buf[:batch_size].to(accelerator.device)
with accelerator.autocast():
generator_output = generator(low_quality)
# --- Discriminator --- # --- Discriminator ---
d_time = datetime.datetime.now()
optimizer_d.zero_grad(set_to_none=True) optimizer_d.zero_grad(set_to_none=True)
with accelerator.autocast(): with accelerator.autocast():
d_loss = discriminator_train( d_loss = discriminator_train(
high_quality, high_quality,
low_quality, low_quality.detach(),
real_labels, real_labels,
fake_labels, fake_labels,
discriminator, discriminator,
generator,
criterion_d, criterion_d,
generator_output.detach()
) )
accelerator.backward(d_loss) accelerator.backward(d_loss)
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1)
optimizer_d.step() optimizer_d.step()
discriminator_time = (datetime.datetime.now() - d_time).microseconds
# --- Generator --- # --- Generator ---
g_time = datetime.datetime.now()
optimizer_g.zero_grad(set_to_none=True) optimizer_g.zero_grad(set_to_none=True)
with accelerator.autocast(): with accelerator.autocast():
g_total, g_adv = generator_train( g_total, g_adv = generator_train(
@@ -195,11 +201,13 @@ try:
generator, generator,
discriminator, discriminator,
criterion_d, criterion_d,
generator_output
) )
accelerator.backward(g_total) accelerator.backward(g_total)
torch.nn.utils.clip_grad_norm_(generator.parameters(), 1) torch.nn.utils.clip_grad_norm_(generator.parameters(), 1)
optimizer_g.step() optimizer_g.step()
generator_time = (datetime.datetime.now() - g_time).microseconds
d_val = accelerator.gather(d_loss.detach()).mean() d_val = accelerator.gather(d_loss.detach()).mean()
g_val = accelerator.gather(g_total.detach()).mean() g_val = accelerator.gather(g_total.detach()).mean()
@@ -219,6 +227,7 @@ try:
) )
steps += 1 steps += 1
progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs")
# epoch averages & schedulers # epoch averages & schedulers
if steps == 0: if steps == 0:

View File

@@ -7,18 +7,13 @@ import torchaudio.transforms as T
class MultiResolutionSTFTLoss(nn.Module): class MultiResolutionSTFTLoss(nn.Module):
"""
Multi-resolution STFT loss.
Combines spectral convergence loss and log-magnitude loss
across multiple STFT resolutions.
"""
def __init__( def __init__(
self, self,
fft_sizes: List[int] = [1024, 2048, 512], fft_sizes: List[int] = [512, 1024, 2048, 4096, 8192],
hop_sizes: List[int] = [120, 240, 50], hop_sizes: List[int] = [64, 128, 256, 512, 1024],
win_lengths: List[int] = [600, 1200, 240], win_lengths: List[int] = [256, 512, 1024, 2048, 4096],
eps: float = 1e-7, eps: float = 1e-7,
center: bool = True
): ):
super().__init__() super().__init__()
@@ -26,15 +21,14 @@ class MultiResolutionSTFTLoss(nn.Module):
self.n_resolutions = len(fft_sizes) self.n_resolutions = len(fft_sizes)
self.stft_transforms = nn.ModuleList() self.stft_transforms = nn.ModuleList()
for n_fft, hop_len, win_len in zip(fft_sizes, hop_sizes, win_lengths): for i, (n_fft, hop_len, win_len) in enumerate(zip(fft_sizes, hop_sizes, win_lengths)):
window = torch.hann_window(win_len)
stft = T.Spectrogram( stft = T.Spectrogram(
n_fft=n_fft, n_fft=n_fft,
hop_length=hop_len, hop_length=hop_len,
win_length=win_len, win_length=win_len,
window_fn=lambda _: window, window_fn=torch.hann_window,
power=None, # Keep complex output power=None,
center=True, center=center,
pad_mode="reflect", pad_mode="reflect",
normalized=False, normalized=False,
) )
@@ -43,12 +37,6 @@ class MultiResolutionSTFTLoss(nn.Module):
def forward( def forward(
self, y_true: torch.Tensor, y_pred: torch.Tensor self, y_true: torch.Tensor, y_pred: torch.Tensor
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
"""
Args:
y_true: (B, T) or (B, 1, T) waveform
y_pred: (B, T) or (B, 1, T) waveform
"""
# Ensure correct shape (B, T)
if y_true.dim() == 3 and y_true.size(1) == 1: if y_true.dim() == 3 and y_true.size(1) == 1:
y_true = y_true.squeeze(1) y_true = y_true.squeeze(1)
if y_pred.dim() == 3 and y_pred.size(1) == 1: if y_pred.dim() == 3 and y_pred.size(1) == 1:
@@ -58,28 +46,21 @@ class MultiResolutionSTFTLoss(nn.Module):
mag_loss = 0.0 mag_loss = 0.0
for stft in self.stft_transforms: for stft in self.stft_transforms:
stft = stft.to(y_pred.device) stft.window = stft.window.to(y_true.device)
# Complex STFTs: (B, F, T, 2)
stft_true = stft(y_true) stft_true = stft(y_true)
stft_pred = stft(y_pred) stft_pred = stft(y_pred)
# Magnitudes
stft_mag_true = torch.abs(stft_true) stft_mag_true = torch.abs(stft_true)
stft_mag_pred = torch.abs(stft_pred) stft_mag_pred = torch.abs(stft_pred)
# --- Spectral Convergence Loss ---
norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1)) norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1))
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1)) norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1))
sc_loss += torch.mean(norm_diff / (norm_true + self.eps)) sc_loss += torch.mean(norm_diff / (norm_true + self.eps))
# --- Log STFT Magnitude Loss --- log_mag_pred = torch.log(stft_mag_pred + self.eps)
mag_loss += F.l1_loss( log_mag_true = torch.log(stft_mag_true + self.eps)
torch.log(stft_mag_pred + self.eps), mag_loss += F.l1_loss(log_mag_pred, log_mag_true)
torch.log(stft_mag_true + self.eps),
)
# Average across resolutions
sc_loss /= self.n_resolutions sc_loss /= self.n_resolutions
mag_loss /= self.n_resolutions mag_loss /= self.n_resolutions
total_loss = sc_loss + mag_loss total_loss = sc_loss + mag_loss

View File

@@ -1,12 +1,17 @@
import torch import torch
# In case if needed again... from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
# from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
#
# stft_loss_fn = MultiResolutionSTFTLoss(
# fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240]
# )
# stft_loss_fn = MultiResolutionSTFTLoss(
# fft_sizes=[512, 1024, 2048, 4096],
# hop_sizes=[128, 256, 512, 1024],
# win_lengths=[512, 1024, 2048, 4096]
# )
stft_loss_fn = MultiResolutionSTFTLoss(
fft_sizes=[512, 1024, 2048],
hop_sizes=[64, 128, 256],
win_lengths=[256, 512, 1024]
)
def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor: def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor:
absolute_difference = torch.abs(input_one - input_two) absolute_difference = torch.abs(input_one - input_two)
@@ -19,42 +24,35 @@ def discriminator_train(
high_labels, high_labels,
low_labels, low_labels,
discriminator, discriminator,
generator,
criterion, criterion,
generator_output
): ):
decision_high = discriminator(high_quality)
d_loss_high = criterion(decision_high, high_labels)
# print(f"Is this real?: {discriminator_decision_from_real} | {d_loss_real}")
decision_low = discriminator(low_quality) real_pair = torch.cat((low_quality, high_quality), dim=1)
d_loss_low = criterion(decision_low, low_labels) decision_real = discriminator(real_pair)
# print(f"Is this real?: {discriminator_decision_from_fake} | {d_loss_fake}") d_loss_real = criterion(decision_real, high_labels)
with torch.no_grad(): fake_pair = torch.cat((low_quality, generator_output), dim=1)
generator_quality = generator(low_quality) decision_fake = discriminator(fake_pair)
decision_gen = discriminator(generator_quality) d_loss_fake = criterion(decision_fake, low_labels)
d_loss_gen = criterion(decision_gen, low_labels)
noise = torch.rand_like(high_quality) * 0.08
decision_noise = discriminator(high_quality + noise)
d_loss_noise = criterion(decision_noise, low_labels)
d_loss = (d_loss_high + d_loss_low + d_loss_gen + d_loss_noise) / 4.0
d_loss = (d_loss_real + d_loss_fake) / 2.0
return d_loss return d_loss
def generator_train( def generator_train(
low_quality, high_quality, real_labels, generator, discriminator, adv_criterion low_quality, high_quality, real_labels, generator, discriminator, adv_criterion, generator_output):
):
generator_output = generator(low_quality)
discriminator_decision = discriminator(generator_output) fake_pair = torch.cat((low_quality, generator_output), dim=1)
discriminator_decision = discriminator(fake_pair)
adversarial_loss = adv_criterion(discriminator_decision, real_labels) adversarial_loss = adv_criterion(discriminator_decision, real_labels)
# Signal similarity mae_loss = signal_mae(generator_output, high_quality)
similarity_loss = signal_mae(generator_output, high_quality) stft_loss = stft_loss_fn(high_quality, generator_output)["total"]
combined_loss = adversarial_loss + (similarity_loss * 100)
lambda_mae = 10.0
lambda_stft = 2.5
lambda_adv = 2.5
combined_loss = (lambda_mae * mae_loss) + (lambda_stft * stft_loss) + (lambda_adv * adversarial_loss)
return combined_loss, adversarial_loss return combined_loss, adversarial_loss