Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def train_loader(train, batch_size):
return spk.data.AtomsLoader(train, batch_size)
def test_loader(example_asedata, batch_size):
loader = schnetpack.data.AtomsLoader(example_asedata, batch_size)
for batch in loader:
for entry in batch.values():
assert entry.shape[0] == min(batch_size, len(loader.dataset))
mu, std = loader.get_statistics("energy")
assert mu["energy"] == torch.FloatTensor([5.0])
assert std["energy"] == torch.FloatTensor([0.0])
def test_loader(test, batch_size):
return spk.data.AtomsLoader(test, batch_size)
np.random.seed(_seed)
if num_train < 1:
num_train = int(num_train * len(dataset))
if num_val < 1:
num_val = int(num_val * len(dataset))
train, val, test = train_test_split(dataset, num_train, num_val)
train_loader = AtomsLoader(
train, batch_size, True, pin_memory=True, num_workers=num_workers
)
val_loader = AtomsLoader(
val, batch_size, False, pin_memory=True, num_workers=num_workers
)
test_loader = AtomsLoader(
test, batch_size, False, pin_memory=True, num_workers=num_workers
)
atomrefs = {
p: dataset.get_atomref(tgt)
for p, tgt in property_map.items()
if tgt is not None
}
return train_loader, val_loader, test_loader, atomrefs
dataset (schnetpack.data.AtomsData): dataset object
Returns:
dataloader for training
"""
# local seed
np.random.seed(_seed)
if num_train < 1:
num_train = int(num_train * len(dataset))
if num_val < 1:
num_val = int(num_val * len(dataset))
train, val, test = train_test_split(dataset, num_train, num_val)
train_loader = AtomsLoader(
train, batch_size, True, pin_memory=True, num_workers=num_workers
)
val_loader = AtomsLoader(
val, batch_size, False, pin_memory=True, num_workers=num_workers
)
test_loader = AtomsLoader(
test, batch_size, False, pin_memory=True, num_workers=num_workers
)
atomrefs = {
p: dataset.get_atomref(tgt)
for p, tgt in property_map.items()
if tgt is not None
}
return train_loader, val_loader, test_loader, atomrefs
split_file=split_path)
logging.info('load data...')
# set up collate function according to args
collate = lambda x: \
collate_atoms(x,
draw_samples=args.draw_random_samples,
label_width_scaling=train_args.label_width_factor,
max_dist=train_args.max_distance,
n_bins=train_args.num_distance_bins)
train_loader = spk.data.AtomsLoader(data_train, batch_size=args.batch_size,
sampler=RandomSampler(data_train),
num_workers=4, pin_memory=True,
collate_fn=collate)
val_loader = spk.data.AtomsLoader(data_val, batch_size=args.batch_size,
num_workers=2, pin_memory=True,
collate_fn=collate)
# construct the model
if args.mode == 'train' or args.checkpoint >= 0:
model = get_model(train_args, parallelize=args.parallel).to(device)
logging.info(f'running on {device}')
# load model or checkpoint for evaluation or generation
if args.mode in ['eval', 'generate']:
if args.checkpoint < 0: # load best model
logging.info(f'restoring best model')
model = torch.load(os.path.join(args.modelpath, 'best_model')).to(device)
else:
logging.info(f'restoring checkpoint {args.checkpoint}')
chkpt = os.path.join(args.modelpath, 'checkpoints',
copyfile(args.split_path, split_path)
logging.info('create splits...')
data_train, data_val, data_test = qm9.create_splits(*train_args.split,
split_file=split_path)
logging.info('load data...')
# set up collate function according to args
collate = lambda x: \
collate_atoms(x,
draw_samples=args.draw_random_samples,
label_width_scaling=train_args.label_width_factor,
max_dist=train_args.max_distance,
n_bins=train_args.num_distance_bins)
train_loader = spk.data.AtomsLoader(data_train, batch_size=args.batch_size,
sampler=RandomSampler(data_train),
num_workers=4, pin_memory=True,
collate_fn=collate)
val_loader = spk.data.AtomsLoader(data_val, batch_size=args.batch_size,
num_workers=2, pin_memory=True,
collate_fn=collate)
# construct the model
if args.mode == 'train' or args.checkpoint >= 0:
model = get_model(train_args, parallelize=args.parallel).to(device)
logging.info(f'running on {device}')
# load model or checkpoint for evaluation or generation
if args.mode in ['eval', 'generate']:
if args.checkpoint < 0: # load best model
logging.info(f'restoring best model')