Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
lr_decay=0.5,
lr_min=1e-6,
logger="csv",
modelpath=modeldir,
log_every_n_epochs=2,
max_steps=30,
checkpoint_interval=1,
keep_n_checkpoints=1,
dataset="qm9",
)
trainer = get_trainer(
args, schnet, qm9_train_loader, qm9_val_loader, metrics=None
)
assert trainer._model == schnet
hook_types = [type(hook) for hook in trainer.hooks]
assert schnetpack.train.hooks.CSVHook in hook_types
assert schnetpack.train.hooks.TensorboardHook not in hook_types
assert schnetpack.train.hooks.MaxEpochHook in hook_types
assert schnetpack.train.hooks.ReduceLROnPlateauHook in hook_types
)
mean = {args.property: None}
model = get_model(
args, train_loader=qm9_train_loader, mean=mean, stddev=mean, atomref=mean
)
os.makedirs(modeldir, exist_ok=True)
evaluate(
args,
model,
qm9_train_loader,
qm9_val_loader,
qm9_test_loader,
"cpu",
metrics=[
schnetpack.train.metrics.MeanAbsoluteError(
"energy_U0", model_output="energy_U0"
)
],
)
assert os.path.exists(os.path.join(modeldir, "evaluation.txt"))
args.split = ["train"]
evaluate(
args,
model,
qm9_train_loader,
qm9_val_loader,
qm9_test_loader,
"cpu",
metrics=[
schnetpack.train.metrics.MeanAbsoluteError(
"energy_U0", model_output="energy_U0"
add_md17_arguments(parser)
add_subparsers(
parser,
defaults=dict(property=MD17.energy, elements=["H", "C", "O"]),
choices=dict(property=[MD17.energy, MD17.forces]),
)
args = parser.parse_args()
train_args = setup_run(args)
# set device
device = torch.device("cuda" if args.cuda else "cpu")
# define metrics
metrics = [
schnetpack.train.metrics.MeanAbsoluteError(MD17.energy, MD17.energy),
schnetpack.train.metrics.RootMeanSquaredError(MD17.energy, MD17.energy),
schnetpack.train.metrics.MeanAbsoluteError(
MD17.forces, MD17.forces, element_wise=True
),
schnetpack.train.metrics.RootMeanSquaredError(
MD17.forces, MD17.forces, element_wise=True
),
]
# build dataset
logging.info("MD17 will be loaded...")
md17 = MD17(
args.datapath,
args.molecule,
download=True,
collect_triples=args.model == "wacsf",
)
train, val, test = data.create_splits(100000, 10000)
loader = spk.data.AtomsLoader(train, batch_size=100, num_workers=4)
val_loader = spk.data.AtomsLoader(val)
# create model
reps = rep.BehlerSFBlock()
output = schnetpack.output_modules.ElementalAtomwise(reps.n_symfuncs)
model = atm.AtomisticModel(reps, output)
# filter for trainable parameters (https://github.com/pytorch/pytorch/issues/679)
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
# create trainer
opt = Adam(trainable_params, lr=1e-4)
loss = lambda b, p: F.mse_loss(p["y"], b[QM9.U0])
trainer = spk.train.Trainer("wacsf/", model, loss, opt, loader, val_loader)
# start training
trainer.train(torch.device("cpu"))