⚗️ | Added MultiPeriodDiscriminator implementation from HiFi-GAN
This commit is contained in:
142
discriminator.py
142
discriminator.py
@@ -1,70 +1,98 @@
|
|||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.utils as utils
|
import torch.nn.utils as utils
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class PatchEmbedding(nn.Module):
|
||||||
|
"""
|
||||||
|
Converts raw audio into a sequence of embeddings (tokens).
|
||||||
|
Small patch_size = Higher Precision (more tokens, finer detail).
|
||||||
|
Large patch_size = Lower Precision (fewer tokens, more global).
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, embed_dim, patch_size, spectral_norm=True):
|
||||||
|
super().__init__()
|
||||||
|
# We use a Conv1d with stride=patch_size to create non-overlapping patches
|
||||||
|
self.proj = nn.Conv1d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
def discriminator_block(
|
if spectral_norm:
|
||||||
in_channels,
|
self.proj = utils.spectral_norm(self.proj)
|
||||||
out_channels,
|
|
||||||
kernel_size=15,
|
|
||||||
stride=1,
|
|
||||||
dilation=1
|
|
||||||
):
|
|
||||||
padding = dilation * (kernel_size - 1) // 2
|
|
||||||
|
|
||||||
conv_layer = nn.Conv1d(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
stride=stride,
|
|
||||||
dilation=dilation,
|
|
||||||
padding=padding
|
|
||||||
)
|
|
||||||
|
|
||||||
conv_layer = utils.spectral_norm(conv_layer)
|
|
||||||
leaky_relu = nn.LeakyReLU(0.2)
|
|
||||||
|
|
||||||
return nn.Sequential(conv_layer, leaky_relu)
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionBlock(nn.Module):
|
|
||||||
def __init__(self, channels):
|
|
||||||
super(AttentionBlock, self).__init__()
|
|
||||||
self.attention = nn.Sequential(
|
|
||||||
nn.Conv1d(channels, channels // 4, kernel_size=1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Conv1d(channels // 4, channels, kernel_size=1),
|
|
||||||
nn.Sigmoid(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
attention_weights = self.attention(x)
|
# x shape: (batch, 1, 8000)
|
||||||
return x + (x * attention_weights)
|
x = self.proj(x) # shape: (batch, embed_dim, num_patches)
|
||||||
|
x = x.transpose(1, 2) # shape: (batch, num_patches, embed_dim)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class TransformerDiscriminator(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
audio_length=8000,
|
||||||
|
patch_size=16, # Lower this for higher precision (e.g., 8 or 16)
|
||||||
|
embed_dim=128, # Dimension of the transformer tokens
|
||||||
|
depth=4, # Number of Transformer blocks
|
||||||
|
heads=4, # Number of attention heads
|
||||||
|
mlp_dim=256, # Hidden dimension of the feed-forward layer
|
||||||
|
spectral_norm=True
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
class SISUDiscriminator(nn.Module):
|
# 1. Calculate sequence length
|
||||||
def __init__(self, layers=8):
|
self.num_patches = audio_length // patch_size
|
||||||
super(SISUDiscriminator, self).__init__()
|
|
||||||
self.discriminator_blocks = nn.Sequential(
|
# 2. Patch Embedding (Tokenizer)
|
||||||
# 1 -> 32
|
self.patch_embed = PatchEmbedding(1, embed_dim, patch_size, spectral_norm)
|
||||||
discriminator_block(2, layers),
|
|
||||||
AttentionBlock(layers),
|
# 3. Class Token (like in BERT/ViT) to aggregate global info
|
||||||
# 32 -> 64
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||||
discriminator_block(layers, layers * 2, dilation=2),
|
|
||||||
# 64 -> 128
|
# 4. Positional Embedding (Learnable)
|
||||||
discriminator_block(layers * 2, layers * 4, dilation=4),
|
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
|
||||||
AttentionBlock(layers * 4),
|
|
||||||
# 128 -> 256
|
# 5. Transformer Encoder
|
||||||
discriminator_block(layers * 4, layers * 8, stride=4),
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
# 256 -> 512
|
d_model=embed_dim,
|
||||||
# discriminator_block(layers * 8, layers * 16, stride=4)
|
nhead=heads,
|
||||||
|
dim_feedforward=mlp_dim,
|
||||||
|
dropout=0.1,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True
|
||||||
)
|
)
|
||||||
|
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
|
||||||
|
|
||||||
self.final_conv = nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1)
|
# 6. Final Classification Head
|
||||||
|
self.norm = nn.LayerNorm(embed_dim)
|
||||||
|
self.head = nn.Linear(embed_dim, 1)
|
||||||
|
|
||||||
self.avg_pool = nn.AdaptiveAvgPool1d(1)
|
if spectral_norm:
|
||||||
|
self.head = utils.spectral_norm(self.head)
|
||||||
|
|
||||||
|
# Initialize weights
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
nn.init.normal_(self.cls_token, std=0.02)
|
||||||
|
nn.init.normal_(self.pos_embed, std=0.02)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.discriminator_blocks(x)
|
b, c, t = x.shape
|
||||||
x = self.final_conv(x)
|
|
||||||
x = self.avg_pool(x)
|
# --- 1. Tokenize Audio ---
|
||||||
return x.squeeze(2)
|
x = self.patch_embed(x) # (Batch, Num_Patches, Embed_Dim)
|
||||||
|
|
||||||
|
# --- 2. Add CLS Token ---
|
||||||
|
cls_tokens = self.cls_token.expand(b, -1, -1)
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1) # (Batch, Num_Patches + 1, Embed_Dim)
|
||||||
|
|
||||||
|
# --- 3. Add Positional Embeddings ---
|
||||||
|
x = x + self.pos_embed
|
||||||
|
|
||||||
|
# --- 4. Transformer Layers ---
|
||||||
|
x = self.transformer(x)
|
||||||
|
|
||||||
|
# --- 5. Classification (Use only CLS token) ---
|
||||||
|
cls_output = x[:, 0] # Take the first token
|
||||||
|
cls_output = self.norm(cls_output)
|
||||||
|
|
||||||
|
score = self.head(cls_output) # (Batch, 1)
|
||||||
|
|
||||||
|
return score
|
||||||
|
|||||||
40
generator.py
40
generator.py
@@ -1,19 +1,20 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
|
||||||
def GeneratorBlock(in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
|
def GeneratorBlock(in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
|
||||||
padding = (kernel_size - 1) // 2 * dilation
|
padding = (kernel_size - 1) // 2 * dilation
|
||||||
|
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.Conv1d(
|
|
||||||
|
weight_norm(nn.Conv1d(
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
padding=padding
|
padding=padding
|
||||||
),
|
)),
|
||||||
nn.InstanceNorm1d(out_channels),
|
|
||||||
nn.PReLU(num_parameters=1, init=0.1),
|
nn.PReLU(num_parameters=1, init=0.1),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,9 +23,9 @@ class AttentionBlock(nn.Module):
|
|||||||
def __init__(self, channels):
|
def __init__(self, channels):
|
||||||
super(AttentionBlock, self).__init__()
|
super(AttentionBlock, self).__init__()
|
||||||
self.attention = nn.Sequential(
|
self.attention = nn.Sequential(
|
||||||
nn.Conv1d(channels, channels // 4, kernel_size=1),
|
weight_norm(nn.Conv1d(channels, channels // 4, kernel_size=1)),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv1d(channels // 4, channels, kernel_size=1),
|
weight_norm(nn.Conv1d(channels // 4, channels, kernel_size=1)),
|
||||||
nn.Sigmoid(),
|
nn.Sigmoid(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -49,21 +50,21 @@ class ResidualInResidualBlock(nn.Module):
|
|||||||
x = self.attention(x)
|
x = self.attention(x)
|
||||||
return x + residual
|
return x + residual
|
||||||
|
|
||||||
def UpsampleBlock(in_channels, out_channels):
|
def UpsampleBlock(in_channels, out_channels, scale_factor=2):
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.ConvTranspose1d(
|
nn.Upsample(scale_factor=scale_factor, mode='nearest'),
|
||||||
|
weight_norm(nn.Conv1d(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
kernel_size=4,
|
kernel_size=3,
|
||||||
stride=2,
|
stride=1,
|
||||||
padding=1
|
padding=1
|
||||||
),
|
)),
|
||||||
nn.InstanceNorm1d(out_channels),
|
|
||||||
nn.PReLU(num_parameters=1, init=0.1)
|
nn.PReLU(num_parameters=1, init=0.1)
|
||||||
)
|
)
|
||||||
|
|
||||||
class SISUGenerator(nn.Module):
|
class SISUGenerator(nn.Module):
|
||||||
def __init__(self, channels=32, num_rirb=1):
|
def __init__(self, channels=32, num_rirb=4):
|
||||||
super(SISUGenerator, self).__init__()
|
super(SISUGenerator, self).__init__()
|
||||||
|
|
||||||
self.first_conv = GeneratorBlock(1, channels)
|
self.first_conv = GeneratorBlock(1, channels)
|
||||||
@@ -73,10 +74,9 @@ class SISUGenerator(nn.Module):
|
|||||||
self.downsample_2 = GeneratorBlock(channels * 2, channels * 4, stride=2)
|
self.downsample_2 = GeneratorBlock(channels * 2, channels * 4, stride=2)
|
||||||
self.downsample_2_attn = AttentionBlock(channels * 4)
|
self.downsample_2_attn = AttentionBlock(channels * 4)
|
||||||
|
|
||||||
self.rirb = ResidualInResidualBlock(channels * 4)
|
self.rirb = nn.Sequential(
|
||||||
# self.rirb = nn.Sequential(
|
*[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)]
|
||||||
# *[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)]
|
)
|
||||||
# )
|
|
||||||
|
|
||||||
self.upsample = UpsampleBlock(channels * 4, channels * 2)
|
self.upsample = UpsampleBlock(channels * 4, channels * 2)
|
||||||
self.upsample_attn = AttentionBlock(channels * 2)
|
self.upsample_attn = AttentionBlock(channels * 2)
|
||||||
@@ -87,13 +87,15 @@ class SISUGenerator(nn.Module):
|
|||||||
self.compress_2 = GeneratorBlock(channels * 2, channels)
|
self.compress_2 = GeneratorBlock(channels * 2, channels)
|
||||||
|
|
||||||
self.final_conv = nn.Sequential(
|
self.final_conv = nn.Sequential(
|
||||||
nn.Conv1d(channels, 1, kernel_size=7, padding=3),
|
weight_norm(nn.Conv1d(channels, 1, kernel_size=7, padding=3)),
|
||||||
nn.Tanh()
|
nn.Tanh()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual_input = x
|
residual_input = x
|
||||||
|
|
||||||
|
# Encoding
|
||||||
x1 = self.first_conv(x)
|
x1 = self.first_conv(x)
|
||||||
|
|
||||||
x2 = self.downsample(x1)
|
x2 = self.downsample(x1)
|
||||||
@@ -102,8 +104,10 @@ class SISUGenerator(nn.Module):
|
|||||||
x3 = self.downsample_2(x2)
|
x3 = self.downsample_2(x2)
|
||||||
x3 = self.downsample_2_attn(x3)
|
x3 = self.downsample_2_attn(x3)
|
||||||
|
|
||||||
|
# Bottleneck (Deep Residual processing)
|
||||||
x_rirb = self.rirb(x3)
|
x_rirb = self.rirb(x3)
|
||||||
|
|
||||||
|
# Decoding with Skip Connections
|
||||||
up1 = self.upsample(x_rirb)
|
up1 = self.upsample(x_rirb)
|
||||||
up1 = self.upsample_attn(up1)
|
up1 = self.upsample_attn(up1)
|
||||||
|
|
||||||
|
|||||||
44
training.py
44
training.py
@@ -3,7 +3,6 @@ import datetime
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import tqdm
|
import tqdm
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
@@ -23,7 +22,7 @@ parser.add_argument(
|
|||||||
"--epochs", type=int, default=5000, help="Number of training epochs"
|
"--epochs", type=int, default=5000, help="Number of training epochs"
|
||||||
)
|
)
|
||||||
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
|
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
|
||||||
parser.add_argument("--num_workers", type=int, default=2, help="DataLoader num_workers")
|
parser.add_argument("--num_workers", type=int, default=4, help="DataLoader num_workers") # Increased workers slightly
|
||||||
parser.add_argument("--debug", action="store_true", help="Print debug logs")
|
parser.add_argument("--debug", action="store_true", help="Print debug logs")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA"
|
"--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA"
|
||||||
@@ -94,8 +93,6 @@ scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|||||||
optimizer_d, mode="min", factor=0.5, patience=5
|
optimizer_d, mode="min", factor=0.5, patience=5
|
||||||
)
|
)
|
||||||
|
|
||||||
criterion_d = nn.MSELoss()
|
|
||||||
|
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
# Prepare accelerator
|
# Prepare accelerator
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
@@ -131,23 +128,25 @@ def save_ckpt(path, epoch):
|
|||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
if args.resume:
|
if args.resume:
|
||||||
ckpt_path = os.path.join(models_dir, "last.pt")
|
ckpt_path = os.path.join(models_dir, "last.pt")
|
||||||
ckpt = torch.load(ckpt_path)
|
if os.path.exists(ckpt_path):
|
||||||
|
ckpt = torch.load(ckpt_path)
|
||||||
|
|
||||||
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
|
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
|
||||||
accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"])
|
accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"])
|
||||||
optimizer_g.load_state_dict(ckpt["optG"])
|
optimizer_g.load_state_dict(ckpt["optG"])
|
||||||
optimizer_d.load_state_dict(ckpt["optD"])
|
optimizer_d.load_state_dict(ckpt["optD"])
|
||||||
scheduler_g.load_state_dict(ckpt["schedG"])
|
scheduler_g.load_state_dict(ckpt["schedG"])
|
||||||
scheduler_d.load_state_dict(ckpt["schedD"])
|
scheduler_d.load_state_dict(ckpt["schedD"])
|
||||||
|
|
||||||
start_epoch = ckpt.get("epoch", 1)
|
start_epoch = ckpt.get("epoch", 1)
|
||||||
accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!")
|
accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!")
|
||||||
|
else:
|
||||||
real_buf = torch.full((loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32)
|
accelerator.print("⚠️ | Resume requested but no checkpoint found. Starting fresh.")
|
||||||
fake_buf = torch.zeros((loader_batch_size, 1), device=accelerator.device, dtype=torch.float32)
|
|
||||||
|
|
||||||
accelerator.print("🏋️ | Started training...")
|
accelerator.print("🏋️ | Started training...")
|
||||||
|
|
||||||
|
smallest_loss = float('inf')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
generator.train()
|
generator.train()
|
||||||
@@ -164,11 +163,6 @@ try:
|
|||||||
(high_quality, low_quality),
|
(high_quality, low_quality),
|
||||||
(high_sample_rate, low_sample_rate),
|
(high_sample_rate, low_sample_rate),
|
||||||
) in enumerate(progress_bar):
|
) in enumerate(progress_bar):
|
||||||
batch_size = high_quality.size(0)
|
|
||||||
|
|
||||||
real_labels = real_buf[:batch_size].to(accelerator.device)
|
|
||||||
fake_labels = fake_buf[:batch_size].to(accelerator.device)
|
|
||||||
|
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
generator_output = generator(low_quality)
|
generator_output = generator(low_quality)
|
||||||
|
|
||||||
@@ -179,10 +173,7 @@ try:
|
|||||||
d_loss = discriminator_train(
|
d_loss = discriminator_train(
|
||||||
high_quality,
|
high_quality,
|
||||||
low_quality.detach(),
|
low_quality.detach(),
|
||||||
real_labels,
|
|
||||||
fake_labels,
|
|
||||||
discriminator,
|
discriminator,
|
||||||
criterion_d,
|
|
||||||
generator_output.detach()
|
generator_output.detach()
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -197,10 +188,8 @@ try:
|
|||||||
g_total, g_adv = generator_train(
|
g_total, g_adv = generator_train(
|
||||||
low_quality,
|
low_quality,
|
||||||
high_quality,
|
high_quality,
|
||||||
real_labels,
|
|
||||||
generator,
|
generator,
|
||||||
discriminator,
|
discriminator,
|
||||||
criterion_d,
|
|
||||||
generator_output
|
generator_output
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -241,6 +230,9 @@ try:
|
|||||||
scheduler_g.step(mean_g)
|
scheduler_g.step(mean_g)
|
||||||
|
|
||||||
save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
|
save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
|
||||||
|
if smallest_loss > mean_g:
|
||||||
|
smallest_loss = mean_g
|
||||||
|
save_ckpt(os.path.join(models_dir, "latest-smallest_loss.pt"), epoch)
|
||||||
accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
|
accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -1,58 +1,113 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
|
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
|
||||||
|
|
||||||
# stft_loss_fn = MultiResolutionSTFTLoss(
|
|
||||||
# fft_sizes=[512, 1024, 2048, 4096],
|
|
||||||
# hop_sizes=[128, 256, 512, 1024],
|
|
||||||
# win_lengths=[512, 1024, 2048, 4096]
|
|
||||||
# )
|
|
||||||
stft_loss_fn = MultiResolutionSTFTLoss(
|
stft_loss_fn = MultiResolutionSTFTLoss(
|
||||||
fft_sizes=[512, 1024, 2048],
|
fft_sizes=[512, 1024, 2048],
|
||||||
hop_sizes=[64, 128, 256],
|
hop_sizes=[64, 128, 256],
|
||||||
win_lengths=[256, 512, 1024]
|
win_lengths=[256, 512, 1024]
|
||||||
)
|
)
|
||||||
|
|
||||||
def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor:
|
def feature_matching_loss(fmap_r, fmap_g):
|
||||||
absolute_difference = torch.abs(input_one - input_two)
|
"""
|
||||||
return torch.mean(absolute_difference)
|
Computes L1 distance between real and fake feature maps.
|
||||||
|
"""
|
||||||
|
loss = 0
|
||||||
|
for dr, dg in zip(fmap_r, fmap_g):
|
||||||
|
for rl, gl in zip(dr, dg):
|
||||||
|
# Stop gradient on real features to save memory/computation
|
||||||
|
rl = rl.detach()
|
||||||
|
loss += torch.mean(torch.abs(rl - gl))
|
||||||
|
|
||||||
|
# Scale by number of feature maps to keep loss magnitude reasonable
|
||||||
|
return loss * 2
|
||||||
|
|
||||||
|
|
||||||
|
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||||
|
"""
|
||||||
|
Least Squares GAN Loss (LSGAN) for the Discriminator.
|
||||||
|
Objective: Real -> 1, Fake -> 0
|
||||||
|
"""
|
||||||
|
loss = 0
|
||||||
|
r_losses = []
|
||||||
|
g_losses = []
|
||||||
|
|
||||||
|
# Iterate over both MPD and MSD outputs
|
||||||
|
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||||
|
# Real should be 1.0
|
||||||
|
r_loss = torch.mean((dr - 1) ** 2)
|
||||||
|
# Fake should be 0.0
|
||||||
|
g_loss = torch.mean(dg ** 2)
|
||||||
|
|
||||||
|
loss += (r_loss + g_loss)
|
||||||
|
r_losses.append(r_loss.item())
|
||||||
|
g_losses.append(g_loss.item())
|
||||||
|
|
||||||
|
return loss, r_losses, g_losses
|
||||||
|
|
||||||
|
|
||||||
|
def generator_adv_loss(disc_generated_outputs):
|
||||||
|
"""
|
||||||
|
Least Squares GAN Loss for the Generator.
|
||||||
|
Objective: Fake -> 1 (Fool the discriminator)
|
||||||
|
"""
|
||||||
|
loss = 0
|
||||||
|
for dg in zip(disc_generated_outputs):
|
||||||
|
dg = dg[0] # Unpack tuple
|
||||||
|
loss += torch.mean((dg - 1) ** 2)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def discriminator_train(
|
def discriminator_train(
|
||||||
high_quality,
|
high_quality,
|
||||||
low_quality,
|
low_quality,
|
||||||
high_labels,
|
|
||||||
low_labels,
|
|
||||||
discriminator,
|
discriminator,
|
||||||
criterion,
|
|
||||||
generator_output
|
generator_output
|
||||||
):
|
):
|
||||||
|
# 1. Forward pass through the Ensemble Discriminator
|
||||||
|
# Note: We pass inputs separately now: (Real_Target, Fake_Candidate)
|
||||||
|
# We detach generator_output because we are only optimizing D here
|
||||||
|
y_d_rs, y_d_gs, _, _ = discriminator(high_quality, generator_output.detach())
|
||||||
|
|
||||||
real_pair = torch.cat((low_quality, high_quality), dim=1)
|
# 2. Calculate Loss (LSGAN)
|
||||||
decision_real = discriminator(real_pair)
|
d_loss, _, _ = discriminator_loss(y_d_rs, y_d_gs)
|
||||||
d_loss_real = criterion(decision_real, high_labels)
|
|
||||||
|
|
||||||
fake_pair = torch.cat((low_quality, generator_output), dim=1)
|
|
||||||
decision_fake = discriminator(fake_pair)
|
|
||||||
d_loss_fake = criterion(decision_fake, low_labels)
|
|
||||||
|
|
||||||
d_loss = (d_loss_real + d_loss_fake) / 2.0
|
|
||||||
return d_loss
|
return d_loss
|
||||||
|
|
||||||
|
|
||||||
def generator_train(
|
def generator_train(
|
||||||
low_quality, high_quality, real_labels, generator, discriminator, adv_criterion, generator_output):
|
low_quality,
|
||||||
|
high_quality,
|
||||||
|
generator,
|
||||||
|
discriminator,
|
||||||
|
generator_output
|
||||||
|
):
|
||||||
|
# 1. Forward pass through Discriminator
|
||||||
|
# We do NOT detach generator_output here, we need gradients for G
|
||||||
|
y_d_rs, y_d_gs, fmap_rs, fmap_gs = discriminator(high_quality, generator_output)
|
||||||
|
|
||||||
fake_pair = torch.cat((low_quality, generator_output), dim=1)
|
# 2. Adversarial Loss (Try to fool D into thinking G is Real)
|
||||||
|
loss_gen_adv = generator_adv_loss(y_d_gs)
|
||||||
|
|
||||||
discriminator_decision = discriminator(fake_pair)
|
# 3. Feature Matching Loss (Force G to match internal features of D)
|
||||||
adversarial_loss = adv_criterion(discriminator_decision, real_labels)
|
loss_fm = feature_matching_loss(fmap_rs, fmap_gs)
|
||||||
|
|
||||||
mae_loss = signal_mae(generator_output, high_quality)
|
# 4. Mel-Spectrogram / STFT Loss (Audio Quality)
|
||||||
stft_loss = stft_loss_fn(high_quality, generator_output)["total"]
|
stft_loss = stft_loss_fn(high_quality, generator_output)["total"]
|
||||||
|
|
||||||
lambda_mae = 10.0
|
# -----------------------------------------
|
||||||
lambda_stft = 2.5
|
# 5. Combine Losses
|
||||||
lambda_adv = 2.5
|
# -----------------------------------------
|
||||||
combined_loss = (lambda_mae * mae_loss) + (lambda_stft * stft_loss) + (lambda_adv * adversarial_loss)
|
# Recommended weights for HiFi-GAN/EnCodec style architectures:
|
||||||
return combined_loss, adversarial_loss
|
# STFT is dominant (45), FM provides stability (2), Adv provides texture (1)
|
||||||
|
|
||||||
|
lambda_stft = 45.0
|
||||||
|
lambda_fm = 2.0
|
||||||
|
lambda_adv = 1.0
|
||||||
|
|
||||||
|
combined_loss = (lambda_stft * stft_loss) + \
|
||||||
|
(lambda_fm * loss_fm) + \
|
||||||
|
(lambda_adv * loss_gen_adv)
|
||||||
|
|
||||||
|
return combined_loss, loss_gen_adv
|
||||||
|
|||||||
Reference in New Issue
Block a user