How to use the learn2learn.vision.datasets.MiniImagenet function in learn2learn

To help you get started, we’ve selected a few learn2learn examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github learnables / learn2learn / tests / integration / maml_miniimagenet_test_notravis.py View on Github external
adaptation_steps=1,
        num_iterations=60000,
        cuda=False,
        seed=42,
):
    random.seed(seed)
    np.random.seed(seed)
    th.manual_seed(seed)
    device = th.device('cpu')
    if cuda and th.cuda.device_count():
        th.cuda.manual_seed(seed)
        device = th.device('cuda')

    # Create Datasets
    train_dataset = l2l.vision.datasets.MiniImagenet(root='./data', mode='train')
    valid_dataset = l2l.vision.datasets.MiniImagenet(root='./data', mode='validation')
    test_dataset = l2l.vision.datasets.MiniImagenet(root='./data', mode='test')
    train_dataset = l2l.data.MetaDataset(train_dataset)
    valid_dataset = l2l.data.MetaDataset(valid_dataset)
    test_dataset = l2l.data.MetaDataset(test_dataset)

    train_transforms = [
        l2l.data.transforms.NWays(train_dataset, ways),
        l2l.data.transforms.KShots(train_dataset, 2*shots),
        l2l.data.transforms.LoadData(train_dataset),
        l2l.data.transforms.RemapLabels(train_dataset),
        l2l.data.transforms.ConsecutiveLabels(train_dataset),
    ]
    train_tasks = l2l.data.TaskDataset(train_dataset,
                                       task_transforms=train_transforms,
                                       num_tasks=20000)
github learnables / learn2learn / examples / vision / maml_miniimagenet.py View on Github external
adaptation_steps=1,
        num_iterations=60000,
        cuda=True,
        seed=42,
):
    random.seed(seed)
    np.random.seed(seed)
    th.manual_seed(seed)
    device = th.device('cpu')
    if cuda and th.cuda.device_count():
        th.cuda.manual_seed(seed)
        device = th.device('cuda')

    # Create Datasets
    train_dataset = l2l.vision.datasets.MiniImagenet(root='./data', mode='train')
    valid_dataset = l2l.vision.datasets.MiniImagenet(root='./data', mode='validation')
    test_dataset = l2l.vision.datasets.MiniImagenet(root='./data', mode='test')
    train_dataset = l2l.data.MetaDataset(train_dataset)
    valid_dataset = l2l.data.MetaDataset(valid_dataset)
    test_dataset = l2l.data.MetaDataset(test_dataset)

    train_transforms = [
        NWays(train_dataset, ways),
        KShots(train_dataset, 2*shots),
        LoadData(train_dataset),
        RemapLabels(train_dataset),
        ConsecutiveLabels(train_dataset),
    ]
    train_tasks = l2l.data.TaskDataset(train_dataset,
                                       task_transforms=train_transforms,
                                       num_tasks=20000)
github learnables / learn2learn / examples / vision / protonet_miniimagenet.py View on Github external
device = torch.device('cpu')
    if args.gpu:
        print("Using gpu")
        torch.cuda.manual_seed(43)
        device = torch.device('cuda')

    model = Convnet()
    model.to(device)

    path_data = '/datadrive/few-shot/miniimagenetdata'
    train_dataset = l2l.vision.datasets.MiniImagenet(
        root=path_data, mode='train')
    valid_dataset = l2l.vision.datasets.MiniImagenet(
        root=path_data, mode='validation')
    test_dataset = l2l.vision.datasets.MiniImagenet(
        root=path_data, mode='test')

    train_sampler = l2l.data.NShotKWayTaskSampler(
        train_dataset.y, 100, args.train_way, args.shot, args.train_query)

    train_loader = DataLoader(dataset=train_dataset, batch_sampler=train_sampler,
                              num_workers=8, pin_memory=True)

    val_sampler = l2l.data.NShotKWayTaskSampler(
        valid_dataset.y, 400, args.test_way, args.shot, args.train_query)

    val_loader = DataLoader(dataset=valid_dataset, batch_sampler=val_sampler,
                            num_workers=8, pin_memory=True)

    test_sampler = l2l.data.NShotKWayTaskSampler(
        test_dataset.y, 2000, args.test_way, args.test_shot, args.test_query)