Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
else:
model_list.sort(
key=lambda m_file: float(m_file.split("_")[3].replace(".hdf5", "")),
reverse=True,
)
model_file = os.path.join(
warm_start, "kfold_{}".format(fold), "model", model_list[-1]
)
# Load model from file
if learning_rate is None:
full_model = load_model(
model_file,
custom_objects={
"softplus2": softplus2,
"Set2Set": Set2Set,
"mean_squared_error_with_scale": mean_squared_error_with_scale,
"MEGNetLayer": MEGNetLayer,
},
)
learning_rate = K.get_value(full_model.optimizer.lr)
# Set up model
model = MEGNetModel(
100,
2,
nblocks=args.n_blocks,
nvocal=95,
npass=args.n_pass,
lr=learning_rate,
loss=args.loss,
def make_megnet_model(nfeat_edge: int = None,
nfeat_global: int = None,
nfeat_node: int = None,
nblocks: int = 3,
n1: int = 64,
n2: int = 32,
n3: int = 16,
nvocal: int = 95,
embedding_dim: int = 16,
nbvocal: int = None,
bond_embedding_dim: int = None,
ngvocal: int = None,
global_embedding_dim: int = None,
npass: int = 3,
ntarget: int = 1,
act: Callable = softplus2,
is_classification: bool = False,
l2_coef: float = None,
dropout: float = None,
dropout_on_predict: bool = False
) -> Model:
"""Make a MEGNet Model
Args:
nfeat_edge: (int) number of bond features
nfeat_global: (int) number of state features
nfeat_node: (int) number of atom features
nblocks: (int) number of MEGNetLayer blocks
n1: (int) number of hidden units in layer 1 in MEGNetLayer
n2: (int) number of hidden units in layer 2 in MEGNetLayer
n3: (int) number of hidden units in layer 3 in MEGNetLayer
nvocal: (int) number of total element
nfeat_global: int = None,
nfeat_node: int = None,
nblocks: int = 3,
lr: float = 1e-3,
n1: int = 64,
n2: int = 32,
n3: int = 16,
nvocal: int = 95,
embedding_dim: int = 16,
nbvocal: int = None,
bond_embedding_dim: int = None,
ngvocal: int = None,
global_embedding_dim: int = None,
npass: int = 3,
ntarget: int = 1,
act: Callable = softplus2,
is_classification: bool = False,
loss: str = "mse",
metrics: List[str] = None,
l2_coef: float = None,
dropout: float = None,
graph_converter: StructureGraph = None,
target_scaler: Scaler = DummyScaler(),
optimizer_kwargs: Dict = None,
dropout_on_predict: bool = False
):
"""
Args:
nfeat_edge: (int) number of bond features
nfeat_global: (int) number of state features
nfeat_node: (int) number of atom features
nblocks: (int) number of MEGNetLayer blocks
nfeat_global: int = None,
nfeat_node: int = None,
nblocks: int = 3,
lr: float = 1e-3,
n1: int = 64,
n2: int = 32,
n3: int = 16,
nvocal: int = 95,
embedding_dim: int = 16,
nbvocal: int = None,
bond_embedding_dim: int = None,
ngvocal: int = None,
global_embedding_dim: int = None,
npass: int = 3,
ntarget: int = 1,
act: Callable = softplus2,
is_classification: bool = False,
loss: str = "mse",
metrics: List[str] = None,
l2_coef: float = None,
dropout: float = None,
graph_converter: StructureGraph = None,
target_scaler: Scaler = DummyScaler(),
optimizer_kwargs: Dict = None,
dropout_on_predict: bool = False
):
"""
Args:
nfeat_edge: (int) number of bond features
nfeat_global: (int) number of state features
nfeat_node: (int) number of atom features
nblocks: (int) number of MEGNetLayer blocks
def make_megnet_model(nfeat_edge: int = None,
nfeat_global: int = None,
nfeat_node: int = None,
nblocks: int = 3,
n1: int = 64,
n2: int = 32,
n3: int = 16,
nvocal: int = 95,
embedding_dim: int = 16,
nbvocal: int = None,
bond_embedding_dim: int = None,
ngvocal: int = None,
global_embedding_dim: int = None,
npass: int = 3,
ntarget: int = 1,
act: Callable = softplus2,
is_classification: bool = False,
l2_coef: float = None,
dropout: float = None,
dropout_on_predict: bool = False
) -> Model:
"""Make a MEGNet Model
Args:
nfeat_edge: (int) number of bond features
nfeat_global: (int) number of state features
nfeat_node: (int) number of atom features
nblocks: (int) number of MEGNetLayer blocks
n1: (int) number of hidden units in layer 1 in MEGNetLayer
n2: (int) number of hidden units in layer 2 in MEGNetLayer
n3: (int) number of hidden units in layer 3 in MEGNetLayer
nvocal: (int) number of total element