Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def filter_init(ex_val1, ex_val2, ex_val3):
text_field = data.Field(sequential=True)
label_field = data.Field(sequential=False)
fields = [("text1", text_field), ("text2", text_field),
("label", label_field)]
example1 = data.Example.fromlist(ex_val1, fields)
example2 = data.Example.fromlist(ex_val2, fields)
example3 = data.Example.fromlist(ex_val3, fields)
examples = [example1, example2, example3]
dataset = data.Dataset(examples, fields)
text_field.build_vocab(dataset)
return dataset, text_field
def test_serialization(self):
nesting_field = data.Field(batch_first=True)
field = data.NestedField(nesting_field)
ex1 = data.Example.fromlist(["john loves mary"], [("words", field)])
ex2 = data.Example.fromlist(["mary cries"], [("words", field)])
dataset = data.Dataset([ex1, ex2], [("words", field)])
field.build_vocab(dataset)
examples_data = [
[
["", "<s>", "</s><s>"] + [""] * 4,
[""] + list("john") + ["", ""],
[""] + list("loves") + [""],
[""] + list("mary") + ["", ""],
["", "</s>", ""] + [""] * 4,
],
[
["", "<s>", "</s><s>"] + [""] * 4,
[""] + list("mary") + ["", ""],
[""] + list("cries") + [""],
["", "</s>", ""] + [""] * 4,
def __init__(self, path, fields, separator="\t", **kwargs):
examples = []
with codecs.open(path, 'r', encoding='utf-8') as input_file:
for line in input_file:
line = line.strip()
if len(line) != 0:
words, labels = line.split(separator)
columns = []
columns.append(words.split(' '))
columns.append([label for label in labels.split(' ') if label])
examples.append(data.Example.fromlist(columns, fields))
super(ATISDataset, self).__init__(examples, fields, **kwargs)
for term in answer.split():
if term not in term_idfs:
term_idfs[term] = 0.0
overlap = compute_overlap([question], [answer])
idf_weighted_overlap = compute_idf_weighted_overlap([question], [answer], term_idfs)
overlap_no_stopwords =\
compute_overlap(stopped([question]), stopped([answer]))
idf_weighted_overlap_no_stopwords =\
compute_idf_weighted_overlap(stopped([question]), stopped([answer]), term_idfs)
ext_feats = str(overlap[0]) + " " + str(idf_weighted_overlap[0]) + " " + \
str(overlap_no_stopwords[0]) + " " + str(idf_weighted_overlap_no_stopwords[0])
fields = [('question', self.QUESTION), ('answer', self.ANSWER), ('ext_feat', self.EXTERNAL)]
example = data.Example.fromlist([question, answer, ext_feats], fields)
this_question = self.QUESTION.numericalize(self.QUESTION.pad([example.question]), self.gpu)
this_answer = self.ANSWER.numericalize(self.ANSWER.pad([example.answer]), self.gpu)
this_external = self.EXTERNAL.numericalize(self.EXTERNAL.pad([example.ext_feat]), self.gpu)
self.model.eval()
scores = self.model(this_question, this_answer, this_external)
scores_sentences.append((scores[:, 2].cpu().data.numpy()[0].tolist(), answer))
return scores_sentences
def get_features(self, dataset):
dataset = remove_neutral(dataset)
vectors = Vectors(dataset.word_embeddings)
self.get_labels(dataset)
fields = {'text': ('text', self.text_field), 'label': ('label', self.label_field)}
text = dataset.data_table[dataset.text_column].to_numpy()
labels = dataset.data_table['label'].to_numpy()
examples = [Example.fromdict(
data={'text': text[x], 'label': labels[x]}, fields=fields) for x in range(labels.shape[0])]
torch_dataset = TorchtextDataset(examples, {'text': self.text_field, 'label': self.label_field})
try:
self.text_field.vocab
except AttributeError:
self.text_field.build_vocab(torch_dataset, vectors=vectors)
self.label_field.build_vocab(torch_dataset)
loader = BucketIterator(torch_dataset, batch_size=25)
return loader
def read_examples(paths, fields, data_dir, mode, filter_pred, num_shard):
data_path_fmt = data_dir + '/examples-' + mode + '-{}.pt'
data_paths = [data_path_fmt.format(i) for i in range(num_shard)]
writers = [open(data_path, 'wb') for data_path in data_paths]
shard = 0
for path in paths:
print("Preprocessing {}".format(path))
with io.open(path, mode='r', encoding='utf-8') as trg_file:
for trg_line in tqdm(trg_file, ascii=True):
trg_line = trg_line.strip()
if trg_line == '':
continue
example = data.Example.fromlist([trg_line], fields)
if not filter_pred(example):
continue
pickle.dump(example, writers[shard])
shard = (shard + 1) % num_shard
for writer in writers:
writer.close()
# Reload pickled objects, and save them again as a list.
common.pickles_to_torch(data_paths)
examples = torch.load(data_paths[0])
return examples, data_paths
torch.save([self.i, self.opt.state_dict()], self.path + '.states')
os.remove(self.path + '.temp')
def __getattr__(self, key):
if key in self.metrics:
return self.metrics[key]
raise AttributeError
def __repr__(self):
return ("BEST: " +
', '.join(f'{metric}: {getattr(self, metric):.3f}'
for metric, value in self.metrics.items()
if value is not None))
class CacheExample(data.Example):
@classmethod
def fromsample(cls, data_lists, names):
ex = cls()
for data, name in zip(data_lists, names):
setattr(ex, name, data)
return ex
class Cache:
def __init__(self, size=10000, fileds=["src", "trg"]):
self.cache = []
self.maxsize = size
def demask(self, data, mask):
def __init__(self, path, text_field, only_supporting=False, **kwargs):
fields = [('story', text_field), ('query', text_field), ('answer', text_field)]
self.sort_key = lambda x: len(x.query)
with open(path, 'r', encoding="utf-8") as f:
triplets = self._parse(f, only_supporting)
examples = [Example.fromlist(triplet, fields) for triplet in triplets]
super(BABI20, self).__init__(examples, fields, **kwargs)
def __init__(self, src_path, tgt_path, fields, **kwargs):
make_example = torchtext.data.Example.fromlist
with codecs.open(src_path, encoding="utf8",errors='ignore') as src_f, \
codecs.open(tgt_path, encoding="utf8",errors='ignore') as tgt_f:
examples = []
for src,tgt in zip(src_f,tgt_f):
src,tgt = src.strip(),tgt.strip()
examples.append(make_example([src,tgt],fields))
super(NMTDataset, self).__init__(examples, fields, **kwargs)
string = re.sub(r"\s{2,}", " ", string)
string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
return string.strip()
text_field.preprocessing = data.Pipeline(clean_str)
fields = [('text', text_field), ('label', label_field)]
if examples is None:
path = self.dirname if path is None else path
examples = []
with open(os.path.join(path, 'train_pos.tsv'), errors='ignore') as f:
examples += [
data.Example.fromlist([line, line, 'pos'], fields) for line in f]
with open(os.path.join(path, 'train_neg.tsv'), errors='ignore') as f:
examples += [
data.Example.fromlist([line, line, 'neg'], fields) for line in f]
super(dataset, self).__init__(examples, fields, **kwargs)