Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_csv_dataset_quotechar(self):
# Based on issue #349
example_data = [("text", "label"),
('" hello world', "0"),
('goodbye " world', "1"),
('this is a pen " ', "0")]
with tempfile.NamedTemporaryFile(dir=self.test_dir) as f:
for example in example_data:
f.write(six.b("{}\n".format(",".join(example))))
TEXT = data.Field(lower=True, tokenize=lambda x: x.split())
fields = {
"label": ("label", data.Field(use_vocab=False,
sequential=False)),
"text": ("text", TEXT)
}
f.seek(0)
dataset = data.TabularDataset(
path=f.name, format="csv",
skip_header=False, fields=fields,
csv_reader_params={"quotechar": None})
TEXT.build_vocab(dataset)
self.assertEqual(len(dataset), len(example_data) - 1)
for i, example in enumerate(dataset):
def test_dataset_split_arguments(self):
num_examples, num_labels = 30, 3
self.write_test_splitting_dataset(num_examples=num_examples,
num_labels=num_labels)
text_field = data.Field()
label_field = data.LabelField()
fields = [('text', text_field), ('label', label_field)]
dataset = data.TabularDataset(
path=self.test_dataset_splitting_path, format="csv", fields=fields)
# Test default split ratio (0.7)
expected_train_size = 21
expected_test_size = 9
train, test = dataset.split()
assert len(train) == expected_train_size
assert len(test) == expected_test_size
# Test array arguments with same ratio
split_ratio = [0.7, 0.3]
def test_errors(self):
# Ensure that trying to retrieve a key not in JSON data errors
self.write_test_ppid_dataset(data_format="json")
question_field = data.Field(sequential=True)
label_field = data.Field(sequential=False)
fields = {"qeustion1": ("q1", question_field),
"question2": ("q2", question_field),
"label": ("label", label_field)}
with self.assertRaises(ValueError):
data.TabularDataset(
path=self.test_ppid_dataset_path, format="json", fields=fields)
def test_targetfield_specials(self):
test_path = os.path.dirname(os.path.realpath(__file__))
data_path = os.path.join(test_path, 'data/eng-fra.txt')
field = TargetField()
train = torchtext.data.TabularDataset(
path=data_path, format='tsv',
fields=[('src', torchtext.data.Field()), ('trg', field)]
)
self.assertTrue(field.sos_id is None)
self.assertTrue(field.eos_id is None)
field.build_vocab(train)
self.assertFalse(field.sos_id is None)
self.assertFalse(field.eos_id is None)
def udpos_dataset(batch_size):
# Setup fields with batch dimension first
inputs = data.Field(init_token="", eos_token="", batch_first=True)
tags = data.Field(init_token="", eos_token="", batch_first=True)
# Download and the load default data.
train, val, test = datasets.UDPOS.splits(
fields=(('inputs_word', inputs), ('labels', tags), (None, None)))
# Build vocab
inputs.build_vocab(train.inputs)
tags.build_vocab(train.tags)
# Get iterators
train_iter, val_iter, test_iter = data.BucketIterator.splits(
(train, val, test), batch_size=batch_size,
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
train_iter.repeat = False
return train_iter, val_iter, test_iter, inputs, tags
if load_preprocessed:
print("Loading preprocessed data...")
src_field = torch.load(data_dir + '/source.pt')['field']
trg_field = torch.load(data_dir + '/target.pt')['field']
data_paths = glob.glob(data_dir + '/examples-train-*.pt')
examples_train = torch.load(data_paths[0])
examples_val = torch.load(data_dir + '/examples-val-0.pt')
fields = [('src', src_field), ('trg', trg_field)]
train = WMT32k(examples_train, fields, filter_pred=filter_pred)
val = WMT32k(examples_val, fields, filter_pred=filter_pred)
else:
src_field = data.Field(tokenize=tokenize_de, batch_first=True,
pad_token=pad, lower=True, eos_token='')
trg_field = data.Field(tokenize=tokenize_en, batch_first=True,
pad_token=pad, lower=True, eos_token='')
print("Loading data... (this may take a while)")
train, val, data_paths = \
WMT32k.splits(exts=('.de', '.en'),
fields=(src_field, trg_field),
data_dir=data_dir,
filter_pred=filter_pred)
print("Building vocabs... (this may take a while)")
build_vocabs(src_field, trg_field, data_paths)
print("Creating iterators...")
train_iter, val_iter = common.BucketByLengthIterator.splits(
(train, val),
data_paths=data_paths,
from torchtext.data import Field, Dataset
from torchtext.vocab import Vocab
from collections import Counter, OrderedDict
from torch.autograd import Variable
class CharField(Field):
vocab_cls = Vocab
def __init__(self, **kwargs):
super(CharField, self).__init__(**kwargs)
if self.preprocessing is None:
self.preprocessing = lambda x: [list(y) for y in x]
def build_vocab(self, *args, **kwargs):
counter = Counter()
sources = []
for arg in args:
if isinstance(arg, Dataset):
sources += [getattr(arg, name) for name, field in
arg.fields.items() if field is self]
else:
def process_labels(string):
"""
Returns the label string as a list of integers
:param string:
:return:
"""
return [float(x) for x in string]
class AAPD(TabularDataset):
NAME = 'AAPD'
NUM_CLASSES = 54
IS_MULTILABEL = True
TEXT_FIELD = Field(batch_first=True, tokenize=clean_string, include_lengths=True)
LABEL_FIELD = Field(sequential=False, use_vocab=False, batch_first=True, preprocessing=process_labels)
@staticmethod
def sort_key(ex):
return len(ex.text)
@classmethod
def splits(cls, path, train=os.path.join('AAPD', 'train.tsv'),
validation=os.path.join('AAPD', 'dev.tsv'),
test=os.path.join('AAPD', 'test.tsv'), **kwargs):
return super(AAPD, cls).splits(
path, train=train, validation=validation, test=test,
format='tsv', fields=[('label', cls.LABEL_FIELD), ('text', cls.TEXT_FIELD)]
)
@classmethod
def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, device=0, vectors=None,
def __init__(self, corpus_path=None, src_corpus=None, tgt_corpus=None, src_vocab=None, tgt_vocab=None, batch_size=64, batch_type="sentence", max_length=60, n_valid_samples=1000, truncate=None):
assert corpus_path is not None or (src_corpus is not None and tgt_corpus is not None)
assert src_vocab is not None and tgt_vocab is not None
self._batch_size = batch_size
self._fixed_train_batches = None
self._fixed_valid_batches = None
self._max_length = max_length
self._n_valid_samples = n_valid_samples
self._src_field = torchtext.data.Field(pad_token="", preprocessing=lambda seq: ["<s>"] + seq + ["</s>"])
self._src_vocab = self._src_field.vocab = Vocab(src_vocab)
self._tgt_field = torchtext.data.Field(pad_token="", preprocessing=lambda seq: ["<s>"] + seq + ["</s>"])
self._tgt_vocab = self._tgt_field.vocab = Vocab(tgt_vocab)
# Make data
if corpus_path is not None:
self._data = torchtext.data.TabularDataset(
path=corpus_path, format='tsv',
fields=[('src', self._src_field), ('tgt', self._tgt_field)],
filter_pred=self._len_filter
)
else:
self._data = BilingualDataset(src_corpus, tgt_corpus, self._src_field, self._tgt_field, filter_pred=self._len_filter)
# Create training and valid dataset
examples = self._data.examples
if truncate is not None:
assert type(truncate) == int
examples = examples[:truncate]
n_train_samples = len(examples) - n_valid_samples
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = ("cpu")
# devices = [0, 1, 2, 3]
#####################
# Data Loading #
#####################
BOS_WORD = '<s>'
EOS_WORD = '</s>'
BLANK_WORD = ""
MIN_FREQ = 2
spacy_en = spacy.load('en')
def tokenize_en(text):
return [tok.text for tok in spacy_en.tokenizer(text)]
TEXT = data.Field(tokenize=tokenize_en, init_token = BOS_WORD,
eos_token = EOS_WORD, pad_token=BLANK_WORD)
train = datasets.TranslationDataset(path=os.path.join(SRC_DIR, DATA),
exts=('.train.src', '.train.trg'), fields=(TEXT, TEXT))
val = datasets.TranslationDataset(path=os.path.join(SRC_DIR, DATA),
exts=('.val.src', '.val.trg'), fields=(TEXT, TEXT))
train_iter = MyIterator(train, batch_size=BATCH_SIZE, device=device,
repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
batch_size_fn=batch_size_fn, train=True)
valid_iter = MyIterator(val, batch_size=BATCH_SIZE, device=device,
repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
batch_size_fn=batch_size_fn, train=False)
random_idx = random.randint(0, len(train) - 1)
print(train[random_idx].src)