Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def setUp(self):
self.n = 5
self.base = range(100)
self.data = ZipDataset(*[self.base for _ in range(5)])
def test_returns_zip_dataset(self):
self.assertIsInstance(self.data, ZipDataset)
def test_zips_multiple_files(self):
fp = self.fp
lines = self.lines
data = TextDataset([fp.name, fp.name], mode='zip')
for x, y in zip(data, lines):
self.assertTupleEqual(x, (y, y))
for j, y in enumerate(lines):
self.assertTupleEqual(data[j], (y, y))
self.assertEqual(len(data), len(lines))
self.assertEqual(data._length, len(lines))
self.assertIsInstance(data._dataset, lineflow.core.ZipDataset)
self.assertIsInstance(data.map(lambda x: x)._dataset, TextDataset)
with io.open(path, 'wb') as f:
pickle.dump(dataset, f)
return dataset
def loader(path):
with io.open(path, 'rb') as f:
return pickle.load(f)
pkl_path = os.path.join(root, 'cnndm.pkl')
return download.cache_or_load_file(pkl_path, creator, loader)
cached_get_wmt14 = lru_cache()(get_wmt14)
class Wmt14(ZipDataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'dev', 'test'}:
raise ValueError(f"only 'train', 'dev' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_wmt14()
super(Wmt14, self).__init__(*raw[split])
def lineflow_zip(*datasets: List[DatasetMixin]) -> ZipDataset:
return ZipDataset(*datasets)
with io.open(path, 'wb') as f:
pickle.dump(dataset, f)
return dataset
def loader(path):
with io.open(path, 'rb') as f:
return pickle.load(f)
pkl_path = os.path.join(root, 'cnndm.pkl')
return download.cache_or_load_file(pkl_path, creator, loader)
cached_get_cnn_dailymail = lru_cache()(get_cnn_dailymail)
class CnnDailymail(ZipDataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'dev', 'test'}:
raise ValueError(f"only 'train', 'dev' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_cnn_dailymail()
super(CnnDailymail, self).__init__(*raw[split])
def __init__(self,
paths: Union[str, List[str]],
encoding: str = 'utf-8',
mode: str = 'zip') -> None:
if isinstance(paths, str):
dataset = easyfile.TextFile(paths, encoding)
elif isinstance(paths, list):
if mode == 'zip':
dataset = ZipDataset(*[easyfile.TextFile(p, encoding) for p in paths])
elif mode == 'concat':
dataset = ConcatDataset(*[easyfile.TextFile(p, encoding) for p in paths])
else:
raise ValueError(f"only 'zip' and 'concat' are valid for 'mode', but '{mode}' is given.")
super().__init__(dataset)