How to use the schnetpack.train.Trainer function in schnetpack

To help you get started, we’ve selected a few schnetpack examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github atomistic-machine-learning / schnetpack / tests / fixtures / train.py View on Github external
def trainer(
    modeldir,
    atomistic_model,
    properties,
    lr,
    train_loader,
    val_loader,
    keep_n_checkpoints,
    checkpoint_interval,
    hooks,
):
    return spk.train.Trainer(
        model_path=modeldir,
        model=atomistic_model,
        loss_fn=spk.train.build_mse_loss(properties),
        optimizer=torch.optim.Adam(atomistic_model.parameters(), lr=lr),
        train_loader=train_loader,
        validation_loader=val_loader,
        keep_n_checkpoints=keep_n_checkpoints,
        checkpoint_interval=checkpoint_interval,
        validation_interval=1,
        hooks=hooks,
        loss_is_normalized=True,
    )
github atomistic-machine-learning / schnetpack / src / examples / ethanol_script.py View on Github external
negative_dr=True,
    )
]
model = schnetpack.atomistic.model.AtomisticModel(representation, output_modules)

# build optimizer
optimizer = Adam(params=model.parameters(), lr=1e-4)

# hooks
logging.info("build trainer")
metrics = [MeanAbsoluteError(p, p) for p in properties]
hooks = [CSVHook(log_path=model_dir, metrics=metrics), ReduceLROnPlateauHook(optimizer)]

# trainer
loss = mse_loss(properties)
trainer = Trainer(
    model_dir,
    model=model,
    hooks=hooks,
    loss_fn=loss,
    optimizer=optimizer,
    train_loader=train_loader,
    validation_loader=val_loader,
)

# run training
logging.info("training")
trainer.train(device="cpu", n_epochs=1000)
github atomistic-machine-learning / schnetpack / src / examples / qm9_schnet.py View on Github external
data = QM9("qm9.db")

# split in train and val
train, val, test = data.create_splits(100000, 10000)
loader = spk.data.AtomsLoader(train, batch_size=100, num_workers=4)
val_loader = spk.data.AtomsLoader(val)

# create model
reps = rep.SchNet()
output = schnetpack.atomistic.Atomwise(n_in=reps.n_atom_basis)
model = schnetpack.atomistic.AtomisticModel(reps, output)

# create trainer
opt = Adam(model.parameters(), lr=1e-4)
loss = lambda b, p: F.mse_loss(p["y"], b[QM9.U0])
trainer = spk.train.Trainer("output/", model, loss, opt, loader, val_loader)

# start training
trainer.train(torch.device("cpu"))
github atomistic-machine-learning / schnetpack / src / schnetpack / utils / script_utils / training.py View on Github external
every_n_epochs=args.log_every_n_epochs,
        )
        hooks.append(logger)
    elif args.logger == "tensorboard":
        logger = spk.train.TensorboardHook(
            os.path.join(args.modelpath, "log"),
            metrics,
            every_n_epochs=args.log_every_n_epochs,
        )
        hooks.append(logger)

    # setup loss function
    loss_fn = get_loss_fn(args)

    # setup trainer
    trainer = spk.train.Trainer(
        args.modelpath,
        model,
        loss_fn,
        optimizer,
        train_loader,
        val_loader,
        checkpoint_interval=args.checkpoint_interval,
        keep_n_checkpoints=args.keep_n_checkpoints,
        hooks=hooks,
    )
    return trainer
github atomistic-machine-learning / G-SchNet / gschnet_qm9_script.py View on Github external
loss_type = loss_layer(out_type, batch['_type_labels'])
        loss_type = torch.sum(loss_type, -1)
        loss_type = torch.mean(loss_type)

        # loss for distance predictions (KLD)
        mask_dist = batch['_dist_mask']
        N = torch.sum(mask_dist)
        out_dist = norm_layer(result['distance_predictions'])
        loss_dist = loss_layer(out_dist, batch['_labels'])
        loss_dist = torch.sum(loss_dist, -1)
        loss_dist = torch.sum(loss_dist * mask_dist) / torch.max(N, torch.ones_like(N))

        return loss_type + loss_dist

    # initialize trainer
    trainer = spk.train.Trainer(args.modelpath,
                                model,
                                loss,
                                optimizer,
                                train_loader,
                                val_loader,
                                hooks=hooks,
                                checkpoint_interval=args.checkpoint_every_n_epochs,
                                keep_n_checkpoints=10)

    # reset optimizer and hooks if starting from pre-trained model (e.g. for
    # fine-tuning)
    if args.pretrained_path is not None:
        logging.info('starting from pre-trained model...')
        # reset epoch and step
        trainer.epoch = 0
        trainer.step = 0
github atomistic-machine-learning / schnetpack / src / examples / qm9_tutorial.py View on Github external
atomref=atomrefs[QM9.U0],
    )
]
model = schnetpack.AtomisticModel(representation, output_modules)

# build optimizer
optimizer = Adam(model.parameters(), lr=1e-4)

# hooks
logging.info("build trainer")
metrics = [MeanAbsoluteError(p, p) for p in properties]
hooks = [CSVHook(log_path=model_dir, metrics=metrics), ReduceLROnPlateauHook(optimizer)]

# trainer
loss = build_mse_loss(properties)
trainer = Trainer(
    model_dir,
    model=model,
    hooks=hooks,
    loss_fn=loss,
    optimizer=optimizer,
    train_loader=train_loader,
    validation_loader=val_loader,
)

# run training
logging.info("training")
trainer.train(device="cpu")