⚗️ | More architectural changes
This commit is contained in:
33
training.py
33
training.py
@@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
|
||||
import torch
|
||||
@@ -52,7 +53,7 @@ accelerator.print("✅ | Compiling done!")
|
||||
# Dataset / DataLoader
|
||||
# ---------------------------
|
||||
accelerator.print("📊 | Fetching dataset...")
|
||||
dataset = AudioDataset("./dataset")
|
||||
dataset = AudioDataset("./dataset", 8192)
|
||||
|
||||
sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None
|
||||
pin_memory = torch.cuda.is_available() and not args.no_pin_memory
|
||||
@@ -93,7 +94,6 @@ scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer_d, mode="min", factor=0.5, patience=5
|
||||
)
|
||||
|
||||
criterion_g = nn.BCEWithLogitsLoss()
|
||||
criterion_d = nn.MSELoss()
|
||||
|
||||
# ---------------------------
|
||||
@@ -143,12 +143,8 @@ if args.resume:
|
||||
start_epoch = ckpt.get("epoch", 1)
|
||||
accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!")
|
||||
|
||||
real_buf = torch.full(
|
||||
(loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32
|
||||
)
|
||||
fake_buf = torch.zeros(
|
||||
(loader_batch_size, 1), device=accelerator.device, dtype=torch.float32
|
||||
)
|
||||
real_buf = torch.full((loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32)
|
||||
fake_buf = torch.zeros((loader_batch_size, 1), device=accelerator.device, dtype=torch.float32)
|
||||
|
||||
accelerator.print("🏋️ | Started training...")
|
||||
|
||||
@@ -157,35 +153,45 @@ try:
|
||||
generator.train()
|
||||
discriminator.train()
|
||||
|
||||
discriminator_time = 0
|
||||
generator_time = 0
|
||||
|
||||
running_d, running_g, steps = 0.0, 0.0, 0
|
||||
|
||||
progress_bar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs")
|
||||
|
||||
for i, (
|
||||
(high_quality, low_quality),
|
||||
(high_sample_rate, low_sample_rate),
|
||||
) in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")):
|
||||
) in enumerate(progress_bar):
|
||||
batch_size = high_quality.size(0)
|
||||
|
||||
real_labels = real_buf[:batch_size].to(accelerator.device)
|
||||
fake_labels = fake_buf[:batch_size].to(accelerator.device)
|
||||
|
||||
with accelerator.autocast():
|
||||
generator_output = generator(low_quality)
|
||||
|
||||
# --- Discriminator ---
|
||||
d_time = datetime.datetime.now()
|
||||
optimizer_d.zero_grad(set_to_none=True)
|
||||
with accelerator.autocast():
|
||||
d_loss = discriminator_train(
|
||||
high_quality,
|
||||
low_quality,
|
||||
low_quality.detach(),
|
||||
real_labels,
|
||||
fake_labels,
|
||||
discriminator,
|
||||
generator,
|
||||
criterion_d,
|
||||
generator_output.detach()
|
||||
)
|
||||
|
||||
accelerator.backward(d_loss)
|
||||
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1)
|
||||
optimizer_d.step()
|
||||
discriminator_time = (datetime.datetime.now() - d_time).microseconds
|
||||
|
||||
# --- Generator ---
|
||||
g_time = datetime.datetime.now()
|
||||
optimizer_g.zero_grad(set_to_none=True)
|
||||
with accelerator.autocast():
|
||||
g_total, g_adv = generator_train(
|
||||
@@ -195,11 +201,13 @@ try:
|
||||
generator,
|
||||
discriminator,
|
||||
criterion_d,
|
||||
generator_output
|
||||
)
|
||||
|
||||
accelerator.backward(g_total)
|
||||
torch.nn.utils.clip_grad_norm_(generator.parameters(), 1)
|
||||
optimizer_g.step()
|
||||
generator_time = (datetime.datetime.now() - g_time).microseconds
|
||||
|
||||
d_val = accelerator.gather(d_loss.detach()).mean()
|
||||
g_val = accelerator.gather(g_total.detach()).mean()
|
||||
@@ -219,6 +227,7 @@ try:
|
||||
)
|
||||
|
||||
steps += 1
|
||||
progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs")
|
||||
|
||||
# epoch averages & schedulers
|
||||
if steps == 0:
|
||||
|
||||
Reference in New Issue
Block a user