Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def patch_matplotlib():
# only once
if PatchedMatplotlib._patched_original_plot is not None:
return True
# noinspection PyBroadException
try:
# we support matplotlib version 2.0.0 and above
import matplotlib
PatchedMatplotlib._matplot_major_version = int(matplotlib.__version__.split('.')[0])
if PatchedMatplotlib._matplot_major_version < 2:
LoggerRoot.get_base_logger().warning(
'matplotlib binding supports version 2.0 and above, found version {}'.format(
matplotlib.__version__))
return False
if running_remotely():
# disable GUI backend - make headless
matplotlib.rcParams['backend'] = 'agg'
import matplotlib.pyplot
matplotlib.pyplot.switch_backend('agg')
import matplotlib.pyplot as plt
import matplotlib.figure as figure
from matplotlib import _pylab_helpers
if six.PY2:
PatchedMatplotlib._patched_original_plot = staticmethod(plt.show)
PatchedMatplotlib._patched_original_imshow = staticmethod(plt.imshow)
PatchedMatplotlib._patched_original_figure = staticmethod(figure.Figure.show)
else:
PatchedMatplotlib._patched_original_plot = plt.show
PatchedMatplotlib._patched_original_imshow = plt.imshow
PatchedMatplotlib._patched_original_figure = figure.Figure.show
def add_tags(self, tags):
"""
Add tags to this task. Old tags are not deleted
In remote, this is a no-op.
:param tags: An iterable or space separated string of new tags (string) to add.
:type tags: str or iterable of str
"""
if not running_remotely() or not self.is_main_task():
if isinstance(tags, six.string_types):
tags = tags.split(" ")
self.data.tags.extend(tags)
self._edit(tags=list(set(self.data.tags)))
def _running_remotely(self):
return bool(running_remotely() and self._task is not None)
def _connect_task_parameters(self, attr_class):
self._try_set_connected_parameter_type(self._ConnectedParametersType.task_parameters)
if running_remotely() and self.is_main_task():
attr_class.update_from_dict(self.get_parameters())
else:
self.set_parameters(attr_class.to_dict())
def reset(self, set_started_on_success=False, force=False):
"""
Reset the task. Task will be reloaded following a successful reset.
Notice: when running remotely the task will not be reset (as it will clear all logs and metrics)
:param set_started_on_success: automatically set started if reset was successful
:param force: force task reset even if running remotely
"""
if not running_remotely() or not self.is_main_task() or force:
super(Task, self).reset(set_started_on_success=set_started_on_success)
def _auto_generate(self, project_name=None, task_name=None, task_type=TaskTypes.training):
created_msg = make_message('Auto-generated at %(time)s by %(user)s@%(host)s')
project_id = None
if project_name:
project_id = get_or_create_project(self, project_name, created_msg)
tags = [self._development_tag] if not running_remotely() else []
extra_properties = {'system_tags': tags} if Session.check_min_api_version('2.3') else {'tags': tags}
req = tasks.CreateRequest(
name=task_name or make_message('Anonymous task (%(user)s@%(host)s %(time)s)'),
type=tasks.TaskTypeEnum(task_type.value),
comment=created_msg,
project=project_id,
input={'view': {}},
**extra_properties
)
res = self.send(req)
return res.response.id
def _validate(self, check_output_dest_credentials=False):
if running_remotely():
super(Task, self)._validate(check_output_dest_credentials=False)
PROC_MASTER_ID_ENV_VAR.set(os.getpid())
if task_type is None:
# Backwards compatibility: if called from Task.current_task and task_type
# was not specified, keep legacy default value of TaskTypes.training
task_type = cls.TaskTypes.training
elif isinstance(task_type, six.string_types):
task_type_lookup = {'testing': cls.TaskTypes.testing, 'inference': cls.TaskTypes.testing,
'train': cls.TaskTypes.training, 'training': cls.TaskTypes.training,}
if task_type not in task_type_lookup:
raise ValueError("Task type '{}' not supported, options are: {}".format(task_type,
list(task_type_lookup.keys())))
task_type = task_type_lookup[task_type]
try:
if not running_remotely():
task = cls._create_dev_task(
project_name,
task_name,
task_type,
reuse_last_task_id,
)
if output_uri:
task.output_uri = output_uri
elif cls.__default_output_uri:
task.output_uri = cls.__default_output_uri
else:
task = cls(
private=cls.__create_protection,
task_id=get_remote_task_id(),
log_to_backend=False,
)
"""
Connect current model with a specific task, only supported for preexisting models,
i.e. not supported on objects created with create_and_connect()
When running in debug mode (i.e. locally), the task is updated with the model object
(i.e. task input model is the load_model_id)
When running remotely (i.e. from a daemon) the model is being updated from the task
Notice! when running remotely the load_model_id is ignored and loaded from the task object
regardless of the code
:param task: Task object
"""
if self._task != task:
raise ValueError('Can only connect preexisting model to task, but this is a fresh model')
if running_remotely() and task.is_main_task():
self._floating_data.design = _Model._wrap_design(self._task.get_model_config_text())
self._floating_data.labels = self._task.get_labels_enumeration()
elif self._floating_data is not None:
# we copy configuration / labels if they exist, obviously someone wants them as the output base model
if _Model._unwrap_design(self._floating_data.design):
if not task.get_model_config_text():
task.set_model_config(config_text=self._floating_data.design)
else:
self._floating_data.design = _Model._wrap_design(self._task.get_model_config_text())
if self._floating_data.labels:
task.set_model_label_enumeration(self._floating_data.labels)
else:
self._floating_data.labels = self._task.get_labels_enumeration()
self.task._save_output_model(self)
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
elif len(args) == 1 and isinstance(args[0], six.string_types):
filename = args[0]
else:
filename = None
if not PatchXGBoostModelIO.__main_task:
return original_fn(f, *args, **kwargs)
# register input model
empty = _Empty()
# Hack: disabled
if False and running_remotely():
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
PatchXGBoostModelIO.__main_task)
model = original_fn(filename or f, *args, **kwargs)
else:
# try to load model before registering, in case we fail
model = original_fn(f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
PatchXGBoostModelIO.__main_task)
if empty.trains_in_model:
# noinspection PyBroadException
try:
model.trains_in_model = empty.trains_in_model
except Exception:
pass
return model