Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
test_transforms = [
l2l.data.transforms.NWays(test_dataset, ways),
l2l.data.transforms.KShots(test_dataset, 2*shots),
l2l.data.transforms.LoadData(test_dataset),
l2l.data.transforms.RemapLabels(test_dataset),
l2l.data.transforms.ConsecutiveLabels(train_dataset),
]
test_tasks = l2l.data.TaskDataset(test_dataset,
task_transforms=test_transforms,
num_tasks=600)
# Create model
model = l2l.vision.models.MiniImagenetCNN(ways)
model.to(device)
maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
opt = optim.Adam(maml.parameters(), meta_lr)
loss = nn.CrossEntropyLoss(size_average=True, reduction='mean')
for iteration in range(num_iterations):
opt.zero_grad()
meta_train_error = 0.0
meta_train_accuracy = 0.0
meta_valid_error = 0.0
meta_valid_accuracy = 0.0
meta_test_error = 0.0
meta_test_accuracy = 0.0
for task in range(meta_batch_size):
# Compute meta-training loss
learner = maml.clone()
batch = train_tasks.sample()
evaluation_error, evaluation_accuracy = fast_adapt(batch,
l2l.data.transforms.FilterLabels(dataset, classes[1200:]),
l2l.data.transforms.NWays(dataset, ways),
l2l.data.transforms.KShots(dataset, 2*shots),
l2l.data.transforms.LoadData(dataset),
l2l.data.transforms.RemapLabels(dataset),
l2l.data.transforms.ConsecutiveLabels(dataset),
l2l.vision.transforms.RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0])
]
test_tasks = l2l.data.TaskDataset(dataset,
task_transforms=test_transforms,
num_tasks=1024)
# Create model
model = l2l.vision.models.OmniglotFC(28 ** 2, ways)
model.to(device)
maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
opt = optim.Adam(maml.parameters(), meta_lr)
loss = nn.CrossEntropyLoss(reduction='mean')
for iteration in range(num_iterations):
opt.zero_grad()
meta_train_error = 0.0
meta_train_accuracy = 0.0
meta_valid_error = 0.0
meta_valid_accuracy = 0.0
meta_test_error = 0.0
meta_test_accuracy = 0.0
for task in range(meta_batch_size):
# Compute meta-training loss
learner = maml.clone()
batch = train_tasks.sample()
evaluation_error, evaluation_accuracy = fast_adapt(batch,