:albemic: | More tests.

This commit is contained in:
2025-03-25 21:51:29 +02:00
parent 54338e55a9
commit f928d8c2cf
3 changed files with 14 additions and 12 deletions

View File

@ -30,7 +30,7 @@ parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
args = parser.parse_args()
@ -80,19 +80,20 @@ def generator_train(low_quality, high_quality, real_labels):
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output)
#mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output)
discriminator_decision = discriminator(generator_output)
adversarial_loss = criterion_g(discriminator_decision, real_labels)
combined_loss = adversarial_loss + 0.5 * mfcc_l
#combined_loss = adversarial_loss + 0.5 * mfcc_l
combined_loss.backward()
adversarial_loss.backward()
optimizer_g.step()
return (generator_output, combined_loss, adversarial_loss, mfcc_l)
#return (generator_output, combined_loss, adversarial_loss, mfcc_l)
return (generator_output, adversarial_loss)
debug = args.verbose
debug = args.debug
# Initialize dataset and dataloader
dataset_dir = './dataset/good'
@ -100,7 +101,7 @@ dataset = AudioDataset(dataset_dir, device)
# ========= SINGLE =========
train_data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
train_data_loader = DataLoader(dataset, batch_size=256, shuffle=True)
# Initialize models and move them to device
generator = SISUGenerator()
@ -157,12 +158,13 @@ def start_training():
# ========= GENERATOR =========
generator.train()
generator_output, combined_loss, adversarial_loss, mfcc_l = generator_train(low_quality_sample, high_quality_sample, real_labels)
#generator_output, combined_loss, adversarial_loss, mfcc_l = generator_train(low_quality_sample, high_quality_sample, real_labels)
generator_output, adversarial_loss = generator_train(low_quality_sample, high_quality_sample, real_labels)
if debug:
print(d_loss, combined_loss, adversarial_loss, mfcc_l)
print(d_loss, adversarial_loss)
scheduler_d.step(d_loss)
#scheduler_g.step(combined_loss)
scheduler_g.step(adversarial_loss)
# ========= SAVE LATEST AUDIO =========
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])