Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_DataLoader_init(get_fake_dataset):
fake_dataset = get_fake_dataset
dataset = ContinuumSetLoader(fake_dataset)
if not dataset.current_task == 0:
raise AssertionError("Test fail")
def test_DataLoader_init_label_size(get_fake_dataset):
"""
Test if the dictionnary of label have the good size
:param get_fake_dataset:
:return:
"""
fake_dataset = get_fake_dataset
dataset = ContinuumSetLoader(fake_dataset)
if not len(dataset.labels) == dataset_size:
raise AssertionError("Test fail")
def test_DataLoader_with_torch(get_fake_dataset):
"""
Test if the dataloader can be used with torch.utils.data.DataLoader
:param get_fake_dataset:
:return:
"""
fake_dataset = get_fake_dataset
dataset = ContinuumSetLoader(fake_dataset)
train_loader = data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=6)
for _, (_, _) in enumerate(train_loader):
break
def test_DataLoader_with_torch_loader(get_fake_dataset):
"""
Test if the dataloader with torch.utils.data.DataLoader provide data of good type
:param get_fake_dataset:
:return:
"""
fake_dataset = get_fake_dataset
dataset = ContinuumSetLoader(fake_dataset)
train_loader = data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=6)
for _, (batch, label) in enumerate(train_loader):
if not isinstance(label, torch.LongTensor):
raise AssertionError("Test fail")
if not isinstance(batch, torch.FloatTensor):
raise AssertionError("Test fail")
break
def test_DataLoader_init_label_is_dict(get_fake_dataset):
"""
Test if the dictionnary of label is really a dictionnary
:param get_fake_dataset:
:return:
"""
fake_dataset = get_fake_dataset
dataset = ContinuumSetLoader(fake_dataset)
if not isinstance(dataset.labels, dict):
raise AssertionError("Test fail")
import os.path
import torch
from copy import deepcopy
from .continuum_loader import ContinuumSetLoader
from .data_utils import load_data, check_and_Download_data, get_images_format
class ContinuumBuilder(ContinuumSetLoader):
'''Parent Class for Sequence Formers'''
def __init__(self, path, dataset, tasks_number, scenario, num_classes, download=False, train=True, path_only=False, verbose=False):
self.tasks_number = tasks_number
self.num_classes = num_classes
self.dataset = dataset
self.i = os.path.join(path, "Datasets")
self.o = os.path.join(path, "Continua", self.dataset)
self.train = train
self.imageSize, self.img_channels = get_images_format(self.dataset)
self.scenario = scenario
self.verbose = verbose
self.path_only = path_only
self.download = download