How to use the schnetpack.datasets.ANI1 function in schnetpack

To help you get started, we’ve selected a few schnetpack examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github atomistic-machine-learning / schnetpack / src / sacred_scripts / download.py View on Github external
def download(dbpath, dataset, folds, cutoff, api_key, molecules):
    dataset = dataset.upper()
    if dataset == 'QM9':
        qm9 = QM9(dbpath)
    elif dataset == 'ISO17':
        for fold in folds:
            iso17 = ISO17(dbpath, fold)
    elif dataset == 'ANI1':
        ani1 = ANI1(dbpath)
    elif dataset == 'MD17':
        for molecule in molecules:
            md17 = MD17(dbpath + molecule + '.db', dataset=molecule)
    elif dataset == 'MATPROJ':
        matproj = MaterialsProject(dbpath, cutoff, api_key)
    else:
        raise NotImplementedError
github atomistic-machine-learning / schnetpack / src / scripts / schnetpack_ani1.py View on Github external
get_main_parser,
    add_subparsers,
    get_loaders,
    get_statistics,
)


logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))


if __name__ == "__main__":
    # parse arguments
    parser = get_main_parser()
    add_subparsers(
        parser,
        defaults=dict(property=ANI1.energy),
        choices=dict(property=[ANI1.energy]),
    )
    args = parser.parse_args()
    train_args = setup_run(args)

    # set device
    device = torch.device("cuda" if args.cuda else "cpu")

    # define metrics
    metrics = [
        schnetpack.train.metrics.MeanAbsoluteError(
            train_args.property, train_args.property
        ),
        schnetpack.train.metrics.RootMeanSquaredError(
            train_args.property, train_args.property
        ),
github atomistic-machine-learning / schnetpack / src / schnetpack / utils / script_utils / parsing.py View on Github external
QM9.G,
            QM9.Cv,
        ],
    )
    qm9_parser.add_argument(
        "--remove_uncharacterized",
        help="Remove uncharacterized molecules from QM9 (default: %(default)s)",
        action="store_true",
    )

    ani1_parser = argparse.ArgumentParser(add_help=False, parents=[data_parser])
    ani1_parser.add_argument(
        "--property",
        type=str,
        help="Database property to be predicted (default: %(default)s)",
        default=ANI1.energy,
        choices=[ANI1.energy],
    )
    ani1_parser.add_argument(
        "--num_heavy_atoms",
        type=int,
        help="Number of heavy atoms that will be loaded into the database."
        " (default: %(default)s)",
        default=8,
    )
    matproj_parser = argparse.ArgumentParser(add_help=False, parents=[data_parser])
    matproj_parser.add_argument(
        "--property",
        type=str,
        help="Database property to be predicted" " (default: %(default)s)",
        default=MaterialsProject.EformationPerAtom,
        choices=[
github atomistic-machine-learning / schnetpack / src / schnetpack / sacred / dataset_ingredients.py View on Github external
def get_ani1(dbpath, num_heavy_atoms, dataset_properties):
    """
    Args:
        dbpath (str): path to the local database
        num_heavy_atoms (int): max number of heavy atoms per molecule
        dataset_properties (list): properties of the dataset

    Returns:
        AtomsData object

    """
    return ANI1(dbpath, num_heavy_atoms=num_heavy_atoms, properties=dataset_properties)
github atomistic-machine-learning / schnetpack / src / schnetpack / sacred / dataset_ingredients.py View on Github external
def ani1():
    """
    Default settings for ANI1 dataset.

    Adds:
        num_heavy_atoms (int): maximum number of heavy atoms
    """
    dbpath = "./data/ani1.db"
    dataset = "ANI1"
    num_heavy_atoms = 2
    property_mapping = {Properties.energy: ANI1.energy}
github atomistic-machine-learning / schnetpack / src / schnetpack / utils / script_utils / data.py View on Github external
"""
    if args.dataset == "qm9":
        if logging:
            logging.info("QM9 will be loaded...")
        qm9 = spk.datasets.QM9(
            args.datapath,
            download=True,
            load_only=[args.property],
            collect_triples=args.model == "wacsf",
            remove_uncharacterized=args.remove_uncharacterized,
        )
        return qm9
    elif args.dataset == "ani1":
        if logging:
            logging.info("ANI1 will be loaded...")
        ani1 = spk.datasets.ANI1(
            args.datapath,
            download=True,
            load_only=[args.property],
            collect_triples=args.model == "wacsf",
            num_heavy_atoms=args.num_heavy_atoms,
        )
        return ani1
    elif args.dataset == "md17":
        if logging:
            logging.info("MD17 will be loaded...")
        md17 = spk.datasets.MD17(
            args.datapath,
            args.molecule,
            download=True,
            collect_triples=args.model == "wacsf",
        )
github atomistic-machine-learning / schnetpack / src / scripts / schnetpack_ani1.py View on Github external
add_subparsers,
    get_loaders,
    get_statistics,
)


logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))


if __name__ == "__main__":
    # parse arguments
    parser = get_main_parser()
    add_subparsers(
        parser,
        defaults=dict(property=ANI1.energy),
        choices=dict(property=[ANI1.energy]),
    )
    args = parser.parse_args()
    train_args = setup_run(args)

    # set device
    device = torch.device("cuda" if args.cuda else "cpu")

    # define metrics
    metrics = [
        schnetpack.train.metrics.MeanAbsoluteError(
            train_args.property, train_args.property
        ),
        schnetpack.train.metrics.RootMeanSquaredError(
            train_args.property, train_args.property
        ),
    ]
github atomistic-machine-learning / schnetpack / src / scripts / schnetpack_ani1.py View on Github external
# set device
    device = torch.device("cuda" if args.cuda else "cpu")

    # define metrics
    metrics = [
        schnetpack.train.metrics.MeanAbsoluteError(
            train_args.property, train_args.property
        ),
        schnetpack.train.metrics.RootMeanSquaredError(
            train_args.property, train_args.property
        ),
    ]

    # build dataset
    logging.info("ANI1 will be loaded...")
    ani1 = spk.datasets.ANI1(
        args.datapath,
        download=True,
        load_only=[train_args.property],
        collect_triples=args.model == "wacsf",
    )

    # get atomrefs
    atomref = ani1.get_atomrefs(train_args.property)

    # splits the dataset in test, val, train sets
    split_path = os.path.join(args.modelpath, "split.npz")
    train_loader, val_loader, test_loader = get_loaders(
        args, dataset=ani1, split_path=split_path, logging=logging
    )

    if args.mode == "train":