Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_message_to_json():
json_out = message_to_json(Experiment("123", "name", "arty", 'active').to_proto())
assert json.loads(json_out) == {
"experiment_id": "123",
"name": "name",
"artifact_location": "arty",
"lifecycle_stage": 'active',
}
def test_get_experiment_by_name(self, store_class):
creds = MlflowHostCreds('https://hello')
store = store_class(lambda: creds)
with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
response = mock.MagicMock
response.status_code = 200
experiment = Experiment(
experiment_id="123", name="abc", artifact_location="/abc",
lifecycle_stage=LifecycleStage.ACTIVE)
response.text = json.dumps({
"experiment": json.loads(message_to_json(experiment.to_proto()))})
mock_http.return_value = response
result = store.get_experiment_by_name("abc")
expected_message0 = GetExperimentByName(experiment_name="abc")
self._verify_requests(mock_http, creds,
"experiments/get-by-name", "GET",
message_to_json(expected_message0))
assert result.experiment_id == experiment.experiment_id
assert result.name == experiment.name
assert result.artifact_location == experiment.artifact_location
assert result.lifecycle_stage == experiment.lifecycle_stage
# Test GetExperimentByName against nonexistent experiment
mock_http.reset_mock()
nonexistent_exp_response = mock.MagicMock
nonexistent_exp_response.status_code = 404
nonexistent_exp_response.text =\
MlflowException("Exp doesn't exist!", RESOURCE_DOES_NOT_EXIST).serialize_as_json()
def get_latest_versions(self, registered_model, stages=None):
"""
Latest version models for each requested stage. If no ``stages`` argument is provided,
returns the latest version for each stage.
:param registered_model: :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
:param stages: List of desired stages. If input list is None, return latest versions for
for 'Staging' and 'Production' stages.
:return: List of `:py:class:`mlflow.entities.model_registry.ModelVersionDetailed` objects.
"""
req_body = message_to_json(GetLatestVersions(
registered_model=registered_model.to_proto(), stages=stages))
response_proto = self._call_endpoint(GetLatestVersions, req_body)
return [ModelVersionDetailed.from_proto(model_version_detailed)
for model_version_detailed in response_proto.model_versions_detailed]
def get_experiment(self, experiment_id):
"""
Fetch the experiment from the backend store.
:param experiment_id: String id for the experiment
:return: A single :py:class:`mlflow.entities.Experiment` object if it exists,
otherwise raises an Exception.
"""
req_body = message_to_json(GetExperiment(experiment_id=str(experiment_id)))
response_proto = self._call_endpoint(GetExperiment, req_body)
return Experiment.from_proto(response_proto.experiment)
def _search_runs(self, experiment_ids, filter_string, run_view_type, max_results, order_by,
page_token):
experiment_ids = [str(experiment_id) for experiment_id in experiment_ids]
sr = SearchRuns(experiment_ids=experiment_ids,
filter=filter_string,
run_view_type=ViewType.to_proto(run_view_type),
max_results=max_results,
order_by=order_by,
page_token=page_token)
req_body = message_to_json(sr)
response_proto = self._call_endpoint(SearchRuns, req_body)
runs = [Run.from_proto(proto_run) for proto_run in response_proto.runs]
# If next_page_token is not set, we will see it as "". We need to convert this to None.
next_page_token = None
if response_proto.next_page_token:
next_page_token = response_proto.next_page_token
return runs, next_page_token
def _restore_run():
request_message = _get_request_message(RestoreRun())
_get_tracking_store().restore_run(request_message.run_id)
response_message = RestoreRun.Response()
response = Response(mimetype='application/json')
response.set_data(message_to_json(response_message))
return response
def delete_model_version(self, model_version):
"""
Delete model version in backend.
:param model_version: :py:class:`mlflow.entities.model_registry.ModelVersion` object.
:return: None
"""
req_body = message_to_json(DeleteModelVersion(model_version=model_version.to_proto()))
self._call_endpoint(DeleteModelVersion, req_body)
response_message = SearchRuns.Response()
run_view_type = ViewType.ACTIVE_ONLY
if request_message.HasField('run_view_type'):
run_view_type = ViewType.from_proto(request_message.run_view_type)
filter_string = request_message.filter
max_results = request_message.max_results
experiment_ids = request_message.experiment_ids
order_by = request_message.order_by
page_token = request_message.page_token
run_entities = _get_tracking_store().search_runs(experiment_ids, filter_string, run_view_type,
max_results, order_by, page_token)
response_message.runs.extend([r.to_proto() for r in run_entities])
if run_entities.token:
response_message.next_page_token = run_entities.token
response = Response(mimetype='application/json')
response.set_data(message_to_json(response_message))
return response
def create_registered_model(self, name):
"""
Create a new registered model in backend store.
:param name: Name of the new model. This is expected to be unique in the backend store.
:return: A single object of :py:class:`mlflow.entities.model_registry.RegisteredModel`
created in the backend.
"""
req_body = message_to_json(CreateRegisteredModel(name=name))
response_proto = self._call_endpoint(CreateRegisteredModel, req_body)
return RegisteredModel.from_proto(response_proto.registered_model)
def _update_experiment():
request_message = _get_request_message(UpdateExperiment())
if request_message.new_name:
_get_tracking_store().rename_experiment(request_message.experiment_id,
request_message.new_name)
response_message = UpdateExperiment.Response()
response = Response(mimetype='application/json')
response.set_data(message_to_json(response_message))
return response