Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def sequence_length(self, val: Union[int, str]):
if isinstance(val, str):
if val == 'auto':
logger.debug("Sequence length will auto set at 95% of sequence length")
elif val == 'variable':
val = None
else:
raise ValueError("sequence_length must be an int or 'auto' or 'variable'")
self.processor.sequence_length = val
tar_filename = model_dict.get('tar_filename')
self.w2v_kwargs = {'binary': model_dict.get('binary')}
url = model_dict.get('url')
untar_filename = model_dict.get('untar_filename')
self.w2v_path = os.path.join(text2vec.USER_DATA_DIR, untar_filename)
if not os.path.exists(self.w2v_path):
get_file(
tar_filename, url, extract=True,
cache_dir=text2vec.USER_DIR,
cache_subdir=text2vec.USER_DATA_DIR,
verbose=1
)
t0 = time.time()
w2v = KeyedVectors.load_word2vec_format(self.w2v_path, **self.w2v_kwargs)
w2v.init_sims(replace=True)
logger.debug('load w2v from %s, spend %s s' % (self.w2v_path, time.time() - t0))
token2idx = {
self.processor.token_pad: 0,
self.processor.token_unk: 1,
self.processor.token_bos: 2,
self.processor.token_eos: 3
}
for token in w2v.index2word:
token2idx[token] = len(token2idx)
vector_matrix = np.zeros((len(token2idx), w2v.vector_size))
vector_matrix[1] = np.random.rand(w2v.vector_size)
vector_matrix[4:] = w2v.vectors
self.embedding_size = w2v.vector_size
self.w2v_vector_matrix = vector_matrix
def _build_token2idx_from_bert(self):
dict_path = os.path.join(self.model_folder, 'vocab.txt')
if not os.path.exists(dict_path):
model_name = self.model_key_map.get(self.model_folder, 'chinese_L-12_H-768_A-12')
url = self.pre_trained_models.get(model_name)
get_file(
model_name + ".zip", url, extract=True,
cache_dir=text2vec.USER_DIR,
cache_subdir=text2vec.USER_DATA_DIR,
verbose=1
)
self.model_folder = os.path.join(text2vec.USER_DATA_DIR, model_name)
dict_path = os.path.join(self.model_folder, 'vocab.txt')
logger.debug(f'load vocab.txt from {dict_path}')
token2idx = {}
with codecs.open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
token = line.strip()
token2idx[token] = len(token2idx)
self.bert_token2idx = token2idx
self.tokenizer = keras_bert.Tokenizer(token2idx)
self.processor.token2idx = self.bert_token2idx
self.processor.idx2token = dict([(value, key) for key, value in token2idx.items()])
vector_matrix = np.zeros((len(token2idx), w2v.vector_size))
vector_matrix[1] = np.random.rand(w2v.vector_size)
vector_matrix[4:] = w2v.vectors
self.embedding_size = w2v.vector_size
self.w2v_vector_matrix = vector_matrix
self.w2v_token2idx = token2idx
self.w2v_top_words = w2v.index2entity[:50]
self.w2v_model_loaded = True
self.w2v = w2v
self.processor.token2idx = self.w2v_token2idx
self.processor.idx2token = dict([(value, key) for key, value in self.w2v_token2idx.items()])
logger.debug('word count : {}'.format(len(self.w2v_vector_matrix)))
logger.debug('emb size : {}'.format(self.embedding_size))
logger.debug('Top 50 word : {}'.format(self.w2v_top_words))
self.tokenizer = Tokenizer()
def optimize_graph(layer_indexes=[-2], config_name='', ckpt_name='', max_seq_len=128, output_dir=''):
try:
# we don't need GPU for optimizing the graph
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
# allow_soft_placement:自动选择运行设备
config = tf.ConfigProto(allow_soft_placement=True)
config_fp = config_name
init_checkpoint = ckpt_name
logger.info('model config: %s' % config_fp)
# 加载bert配置文件
with tf.gfile.GFile(config_fp, 'r') as f:
bert_config = modeling.BertConfig.from_dict(json.load(f))
logger.info('build graph...')
# input placeholders, not sure if they are friendly to XLA
input_ids = tf.placeholder(tf.int32, (None, max_seq_len), 'input_ids')
input_mask = tf.placeholder(tf.int32, (None, max_seq_len), 'input_mask')
input_type_ids = tf.placeholder(tf.int32, (None, max_seq_len), 'input_type_ids')
jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
with jit_scope():
input_tensors = [input_ids, input_mask, input_type_ids]
if isinstance(seq_len, tuple):
seq_len = seq_len[0]
config_path = os.path.join(self.model_folder, 'bert_config.json')
check_point_path = os.path.join(self.model_folder, 'bert_model.ckpt')
logger.debug('load bert model from %s' % check_point_path)
bert_model = keras_bert.load_trained_model_from_checkpoint(config_path,
check_point_path,
seq_len=seq_len,
output_layer_num=self.layer_nums,
training=self.training,
trainable=self.trainable)
self._model = tf.keras.Model(bert_model.inputs, bert_model.output)
bert_seq_len = int(bert_model.output.shape[1])
if bert_seq_len < seq_len:
logger.warning(f"Sequence length limit set to {bert_seq_len} by pre-trained model")
self.sequence_length = bert_seq_len
self.embedding_size = int(bert_model.output.shape[-1])
output_features = NonMaskingLayer()(bert_model.output)
self.embed_model = tf.keras.Model(bert_model.inputs, output_features)
logger.debug(f'seq_len: {self.sequence_length}')
logger.info('optimize...')
bert_graph = optimize_for_inference(
bert_graph,
[n.name[:-2] for n in input_tensors],
[n.name[:-2] for n in output_tensors],
[dtype.as_datatype_enum for dtype in dtypes],
False)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
graph_file = os.path.join(output_dir, 'graph.txt')
logger.info('write graph to file: %s' % graph_file)
with tf.gfile.GFile(graph_file, 'wb') as f:
f.write(bert_graph.SerializeToString())
return graph_file
except Exception as e:
logger.error('fail to optimize the graph!')
logger.error(e)
bert_model = keras_bert.load_trained_model_from_checkpoint(config_path,
check_point_path,
seq_len=seq_len,
output_layer_num=self.layer_nums,
training=self.training,
trainable=self.trainable)
self._model = tf.keras.Model(bert_model.inputs, bert_model.output)
bert_seq_len = int(bert_model.output.shape[1])
if bert_seq_len < seq_len:
logger.warning(f"Sequence length limit set to {bert_seq_len} by pre-trained model")
self.sequence_length = bert_seq_len
self.embedding_size = int(bert_model.output.shape[-1])
output_features = NonMaskingLayer()(bert_model.output)
self.embed_model = tf.keras.Model(bert_model.inputs, output_features)
logger.debug(f'seq_len: {self.sequence_length}')