Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_fifo_scheduler():
scheduler = ag.scheduler.FIFOScheduler(train_fn,
resource={'num_cpus': 4, 'num_gpus': 0},
num_trials=10,
reward_attr='accuracy',
time_attr='epoch',
checkpoint=None)
scheduler.run()
scheduler.join_jobs()
def test_scheduler():
scheduler = ag.scheduler.TaskScheduler()
print('scheduler', scheduler)
args = argparse.ArgumentParser()
config = {'lr': ag.searcher.sample_from(
lambda: np.power(10.0, np.random.uniform(-4, -1)))}
for i in range(10):
resource = ag.resource.Resources(num_cpus=2, num_gpus=1)
task = ag.scheduler.Task(my_task, {'args': args, 'config': config}, resource)
scheduler.add_task(task)
scheduler.join_tasks()
def test_scheduler():
scheduler = ag.scheduler.TaskScheduler()
print('scheduler', scheduler)
args = argparse.ArgumentParser()
config = {'lr': ag.searcher.sample_from(
lambda: np.power(10.0, np.random.uniform(-4, -1)))}
for i in range(10):
resource = ag.resource.Resources(num_cpus=2, num_gpus=1)
task = ag.scheduler.Task(my_task, {'args': args, 'config': config}, resource)
scheduler.add_task(task)
scheduler.join_tasks()
def test_hyperband_scheduler():
scheduler = ag.scheduler.HyperbandScheduler(train_fn,
resource={'num_cpus': 4, 'num_gpus': 0},
num_trials=10,
reward_attr='accuracy',
time_attr='epoch',
grace_period=1,
checkpoint=None)
scheduler.run()
scheduler.join_jobs()
label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
outputs = [net(X) for X in data]
metric.update(label, outputs)
return metric.get()[1]
if __name__ == '__main__':
args = parse_args()
if args.debug:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
train_cifar.update(epochs=args.epochs)
# create searcher and scheduler
if args.scheduler == 'hyperband':
myscheduler = ag.scheduler.HyperbandScheduler(train_cifar,
resource={'num_cpus': 2, 'num_gpus': args.num_gpus},
num_trials=args.num_trials,
checkpoint=args.checkpoint,
time_attr='epoch', reward_attr="accuracy",
max_t=args.epochs, grace_period=args.epochs//4)
elif args.scheduler == 'fifo':
myscheduler = ag.scheduler.FIFOScheduler(train_cifar,
resource={'num_cpus': 2, 'num_gpus': args.num_gpus},
num_trials=args.num_trials,
checkpoint=args.checkpoint,
reward_attr="accuracy")
else:
raise RuntimeError('Unsuported Scheduler!')
myscheduler.run()
myscheduler.join_tasks()
if args.debug:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
train_cifar.update(epochs=args.epochs)
# create searcher and scheduler
if args.scheduler == 'hyperband':
myscheduler = ag.scheduler.HyperbandScheduler(train_cifar,
resource={'num_cpus': 2, 'num_gpus': args.num_gpus},
num_trials=args.num_trials,
checkpoint=args.checkpoint,
time_attr='epoch', reward_attr="accuracy",
max_t=args.epochs, grace_period=args.epochs//4)
elif args.scheduler == 'fifo':
myscheduler = ag.scheduler.FIFOScheduler(train_cifar,
resource={'num_cpus': 2, 'num_gpus': args.num_gpus},
num_trials=args.num_trials,
checkpoint=args.checkpoint,
reward_attr="accuracy")
else:
raise RuntimeError('Unsuported Scheduler!')
myscheduler.run()
myscheduler.join_tasks()
myscheduler.get_training_curves('{}.png'.format(os.path.splitext(args.checkpoint)[0]))
print('The Best Configuration and Accuracy are: {}, {}'.format(myscheduler.get_best_config(),
myscheduler.get_best_reward()))
args,
{'num_cpus': int(resources_per_trial['max_num_cpus']),
'num_gpus': int(resources_per_trial['max_num_gpus'])},
searcher,
checkpoint=savedir,
resume=resume,
time_attr='epoch',
reward_attr='accuracy',
max_t=resources_per_trial[
'max_training_epochs'],
grace_period=resources_per_trial[
'max_training_epochs']//4,
visualizer=visualizer)
# TODO (cgraywang): use empiral val now
else:
trial_scheduler = ag.scheduler.FIFO_Scheduler(
train_image_classification,
args,
{'num_cpus': int(resources_per_trial['max_num_cpus']),
'num_gpus': int(resources_per_trial['max_num_gpus'])},
searcher,
checkpoint=savedir,
resume=resume,
visualizer=visualizer)
trial_scheduler.run(num_trials=stop_criterion['max_trial_count'])
# TODO (cgraywang)
trials = None
best_result = trial_scheduler.get_best_reward()
best_config = trial_scheduler.get_best_config()
results = Results(trials, best_result, best_config, time.time() - start_fit_time)
logger.info('Finished.')
return results
cs = CS.ConfigurationSpace()
lr = CSH.UniformFloatHyperparameter('lr', lower=1e-4, upper=1e-1, log=True)
cs.add_hyperparameter(lr)
# create searcher and scheduler
searcher = ag.searcher.RandomSampling(cs)
if args.scheduler == 'hyperband':
myscheduler = ag.scheduler.Hyperband_Scheduler(train_mnist, args,
{'num_cpus': 2, 'num_gpus': 0}, searcher,
num_trials=args.num_trials,
checkpoint=args.checkpoint,
resume = args.resume,
time_attr='epoch', reward_attr="accuracy",
max_t=args.epochs, grace_period=1)
else:
myscheduler = ag.scheduler.FIFO_Scheduler(train_mnist, args,
{'num_cpus': 2, 'num_gpus': 0}, searcher,
num_trials=args.num_trials,
checkpoint=args.checkpoint,
resume = args.resume,
reward_attr="accuracy")
myscheduler.run()
myscheduler.join_tasks()
myscheduler.get_training_curves('{}.png'.format(os.path.splitext(args.checkpoint)[0]))
print('The Best Configuration and Accuracy are: {}, {}'.format(myscheduler.get_best_config(),
myscheduler.get_best_reward()))
curr_loss = nd.mean(loss).asscalar()
moving_loss = (curr_loss if ((i == 0) and (e == 0))
else (1 - smoothing_constant) * moving_loss + smoothing_constant * curr_loss)
test_accuracy = evaluate_accuracy(test_data, net)
train_accuracy = evaluate_accuracy(train_data, net)
print("Epoch %s. Loss: %s, Train_acc %s, Test_acc %s" % (e, moving_loss, train_accuracy, test_accuracy))
if __name__ == "__main__":
args = parser.parse_args()
# train_mnist(args) # standard execution
config = {'lr': ag.distribution.sample_from(
lambda: np.power(10.0, np.random.uniform(-4, -1)))}
myscheduler = ag.scheduler.TaskScheduler()
for i in range(5):
resource = ag.resource.Resources(num_cpus=2, num_gpus=1)
task = ag.scheduler.Task(train_mnist, {'args': args, 'config': config}, resource)
myscheduler.add_task(task)