Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_oov(self):
unknown_token = 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'
texts = [
unknown_token,
unknown_token + ' the'
]
augmenters = [
naw.TfIdfAug(model_path=os.environ.get("MODEL_DIR"), action=Action.INSERT),
naw.TfIdfAug(model_path=os.environ.get("MODEL_DIR"), action=Action.SUBSTITUTE)
]
for aug in augmenters:
for text in texts:
self.assertLess(0, len(text))
augmented_text = aug.augment(text)
if aug.action == Action.INSERT:
self.assertLess(len(text.split(' ')), len(augmented_text.split(' ')))
self.assertNotEqual(text, augmented_text)
elif aug.action == Action.SUBSTITUTE:
self.assertEqual(len(text.split(' ')), len(augmented_text.split(' ')))
if unknown_token == text:
self.assertEqual(text, augmented_text)
else:
self.assertNotEqual(text, augmented_text)
def test_substitute(self):
texts = [
'The quick brown fox jumps over the lazy dog'
]
aug = naw.FasttextAug(
model_path=os.environ.get("MODEL_DIR") + 'wiki-news-300d-1M.vec',
action=Action.SUBSTITUTE)
for text in texts:
self.assertLess(0, len(text))
augmented_text = aug.augment(text)
self.assertNotEqual(text, augmented_text)
self.assertLess(0, len(texts))
cls.insert_augmenters = [
naw.Word2vecAug(
model_path=os.environ.get("MODEL_DIR") + 'GoogleNews-vectors-negative300.bin',
action=Action.INSERT),
naw.FasttextAug(
model_path=os.environ.get("MODEL_DIR") + 'wiki-news-300d-1M.vec',
action=Action.INSERT),
naw.GloVeAug(
model_path=os.environ.get("MODEL_DIR") + 'glove.6B.50d.txt',
action=Action.INSERT)
]
cls.substitute_augmenters = [
naw.Word2vecAug(
model_path=os.environ.get("MODEL_DIR") + 'GoogleNews-vectors-negative300.bin',
action=Action.SUBSTITUTE),
naw.FasttextAug(
model_path=os.environ.get("MODEL_DIR") + 'wiki-news-300d-1M.vec',
action=Action.SUBSTITUTE),
naw.GloVeAug(
model_path=os.environ.get("MODEL_DIR") + 'glove.6B.50d.txt',
action=Action.SUBSTITUTE)
]
def __init__(self, zone=(0.2, 0.8), coverage=1.,
color='white', noises=None, name='Noise_Aug', noise_factor=0.01, verbose=0):
super().__init__(
action=Action.SUBSTITUTE, name=name, device='cpu', verbose=verbose)
if noise_factor != 0.01:
print(WarningMessage.DEPRECATED.format('noise_factor', '0.0.12', ''))
self.model = self.get_model(zone, coverage, color, noises)
def __init__(self, model_path='.', action=Action.SUBSTITUTE,
name='WordEmbs_Aug', aug_min=1, aug_p=0.3, aug_n=5, n_gram_separator='_',
stopwords=None, tokenizer=None, reverse_tokenizer=None, verbose=0):
super().__init__(
action=action, name=name, aug_p=aug_p, aug_min=aug_min, stopwords=stopwords,
tokenizer=tokenizer, reverse_tokenizer=reverse_tokenizer, verbose=verbose)
self.model_path = model_path
self.aug_n = aug_n
self.model = self.get_model(force_reload=False)
self.n_gram_separator = n_gram_separator
def __init__(self, model_path='bert-base-uncased', tokenizer_path='bert-base-uncased', action=Action.SUBSTITUTE,
name='Bert_Aug', aug_min=1, aug_max=10, aug_p=0.3, aug_n=5, stopwords=None, device='cpu', verbose=0):
super().__init__(
action=action, name=name, aug_p=aug_p, aug_min=aug_min, aug_max=aug_max, tokenizer=None,
stopwords=stopwords, verbose=verbose)
self.model_path = model_path
self.tokenizer_path = tokenizer_path
self.aug_n = aug_n
self.device = device
self.model = self.get_model(device=device, force_reload=False)
self.tokenizer = self.model.tokenizer.tokenize
self.reverse_tokenizer = self._reverse_tokenizer
def __init__(self, mask_factor, name='TimeMasking_Aug', verbose=0):
super(TimeMaskingAug, self).__init__(
action=Action.SUBSTITUTE, name=name, device='cpu', verbose=verbose)
self.model = self.get_model(mask_factor)
def __init__(self, name='OCR_Aug', aug_char_min=1, aug_char_max=10, aug_char_p=0.3,
aug_word_p=0.3, aug_word_min=1, aug_word_max=10, stopwords=None,
tokenizer=None, reverse_tokenizer=None, verbose=0):
super().__init__(
action=Action.SUBSTITUTE, name=name, aug_char_min=aug_char_min, aug_char_max=aug_char_max,
aug_char_p=aug_char_p, aug_word_min=aug_word_min, aug_word_max=aug_word_max, aug_word_p=aug_word_p,
tokenizer=tokenizer, reverse_tokenizer=reverse_tokenizer, stopwords=stopwords, device='cpu',
verbose=verbose)
self.model = self.get_model()
def __init__(self, zone=(0.2, 0.8), coverage=1., duration=None,
factor=(0.5, 2),
speed_range=(0.5, 2), name='Speed_Aug', verbose=0):
super().__init__(
action=Action.SUBSTITUTE, name=name, device='cpu', verbose=verbose)
if speed_range != (0.5, 2):
print(WarningMessage.DEPRECATED.format('speed_range', '0.0.12', 'factor'))
factor = speed_range
self.model = self.get_model(zone, coverage, duration, factor)