Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def energy_rmse():
return RootMeanSquaredError("_energy", "y", name="energy")
def forces_rmse():
return RootMeanSquaredError("_forces", "dydx", name="forces", element_wise=True)
parser,
defaults=dict(property=ANI1.energy),
choices=dict(property=[ANI1.energy]),
)
args = parser.parse_args()
train_args = setup_run(args)
# set device
device = torch.device("cuda" if args.cuda else "cpu")
# define metrics
metrics = [
schnetpack.train.metrics.MeanAbsoluteError(
train_args.property, train_args.property
),
schnetpack.train.metrics.RootMeanSquaredError(
train_args.property, train_args.property
),
]
# build dataset
logging.info("ANI1 will be loaded...")
ani1 = spk.datasets.ANI1(
args.datapath,
download=True,
load_only=[train_args.property],
collect_triples=args.model == "wacsf",
)
# get atomrefs
atomref = ani1.get_atomrefs(train_args.property)
def get_metrics(args):
# setup property metrics
metrics = [
spk.train.metrics.MeanAbsoluteError(args.property, args.property),
spk.train.metrics.RootMeanSquaredError(args.property, args.property),
]
# add metrics for derivative
derivative = spk.utils.get_derivative(args)
if derivative is not None:
metrics += [
spk.train.metrics.MeanAbsoluteError(
derivative, derivative, element_wise=True
),
spk.train.metrics.RootMeanSquaredError(
derivative, derivative, element_wise=True
),
]
return metrics
choices=dict(property=[MD17.energy, MD17.forces]),
)
args = parser.parse_args()
train_args = setup_run(args)
# set device
device = torch.device("cuda" if args.cuda else "cpu")
# define metrics
metrics = [
schnetpack.train.metrics.MeanAbsoluteError(MD17.energy, MD17.energy),
schnetpack.train.metrics.RootMeanSquaredError(MD17.energy, MD17.energy),
schnetpack.train.metrics.MeanAbsoluteError(
MD17.forces, MD17.forces, element_wise=True
),
schnetpack.train.metrics.RootMeanSquaredError(
MD17.forces, MD17.forces, element_wise=True
),
]
# build dataset
logging.info("MD17 will be loaded...")
md17 = MD17(
args.datapath,
args.molecule,
download=True,
collect_triples=args.model == "wacsf",
)
# get atomrefs
atomref = md17.get_atomrefs(train_args.property)