Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_embed_not_implemented(self):
class Dummy(SentenceEmbedding):
def __call__(self, s): ...
sentence = 'I am a dog.'
with self.assertRaises(NotImplementedError):
Dummy(Tokenizer(), WordEmbedding()).embed(sentence)
def __init__(
self,
tokenizer: Tokenizer,
word_embedder: WordEmbedding) -> None:
self.tokenizer = tokenizer
self.word_embedder = word_embedder
def embed(self, sentence: str) -> np.ndarray:
raise NotImplementedError
def __call__(self, sentence: str) -> np.ndarray:
raise NotImplementedError
class MeanEmbedding(SentenceEmbedding):
def __init__(
self,
lang: str = 'en',
tokenizer: Tokenizer = None,
word_embedder: WordEmbedding = None) -> None:
tokenizer = tokenizer or {"en": SimpleTokenizer(),
"fr": SimpleTokenizer(),
"ja": JapaneseTokenizer()}[lang]
word_embedder = word_embedder or FasttextEmbedding(lang)
super().__init__(tokenizer, word_embedder)
def embed(self, sentence: str) -> np.ndarray:
tokens = self.tokenizer.tokenize(sentence)
vectors = self.word_embedder.get_word_vectors(tokens)
return np.mean(vectors, axis=0)