| Added app.py script so the model can be used.

This commit is contained in:
2025-06-06 22:10:06 +03:00
parent a135c765da
commit 2ded03713d
3 changed files with 88 additions and 9 deletions

View File

@ -69,14 +69,14 @@ debug = args.debug
# Initialize dataset and dataloader
dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir, device)
models_dir = "models"
models_dir = "./models"
os.makedirs(models_dir, exist_ok=True)
audio_output_dir = "output"
audio_output_dir = "./output"
os.makedirs(audio_output_dir, exist_ok=True)
# ========= SINGLE =========
train_data_loader = DataLoader(dataset, batch_size=1024, shuffle=True)
train_data_loader = DataLoader(dataset, batch_size=8192, shuffle=True, num_workers=24)
# ========= MODELS =========
@ -85,17 +85,17 @@ generator = SISUGenerator()
discriminator = SISUDiscriminator()
epoch: int = args.epoch
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
if args.continue_training:
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
epoch = epoch_from_file["epoch"] + 1
else:
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
if args.discriminator is not None:
elif args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
else:
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
epoch = epoch_from_file["epoch"] + 1
generator = generator.to(device)
discriminator = discriminator.to(device)