Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_DataLoader_init(get_fake_dataset):
fake_dataset = get_fake_dataset
dataset = ContinuumSetLoader(fake_dataset)
if not dataset.current_task == 0:
raise AssertionError("Test fail")
def test_disjoint_vanilla_test(dataset, n_tasks):
# no need to download the dataset again for this test (if it already exists)
input_folder = os.path.join(dir_data, 'Data')
Disjoint(path=input_folder, dataset=dataset, tasks_number=n_tasks, download=False, train=False)
check_task_sequences_files(scenario="Disjoint", folder=dir_data, n_tasks=n_tasks, dataset=dataset, train=False)
def test_DataLoader_init_label_size(get_fake_dataset):
"""
Test if the dictionnary of label have the good size
:param get_fake_dataset:
:return:
"""
fake_dataset = get_fake_dataset
dataset = ContinuumSetLoader(fake_dataset)
if not len(dataset.labels) == dataset_size:
raise AssertionError("Test fail")
def test_DataLoader_with_torch(get_fake_dataset):
"""
Test if the dataloader can be used with torch.utils.data.DataLoader
:param get_fake_dataset:
:return:
"""
fake_dataset = get_fake_dataset
dataset = ContinuumSetLoader(fake_dataset)
train_loader = data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=6)
for _, (_, _) in enumerate(train_loader):
break
def test_DataLoader_with_torch_loader(get_fake_dataset):
"""
Test if the dataloader with torch.utils.data.DataLoader provide data of good type
:param get_fake_dataset:
:return:
"""
fake_dataset = get_fake_dataset
dataset = ContinuumSetLoader(fake_dataset)
train_loader = data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=6)
for _, (batch, label) in enumerate(train_loader):
if not isinstance(label, torch.LongTensor):
raise AssertionError("Test fail")
if not isinstance(batch, torch.FloatTensor):
raise AssertionError("Test fail")
break
def test_DataLoader_init_label_is_dict(get_fake_dataset):
"""
Test if the dictionnary of label is really a dictionnary
:param get_fake_dataset:
:return:
"""
fake_dataset = get_fake_dataset
dataset = ContinuumSetLoader(fake_dataset)
if not isinstance(dataset.labels, dict):
raise AssertionError("Test fail")
def test_download(tmpdir, dataset):
continuum = Disjoint(path=tmpdir, dataset=dataset, tasks_number=1, download=False, train=True)
if continuum is None:
raise AssertionError("Object construction has failed")
def test_disjoint_vanilla_train(dataset, n_tasks):
# no need to download the dataset again for this test (if it already exists)
input_folder = os.path.join(dir_data, 'Data')
Disjoint(path=input_folder, dataset=dataset, tasks_number=n_tasks, download=False, train=True)
check_task_sequences_files(scenario="Disjoint", folder=dir_data, n_tasks=n_tasks, dataset=dataset, train=True)
def test_mnist_fellowship():
# no need to download the dataset again for this test (if it already exists)
input_folder = os.path.join(dir_data, 'Data')
MnistFellowship(path=input_folder, merge=False, download=False, train=True)
check_task_sequences_files(scenario="mnist_fellowship",
folder=dir_data,
n_tasks=3,
dataset="mnist_fellowship",
train=True)
def test_mnist_fellowship_merge():
# no need to download the dataset again for this test (if it already exists)
input_folder = os.path.join(dir_data, 'Data')
MnistFellowship(path=input_folder, merge=True, download=False, train=True)
check_task_sequences_files(scenario="mnist_fellowship_merge",
folder=dir_data,
n_tasks=3,
dataset="mnist_fellowship",
train=True)