Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def trainer(
modeldir,
atomistic_model,
properties,
lr,
train_loader,
val_loader,
keep_n_checkpoints,
checkpoint_interval,
hooks,
):
return spk.train.Trainer(
model_path=modeldir,
model=atomistic_model,
loss_fn=spk.train.build_mse_loss(properties),
optimizer=torch.optim.Adam(atomistic_model.parameters(), lr=lr),
train_loader=train_loader,
validation_loader=val_loader,
keep_n_checkpoints=keep_n_checkpoints,
checkpoint_interval=checkpoint_interval,
validation_interval=1,
hooks=hooks,
loss_is_normalized=True,
)
negative_dr=True,
)
]
model = schnetpack.atomistic.model.AtomisticModel(representation, output_modules)
# build optimizer
optimizer = Adam(params=model.parameters(), lr=1e-4)
# hooks
logging.info("build trainer")
metrics = [MeanAbsoluteError(p, p) for p in properties]
hooks = [CSVHook(log_path=model_dir, metrics=metrics), ReduceLROnPlateauHook(optimizer)]
# trainer
loss = mse_loss(properties)
trainer = Trainer(
model_dir,
model=model,
hooks=hooks,
loss_fn=loss,
optimizer=optimizer,
train_loader=train_loader,
validation_loader=val_loader,
)
# run training
logging.info("training")
trainer.train(device="cpu", n_epochs=1000)
data = QM9("qm9.db")
# split in train and val
train, val, test = data.create_splits(100000, 10000)
loader = spk.data.AtomsLoader(train, batch_size=100, num_workers=4)
val_loader = spk.data.AtomsLoader(val)
# create model
reps = rep.SchNet()
output = schnetpack.atomistic.Atomwise(n_in=reps.n_atom_basis)
model = schnetpack.atomistic.AtomisticModel(reps, output)
# create trainer
opt = Adam(model.parameters(), lr=1e-4)
loss = lambda b, p: F.mse_loss(p["y"], b[QM9.U0])
trainer = spk.train.Trainer("output/", model, loss, opt, loader, val_loader)
# start training
trainer.train(torch.device("cpu"))
every_n_epochs=args.log_every_n_epochs,
)
hooks.append(logger)
elif args.logger == "tensorboard":
logger = spk.train.TensorboardHook(
os.path.join(args.modelpath, "log"),
metrics,
every_n_epochs=args.log_every_n_epochs,
)
hooks.append(logger)
# setup loss function
loss_fn = get_loss_fn(args)
# setup trainer
trainer = spk.train.Trainer(
args.modelpath,
model,
loss_fn,
optimizer,
train_loader,
val_loader,
checkpoint_interval=args.checkpoint_interval,
keep_n_checkpoints=args.keep_n_checkpoints,
hooks=hooks,
)
return trainer
loss_type = loss_layer(out_type, batch['_type_labels'])
loss_type = torch.sum(loss_type, -1)
loss_type = torch.mean(loss_type)
# loss for distance predictions (KLD)
mask_dist = batch['_dist_mask']
N = torch.sum(mask_dist)
out_dist = norm_layer(result['distance_predictions'])
loss_dist = loss_layer(out_dist, batch['_labels'])
loss_dist = torch.sum(loss_dist, -1)
loss_dist = torch.sum(loss_dist * mask_dist) / torch.max(N, torch.ones_like(N))
return loss_type + loss_dist
# initialize trainer
trainer = spk.train.Trainer(args.modelpath,
model,
loss,
optimizer,
train_loader,
val_loader,
hooks=hooks,
checkpoint_interval=args.checkpoint_every_n_epochs,
keep_n_checkpoints=10)
# reset optimizer and hooks if starting from pre-trained model (e.g. for
# fine-tuning)
if args.pretrained_path is not None:
logging.info('starting from pre-trained model...')
# reset epoch and step
trainer.epoch = 0
trainer.step = 0
atomref=atomrefs[QM9.U0],
)
]
model = schnetpack.AtomisticModel(representation, output_modules)
# build optimizer
optimizer = Adam(model.parameters(), lr=1e-4)
# hooks
logging.info("build trainer")
metrics = [MeanAbsoluteError(p, p) for p in properties]
hooks = [CSVHook(log_path=model_dir, metrics=metrics), ReduceLROnPlateauHook(optimizer)]
# trainer
loss = build_mse_loss(properties)
trainer = Trainer(
model_dir,
model=model,
hooks=hooks,
loss_fn=loss,
optimizer=optimizer,
train_loader=train_loader,
validation_loader=val_loader,
)
# run training
logging.info("training")
trainer.train(device="cpu")