Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
metrics=[
schnetpack.train.metrics.MeanAbsoluteError(
"energy_U0", model_output="energy_U0"
)
],
)
args.split = ["validation"]
evaluate(
args,
model,
qm9_train_loader,
qm9_val_loader,
qm9_test_loader,
"cpu",
metrics=[
schnetpack.train.metrics.MeanAbsoluteError(
"energy_U0", model_output="energy_U0"
)
def forces_mae():
return MeanAbsoluteError("_forces", "dydx", name="forces", element_wise=True)
# parse arguments
parser = get_main_parser()
add_subparsers(
parser,
defaults=dict(property=ANI1.energy),
choices=dict(property=[ANI1.energy]),
)
args = parser.parse_args()
train_args = setup_run(args)
# set device
device = torch.device("cuda" if args.cuda else "cpu")
# define metrics
metrics = [
schnetpack.train.metrics.MeanAbsoluteError(
train_args.property, train_args.property
),
schnetpack.train.metrics.RootMeanSquaredError(
train_args.property, train_args.property
),
]
# build dataset
logging.info("ANI1 will be loaded...")
ani1 = spk.datasets.ANI1(
args.datapath,
download=True,
load_only=[train_args.property],
collect_triples=args.model == "wacsf",
)
def get_metrics(args):
# setup property metrics
metrics = [
spk.train.metrics.MeanAbsoluteError(args.property, args.property),
spk.train.metrics.RootMeanSquaredError(args.property, args.property),
]
# add metrics for derivative
derivative = spk.utils.get_derivative(args)
if derivative is not None:
metrics += [
spk.train.metrics.MeanAbsoluteError(
derivative, derivative, element_wise=True
),
spk.train.metrics.RootMeanSquaredError(
derivative, derivative, element_wise=True
),
]
return metrics
n_in=representation.n_atom_basis,
property="energy",
derivative="forces",
mean=means["energy"],
stddev=stddevs["energy"],
negative_dr=True,
)
]
model = schnetpack.atomistic.model.AtomisticModel(representation, output_modules)
# build optimizer
optimizer = Adam(params=model.parameters(), lr=1e-4)
# hooks
logging.info("build trainer")
metrics = [MeanAbsoluteError(p, p) for p in properties]
hooks = [CSVHook(log_path=model_dir, metrics=metrics), ReduceLROnPlateauHook(optimizer)]
# trainer
loss = mse_loss(properties)
trainer = Trainer(
model_dir,
model=model,
hooks=hooks,
loss_fn=loss,
optimizer=optimizer,
train_loader=train_loader,
validation_loader=val_loader,
)
# run training
logging.info("training")