:albemic: | More tests.

This commit is contained in:
NikkeDoy 2025-03-25 21:51:29 +02:00
parent 54338e55a9
commit f928d8c2cf
3 changed files with 14 additions and 12 deletions

View File

@ -28,7 +28,7 @@ class AttentionBlock(nn.Module):
return x * attention_weights
class SISUDiscriminator(nn.Module):
def __init__(self, layers=64): #Increased base layer count
def __init__(self, layers=4): #Increased base layer count
super(SISUDiscriminator, self).__init__()
self.model = nn.Sequential(
discriminator_block(1, layers, kernel_size=3, stride=1), #Aggressive downsampling

View File

@ -34,7 +34,7 @@ class ResidualInResidualBlock(nn.Module):
return x + residual
class SISUGenerator(nn.Module):
def __init__(self, layer=64, num_rirb=4): #increased base layer and rirb amounts
def __init__(self, layer=4, num_rirb=4): #increased base layer and rirb amounts
super(SISUGenerator, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv1d(1, layer, kernel_size=7, padding=3),

View File

@ -30,7 +30,7 @@ parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
args = parser.parse_args()
@ -80,19 +80,20 @@ def generator_train(low_quality, high_quality, real_labels):
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output)
#mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output)
discriminator_decision = discriminator(generator_output)
adversarial_loss = criterion_g(discriminator_decision, real_labels)
combined_loss = adversarial_loss + 0.5 * mfcc_l
#combined_loss = adversarial_loss + 0.5 * mfcc_l
combined_loss.backward()
adversarial_loss.backward()
optimizer_g.step()
return (generator_output, combined_loss, adversarial_loss, mfcc_l)
#return (generator_output, combined_loss, adversarial_loss, mfcc_l)
return (generator_output, adversarial_loss)
debug = args.verbose
debug = args.debug
# Initialize dataset and dataloader
dataset_dir = './dataset/good'
@ -100,7 +101,7 @@ dataset = AudioDataset(dataset_dir, device)
# ========= SINGLE =========
train_data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
train_data_loader = DataLoader(dataset, batch_size=256, shuffle=True)
# Initialize models and move them to device
generator = SISUGenerator()
@ -157,12 +158,13 @@ def start_training():
# ========= GENERATOR =========
generator.train()
generator_output, combined_loss, adversarial_loss, mfcc_l = generator_train(low_quality_sample, high_quality_sample, real_labels)
#generator_output, combined_loss, adversarial_loss, mfcc_l = generator_train(low_quality_sample, high_quality_sample, real_labels)
generator_output, adversarial_loss = generator_train(low_quality_sample, high_quality_sample, real_labels)
if debug:
print(d_loss, combined_loss, adversarial_loss, mfcc_l)
print(d_loss, adversarial_loss)
scheduler_d.step(d_loss)
#scheduler_g.step(combined_loss)
scheduler_g.step(adversarial_loss)
# ========= SAVE LATEST AUDIO =========
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])