Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _test_mean(self, onehot):
vocab_size = 99 if onehot else 800
priming_data = generate_sentences(2, 6, vocab_size)
test_data = random.sample(priming_data, len(priming_data) // 5)
enc = ShortTextEncoder(combine='mean')
enc.prepare_encoder(priming_data)
if onehot:
assert not enc.cae.use_autoencoder
else:
assert enc.cae.use_autoencoder
encoded_data = enc.encode(test_data)
assert len(test_data) == len(encoded_data)
with self.assertRaises(ValueError):
decoded_data = enc.decode(encoded_data)
def _test_concat(self, onehot):
vocab_size = 99 if onehot else 800
priming_data = generate_sentences(2, 6, vocab_size)
test_data = random.sample(priming_data, len(priming_data) // 5)
enc = ShortTextEncoder(combine='concat')
enc.prepare_encoder(priming_data)
if onehot:
assert not enc.cae.use_autoencoder
else:
assert enc.cae.use_autoencoder
encoded_data = enc.encode(test_data)
decoded_data = enc.decode(encoded_data)
assert len(test_data) == len(encoded_data) == len(decoded_data)
for x_sent, y_sent in zip(
test_data,
[' '.join(x) for x in decoded_data]
):
def _test_concat(self, onehot):
vocab_size = 99 if onehot else 800
priming_data = generate_sentences(2, 6, vocab_size)
test_data = random.sample(priming_data, len(priming_data) // 5)
enc = ShortTextEncoder(combine='concat')
enc.prepare_encoder(priming_data)
if onehot:
assert not enc.cae.use_autoencoder
else:
assert enc.cae.use_autoencoder
encoded_data = enc.encode(test_data)
decoded_data = enc.decode(encoded_data)
assert len(test_data) == len(encoded_data) == len(decoded_data)
for x_sent, y_sent in zip(
test_data,
[' '.join(x) for x in decoded_data]
):
def _test_mean(self, onehot):
vocab_size = 99 if onehot else 800
priming_data = generate_sentences(2, 6, vocab_size)
test_data = random.sample(priming_data, len(priming_data) // 5)
enc = ShortTextEncoder(combine='mean')
enc.prepare_encoder(priming_data)
if onehot:
assert not enc.cae.use_autoencoder
else:
assert enc.cae.use_autoencoder
encoded_data = enc.encode(test_data)
assert len(test_data) == len(encoded_data)
with self.assertRaises(ValueError):
decoded_data = enc.decode(encoded_data)