⚗️ | More architectural changes

This commit is contained in:
2025-11-18 21:34:59 +02:00
parent 3f23242d6f
commit 782a3bab28
8 changed files with 245 additions and 254 deletions

View File

@@ -1,4 +1,5 @@
import argparse
import datetime
import os
import torch
@@ -52,7 +53,7 @@ accelerator.print("✅ | Compiling done!")
# Dataset / DataLoader
# ---------------------------
accelerator.print("📊 | Fetching dataset...")
dataset = AudioDataset("./dataset")
dataset = AudioDataset("./dataset", 8192)
sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None
pin_memory = torch.cuda.is_available() and not args.no_pin_memory
@@ -93,7 +94,6 @@ scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_d, mode="min", factor=0.5, patience=5
)
criterion_g = nn.BCEWithLogitsLoss()
criterion_d = nn.MSELoss()
# ---------------------------
@@ -143,12 +143,8 @@ if args.resume:
start_epoch = ckpt.get("epoch", 1)
accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!")
real_buf = torch.full(
(loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32
)
fake_buf = torch.zeros(
(loader_batch_size, 1), device=accelerator.device, dtype=torch.float32
)
real_buf = torch.full((loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32)
fake_buf = torch.zeros((loader_batch_size, 1), device=accelerator.device, dtype=torch.float32)
accelerator.print("🏋️ | Started training...")
@@ -157,35 +153,45 @@ try:
generator.train()
discriminator.train()
discriminator_time = 0
generator_time = 0
running_d, running_g, steps = 0.0, 0.0, 0
progress_bar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs")
for i, (
(high_quality, low_quality),
(high_sample_rate, low_sample_rate),
) in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")):
) in enumerate(progress_bar):
batch_size = high_quality.size(0)
real_labels = real_buf[:batch_size].to(accelerator.device)
fake_labels = fake_buf[:batch_size].to(accelerator.device)
with accelerator.autocast():
generator_output = generator(low_quality)
# --- Discriminator ---
d_time = datetime.datetime.now()
optimizer_d.zero_grad(set_to_none=True)
with accelerator.autocast():
d_loss = discriminator_train(
high_quality,
low_quality,
low_quality.detach(),
real_labels,
fake_labels,
discriminator,
generator,
criterion_d,
generator_output.detach()
)
accelerator.backward(d_loss)
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1)
optimizer_d.step()
discriminator_time = (datetime.datetime.now() - d_time).microseconds
# --- Generator ---
g_time = datetime.datetime.now()
optimizer_g.zero_grad(set_to_none=True)
with accelerator.autocast():
g_total, g_adv = generator_train(
@@ -195,11 +201,13 @@ try:
generator,
discriminator,
criterion_d,
generator_output
)
accelerator.backward(g_total)
torch.nn.utils.clip_grad_norm_(generator.parameters(), 1)
optimizer_g.step()
generator_time = (datetime.datetime.now() - g_time).microseconds
d_val = accelerator.gather(d_loss.detach()).mean()
g_val = accelerator.gather(g_total.detach()).mean()
@@ -219,6 +227,7 @@ try:
)
steps += 1
progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs")
# epoch averages & schedulers
if steps == 0: