How to use the text2vec.utils.non_masking_layer.NonMaskingLayer function in text2vec

To help you get started, we’ve selected a few text2vec examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github shibing624 / text2vec / text2vec / embeddings / bert_embedding.py View on Github external
check_point_path = os.path.join(self.model_folder, 'bert_model.ckpt')
            logger.debug('load bert model from %s' % check_point_path)
            bert_model = keras_bert.load_trained_model_from_checkpoint(config_path,
                                                                       check_point_path,
                                                                       seq_len=seq_len,
                                                                       output_layer_num=self.layer_nums,
                                                                       training=self.training,
                                                                       trainable=self.trainable)

            self._model = tf.keras.Model(bert_model.inputs, bert_model.output)
            bert_seq_len = int(bert_model.output.shape[1])
            if bert_seq_len < seq_len:
                logger.warning(f"Sequence length limit set to {bert_seq_len} by pre-trained model")
                self.sequence_length = bert_seq_len
            self.embedding_size = int(bert_model.output.shape[-1])
            output_features = NonMaskingLayer()(bert_model.output)
            self.embed_model = tf.keras.Model(bert_model.inputs, output_features)
            logger.debug(f'seq_len: {self.sequence_length}')
github shibing624 / text2vec / text2vec / utils / non_masking_layer.py View on Github external
def __init__(self, **kwargs):
        self.supports_masking = True
        super(NonMaskingLayer, self).__init__(**kwargs)
github shibing624 / text2vec / text2vec / utils / non_masking_layer.py View on Github external
def __init__(self, **kwargs):
        self.supports_masking = True
        super(NonMaskingLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        pass

    def compute_mask(self, inputs, input_mask=None):
        # do not pass the mask to the next layers
        return None

    def call(self, x, mask=None):
        return x


custom_objects['NonMaskingLayer'] = NonMaskingLayer


if __name__ == "__main__":
    print("Hello world")