⚗️ | Added MultiPeriodDiscriminator implementation from HiFi-GAN
This commit is contained in:
44
training.py
44
training.py
@@ -3,7 +3,6 @@ import datetime
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import tqdm
|
||||
from accelerate import Accelerator
|
||||
@@ -23,7 +22,7 @@ parser.add_argument(
|
||||
"--epochs", type=int, default=5000, help="Number of training epochs"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
|
||||
parser.add_argument("--num_workers", type=int, default=2, help="DataLoader num_workers")
|
||||
parser.add_argument("--num_workers", type=int, default=4, help="DataLoader num_workers") # Increased workers slightly
|
||||
parser.add_argument("--debug", action="store_true", help="Print debug logs")
|
||||
parser.add_argument(
|
||||
"--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA"
|
||||
@@ -94,8 +93,6 @@ scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer_d, mode="min", factor=0.5, patience=5
|
||||
)
|
||||
|
||||
criterion_d = nn.MSELoss()
|
||||
|
||||
# ---------------------------
|
||||
# Prepare accelerator
|
||||
# ---------------------------
|
||||
@@ -131,23 +128,25 @@ def save_ckpt(path, epoch):
|
||||
start_epoch = 0
|
||||
if args.resume:
|
||||
ckpt_path = os.path.join(models_dir, "last.pt")
|
||||
ckpt = torch.load(ckpt_path)
|
||||
if os.path.exists(ckpt_path):
|
||||
ckpt = torch.load(ckpt_path)
|
||||
|
||||
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
|
||||
accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"])
|
||||
optimizer_g.load_state_dict(ckpt["optG"])
|
||||
optimizer_d.load_state_dict(ckpt["optD"])
|
||||
scheduler_g.load_state_dict(ckpt["schedG"])
|
||||
scheduler_d.load_state_dict(ckpt["schedD"])
|
||||
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
|
||||
accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"])
|
||||
optimizer_g.load_state_dict(ckpt["optG"])
|
||||
optimizer_d.load_state_dict(ckpt["optD"])
|
||||
scheduler_g.load_state_dict(ckpt["schedG"])
|
||||
scheduler_d.load_state_dict(ckpt["schedD"])
|
||||
|
||||
start_epoch = ckpt.get("epoch", 1)
|
||||
accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!")
|
||||
|
||||
real_buf = torch.full((loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32)
|
||||
fake_buf = torch.zeros((loader_batch_size, 1), device=accelerator.device, dtype=torch.float32)
|
||||
start_epoch = ckpt.get("epoch", 1)
|
||||
accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!")
|
||||
else:
|
||||
accelerator.print("⚠️ | Resume requested but no checkpoint found. Starting fresh.")
|
||||
|
||||
accelerator.print("🏋️ | Started training...")
|
||||
|
||||
smallest_loss = float('inf')
|
||||
|
||||
try:
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
generator.train()
|
||||
@@ -164,11 +163,6 @@ try:
|
||||
(high_quality, low_quality),
|
||||
(high_sample_rate, low_sample_rate),
|
||||
) in enumerate(progress_bar):
|
||||
batch_size = high_quality.size(0)
|
||||
|
||||
real_labels = real_buf[:batch_size].to(accelerator.device)
|
||||
fake_labels = fake_buf[:batch_size].to(accelerator.device)
|
||||
|
||||
with accelerator.autocast():
|
||||
generator_output = generator(low_quality)
|
||||
|
||||
@@ -179,10 +173,7 @@ try:
|
||||
d_loss = discriminator_train(
|
||||
high_quality,
|
||||
low_quality.detach(),
|
||||
real_labels,
|
||||
fake_labels,
|
||||
discriminator,
|
||||
criterion_d,
|
||||
generator_output.detach()
|
||||
)
|
||||
|
||||
@@ -197,10 +188,8 @@ try:
|
||||
g_total, g_adv = generator_train(
|
||||
low_quality,
|
||||
high_quality,
|
||||
real_labels,
|
||||
generator,
|
||||
discriminator,
|
||||
criterion_d,
|
||||
generator_output
|
||||
)
|
||||
|
||||
@@ -241,6 +230,9 @@ try:
|
||||
scheduler_g.step(mean_g)
|
||||
|
||||
save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
|
||||
if smallest_loss > mean_g:
|
||||
smallest_loss = mean_g
|
||||
save_ckpt(os.path.join(models_dir, "latest-smallest_loss.pt"), epoch)
|
||||
accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
|
||||
|
||||
except Exception:
|
||||
|
||||
Reference in New Issue
Block a user