Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_gpt2_bpe_tokenize(self):
with testing_utils.capture_output():
opt = Opt({'dict_tokenizer': 'gpt2', 'datapath': './data'})
agent = DictionaryAgent(opt)
self.assertEqual(
# grinning face emoji
agent.gpt2_tokenize(u'Hello, ParlAI! \U0001f600'),
[
'Hello',
',',
r'\xc4\xa0Par',
'l',
'AI',
'!',
r'\xc4\xa0\xc3\xb0\xc5\x81\xc4\xba',
r'\xc4\xa2',
],
)
self.assertEqual(
agent.vec2txt(
def __init__(self, opt, shared=None):
"""
Initialize DictionaryAgent.
"""
self.opt = copy.deepcopy(opt)
self.minfreq = opt.get('dict_minfreq', DictionaryAgent.default_minfreq)
self.null_token = opt.get('dict_nulltoken', DictionaryAgent.default_null)
self.end_token = opt.get('dict_endtoken', DictionaryAgent.default_end)
self.unk_token = opt.get('dict_unktoken', DictionaryAgent.default_unk)
self.start_token = opt.get('dict_starttoken', DictionaryAgent.default_start)
self.max_ngram_size = opt.get(
'dict_max_ngram_size', DictionaryAgent.default_maxngram
)
self.tokenizer = opt.get('dict_tokenizer', DictionaryAgent.default_tok)
self.lower = opt.get('dict_lower', DictionaryAgent.default_lower)
self.maxtokens = opt.get('dict_maxtokens', DictionaryAgent.default_maxtokens)
self.textfields = opt.get(
'dict_textfields', DictionaryAgent.default_textfields
).split(",")
try:
self.tokenizer_fun = getattr(self, self.tokenizer + '_tokenize')
except AttributeError:
raise AttributeError(
'tokenizer type {} not yet supported'.format(self.tokenizer)
)
if shared:
self.freq = shared.get('freq', {})
self.tok2ind = shared.get('tok2ind', {})
self.ind2tok = shared.get('ind2tok', {})
'--dict-initpath',
hidden=True,
help='path to a saved dictionary to load tokens / counts from to '
'seed the dictionary with initial tokens and/or frequencies',
)
dictionary.add_argument(
'--dict-language',
default=DictionaryAgent.default_lang,
hidden=True,
help='sets language for the punkt sentence tokenizer',
)
dictionary.add_argument(
'--dict-max-ngram-size',
type=int,
hidden=True,
default=DictionaryAgent.default_maxngram,
help='looks for ngrams of up to this size. this is ignored when '
'building the dictionary. note: this takes approximate '
'runtime of len(sentence)^max_ngram_size',
)
dictionary.add_argument(
'--dict-minfreq',
default=DictionaryAgent.default_minfreq,
type=int,
help='minimum frequency of words to include them in sorted '
'dict or minimum frequency of bpe codecs',
hidden=True,
)
dictionary.add_argument(
'--dict-maxtokens',
default=DictionaryAgent.default_maxtokens,
type=int,
)
return
if skip_if_built and os.path.isfile(opt['dict_file']):
# Dictionary already built, skip all loading or setup
print("[ dictionary already built .]")
return None
if is_distributed():
raise ValueError('Dictionaries should be pre-built before distributed train.')
if opt.get('dict_class'):
# Custom dictionary class
dictionary = str2class(opt['dict_class'])(opt)
else:
# Default dictionary class
dictionary = DictionaryAgent(opt)
if os.path.isfile(opt['dict_file']):
# Dictionary already built, return loaded dictionary agent
print("[ dictionary already built .]")
return dictionary
ordered_opt = copy.deepcopy(opt)
cnt = 0
# we use train set to build dictionary
ordered_opt['numthreads'] = 1
ordered_opt['batchsize'] = 1
ordered_opt['image_mode'] = 'none'
ordered_opt['pytorch_teacher_batch_sort'] = False
if ordered_opt['task'] == 'pytorch_teacher' or not ordered_opt['task']:
pytorch_teacher_task = ordered_opt.get('pytorch_teacher_task', '')
)
dict_loop.add_argument(
'--dict-include-test',
default=False,
type='bool',
help='Include test set in dictionary building for task.',
hidden=hidden,
)
dict_loop.add_argument(
'-ltim', '--log-every-n-secs', type=float, default=2, hidden=hidden
)
partial, _ = parser.parse_known_args(nohelp=True)
if vars(partial).get('dict_class'):
str2class(vars(partial).get('dict_class')).add_cmdline_args(parser)
else:
DictionaryAgent.add_cmdline_args(parser)
return parser
def dictionary_class():
return DictionaryAgent
self.fixedCands_txt = shared['fixedCands_txt']
self.fixedCands2 = shared['fixedCands2']
self.fixedCands_txt2 = shared['fixedCands_txt2']
else:
print("[ creating KvmemnnAgent ]")
# this is not a shared instance of this class, so do full init
self.threadindex = -1
torch.set_num_threads(1)
if (opt['dict_file'] is None and opt.get('model_file')) or os.path.isfile(
opt['model_file'] + '.dict'
):
# set default dict-file if not set
opt['dict_file'] = opt['model_file'] + '.dict'
# load dictionary and basic tokens & vectors
self.dict = DictionaryAgent(opt)
if 'loss' not in opt:
opt['loss'] = 'cosine'
self.model = Kvmemnn(opt, len(self.dict), self.dict)
if opt.get('model_file') and os.path.isfile(opt['model_file']):
self.load(opt['model_file'])
self.model.share_memory()
self.fixedCands = False
self.fixedX = None
path = opt['model_file'] + '.candspair'
if os.path.isfile(path) and opt.get('loadcands') != False:
print("[loading candidates: " + path + "*]")
fc = load_cands(path)
fcs = []
for c in fc:
fcs.append(Variable(torch.LongTensor(self.parse(c)).unsqueeze(0)))
def build_dict(opt):
if not opt.get('dict_file'):
print('Tried to build dictionary but `--dict-file` is not set. Set ' +
'this param so the dictionary can be saved.')
return
print('[ setting up dictionary. ]')
if os.path.isfile(opt['dict_file']):
# Dictionary already built
print("[ dictionary already built .]")
return
if opt.get('dict_class'):
# Custom dictionary class
dictionary = str2class(opt['dict_class'])(opt)
else:
# Default dictionary class
dictionary = DictionaryAgent(opt)
ordered_opt = copy.deepcopy(opt)
cnt = 0
# we use train set to build dictionary
ordered_opt['datatype'] = 'train:ordered'
ordered_opt['numthreads'] = 1
ordered_opt['batchsize'] = 1
world_dict = create_task(ordered_opt, dictionary)
# pass examples to dictionary
for _ in world_dict:
cnt += 1
if cnt > opt['dict_maxexs'] and opt['dict_maxexs'] > 0:
print('Processed {} exs, moving on.'.format(opt['dict_maxexs']))
# don't wait too long...
break
world_dict.parley()
print('[ dictionary built. ]')
def dictionary_class():
return DictionaryAgent