Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test(self):
# noinspection PyBroadException
try:
status = str(self.task.status)
message = self.task.data.status_message
if status == str(tasks.TaskStatusEnum.in_progress) and "stopping" in message:
return TaskStopReason.stopped
_expected_statuses = (
str(tasks.TaskStatusEnum.created),
str(tasks.TaskStatusEnum.queued),
str(tasks.TaskStatusEnum.in_progress),
)
if status not in _expected_statuses and "worker" not in message:
return TaskStopReason.status_changed
if status == str(tasks.TaskStatusEnum.created):
self._task_reset_state_counter += 1
if self._task_reset_state_counter >= self._number_of_consecutive_reset_tests:
return TaskStopReason.reset
def test(self):
# noinspection PyBroadException
try:
status = str(self.task.status)
message = self.task.data.status_message
if status == str(tasks.TaskStatusEnum.in_progress) and "stopping" in message:
return TaskStopReason.stopped
_expected_statuses = (
str(tasks.TaskStatusEnum.created),
str(tasks.TaskStatusEnum.queued),
str(tasks.TaskStatusEnum.in_progress),
)
if status not in _expected_statuses and "worker" not in message:
return TaskStopReason.status_changed
if status == str(tasks.TaskStatusEnum.created):
self._task_reset_state_counter += 1
if self._task_reset_state_counter >= self._number_of_consecutive_reset_tests:
return TaskStopReason.reset
self.task.log.warning(
"Task {} was reset! if state is consistent we shall terminate.".format(self.task.id),
)
else:
def upload_artifact(self, name, artifact_object=None, metadata=None, delete_after_upload=False):
if not Session.check_min_api_version('2.3'):
LoggerRoot.get_base_logger().warning('Artifacts not supported by your TRAINS-server version, '
'please upgrade to the latest server version')
return False
if name in self._artifacts_dict:
raise ValueError("Artifact by the name of {} is already registered, use register_artifact".format(name))
artifact_type_data = tasks.ArtifactTypeData()
override_filename_in_uri = None
override_filename_ext_in_uri = None
if np and isinstance(artifact_object, np.ndarray):
artifact_type = 'numpy'
artifact_type_data.content_type = 'application/numpy'
artifact_type_data.preview = str(artifact_object.__repr__())
override_filename_ext_in_uri = '.npz'
override_filename_in_uri = name+override_filename_ext_in_uri
fd, local_filename = mkstemp(prefix=name+'.', suffix=override_filename_ext_in_uri)
os.close(fd)
np.savez_compressed(local_filename, **{name: artifact_object})
delete_after_upload = True
elif pd and isinstance(artifact_object, pd.DataFrame):
artifact_type = 'pandas'
artifact_type_data.content_type = 'text/csv'
artifact_type_data.preview = str(artifact_object.__repr__())
if not Session.check_min_api_version('2.4'):
raise ValueError("Trains-server does not support DevOps features, upgrade trains-server to 0.12.0 or above")
task_id = task if isinstance(task, six.string_types) else task.id
session = cls._get_default_session()
if not queue_id:
req = queues.GetAllRequest(name=queue_name, only_fields=["id"])
res = cls._send(session=session, req=req)
if not res.response.queues:
raise ValueError('Could not find queue named "{}"'.format(queue_name))
queue_id = res.response.queues[0].id
if len(res.response.queues) > 1:
LoggerRoot.get_base_logger().info("Multiple queues with name={}, selecting queue id={}".format(
queue_name, queue_id))
req = tasks.EnqueueRequest(task=task_id, queue=queue_id)
res = cls._send(session=session, req=req)
resp = res.response
return resp
def _conditionally_start_task(self):
if str(self.status) == str(tasks.TaskStatusEnum.created):
self.started()
def _update_requirements(self, requirements):
if not isinstance(requirements, dict):
requirements = {'pip': requirements}
# protection, Old API might not support it
try:
self.data.script.requirements = requirements
self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements))
except Exception:
pass
def _auto_generate(self, project_name=None, task_name=None, task_type=TaskTypes.training):
created_msg = make_message('Auto-generated at %(time)s by %(user)s@%(host)s')
project_id = None
if project_name:
project_id = get_or_create_project(self, project_name, created_msg)
tags = [self._development_tag] if not running_remotely() else []
extra_properties = {'system_tags': tags} if Session.check_min_api_version('2.3') else {'tags': tags}
req = tasks.CreateRequest(
name=task_name or make_message('Anonymous task (%(user)s@%(host)s %(time)s)'),
type=tasks.TaskTypeEnum(task_type.value),
comment=created_msg,
project=project_id,
input={'view': {}},
**extra_properties
)
res = self.send(req)
return res.response.id
def _auto_generate(self, project_name=None, task_name=None, task_type=TaskTypes.training):
created_msg = make_message('Auto-generated at %(time)s by %(user)s@%(host)s')
project_id = None
if project_name:
project_id = get_or_create_project(self, project_name, created_msg)
tags = [self._development_tag] if not running_remotely() else []
extra_properties = {'system_tags': tags} if Session.check_min_api_version('2.3') else {'tags': tags}
req = tasks.CreateRequest(
name=task_name or make_message('Anonymous task (%(user)s@%(host)s %(time)s)'),
type=tasks.TaskTypeEnum(task_type.value),
comment=created_msg,
project=project_id,
input={'view': {}},
**extra_properties
)
res = self.send(req)
return res.response.id
def register(self, task):
if self._thread:
return True
if TaskStopSignal.enabled:
self._dev_stop_signal = TaskStopSignal(task=task)
self._support_ping = hasattr(tasks, 'PingRequest')
# if there is nothing to monitor, leave
if not self._support_ping and not self._dev_stop_signal:
return
self._task = task
self._exit_event.clear()
self._thread = Thread(target=self._daemon)
self._thread.daemon = True
self._thread.start()
return True
artifact = an_artifact
break
file_size = local_csv.stat().st_size
# upload file
uri = self._upload_local_file(local_csv, name, delete_after_upload=True,
override_filename=override_filename_in_uri,
override_filename_ext=override_filename_ext_in_uri)
# update task artifacts
with self._task_edit_lock:
if not artifact:
artifact = tasks.Artifact(key=name, type=self._pd_artifact_type)
self._task_artifact_list.append(artifact)
artifact_type_data = tasks.ArtifactTypeData()
artifact_type_data.data_hash = current_sha2
artifact_type_data.content_type = "text/csv"
artifact_type_data.preview = str(pd_artifact.__repr__())+'\n\n'+self._get_statistics({name: pd_artifact})
artifact.type_data = artifact_type_data
artifact.uri = uri
artifact.content_size = file_size
artifact.hash = file_sha2
artifact.timestamp = int(time())
artifact.display_data = [(str(k), str(v)) for k, v in pd_metadata.items()] if pd_metadata else None
self._task.set_artifacts(self._task_artifact_list)