Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
contributions = spk.utils.get_contributions(args)
stress = spk.utils.get_stress(args)
if args.dataset == "md17" and not args.ignore_forces:
derivative = spk.datasets.MD17.forces
output_module_str = spk.utils.get_module_str(args)
if output_module_str == "dipole_moment":
return spk.atomistic.output_modules.DipoleMoment(
args.features,
predict_magnitude=True,
mean=mean[args.property],
stddev=stddev[args.property],
property=args.property,
contributions=contributions,
)
elif output_module_str == "electronic_spatial_extent":
return spk.atomistic.output_modules.ElectronicSpatialExtent(
args.features,
mean=mean[args.property],
stddev=stddev[args.property],
property=args.property,
contributions=contributions,
)
elif output_module_str == "atomwise":
return spk.atomistic.output_modules.Atomwise(
args.features,
aggregation_mode=spk.utils.get_pooling_mode(args),
mean=mean[args.property],
stddev=stddev[args.property],
atomref=atomref[args.property],
property=args.property,
derivative=derivative,
negative_dr=negative_dr,
logging.info("calculate statistics...")
mean, stddev = get_statistics(
split_path,
train_loader,
train_args,
atomref,
logging=logging,
per_atom=True,
)
# build representation
representation = get_representation(args, train_loader)
# build output module
if args.model == "schnet":
output_module = spk.atomistic.output_modules.Atomwise(
args.features,
aggregation_mode=args.aggregation_mode,
mean=mean[args.property],
stddev=stddev[args.property],
atomref=atomref[args.property],
property=args.property,
derivative="forces",
negative_dr=True,
)
elif args.model == "wascf":
elements = frozenset((atomic_numbers[i] for i in sorted(args.elements)))
output_module = ElementalAtomwise(
representation.n_symfuncs,
n_hidden=args.n_nodes,
n_layers=args.n_layers,
mean=mean[args.property],
args, dataset=mp, split_path=split_path, logging=logging
)
if args.mode == "train":
# get statistics
logging.info("calculate statistics...")
mean, stddev = get_statistics(
split_path, train_loader, train_args, atomref, logging=logging
)
# build representation
representation = get_representation(train_args, train_loader=train_loader)
# build output module
if args.model == "schnet":
output_module = schnetpack.atomistic.output_modules.Atomwise(
args.features,
aggregation_mode=args.aggregation_mode,
mean=mean[args.property],
stddev=stddev[args.property],
atomref=None,
train_embeddings=True,
property=args.property,
)
else:
raise NotImplementedError
# build AtomisticModel
model = get_model(
representation=representation,
output_modules=output_module,
parallelize=args.parallel,
args, dataset=omdb, split_path=split_path, logging=logging
)
if args.mode == "train":
# get statistics
logging.info("calculate statistics...")
mean, stddev = get_statistics(
split_path, train_loader, train_args, atomref, logging=logging
)
# build representation
representation = get_representation(train_args, train_loader=train_loader)
# build output module
if args.model == "schnet":
output_module = schnetpack.atomistic.output_modules.Atomwise(
args.features,
aggregation_mode=args.aggregation_mode,
mean=mean[args.property],
stddev=stddev[args.property],
atomref=atomref[args.property],
train_embeddings=True,
property=args.property,
)
else:
raise NotImplementedError
# build AtomisticModel
model = get_model(
representation=representation,
output_modules=output_module,
parallelize=args.parallel,
atomref=atomref[args.property],
property=args.property,
)
elif args.model == "wacsf":
elements = frozenset((atomic_numbers[i] for i in sorted(args.elements)))
if args.property == QM9.mu:
output_module = schnetpack.atomistic.output_modules.ElementalDipoleMoment(
representation.n_symfuncs,
n_hidden=args.n_nodes,
n_layers=args.n_layers,
predict_magnitude=True,
elements=elements,
property=args.property,
)
else:
output_module = schnetpack.atomistic.output_modules.ElementalAtomwise(
representation.n_symfuncs,
n_hidden=args.n_nodes,
n_layers=args.n_layers,
aggregation_mode=args.aggregation_mode,
mean=mean[args.property],
stddev=stddev[args.property],
atomref=atomref[args.property],
elements=elements,
property=args.property,
)
else:
raise NotImplementedError
# build AtomisticModel
model = get_model(
representation=representation,
def get_output_module(args, representation, mean, stddev, atomref):
derivative = spk.utils.get_derivative(args)
negative_dr = spk.utils.get_negative_dr(args)
contributions = spk.utils.get_contributions(args)
stress = spk.utils.get_stress(args)
if args.dataset == "md17" and not args.ignore_forces:
derivative = spk.datasets.MD17.forces
output_module_str = spk.utils.get_module_str(args)
if output_module_str == "dipole_moment":
return spk.atomistic.output_modules.DipoleMoment(
args.features,
predict_magnitude=True,
mean=mean[args.property],
stddev=stddev[args.property],
property=args.property,
contributions=contributions,
)
elif output_module_str == "electronic_spatial_extent":
return spk.atomistic.output_modules.ElectronicSpatialExtent(
args.features,
mean=mean[args.property],
stddev=stddev[args.property],
property=args.property,
contributions=contributions,
)
elif output_module_str == "atomwise":
predict_magnitude=True,
mean=mean[args.property],
stddev=stddev[args.property],
property=args.property,
contributions=contributions,
)
elif output_module_str == "electronic_spatial_extent":
return spk.atomistic.output_modules.ElectronicSpatialExtent(
args.features,
mean=mean[args.property],
stddev=stddev[args.property],
property=args.property,
contributions=contributions,
)
elif output_module_str == "atomwise":
return spk.atomistic.output_modules.Atomwise(
args.features,
aggregation_mode=spk.utils.get_pooling_mode(args),
mean=mean[args.property],
stddev=stddev[args.property],
atomref=atomref[args.property],
property=args.property,
derivative=derivative,
negative_dr=negative_dr,
contributions=contributions,
stress=stress,
)
elif output_module_str == "polarizability":
return spk.atomistic.output_modules.Polarizability(
args.features,
aggregation_mode=spk.utils.get_pooling_mode(args),
property=args.property,
return spk.atomistic.output_modules.Polarizability(
args.features,
aggregation_mode=spk.utils.get_pooling_mode(args),
property=args.property,
)
elif output_module_str == "isotropic_polarizability":
return spk.atomistic.output_modules.Polarizability(
args.features,
aggregation_mode=spk.utils.get_pooling_mode(args),
property=args.property,
isotropic=True,
)
# wacsf modules
elif output_module_str == "elemental_dipole_moment":
elements = frozenset((atomic_numbers[i] for i in sorted(args.elements)))
return spk.atomistic.output_modules.ElementalDipoleMoment(
representation.n_symfuncs,
n_hidden=args.n_nodes,
n_layers=args.n_layers,
predict_magnitude=True,
elements=elements,
property=args.property,
)
elif output_module_str == "elemental_atomwise":
elements = frozenset((atomic_numbers[i] for i in sorted(args.elements)))
return spk.atomistic.output_modules.ElementalAtomwise(
representation.n_symfuncs,
n_hidden=args.n_nodes,
n_layers=args.n_layers,
aggregation_mode=spk.utils.get_pooling_mode(args),
mean=mean[args.property],
stddev=stddev[args.property],
# build representation
representation = get_representation(args, train_loader=train_loader)
# build output module
if args.model == "schnet":
if args.property == QM9.mu:
output_module = schnetpack.atomistic.output_modules.DipoleMoment(
args.features,
predict_magnitude=True,
mean=mean[args.property],
stddev=stddev[args.property],
property=args.property,
)
else:
output_module = schnetpack.atomistic.output_modules.Atomwise(
args.features,
aggregation_mode=args.aggregation_mode,
mean=mean[args.property],
stddev=stddev[args.property],
atomref=atomref[args.property],
property=args.property,
)
elif args.model == "wacsf":
elements = frozenset((atomic_numbers[i] for i in sorted(args.elements)))
if args.property == QM9.mu:
output_module = schnetpack.atomistic.output_modules.ElementalDipoleMoment(
representation.n_symfuncs,
n_hidden=args.n_nodes,
n_layers=args.n_layers,
predict_magnitude=True,
elements=elements,