Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def oov(self, augs):
unknown_token = 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'
texts = [
unknown_token,
unknown_token + ' the'
]
for aug in augs:
for text in texts:
self.assertLess(0, len(text))
augmented_text = aug.augment(text)
if aug.action == Action.INSERT:
self.assertLess(len(text.split(' ')), len(augmented_text.split(' ')))
elif aug.action == Action.SUBSTITUTE:
self.assertEqual(len(text.split(' ')), len(augmented_text.split(' ')))
else:
raise Exception('Augmenter is neither INSERT or SUBSTITUTE')
if aug.model_type not in ['roberta']:
self.assertTrue(aug.model.SUBWORD_PREFIX not in augmented_text)
def __init__(self, sampling_rate, duration=3, direction='random',
shift_max=3, shift_direction='both',
name='Shift_Aug', verbose=0):
super().__init__(
action=Action.SUBSTITUTE, name=name, device='cpu', verbose=verbose)
if shift_direction != 'both':
print(WarningMessage.DEPRECATED.format('shift_direction', '0.0.12', 'direction'))
direction = shift_direction
if shift_max != 3:
print(WarningMessage.DEPRECATED.format('shift_max', '0.0.12', 'duration'))
duration = shift_max
self.model = self.get_model(sampling_rate, duration, direction)
def __init__(self, name='Split_Aug', aug_min=1, aug_max=10, aug_p=0.3, min_char=4, stopwords=None,
tokenizer=None, reverse_tokenizer=None, verbose=0):
super().__init__(
action=Action.SPLIT, name=name, aug_p=aug_p, aug_min=aug_min, aug_max=aug_max, stopwords=stopwords,
tokenizer=tokenizer, reverse_tokenizer=reverse_tokenizer, device='cpu', verbose=verbose)
self.min_char = min_char
def __init__(self, sampling_rate=None, zone=(0.2, 0.8), coverage=0.1, duration=None,
crop_range=(0.2, 0.8), crop_factor=2, name='Crop_Aug', verbose=0):
super().__init__(
action=Action.DELETE, name=name, device='cpu', verbose=verbose)
self.model = self.get_model(sampling_rate, zone, coverage, duration)
if crop_range != (0.2, 0.8):
print(WarningMessage.DEPRECATED.format('crop_range', '0.0.12', 'zone'))
if crop_factor != 2:
print(WarningMessage.DEPRECATED.format('crop_factor', '0.0.12', 'temperature'))
def __init__(self, sampling_rate, zone=(0.2, 0.8), coverage=1., duration=None,
factor=(-10, 10), pitch_range=(-10, 10), name='Pitch_Aug', verbose=0):
super().__init__(
action=Action.SUBSTITUTE, name=name, device='cpu', verbose=verbose)
if pitch_range != (-10, 10):
print(WarningMessage.DEPRECATED.format('pitch_range', '0.0.12', 'factor'))
factor = pitch_range
self.model = self.get_model(sampling_rate, zone, coverage, duration, factor)