Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def train(args):
dataset = omniglot(args.folder, shots=args.num_shots, ways=args.num_ways,
shuffle=True, test_shots=15, meta_train=True, download=args.download)
dataloader = BatchMetaDataLoader(dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.num_workers)
model = PrototypicalNetwork(1, args.embedding_size,
hidden_size=args.hidden_size)
model.to(device=args.device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Training loop
with tqdm(dataloader, total=args.num_batches) as pbar:
for batch_idx, batch in enumerate(pbar):
model.zero_grad()
train_inputs, train_targets = batch['train']
def train(args):
dataset = omniglot(args.folder, shots=args.num_shots, ways=args.num_ways,
shuffle=True, test_shots=15, meta_train=True, download=args.download)
dataloader = BatchMetaDataLoader(dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.num_workers)
model = ConvolutionalNeuralNetwork(1, args.num_ways,
hidden_size=args.hidden_size)
model.to(device=args.device)
model.train()
meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Training loop
with tqdm(dataloader, total=args.num_batches) as pbar:
for batch_idx, batch in enumerate(pbar):
model.zero_grad()
train_inputs, train_targets = batch['train']