Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def setUpClass(cls):
env_config_path = os.path.abspath(os.path.join(
os.path.dirname(__file__), '..', '..', '.env'))
load_dotenv(env_config_path)
# https://freewavesamples.com/yamaha-v50-rock-beat-120-bpm
cls.sample_wav_file = os.path.join(
os.environ.get("TEST_DIR"), 'res', 'audio', 'Yamaha-V50-Rock-Beat-120bpm.wav'
)
cls.audio, cls.sampling_rate = AudioLoader.load_audio(cls.sample_wav_file)
cls.textual_augs = [
nac.RandomCharAug(),
naw.ContextualWordEmbsAug(),
nas.ContextualWordEmbsForSentenceAug()
]
cls.audio_augs = [
naa.CropAug(sampling_rate=cls.sampling_rate),
naa.SpeedAug(),
]
def execute_by_device(self, device):
for model_path in self.model_paths:
insert_aug = naw.ContextualWordEmbsAug(
model_path=model_path, action="insert", force_reload=True, device=device)
substitute_aug = naw.ContextualWordEmbsAug(
model_path=model_path, action="substitute", force_reload=True, device=device)
self.oov([insert_aug, substitute_aug])
self.insert(insert_aug)
self.substitute(substitute_aug)
self.substitute_stopwords(substitute_aug)
self.subword([insert_aug, substitute_aug])
self.not_substitute_unknown_word(substitute_aug)
self.top_k([insert_aug, substitute_aug])
self.top_p([insert_aug, substitute_aug])
self.top_k_top_p([insert_aug, substitute_aug])
self.no_top_k_top_p([insert_aug, substitute_aug])
self.max_length([insert_aug, substitute_aug])
self.empty_replacement(substitute_aug)
self.assertLess(0, len(self.model_paths))
def test_skip_punctuation(self):
text = '. . . . ! ? # @'
augs = [
naw.ContextualWordEmbsAug(action='insert'),
naw.AntonymAug(),
naw.TfIdfAug(model_path=os.environ.get("MODEL_DIR"), action="substitute")
]
for aug in augs:
augmented_text = aug.augment(text)
self.assertEqual(text, augmented_text)
def test_empty_input_for_insert(self):
text = ' '
augs = [
naw.ContextualWordEmbsAug(action="insert"),
naw.TfIdfAug(model_path=os.environ.get("MODEL_DIR"), action="substitute")
]
for aug in augs:
augmented_text = aug.augment(text)
# FIXME: standardize return
is_equal = augmented_text == '' or augmented_text == ' '
self.assertTrue(is_equal)
def execute_by_device(self, device):
for model_path in self.model_paths:
insert_aug = naw.ContextualWordEmbsAug(
model_path=model_path, action="insert", force_reload=True, device=device)
substitute_aug = naw.ContextualWordEmbsAug(
model_path=model_path, action="substitute", force_reload=True, device=device)
self.oov([insert_aug, substitute_aug])
self.insert(insert_aug)
self.substitute(substitute_aug)
self.substitute_stopwords(substitute_aug)
self.subword([insert_aug, substitute_aug])
self.not_substitute_unknown_word(substitute_aug)
self.top_k([insert_aug, substitute_aug])
self.top_p([insert_aug, substitute_aug])
self.top_k_top_p([insert_aug, substitute_aug])
self.no_top_k_top_p([insert_aug, substitute_aug])
self.max_length([insert_aug, substitute_aug])
self.empty_replacement(substitute_aug)
def test_reset_model(self):
for model_path in self.model_paths:
original_aug = naw.ContextualWordEmbsAug(
model_path=model_path, action="insert", force_reload=True, top_p=0.5)
original_temperature = original_aug.model.temperature
original_top_k = original_aug.model.top_k
original_top_p = original_aug.model.top_p
new_aug = naw.ContextualWordEmbsAug(
model_path=model_path, action="insert", force_reload=True,
temperature=original_temperature+1, top_k=original_top_k+1, top_p=original_top_p+1)
new_temperature = new_aug.model.temperature
new_top_k = new_aug.model.top_k
new_top_p = new_aug.model.top_p
self.assertEqual(original_temperature+1, new_temperature)
self.assertEqual(original_top_k + 1, new_top_k)
self.assertEqual(original_top_p + 1, new_top_p)
def test_reset_model(self):
for model_path in self.model_paths:
original_aug = naw.ContextualWordEmbsAug(
model_path=model_path, action="insert", force_reload=True, top_p=0.5)
original_temperature = original_aug.model.temperature
original_top_k = original_aug.model.top_k
original_top_p = original_aug.model.top_p
new_aug = naw.ContextualWordEmbsAug(
model_path=model_path, action="insert", force_reload=True,
temperature=original_temperature+1, top_k=original_top_k+1, top_p=original_top_p+1)
new_temperature = new_aug.model.temperature
new_top_k = new_aug.model.top_k
new_top_p = new_aug.model.top_p
self.assertEqual(original_temperature+1, new_temperature)
self.assertEqual(original_top_k + 1, new_top_k)
self.assertEqual(original_top_p + 1, new_top_p)