Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
return download.cache_or_load_file(pkl_path, creator, loader)
cached_get_wikitext = lru_cache()(get_wikitext)
class WikiText2(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'dev', 'test'}:
raise ValueError(f"only 'train', 'dev' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_wikitext('wikitext-2')
return super(WikiText2, self).__init__(raw[split])
class WikiText103(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'dev', 'test'}:
raise ValueError(f"only 'train', 'dev' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_wikitext('wikitext-103')
return super(WikiText103, self).__init__(raw[split])
def __iter__(self) -> Iterator[Any]:
for d in self._datasets:
yield from d
def get_example(self, i: int) -> Any:
self._prepare()
j = bisect.bisect_right(self._lengths, i)
return self._datasets[j][i - self._offsets[j]]
def __len__(self) -> int:
self._prepare()
return self._length
class ZipDataset(Dataset):
def __init__(self, *datasets: List[DatasetMixin]) -> None:
assert all(isinstance(d, DatasetMixin) for d in datasets)
self._datasets = datasets
self._length = None
def __iter__(self) -> Iterator[Tuple[Any]]:
yield from zip(*self._datasets)
def get_example(self, i: int) -> Tuple[Any]:
return tuple(d[i] for d in self._datasets)
def __len__(self) -> int:
if self._length is None:
self._length = min(len(d) for d in self._datasets)
return self._length
if self._ready:
yield from self._dataset
else:
iterable, self._iterable = tee(self._iterable)
yield from iterable
def get_example(self, i: int) -> Any:
self._prepare()
return super(IterableDataset, self).get_example(i)
def __len__(self) -> int:
self._prepare()
return super(IterableDataset, self).__len__()
class ConcatDataset(Dataset):
def __init__(self, *datasets: List[DatasetMixin]) -> None:
assert all(isinstance(d, DatasetMixin) for d in datasets)
self._datasets = datasets
self._offsets = None
self._length = None
self._ready = False
def _prepare(self) -> None:
if self._ready:
return
self._lengths = list(accumulate(len(d) for d in self._datasets))
self._offsets = [0] + self._lengths[:-1]
self._length = self._lengths[-1]
self._ready = True
def lineflow_load(filename: str) -> Dataset:
print(f'Loading data from {filename}...')
with open(filename, 'rb') as f:
dataset = pickle.load(f)
return Dataset(dataset)
dataset: DatasetMixin,
map_func: Callable[[Any], Any]) -> None:
assert callable(map_func)
self._map_func = map_func
super(MapDataset, self).__init__(dataset)
def __iter__(self) -> Iterator[Any]:
yield from map(self._map_func, self._dataset)
def get_example(self, i: int) -> Any:
return self._map_func(self._dataset[i])
class CacheDataset(Dataset):
def __init__(self, cache: List[Any]) -> None:
super(CacheDataset, self).__init__(cache)
self._length = len(cache)
def lineflow_concat(*datasets: List[DatasetMixin]) -> ConcatDataset:
return ConcatDataset(*datasets)
def lineflow_zip(*datasets: List[DatasetMixin]) -> ZipDataset:
return ZipDataset(*datasets)
def lineflow_filter(
predicate: Callable[[Any], bool],
self._datasets = datasets
self._length = None
def __iter__(self) -> Iterator[Tuple[Any]]:
yield from zip(*self._datasets)
def get_example(self, i: int) -> Tuple[Any]:
return tuple(d[i] for d in self._datasets)
def __len__(self) -> int:
if self._length is None:
self._length = min(len(d) for d in self._datasets)
return self._length
class MapDataset(Dataset):
def __init__(self,
dataset: DatasetMixin,
map_func: Callable[[Any], Any]) -> None:
assert callable(map_func)
self._map_func = map_func
super(MapDataset, self).__init__(dataset)
def __iter__(self) -> Iterator[Any]:
yield from map(self._map_func, self._dataset)
def get_example(self, i: int) -> Any:
return self._map_func(self._dataset[i])
assert name == 'wikitext-2' or name == 'wikitext-103'
if name == 'wikitext-2':
creator = list_creator
elif name == 'wikitext-103':
creator = easyfile_creator
pkl_path = os.path.join(root, f'{name.replace("-", "")}.pkl')
return download.cache_or_load_file(pkl_path, creator, loader)
cached_get_wikitext = lru_cache()(get_wikitext)
class WikiText2(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'dev', 'test'}:
raise ValueError(f"only 'train', 'dev' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_wikitext('wikitext-2')
return super(WikiText2, self).__init__(raw[split])
class WikiText103(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'dev', 'test'}:
raise ValueError(f"only 'train', 'dev' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_wikitext('wikitext-103')
return super(WikiText103, self).__init__(raw[split])
path = Path(filename)
if path.exists():
print(f'Loading data from {filename}...')
with path.open('rb') as f:
cache = pickle.load(f)
else:
if not path.parent.exists():
path.parent.mkdir(parents=True)
print(f'Saving data to {filename}...')
cache = list(self)
with path.open('wb') as f:
pickle.dump(cache, f)
return CacheDataset(cache)
class IterableDataset(Dataset):
def __init__(self, iterable: Iterable) -> None:
self._dataset = None
self._length = None
self._iterable = iterable
self._ready = False
def _prepare(self) -> None:
if self._ready:
return
self._dataset = list(self._iterable)
self._length = len(self._dataset)
self._ready = True
def __iter__(self) -> Iterator[Any]:
if self._ready:
yield from self._dataset