Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_build_vocab_from_dataset(self):
nesting_field = data.Field(tokenize=list, unk_token="", pad_token="",
init_token="", eos_token="")
CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>")
ex1 = data.Example.fromlist(["aaa bbb c"], [("chars", CHARS)])
ex2 = data.Example.fromlist(["bbb aaa"], [("chars", CHARS)])
dataset = data.Dataset([ex1, ex2], [("chars", CHARS)])
CHARS.build_vocab(dataset, min_freq=2)
expected = "a b <s> </s> ".split()
assert len(CHARS.vocab) == len(expected)
for c in expected:
assert c in CHARS.vocab.stoi
expected_freqs = Counter({"a": 6, "b": 6, "c": 1})
assert CHARS.vocab.freqs == CHARS.nesting_field.vocab.freqs == expected_freqs
datafields = [("text",TEXT),("label",LABEL)]
# Load data from pd.DataFrame into torchtext.data.Dataset
train_df = self.get_pandas_df(train_file)
train_examples = [data.Example.fromlist(i, datafields) for i in train_df.values.tolist()]
train_data = data.Dataset(train_examples, datafields)
test_df = self.get_pandas_df(test_file)
test_examples = [data.Example.fromlist(i, datafields) for i in test_df.values.tolist()]
test_data = data.Dataset(test_examples, datafields)
# If validation file exists, load it. Otherwise get validation data from training data
if val_file:
val_df = self.get_pandas_df(val_file)
val_examples = [data.Example.fromlist(i, datafields) for i in val_df.values.tolist()]
val_data = data.Dataset(val_examples, datafields)
else:
train_data, val_data = train_data.split(split_ratio=0.8)
TEXT.build_vocab(train_data, vectors=Vectors(w2v_file))
self.word_embeddings = TEXT.vocab.vectors
self.vocab = TEXT.vocab
self.train_iterator = data.BucketIterator(
(train_data),
batch_size=self.config.batch_size,
sort_key=lambda x: len(x.text),
repeat=False,
shuffle=True)
self.val_iterator, self.test_iterator = data.BucketIterator.splits(
(val_data, test_data),
def merge_vocabs(vocabs, specials, vocab_size=None):
"""
Merge individual vocabularies (assumed to be generated from disjoint
documents) into a larger vocabulary.
Args:
vocabs: `torchtext.vocab.Vocab` vocabularies to be merged
vocab_size: `int` the final vocabulary size. `None` for no limit.
Return:
`torchtext.vocab.Vocab`
"""
merged = sum([vocab.freqs for vocab in vocabs], Counter())
return torchtext.vocab.Vocab(merged,
specials=specials,
max_size=vocab_size)
class NMTDataset(torchtext.data.Dataset):
# @staticmethod
# def sort_key(ex):
# return data.interleave_keys(len(ex.src), len(ex.tgt))
def sort_key(self, ex):
""" Sort using length of source sentences. """
# Default to a balanced sort, prioritizing tgt len match.
# TODO: make this configurable.
if hasattr(ex, "tgt"):
return -len(ex.src), -len(ex.tgt)
return -len(ex.src)
def __init__(self, src_path, tgt_path, fields, **kwargs):
make_example = torchtext.data.Example.fromlist
for i, arr in enumerate(arrs):
if self.nesting_field.include_lengths:
arr = tuple(arr)
numericalized_ex = self.nesting_field.numericalize(
arr, device=device, train=train)
if self.nesting_field.include_lengths:
numericalized_ex, lengths_ex = numericalized_ex
numericalized[i, :, :] = numericalized_ex.data
lengths[i, :len(lengths_ex)] = lengths_ex
if self.nesting_field.include_lengths:
return Variable(numericalized), lengths
else:
return Variable(numericalized)
class SequenceTaggingDataset(data.Dataset):
"""Defines a dataset for sequence tagging. Examples in this dataset
contain paired lists -- paired list of words and tags.
For example, in the case of part-of-speech tagging, an example is of the
form
[I, love, PyTorch, .] paired with [PRON, VERB, PROPN, PUNCT]
See torchtext/test/sequence_tagging.py on how to use this class.
"""
@staticmethod
def sort_key(example):
for attr in dir(example):
if not callable(getattr(example, attr)) and \
not attr.startswith("__"):
return len(getattr(example, attr))
src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts)
examples = []
with io.open(src_path, mode='r', encoding='utf-8') as src_file, \
io.open(trg_path, mode='r', encoding='utf-8') as trg_file:
for src_line, trg_line in zip(src_file, trg_file):
src_line, trg_line = src_line.strip(), trg_line.strip()
if src_line != '' and trg_line != '':
examples.append(data.Example.fromlist(
[src_line, trg_line], fields))
with open(cache_file, 'wb') as fp:
pickle.dump(examples, file=fp)
data.Dataset.__init__(self, examples, fields, **kwargs)
opt.max_token_seq_len = data['settings'].max_len
opt.src_pad_idx = data['vocab']['src'].vocab.stoi[Constants.PAD_WORD]
opt.trg_pad_idx = data['vocab']['trg'].vocab.stoi[Constants.PAD_WORD]
opt.src_vocab_size = len(data['vocab']['src'].vocab)
opt.trg_vocab_size = len(data['vocab']['trg'].vocab)
#========= Preparing Model =========#
if opt.embs_share_weight:
assert data['vocab']['src'].vocab.stoi == data['vocab']['trg'].vocab.stoi, \
'To sharing word embedding the src/trg word2idx table shall be the same.'
fields = {'src': data['vocab']['src'], 'trg':data['vocab']['trg']}
train = Dataset(examples=data['train'], fields=fields)
val = Dataset(examples=data['valid'], fields=fields)
train_iterator = BucketIterator(train, batch_size=batch_size, device=device, train=True)
val_iterator = BucketIterator(val, batch_size=batch_size, device=device)
return train_iterator, val_iterator
def get_dataset(json_path, cache_root="./data/cache", mode="train", tokenizer=None):
examples, fields = build_examples_from_json(json_path, mode, tokenizer)
# examples_file_name="%s%s.examples" % (mode, "_DEBUG" if DEBUG else "")
# example_file = os.path.join(cache_root, examples_file_name)
#
# if os.path.exists(example_file):
# print("loading examples from %s" % example_file)
# examples = pickle.load(open(example_file, "rb"))
# else:
# print("building examples %s" % example_file)
# examples = extract_method(fields, json_path)
# pickle.dump(examples, open(example_file, "wb"))
squad_dataset = data.Dataset(examples, fields)
# build voc
squad_dataset.fields["passage"].build_vocab(squad_dataset, [x.question for x in squad_dataset.examples],
wv_type='glove.840B', wv_dir="./data/embedding/glove_word/",
unk_init="zero")
squad_dataset.fields["question"].vocab = squad_dataset.fields["passage"].vocab
#
# squad_dataset.fields["question"].build_vocab(squad_dataset, [x.passage for x in squad_dataset.examples],
# wv_type='glove.840B', wv_dir="./data/embedding/glove_word/",
# unk_init="zero")
dataset_file_name = "%s%s.dataset" % (mode, "_DEBUG" if DEBUG else "")
# squad_dataset.fields["answer_text"].build_vocab(squad_dataset,[x.passage for x in squad_dataset.examples], wv_type='glove.840B',
import os
import glob
import io
from .. import data
class IMDB(data.Dataset):
urls = ['http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz']
name = 'imdb'
dirname = 'aclImdb'
@staticmethod
def sort_key(ex):
return len(ex.text)
def __init__(self, path, text_field, label_field, **kwargs):
"""Create an IMDB dataset instance given a path and fields.
Arguments:
path: Path to the dataset's highest level directory
text_field: The field that will be used for text data.
label_field: The field that will be used for label data.
sentences = augmentation(sentences, pos_dict)
else:
sentences = [text for text, _ in input_tsv]
# Load teacher model
model = BertForSequenceClassification.from_pretrained(args.model).to(device)
tokenizer = BertTokenizer.from_pretrained(args.model, do_lower_case=True)
# Assign labels with teacher
teacher_field = data.Field(sequential=True, tokenize=tokenizer.tokenize, lower=True, include_lengths=True, batch_first=True)
fields = [("text", teacher_field)]
if not args.no_augment:
examples = [data.Example.fromlist([" ".join(words)], fields) for words in sentences]
else:
examples = [data.Example.fromlist([text], fields) for text in sentences]
augmented_dataset = data.Dataset(examples, fields)
teacher_field.vocab = BertVocab(tokenizer.vocab)
new_labels = BertTrainer(model, device, batch_size=args.batch_size).infer(augmented_dataset)
# Write to file
with open(args.output, "w") as f:
f.write("sentence\tscores\n")
for sentence, rating in zip(sentences, new_labels):
if not args.no_augment:
text = " ".join(sentence)
else: text = sentence
f.write("%s\t%.6f %.6f\n" % (text, *rating))
from .. import data
import random
class SequenceTaggingDataset(data.Dataset):
"""Defines a dataset for sequence tagging. Examples in this dataset
contain paired lists -- paired list of words and tags.
For example, in the case of part-of-speech tagging, an example is of the
form
[I, love, PyTorch, .] paired with [PRON, VERB, PROPN, PUNCT]
See torchtext/test/sequence_tagging.py on how to use this class.
"""
@staticmethod
def sort_key(example):
for attr in dir(example):
if not callable(getattr(example, attr)) and \
not attr.startswith("__"):
return len(getattr(example, attr))