Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def run_display_test(self, kwargs):
f = io.StringIO()
with redirect_stdout(f):
parser = setup_args()
parser.set_defaults(**kwargs)
opt = parser.parse_args()
agent = RepeatLabelAgent(opt)
world = create_task(opt, agent)
display(opt)
str_output = f.getvalue()
self.assertTrue(
'[ loaded {} episodes with a total of {} examples ]'.format(
world.num_episodes(), world.num_examples()
)
in str_output,
'Wizard of Wikipedia failed with following args: {}'.format(opt),
)
from parlai.core.teachers import FbDialogTeacher
from .build import build
import copy
import os
def _path(opt, filtered):
# Build the data if it doesn't exist.
build(opt)
dt = opt['datatype'].split(':')[0]
return os.path.join(opt['datapath'], 'MCTest', dt + filtered + '.txt')
class Task160Teacher(FbDialogTeacher):
def __init__(self, opt, shared=None):
opt = copy.deepcopy(opt)
opt['datafile'] = _path(opt, '160')
super().__init__(opt, shared)
class Task500Teacher(FbDialogTeacher):
def __init__(self, opt, shared=None):
opt = copy.deepcopy(opt)
opt['datafile'] = _path(opt, '500')
super().__init__(opt, shared)
class DefaultTeacher(Task500Teacher):
pass
def build(opt):
dpath = os.path.join(opt['datapath'], 'BookTest')
version = None
if not build_data.built(dpath, version_string=version):
print('[building data: ' + dpath + ']')
if build_data.built(dpath):
# An older version exists, so remove these outdated files.
build_data.remove_dir(dpath)
build_data.make_dir(dpath)
# Download the data.
for downloadable_file in RESOURCES:
downloadable_file.download_file(dpath)
# Mark the data as built.
build_data.mark_done(dpath, version_string=version)
def test_qangaroo(self):
from parlai.core.params import ParlaiParser
from parlai.tasks.qangaroo.agents import DefaultTeacher
opt = ParlaiParser().parse_args(args=self.args)
opt['datatype'] = 'train'
teacher = DefaultTeacher(opt)
reply = teacher.act()
check(opt, reply)
shutil.rmtree(self.TMP_PATH)
def test_file_inference(self):
"""
Test --inference with older model files.
"""
testing_utils.download_unittest_models()
with testing_utils.capture_output():
pp = ParlaiParser(True, True)
opt = pp.parse_args(
['--model-file', 'zoo:unittest/transformer_generator2/model']
)
agent = create_agent(opt, True)
self.assertEqual(agent.opt['inference'], 'greedy')
with testing_utils.capture_output():
pp = ParlaiParser(True, True)
opt = pp.parse_args(
[
'--model-file',
'zoo:unittest/transformer_generator2/model',
'--beam-size',
'5',
],
print_args=False,
)
agent = create_agent(opt, True)
self.assertEqual(agent.opt['inference'], 'beam')
def get_acts_epochs_1_and_2(defaults):
parser.set_defaults(**defaults)
opt = parser.parse_args()
build_dict(opt)
agent = create_agent(opt)
world_data = create_task(opt, agent)
acts_epoch_1 = []
acts_epoch_2 = []
while not world_data.epoch_done():
world_data.parley()
acts_epoch_1.append(world_data.acts[0])
world_data.reset()
while not world_data.epoch_done():
world_data.parley()
acts_epoch_2.append(world_data.acts[0])
acts_epoch_1 = [bb for b in acts_epoch_1 for bb in b]
acts_epoch_1 = sorted(
[b for b in acts_epoch_1 if 'text' in b], key=lambda x: x.get('text')
)
acts_epoch_2 = [bb for b in acts_epoch_2 for bb in b]
acts_epoch_2 = sorted(
[b for b in acts_epoch_2 if 'text' in b], key=lambda x: x.get('text')
with testing_utils.capture_output() as _:
modfn = os.path.join(tmp, 'model')
with open(modfn, 'w') as f:
f.write('Test.')
optfn = modfn + '.opt'
base_opt = {
'model': 'tests.test_params:_ExampleUpgradeOptAgent',
'dict_file': modfn + '.dict',
'model_file': modfn,
}
with open(optfn, 'w') as f:
json.dump(base_opt, f)
pp = ParlaiParser(True, True)
opt = pp.parse_args(['--model-file', modfn])
agents.create_agent(opt)
def test_gpt2_bpe_tokenize(self):
with testing_utils.capture_output():
opt = Opt({'dict_tokenizer': 'gpt2', 'datapath': './data'})
agent = DictionaryAgent(opt)
self.assertEqual(
# grinning face emoji
agent.gpt2_tokenize(u'Hello, ParlAI! \U0001f600'),
[
'Hello',
',',
r'\xc4\xa0Par',
'l',
'AI',
'!',
r'\xc4\xa0\xc3\xb0\xc5\x81\xc4\xba',
r'\xc4\xa2',
],
)
self.assertEqual(
agent.vec2txt(
def build(opt):
dpath = os.path.join(opt['datapath'], 'MCTest')
version = None
if not build_data.built(dpath, version_string=version):
print('[building data: ' + dpath + ']')
if build_data.built(dpath):
# An older version exists, so remove these outdated files.
build_data.remove_dir(dpath)
build_data.make_dir(dpath)
# Download the data.
fname = 'mctest.tar.gz'
url = 'http://parl.ai/downloads/mctest/' + fname
build_data.download(url, dpath, fname)
build_data.untar(dpath, fname)
dpext = os.path.join(dpath, 'mctest')
create_fb_format(
dpath, 'train160', os.path.join(dpext, 'MCTest', 'mc160.train'), None
)
create_fb_format(
dpath, 'valid160', os.path.join(dpext, 'MCTest', 'mc160.dev'), None
)
create_fb_format(
dpath,
'test160',
os.path.join(dpext, 'MCTest', 'mc160.test'),
os.path.join(dpext, 'MCTestAnswers', 'mc160.test.ans'),
)
create_fb_format(
dpath, 'train500', os.path.join(dpext, 'MCTest', 'mc500.train'), None
def test_load_agent(self):
agent_module = load_agent_module(OPTIONS['agent'])
self.assertEqual(agent_module, RepeatLabelAgent)