Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
with jit_scope():
input_tensors = [input_ids, input_mask, input_type_ids]
model = modeling.BertModel(
config=bert_config,
is_training=False,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=input_type_ids,
use_one_hot_embeddings=False)
# 获取所有要训练的变量
tvars = tf.trainable_variables()
(assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,
init_checkpoint)
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1)
masked_reduce_mean = lambda x, m: tf.reduce_sum(mul_mask(x, m), axis=1) / (
tf.reduce_sum(m, axis=1, keepdims=True) + 1e-10)
# 共享卷积核
with tf.variable_scope("pooling"):
# 如果只有一层,就只取对应那一层的weight
if len(layer_indexes) == 1:
encoder_layer = model.all_encoder_layers[layer_indexes[0]]
else:
# 否则遍历需要取的层,把所有层的weight取出来并拼接起来shape:768*层数
all_layers = [model.all_encoder_layers[l] for l in layer_indexes]