How to use the torchmeta.utils.data.Dataset function in torchmeta

To help you get started, we’ve selected a few torchmeta examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github tristandeleu / pytorch-meta / torchmeta / datasets / miniimagenet.py View on Github external
with h5py.File(filename, 'w') as f:
                group = f.create_group('datasets')
                for name, indices in classes.items():
                    group.create_dataset(name, data=images[indices])

            labels_filename = os.path.join(self.root, self.filename_labels.format(split))
            with open(labels_filename, 'w') as f:
                labels = sorted(list(classes.keys()))
                json.dump(labels, f)

            if os.path.isfile(pkl_filename):
                os.remove(pkl_filename)


class MiniImagenetDataset(Dataset):
    def __init__(self, data, class_name, transform=None, target_transform=None):
        super(MiniImagenetDataset, self).__init__(transform=transform,
            target_transform=target_transform)
        self.data = data
        self.class_name = class_name

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        image = Image.fromarray(self.data[index])
        target = self.class_name

        if self.transform is not None:
            image = self.transform(image)
github tristandeleu / pytorch-meta / torchmeta / datasets / cub.py View on Github external
dataset = group.create_dataset(label, (len(images),), dtype=dtype)
                    for i, image in enumerate(images):
                        with open(image, 'rb') as f:
                            array = bytearray(f.read())
                            dataset[i] = np.asarray(array, dtype=np.uint8)

        tar_folder, _ = os.path.splitext(tgz_filename)
        if os.path.isdir(tar_folder):
            shutil.rmtree(tar_folder)

        attributes_filename = os.path.join(self.root, 'attributes.txt')
        if os.path.isfile(attributes_filename):
            os.remove(attributes_filename)


class CUBDataset(Dataset):
    def __init__(self, data, label, transform=None, target_transform=None):
        super(CUBDataset, self).__init__(transform=transform,
            target_transform=target_transform)
        self.data = data
        self.label = label

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = Image.open(io.BytesIO(self.data[index])).convert('RGB')
        target = self.label

        if self.transform is not None:
            image = self.transform(image)
github tristandeleu / pytorch-meta / torchmeta / datasets / omniglot.py View on Github external
shutil.rmtree(os.path.join(self.root, name))

        for split in ['train', 'val', 'test']:
            filename = os.path.join(self.root, self.filename_labels.format(
                'vinyals_', split))
            data = get_asset(self.folder, '{0}.json'.format(split), dtype='json')

            with open(filename, 'w') as f:
                labels = sorted([('images_{0}'.format(name), alphabet, character)
                    for (name, alphabets) in data.items()
                    for (alphabet, characters) in alphabets.items()
                    for character in characters])
                json.dump(labels, f)


class OmniglotDataset(Dataset):
    def __init__(self, data, character_name, transform=None, target_transform=None):
        super(OmniglotDataset, self).__init__(transform=transform,
            target_transform=target_transform)
        self.data = data
        self.character_name = character_name

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = Image.fromarray(self.data[index])
        target = self.character_name

        if self.transform is not None:
            image = self.transform(image)
github tristandeleu / pytorch-meta / torchmeta / datasets / cifar100 / base.py View on Github external
dataset = group.create_dataset(fine_label_names[j],
                        data=images[fine_labels == j])
                fine_names[coarse_name] = [fine_label_names[j] for j in fine_indices]

        filename_fine_names = os.path.join(self.root, self.filename_fine_names)
        with open(filename_fine_names, 'w') as f:
            json.dump(fine_names, f)

        gz_folder = os.path.join(self.root, self.gz_folder)
        if os.path.isdir(gz_folder):
            shutil.rmtree(gz_folder)
        if os.path.isfile('{0}.tar.gz'.format(gz_folder)):
            os.remove('{0}.tar.gz'.format(gz_folder))


class CIFAR100Dataset(Dataset):
    def __init__(self, data, coarse_label_name, fine_label_name,
                 transform=None, target_transform=None):
        super(CIFAR100Dataset, self).__init__(transform=transform,
            target_transform=target_transform)
        self.data = data
        self.coarse_label_name = coarse_label_name
        self.fine_label_name = fine_label_name

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        image = Image.fromarray(self.data[index])
        target = (self.coarse_label_name, self.fine_label_name)

        if self.transform is not None:
github tristandeleu / pytorch-meta / torchmeta / datasets / tieredimagenet.py View on Github external
dtype = h5py.special_dtype(vlen=np.uint8)
                for i, label in enumerate(tqdm(labels_str, desc=filename)):
                    indices, = np.where(labels['label_specific'] == i)
                    dataset = group.create_dataset(label, (len(indices),), dtype=dtype)
                    general_idx = general_labels[indices[0]]
                    dataset.attrs['label_general'] = (general_labels_str[general_idx]
                        if general_idx < len(general_labels_str) else '')
                    dataset.attrs['label_specific'] = label
                    for j, k in enumerate(indices):
                        dataset[j] = np.squeeze(images[k])

        if os.path.isdir(tar_folder):
            shutil.rmtree(tar_folder)


class TieredImagenetDataset(Dataset):
    def __init__(self, data, general_class_name, specific_class_name,
                 transform=None, target_transform=None):
        super(TieredImagenetDataset, self).__init__(transform=transform,
            target_transform=target_transform)
        self.data = data
        self.general_class_name = general_class_name
        self.specific_class_name = specific_class_name

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = Image.open(io.BytesIO(self.data[index]))
        target = (self.general_class_name, self.specific_class_name)

        if self.transform is not None: