How to use the megnet.models.MEGNetModel.from_file function in megnet

To help you get started, we’ve selected a few megnet examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github materialsvirtuallab / megnet / megnet / utils / models.py View on Github external
def __init__(self, target_name):
        self.model = MEGNetModel.from_file(pjoin(QM9_MODELDIR, target_name + ".hdf5"))
        self.model.graph_converter.atom_converter = AtomNumberToTypeconverter()
        self.scaler = Scaler(SCALER[target_name]['mean'], SCALER[target_name]['std'],
                             SCALER[target_name]['is_per_atom'])
github materialsvirtuallab / megnet / megnet / utils / models.py View on Github external
def load_model(model_name: str) -> GraphModel:
    """
    load the model by user friendly name as in megnet.utils.models.AVAILABEL_MODELS

    Args:
        model_name: str model name string

    Returns: GraphModel

    """

    if model_name in AVAILABLE_MODELS:
        return MEGNetModel.from_file(MODEL_MAPPING[model_name])
    else:
        raise ValueError('model name %s not in available model list %s' % (model_name, AVAILABLE_MODELS))
github materialsvirtuallab / megnet / megnet / utils / descriptor.py View on Github external
def __init__(self, model_name: str = DEFAULT_MODEL, use_cache: bool = True):
        if isinstance(model_name, str):
            model = MEGNetModel.from_file(model_name)
        elif isinstance(model_name, GraphModel):
            model = model_name
        else:
            raise ValueError('model_name only support str or GraphModel object')

        layers = model.layers
        important_prefix = ['meg', 'set', 'concatenate']
        all_names = [i.name for i in layers if any([i.name.startswith(j) for j in important_prefix])]
        valid_outputs = [i.output for i in layers if any([i.name.startswith(j) for j in important_prefix])]

        outputs = []
        valid_names = []
        for i, j in zip(all_names, valid_outputs):
            if isinstance(j, list):
                for k, l in enumerate(j):
                    valid_names.append(i + '_%d' % k)