From f928d8c2cf3540c1f22f13f70715863220766eb9 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Tue, 25 Mar 2025 21:51:29 +0200 Subject: [PATCH] :albemic: | More tests. --- discriminator.py | 2 +- generator.py | 2 +- training.py | 22 ++++++++++++---------- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/discriminator.py b/discriminator.py index 1608199..58b95f0 100644 --- a/discriminator.py +++ b/discriminator.py @@ -28,7 +28,7 @@ class AttentionBlock(nn.Module): return x * attention_weights class SISUDiscriminator(nn.Module): - def __init__(self, layers=64): #Increased base layer count + def __init__(self, layers=4): #Increased base layer count super(SISUDiscriminator, self).__init__() self.model = nn.Sequential( discriminator_block(1, layers, kernel_size=3, stride=1), #Aggressive downsampling diff --git a/generator.py b/generator.py index 04ac5b4..950530a 100644 --- a/generator.py +++ b/generator.py @@ -34,7 +34,7 @@ class ResidualInResidualBlock(nn.Module): return x + residual class SISUGenerator(nn.Module): - def __init__(self, layer=64, num_rirb=4): #increased base layer and rirb amounts + def __init__(self, layer=4, num_rirb=4): #increased base layer and rirb amounts super(SISUGenerator, self).__init__() self.conv1 = nn.Sequential( nn.Conv1d(1, layer, kernel_size=7, padding=3), diff --git a/training.py b/training.py index 63fc5b8..814fcda 100644 --- a/training.py +++ b/training.py @@ -30,7 +30,7 @@ parser.add_argument("--discriminator", type=str, default=None, help="Path to the discriminator model file") parser.add_argument("--device", type=str, default="cpu", help="Select device") parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning") -parser.add_argument("--verbose", action="store_true", help="Increase output verbosity") +parser.add_argument("--debug", action="store_true", help="Print debug logs") args = parser.parse_args() @@ -80,19 +80,20 @@ def generator_train(low_quality, high_quality, real_labels): # Forward pass for fake samples (from generator output) generator_output = generator(low_quality[0]) - mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output) + #mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output) discriminator_decision = discriminator(generator_output) adversarial_loss = criterion_g(discriminator_decision, real_labels) - combined_loss = adversarial_loss + 0.5 * mfcc_l + #combined_loss = adversarial_loss + 0.5 * mfcc_l - combined_loss.backward() + adversarial_loss.backward() optimizer_g.step() - return (generator_output, combined_loss, adversarial_loss, mfcc_l) + #return (generator_output, combined_loss, adversarial_loss, mfcc_l) + return (generator_output, adversarial_loss) -debug = args.verbose +debug = args.debug # Initialize dataset and dataloader dataset_dir = './dataset/good' @@ -100,7 +101,7 @@ dataset = AudioDataset(dataset_dir, device) # ========= SINGLE ========= -train_data_loader = DataLoader(dataset, batch_size=16, shuffle=True) +train_data_loader = DataLoader(dataset, batch_size=256, shuffle=True) # Initialize models and move them to device generator = SISUGenerator() @@ -157,12 +158,13 @@ def start_training(): # ========= GENERATOR ========= generator.train() - generator_output, combined_loss, adversarial_loss, mfcc_l = generator_train(low_quality_sample, high_quality_sample, real_labels) + #generator_output, combined_loss, adversarial_loss, mfcc_l = generator_train(low_quality_sample, high_quality_sample, real_labels) + generator_output, adversarial_loss = generator_train(low_quality_sample, high_quality_sample, real_labels) if debug: - print(d_loss, combined_loss, adversarial_loss, mfcc_l) + print(d_loss, adversarial_loss) scheduler_d.step(d_loss) - #scheduler_g.step(combined_loss) + scheduler_g.step(adversarial_loss) # ========= SAVE LATEST AUDIO ========= high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])