How to use the schnetpack.datasets.QM9 function in schnetpack

To help you get started, we’ve selected a few schnetpack examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github atomistic-machine-learning / schnetpack / tests / fixtures / qm9.py View on Github external
def qm9_dataset(qm9_dbpath):
    print(os.path.exists(qm9_dbpath))
    return QM9(qm9_dbpath)
github atomistic-machine-learning / schnetpack / src / examples / qm9_tutorial.py View on Github external
from torch.optim import Adam
import os
import schnetpack as spk
import schnetpack.atomistic.model
from schnetpack.datasets import QM9
from schnetpack.train import Trainer, CSVHook, ReduceLROnPlateauHook
from schnetpack.train.metrics import MeanAbsoluteError
from schnetpack.train import build_mse_loss


logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))

# basic settings
model_dir = "qm9_model"
os.makedirs(model_dir)
properties = [QM9.U0]

# data preparation
logging.info("get dataset")
dataset = QM9("data/qm9.db", load_only=[QM9.U0])
train, val, test = spk.train_test_split(
    dataset, 1000, 100, os.path.join(model_dir, "split.npz")
)
train_loader = spk.AtomsLoader(train, batch_size=64, shuffle=True)
val_loader = spk.AtomsLoader(val, batch_size=64)

# statistics
atomrefs = dataset.get_atomref(properties)
means, stddevs = train_loader.get_statistics(
    properties, divide_by_atoms=True, single_atom_ref=atomrefs
)
github atomistic-machine-learning / schnetpack / src / examples / qm9_tutorial.py View on Github external
from schnetpack.datasets import QM9
from schnetpack.train import Trainer, CSVHook, ReduceLROnPlateauHook
from schnetpack.train.metrics import MeanAbsoluteError
from schnetpack.train import build_mse_loss


logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))

# basic settings
model_dir = "qm9_model"
os.makedirs(model_dir)
properties = [QM9.U0]

# data preparation
logging.info("get dataset")
dataset = QM9("data/qm9.db", load_only=[QM9.U0])
train, val, test = spk.train_test_split(
    dataset, 1000, 100, os.path.join(model_dir, "split.npz")
)
train_loader = spk.AtomsLoader(train, batch_size=64, shuffle=True)
val_loader = spk.AtomsLoader(val, batch_size=64)

# statistics
atomrefs = dataset.get_atomref(properties)
means, stddevs = train_loader.get_statistics(
    properties, divide_by_atoms=True, single_atom_ref=atomrefs
)

