Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def load_state_dict(self, state_dict):
"""Load from the saved state dict.
Examples
--------
>>> scheduler.load_state_dict(ag.load('checkpoint.ag'))
"""
self.finished_tasks = pickle.loads(state_dict['finished_tasks'])
Task.set_id(state_dict['TASK_ID'])
logger.debug('\nLoading finished_tasks: {} '.format(self.finished_tasks))
def _dict_from_task(self, task):
if isinstance(task, Task):
return {'TASK_ID': task.task_id, 'Args': task.args}
else:
assert isinstance(task, dict)
return {'TASK_ID': task['TASK_ID'], 'Args': task['Args']}
def schedule_next(self):
"""Schedule next searcher suggested task
"""
# Allow for the promotion of a previously chosen config. Also,
# extra_kwargs contains extra info passed to both add_job and to
# get_config (if no config is promoted)
config, extra_kwargs = self._promote_config()
if config is None:
# No config to promote: Query next config to evaluate from searcher
config = self.searcher.get_config(**extra_kwargs)
extra_kwargs['new_config'] = True
else:
# This is not a new config, but a paused one which is now promoted
extra_kwargs['new_config'] = False
task = Task(self.train_fn, {'args': self.args, 'config': config},
DistributedResource(**self.resource))
self.add_job(task, **extra_kwargs)
self._add_training_result(task.task_id, reported_result, task.args['config'])
reporter.move_on()
last_result = reported_result
if last_result is not None:
self.searcher.update(config, last_result[self._reward_attr], done=True)
with self.lock:
results[pickle.dumps(config)] = last_result[self._reward_attr]
# launch the tasks
tasks = []
task_jobs = []
reporter_threads = []
for config in configs:
logger.debug('scheduling config: {}'.format(config))
# create task
task = Task(self.train_fn, {'args': self.args, 'config': config},
DistributedResource(**self.resource))
reporter = DistStatusReporter()
task.args['reporter'] = reporter
task_job = self.add_job(task)
# run reporter
reporter_thread = threading.Thread(target=_run_reporter, args=(task, task_job, reporter))
reporter_thread.start()
tasks.append(task)
task_jobs.append(task_job)
reporter_threads.append(reporter_thread)
for p1, p2 in zip(task_jobs, reporter_threads):
p1.result()
p2.join()
with self.LOCK:
for task in tasks:
def _async_run_trial():
self.mp_count.value += 1
self.mp_seed.value += 1
seed = self.mp_seed.value
mx.random.seed(seed)
with mx.autograd.record():
# sample one configuration
with self.lock:
config, log_prob, entropy = self.controller.sample(with_details=True)
config = config[0]
task = Task(self.train_fn, {'args': self.args, 'config': config},
DistributedResource(**self.resource))
# start training task
reporter = DistStatusReporter()
task.args['reporter'] = reporter
task_thread = self.add_job(task)
# run reporter
last_result = None
config = task.args['config']
while task_thread.is_alive():
reported_result = reporter.fetch()
if 'done' in reported_result and reported_result['done'] is True:
reporter.move_on()
task_thread.join()
break
self._add_training_result(task.task_id, reported_result, task.args['config'])
def load_state_dict(self, state_dict):
"""Load from the saved state dict.
Examples
--------
>>> scheduler.load_state_dict(ag.load('checkpoint.ag'))
"""
self.finished_tasks = pickle.loads(state_dict['finished_tasks'])
#self.baseline = pickle.loads(state_dict['baseline'])
Task.set_id(state_dict['TASK_ID'])
self.searcher.load_state_dict(state_dict['searcher'])
self.training_history = json.loads(state_dict['training_history'])
if self.visualizer == 'mxboard' or self.visualizer == 'tensorboard':
self.mxboard._scalar_dict = json.loads(state_dict['visualizer'])
logger.debug('Loading Searcher State {}'.format(self.searcher))