diff --git a/training.py b/training.py index 1be713c..ab2b1e5 100644 --- a/training.py +++ b/training.py @@ -87,7 +87,7 @@ os.makedirs(audio_output_dir, exist_ok=True) # ========= SINGLE ========= -train_data_loader = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=24) +train_data_loader = DataLoader(dataset, batch_size=2048, shuffle=True, num_workers=24) # ========= MODELS ========= @@ -104,7 +104,7 @@ if args.continue_training: discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True)) else: generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) - discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) + discriminator.load_state_dict(torch.load(f"{models_dir}/temp_discriminator.pt", map_location=device, weights_only=True)) epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json") epoch = epoch_from_file["epoch"] + 1