:albemic: | Real-time testing...
This commit is contained in:
102
training.py
102
training.py
@ -43,11 +43,11 @@ print(f"Using device: {device}")
|
||||
|
||||
# Parameters
|
||||
sample_rate = 44100
|
||||
n_fft = 2048
|
||||
hop_length = 256
|
||||
n_fft = 128
|
||||
hop_length = 128
|
||||
win_length = n_fft
|
||||
n_mels = 128
|
||||
n_mfcc = 20 # If using MFCC
|
||||
n_mels = 40
|
||||
n_mfcc = 13 # If using MFCC
|
||||
|
||||
mfcc_transform = T.MFCC(
|
||||
sample_rate,
|
||||
@ -76,7 +76,7 @@ os.makedirs(audio_output_dir, exist_ok=True)
|
||||
|
||||
# ========= SINGLE =========
|
||||
|
||||
train_data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
|
||||
train_data_loader = DataLoader(dataset, batch_size=1, shuffle=True)
|
||||
|
||||
|
||||
# ========= MODELS =========
|
||||
@ -115,61 +115,69 @@ scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min'
|
||||
def start_training():
|
||||
generator_epochs = 5000
|
||||
for generator_epoch in range(generator_epochs):
|
||||
low_quality_audio = (torch.empty((1)), 1)
|
||||
high_quality_audio = (torch.empty((1)), 1)
|
||||
ai_enhanced_audio = (torch.empty((1)), 1)
|
||||
high_quality_audio = ([torch.empty((1))], 1)
|
||||
low_quality_audio = ([torch.empty((1))], 1)
|
||||
ai_enhanced_audio = ([torch.empty((1))], 1)
|
||||
|
||||
times_correct = 0
|
||||
|
||||
# ========= TRAINING =========
|
||||
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
|
||||
# for high_quality_clip, low_quality_clip in train_data_loader:
|
||||
high_quality_sample = (high_quality_clip[0], high_quality_clip[1])
|
||||
low_quality_sample = (low_quality_clip[0], low_quality_clip[1])
|
||||
for high_quality_data, low_quality_data in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
|
||||
## Data structure:
|
||||
# [[float..., float..., float...], sample_rate]
|
||||
|
||||
# ========= LABELS =========
|
||||
batch_size = high_quality_clip[0].size(0)
|
||||
|
||||
batch_size = high_quality_data[0][0].size(0)
|
||||
real_labels = torch.ones(batch_size, 1).to(device)
|
||||
fake_labels = torch.zeros(batch_size, 1).to(device)
|
||||
|
||||
# ========= DISCRIMINATOR =========
|
||||
discriminator.train()
|
||||
d_loss = discriminator_train(
|
||||
high_quality_sample,
|
||||
low_quality_sample,
|
||||
real_labels,
|
||||
fake_labels,
|
||||
discriminator,
|
||||
generator,
|
||||
criterion_d,
|
||||
optimizer_d
|
||||
)
|
||||
high_quality_audio = high_quality_data
|
||||
low_quality_audio = low_quality_data
|
||||
|
||||
# ========= GENERATOR =========
|
||||
generator.train()
|
||||
generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train(
|
||||
low_quality_sample,
|
||||
high_quality_sample,
|
||||
real_labels,
|
||||
generator,
|
||||
discriminator,
|
||||
criterion_d,
|
||||
optimizer_g,
|
||||
device,
|
||||
mel_transform,
|
||||
stft_transform,
|
||||
mfcc_transform
|
||||
)
|
||||
ai_enhanced_outputs = []
|
||||
|
||||
if debug:
|
||||
print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}")
|
||||
scheduler_d.step(d_loss.detach())
|
||||
scheduler_g.step(adversarial_loss.detach())
|
||||
for high_quality_sample, low_quality_sample in tqdm.tqdm(zip(high_quality_data[0], low_quality_data[0]), desc=f"Processing audio clip.. Length: {len(high_quality_data[0])}"):
|
||||
# ========= DISCRIMINATOR =========
|
||||
discriminator.train()
|
||||
d_loss = discriminator_train(
|
||||
high_quality_sample,
|
||||
low_quality_sample,
|
||||
real_labels,
|
||||
fake_labels,
|
||||
discriminator,
|
||||
generator,
|
||||
criterion_d,
|
||||
optimizer_d
|
||||
)
|
||||
|
||||
# ========= GENERATOR =========
|
||||
generator.train()
|
||||
generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train(
|
||||
low_quality_sample,
|
||||
high_quality_sample,
|
||||
real_labels,
|
||||
generator,
|
||||
discriminator,
|
||||
criterion_d,
|
||||
optimizer_g,
|
||||
device,
|
||||
mel_transform,
|
||||
stft_transform,
|
||||
mfcc_transform
|
||||
)
|
||||
|
||||
ai_enhanced_outputs.append(generator_output)
|
||||
|
||||
if debug:
|
||||
print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}")
|
||||
scheduler_d.step(d_loss.detach())
|
||||
scheduler_g.step(adversarial_loss.detach())
|
||||
|
||||
# ========= SAVE LATEST AUDIO =========
|
||||
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])
|
||||
low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0])
|
||||
ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0])
|
||||
high_quality_audio = (torch.cat(high_quality_data[0]), high_quality_data[1])
|
||||
low_quality_audio = (torch.cat(low_quality_data[0]), low_quality_data[1])
|
||||
ai_enhanced_audio = (torch.cat(ai_enhanced_outputs), high_quality_data[1])
|
||||
|
||||
new_epoch = generator_epoch+epoch
|
||||
|
||||
|
Reference in New Issue
Block a user