Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _collect_metrics(self, inputs_dict, trial_id, results_dict):
thread_name = threading.currentThread().getName()
logger.info('starting thread: ' + thread_name)
if self.remote:
try:
# checkpoint_path = 'best_' + trial_id + '.pt'
checkpoint_path = 'checkpoint.pt'
path_to_tensorboard_dir = 'runs'
logger.info("trying to get execution objects")
execution_obj = self._run_trial_remote_execution(inputs_dict)
logger.info("got execution objects")
# TODO: Turn execution_obj into metrics
while execution_obj.latest_status['status'] != 'success':
#TODO: make time sleep in env variable
time.sleep(5)
execution_obj = dl.executions.get(execution_id=execution_obj.id)
if execution_obj.latest_status['status'] == 'failed':
raise Exception("plugin execution failed")
logger.info("execution object status is successful")
if os.path.exists(checkpoint_path):
logger.info('overwriting checkpoint.pt . . .')
os.remove(checkpoint_path)
if os.path.exists(path_to_tensorboard_dir):
logger.info('overwriting tenorboards runs . . .')
shutil.rmtree(path_to_tensorboard_dir)
# download artifacts, should contain metrics and tensorboard runs
# TODO: download many different metrics then should have id hash as well..
self.project.artifacts.download(package_name=self.package_name,
execution_id=execution_obj.id,
local_path=os.getcwd())
logger.info('going to load ' + checkpoint_path + ' into checkpoint')
if torch.cuda.is_available():
self.service = self.global_project.services.get(service_name='predict')
model_specs = self.optimal_model.unwrap()
dataset_input = dl.FunctionIO(type='Dataset', name='dataset', value={"dataset_id": self.dataset_id})
checkpoint_path_input = dl.FunctionIO(type='Json', name='checkpoint_path', value={"checkpoint_path": checkpoint_path})
val_query_input = dl.FunctionIO(type='Json', name='val_query', value=self.val_query)
model_specs_input = dl.FunctionIO(type='Json', name='model_specs', value=model_specs)
inputs = [dataset_input, val_query_input, checkpoint_path_input, model_specs_input]
logger.info('checkpoint is type: ' + str(type(checkpoint_path)))
try:
logger.info("trying to get execution object")
execution_obj = self._run_pred_remote_execution(inputs)
logger.info("got execution object")
# TODO: Turn execution_obj into metrics
while execution_obj.latest_status['status'] != 'success':
time.sleep(5)
execution_obj = dl.executions.get(execution_id=execution_obj.id)
if execution_obj.latest_status['status'] == 'failed':
raise Exception("plugin execution failed")
logger.info("execution object status is successful")
# download artifacts, should contain dir with txt file annotations
# TODO: download many different metrics then should have id hash as well..
self.project.artifacts.download(package_name=self.package_name,
execution_id=execution_obj.id,
local_path=os.getcwd())
except Exception as e:
Exception(' had an exception: \n', repr(e))
def _launch_remote_best_trial(self, best_trial):
model_specs = self.optimal_model.unwrap()
dataset_input = dl.FunctionIO(type='Dataset', name='dataset', value={"dataset_id": self.dataset_id})
train_query_input = dl.FunctionIO(type='Json', name='train_query', value=self.train_query)
val_query_input = dl.FunctionIO(type='Json', name='val_query', value=self.val_query)
hp_value_input = dl.FunctionIO(type='Json', name='hp_values', value=best_trial['hp_values'])
model_specs_input = dl.FunctionIO(type='Json', name='model_specs', value=model_specs)
inputs = [dataset_input, train_query_input, val_query_input, hp_value_input, model_specs_input]
execution_obj = self._run_trial_remote_execution(inputs)
while execution_obj.latest_status['status'] != 'success':
time.sleep(5)
execution_obj = dl.executions.get(execution_id=execution_obj.id)
if execution_obj.latest_status['status'] == 'failed':
raise Exception("package execution failed")
return execution_obj
download_and_organize(path_to_dataset=path_to_dataset, dataset_obj=test_dataset, filters=filters)
json_file_path = os.path.join(path_to_dataset, 'json')
self.model_obj = self.project.models.get(model_name='retinanet')
self.adapter = self.model_obj.build(local_path=os.getcwd())
logger.info('model built')
while 1:
self.compute = precision_recall_compute()
self.compute.add_dataloop_local_annotations(json_file_path)
logger.info("running new execution")
execution_obj = self.service.execute(function_name='search', execution_input=[self.configs_input],
project_id='72bb623f-517f-472b-ad69-104fed8ee94a')
while execution_obj.latest_status['status'] != 'success':
sleep(5)
execution_obj = dl.executions.get(execution_id=execution_obj.id)
if execution_obj.latest_status['status'] == 'failed':
raise Exception("plugin execution failed")
logger.info("execution object status is successful")
self.project.artifacts.download(package_name='zazuml',
execution_id=execution_obj.id,
local_path=os.getcwd())
logs_file_name = 'timer_logs_' + str(execution_obj.id) + '.conf'
graph_file_name = 'precision_recall_' + str(execution_obj.id) + '.png'
self.cycle_logger = init_logging(__name__, filename=logs_file_name)
logger.info('artifact download finished')
logger.info(str(os.listdir('.')))
# load new checkpoint and change to unique name
new_checkpoint_name = 'checkpoint_' + str(execution_obj.id) + '.pt'
logger.info(str(os.listdir('.')))