Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
import schnetpack.atomistic as atm
import schnetpack.representation as rep
from schnetpack.datasets import *
# load qm9 dataset and download if necessary
data = QM9("qm9.db", properties=[QM9.U0], collect_triples=True)
# split in train and val
train, val, test = data.create_splits(100000, 10000)
loader = spk.data.AtomsLoader(train, batch_size=100, num_workers=4)
val_loader = spk.data.AtomsLoader(val)
# create model
reps = rep.BehlerSFBlock()
output = schnetpack.output_modules.ElementalAtomwise(reps.n_symfuncs)
model = atm.AtomisticModel(reps, output)
# filter for trainable parameters (https://github.com/pytorch/pytorch/issues/679)
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
# create trainer
opt = Adam(trainable_params, lr=1e-4)
loss = lambda b, p: F.mse_loss(p["y"], b[QM9.U0])
trainer = spk.train.Trainer("wacsf/", model, loss, opt, loader, val_loader)
# start training
trainer.train(torch.device("cpu"))
import schnetpack as spk
import schnetpack.representation as rep
from schnetpack.datasets import *
# load qm9 dataset and download if necessary
data = QM9("qm9.db")
# split in train and val
train, val, test = data.create_splits(100000, 10000)
loader = spk.data.AtomsLoader(train, batch_size=100, num_workers=4)
val_loader = spk.data.AtomsLoader(val)
# create model
reps = rep.SchNet()
output = schnetpack.atomistic.Atomwise(n_in=reps.n_atom_basis)
model = schnetpack.atomistic.AtomisticModel(reps, output)
# create trainer
opt = Adam(model.parameters(), lr=1e-4)
loss = lambda b, p: F.mse_loss(p["y"], b[QM9.U0])
trainer = spk.train.Trainer("output/", model, loss, opt, loader, val_loader)
# start training
trainer.train(torch.device("cpu"))
args, dataset=ani1, split_path=split_path, logging=logging
)
if args.mode == "train":
# get statistics
logging.info("calculate statistics...")
mean, stddev = get_statistics(
split_path, train_loader, train_args, atomref, logging=logging
)
# build representation
representation = get_representation(train_args, train_loader=train_loader)
# build output module
if args.model == "schnet":
output_modules = schnetpack.atomistic.Atomwise(
args.features,
mean=mean[args.property],
stddev=stddev[args.property],
atomref=atomref[args.property],
aggregation_mode=args.aggregation_mode,
property="energy",
)
elif args.model == "wacsf":
elements = frozenset((atomic_numbers[i] for i in sorted(args.elements)))
output_modules = schnetpack.atomistic.ElementalAtomwise(
n_in=representation.n_symfuncs,
n_hidden=args.n_nodes,
n_layers=args.n_layers,
mean=mean[args.property],
stddev=stddev[args.property],
aggregation_mode=args.aggregation_mode,
dataset, 1000, 100, os.path.join(model_dir, "split.npz")
)
train_loader = spk.AtomsLoader(train, batch_size=64, shuffle=True)
val_loader = spk.AtomsLoader(val, batch_size=64)
# statistics
atomrefs = dataset.get_atomref(properties)
means, stddevs = train_loader.get_statistics(
properties, divide_by_atoms=True, single_atom_ref=atomrefs
)
# model build
logging.info("build model")
representation = spk.SchNet(n_interactions=6)
output_modules = [
spk.atomistic.Atomwise(
n_in=representation.n_atom_basis,
property=QM9.U0,
mean=means[QM9.U0],
stddev=stddevs[QM9.U0],
atomref=atomrefs[QM9.U0],
)
]
model = schnetpack.AtomisticModel(representation, output_modules)
# build optimizer
optimizer = Adam(model.parameters(), lr=1e-4)
# hooks
logging.info("build trainer")
metrics = [MeanAbsoluteError(p, p) for p in properties]
hooks = [CSVHook(log_path=model_dir, metrics=metrics), ReduceLROnPlateauHook(optimizer)]
import schnetpack as spk
import schnetpack.representation as rep
from schnetpack.datasets import *
# load qm9 dataset and download if necessary
data = QM9("qm9.db")
# split in train and val
train, val, test = data.create_splits(100000, 10000)
loader = spk.data.AtomsLoader(train, batch_size=100, num_workers=4)
val_loader = spk.data.AtomsLoader(val)
# create model
reps = rep.SchNet()
output = schnetpack.atomistic.Atomwise(n_in=reps.n_atom_basis)
model = schnetpack.atomistic.AtomisticModel(reps, output)
# create trainer
opt = Adam(model.parameters(), lr=1e-4)
loss = lambda b, p: F.mse_loss(p["y"], b[QM9.U0])
trainer = spk.train.Trainer("output/", model, loss, opt, loader, val_loader)
# start training
trainer.train(torch.device("cpu"))
# build representation
representation = get_representation(train_args, train_loader=train_loader)
# build output module
if args.model == "schnet":
output_modules = schnetpack.atomistic.Atomwise(
args.features,
mean=mean[args.property],
stddev=stddev[args.property],
atomref=atomref[args.property],
aggregation_mode=args.aggregation_mode,
property="energy",
)
elif args.model == "wacsf":
elements = frozenset((atomic_numbers[i] for i in sorted(args.elements)))
output_modules = schnetpack.atomistic.ElementalAtomwise(
n_in=representation.n_symfuncs,
n_hidden=args.n_nodes,
n_layers=args.n_layers,
mean=mean[args.property],
stddev=stddev[args.property],
aggregation_mode=args.aggregation_mode,
atomref=atomref[args.property],
elements=elements,
property="energy",
)
else:
raise NotImplementedError("Model {} is not known".format(args.model))
# build AtomisticModel
model = get_model(
representation=representation,