Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def download(dbpath, dataset, folds, cutoff, api_key, molecules):
dataset = dataset.upper()
if dataset == 'QM9':
qm9 = QM9(dbpath)
elif dataset == 'ISO17':
for fold in folds:
iso17 = ISO17(dbpath, fold)
elif dataset == 'ANI1':
ani1 = ANI1(dbpath)
elif dataset == 'MD17':
for molecule in molecules:
md17 = MD17(dbpath + molecule + '.db', dataset=molecule)
elif dataset == 'MATPROJ':
matproj = MaterialsProject(dbpath, cutoff, api_key)
else:
raise NotImplementedError
get_main_parser,
add_subparsers,
get_loaders,
get_statistics,
)
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
if __name__ == "__main__":
# parse arguments
parser = get_main_parser()
add_subparsers(
parser,
defaults=dict(property=ANI1.energy),
choices=dict(property=[ANI1.energy]),
)
args = parser.parse_args()
train_args = setup_run(args)
# set device
device = torch.device("cuda" if args.cuda else "cpu")
# define metrics
metrics = [
schnetpack.train.metrics.MeanAbsoluteError(
train_args.property, train_args.property
),
schnetpack.train.metrics.RootMeanSquaredError(
train_args.property, train_args.property
),
QM9.G,
QM9.Cv,
],
)
qm9_parser.add_argument(
"--remove_uncharacterized",
help="Remove uncharacterized molecules from QM9 (default: %(default)s)",
action="store_true",
)
ani1_parser = argparse.ArgumentParser(add_help=False, parents=[data_parser])
ani1_parser.add_argument(
"--property",
type=str,
help="Database property to be predicted (default: %(default)s)",
default=ANI1.energy,
choices=[ANI1.energy],
)
ani1_parser.add_argument(
"--num_heavy_atoms",
type=int,
help="Number of heavy atoms that will be loaded into the database."
" (default: %(default)s)",
default=8,
)
matproj_parser = argparse.ArgumentParser(add_help=False, parents=[data_parser])
matproj_parser.add_argument(
"--property",
type=str,
help="Database property to be predicted" " (default: %(default)s)",
default=MaterialsProject.EformationPerAtom,
choices=[
def get_ani1(dbpath, num_heavy_atoms, dataset_properties):
"""
Args:
dbpath (str): path to the local database
num_heavy_atoms (int): max number of heavy atoms per molecule
dataset_properties (list): properties of the dataset
Returns:
AtomsData object
"""
return ANI1(dbpath, num_heavy_atoms=num_heavy_atoms, properties=dataset_properties)
def ani1():
"""
Default settings for ANI1 dataset.
Adds:
num_heavy_atoms (int): maximum number of heavy atoms
"""
dbpath = "./data/ani1.db"
dataset = "ANI1"
num_heavy_atoms = 2
property_mapping = {Properties.energy: ANI1.energy}
"""
if args.dataset == "qm9":
if logging:
logging.info("QM9 will be loaded...")
qm9 = spk.datasets.QM9(
args.datapath,
download=True,
load_only=[args.property],
collect_triples=args.model == "wacsf",
remove_uncharacterized=args.remove_uncharacterized,
)
return qm9
elif args.dataset == "ani1":
if logging:
logging.info("ANI1 will be loaded...")
ani1 = spk.datasets.ANI1(
args.datapath,
download=True,
load_only=[args.property],
collect_triples=args.model == "wacsf",
num_heavy_atoms=args.num_heavy_atoms,
)
return ani1
elif args.dataset == "md17":
if logging:
logging.info("MD17 will be loaded...")
md17 = spk.datasets.MD17(
args.datapath,
args.molecule,
download=True,
collect_triples=args.model == "wacsf",
)
add_subparsers,
get_loaders,
get_statistics,
)
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
if __name__ == "__main__":
# parse arguments
parser = get_main_parser()
add_subparsers(
parser,
defaults=dict(property=ANI1.energy),
choices=dict(property=[ANI1.energy]),
)
args = parser.parse_args()
train_args = setup_run(args)
# set device
device = torch.device("cuda" if args.cuda else "cpu")
# define metrics
metrics = [
schnetpack.train.metrics.MeanAbsoluteError(
train_args.property, train_args.property
),
schnetpack.train.metrics.RootMeanSquaredError(
train_args.property, train_args.property
),
]
# set device
device = torch.device("cuda" if args.cuda else "cpu")
# define metrics
metrics = [
schnetpack.train.metrics.MeanAbsoluteError(
train_args.property, train_args.property
),
schnetpack.train.metrics.RootMeanSquaredError(
train_args.property, train_args.property
),
]
# build dataset
logging.info("ANI1 will be loaded...")
ani1 = spk.datasets.ANI1(
args.datapath,
download=True,
load_only=[train_args.property],
collect_triples=args.model == "wacsf",
)
# get atomrefs
atomref = ani1.get_atomrefs(train_args.property)
# splits the dataset in test, val, train sets
split_path = os.path.join(args.modelpath, "split.npz")
train_loader, val_loader, test_loader = get_loaders(
args, dataset=ani1, split_path=split_path, logging=logging
)
if args.mode == "train":