How to use the schnetpack.data.AtomsLoader function in schnetpack

To help you get started, we’ve selected a few schnetpack examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github atomistic-machine-learning / schnetpack / tests / fixtures / data.py View on Github external
def train_loader(train, batch_size):
    return spk.data.AtomsLoader(train, batch_size)
github atomistic-machine-learning / schnetpack / tests / test_data.py View on Github external
def test_loader(example_asedata, batch_size):
    loader = schnetpack.data.AtomsLoader(example_asedata, batch_size)
    for batch in loader:
        for entry in batch.values():
            assert entry.shape[0] == min(batch_size, len(loader.dataset))

    mu, std = loader.get_statistics("energy")
    assert mu["energy"] == torch.FloatTensor([5.0])
    assert std["energy"] == torch.FloatTensor([0.0])
github atomistic-machine-learning / schnetpack / tests / fixtures / data.py View on Github external
def test_loader(test, batch_size):
    return spk.data.AtomsLoader(test, batch_size)
github atomistic-machine-learning / schnetpack / src / schnetpack / sacred / dataloader_ingredient.py View on Github external
np.random.seed(_seed)

    if num_train < 1:
        num_train = int(num_train * len(dataset))
    if num_val < 1:
        num_val = int(num_val * len(dataset))

    train, val, test = train_test_split(dataset, num_train, num_val)

    train_loader = AtomsLoader(
        train, batch_size, True, pin_memory=True, num_workers=num_workers
    )
    val_loader = AtomsLoader(
        val, batch_size, False, pin_memory=True, num_workers=num_workers
    )
    test_loader = AtomsLoader(
        test, batch_size, False, pin_memory=True, num_workers=num_workers
    )

    atomrefs = {
        p: dataset.get_atomref(tgt)
        for p, tgt in property_map.items()
        if tgt is not None
    }

    return train_loader, val_loader, test_loader, atomrefs
github atomistic-machine-learning / schnetpack / src / schnetpack / sacred / dataloader_ingredient.py View on Github external
dataset (schnetpack.data.AtomsData): dataset object

    Returns:
        dataloader for training
    """
    # local seed
    np.random.seed(_seed)

    if num_train < 1:
        num_train = int(num_train * len(dataset))
    if num_val < 1:
        num_val = int(num_val * len(dataset))

    train, val, test = train_test_split(dataset, num_train, num_val)

    train_loader = AtomsLoader(
        train, batch_size, True, pin_memory=True, num_workers=num_workers
    )
    val_loader = AtomsLoader(
        val, batch_size, False, pin_memory=True, num_workers=num_workers
    )
    test_loader = AtomsLoader(
        test, batch_size, False, pin_memory=True, num_workers=num_workers
    )

    atomrefs = {
        p: dataset.get_atomref(tgt)
        for p, tgt in property_map.items()
        if tgt is not None
    }

    return train_loader, val_loader, test_loader, atomrefs
github atomistic-machine-learning / G-SchNet / gschnet_qm9_script.py View on Github external
split_file=split_path)

        logging.info('load data...')
        # set up collate function according to args
        collate = lambda x: \
            collate_atoms(x,
                          draw_samples=args.draw_random_samples,
                          label_width_scaling=train_args.label_width_factor,
                          max_dist=train_args.max_distance,
                          n_bins=train_args.num_distance_bins)

        train_loader = spk.data.AtomsLoader(data_train, batch_size=args.batch_size,
                                            sampler=RandomSampler(data_train),
                                            num_workers=4, pin_memory=True,
                                            collate_fn=collate)
        val_loader = spk.data.AtomsLoader(data_val, batch_size=args.batch_size,
                                          num_workers=2, pin_memory=True,
                                          collate_fn=collate)

    # construct the model
    if args.mode == 'train' or args.checkpoint >= 0:
        model = get_model(train_args, parallelize=args.parallel).to(device)
    logging.info(f'running on {device}')

    # load model or checkpoint for evaluation or generation
    if args.mode in ['eval', 'generate']:
        if args.checkpoint < 0:  # load best model
            logging.info(f'restoring best model')
            model = torch.load(os.path.join(args.modelpath, 'best_model')).to(device)
        else:
            logging.info(f'restoring checkpoint {args.checkpoint}')
            chkpt = os.path.join(args.modelpath, 'checkpoints',
github atomistic-machine-learning / G-SchNet / gschnet_qm9_script.py View on Github external
copyfile(args.split_path, split_path)

        logging.info('create splits...')
        data_train, data_val, data_test = qm9.create_splits(*train_args.split,
                                                            split_file=split_path)

        logging.info('load data...')
        # set up collate function according to args
        collate = lambda x: \
            collate_atoms(x,
                          draw_samples=args.draw_random_samples,
                          label_width_scaling=train_args.label_width_factor,
                          max_dist=train_args.max_distance,
                          n_bins=train_args.num_distance_bins)

        train_loader = spk.data.AtomsLoader(data_train, batch_size=args.batch_size,
                                            sampler=RandomSampler(data_train),
                                            num_workers=4, pin_memory=True,
                                            collate_fn=collate)
        val_loader = spk.data.AtomsLoader(data_val, batch_size=args.batch_size,
                                          num_workers=2, pin_memory=True,
                                          collate_fn=collate)

    # construct the model
    if args.mode == 'train' or args.checkpoint >= 0:
        model = get_model(train_args, parallelize=args.parallel).to(device)
    logging.info(f'running on {device}')

    # load model or checkpoint for evaluation or generation
    if args.mode in ['eval', 'generate']:
        if args.checkpoint < 0:  # load best model
            logging.info(f'restoring best model')