Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
batch_num = 200
for i in range(batch_num):
batch_size = nni.choice({50: 50, 250: 250, 500: 500}, name=
'batch_size')
batch = mnist.train.next_batch(batch_size)
dropout_rate = nni.choice({1: 1, 5: 5}, name='dropout_rate')
mnist_network.train_step.run(feed_dict={mnist_network.x: batch[
0], mnist_network.y: batch[1], mnist_network.keep_prob:
dropout_rate})
if i % 100 == 0:
test_acc = mnist_network.accuracy.eval(feed_dict={
mnist_network.x: mnist.test.images, mnist_network.y:
mnist.test.labels, mnist_network.keep_prob: 1.0})
nni.report_intermediate_result(test_acc)
test_acc = mnist_network.accuracy.eval(feed_dict={mnist_network.x:
mnist.test.images, mnist_network.y: mnist.test.labels,
mnist_network.keep_prob: 1.0})
nni.report_final_result(test_acc)
def on_epoch_end(self, epoch, logs={}):
'''
Run on end of each epoch
'''
LOG.debug(logs)
nni.report_intermediate_result(logs["val_acc"])
def on_epoch_end(self, epoch, logs=None):
"""Reports intermediate accuracy to NNI framework"""
# TensorFlow 2.0 API reference claims the key is `val_acc`, but in fact it's `val_accuracy`
if 'val_acc' in logs:
nni.report_intermediate_result(logs['val_acc'])
else:
nni.report_intermediate_result(logs['val_accuracy'])
def on_epoch_end(self, epoch, logs=None):
"""
Run on end of each epoch
"""
if logs is None:
logs = dict()
logger.debug(logs)
nni.report_intermediate_result(logs["val_acc"])
num_batches = ceil(len(train_set.dataset) / float(bsz))
total_loss = 0.0
for epoch in range(3):
epoch_loss = 0.0
for data, target in train_set:
data, target = Variable(data), Variable(target)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
epoch_loss += loss.item()
loss.backward()
average_gradients(model)
optimizer.step()
#logger.debug('Rank: ', rank, ', epoch: ', epoch, ': ', epoch_loss / num_batches)
if rank == 0:
nni.report_intermediate_result(epoch_loss / num_batches)
total_loss += (epoch_loss / num_batches)
total_loss /= 3
logger.debug('Final loss: {}'.format(total_loss))
if rank == 0:
nni.report_final_result(total_loss)
def on_epoch_end(self, epoch, logs={}):
'''
Run on end of each epoch
'''
LOG.debug(logs)
nni.report_intermediate_result(logs["val_acc"])
def on_epoch_end(self, epoch, logs=None):
"""
Run on end of each epoch
"""
if logs is None:
logs = dict()
logger.debug(logs)
nni.report_intermediate_result(logs["val_acc"])
batch_num = nni.choice(50, 250, 500, name='batch_num')
for i in range(batch_num):
batch = mnist.train.next_batch(batch_num)
dropout_rate = nni.choice(1, 5, name='dropout_rate')
mnist_network.train_step.run(feed_dict={mnist_network.images: batch[0],
mnist_network.labels: batch[1],
mnist_network.keep_prob: dropout_rate}
)
if i % 100 == 0:
test_acc = mnist_network.accuracy.eval(
feed_dict={mnist_network.images: mnist.test.images,
mnist_network.labels: mnist.test.labels,
mnist_network.keep_prob: 1.0})
nni.report_intermediate_result(test_acc)
logger.debug('test accuracy %g', test_acc)
logger.debug('Pipe send intermediate result done.')
test_acc = mnist_network.accuracy.eval(
feed_dict={mnist_network.images: mnist.test.images,
mnist_network.labels: mnist.test.labels,
mnist_network.keep_prob: 1.0})
nni.report_final_result(test_acc)
logger.debug('Final result is %g', test_acc)
logger.debug('Send final result done.')
answers = generate_predict_json(
position1, position2, ids, contexts)
if save_path is not None:
with open(os.path.join(save_path, 'epoch%d.prediction' % epoch), 'w') as file:
json.dump(answers, file)
else:
answers = json.dumps(answers)
answers = json.loads(answers)
iter = epoch + 1
acc = evaluate.evaluate_with_predictions(
args.dev_file, answers)
logger.debug('Send intermediate acc: %s', str(acc))
nni.report_intermediate_result(acc)
logger.debug('Send intermediate result done.')
if acc > bestacc:
if acc * improvement_threshold > bestacc:
patience = max(patience, iter * patience_increase)
bestacc = acc
if save_path is not None:
saver.save(os.path.join(sess, save_path + 'epoch%d.model' % epoch))
with open(os.path.join(save_path, 'epoch%d.score' % epoch), 'wb') as file:
pickle.dump(
(position1, position2, ids, contexts), file)
logger.debug('epoch %d acc %g bestacc %g' %
(epoch, acc, bestacc))
if patience <= iter: