Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def setUp(self):
self.data = Dataset(range(100))
def setUp(self):
self.data = Dataset(range(100))
window_size = 3
expected = []
it = iter(range(100))
window = tuple(itertools.islice(it, window_size))
while window:
expected.append(window)
window = tuple(itertools.islice(it, window_size))
self.expected = expected
self.window_size = window_size
return download.cache_or_load_file(pkl_path, creator, loader)
cached_get_text_classification_dataset = lru_cache()(get_text_classification_dataset)
class AgNews(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'test'}:
raise ValueError(f"only 'train' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_text_classification_dataset('ag_news')
super(AgNews, self).__init__(raw[split])
class SogouNews(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'test'}:
raise ValueError(f"only 'train' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_text_classification_dataset('sogou_news')
super(SogouNews, self).__init__(raw[split])
class Dbpedia(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'test'}:
raise ValueError(f"only 'train' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_text_classification_dataset('dbpedia')
super(Dbpedia, self).__init__(raw[split])
with io.open(path, 'wb') as f:
pickle.dump(dataset, f)
return dataset
def loader(path):
with io.open(path, 'rb') as f:
return pickle.load(f)
pkl_path = os.path.join(root, 'ptb.pkl')
return download.cache_or_load_file(pkl_path, creator, loader)
cached_get_penn_treebank = lru_cache()(get_penn_treebank)
class PennTreebank(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'dev', 'test'}:
raise ValueError(f"only 'train', 'dev' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_penn_treebank()
super(PennTreebank, self).__init__(raw[split])
with io.open(path, 'wb') as f:
pickle.dump(dataset, f)
return dataset
def loader(path):
with io.open(path, 'rb') as f:
return pickle.load(f)
pkl_path = os.path.join(root, 'enja.pkl')
return download.cache_or_load_file(pkl_path, creator, loader)
cached_get_small_parallel_enja = lru_cache()(get_small_parallel_enja)
class SmallParallelEnJa(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'dev', 'test'}:
raise ValueError(f"only 'train', 'dev' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_small_parallel_enja()
super().__init__(raw[split])
assert key in urls
if key in ('ag_news', 'dpbedia'):
creator = list_creator
else:
creator = easyfile_creator
pkl_path = os.path.join(root, f'{key}.pkl')
return download.cache_or_load_file(pkl_path, creator, loader)
cached_get_text_classification_dataset = lru_cache()(get_text_classification_dataset)
class AgNews(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'test'}:
raise ValueError(f"only 'train' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_text_classification_dataset('ag_news')
super(AgNews, self).__init__(raw[split])
class SogouNews(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'test'}:
raise ValueError(f"only 'train' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_text_classification_dataset('sogou_news')
super(SogouNews, self).__init__(raw[split])
raise ValueError(f"only 'train' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_text_classification_dataset('ag_news')
super(AgNews, self).__init__(raw[split])
class SogouNews(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'test'}:
raise ValueError(f"only 'train' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_text_classification_dataset('sogou_news')
super(SogouNews, self).__init__(raw[split])
class Dbpedia(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'test'}:
raise ValueError(f"only 'train' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_text_classification_dataset('dbpedia')
super(Dbpedia, self).__init__(raw[split])
class YelpReviewPolarity(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'test'}:
raise ValueError(f"only 'train' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_text_classification_dataset('yelp_review_polarity')
super(YelpReviewPolarity, self).__init__(raw[split])
from typing import List, Tuple, Any, Iterator
import random
from lineflow import Dataset
class SubDataset(Dataset):
def __init__(self,
dataset: Dataset,
start: int,
end: int,
indices: List[int] = None) -> None:
if start < 0 or end > len(dataset):
raise ValueError('subset overruns the base dataset.')
self._dataset = dataset
self._start = start
self._end = end
self._size = end - start
if indices is not None and len(indices) != len(dataset):
msg = (f'indices option must have the same length as the base '
'dataset: len(indices) = {len(indices)} while len(dataset) = {len(dataset)}')
raise ValueError(msg)
self._indices = indices or list(range(len(dataset)))
raise ValueError(f"only 'train' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_text_classification_dataset('yelp_review_polarity')
super(YelpReviewPolarity, self).__init__(raw[split])
class YelpReviewFull(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'test'}:
raise ValueError(f"only 'train' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_text_classification_dataset('yelp_review_full')
super(YelpReviewFull, self).__init__(raw[split])
class YahooAnswers(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'test'}:
raise ValueError(f"only 'train' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_text_classification_dataset('yahoo_answers')
super(YahooAnswers, self).__init__(raw[split])
class AmazonReviewPolarity(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'test'}:
raise ValueError(f"only 'train' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_text_classification_dataset('amazon_review_polarity')
super(AmazonReviewPolarity, self).__init__(raw[split])
raise ValueError(f"only 'train' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_text_classification_dataset('yelp_review_full')
super(YelpReviewFull, self).__init__(raw[split])
class YahooAnswers(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'test'}:
raise ValueError(f"only 'train' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_text_classification_dataset('yahoo_answers')
super(YahooAnswers, self).__init__(raw[split])
class AmazonReviewPolarity(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'test'}:
raise ValueError(f"only 'train' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_text_classification_dataset('amazon_review_polarity')
super(AmazonReviewPolarity, self).__init__(raw[split])
class AmazonReviewFull(Dataset):
def __init__(self, split: str = 'train') -> None:
if split not in {'train', 'test'}:
raise ValueError(f"only 'train' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_text_classification_dataset('amazon_review_full')
super(AmazonReviewFull, self).__init__(raw[split])