Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def qm9_splits(qm9_dataset, qm9_split):
return spk.data.train_test_split(qm9_dataset, *qm9_split)
def test_triples_exception(example_asedata, batch_size):
loader = schnetpack.data.AtomsLoader(example_asedata, batch_size)
reps = rep.BehlerSFBlock(n_radial=22, n_angular=5, elements=frozenset(range(100)))
for batch in loader:
with pytest.raises(rep.HDNNException):
reps(batch)
break
Args:
args (argparse.Namespace): parsed script arguments
dataset (spk.AtomsData): total dataset
split_path (str): path to split file
logging: logger
Returns:
(spk.AtomsLoader, spk.AtomsLoader, spk.AtomsLoader): dataloaders for train,
val and test
"""
if logging is not None:
logging.info("create splits...")
# create or load dataset splits depending on args.mode
if args.mode == "train":
data_train, data_val, data_test = spk.data.train_test_split(
dataset, *args.split, split_file=split_path
)
else:
data_train, data_val, data_test = spk.data.train_test_split(
dataset, split_file=split_path
)
if logging is not None:
logging.info("load data...")
# build dataloaders
train_loader = spk.data.AtomsLoader(
data_train,
batch_size=args.batch_size,
sampler=RandomSampler(data_train),
num_workers=4,
import torch
import torch.nn.functional as F
from torch.optim import Adam
import schnetpack as spk
import schnetpack.atomistic as atm
import schnetpack.representation as rep
from schnetpack.datasets import *
# load qm9 dataset and download if necessary
data = QM9("qm9.db", properties=[QM9.U0], collect_triples=True)
# split in train and val
train, val, test = data.create_splits(100000, 10000)
loader = spk.data.AtomsLoader(train, batch_size=100, num_workers=4)
val_loader = spk.data.AtomsLoader(val)
# create model
reps = rep.BehlerSFBlock()
output = schnetpack.output_modules.ElementalAtomwise(reps.n_symfuncs)
model = atm.AtomisticModel(reps, output)
# filter for trainable parameters (https://github.com/pytorch/pytorch/issues/679)
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
# create trainer
opt = Adam(trainable_params, lr=1e-4)
loss = lambda b, p: F.mse_loss(p["y"], b[QM9.U0])
trainer = spk.train.Trainer("wacsf/", model, loss, opt, loader, val_loader)
# start training
trainer.train(torch.device("cpu"))
import schnetpack.atomistic.output_modules
import torch
import torch.nn.functional as F
from torch.optim import Adam
import schnetpack as spk
import schnetpack.representation as rep
from schnetpack.datasets import *
# load qm9 dataset and download if necessary
data = QM9("qm9.db")
# split in train and val
train, val, test = data.create_splits(100000, 10000)
loader = spk.data.AtomsLoader(train, batch_size=100, num_workers=4)
val_loader = spk.data.AtomsLoader(val)
# create model
reps = rep.SchNet()
output = schnetpack.atomistic.Atomwise(n_in=reps.n_atom_basis)
model = schnetpack.atomistic.AtomisticModel(reps, output)
# create trainer
opt = Adam(model.parameters(), lr=1e-4)
loss = lambda b, p: F.mse_loss(p["y"], b[QM9.U0])
trainer = spk.train.Trainer("output/", model, loss, opt, loader, val_loader)
# start training
trainer.train(torch.device("cpu"))
import schnetpack.atomistic.output_modules
import torch
import torch.nn.functional as F
from torch.optim import Adam
import schnetpack as spk
import schnetpack.representation as rep
from schnetpack.datasets import *
# load qm9 dataset and download if necessary
data = QM9("qm9.db")
# split in train and val
train, val, test = data.create_splits(100000, 10000)
loader = spk.data.AtomsLoader(train, batch_size=100, num_workers=4)
val_loader = spk.data.AtomsLoader(val)
# create model
reps = rep.SchNet()
output = schnetpack.atomistic.Atomwise(n_in=reps.n_atom_basis)
model = schnetpack.atomistic.AtomisticModel(reps, output)
# create trainer
opt = Adam(model.parameters(), lr=1e-4)
loss = lambda b, p: F.mse_loss(p["y"], b[QM9.U0])
trainer = spk.train.Trainer("output/", model, loss, opt, loader, val_loader)
# start training
trainer.train(torch.device("cpu"))