Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def split_dataset_n(dataset: Dataset,
n: int,
indices: List[int] = None) -> List[SubDataset]:
n_examples = len(dataset)
sub_size = n_examples // n
return [SubDataset(dataset, sub_size * i, sub_size * (i + 1), indices)
for i in range(n)]
def split_dataset_n_random(dataset: Dataset,
n: int,
seed=None) -> List[SubDataset]:
n_examples = len(dataset)
sub_size = n_examples // n
random.seed(seed)
indices = list(range(len(dataset)))
random.shuffle(indices)
return [SubDataset(dataset, sub_size * i, sub_size * (i + 1), indices)
for i in range(n)]
for i, y in enumerate(lines):
self.assertEqual(data[i], y)
self.assertEqual(len(data), len(lines))
self.assertEqual(data._length, len(lines))
# check if length is cached
self.assertEqual(len(data), len(lines))
self.assertIsInstance(data._dataset, easyfile.TextFile)
data = data.map(str.split)
for x, y in zip(data, lines):
self.assertEqual(x, y.split())
self.assertIsInstance(data, lineflow.core.MapDataset)
self.assertIsInstance(data._dataset, TextDataset)
def test_raises_value_error_with_invalid_split(self):
with self.assertRaises(ValueError):
Squad(split='invalid_split')
def test_loads_v2_each_split(self):
train = Squad(split='train', version=2)
self.assertEqual(len(train), 130_319)
dev = Squad(split='dev', version=2)
self.assertEqual(len(dev), 11_873)
def test_loads_v1_each_split(self):
train = Squad(split='train', version=1)
self.assertEqual(len(train), 87_599)
dev = Squad(split='dev', version=1)
self.assertEqual(len(dev), 10_570)
def test_loads_v1_each_split(self):
train = Squad(split='train', version=1)
self.assertEqual(len(train), 87_599)
dev = Squad(split='dev', version=1)
self.assertEqual(len(dev), 10_570)
def test_loads_v2_each_split(self):
train = Squad(split='train', version=2)
self.assertEqual(len(train), 130_319)
dev = Squad(split='dev', version=2)
self.assertEqual(len(dev), 11_873)
def test_raises_value_error_with_invalid_split(self):
with self.assertRaises(ValueError):
CnnDailymail(split='invalid_split')
def test_loads_each_split(self):
train = CnnDailymail(split='train')
self.assertEqual(len(train), 287_227)
dev = CnnDailymail(split='dev')
self.assertEqual(len(dev), 13_368)
test = CnnDailymail(split='test')
self.assertEqual(len(test), 11_490)