Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def train(net, train_iter, test_iter, loss, trainer, num_epochs,
ctx_list=d2l.try_all_gpus()):
num_batches, timer = len(train_iter), d2l.Timer()
for epoch in range(num_epochs):
# store training_loss, training_accuracy, num_examples, num_features
metric = [0.0] * 4
for i, (features, labels) in enumerate(train_iter):
timer.start()
l, acc = d2l.train_batch_ch12(
net, features, labels, loss, trainer, ctx_list)
metric = [a+b for a, b in zip(metric, (l, acc, labels.shape[0], labels.size))]
timer.stop()
if (i+1) % (num_batches // 5) == 0:
print('loss %.3f, train acc %.3f' % (metric[0]/metric[2], metric[1]/metric[3]))
test_acc = d2l.evaluate_accuracy_gpus(net, test_iter)
print('loss %.3f, train acc %.3f, test acc %.3f' % (
metric[0]/metric[2], metric[1]/metric[3], test_acc))
print('%.1f exampes/sec on %s' % (
metric[2]*num_epochs/timer.sum(), ctx_list))