How to use the schnetpack.train.metrics.MeanAbsoluteError function in schnetpack

To help you get started, we’ve selected a few schnetpack examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github atomistic-machine-learning / schnetpack / tests / test_scripts_utils.py View on Github external
metrics=[
                schnetpack.train.metrics.MeanAbsoluteError(
                    "energy_U0", model_output="energy_U0"
                )
            ],
        )
        args.split = ["validation"]
        evaluate(
            args,
            model,
            qm9_train_loader,
            qm9_val_loader,
            qm9_test_loader,
            "cpu",
            metrics=[
                schnetpack.train.metrics.MeanAbsoluteError(
                    "energy_U0", model_output="energy_U0"
                )
github atomistic-machine-learning / schnetpack / tests / test_metrics.py View on Github external
def forces_mae():
    return MeanAbsoluteError("_forces", "dydx", name="forces", element_wise=True)
github atomistic-machine-learning / schnetpack / src / scripts / schnetpack_ani1.py View on Github external
# parse arguments
    parser = get_main_parser()
    add_subparsers(
        parser,
        defaults=dict(property=ANI1.energy),
        choices=dict(property=[ANI1.energy]),
    )
    args = parser.parse_args()
    train_args = setup_run(args)

    # set device
    device = torch.device("cuda" if args.cuda else "cpu")

    # define metrics
    metrics = [
        schnetpack.train.metrics.MeanAbsoluteError(
            train_args.property, train_args.property
        ),
        schnetpack.train.metrics.RootMeanSquaredError(
            train_args.property, train_args.property
        ),
    ]

    # build dataset
    logging.info("ANI1 will be loaded...")
    ani1 = spk.datasets.ANI1(
        args.datapath,
        download=True,
        load_only=[train_args.property],
        collect_triples=args.model == "wacsf",
    )
github atomistic-machine-learning / schnetpack / src / schnetpack / utils / script_utils / training.py View on Github external
def get_metrics(args):
    # setup property metrics
    metrics = [
        spk.train.metrics.MeanAbsoluteError(args.property, args.property),
        spk.train.metrics.RootMeanSquaredError(args.property, args.property),
    ]

    # add metrics for derivative
    derivative = spk.utils.get_derivative(args)
    if derivative is not None:
        metrics += [
            spk.train.metrics.MeanAbsoluteError(
                derivative, derivative, element_wise=True
            ),
            spk.train.metrics.RootMeanSquaredError(
                derivative, derivative, element_wise=True
            ),
        ]

    return metrics
github atomistic-machine-learning / schnetpack / src / examples / ethanol_script.py View on Github external
n_in=representation.n_atom_basis,
        property="energy",
        derivative="forces",
        mean=means["energy"],
        stddev=stddevs["energy"],
        negative_dr=True,
    )
]
model = schnetpack.atomistic.model.AtomisticModel(representation, output_modules)

# build optimizer
optimizer = Adam(params=model.parameters(), lr=1e-4)

# hooks
logging.info("build trainer")
metrics = [MeanAbsoluteError(p, p) for p in properties]
hooks = [CSVHook(log_path=model_dir, metrics=metrics), ReduceLROnPlateauHook(optimizer)]

# trainer
loss = mse_loss(properties)
trainer = Trainer(
    model_dir,
    model=model,
    hooks=hooks,
    loss_fn=loss,
    optimizer=optimizer,
    train_loader=train_loader,
    validation_loader=val_loader,
)

# run training
logging.info("training")