Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_vocab_specials_first(self):
c = Counter("a a b b c c".split())
# add specials into vocabulary at first
v = vocab.Vocab(c, max_size=2, specials=['', ''])
expected_itos = ['', '', 'a', 'b']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
self.assertEqual(v.itos, expected_itos)
self.assertEqual(dict(v.stoi), expected_stoi)
# add specials into vocabulary at last
v = vocab.Vocab(c, max_size=2, specials=['', ''], specials_first=False)
expected_itos = ['a', 'b', '', '']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
self.assertEqual(v.itos, expected_itos)
self.assertEqual(dict(v.stoi), expected_stoi)
def test_errors(self):
c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIá‘•Oᗪᕮ_Tá•®á™T': 5, 'freq_too_low': 2})
with self.assertRaises(ValueError):
# Test proper error raised when using unknown string alias
vocab.Vocab(c, min_freq=3, specials=['', '', ''],
vectors=["fasttext.english.300d"])
vocab.Vocab(c, min_freq=3, specials=['', '', ''],
vectors="fasttext.english.300d")
with self.assertRaises(ValueError):
# Test proper error is raised when vectors argument is
# non-string or non-Vectors
vocab.Vocab(c, min_freq=3, specials=['', '', ''],
vectors={"word": [1, 2, 3]})
def create_vocab(config, datasets, fields):
vocab_path = os.path.join(config.save_path, "vocabulary.pth")
common_vocab = config.compositional_args and config.compositional_rels
if os.path.exists(vocab_path):
arg_counter, rel_counter = torch.load(vocab_path)
else:
arg_counter, rel_counter = vocab_from_instances(datasets[0], datasets[1])
torch.save((arg_counter, rel_counter), vocab_path)
subject_field, object_field, relation_field = fields
arg_specials = list(OrderedDict.fromkeys(tok for tok in [subject_field.unk_token, subject_field.pad_token, subject_field.init_token, subject_field.eos_token] if tok is not None))
rel_specials = list(OrderedDict.fromkeys(tok for tok in [relation_field.unk_token, relation_field.pad_token, relation_field.init_token, relation_field.eos_token] if tok is not None))
max_vocab_size = config.max_vocab_size if hasattr(config, "max_vocab_size") else None
arg_vocab = Vocab(arg_counter, specials=arg_specials, vectors='glove.6B.200d', vectors_cache='/glove', max_size=max_vocab_size)
rel_vocab = Vocab(rel_counter, specials=rel_specials, vectors='glove.6B.200d', vectors_cache='/glove', max_size=max_vocab_size) if not common_vocab else arg_vocab
subject_field.vocab, object_field.vocab, relation_field.vocab = arg_vocab, arg_vocab, rel_vocab
from onmt.io.DatasetBase import UNK_WORD, PAD_WORD, BOS_WORD, EOS_WORD
from onmt.io.TextDataset import TextDataset
from onmt.io.ImageDataset import ImageDataset
from onmt.io.AudioDataset import AudioDataset
def _getstate(self):
return dict(self.__dict__, stoi=dict(self.stoi))
def _setstate(self, state):
self.__dict__.update(state)
self.stoi = defaultdict(lambda: 0, self.stoi)
torchtext.vocab.Vocab.__getstate__ = _getstate
torchtext.vocab.Vocab.__setstate__ = _setstate
def get_fields(data_type, n_src_features, n_tgt_features):
"""
Args:
data_type: type of the source input. Options are [text|img|audio].
n_src_features: the number of source features to
create `torchtext.data.Field` for.
n_tgt_features: the number of target features to
create `torchtext.data.Field` for.
Returns:
A dictionary whose keys are strings and whose values are the
corresponding Field objects.
"""
"""Detect old-style vocabs (``List[Tuple[str, torchtext.data.Vocab]]``).
Args:
vocab: some object loaded from a *.vocab.pt file
Returns:
Whether ``vocab`` is a list of pairs where the second object
is a :class:`torchtext.vocab.Vocab` object.
This exists because previously only the vocab objects from the fields
were saved directly, not the fields themselves, and the fields needed to
be reconstructed at training and translation time.
"""
return isinstance(vocab, list) and \
any(isinstance(v[1], Vocab) for v in vocab)
Args:
example (dict): An example dictionary with a ``"src"`` key and
maybe a ``"tgt"`` key. (This argument changes in place!)
src_field (torchtext.data.Field): Field object.
tgt_field (torchtext.data.Field): Field object.
Returns:
torchtext.data.Vocab and ``example``, changed as described.
"""
src = src_field.tokenize(example["src"])
# make a small vocab containing just the tokens in the source sequence
unk = src_field.unk_token
pad = src_field.pad_token
src_ex_vocab = Vocab(Counter(src), specials=[unk, pad])
unk_idx = src_ex_vocab.stoi[unk]
# Map source tokens to indices in the dynamic dict.
src_map = torch.LongTensor([src_ex_vocab.stoi[w] for w in src])
example["src_map"] = src_map
if "tgt" in example:
tgt = tgt_field.tokenize(example["tgt"])
mask = torch.LongTensor(
[unk_idx] + [src_ex_vocab.stoi[w] for w in tgt] + [unk_idx])
example["alignment"] = mask
return src_ex_vocab, example
def get_vocab_imdb(data):
"""Get the vocab for the IMDB data set for sentiment analysis."""
tokenized_data = get_tokenized_imdb(data)
counter = collections.Counter([tk for st in tokenized_data for tk in st])
return text.vocab.Vocab(counter, min_freq=5)
def _setup_datasets(dataset_name, root='.data', ngrams=1, vocab=None, include_unk=False):
dataset_tar = download_from_url(URLS[dataset_name], root=root)
extracted_files = extract_archive(dataset_tar)
for fname in extracted_files:
if fname.endswith('train.csv'):
train_csv_path = fname
if fname.endswith('test.csv'):
test_csv_path = fname
if vocab is None:
logging.info('Building Vocab based on {}'.format(train_csv_path))
vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams))
else:
if not isinstance(vocab, Vocab):
raise TypeError("Passed vocabulary is not of type Vocab")
logging.info('Vocab has {} entries'.format(len(vocab)))
logging.info('Creating training data')
train_data, train_labels = _create_data_from_iterator(
vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True), include_unk)
logging.info('Creating testing data')
test_data, test_labels = _create_data_from_iterator(
vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True), include_unk)
if len(train_labels ^ test_labels) > 0:
raise ValueError("Training and test labels don't match")
return (TextClassificationDataset(vocab, train_data, train_labels),
TextClassificationDataset(vocab, test_data, test_labels))
vectors = GloVe(name=name, dim=dim, cache=expanduser('~/cache'))
elif pretrained_embeddings == 'fasttext':
FastText.__getitem__ = CaseInsensitiveVectors.__getitem__
FastText.cache = CaseInsensitiveVectors.cache
vectors = FastText(language=language,
cache=expanduser('~/cache'))
# extend vocab with words of test/val set that has embeddings in
# pre-trained embedding
# A prod-version would do it dynamically at inference time
counter = Counter()
sentences.build_vocab(val_data, test_data)
for word in sentences.vocab.stoi:
if word in vectors.stoi or word.lower() in vectors.stoi or \
re.sub('\d', '0', word.lower()) in vectors.stoi:
counter[word] = 1
eval_vocab = Vocab(counter)
print("%i/%i eval/test word in pretrained" % (len(counter),
len(sentences.vocab.stoi)))
sentences.build_vocab(train_data)
prev_vocab_size = len(sentences.vocab.stoi)
sentences.vocab.extend(eval_vocab)
new_vocab_size = len(sentences.vocab.stoi)
print('New vocab size: %i (was %i)' % (new_vocab_size,
prev_vocab_size))
sentences.vocab.load_vectors(vectors)
embedding_dim = sentences.vocab.vectors.shape[1]
artifact_dir = _run.info['artifact_dir']
vocab_dict = {'sentences': sentences.vocab,
'tags': tags.vocab,
'letters': letter.vocab}
torch.save(vocab_dict, open(join(artifact_dir, 'vocab.pt'), 'wb+'))
format='tsv',
fields=[
('context', TEXT),
('generated', TEXT),
('gold', TEXT),
])
#TEXT.build_vocab(train)
# Read in the LM dictionary.
print('Building the dictionary')
with open(args.dic, 'rb') as dic_file:
dictionary = pickle.load(dic_file)
# Reconstruct the dictionary in torchtext.
counter = Counter({'': 0, '':0})
TEXT.vocab = vocab.Vocab(counter, specials=['', ''])
TEXT.vocab.itos = dictionary.idx2word
TEXT.vocab.stoi = defaultdict(vocab._default_unk_index, dictionary.word2idx)
TEXT.vocab.load_vectors('glove.6B.%dd' % args.embedding_dim)
itos = TEXT.vocab.itos if args.p else None
print('Vocab size %d' % len(TEXT.vocab))
train_iter = data.Iterator(dataset=train, batch_size=args.batch_size,
sort_key=lambda x: len(x.context), sort=True, repeat=False)
valid_iter = data.Iterator(dataset=valid, batch_size=args.batch_size, sort_key=lambda x: len(x.context), sort=True, repeat=False)
print('Initializing the model')
if args.load_model != '':
with open(args.load_model, 'rb') as f:
model = torch.load(f).cuda()