Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def qm9_dataset(qm9_dbpath):
print(os.path.exists(qm9_dbpath))
return QM9(qm9_dbpath)
from torch.optim import Adam
import os
import schnetpack as spk
import schnetpack.atomistic.model
from schnetpack.datasets import QM9
from schnetpack.train import Trainer, CSVHook, ReduceLROnPlateauHook
from schnetpack.train.metrics import MeanAbsoluteError
from schnetpack.train import build_mse_loss
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
# basic settings
model_dir = "qm9_model"
os.makedirs(model_dir)
properties = [QM9.U0]
# data preparation
logging.info("get dataset")
dataset = QM9("data/qm9.db", load_only=[QM9.U0])
train, val, test = spk.train_test_split(
dataset, 1000, 100, os.path.join(model_dir, "split.npz")
)
train_loader = spk.AtomsLoader(train, batch_size=64, shuffle=True)
val_loader = spk.AtomsLoader(val, batch_size=64)
# statistics
atomrefs = dataset.get_atomref(properties)
means, stddevs = train_loader.get_statistics(
properties, divide_by_atoms=True, single_atom_ref=atomrefs
)
from schnetpack.datasets import QM9
from schnetpack.train import Trainer, CSVHook, ReduceLROnPlateauHook
from schnetpack.train.metrics import MeanAbsoluteError
from schnetpack.train import build_mse_loss
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
# basic settings
model_dir = "qm9_model"
os.makedirs(model_dir)
properties = [QM9.U0]
# data preparation
logging.info("get dataset")
dataset = QM9("data/qm9.db", load_only=[QM9.U0])
train, val, test = spk.train_test_split(
dataset, 1000, 100, os.path.join(model_dir, "split.npz")
)
train_loader = spk.AtomsLoader(train, batch_size=64, shuffle=True)
val_loader = spk.AtomsLoader(val, batch_size=64)
# statistics
atomrefs = dataset.get_atomref(properties)
means, stddevs = train_loader.get_statistics(
properties, divide_by_atoms=True, single_atom_ref=atomrefs
)
# model build
logging.info("build model")
representation = spk.SchNet(n_interactions=6)
output_modules = [
__all__ = [
"divide_by_atoms",
"pooling_mode",
"get_divide_by_atoms",
"get_pooling_mode",
"get_negative_dr",
"get_derivative",
"get_contributions",
"get_module_str",
"get_environment_provider",
]
divide_by_atoms = {
QM9.mu: True,
QM9.alpha: True,
QM9.homo: False,
QM9.lumo: False,
QM9.gap: False,
QM9.r2: True,
QM9.zpve: True,
QM9.U0: True,
QM9.U: True,
QM9.H: True,
QM9.G: True,
QM9.Cv: True,
ANI1.energy: True,
MD17.energy: True,
MaterialsProject.EformationPerAtom: False,
MaterialsProject.EPerAtom: False,
MaterialsProject.BandGap: False,
data_parser = argparse.ArgumentParser(add_help=False)
data_parser.add_argument(
"--environment_provider",
type=str,
default="simple",
choices=["simple", "ase", "torch"],
help="Environment provider for dataset. (default: %(default)s)",
)
# qm9
qm9_parser = argparse.ArgumentParser(add_help=False, parents=[data_parser])
qm9_parser.add_argument(
"--property",
type=str,
help="Database property to be predicted (default: %(default)s)",
default=QM9.U0,
choices=[
QM9.A,
QM9.B,
QM9.C,
QM9.mu,
QM9.alpha,
QM9.homo,
QM9.lumo,
QM9.gap,
QM9.r2,
QM9.zpve,
QM9.U0,
QM9.U,
QM9.H,
QM9.G,
QM9.Cv,
"--remove_uncharacterized",
help="Remove uncharacterized molecules from QM9",
type=bool,
default="False",
)
md17_parser = data_subparsers.add_parser("md17", help="MD17 datasets")
md17_parser.add_argument(
"molecule", help="Molecule dataset", choices=dset.MD17.existing_datasets
)
md17_parser.add_argument("dbpath", help="Destination path")
args = main_parser.parse_args()
if args.dataset == "qm9":
dset.QM9(args.dbpath, True)
if args.dataset == "md17":
dset.MD17(args.dbpath, args.molecule)
else:
print("Unknown dataset!")
def get_dataset(args, logging=None):
"""
Get dataset from arguments.
Args:
args (argparse.Namespace): parsed arguments
logging: logger
Returns:
spk.data.AtomsData: dataset
"""
if args.dataset == "qm9":
if logging:
logging.info("QM9 will be loaded...")
qm9 = spk.datasets.QM9(
args.datapath,
download=True,
load_only=[args.property],
collect_triples=args.model == "wacsf",
remove_uncharacterized=args.remove_uncharacterized,
)
return qm9
elif args.dataset == "ani1":
if logging:
logging.info("ANI1 will be loaded...")
ani1 = spk.datasets.ANI1(
args.datapath,
download=True,
load_only=[args.property],
collect_triples=args.model == "wacsf",
num_heavy_atoms=args.num_heavy_atoms,
def download(dbpath, dataset, folds, cutoff, api_key, molecules):
dataset = dataset.upper()
if dataset == 'QM9':
qm9 = QM9(dbpath)
elif dataset == 'ISO17':
for fold in folds:
iso17 = ISO17(dbpath, fold)
elif dataset == 'ANI1':
ani1 = ANI1(dbpath)
elif dataset == 'MD17':
for molecule in molecules:
md17 = MD17(dbpath + molecule + '.db', dataset=molecule)
elif dataset == 'MATPROJ':
matproj = MaterialsProject(dbpath, cutoff, api_key)
else:
raise NotImplementedError
QM9.G: True,
QM9.Cv: True,
ANI1.energy: True,
MD17.energy: True,
MaterialsProject.EformationPerAtom: False,
MaterialsProject.EPerAtom: False,
MaterialsProject.BandGap: False,
MaterialsProject.TotalMagnetization: True,
OrganicMaterialsDatabase.BandGap: False,
}
pooling_mode = {
QM9.mu: "sum",
QM9.alpha: "sum",
QM9.homo: "avg",
QM9.lumo: "avg",
QM9.gap: "avg",
QM9.r2: "sum",
QM9.zpve: "sum",
QM9.U0: "sum",
QM9.U: "sum",
QM9.H: "sum",
QM9.G: "sum",
QM9.Cv: "sum",
ANI1.energy: "sum",
MD17.energy: "sum",
MaterialsProject.EformationPerAtom: "avg",
MaterialsProject.EPerAtom: "avg",
MaterialsProject.BandGap: "avg",
MaterialsProject.TotalMagnetization: "sum",
OrganicMaterialsDatabase.BandGap: "avg",
}
def qm9():
"""Default settings for QM9 dataset."""
dbpath = "./data/qm9.db"
dataset = "QM9"
property_mapping = {
Properties.energy: QM9.U0,
Properties.dipole_moment: QM9.mu,
Properties.iso_polarizability: QM9.alpha,
}