diff --git a/training.py b/training.py index 380f738..c050b9c 100644 --- a/training.py +++ b/training.py @@ -119,7 +119,7 @@ if args.discriminator is not None: # Loss criterion_g = nn.MSELoss() -criterion_d = nn.BCELoss() +criterion_d = nn.BCEWithLogitsLoss() # Optimizers optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))