How to use the lineflow.core.Dataset function in lineflow

To help you get started, we’ve selected a few lineflow examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github tofunlp / lineflow / lineflow / datasets / wikitext.py View on Github external
return download.cache_or_load_file(pkl_path, creator, loader)


cached_get_wikitext = lru_cache()(get_wikitext)


class WikiText2(Dataset):
    def __init__(self, split: str = 'train') -> None:
        if split not in {'train', 'dev', 'test'}:
            raise ValueError(f"only 'train', 'dev' and 'test' are valid for 'split', but '{split}' is given.")

        raw = cached_get_wikitext('wikitext-2')
        return super(WikiText2, self).__init__(raw[split])


class WikiText103(Dataset):
    def __init__(self, split: str = 'train') -> None:
        if split not in {'train', 'dev', 'test'}:
            raise ValueError(f"only 'train', 'dev' and 'test' are valid for 'split', but '{split}' is given.")

        raw = cached_get_wikitext('wikitext-103')
        return super(WikiText103, self).__init__(raw[split])
github tofunlp / lineflow / lineflow / core.py View on Github external
def __iter__(self) -> Iterator[Any]:
        for d in self._datasets:
            yield from d

    def get_example(self, i: int) -> Any:
        self._prepare()
        j = bisect.bisect_right(self._lengths, i)
        return self._datasets[j][i - self._offsets[j]]

    def __len__(self) -> int:
        self._prepare()
        return self._length


class ZipDataset(Dataset):
    def __init__(self, *datasets: List[DatasetMixin]) -> None:
        assert all(isinstance(d, DatasetMixin) for d in datasets)
        self._datasets = datasets
        self._length = None

    def __iter__(self) -> Iterator[Tuple[Any]]:
        yield from zip(*self._datasets)

    def get_example(self, i: int) -> Tuple[Any]:
        return tuple(d[i] for d in self._datasets)

    def __len__(self) -> int:
        if self._length is None:
            self._length = min(len(d) for d in self._datasets)
        return self._length
github tofunlp / lineflow / lineflow / core.py View on Github external
if self._ready:
            yield from self._dataset
        else:
            iterable, self._iterable = tee(self._iterable)
            yield from iterable

    def get_example(self, i: int) -> Any:
        self._prepare()
        return super(IterableDataset, self).get_example(i)

    def __len__(self) -> int:
        self._prepare()
        return super(IterableDataset, self).__len__()


class ConcatDataset(Dataset):
    def __init__(self, *datasets: List[DatasetMixin]) -> None:
        assert all(isinstance(d, DatasetMixin) for d in datasets)

        self._datasets = datasets
        self._offsets = None
        self._length = None
        self._ready = False

    def _prepare(self) -> None:
        if self._ready:
            return
        self._lengths = list(accumulate(len(d) for d in self._datasets))
        self._offsets = [0] + self._lengths[:-1]
        self._length = self._lengths[-1]
        self._ready = True
github tofunlp / lineflow / lineflow / core.py View on Github external
def lineflow_load(filename: str) -> Dataset:
    print(f'Loading data from {filename}...')
    with open(filename, 'rb') as f:
        dataset = pickle.load(f)
    return Dataset(dataset)
github tofunlp / lineflow / lineflow / core.py View on Github external
dataset: DatasetMixin,
                 map_func: Callable[[Any], Any]) -> None:
        assert callable(map_func)

        self._map_func = map_func

        super(MapDataset, self).__init__(dataset)

    def __iter__(self) -> Iterator[Any]:
        yield from map(self._map_func, self._dataset)

    def get_example(self, i: int) -> Any:
        return self._map_func(self._dataset[i])


class CacheDataset(Dataset):
    def __init__(self, cache: List[Any]) -> None:
        super(CacheDataset, self).__init__(cache)

        self._length = len(cache)


def lineflow_concat(*datasets: List[DatasetMixin]) -> ConcatDataset:
    return ConcatDataset(*datasets)


def lineflow_zip(*datasets: List[DatasetMixin]) -> ZipDataset:
    return ZipDataset(*datasets)


def lineflow_filter(
        predicate: Callable[[Any], bool],
github tofunlp / lineflow / lineflow / core.py View on Github external
self._datasets = datasets
        self._length = None

    def __iter__(self) -> Iterator[Tuple[Any]]:
        yield from zip(*self._datasets)

    def get_example(self, i: int) -> Tuple[Any]:
        return tuple(d[i] for d in self._datasets)

    def __len__(self) -> int:
        if self._length is None:
            self._length = min(len(d) for d in self._datasets)
        return self._length


class MapDataset(Dataset):
    def __init__(self,
                 dataset: DatasetMixin,
                 map_func: Callable[[Any], Any]) -> None:
        assert callable(map_func)

        self._map_func = map_func

        super(MapDataset, self).__init__(dataset)

    def __iter__(self) -> Iterator[Any]:
        yield from map(self._map_func, self._dataset)

    def get_example(self, i: int) -> Any:
        return self._map_func(self._dataset[i])
github tofunlp / lineflow / lineflow / datasets / wikitext.py View on Github external
assert name == 'wikitext-2' or name == 'wikitext-103'

    if name == 'wikitext-2':
        creator = list_creator
    elif name == 'wikitext-103':
        creator = easyfile_creator

    pkl_path = os.path.join(root, f'{name.replace("-", "")}.pkl')
    return download.cache_or_load_file(pkl_path, creator, loader)


cached_get_wikitext = lru_cache()(get_wikitext)


class WikiText2(Dataset):
    def __init__(self, split: str = 'train') -> None:
        if split not in {'train', 'dev', 'test'}:
            raise ValueError(f"only 'train', 'dev' and 'test' are valid for 'split', but '{split}' is given.")

        raw = cached_get_wikitext('wikitext-2')
        return super(WikiText2, self).__init__(raw[split])


class WikiText103(Dataset):
    def __init__(self, split: str = 'train') -> None:
        if split not in {'train', 'dev', 'test'}:
            raise ValueError(f"only 'train', 'dev' and 'test' are valid for 'split', but '{split}' is given.")

        raw = cached_get_wikitext('wikitext-103')
        return super(WikiText103, self).__init__(raw[split])
github tofunlp / lineflow / lineflow / core.py View on Github external
path = Path(filename)
        if path.exists():
            print(f'Loading data from {filename}...')
            with path.open('rb') as f:
                cache = pickle.load(f)
        else:
            if not path.parent.exists():
                path.parent.mkdir(parents=True)
            print(f'Saving data to {filename}...')
            cache = list(self)
            with path.open('wb') as f:
                pickle.dump(cache, f)
        return CacheDataset(cache)


class IterableDataset(Dataset):
    def __init__(self, iterable: Iterable) -> None:
        self._dataset = None
        self._length = None
        self._iterable = iterable
        self._ready = False

    def _prepare(self) -> None:
        if self._ready:
            return
        self._dataset = list(self._iterable)
        self._length = len(self._dataset)
        self._ready = True

    def __iter__(self) -> Iterator[Any]:
        if self._ready:
            yield from self._dataset