Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
pooling_mode = {
QM9.mu: "sum",
QM9.alpha: "sum",
QM9.homo: "avg",
QM9.lumo: "avg",
QM9.gap: "avg",
QM9.r2: "sum",
QM9.zpve: "sum",
QM9.U0: "sum",
QM9.U: "sum",
QM9.H: "sum",
QM9.G: "sum",
QM9.Cv: "sum",
ANI1.energy: "sum",
MD17.energy: "sum",
MaterialsProject.EformationPerAtom: "avg",
MaterialsProject.EPerAtom: "avg",
MaterialsProject.BandGap: "avg",
MaterialsProject.TotalMagnetization: "sum",
OrganicMaterialsDatabase.BandGap: "avg",
}
def get_divide_by_atoms(args):
"""
Get 'divide_by_atoms'-parameter depending on run arguments.
"""
if args.dataset == "custom":
return args.aggregation_mode == "sum"
return divide_by_atoms[args.property]
def md17():
"""
Default settings for MD17 dataset.
Adds:
molecule (str): name of the molecule that is included in MD17
"""
dbpath = "./data"
dataset = "MD17"
molecule = "aspirin"
property_mapping = {Properties.energy: MD17.energy, Properties.forces: MD17.forces}
return qm9
elif args.dataset == "ani1":
if logging:
logging.info("ANI1 will be loaded...")
ani1 = spk.datasets.ANI1(
args.datapath,
download=True,
load_only=[args.property],
collect_triples=args.model == "wacsf",
num_heavy_atoms=args.num_heavy_atoms,
)
return ani1
elif args.dataset == "md17":
if logging:
logging.info("MD17 will be loaded...")
md17 = spk.datasets.MD17(
args.datapath,
args.molecule,
download=True,
collect_triples=args.model == "wacsf",
)
return md17
elif args.dataset == "matproj":
if logging:
logging.info("Materials project will be loaded...")
mp = spk.datasets.MaterialsProject(
args.datapath,
args.cutoff,
apikey=args.apikey,
download=True,
load_only=[args.property],
)
def get_md17(dbpath, molecule, dataset_properties):
"""
Args:
dbpath (str): path to the local database
molecule (str): name of a molecule that is contained in the MD17 dataset
dataset_properties (list): properties of the dataset
Returns:
AtomsData object
"""
return MD17(dbpath, molecule=molecule, properties=dataset_properties)
type=bool,
default="False",
)
md17_parser = data_subparsers.add_parser("md17", help="MD17 datasets")
md17_parser.add_argument(
"molecule", help="Molecule dataset", choices=dset.MD17.existing_datasets
)
md17_parser.add_argument("dbpath", help="Destination path")
args = main_parser.parse_args()
if args.dataset == "qm9":
dset.QM9(args.dbpath, True)
if args.dataset == "md17":
dset.MD17(args.dbpath, args.molecule)
else:
print("Unknown dataset!")
def get_output_module(args, representation, mean, stddev, atomref):
derivative = spk.utils.get_derivative(args)
negative_dr = spk.utils.get_negative_dr(args)
contributions = spk.utils.get_contributions(args)
stress = spk.utils.get_stress(args)
if args.dataset == "md17" and not args.ignore_forces:
derivative = spk.datasets.MD17.forces
output_module_str = spk.utils.get_module_str(args)
if output_module_str == "dipole_moment":
return spk.atomistic.output_modules.DipoleMoment(
args.features,
predict_magnitude=True,
mean=mean[args.property],
stddev=stddev[args.property],
property=args.property,
contributions=contributions,
)
elif output_module_str == "electronic_spatial_extent":
return spk.atomistic.output_modules.ElectronicSpatialExtent(
args.features,
mean=mean[args.property],
stddev=stddev[args.property],
property=args.property,
divide_by_atoms = {
QM9.mu: True,
QM9.alpha: True,
QM9.homo: False,
QM9.lumo: False,
QM9.gap: False,
QM9.r2: True,
QM9.zpve: True,
QM9.U0: True,
QM9.U: True,
QM9.H: True,
QM9.G: True,
QM9.Cv: True,
ANI1.energy: True,
MD17.energy: True,
MaterialsProject.EformationPerAtom: False,
MaterialsProject.EPerAtom: False,
MaterialsProject.BandGap: False,
MaterialsProject.TotalMagnetization: True,
OrganicMaterialsDatabase.BandGap: False,
}
pooling_mode = {
QM9.mu: "sum",
QM9.alpha: "sum",
QM9.homo: "avg",
QM9.lumo: "avg",
QM9.gap: "avg",
QM9.r2: "sum",
QM9.zpve: "sum",
QM9.U0: "sum",