# model build
logging.info("build model")
representation = spk.SchNet(n_interactions=6)
output_modules = [
github atomistic-machine-learning / schnetpack / src / schnetpack / utils / script_utils / settings.py View on Github external
__all__ = [
    "divide_by_atoms",
    "pooling_mode",
    "get_divide_by_atoms",
    "get_pooling_mode",
    "get_negative_dr",
    "get_derivative",
    "get_contributions",
    "get_module_str",
    "get_environment_provider",
]


divide_by_atoms = {
    QM9.mu: True,
    QM9.alpha: True,
    QM9.homo: False,
    QM9.lumo: False,
    QM9.gap: False,
    QM9.r2: True,
    QM9.zpve: True,
    QM9.U0: True,
    QM9.U: True,
    QM9.H: True,
    QM9.G: True,
    QM9.Cv: True,
    ANI1.energy: True,
    MD17.energy: True,
    MaterialsProject.EformationPerAtom: False,
    MaterialsProject.EPerAtom: False,
    MaterialsProject.BandGap: False,
github atomistic-machine-learning / schnetpack / src / schnetpack / utils / script_utils / parsing.py View on Github external
data_parser = argparse.ArgumentParser(add_help=False)
    data_parser.add_argument(
        "--environment_provider",
        type=str,
        default="simple",
        choices=["simple", "ase", "torch"],
        help="Environment provider for dataset. (default: %(default)s)",
    )

    # qm9
    qm9_parser = argparse.ArgumentParser(add_help=False, parents=[data_parser])
    qm9_parser.add_argument(
        "--property",
        type=str,
        help="Database property to be predicted (default: %(default)s)",
        default=QM9.U0,
        choices=[
            QM9.A,
            QM9.B,
            QM9.C,
            QM9.mu,
            QM9.alpha,
            QM9.homo,
            QM9.lumo,
            QM9.gap,
            QM9.r2,
            QM9.zpve,
            QM9.U0,
            QM9.U,
            QM9.H,
            QM9.G,
            QM9.Cv,
github atomistic-machine-learning / schnetpack / src / scripts / schnetpack_load.py View on Github external
"--remove_uncharacterized",
        help="Remove uncharacterized molecules from QM9",
        type=bool,
        default="False",
    )

    md17_parser = data_subparsers.add_parser("md17", help="MD17 datasets")
    md17_parser.add_argument(
        "molecule", help="Molecule dataset", choices=dset.MD17.existing_datasets
    )
    md17_parser.add_argument("dbpath", help="Destination path")

    args = main_parser.parse_args()

    if args.dataset == "qm9":
        dset.QM9(args.dbpath, True)
    if args.dataset == "md17":
        dset.MD17(args.dbpath, args.molecule)
    else:
        print("Unknown dataset!")
github atomistic-machine-learning / schnetpack / src / schnetpack / utils / script_utils / data.py View on Github external
def get_dataset(args, logging=None):
    """
    Get dataset from arguments.

    Args:
        args (argparse.Namespace): parsed arguments
        logging: logger

    Returns:
        spk.data.AtomsData: dataset

    """
    if args.dataset == "qm9":
        if logging:
            logging.info("QM9 will be loaded...")
        qm9 = spk.datasets.QM9(
            args.datapath,
            download=True,
            load_only=[args.property],
            collect_triples=args.model == "wacsf",
            remove_uncharacterized=args.remove_uncharacterized,
        )
        return qm9
    elif args.dataset == "ani1":
        if logging:
            logging.info("ANI1 will be loaded...")
        ani1 = spk.datasets.ANI1(
            args.datapath,
            download=True,
            load_only=[args.property],
            collect_triples=args.model == "wacsf",
            num_heavy_atoms=args.num_heavy_atoms,
github atomistic-machine-learning / schnetpack / src / sacred_scripts / download.py View on Github external
def download(dbpath, dataset, folds, cutoff, api_key, molecules):
    dataset = dataset.upper()
    if dataset == 'QM9':
        qm9 = QM9(dbpath)
    elif dataset == 'ISO17':
        for fold in folds:
            iso17 = ISO17(dbpath, fold)
    elif dataset == 'ANI1':
        ani1 = ANI1(dbpath)
    elif dataset == 'MD17':
        for molecule in molecules:
            md17 = MD17(dbpath + molecule + '.db', dataset=molecule)
    elif dataset == 'MATPROJ':
        matproj = MaterialsProject(dbpath, cutoff, api_key)
    else:
        raise NotImplementedError
github atomistic-machine-learning / schnetpack / src / schnetpack / utils / script_utils / settings.py View on Github external
QM9.G: True,
    QM9.Cv: True,
    ANI1.energy: True,
    MD17.energy: True,
    MaterialsProject.EformationPerAtom: False,
    MaterialsProject.EPerAtom: False,
    MaterialsProject.BandGap: False,
    MaterialsProject.TotalMagnetization: True,
    OrganicMaterialsDatabase.BandGap: False,
}

pooling_mode = {
    QM9.mu: "sum",
    QM9.alpha: "sum",
    QM9.homo: "avg",
    QM9.lumo: "avg",
    QM9.gap: "avg",
    QM9.r2: "sum",
    QM9.zpve: "sum",
    QM9.U0: "sum",
    QM9.U: "sum",
    QM9.H: "sum",
    QM9.G: "sum",
    QM9.Cv: "sum",
    ANI1.energy: "sum",
    MD17.energy: "sum",
    MaterialsProject.EformationPerAtom: "avg",
    MaterialsProject.EPerAtom: "avg",
    MaterialsProject.BandGap: "avg",
    MaterialsProject.TotalMagnetization: "sum",
    OrganicMaterialsDatabase.BandGap: "avg",
}
github atomistic-machine-learning / schnetpack / src / schnetpack / sacred / dataset_ingredients.py View on Github external
def qm9():
    """Default settings for QM9 dataset."""
    dbpath = "./data/qm9.db"
    dataset = "QM9"
    property_mapping = {
        Properties.energy: QM9.U0,
        Properties.dipole_moment: QM9.mu,
        Properties.iso_polarizability: QM9.alpha,
    }