⚗️ | Added MultiPeriodDiscriminator implementation from HiFi-GAN
This commit is contained in:
247
discriminator.py
247
discriminator.py
@@ -1,98 +1,179 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.utils as utils
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
from torch.nn.utils.parametrizations import weight_norm, spectral_norm
|
||||||
|
|
||||||
class PatchEmbedding(nn.Module):
|
# -------------------------------------------------------------------
|
||||||
"""
|
# 1. Multi-Period Discriminator (MPD)
|
||||||
Converts raw audio into a sequence of embeddings (tokens).
|
# Captures periodic structures (pitch/timbre) by folding audio.
|
||||||
Small patch_size = Higher Precision (more tokens, finer detail).
|
# -------------------------------------------------------------------
|
||||||
Large patch_size = Lower Precision (fewer tokens, more global).
|
|
||||||
"""
|
|
||||||
def __init__(self, in_channels, embed_dim, patch_size, spectral_norm=True):
|
|
||||||
super().__init__()
|
|
||||||
# We use a Conv1d with stride=patch_size to create non-overlapping patches
|
|
||||||
self.proj = nn.Conv1d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
|
||||||
|
|
||||||
if spectral_norm:
|
class DiscriminatorP(nn.Module):
|
||||||
self.proj = utils.spectral_norm(self.proj)
|
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||||
|
super(DiscriminatorP, self).__init__()
|
||||||
|
self.period = period
|
||||||
|
self.use_spectral_norm = use_spectral_norm
|
||||||
|
|
||||||
|
# Use spectral_norm for stability, or weight_norm for performance
|
||||||
|
norm_f = spectral_norm if use_spectral_norm else weight_norm
|
||||||
|
|
||||||
|
# We use 2D convs because we "fold" the 1D audio into 2D (Period x Time)
|
||||||
|
self.convs = nn.ModuleList([
|
||||||
|
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(2, 0))),
|
||||||
|
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(2, 0))),
|
||||||
|
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(2, 0))),
|
||||||
|
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(2, 0))),
|
||||||
|
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||||
|
])
|
||||||
|
|
||||||
|
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# x shape: (batch, 1, 8000)
|
fmap = []
|
||||||
x = self.proj(x) # shape: (batch, embed_dim, num_patches)
|
|
||||||
x = x.transpose(1, 2) # shape: (batch, num_patches, embed_dim)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class TransformerDiscriminator(nn.Module):
|
# 1d to 2d conversion: [B, C, T] -> [B, C, T/P, P]
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
audio_length=8000,
|
|
||||||
patch_size=16, # Lower this for higher precision (e.g., 8 or 16)
|
|
||||||
embed_dim=128, # Dimension of the transformer tokens
|
|
||||||
depth=4, # Number of Transformer blocks
|
|
||||||
heads=4, # Number of attention heads
|
|
||||||
mlp_dim=256, # Hidden dimension of the feed-forward layer
|
|
||||||
spectral_norm=True
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# 1. Calculate sequence length
|
|
||||||
self.num_patches = audio_length // patch_size
|
|
||||||
|
|
||||||
# 2. Patch Embedding (Tokenizer)
|
|
||||||
self.patch_embed = PatchEmbedding(1, embed_dim, patch_size, spectral_norm)
|
|
||||||
|
|
||||||
# 3. Class Token (like in BERT/ViT) to aggregate global info
|
|
||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
||||||
|
|
||||||
# 4. Positional Embedding (Learnable)
|
|
||||||
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
|
|
||||||
|
|
||||||
# 5. Transformer Encoder
|
|
||||||
encoder_layer = nn.TransformerEncoderLayer(
|
|
||||||
d_model=embed_dim,
|
|
||||||
nhead=heads,
|
|
||||||
dim_feedforward=mlp_dim,
|
|
||||||
dropout=0.1,
|
|
||||||
activation='gelu',
|
|
||||||
batch_first=True
|
|
||||||
)
|
|
||||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
|
|
||||||
|
|
||||||
# 6. Final Classification Head
|
|
||||||
self.norm = nn.LayerNorm(embed_dim)
|
|
||||||
self.head = nn.Linear(embed_dim, 1)
|
|
||||||
|
|
||||||
if spectral_norm:
|
|
||||||
self.head = utils.spectral_norm(self.head)
|
|
||||||
|
|
||||||
# Initialize weights
|
|
||||||
self._init_weights()
|
|
||||||
|
|
||||||
def _init_weights(self):
|
|
||||||
nn.init.normal_(self.cls_token, std=0.02)
|
|
||||||
nn.init.normal_(self.pos_embed, std=0.02)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
b, c, t = x.shape
|
b, c, t = x.shape
|
||||||
|
if t % self.period != 0: # Pad if not divisible by period
|
||||||
|
n_pad = self.period - (t % self.period)
|
||||||
|
x = F.pad(x, (0, n_pad), "reflect")
|
||||||
|
t = t + n_pad
|
||||||
|
|
||||||
# --- 1. Tokenize Audio ---
|
x = x.view(b, c, t // self.period, self.period)
|
||||||
x = self.patch_embed(x) # (Batch, Num_Patches, Embed_Dim)
|
|
||||||
|
|
||||||
# --- 2. Add CLS Token ---
|
for l in self.convs:
|
||||||
cls_tokens = self.cls_token.expand(b, -1, -1)
|
x = l(x)
|
||||||
x = torch.cat((cls_tokens, x), dim=1) # (Batch, Num_Patches + 1, Embed_Dim)
|
x = F.leaky_relu(x, 0.1)
|
||||||
|
fmap.append(x) # Store feature map for Feature Matching Loss
|
||||||
|
|
||||||
# --- 3. Add Positional Embeddings ---
|
x = self.conv_post(x)
|
||||||
x = x + self.pos_embed
|
fmap.append(x)
|
||||||
|
|
||||||
# --- 4. Transformer Layers ---
|
# Flatten back to 1D for score
|
||||||
x = self.transformer(x)
|
x = torch.flatten(x, 1, -1)
|
||||||
|
|
||||||
# --- 5. Classification (Use only CLS token) ---
|
return x, fmap
|
||||||
cls_output = x[:, 0] # Take the first token
|
|
||||||
cls_output = self.norm(cls_output)
|
|
||||||
|
|
||||||
score = self.head(cls_output) # (Batch, 1)
|
|
||||||
|
|
||||||
return score
|
class MultiPeriodDiscriminator(nn.Module):
|
||||||
|
def __init__(self, periods=[2, 3, 5, 7, 11]):
|
||||||
|
super(MultiPeriodDiscriminator, self).__init__()
|
||||||
|
self.discriminators = nn.ModuleList([
|
||||||
|
DiscriminatorP(p) for p in periods
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, y, y_hat):
|
||||||
|
y_d_rs = [] # Real scores
|
||||||
|
y_d_gs = [] # Generated (Fake) scores
|
||||||
|
fmap_rs = [] # Real feature maps
|
||||||
|
fmap_gs = [] # Generated (Fake) feature maps
|
||||||
|
|
||||||
|
for i, d in enumerate(self.discriminators):
|
||||||
|
y_d_r, fmap_r = d(y)
|
||||||
|
y_d_g, fmap_g = d(y_hat)
|
||||||
|
|
||||||
|
y_d_rs.append(y_d_r)
|
||||||
|
fmap_rs.append(fmap_r)
|
||||||
|
y_d_gs.append(y_d_g)
|
||||||
|
fmap_gs.append(fmap_g)
|
||||||
|
|
||||||
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
# 2. Multi-Scale Discriminator (MSD)
|
||||||
|
# Captures structure at different audio resolutions (raw, x0.5, x0.25).
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
|
||||||
|
class DiscriminatorS(nn.Module):
|
||||||
|
def __init__(self, use_spectral_norm=False):
|
||||||
|
super(DiscriminatorS, self).__init__()
|
||||||
|
norm_f = spectral_norm if use_spectral_norm else weight_norm
|
||||||
|
|
||||||
|
# Standard 1D Convolutions with large receptive field
|
||||||
|
self.convs = nn.ModuleList([
|
||||||
|
norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)),
|
||||||
|
norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
||||||
|
norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
||||||
|
norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
||||||
|
norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
||||||
|
norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||||
|
])
|
||||||
|
self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
fmap = []
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x)
|
||||||
|
x = F.leaky_relu(x, 0.1)
|
||||||
|
fmap.append(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
fmap.append(x)
|
||||||
|
x = torch.flatten(x, 1, -1)
|
||||||
|
return x, fmap
|
||||||
|
|
||||||
|
|
||||||
|
class MultiScaleDiscriminator(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MultiScaleDiscriminator, self).__init__()
|
||||||
|
# 3 Scales: Original, Downsampled x2, Downsampled x4
|
||||||
|
self.discriminators = nn.ModuleList([
|
||||||
|
DiscriminatorS(use_spectral_norm=True),
|
||||||
|
DiscriminatorS(),
|
||||||
|
DiscriminatorS(),
|
||||||
|
])
|
||||||
|
self.meanpools = nn.ModuleList([
|
||||||
|
nn.AvgPool1d(4, 2, padding=2),
|
||||||
|
nn.AvgPool1d(4, 2, padding=2)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, y, y_hat):
|
||||||
|
y_d_rs = []
|
||||||
|
y_d_gs = []
|
||||||
|
fmap_rs = []
|
||||||
|
fmap_gs = []
|
||||||
|
|
||||||
|
for i, d in enumerate(self.discriminators):
|
||||||
|
if i != 0:
|
||||||
|
# Downsample input for subsequent discriminators
|
||||||
|
y = self.meanpools[i-1](y)
|
||||||
|
y_hat = self.meanpools[i-1](y_hat)
|
||||||
|
|
||||||
|
y_d_r, fmap_r = d(y)
|
||||||
|
y_d_g, fmap_g = d(y_hat)
|
||||||
|
|
||||||
|
y_d_rs.append(y_d_r)
|
||||||
|
fmap_rs.append(fmap_r)
|
||||||
|
y_d_gs.append(y_d_g)
|
||||||
|
fmap_gs.append(fmap_g)
|
||||||
|
|
||||||
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
# 3. Master Wrapper
|
||||||
|
# Combines MPD and MSD into one class to fit your training script.
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
|
||||||
|
class SISUDiscriminator(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(SISUDiscriminator, self).__init__()
|
||||||
|
self.mpd = MultiPeriodDiscriminator()
|
||||||
|
self.msd = MultiScaleDiscriminator()
|
||||||
|
|
||||||
|
def forward(self, y, y_hat):
|
||||||
|
# Return format:
|
||||||
|
# scores_real, scores_fake, features_real, features_fake
|
||||||
|
|
||||||
|
# Run Multi-Period
|
||||||
|
mpd_y_d_rs, mpd_y_d_gs, mpd_fmap_rs, mpd_fmap_gs = self.mpd(y, y_hat)
|
||||||
|
|
||||||
|
# Run Multi-Scale
|
||||||
|
msd_y_d_rs, msd_y_d_gs, msd_fmap_rs, msd_fmap_gs = self.msd(y, y_hat)
|
||||||
|
|
||||||
|
# Combine all results
|
||||||
|
return (
|
||||||
|
mpd_y_d_rs + msd_y_d_rs, # All real scores
|
||||||
|
mpd_y_d_gs + msd_y_d_gs, # All fake scores
|
||||||
|
mpd_fmap_rs + msd_fmap_rs, # All real feature maps
|
||||||
|
mpd_fmap_gs + msd_fmap_gs # All fake feature maps
|
||||||
|
)
|
||||||
|
|||||||
41
training.py
41
training.py
@@ -3,6 +3,7 @@ import datetime
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import tqdm
|
import tqdm
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
@@ -39,10 +40,13 @@ accelerator = Accelerator(mixed_precision="bf16")
|
|||||||
# Models
|
# Models
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
generator = SISUGenerator()
|
generator = SISUGenerator()
|
||||||
|
# Note: SISUDiscriminator is now an Ensemble (MPD + MSD)
|
||||||
discriminator = SISUDiscriminator()
|
discriminator = SISUDiscriminator()
|
||||||
|
|
||||||
accelerator.print("🔨 | Compiling models...")
|
accelerator.print("🔨 | Compiling models...")
|
||||||
|
|
||||||
|
# Torch compile is great, but if you hit errors with the new List/Tuple outputs
|
||||||
|
# of the discriminator, you might need to disable it for D.
|
||||||
generator = torch.compile(generator)
|
generator = torch.compile(generator)
|
||||||
discriminator = torch.compile(discriminator)
|
discriminator = torch.compile(discriminator)
|
||||||
|
|
||||||
@@ -108,21 +112,24 @@ models_dir = "./models"
|
|||||||
os.makedirs(models_dir, exist_ok=True)
|
os.makedirs(models_dir, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
def save_ckpt(path, epoch):
|
def save_ckpt(path, epoch, loss=None, is_best=False):
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
accelerator.save(
|
state = {
|
||||||
{
|
"epoch": epoch,
|
||||||
"epoch": epoch,
|
"G": accelerator.unwrap_model(generator).state_dict(),
|
||||||
"G": accelerator.unwrap_model(generator).state_dict(),
|
"D": accelerator.unwrap_model(discriminator).state_dict(),
|
||||||
"D": accelerator.unwrap_model(discriminator).state_dict(),
|
"optG": optimizer_g.state_dict(),
|
||||||
"optG": optimizer_g.state_dict(),
|
"optD": optimizer_d.state_dict(),
|
||||||
"optD": optimizer_d.state_dict(),
|
"schedG": scheduler_g.state_dict(),
|
||||||
"schedG": scheduler_g.state_dict(),
|
"schedD": scheduler_d.state_dict()
|
||||||
"schedD": scheduler_d.state_dict(),
|
}
|
||||||
},
|
|
||||||
path,
|
accelerator.save(state, os.path.join(models_dir, "last.pt"))
|
||||||
)
|
|
||||||
|
if is_best:
|
||||||
|
accelerator.save(state, os.path.join(models_dir, "best.pt"))
|
||||||
|
accelerator.print(f"🌟 | New best model saved with G Loss: {loss:.4f}")
|
||||||
|
|
||||||
|
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
@@ -143,9 +150,8 @@ if args.resume:
|
|||||||
else:
|
else:
|
||||||
accelerator.print("⚠️ | Resume requested but no checkpoint found. Starting fresh.")
|
accelerator.print("⚠️ | Resume requested but no checkpoint found. Starting fresh.")
|
||||||
|
|
||||||
accelerator.print("🏋️ | Started training...")
|
|
||||||
|
|
||||||
smallest_loss = float('inf')
|
accelerator.print("🏋️ | Started training...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
@@ -172,7 +178,6 @@ try:
|
|||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
d_loss = discriminator_train(
|
d_loss = discriminator_train(
|
||||||
high_quality,
|
high_quality,
|
||||||
low_quality.detach(),
|
|
||||||
discriminator,
|
discriminator,
|
||||||
generator_output.detach()
|
generator_output.detach()
|
||||||
)
|
)
|
||||||
@@ -218,7 +223,6 @@ try:
|
|||||||
steps += 1
|
steps += 1
|
||||||
progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs")
|
progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs")
|
||||||
|
|
||||||
# epoch averages & schedulers
|
|
||||||
if steps == 0:
|
if steps == 0:
|
||||||
accelerator.print("🪹 | No steps in epoch (empty dataloader?). Exiting.")
|
accelerator.print("🪹 | No steps in epoch (empty dataloader?). Exiting.")
|
||||||
break
|
break
|
||||||
@@ -230,9 +234,6 @@ try:
|
|||||||
scheduler_g.step(mean_g)
|
scheduler_g.step(mean_g)
|
||||||
|
|
||||||
save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
|
save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
|
||||||
if smallest_loss > mean_g:
|
|
||||||
smallest_loss = mean_g
|
|
||||||
save_ckpt(os.path.join(models_dir, "latest-smallest_loss.pt"), epoch)
|
|
||||||
accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
|
accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -2,13 +2,14 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
|
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
|
||||||
|
|
||||||
|
# Keep STFT settings as is
|
||||||
stft_loss_fn = MultiResolutionSTFTLoss(
|
stft_loss_fn = MultiResolutionSTFTLoss(
|
||||||
fft_sizes=[512, 1024, 2048],
|
fft_sizes=[512, 1024, 2048],
|
||||||
hop_sizes=[64, 128, 256],
|
hop_sizes=[64, 128, 256],
|
||||||
win_lengths=[256, 512, 1024]
|
win_lengths=[256, 512, 1024]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def feature_matching_loss(fmap_r, fmap_g):
|
def feature_matching_loss(fmap_r, fmap_g):
|
||||||
"""
|
"""
|
||||||
Computes L1 distance between real and fake feature maps.
|
Computes L1 distance between real and fake feature maps.
|
||||||
@@ -16,11 +17,9 @@ def feature_matching_loss(fmap_r, fmap_g):
|
|||||||
loss = 0
|
loss = 0
|
||||||
for dr, dg in zip(fmap_r, fmap_g):
|
for dr, dg in zip(fmap_r, fmap_g):
|
||||||
for rl, gl in zip(dr, dg):
|
for rl, gl in zip(dr, dg):
|
||||||
# Stop gradient on real features to save memory/computation
|
|
||||||
rl = rl.detach()
|
rl = rl.detach()
|
||||||
loss += torch.mean(torch.abs(rl - gl))
|
loss += torch.mean(torch.abs(rl - gl))
|
||||||
|
|
||||||
# Scale by number of feature maps to keep loss magnitude reasonable
|
|
||||||
return loss * 2
|
return loss * 2
|
||||||
|
|
||||||
|
|
||||||
@@ -33,11 +32,8 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
|||||||
r_losses = []
|
r_losses = []
|
||||||
g_losses = []
|
g_losses = []
|
||||||
|
|
||||||
# Iterate over both MPD and MSD outputs
|
|
||||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||||
# Real should be 1.0
|
|
||||||
r_loss = torch.mean((dr - 1) ** 2)
|
r_loss = torch.mean((dr - 1) ** 2)
|
||||||
# Fake should be 0.0
|
|
||||||
g_loss = torch.mean(dg ** 2)
|
g_loss = torch.mean(dg ** 2)
|
||||||
|
|
||||||
loss += (r_loss + g_loss)
|
loss += (r_loss + g_loss)
|
||||||
@@ -61,16 +57,11 @@ def generator_adv_loss(disc_generated_outputs):
|
|||||||
|
|
||||||
def discriminator_train(
|
def discriminator_train(
|
||||||
high_quality,
|
high_quality,
|
||||||
low_quality,
|
|
||||||
discriminator,
|
discriminator,
|
||||||
generator_output
|
generator_output
|
||||||
):
|
):
|
||||||
# 1. Forward pass through the Ensemble Discriminator
|
|
||||||
# Note: We pass inputs separately now: (Real_Target, Fake_Candidate)
|
|
||||||
# We detach generator_output because we are only optimizing D here
|
|
||||||
y_d_rs, y_d_gs, _, _ = discriminator(high_quality, generator_output.detach())
|
y_d_rs, y_d_gs, _, _ = discriminator(high_quality, generator_output.detach())
|
||||||
|
|
||||||
# 2. Calculate Loss (LSGAN)
|
|
||||||
d_loss, _, _ = discriminator_loss(y_d_rs, y_d_gs)
|
d_loss, _, _ = discriminator_loss(y_d_rs, y_d_gs)
|
||||||
|
|
||||||
return d_loss
|
return d_loss
|
||||||
@@ -83,25 +74,14 @@ def generator_train(
|
|||||||
discriminator,
|
discriminator,
|
||||||
generator_output
|
generator_output
|
||||||
):
|
):
|
||||||
# 1. Forward pass through Discriminator
|
|
||||||
# We do NOT detach generator_output here, we need gradients for G
|
|
||||||
y_d_rs, y_d_gs, fmap_rs, fmap_gs = discriminator(high_quality, generator_output)
|
y_d_rs, y_d_gs, fmap_rs, fmap_gs = discriminator(high_quality, generator_output)
|
||||||
|
|
||||||
# 2. Adversarial Loss (Try to fool D into thinking G is Real)
|
|
||||||
loss_gen_adv = generator_adv_loss(y_d_gs)
|
loss_gen_adv = generator_adv_loss(y_d_gs)
|
||||||
|
|
||||||
# 3. Feature Matching Loss (Force G to match internal features of D)
|
|
||||||
loss_fm = feature_matching_loss(fmap_rs, fmap_gs)
|
loss_fm = feature_matching_loss(fmap_rs, fmap_gs)
|
||||||
|
|
||||||
# 4. Mel-Spectrogram / STFT Loss (Audio Quality)
|
|
||||||
stft_loss = stft_loss_fn(high_quality, generator_output)["total"]
|
stft_loss = stft_loss_fn(high_quality, generator_output)["total"]
|
||||||
|
|
||||||
# -----------------------------------------
|
|
||||||
# 5. Combine Losses
|
|
||||||
# -----------------------------------------
|
|
||||||
# Recommended weights for HiFi-GAN/EnCodec style architectures:
|
|
||||||
# STFT is dominant (45), FM provides stability (2), Adv provides texture (1)
|
|
||||||
|
|
||||||
lambda_stft = 45.0
|
lambda_stft = 45.0
|
||||||
lambda_fm = 2.0
|
lambda_fm = 2.0
|
||||||
lambda_adv = 1.0
|
lambda_adv = 1.0
|
||||||
|
|||||||
Reference in New Issue
Block a user