🐛 | Fix loading wrong model.
This commit is contained in:
@ -87,7 +87,7 @@ os.makedirs(audio_output_dir, exist_ok=True)
|
|||||||
|
|
||||||
# ========= SINGLE =========
|
# ========= SINGLE =========
|
||||||
|
|
||||||
train_data_loader = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=24)
|
train_data_loader = DataLoader(dataset, batch_size=2048, shuffle=True, num_workers=24)
|
||||||
|
|
||||||
|
|
||||||
# ========= MODELS =========
|
# ========= MODELS =========
|
||||||
@ -104,7 +104,7 @@ if args.continue_training:
|
|||||||
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
|
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
|
||||||
else:
|
else:
|
||||||
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
|
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
|
||||||
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
|
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_discriminator.pt", map_location=device, weights_only=True))
|
||||||
|
|
||||||
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
|
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
|
||||||
epoch = epoch_from_file["epoch"] + 1
|
epoch = epoch_from_file["epoch"] + 1
|
||||||
|
Reference in New Issue
Block a user