Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def one_block(a, b, c, has_ff=True):
if has_ff:
x1_ = ff(a)
x2_ = ff(b)
x3_ = ff(c)
else:
x1_ = a
x2_ = b
x3_ = c
out = MEGNetLayer(
[n1, n1, n2], [n1, n1, n2], [n1, n1, n2],
pool_method='mean', activation=act, kernel_regularizer=reg)(
[x1_, x2_, x3_, x4, x5, x6, x7])
x1_temp = out[0]
x2_temp = out[1]
x3_temp = out[2]
if dropout:
x1_temp = Dropout(dropout)(x1_temp, training=dropout_training)
x2_temp = Dropout(dropout)(x2_temp, training=dropout_training)
x3_temp = Dropout(dropout)(x3_temp, training=dropout_training)
return x1_temp, x2_temp, x3_temp
def one_block(a, b, c, has_ff=True, block_index=0):
if has_ff:
x1_ = ff(a, name_prefix='block_%d_atom_ff' % block_index)
x2_ = ff(b, name_prefix='block_%d_bond_ff' % block_index)
x3_ = ff(c, name_prefix='block_%d_state_ff' % block_index)
else:
x1_ = a
x2_ = b
x3_ = c
out = MEGNetLayer(
[n1, n1, n2], [n1, n1, n2], [n1, n1, n2],
pool_method='mean', activation=act, kernel_regularizer=reg, name='megnet_%d' % block_index)(
[x1_, x2_, x3_, x4, x5, x6, x7])
x1_temp = out[0]
x2_temp = out[1]
x3_temp = out[2]
if dropout:
x1_temp = Dropout(dropout, name='dropout_atom_%d' % block_index)(x1_temp, training=dropout_training)
x2_temp = Dropout(dropout, name='dropout_bond_%d' % block_index)(x2_temp, training=dropout_training)
x3_temp = Dropout(dropout, name='dropout_state_%d' % block_index)(x3_temp, training=dropout_training)
return x1_temp, x2_temp, x3_temp
reverse=True,
)
model_file = os.path.join(
warm_start, "kfold_{}".format(fold), "model", model_list[-1]
)
# Load model from file
if learning_rate is None:
full_model = load_model(
model_file,
custom_objects={
"softplus2": softplus2,
"Set2Set": Set2Set,
"mean_squared_error_with_scale": mean_squared_error_with_scale,
"MEGNetLayer": MEGNetLayer,
},
)
learning_rate = K.get_value(full_model.optimizer.lr)
# Set up model
model = MEGNetModel(
100,
2,
nblocks=args.n_blocks,
nvocal=95,
npass=args.n_pass,
lr=learning_rate,
loss=args.loss,
graph_convertor=cg,
is_classification=True if args.type == "classification" else False,
nfeat_node=None if embedding_file is None else 16,