⚗️ | Added MultiPeriodDiscriminator implementation from HiFi-GAN

This commit is contained in:
2025-12-06 18:04:18 +02:00
parent bf0a6e58e9
commit e3e555794e
3 changed files with 187 additions and 125 deletions

View File

@@ -3,6 +3,7 @@ import datetime
import os
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm
from accelerate import Accelerator
@@ -39,10 +40,13 @@ accelerator = Accelerator(mixed_precision="bf16")
# Models
# ---------------------------
generator = SISUGenerator()
# Note: SISUDiscriminator is now an Ensemble (MPD + MSD)
discriminator = SISUDiscriminator()
accelerator.print("🔨 | Compiling models...")
# Torch compile is great, but if you hit errors with the new List/Tuple outputs
# of the discriminator, you might need to disable it for D.
generator = torch.compile(generator)
discriminator = torch.compile(discriminator)
@@ -108,21 +112,24 @@ models_dir = "./models"
os.makedirs(models_dir, exist_ok=True)
def save_ckpt(path, epoch):
def save_ckpt(path, epoch, loss=None, is_best=False):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
accelerator.save(
{
"epoch": epoch,
"G": accelerator.unwrap_model(generator).state_dict(),
"D": accelerator.unwrap_model(discriminator).state_dict(),
"optG": optimizer_g.state_dict(),
"optD": optimizer_d.state_dict(),
"schedG": scheduler_g.state_dict(),
"schedD": scheduler_d.state_dict(),
},
path,
)
state = {
"epoch": epoch,
"G": accelerator.unwrap_model(generator).state_dict(),
"D": accelerator.unwrap_model(discriminator).state_dict(),
"optG": optimizer_g.state_dict(),
"optD": optimizer_d.state_dict(),
"schedG": scheduler_g.state_dict(),
"schedD": scheduler_d.state_dict()
}
accelerator.save(state, os.path.join(models_dir, "last.pt"))
if is_best:
accelerator.save(state, os.path.join(models_dir, "best.pt"))
accelerator.print(f"🌟 | New best model saved with G Loss: {loss:.4f}")
start_epoch = 0
@@ -143,9 +150,8 @@ if args.resume:
else:
accelerator.print("⚠️ | Resume requested but no checkpoint found. Starting fresh.")
accelerator.print("🏋️ | Started training...")
smallest_loss = float('inf')
accelerator.print("🏋️ | Started training...")
try:
for epoch in range(start_epoch, args.epochs):
@@ -172,7 +178,6 @@ try:
with accelerator.autocast():
d_loss = discriminator_train(
high_quality,
low_quality.detach(),
discriminator,
generator_output.detach()
)
@@ -218,7 +223,6 @@ try:
steps += 1
progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs")
# epoch averages & schedulers
if steps == 0:
accelerator.print("🪹 | No steps in epoch (empty dataloader?). Exiting.")
break
@@ -230,9 +234,6 @@ try:
scheduler_g.step(mean_g)
save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
if smallest_loss > mean_g:
smallest_loss = mean_g
save_ckpt(os.path.join(models_dir, "latest-smallest_loss.pt"), epoch)
accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
except Exception: