Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _set_storage_uri(self, value):
value = value.rstrip('/') if value else None
self._storage_uri = StorageHelper.conform_url(value)
self.data.output.destination = self._storage_uri
self._edit(output_dest=self._storage_uri or ('' if Session.check_min_api_version('2.3') else None))
if self._storage_uri or self._output_model:
self.output_model.upload_storage_uri = self._storage_uri
def set_artifacts(self, artifacts_list=None):
"""
List of artifacts (tasks.Artifact) to update the task
:param list artifacts_list: list of artifacts (type tasks.Artifact)
"""
if not Session.check_min_api_version('2.3'):
return False
if not (isinstance(artifacts_list, (list, tuple))
and all(isinstance(a, tasks.Artifact) for a in artifacts_list)):
raise ValueError('Expected artifacts to [tasks.Artifacts]')
with self._edit_lock:
self.reload()
execution = self.data.execution
keys = [a.key for a in artifacts_list]
execution.artifacts = [a for a in execution.artifacts or [] if a.key not in keys] + artifacts_list
self._edit(execution=execution)
local_csv = Path(local_csv)
pd_artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
current_sha2, file_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
if name in self._last_artifacts_upload:
previous_sha2 = self._last_artifacts_upload[name]
if previous_sha2 == current_sha2:
# nothing to do, we can skip the upload
try:
local_csv.unlink()
except Exception:
pass
return
self._last_artifacts_upload[name] = current_sha2
# If old trains-server, upload as debug image
if not Session.check_min_api_version('2.3'):
logger.report_image(title='artifacts', series=name, local_path=local_csv.as_posix(),
delete_after_upload=True, iteration=self._task.get_last_iteration(),
max_image_history=2)
return
# Find our artifact
artifact = None
for an_artifact in self._task_artifact_list:
if an_artifact.key == name:
artifact = an_artifact
break
file_size = local_csv.stat().st_size
# upload file
uri = self._upload_local_file(local_csv, name, delete_after_upload=True,
:param label_enumeration: dictionary of string to integer, enumerating the model output to labels
example: {'background': 0 , 'person': 1}
:param name: optional, name for the newly imported model
:param tags: optional, list of strings as tags
:param comment: optional, string description for the model
:param is_package: Boolean. Indicates that the imported weights file is a package.
If True, and a new model was created, a package tag will be added.
:param create_as_published: Boolean. If True, and a new model is created, it will be published.
:param framework: optional, string name of the framework of the model or Framework
"""
config_text = cls._resolve_config(config_text=config_text, config_dict=config_dict)
weights_url = StorageHelper.conform_url(weights_url)
if not weights_url:
raise ValueError("Please provide a valid weights_url parameter")
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]}
result = _Model._get_default_session().send(models.GetAllRequest(
uri=[weights_url],
only_fields=["id", "name", "created"],
**extra
))
if result.response.models:
logger = get_logger()
logger.debug('A model with uri "{}" already exists. Selecting it'.format(weights_url))
model = get_single_result(
entity='model',
query=weights_url,
results=result.response.models,
log=logger,
def _complete_update_for_task(self, uri, task_id=None, name=None, comment=None, tags=None, override_model_id=None,
cb=None):
if self._data:
name = name or self.data.name
comment = comment or self.data.comment
tags = tags or (self.data.system_tags if hasattr(self.data, 'system_tags') else self.data.tags)
uri = (uri or self.data.uri) if not override_model_id else None
if tags:
extra = {'system_tags': tags} if Session.check_min_api_version('2.3') else {'tags': tags}
else:
extra = {}
res = self.send(
models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment,
override_model_id=override_model_id, **extra))
if self.id is None:
# update the model id. in case it was just created, this will trigger a reload of the model object
self.id = res.response.id
else:
self.reload()
try:
if cb:
cb(uri)
except Exception as ex:
self.log.warning('Failed calling callback on complete_update_for_task: %s' % str(ex))
pass
if repo_info.is_empty():
_log("no info for {}", script_dir)
repo_root = repo_info.root or script_dir
working_dir = cls._get_working_dir(repo_root)
entry_point = cls._get_entry_point(repo_root, script_path)
if check_uncommitted:
diff = cls._get_script_code(script_path.as_posix()) \
if not plugin or not repo_info.commit else repo_info.diff
else:
diff = ''
# if this is not jupyter, get the requirements.txt
requirements = ''
# create requirements if backend supports requirements
# if jupyter is present, requirements will be created in the background, when saving a snapshot
if not jupyter_filepath and Session.check_min_api_version('2.2'):
script_requirements = ScriptRequirements(
Path(repo_root).as_posix() if repo_info.url else script_path.as_posix())
if create_requirements:
requirements = script_requirements.get_requirements()
else:
script_requirements = None
script_info = dict(
repository=furl(repo_info.url).remove(username=True, password=True).tostr(),
branch=repo_info.branch,
version_num=repo_info.commit,
entry_point=entry_point,
working_dir=working_dir,
diff=diff,
requirements={'pip': requirements} if requirements else None,
)
def clone(self, name, comment=None, child=True, tags=None, task=None, ready=True):
"""
Clone this model into a new model.
:param name: Name for the new model
:param comment: Optional comment for the new model
:param child: Should the new model be a child of this model? (default True)
:return: The new model's ID
"""
data = self.data
assert isinstance(data, models.Model)
parent = self.id if child else None
extra = {'system_tags': tags or data.system_tags} \
if Session.check_min_api_version('2.3') else {'tags': tags or data.tags}
req = models.CreateRequest(
uri=data.uri,
name=name,
labels=data.labels,
comment=comment or data.comment,
framework=data.framework,
design=data.design,
ready=ready,
project=data.project,
parent=parent,
task=task,
**extra
)
res = self.send(req)
return res.response.id
# serialize notebook to a temp file
# noinspection PyBroadException
try:
get_ipython().run_line_magic('notebook', local_jupyter_filename)
except Exception as ex:
continue
# get notebook python script
script_code, resources = _script_exporter.from_filename(local_jupyter_filename)
current_script_hash = hash(script_code)
if prev_script_hash and prev_script_hash == current_script_hash:
continue
requirements_txt = ''
# parse jupyter python script and prepare pip requirements (pigar)
# if backend supports requirements
if file_import_modules and Session.check_min_api_version('2.2'):
fmodules, _ = file_import_modules(notebook.parts[-1], script_code)
installed_pkgs = get_installed_pkgs_detail()
reqs = ReqsModules()
for name in fmodules:
if name in installed_pkgs:
pkg_name, version = installed_pkgs[name]
reqs.add(pkg_name, version, fmodules[name])
requirements_txt = ScriptRequirements.create_requirements_txt(reqs)
# update script
prev_script_hash = current_script_hash
data_script = task.data.script
data_script.diff = script_code
data_script.requirements = {'pip': requirements_txt}
task._update_script(script=data_script)
# update requirements
def _auto_generate(self, project_name=None, task_name=None, task_type=TaskTypes.training):
created_msg = make_message('Auto-generated at %(time)s by %(user)s@%(host)s')
project_id = None
if project_name:
project_id = get_or_create_project(self, project_name, created_msg)
tags = [self._development_tag] if not running_remotely() else []
extra_properties = {'system_tags': tags} if Session.check_min_api_version('2.3') else {'tags': tags}
req = tasks.CreateRequest(
name=task_name or make_message('Anonymous task (%(user)s@%(host)s %(time)s)'),
type=tasks.TaskTypeEnum(task_type.value),
comment=created_msg,
project=project_id,
input={'view': {}},
**extra_properties
)
res = self.send(req)
return res.response.id
try:
if 'IPython' in sys.modules:
from IPython import get_ipython
ip = get_ipython()
if ip and matplotlib.is_interactive():
# instead of hooking ipython, we should hook the matplotlib
import matplotlib.pyplot as plt
PatchedMatplotlib.__patched_original_draw_all = plt.draw_all
plt.draw_all = PatchedMatplotlib.__patched_draw_all
# ip.events.register('post_execute', PatchedMatplotlib.ipython_post_execute_hook)
except Exception:
pass
# update api version
from ..backend_api import Session
PatchedMatplotlib._support_image_plot = Session.check_min_api_version('2.2')
# create plotly renderer
try:
from plotly import optional_imports
PatchedMatplotlib._matplotlylib = optional_imports.get_module('plotly.matplotlylib')
PatchedMatplotlib._plotly_renderer = PatchedMatplotlib._matplotlylib.PlotlyRenderer()
except Exception:
pass
return True