Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
check_point_path = os.path.join(self.model_folder, 'bert_model.ckpt')
logger.debug('load bert model from %s' % check_point_path)
bert_model = keras_bert.load_trained_model_from_checkpoint(config_path,
check_point_path,
seq_len=seq_len,
output_layer_num=self.layer_nums,
training=self.training,
trainable=self.trainable)
self._model = tf.keras.Model(bert_model.inputs, bert_model.output)
bert_seq_len = int(bert_model.output.shape[1])
if bert_seq_len < seq_len:
logger.warning(f"Sequence length limit set to {bert_seq_len} by pre-trained model")
self.sequence_length = bert_seq_len
self.embedding_size = int(bert_model.output.shape[-1])
output_features = NonMaskingLayer()(bert_model.output)
self.embed_model = tf.keras.Model(bert_model.inputs, output_features)
logger.debug(f'seq_len: {self.sequence_length}')
def __init__(self, **kwargs):
self.supports_masking = True
super(NonMaskingLayer, self).__init__(**kwargs)
def __init__(self, **kwargs):
self.supports_masking = True
super(NonMaskingLayer, self).__init__(**kwargs)
def build(self, input_shape):
pass
def compute_mask(self, inputs, input_mask=None):
# do not pass the mask to the next layers
return None
def call(self, x, mask=None):
return x
custom_objects['NonMaskingLayer'] = NonMaskingLayer
if __name__ == "__main__":
print("Hello world")