Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"discriminator": {
"name": ACGANDiscriminator,
"args": {"num_classes": 10, "in_channels": 1, "step_channels": 4},
"optimizer": {
"name": Adam,
"args": {"lr": 0.0002, "betas": (0.5, 0.999)},
},
},
}
losses_list = [
MinimaxGeneratorLoss(),
MinimaxDiscriminatorLoss(),
AuxiliaryClassifierGeneratorLoss(),
AuxiliaryClassifierDiscriminatorLoss(),
]
trainer = Trainer(
network_params,
losses_list,
sample_size=1,
epochs=1,
device=torch.device("cpu"),
)
trainer(mnist_dataloader())
"optimizer": {
"name": Adam,
"args": {"lr": 0.0002, "betas": (0.5, 0.999)},
},
},
"discriminator": {
"name": ConditionalGANDiscriminator,
"args": {"num_classes": 10, "in_channels": 1, "step_channels": 4},
"optimizer": {
"name": Adam,
"args": {"lr": 0.0002, "betas": (0.5, 0.999)},
},
},
}
losses_list = [MinimaxGeneratorLoss(), MinimaxDiscriminatorLoss()]
trainer = Trainer(
network_params,
losses_list,
sample_size=1,
epochs=1,
device=torch.device("cpu"),
)
trainer(mnist_dataloader())
"optimizer": {
"name": Adam,
"args": {"lr": 0.0002, "betas": (0.5, 0.999)},
},
},
"discriminator": {
"name": DCGANDiscriminator,
"args": {"in_channels": 1, "step_channels": 4},
"optimizer": {
"name": Adam,
"args": {"lr": 0.0002, "betas": (0.5, 0.999)},
},
},
}
losses_list = [MinimaxGeneratorLoss(), MinimaxDiscriminatorLoss()]
trainer = Trainer(
network_params,
losses_list,
sample_size=1,
epochs=1,
device=torch.device("cpu"),
)
trainer(mnist_dataloader())
trainer = ParallelTrainer(
network_configuration,
losses,
args.list_gpus,
epochs=args.epochs,
sample_size=args.sample_size,
checkpoints=args.checkpoint,
retain_checkpoints=1,
recon=args.reconstructions,
)
else:
if args.cpu == 1:
device = torch.device("cpu")
else:
device = torch.device("cuda:0")
trainer = Trainer(
network_configuration,
losses,
device=device,
epochs=args.epochs,
sample_size=args.sample_size,
checkpoints=args.checkpoint,
retain_checkpoints=1,
recon=args.reconstructions,
)
train_dataset = dataset(
root=args.data_dir,
train=True,
download=True,
transform=transformations)
trainer = ParallelTrainer(
network_config,
losses_list,
args.list_gpus,
epochs=args.epochs,
sample_size=args.sample_size,
checkpoints=args.checkpoint,
retain_checkpoints=1,
recon=args.reconstructions,
)
else:
if args.cpu == 1:
device = torch.device("cpu")
else:
device = torch.device("cuda:0")
trainer = Trainer(
network_config,
losses_list,
device=device,
epochs=args.epochs,
sample_size=args.sample_size,
checkpoints=args.checkpoint,
retain_checkpoints=1,
recon=args.reconstructions,
)
# Transforms to get Binarized MNIST
dataset = dsets.MNIST(
root=args.data_dir,
train=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),