Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
import torch
import os
from continuum.continuumbuilder import ContinuumBuilder
from copy import deepcopy
class Permutations(ContinuumBuilder):
'''Scenario : In this scenario, for each tasks all classes are available, however for each task pixels are permutated.
The goal is to test algorithms where all data for each classes are not available simultaneously and are available from
different mode of th distribution (different permutation modes).'''
def __init__(self, path="./Data", dataset="MNIST", tasks_number=5, download=False, train=True):
self.num_pixels = 0 # will be set in prepare_formatting
self.perm_file = "" # will be set in prepare_formatting
self.list_perm = []
super(Permutations, self).__init__(path=path,
dataset=dataset,
tasks_number=tasks_number,
scenario="Rotations",
download=download,
train=train,
num_classes=10)
from continuum.continuumbuilder import ContinuumBuilder
class Disjoint(ContinuumBuilder):
"""Scenario : each new classes gives never seen new classes to learn. The code here allows to choose in how many task we
want to split a dataset and therefor in autorize to choose the number of classes per tasks.
This scenario test algorithms when there is no intersection between tasks."""
def __init__(self, path="./Data", dataset="MNIST", tasks_number=1, download=False, train=True):
super(Disjoint, self).__init__(path=path,
dataset=dataset,
tasks_number=tasks_number,
scenario="Disjoint",
download=download,
train=train,
num_classes=10)
def select_index(self, ind_task, y):
cpt = int(self.num_classes / self.tasks_number)
if not os.path.exists(self.o):
os.makedirs(self.o)
if self.train:
self.out_file = os.path.join(self.o, '{}_{}_train{}.pt'.format(self.scenario, self.tasks_number, light_id))
else:
self.out_file = os.path.join(self.o, '{}_{}_test{}.pt'.format(self.scenario, self.tasks_number, light_id))
check_and_Download_data(self.i, self.dataset, scenario=self.scenario)
if self.download or not os.path.isfile(self.out_file):
self.formating_data()
else:
self.continuum = torch.load(self.out_file)
super(ContinuumBuilder, self).__init__(self.continuum)
from torchvision import transforms
import torch
from continuum.continuumbuilder import ContinuumBuilder
class Rotations(ContinuumBuilder):
'''Scenario : In this scenario, for each tasks all classes are available, however for each task data rotate a bit.
The goal is to test algorithms where all data for each classes are not available simultaneously and there is a concept
drift.'''
def __init__(self, path="./Data", dataset="MNIST", tasks_number=1, rotation_number=None, download=False, train=True, min_rot=0.0,
max_rot=90.0):
self.max_rot = max_rot
self.min_rot = min_rot
if rotation_number is None:
rotation_number = tasks_number
self.rotation_number = rotation_number
super(Rotations, self).__init__(path=path,
dataset=dataset,
tasks_number=tasks_number,