⚗️ | Added MultiPeriodDiscriminator implementation from HiFi-GAN
This commit is contained in:
41
training.py
41
training.py
@@ -3,6 +3,7 @@ import datetime
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import tqdm
|
||||
from accelerate import Accelerator
|
||||
@@ -39,10 +40,13 @@ accelerator = Accelerator(mixed_precision="bf16")
|
||||
# Models
|
||||
# ---------------------------
|
||||
generator = SISUGenerator()
|
||||
# Note: SISUDiscriminator is now an Ensemble (MPD + MSD)
|
||||
discriminator = SISUDiscriminator()
|
||||
|
||||
accelerator.print("🔨 | Compiling models...")
|
||||
|
||||
# Torch compile is great, but if you hit errors with the new List/Tuple outputs
|
||||
# of the discriminator, you might need to disable it for D.
|
||||
generator = torch.compile(generator)
|
||||
discriminator = torch.compile(discriminator)
|
||||
|
||||
@@ -108,21 +112,24 @@ models_dir = "./models"
|
||||
os.makedirs(models_dir, exist_ok=True)
|
||||
|
||||
|
||||
def save_ckpt(path, epoch):
|
||||
def save_ckpt(path, epoch, loss=None, is_best=False):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
accelerator.save(
|
||||
{
|
||||
"epoch": epoch,
|
||||
"G": accelerator.unwrap_model(generator).state_dict(),
|
||||
"D": accelerator.unwrap_model(discriminator).state_dict(),
|
||||
"optG": optimizer_g.state_dict(),
|
||||
"optD": optimizer_d.state_dict(),
|
||||
"schedG": scheduler_g.state_dict(),
|
||||
"schedD": scheduler_d.state_dict(),
|
||||
},
|
||||
path,
|
||||
)
|
||||
state = {
|
||||
"epoch": epoch,
|
||||
"G": accelerator.unwrap_model(generator).state_dict(),
|
||||
"D": accelerator.unwrap_model(discriminator).state_dict(),
|
||||
"optG": optimizer_g.state_dict(),
|
||||
"optD": optimizer_d.state_dict(),
|
||||
"schedG": scheduler_g.state_dict(),
|
||||
"schedD": scheduler_d.state_dict()
|
||||
}
|
||||
|
||||
accelerator.save(state, os.path.join(models_dir, "last.pt"))
|
||||
|
||||
if is_best:
|
||||
accelerator.save(state, os.path.join(models_dir, "best.pt"))
|
||||
accelerator.print(f"🌟 | New best model saved with G Loss: {loss:.4f}")
|
||||
|
||||
|
||||
start_epoch = 0
|
||||
@@ -143,9 +150,8 @@ if args.resume:
|
||||
else:
|
||||
accelerator.print("⚠️ | Resume requested but no checkpoint found. Starting fresh.")
|
||||
|
||||
accelerator.print("🏋️ | Started training...")
|
||||
|
||||
smallest_loss = float('inf')
|
||||
accelerator.print("🏋️ | Started training...")
|
||||
|
||||
try:
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
@@ -172,7 +178,6 @@ try:
|
||||
with accelerator.autocast():
|
||||
d_loss = discriminator_train(
|
||||
high_quality,
|
||||
low_quality.detach(),
|
||||
discriminator,
|
||||
generator_output.detach()
|
||||
)
|
||||
@@ -218,7 +223,6 @@ try:
|
||||
steps += 1
|
||||
progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs")
|
||||
|
||||
# epoch averages & schedulers
|
||||
if steps == 0:
|
||||
accelerator.print("🪹 | No steps in epoch (empty dataloader?). Exiting.")
|
||||
break
|
||||
@@ -230,9 +234,6 @@ try:
|
||||
scheduler_g.step(mean_g)
|
||||
|
||||
save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
|
||||
if smallest_loss > mean_g:
|
||||
smallest_loss = mean_g
|
||||
save_ckpt(os.path.join(models_dir, "latest-smallest_loss.pt"), epoch)
|
||||
accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
|
||||
|
||||
except Exception:
|
||||
|
||||
Reference in New Issue
Block a user