Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def qm9_splits(qm9_dataset, qm9_split):
return spk.data.train_test_split(qm9_dataset, *qm9_split)
def train_loader(train, batch_size):
return spk.data.AtomsLoader(train, batch_size)
def test_loader(example_asedata, batch_size):
loader = schnetpack.data.AtomsLoader(example_asedata, batch_size)
for batch in loader:
for entry in batch.values():
assert entry.shape[0] == min(batch_size, len(loader.dataset))
mu, std = loader.get_statistics("energy")
assert mu["energy"] == torch.FloatTensor([5.0])
assert std["energy"] == torch.FloatTensor([0.0])
def test_loader(test, batch_size):
return spk.data.AtomsLoader(test, batch_size)
def qm9_dataset(qm9_dbpath):
print(os.path.exists(qm9_dbpath))
return QM9(qm9_dbpath)
lr_decay=0.5,
lr_min=1e-6,
logger="csv",
modelpath=modeldir,
log_every_n_epochs=2,
max_steps=30,
checkpoint_interval=1,
keep_n_checkpoints=1,
dataset="qm9",
)
trainer = get_trainer(
args, schnet, qm9_train_loader, qm9_val_loader, metrics=None
)
assert trainer._model == schnet
hook_types = [type(hook) for hook in trainer.hooks]
assert schnetpack.train.hooks.CSVHook in hook_types
assert schnetpack.train.hooks.TensorboardHook not in hook_types
assert schnetpack.train.hooks.MaxEpochHook in hook_types
assert schnetpack.train.hooks.ReduceLROnPlateauHook in hook_types
)
mean = {args.property: None}
model = get_model(
args, train_loader=qm9_train_loader, mean=mean, stddev=mean, atomref=mean
)
os.makedirs(modeldir, exist_ok=True)
evaluate(
args,
model,
qm9_train_loader,
qm9_val_loader,
qm9_test_loader,
"cpu",
metrics=[
schnetpack.train.metrics.MeanAbsoluteError(
"energy_U0", model_output="energy_U0"
)
],
)
assert os.path.exists(os.path.join(modeldir, "evaluation.txt"))
args.split = ["train"]
evaluate(
args,
model,
qm9_train_loader,
qm9_val_loader,
qm9_test_loader,
"cpu",
metrics=[
schnetpack.train.metrics.MeanAbsoluteError(
"energy_U0", model_output="energy_U0"
metrics=[
schnetpack.train.metrics.MeanAbsoluteError(
"energy_U0", model_output="energy_U0"
)
],
)
args.split = ["validation"]
evaluate(
args,
model,
qm9_train_loader,
qm9_val_loader,
qm9_test_loader,
"cpu",
metrics=[
schnetpack.train.metrics.MeanAbsoluteError(
"energy_U0", model_output="energy_U0"
)
def forces_mae():
return MeanAbsoluteError("_forces", "dydx", name="forces", element_wise=True)
def test_main_file_parser(main_path, targets_main):
main_parser = OrcaMainFileParser(properties=OrcaMainFileParser.properties)
main_parser.parse_file(main_path)
results = main_parser.get_parsed()
results[Properties.Z] = results["atoms"][0]
results[Properties.R] = results["atoms"][1]
results.pop("atoms", None)
for p in targets_main:
assert p in results
if p == Properties.Z:
assert np.array_equal(results[p], targets_main[p])
else:
assert np.allclose(results[p], targets_main[p])