Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
with h5py.File(filename, 'w') as f:
group = f.create_group('datasets')
for name, indices in classes.items():
group.create_dataset(name, data=images[indices])
labels_filename = os.path.join(self.root, self.filename_labels.format(split))
with open(labels_filename, 'w') as f:
labels = sorted(list(classes.keys()))
json.dump(labels, f)
if os.path.isfile(pkl_filename):
os.remove(pkl_filename)
class MiniImagenetDataset(Dataset):
def __init__(self, data, class_name, transform=None, target_transform=None):
super(MiniImagenetDataset, self).__init__(transform=transform,
target_transform=target_transform)
self.data = data
self.class_name = class_name
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index):
image = Image.fromarray(self.data[index])
target = self.class_name
if self.transform is not None:
image = self.transform(image)
dataset = group.create_dataset(label, (len(images),), dtype=dtype)
for i, image in enumerate(images):
with open(image, 'rb') as f:
array = bytearray(f.read())
dataset[i] = np.asarray(array, dtype=np.uint8)
tar_folder, _ = os.path.splitext(tgz_filename)
if os.path.isdir(tar_folder):
shutil.rmtree(tar_folder)
attributes_filename = os.path.join(self.root, 'attributes.txt')
if os.path.isfile(attributes_filename):
os.remove(attributes_filename)
class CUBDataset(Dataset):
def __init__(self, data, label, transform=None, target_transform=None):
super(CUBDataset, self).__init__(transform=transform,
target_transform=target_transform)
self.data = data
self.label = label
def __len__(self):
return len(self.data)
def __getitem__(self, index):
image = Image.open(io.BytesIO(self.data[index])).convert('RGB')
target = self.label
if self.transform is not None:
image = self.transform(image)
shutil.rmtree(os.path.join(self.root, name))
for split in ['train', 'val', 'test']:
filename = os.path.join(self.root, self.filename_labels.format(
'vinyals_', split))
data = get_asset(self.folder, '{0}.json'.format(split), dtype='json')
with open(filename, 'w') as f:
labels = sorted([('images_{0}'.format(name), alphabet, character)
for (name, alphabets) in data.items()
for (alphabet, characters) in alphabets.items()
for character in characters])
json.dump(labels, f)
class OmniglotDataset(Dataset):
def __init__(self, data, character_name, transform=None, target_transform=None):
super(OmniglotDataset, self).__init__(transform=transform,
target_transform=target_transform)
self.data = data
self.character_name = character_name
def __len__(self):
return len(self.data)
def __getitem__(self, index):
image = Image.fromarray(self.data[index])
target = self.character_name
if self.transform is not None:
image = self.transform(image)
dataset = group.create_dataset(fine_label_names[j],
data=images[fine_labels == j])
fine_names[coarse_name] = [fine_label_names[j] for j in fine_indices]
filename_fine_names = os.path.join(self.root, self.filename_fine_names)
with open(filename_fine_names, 'w') as f:
json.dump(fine_names, f)
gz_folder = os.path.join(self.root, self.gz_folder)
if os.path.isdir(gz_folder):
shutil.rmtree(gz_folder)
if os.path.isfile('{0}.tar.gz'.format(gz_folder)):
os.remove('{0}.tar.gz'.format(gz_folder))
class CIFAR100Dataset(Dataset):
def __init__(self, data, coarse_label_name, fine_label_name,
transform=None, target_transform=None):
super(CIFAR100Dataset, self).__init__(transform=transform,
target_transform=target_transform)
self.data = data
self.coarse_label_name = coarse_label_name
self.fine_label_name = fine_label_name
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index):
image = Image.fromarray(self.data[index])
target = (self.coarse_label_name, self.fine_label_name)
if self.transform is not None:
dtype = h5py.special_dtype(vlen=np.uint8)
for i, label in enumerate(tqdm(labels_str, desc=filename)):
indices, = np.where(labels['label_specific'] == i)
dataset = group.create_dataset(label, (len(indices),), dtype=dtype)
general_idx = general_labels[indices[0]]
dataset.attrs['label_general'] = (general_labels_str[general_idx]
if general_idx < len(general_labels_str) else '')
dataset.attrs['label_specific'] = label
for j, k in enumerate(indices):
dataset[j] = np.squeeze(images[k])
if os.path.isdir(tar_folder):
shutil.rmtree(tar_folder)
class TieredImagenetDataset(Dataset):
def __init__(self, data, general_class_name, specific_class_name,
transform=None, target_transform=None):
super(TieredImagenetDataset, self).__init__(transform=transform,
target_transform=target_transform)
self.data = data
self.general_class_name = general_class_name
self.specific_class_name = specific_class_name
def __len__(self):
return len(self.data)
def __getitem__(self, index):
image = Image.open(io.BytesIO(self.data[index]))
target = (self.general_class_name, self.specific_class_name)
if self.transform is not None: