How to use the schnetpack.atomistic.output_modules function in schnetpack

To help you get started, we’ve selected a few schnetpack examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github atomistic-machine-learning / schnetpack / src / schnetpack / utils / script_utils / model.py View on Github external
contributions = spk.utils.get_contributions(args)
    stress = spk.utils.get_stress(args)
    if args.dataset == "md17" and not args.ignore_forces:
        derivative = spk.datasets.MD17.forces
    output_module_str = spk.utils.get_module_str(args)
    if output_module_str == "dipole_moment":
        return spk.atomistic.output_modules.DipoleMoment(
            args.features,
            predict_magnitude=True,
            mean=mean[args.property],
            stddev=stddev[args.property],
            property=args.property,
            contributions=contributions,
        )
    elif output_module_str == "electronic_spatial_extent":
        return spk.atomistic.output_modules.ElectronicSpatialExtent(
            args.features,
            mean=mean[args.property],
            stddev=stddev[args.property],
            property=args.property,
            contributions=contributions,
        )
    elif output_module_str == "atomwise":
        return spk.atomistic.output_modules.Atomwise(
            args.features,
            aggregation_mode=spk.utils.get_pooling_mode(args),
            mean=mean[args.property],
            stddev=stddev[args.property],
            atomref=atomref[args.property],
            property=args.property,
            derivative=derivative,
            negative_dr=negative_dr,
github atomistic-machine-learning / schnetpack / src / scripts / schnetpack_md17.py View on Github external
logging.info("calculate statistics...")
        mean, stddev = get_statistics(
            split_path,
            train_loader,
            train_args,
            atomref,
            logging=logging,
            per_atom=True,
        )

        # build representation
        representation = get_representation(args, train_loader)

        # build output module
        if args.model == "schnet":
            output_module = spk.atomistic.output_modules.Atomwise(
                args.features,
                aggregation_mode=args.aggregation_mode,
                mean=mean[args.property],
                stddev=stddev[args.property],
                atomref=atomref[args.property],
                property=args.property,
                derivative="forces",
                negative_dr=True,
            )
        elif args.model == "wascf":
            elements = frozenset((atomic_numbers[i] for i in sorted(args.elements)))
            output_module = ElementalAtomwise(
                representation.n_symfuncs,
                n_hidden=args.n_nodes,
                n_layers=args.n_layers,
                mean=mean[args.property],
github atomistic-machine-learning / schnetpack / src / scripts / schnetpack_matproj.py View on Github external
args, dataset=mp, split_path=split_path, logging=logging
    )

    if args.mode == "train":
        # get statistics
        logging.info("calculate statistics...")
        mean, stddev = get_statistics(
            split_path, train_loader, train_args, atomref, logging=logging
        )

        # build representation
        representation = get_representation(train_args, train_loader=train_loader)

        # build output module
        if args.model == "schnet":
            output_module = schnetpack.atomistic.output_modules.Atomwise(
                args.features,
                aggregation_mode=args.aggregation_mode,
                mean=mean[args.property],
                stddev=stddev[args.property],
                atomref=None,
                train_embeddings=True,
                property=args.property,
            )
        else:
            raise NotImplementedError

        # build AtomisticModel
        model = get_model(
            representation=representation,
            output_modules=output_module,
            parallelize=args.parallel,
github atomistic-machine-learning / schnetpack / src / scripts / schnetpack_omdb.py View on Github external
args, dataset=omdb, split_path=split_path, logging=logging
    )

    if args.mode == "train":
        # get statistics
        logging.info("calculate statistics...")
        mean, stddev = get_statistics(
            split_path, train_loader, train_args, atomref, logging=logging
        )

        # build representation
        representation = get_representation(train_args, train_loader=train_loader)

        # build output module
        if args.model == "schnet":
            output_module = schnetpack.atomistic.output_modules.Atomwise(
                args.features,
                aggregation_mode=args.aggregation_mode,
                mean=mean[args.property],
                stddev=stddev[args.property],
                atomref=atomref[args.property],
                train_embeddings=True,
                property=args.property,
            )
        else:
            raise NotImplementedError

        # build AtomisticModel
        model = get_model(
            representation=representation,
            output_modules=output_module,
            parallelize=args.parallel,
github atomistic-machine-learning / schnetpack / src / scripts / schnetpack_qm9.py View on Github external
atomref=atomref[args.property],
                    property=args.property,
                )
        elif args.model == "wacsf":
            elements = frozenset((atomic_numbers[i] for i in sorted(args.elements)))
            if args.property == QM9.mu:
                output_module = schnetpack.atomistic.output_modules.ElementalDipoleMoment(
                    representation.n_symfuncs,
                    n_hidden=args.n_nodes,
                    n_layers=args.n_layers,
                    predict_magnitude=True,
                    elements=elements,
                    property=args.property,
                )
            else:
                output_module = schnetpack.atomistic.output_modules.ElementalAtomwise(
                    representation.n_symfuncs,
                    n_hidden=args.n_nodes,
                    n_layers=args.n_layers,
                    aggregation_mode=args.aggregation_mode,
                    mean=mean[args.property],
                    stddev=stddev[args.property],
                    atomref=atomref[args.property],
                    elements=elements,
                    property=args.property,
                )
        else:
            raise NotImplementedError

        # build AtomisticModel
        model = get_model(
            representation=representation,
github atomistic-machine-learning / schnetpack / src / schnetpack / utils / script_utils / model.py View on Github external
def get_output_module(args, representation, mean, stddev, atomref):
    derivative = spk.utils.get_derivative(args)
    negative_dr = spk.utils.get_negative_dr(args)
    contributions = spk.utils.get_contributions(args)
    stress = spk.utils.get_stress(args)
    if args.dataset == "md17" and not args.ignore_forces:
        derivative = spk.datasets.MD17.forces
    output_module_str = spk.utils.get_module_str(args)
    if output_module_str == "dipole_moment":
        return spk.atomistic.output_modules.DipoleMoment(
            args.features,
            predict_magnitude=True,
            mean=mean[args.property],
            stddev=stddev[args.property],
            property=args.property,
            contributions=contributions,
        )
    elif output_module_str == "electronic_spatial_extent":
        return spk.atomistic.output_modules.ElectronicSpatialExtent(
            args.features,
            mean=mean[args.property],
            stddev=stddev[args.property],
            property=args.property,
            contributions=contributions,
        )
    elif output_module_str == "atomwise":
github atomistic-machine-learning / schnetpack / src / schnetpack / utils / script_utils / model.py View on Github external
predict_magnitude=True,
            mean=mean[args.property],
            stddev=stddev[args.property],
            property=args.property,
            contributions=contributions,
        )
    elif output_module_str == "electronic_spatial_extent":
        return spk.atomistic.output_modules.ElectronicSpatialExtent(
            args.features,
            mean=mean[args.property],
            stddev=stddev[args.property],
            property=args.property,
            contributions=contributions,
        )
    elif output_module_str == "atomwise":
        return spk.atomistic.output_modules.Atomwise(
            args.features,
            aggregation_mode=spk.utils.get_pooling_mode(args),
            mean=mean[args.property],
            stddev=stddev[args.property],
            atomref=atomref[args.property],
            property=args.property,
            derivative=derivative,
            negative_dr=negative_dr,
            contributions=contributions,
            stress=stress,
        )
    elif output_module_str == "polarizability":
        return spk.atomistic.output_modules.Polarizability(
            args.features,
            aggregation_mode=spk.utils.get_pooling_mode(args),
            property=args.property,
github atomistic-machine-learning / schnetpack / src / schnetpack / utils / script_utils / model.py View on Github external
return spk.atomistic.output_modules.Polarizability(
            args.features,
            aggregation_mode=spk.utils.get_pooling_mode(args),
            property=args.property,
        )
    elif output_module_str == "isotropic_polarizability":
        return spk.atomistic.output_modules.Polarizability(
            args.features,
            aggregation_mode=spk.utils.get_pooling_mode(args),
            property=args.property,
            isotropic=True,
        )
    # wacsf modules
    elif output_module_str == "elemental_dipole_moment":
        elements = frozenset((atomic_numbers[i] for i in sorted(args.elements)))
        return spk.atomistic.output_modules.ElementalDipoleMoment(
            representation.n_symfuncs,
            n_hidden=args.n_nodes,
            n_layers=args.n_layers,
            predict_magnitude=True,
            elements=elements,
            property=args.property,
        )
    elif output_module_str == "elemental_atomwise":
        elements = frozenset((atomic_numbers[i] for i in sorted(args.elements)))
        return spk.atomistic.output_modules.ElementalAtomwise(
            representation.n_symfuncs,
            n_hidden=args.n_nodes,
            n_layers=args.n_layers,
            aggregation_mode=spk.utils.get_pooling_mode(args),
            mean=mean[args.property],
            stddev=stddev[args.property],
github atomistic-machine-learning / schnetpack / src / scripts / schnetpack_qm9.py View on Github external
# build representation
        representation = get_representation(args, train_loader=train_loader)

        # build output module
        if args.model == "schnet":
            if args.property == QM9.mu:
                output_module = schnetpack.atomistic.output_modules.DipoleMoment(
                    args.features,
                    predict_magnitude=True,
                    mean=mean[args.property],
                    stddev=stddev[args.property],
                    property=args.property,
                )
            else:
                output_module = schnetpack.atomistic.output_modules.Atomwise(
                    args.features,
                    aggregation_mode=args.aggregation_mode,
                    mean=mean[args.property],
                    stddev=stddev[args.property],
                    atomref=atomref[args.property],
                    property=args.property,
                )
        elif args.model == "wacsf":
            elements = frozenset((atomic_numbers[i] for i in sorted(args.elements)))
            if args.property == QM9.mu:
                output_module = schnetpack.atomistic.output_modules.ElementalDipoleMoment(
                    representation.n_symfuncs,
                    n_hidden=args.n_nodes,
                    n_layers=args.n_layers,
                    predict_magnitude=True,
                    elements=elements,