How to use the schnetpack.train function in schnetpack

To help you get started, we’ve selected a few schnetpack examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github atomistic-machine-learning / schnetpack / tests / test_scripts_utils.py View on Github external
lr_decay=0.5,
            lr_min=1e-6,
            logger="csv",
            modelpath=modeldir,
            log_every_n_epochs=2,
            max_steps=30,
            checkpoint_interval=1,
            keep_n_checkpoints=1,
            dataset="qm9",
        )
        trainer = get_trainer(
            args, schnet, qm9_train_loader, qm9_val_loader, metrics=None
        )
        assert trainer._model == schnet
        hook_types = [type(hook) for hook in trainer.hooks]
        assert schnetpack.train.hooks.CSVHook in hook_types
        assert schnetpack.train.hooks.TensorboardHook not in hook_types
        assert schnetpack.train.hooks.MaxEpochHook in hook_types
        assert schnetpack.train.hooks.ReduceLROnPlateauHook in hook_types
github atomistic-machine-learning / schnetpack / tests / test_scripts_utils.py View on Github external
)
        mean = {args.property: None}
        model = get_model(
            args, train_loader=qm9_train_loader, mean=mean, stddev=mean, atomref=mean
        )

        os.makedirs(modeldir, exist_ok=True)
        evaluate(
            args,
            model,
            qm9_train_loader,
            qm9_val_loader,
            qm9_test_loader,
            "cpu",
            metrics=[
                schnetpack.train.metrics.MeanAbsoluteError(
                    "energy_U0", model_output="energy_U0"
                )
            ],
        )
        assert os.path.exists(os.path.join(modeldir, "evaluation.txt"))
        args.split = ["train"]
        evaluate(
            args,
            model,
            qm9_train_loader,
            qm9_val_loader,
            qm9_test_loader,
            "cpu",
            metrics=[
                schnetpack.train.metrics.MeanAbsoluteError(
                    "energy_U0", model_output="energy_U0"
github atomistic-machine-learning / schnetpack / src / scripts / schnetpack_md17.py View on Github external
add_md17_arguments(parser)
    add_subparsers(
        parser,
        defaults=dict(property=MD17.energy, elements=["H", "C", "O"]),
        choices=dict(property=[MD17.energy, MD17.forces]),
    )
    args = parser.parse_args()
    train_args = setup_run(args)

    # set device
    device = torch.device("cuda" if args.cuda else "cpu")

    # define metrics
    metrics = [
        schnetpack.train.metrics.MeanAbsoluteError(MD17.energy, MD17.energy),
        schnetpack.train.metrics.RootMeanSquaredError(MD17.energy, MD17.energy),
        schnetpack.train.metrics.MeanAbsoluteError(
            MD17.forces, MD17.forces, element_wise=True
        ),
        schnetpack.train.metrics.RootMeanSquaredError(
            MD17.forces, MD17.forces, element_wise=True
        ),
    ]

    # build dataset
    logging.info("MD17 will be loaded...")
    md17 = MD17(
        args.datapath,
        args.molecule,
        download=True,
        collect_triples=args.model == "wacsf",
    )
github atomistic-machine-learning / schnetpack / src / examples / qm9_wACSF.py View on Github external
train, val, test = data.create_splits(100000, 10000)
loader = spk.data.AtomsLoader(train, batch_size=100, num_workers=4)
val_loader = spk.data.AtomsLoader(val)

# create model
reps = rep.BehlerSFBlock()
output = schnetpack.output_modules.ElementalAtomwise(reps.n_symfuncs)
model = atm.AtomisticModel(reps, output)

# filter for trainable parameters (https://github.com/pytorch/pytorch/issues/679)
trainable_params = filter(lambda p: p.requires_grad, model.parameters())

# create trainer
opt = Adam(trainable_params, lr=1e-4)
loss = lambda b, p: F.mse_loss(p["y"], b[QM9.U0])
trainer = spk.train.Trainer("wacsf/", model, loss, opt, loader, val_loader)

# start training
trainer.train(torch.device("cpu"))