:albemic: | More tests.

This commit is contained in:
NikkeDoy 2025-03-25 21:51:29 +02:00
parent 54338e55a9
commit f928d8c2cf
3 changed files with 14 additions and 12 deletions

View File

@ -28,7 +28,7 @@ class AttentionBlock(nn.Module):
return x * attention_weights return x * attention_weights
class SISUDiscriminator(nn.Module): class SISUDiscriminator(nn.Module):
def __init__(self, layers=64): #Increased base layer count def __init__(self, layers=4): #Increased base layer count
super(SISUDiscriminator, self).__init__() super(SISUDiscriminator, self).__init__()
self.model = nn.Sequential( self.model = nn.Sequential(
discriminator_block(1, layers, kernel_size=3, stride=1), #Aggressive downsampling discriminator_block(1, layers, kernel_size=3, stride=1), #Aggressive downsampling

View File

@ -34,7 +34,7 @@ class ResidualInResidualBlock(nn.Module):
return x + residual return x + residual
class SISUGenerator(nn.Module): class SISUGenerator(nn.Module):
def __init__(self, layer=64, num_rirb=4): #increased base layer and rirb amounts def __init__(self, layer=4, num_rirb=4): #increased base layer and rirb amounts
super(SISUGenerator, self).__init__() super(SISUGenerator, self).__init__()
self.conv1 = nn.Sequential( self.conv1 = nn.Sequential(
nn.Conv1d(1, layer, kernel_size=7, padding=3), nn.Conv1d(1, layer, kernel_size=7, padding=3),

View File

@ -30,7 +30,7 @@ parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file") help="Path to the discriminator model file")
parser.add_argument("--device", type=str, default="cpu", help="Select device") parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning") parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity") parser.add_argument("--debug", action="store_true", help="Print debug logs")
args = parser.parse_args() args = parser.parse_args()
@ -80,19 +80,20 @@ def generator_train(low_quality, high_quality, real_labels):
# Forward pass for fake samples (from generator output) # Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0]) generator_output = generator(low_quality[0])
mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output) #mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output)
discriminator_decision = discriminator(generator_output) discriminator_decision = discriminator(generator_output)
adversarial_loss = criterion_g(discriminator_decision, real_labels) adversarial_loss = criterion_g(discriminator_decision, real_labels)
combined_loss = adversarial_loss + 0.5 * mfcc_l #combined_loss = adversarial_loss + 0.5 * mfcc_l
combined_loss.backward() adversarial_loss.backward()
optimizer_g.step() optimizer_g.step()
return (generator_output, combined_loss, adversarial_loss, mfcc_l) #return (generator_output, combined_loss, adversarial_loss, mfcc_l)
return (generator_output, adversarial_loss)
debug = args.verbose debug = args.debug
# Initialize dataset and dataloader # Initialize dataset and dataloader
dataset_dir = './dataset/good' dataset_dir = './dataset/good'
@ -100,7 +101,7 @@ dataset = AudioDataset(dataset_dir, device)
# ========= SINGLE ========= # ========= SINGLE =========
train_data_loader = DataLoader(dataset, batch_size=16, shuffle=True) train_data_loader = DataLoader(dataset, batch_size=256, shuffle=True)
# Initialize models and move them to device # Initialize models and move them to device
generator = SISUGenerator() generator = SISUGenerator()
@ -157,12 +158,13 @@ def start_training():
# ========= GENERATOR ========= # ========= GENERATOR =========
generator.train() generator.train()
generator_output, combined_loss, adversarial_loss, mfcc_l = generator_train(low_quality_sample, high_quality_sample, real_labels) #generator_output, combined_loss, adversarial_loss, mfcc_l = generator_train(low_quality_sample, high_quality_sample, real_labels)
generator_output, adversarial_loss = generator_train(low_quality_sample, high_quality_sample, real_labels)
if debug: if debug:
print(d_loss, combined_loss, adversarial_loss, mfcc_l) print(d_loss, adversarial_loss)
scheduler_d.step(d_loss) scheduler_d.step(d_loss)
#scheduler_g.step(combined_loss) scheduler_g.step(adversarial_loss)
# ========= SAVE LATEST AUDIO ========= # ========= SAVE LATEST AUDIO =========
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0]) high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])