⚗️ | More architectural changes

This commit is contained in:
2025-11-18 21:34:59 +02:00
parent 3f23242d6f
commit 782a3bab28
8 changed files with 245 additions and 254 deletions

View File

@@ -7,18 +7,13 @@ import torchaudio.transforms as T
class MultiResolutionSTFTLoss(nn.Module):
"""
Multi-resolution STFT loss.
Combines spectral convergence loss and log-magnitude loss
across multiple STFT resolutions.
"""
def __init__(
self,
fft_sizes: List[int] = [1024, 2048, 512],
hop_sizes: List[int] = [120, 240, 50],
win_lengths: List[int] = [600, 1200, 240],
fft_sizes: List[int] = [512, 1024, 2048, 4096, 8192],
hop_sizes: List[int] = [64, 128, 256, 512, 1024],
win_lengths: List[int] = [256, 512, 1024, 2048, 4096],
eps: float = 1e-7,
center: bool = True
):
super().__init__()
@@ -26,15 +21,14 @@ class MultiResolutionSTFTLoss(nn.Module):
self.n_resolutions = len(fft_sizes)
self.stft_transforms = nn.ModuleList()
for n_fft, hop_len, win_len in zip(fft_sizes, hop_sizes, win_lengths):
window = torch.hann_window(win_len)
for i, (n_fft, hop_len, win_len) in enumerate(zip(fft_sizes, hop_sizes, win_lengths)):
stft = T.Spectrogram(
n_fft=n_fft,
hop_length=hop_len,
win_length=win_len,
window_fn=lambda _: window,
power=None, # Keep complex output
center=True,
window_fn=torch.hann_window,
power=None,
center=center,
pad_mode="reflect",
normalized=False,
)
@@ -43,12 +37,6 @@ class MultiResolutionSTFTLoss(nn.Module):
def forward(
self, y_true: torch.Tensor, y_pred: torch.Tensor
) -> Dict[str, torch.Tensor]:
"""
Args:
y_true: (B, T) or (B, 1, T) waveform
y_pred: (B, T) or (B, 1, T) waveform
"""
# Ensure correct shape (B, T)
if y_true.dim() == 3 and y_true.size(1) == 1:
y_true = y_true.squeeze(1)
if y_pred.dim() == 3 and y_pred.size(1) == 1:
@@ -58,28 +46,21 @@ class MultiResolutionSTFTLoss(nn.Module):
mag_loss = 0.0
for stft in self.stft_transforms:
stft = stft.to(y_pred.device)
# Complex STFTs: (B, F, T, 2)
stft.window = stft.window.to(y_true.device)
stft_true = stft(y_true)
stft_pred = stft(y_pred)
# Magnitudes
stft_mag_true = torch.abs(stft_true)
stft_mag_pred = torch.abs(stft_pred)
# --- Spectral Convergence Loss ---
norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1))
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1))
sc_loss += torch.mean(norm_diff / (norm_true + self.eps))
# --- Log STFT Magnitude Loss ---
mag_loss += F.l1_loss(
torch.log(stft_mag_pred + self.eps),
torch.log(stft_mag_true + self.eps),
)
log_mag_pred = torch.log(stft_mag_pred + self.eps)
log_mag_true = torch.log(stft_mag_true + self.eps)
mag_loss += F.l1_loss(log_mag_pred, log_mag_true)
# Average across resolutions
sc_loss /= self.n_resolutions
mag_loss /= self.n_resolutions
total_loss = sc_loss + mag_loss

View File

@@ -1,12 +1,17 @@
import torch
# In case if needed again...
# from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
#
# stft_loss_fn = MultiResolutionSTFTLoss(
# fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240]
# )
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
# stft_loss_fn = MultiResolutionSTFTLoss(
# fft_sizes=[512, 1024, 2048, 4096],
# hop_sizes=[128, 256, 512, 1024],
# win_lengths=[512, 1024, 2048, 4096]
# )
stft_loss_fn = MultiResolutionSTFTLoss(
fft_sizes=[512, 1024, 2048],
hop_sizes=[64, 128, 256],
win_lengths=[256, 512, 1024]
)
def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor:
absolute_difference = torch.abs(input_one - input_two)
@@ -19,42 +24,35 @@ def discriminator_train(
high_labels,
low_labels,
discriminator,
generator,
criterion,
generator_output
):
decision_high = discriminator(high_quality)
d_loss_high = criterion(decision_high, high_labels)
# print(f"Is this real?: {discriminator_decision_from_real} | {d_loss_real}")
decision_low = discriminator(low_quality)
d_loss_low = criterion(decision_low, low_labels)
# print(f"Is this real?: {discriminator_decision_from_fake} | {d_loss_fake}")
real_pair = torch.cat((low_quality, high_quality), dim=1)
decision_real = discriminator(real_pair)
d_loss_real = criterion(decision_real, high_labels)
with torch.no_grad():
generator_quality = generator(low_quality)
decision_gen = discriminator(generator_quality)
d_loss_gen = criterion(decision_gen, low_labels)
noise = torch.rand_like(high_quality) * 0.08
decision_noise = discriminator(high_quality + noise)
d_loss_noise = criterion(decision_noise, low_labels)
d_loss = (d_loss_high + d_loss_low + d_loss_gen + d_loss_noise) / 4.0
fake_pair = torch.cat((low_quality, generator_output), dim=1)
decision_fake = discriminator(fake_pair)
d_loss_fake = criterion(decision_fake, low_labels)
d_loss = (d_loss_real + d_loss_fake) / 2.0
return d_loss
def generator_train(
low_quality, high_quality, real_labels, generator, discriminator, adv_criterion
):
generator_output = generator(low_quality)
low_quality, high_quality, real_labels, generator, discriminator, adv_criterion, generator_output):
discriminator_decision = discriminator(generator_output)
fake_pair = torch.cat((low_quality, generator_output), dim=1)
discriminator_decision = discriminator(fake_pair)
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
# Signal similarity
similarity_loss = signal_mae(generator_output, high_quality)
combined_loss = adversarial_loss + (similarity_loss * 100)
mae_loss = signal_mae(generator_output, high_quality)
stft_loss = stft_loss_fn(high_quality, generator_output)["total"]
lambda_mae = 10.0
lambda_stft = 2.5
lambda_adv = 2.5
combined_loss = (lambda_mae * mae_loss) + (lambda_stft * stft_loss) + (lambda_adv * adversarial_loss)
return combined_loss, adversarial_loss