Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def get_matproj(dbpath, cutoff, api_key, dataset_properties):
"""
Args:
dbpath (str): path to the local database
cutoff (float): cutoff radius
api_key (str): personal api_key for materialsproject.org
dataset_properties (list): properties of the dataset
Returns:
AtomsData object
"""
return MaterialsProject(dbpath, cutoff, api_key, properties=dataset_properties)
def download(dbpath, dataset, folds, cutoff, api_key, molecules):
dataset = dataset.upper()
if dataset == 'QM9':
qm9 = QM9(dbpath)
elif dataset == 'ISO17':
for fold in folds:
iso17 = ISO17(dbpath, fold)
elif dataset == 'ANI1':
ani1 = ANI1(dbpath)
elif dataset == 'MD17':
for molecule in molecules:
md17 = MD17(dbpath + molecule + '.db', dataset=molecule)
elif dataset == 'MATPROJ':
matproj = MaterialsProject(dbpath, cutoff, api_key)
else:
raise NotImplementedError
QM9.mu: True,
QM9.alpha: True,
QM9.homo: False,
QM9.lumo: False,
QM9.gap: False,
QM9.r2: True,
QM9.zpve: True,
QM9.U0: True,
QM9.U: True,
QM9.H: True,
QM9.G: True,
QM9.Cv: True,
ANI1.energy: True,
MD17.energy: True,
MaterialsProject.EformationPerAtom: False,
MaterialsProject.EPerAtom: False,
MaterialsProject.BandGap: False,
MaterialsProject.TotalMagnetization: True,
OrganicMaterialsDatabase.BandGap: False,
}
pooling_mode = {
QM9.mu: "sum",
QM9.alpha: "sum",
QM9.homo: "avg",
QM9.lumo: "avg",
QM9.gap: "avg",
QM9.r2: "sum",
QM9.zpve: "sum",
QM9.U0: "sum",
QM9.U: "sum",
QM9.H: "sum",
def matproj():
"""
Default settings for Materials Project dataset.
Adds:
cutoff (float):
api_key (str): personal api key from https://materialsproject.org
"""
dbpath = "./data/matproj.db"
dataset = "MATPROJ"
cutoff = 5.0
api_key = ""
property_mapping = {Properties.energy_contributions: MaterialsProject.EPerAtom}
# set device
device = torch.device("cuda" if args.cuda else "cpu")
# define metrics
metrics = [
schnetpack.train.metrics.MeanAbsoluteError(
train_args.property, train_args.property
),
schnetpack.train.metrics.RootMeanSquaredError(
train_args.property, train_args.property
),
]
# build dataset
mp = spk.datasets.MaterialsProject(
args.datapath,
args.cutoff,
apikey=args.apikey,
download=True,
load_only=[train_args.property],
)
# get atomrefs
atomref = mp.get_atomrefs(train_args.property)
# splits the dataset in test, val, train sets
split_path = os.path.join(args.modelpath, "split.npz")
train_loader, val_loader, test_loader = get_loaders(
args, dataset=mp, split_path=split_path, logging=logging
)
# parse arguments
parser = get_main_parser()
add_matproj_arguments(parser)
add_subparsers(
parser,
defaults=dict(
property=MaterialsProject.EformationPerAtom,
features=64,
aggregation_mode="mean",
),
choices=dict(
property=[
MaterialsProject.EformationPerAtom,
MaterialsProject.EPerAtom,
MaterialsProject.BandGap,
MaterialsProject.TotalMagnetization,
]
),
)
args = parser.parse_args()
train_args = setup_run(args)
# set device
device = torch.device("cuda" if args.cuda else "cpu")
# define metrics
metrics = [
schnetpack.train.metrics.MeanAbsoluteError(
train_args.property, train_args.property
),
schnetpack.train.metrics.RootMeanSquaredError(
train_args.property, train_args.property
)
return ani1
elif args.dataset == "md17":
if logging:
logging.info("MD17 will be loaded...")
md17 = spk.datasets.MD17(
args.datapath,
args.molecule,
download=True,
collect_triples=args.model == "wacsf",
)
return md17
elif args.dataset == "matproj":
if logging:
logging.info("Materials project will be loaded...")
mp = spk.datasets.MaterialsProject(
args.datapath,
args.cutoff,
apikey=args.apikey,
download=True,
load_only=[args.property],
)
return mp
elif args.dataset == "omdb":
if logging:
logging.info("Organic Materials Database will be loaded...")
omdb = spk.datasets.OrganicMaterialsDatabase(
args.datapath, args.cutoff, download=True, load_only=[args.property]
)
return omdb
elif args.dataset == "custom":
if logging:
divide_by_atoms = {
QM9.mu: True,
QM9.alpha: True,
QM9.homo: False,
QM9.lumo: False,
QM9.gap: False,
QM9.r2: True,
QM9.zpve: True,
QM9.U0: True,
QM9.U: True,
QM9.H: True,
QM9.G: True,
QM9.Cv: True,
ANI1.energy: True,
MD17.energy: True,
MaterialsProject.EformationPerAtom: False,
MaterialsProject.EPerAtom: False,
MaterialsProject.BandGap: False,
MaterialsProject.TotalMagnetization: True,
OrganicMaterialsDatabase.BandGap: False,
}
pooling_mode = {
QM9.mu: "sum",
QM9.alpha: "sum",
QM9.homo: "avg",
QM9.lumo: "avg",
QM9.gap: "avg",
QM9.r2: "sum",
QM9.zpve: "sum",
QM9.U0: "sum",
QM9.U: "sum",
)
ani1_parser.add_argument(
"--num_heavy_atoms",
type=int,
help="Number of heavy atoms that will be loaded into the database."
" (default: %(default)s)",
default=8,
)
matproj_parser = argparse.ArgumentParser(add_help=False, parents=[data_parser])
matproj_parser.add_argument(
"--property",
type=str,
help="Database property to be predicted" " (default: %(default)s)",
default=MaterialsProject.EformationPerAtom,
choices=[
MaterialsProject.EformationPerAtom,
MaterialsProject.EPerAtom,
MaterialsProject.BandGap,
MaterialsProject.TotalMagnetization,
],
)
matproj_parser.add_argument(
"--apikey",
help="API key for Materials Project (see https://materialsproject.org/open)",
default=None,
)
md17_parser = argparse.ArgumentParser(add_help=False, parents=[data_parser])
md17_parser.add_argument(
"--property",
type=str,
help="Database property to be predicted" " (default: %(default)s)",
default=MD17.energy,