Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""
unique_ids = features["unique_ids"]
input_ids = features["input_ids"]
input_mask = features["input_mask"]
input_type_ids = features["input_type_ids"]
jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
with jit_scope():
model = modeling.BertModel(
config=bert_config,
is_training=False,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=input_type_ids)
if mode != tf.estimator.ModeKeys.PREDICT:
raise ValueError("Only PREDICT modes are supported: %s" % (mode))
tvars = tf.trainable_variables()
(assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,
init_checkpoint)
tf.logging.info("**** Trainable Variables ****")
for var in tvars:
# 加载bert配置文件
with tf.gfile.GFile(config_fp, 'r') as f:
bert_config = modeling.BertConfig.from_dict(json.load(f))
logger.info('build graph...')
# input placeholders, not sure if they are friendly to XLA
input_ids = tf.placeholder(tf.int32, (None, max_seq_len), 'input_ids')
input_mask = tf.placeholder(tf.int32, (None, max_seq_len), 'input_mask')
input_type_ids = tf.placeholder(tf.int32, (None, max_seq_len), 'input_type_ids')
jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
with jit_scope():
input_tensors = [input_ids, input_mask, input_type_ids]
model = modeling.BertModel(
config=bert_config,
is_training=False,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=input_type_ids,
use_one_hot_embeddings=False)
# 获取所有要训练的变量
tvars = tf.trainable_variables()
(assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,
init_checkpoint)
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1)
def create_model(self, bert_config, is_training, input_ids, input_mask, segment_ids,
labels, num_labels, use_one_hot_embeddings):
"""
建立微调模型
:param is_training: 训练还是测试
:param input_ids: 输入句子中每个字的索引列表
:param input_mask: 字的屏蔽列表
:param segment_ids: 分段列表,第一个句子用0表示,第二个句子用1表示,[0,0,0...,1,1,]
:param labels: 两个句子是否相似,0:不相似,1:相似
:param num_labels: 多少个样本,多少个标签
:param use_one_hot_embeddings:
:return:
"""
model = modeling.BertModel(
config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings)
# If you want to use the token-level output, use model.get_sequence_output()
output_layer = model.get_pooled_output()
hidden_size = output_layer.shape[-1].value
output_weights = tf.get_variable(
"output_weights", [num_labels, hidden_size],
initializer=tf.truncated_normal_initializer(stddev=0.02))
jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
with jit_scope():
model = modeling.BertModel(
config=bert_config,
is_training=False,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=input_type_ids)
if mode != tf.estimator.ModeKeys.PREDICT:
raise ValueError("Only PREDICT modes are supported: %s" % (mode))
tvars = tf.trainable_variables()
(assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,
init_checkpoint)
tf.logging.info("**** Trainable Variables ****")
for var in tvars:
init_string = ""
if var.name in initialized_variable_names:
init_string = ", *INIT_FROM_CKPT*"
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
init_string)
all_layers = model.get_all_encoder_layers()
predictions = {
"unique_id": unique_ids,
}
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
label_ids = features["label_ids"]
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
(total_loss, per_example_loss, logits, probabilities) = self.create_model(
bert_config, is_training, input_ids, input_mask, segment_ids, label_ids,
num_labels, use_one_hot_embeddings)
tvars = tf.trainable_variables()
initialized_variable_names = {}
if init_checkpoint:
(assignment_map, initialized_variable_names) \
= modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
tf.logging.info("**** Trainable Variables ****")
for var in tvars:
init_string = ""
if var.name in initialized_variable_names:
init_string = ", *INIT_FROM_CKPT*"
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
init_string)
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = optimization.create_optimizer(
total_loss, learning_rate, num_train_steps, num_warmup_steps, False)
output_spec = EstimatorSpec(
def optimize_graph(layer_indexes=[-2], config_name='', ckpt_name='', max_seq_len=128, output_dir=''):
try:
# we don't need GPU for optimizing the graph
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
# allow_soft_placement:自动选择运行设备
config = tf.ConfigProto(allow_soft_placement=True)
config_fp = config_name
init_checkpoint = ckpt_name
logger.info('model config: %s' % config_fp)
# 加载bert配置文件
with tf.gfile.GFile(config_fp, 'r') as f:
bert_config = modeling.BertConfig.from_dict(json.load(f))
logger.info('build graph...')
# input placeholders, not sure if they are friendly to XLA
input_ids = tf.placeholder(tf.int32, (None, max_seq_len), 'input_ids')
input_mask = tf.placeholder(tf.int32, (None, max_seq_len), 'input_mask')
input_type_ids = tf.placeholder(tf.int32, (None, max_seq_len), 'input_type_ids')
jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
with jit_scope():
input_tensors = [input_ids, input_mask, input_type_ids]
model = modeling.BertModel(
config=bert_config,
is_training=False,
input_ids=input_ids,