Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
@testing_utils.skipUnlessGPU
@unittest.skipIf(SKIP_TESTS, "Fairseq not installed")
def test_labelcands(self):
stdout, valid, test = testing_utils.train_model(
dict(
task='integration_tests:candidate',
model='fairseq',
arch='lstm_wiseman_iwslt_de_en',
lr=LR,
batchsize=BATCH_SIZE,
num_epochs=NUM_EPOCHS,
rank_candidates=True,
skip_generation=True,
)
)
self.assertTrue(
def test_display_data(self):
"""Test that, with pre-loaded image features, all examples are different."""
def _test_display_output(opt):
output = testing_utils.display_data(opt)
train_labels = re.findall(r"\[labels: .*\]", output[0])
valid_labels = re.findall(r"\[eval_labels: .*\]", output[1])
test_labels = re.findall(r"\[eval_labels: .*\]", output[2])
for i, lbls in enumerate([train_labels, valid_labels, test_labels]):
self.assertGreater(len(lbls), 0, 'DisplayData failed')
self.assertEqual(len(lbls), len(set(lbls)), output[i])
with testing_utils.tempdir() as tmpdir:
data_path = tmpdir
os.makedirs(os.path.join(data_path, 'ImageTeacher'))
opt = {'task': 'integration_tests:ImageTeacher', 'datapath': data_path}
for image_mode in ['resnet152', 'no_image_model']:
opt['image_mode'] = image_mode
_test_display_output(opt)
def test_download_multiprocess(self):
urls = [
'https://parl.ai/downloads/mnist/mnist.tar.gz',
'https://parl.ai/downloads/mnist/mnist.tar.gz.BAD',
'https://parl.ai/downloads/mnist/mnist.tar.gz.BAD',
]
with testing_utils.capture_output() as stdout:
download_results = build_data.download_multiprocess(
urls, self.datapath, dest_filenames=self.dest_filenames
)
stdout = stdout.getvalue()
output_filenames, output_statuses, output_errors = zip(*download_results)
self.assertEqual(
output_filenames,
self.dest_filenames,
f'output filenames not correct\n{stdout}',
)
self.assertEqual(
output_statuses,
(200, 403, 403),
f'output http statuses not correct\n{stdout}',
)
def test_labelcands(self):
stdout, valid, test = testing_utils.train_model(
dict(
task='integration_tests:candidate',
model='fairseq',
arch='lstm_wiseman_iwslt_de_en',
lr=LR,
batchsize=BATCH_SIZE,
num_epochs=NUM_EPOCHS,
rank_candidates=True,
skip_generation=True,
)
)
self.assertTrue(
valid['hits@1'] > 0.95,
"valid hits@1 = {}\nLOG:\n{}".format(valid['hits@1'], stdout),
)
def test_retrieval(self):
stdout, _, test = testing_utils.eval_model(RETRIEVAL_OPTIONS)
self.assertGreaterEqual(
test['accuracy'],
0.86,
'test acc = {}\nLOG:\n{}'.format(test['accuracy'], stdout),
)
self.assertGreaterEqual(
test['hits@5'],
0.98,
'test hits@5 = {}\nLOG:\n{}'.format(test['hits@5'], stdout),
)
self.assertGreaterEqual(
test['hits@10'],
0.99,
'test hits@10 = {}\nLOG:\n{}'.format(test['hits@10'], stdout),
)
"""
# Download model, move to a new location
datapath = ParlaiParser().parse_args(print_args=False)['datapath']
try:
# remove unittest models if there before
shutil.rmtree(os.path.join(datapath, 'models/unittest'))
except FileNotFoundError:
pass
testing_utils.download_unittest_models()
zoo_path = 'zoo:unittest/seq2seq/model'
model_path = modelzoo_path(datapath, zoo_path)
os.remove(model_path + '.dict')
# Test that eval model fails
with self.assertRaises(RuntimeError):
testing_utils.eval_model(dict(task='babi:task1k:1', model_file=model_path))
try:
# remove unittest models if there after
shutil.rmtree(os.path.join(datapath, 'models/unittest'))
except FileNotFoundError:
pass
def setUpClass(cls):
# go ahead and download things here
with testing_utils.capture_output():
parser = display_data.setup_args()
parser.set_defaults(**END2END_OPTIONS)
opt = parser.parse_args(print_args=False)
opt['num_examples'] = 1
display_data.display_data(opt)
def test_released_model(self):
"""
Check the pretrained model produces correct results.
"""
_, _, test = testing_utils.eval_model(
{
'model_file': 'zoo:self_feeding/hh131k_hb60k_fb60k_st1k/model',
'task': 'self_feeding:all',
'batchsize': 20,
},
skip_valid=True,
)
self.assertAlmostEqual(test['dia_acc'], 0.506, delta=0.001)
self.assertAlmostEqual(test['fee_acc'], 0.744, delta=0.001)
self.assertAlmostEqual(test['sat_f1'], 0.8343, delta=0.0001)
def get_agent(**kwargs):
r"""
Return opt-initialized agent.
:param kwargs: any kwargs you want to set using parser.set_params(\*\*kwargs)
"""
if 'no_cuda' not in kwargs:
kwargs['no_cuda'] = True
from parlai.core.params import ParlaiParser
parser = ParlaiParser()
MockTorchAgent.add_cmdline_args(parser)
parser.set_params(**kwargs)
opt = parser.parse_args(print_args=False)
with testing_utils.capture_output():
return MockTorchAgent(opt)
inference='beam',
beam_size=5,
**args,
)
)
self.assertGreaterEqual(noblock_valid['f1'], 0.99)
# first confirm all is good without blocking
_, valid, test = testing_utils.eval_model(
dict(beam_context_block_ngram=-1, **args)
)
self.assertGreaterEqual(valid['f1'], 0.99)
self.assertGreaterEqual(valid['bleu-4'], 0.99)
# there's a special case for block == 1
_, valid, test = testing_utils.eval_model(
dict(beam_context_block_ngram=1, **args)
)
# bleu and f1 should be totally wrecked.
self.assertLess(valid['f1'], 0.01)
self.assertLess(valid['bleu-4'], 0.01)
# a couple general cases
_, valid, test = testing_utils.eval_model(
dict(beam_context_block_ngram=2, **args)
)
# should take a big hit here
self.assertLessEqual(valid['f1'], noblock_valid['f1'])
# bleu-1 should be relatively okay
self.assertLessEqual(valid['bleu-1'], noblock_valid['bleu-1'])
self.assertGreaterEqual(valid['bleu-1'], 0.50)
# and bleu-2 should be 0 at this point