Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
@staticmethod
def load_model(file, **kwars):
return MyModel(file.get("x").value)
def _import_module(name, **kwargs):
if name.startswith(FakeKerasModule.__name__):
return FakeKerasModule
else:
return importlib.import_module(name, **kwargs)
with mock.patch("importlib.import_module") as import_module_mock:
import_module_mock.side_effect = _import_module
x = MyModel("x123")
path0 = os.path.join(model_path, "0")
with pytest.raises(MlflowException):
mlflow.keras.save_model(x, path0)
mlflow.keras.save_model(x, path0, keras_module=FakeKerasModule)
y = mlflow.keras.load_model(path0)
assert x == y
path1 = os.path.join(model_path, "1")
mlflow.keras.save_model(x, path1, keras_module=FakeKerasModule.__name__)
z = mlflow.keras.load_model(path1)
assert x == z
# Tests model log
with mlflow.start_run() as active_run:
with pytest.raises(MlflowException):
mlflow.keras.log_model(x, "model0")
mlflow.keras.log_model(x, "model0", keras_module=FakeKerasModule)
a = mlflow.keras.load_model("runs:/{}/model0".format(active_run.info.run_id))
assert x == a
mlflow.keras.log_model(x, "model1", keras_module=FakeKerasModule.__name__)
with mock.patch("importlib.import_module") as import_module_mock:
import_module_mock.side_effect = _import_module
x = MyModel("x123")
path0 = os.path.join(model_path, "0")
with pytest.raises(MlflowException):
mlflow.keras.save_model(x, path0)
mlflow.keras.save_model(x, path0, keras_module=FakeKerasModule)
y = mlflow.keras.load_model(path0)
assert x == y
path1 = os.path.join(model_path, "1")
mlflow.keras.save_model(x, path1, keras_module=FakeKerasModule.__name__)
z = mlflow.keras.load_model(path1)
assert x == z
# Tests model log
with mlflow.start_run() as active_run:
with pytest.raises(MlflowException):
mlflow.keras.log_model(x, "model0")
mlflow.keras.log_model(x, "model0", keras_module=FakeKerasModule)
a = mlflow.keras.load_model("runs:/{}/model0".format(active_run.info.run_id))
assert x == a
mlflow.keras.log_model(x, "model1", keras_module=FakeKerasModule.__name__)
b = mlflow.keras.load_model("runs:/{}/model1".format(active_run.info.run_id))
assert x == b
def _parse_uri(uri):
"""
Returns (name, version, stage). Since a models:/ URI can only have one of {version, stage},
it will return (name, version, None) or (name, None, stage).
"""
parsed = urllib.parse.urlparse(uri)
if parsed.scheme != "models":
raise MlflowException(ModelsArtifactRepository._improper_model_uri_msg(uri))
path = parsed.path
if not path.startswith('/') or len(path) <= 1:
raise MlflowException(ModelsArtifactRepository._improper_model_uri_msg(uri))
parts = path[1:].split("/")
if len(parts) != 2 or parts[0].strip() == "":
raise MlflowException(ModelsArtifactRepository._improper_model_uri_msg(uri))
if parts[1].isdigit():
return parts[0], int(parts[1]), None
else:
return parts[0], None, parts[1]
def _get_preferred_deployment_flavor(model_config):
"""
Obtains the flavor that MLflow would prefer to use when deploying the model.
If the model does not contain any supported flavors for deployment, an exception
will be thrown.
:param model_config: An MLflow model object
:return: The name of the preferred deployment flavor for the specified model
"""
if mleap.FLAVOR_NAME in model_config.flavors:
return mleap.FLAVOR_NAME
elif pyfunc.FLAVOR_NAME in model_config.flavors:
return pyfunc.FLAVOR_NAME
else:
raise MlflowException(
message=(
"The specified model does not contain any of the supported flavors for"
" deployment. The model contains the following flavors: {model_flavors}."
" Supported flavors: {supported_flavors}".format(
model_flavors=model_config.flavors.keys(),
supported_flavors=SUPPORTED_DEPLOYMENT_FLAVORS)),
error_code=RESOURCE_DOES_NOT_EXIST)
def verify_rest_response(response, endpoint):
"""Verify the return code and raise exception if the request was not successful."""
if response.status_code != 200:
if _can_parse_as_json(response.text):
raise RestException(json.loads(response.text))
else:
base_msg = "API request to endpoint %s failed with error code " \
"%s != 200" % (endpoint, response.status_code)
raise MlflowException("%s. Response body: '%s'" % (base_msg, response.text))
return response
def _get_run_info(self, run_uuid):
"""
Note: Will get both active and deleted runs.
"""
exp_id, run_dir = self._find_run_root(run_uuid)
if run_dir is None:
raise MlflowException("Run '%s' not found" % run_uuid,
databricks_pb2.RESOURCE_DOES_NOT_EXIST)
meta = read_yaml(run_dir, FileStore.META_DATA_FILE_NAME)
run_info = _read_persisted_run_info_dict(meta)
if run_info.experiment_id != exp_id:
logging.warning("Wrong experiment ID (%s) recorded for run '%s'. It should be %s. "
"Run will be ignored.", str(run_info.experiment_id),
str(run_info.run_id), str(exp_id), exc_info=True)
return None
return run_info
and is contained in the specified model. If one of these conditions
is not met, an exception is thrown.
:param model_config: An MLflow Model object
:param flavor: The deployment flavor to validate
"""
if flavor not in SUPPORTED_DEPLOYMENT_FLAVORS:
raise MlflowException(
message=(
"The specified flavor: `{flavor_name}` is not supported for deployment."
" Please use one of the supported flavors: {supported_flavor_names}".format(
flavor_name=flavor,
supported_flavor_names=SUPPORTED_DEPLOYMENT_FLAVORS)),
error_code=INVALID_PARAMETER_VALUE)
elif flavor not in model_config.flavors:
raise MlflowException(
message=("The specified model does not contain the specified deployment flavor:"
" `{flavor_name}`. Please use one of the following deployment flavors"
" that the model contains: {model_flavors}".format(
flavor_name=flavor, model_flavors=model_config.flavors.keys())),
error_code=RESOURCE_DOES_NOT_EXIST)
def get_run(self, run_id):
"""
Note: Will get both active and deleted runs.
"""
_validate_run_id(run_id)
run_info = self._get_run_info(run_id)
if run_info is None:
raise MlflowException("Run '%s' metadata is in invalid state." % run_id,
databricks_pb2.INVALID_STATE)
metrics = self.get_all_metrics(run_id)
params = self.get_all_params(run_id)
tags = self.get_all_tags(run_id)
return Run(run_info, RunData(metrics, params, tags))
# on attributes and on joined tables as we must keep all clauses in the same order
if order_by_list:
for order_by_clause in order_by_list:
clause_id += 1
(key_type, key, ascending) = SearchUtils.parse_order_by(order_by_clause)
if SearchUtils.is_attribute(key_type, '='):
order_value = getattr(SqlRun, SqlRun.get_attribute_name(key))
else:
if SearchUtils.is_metric(key_type, '='): # any valid comparator
entity = SqlLatestMetric
elif SearchUtils.is_tag(key_type, '='):
entity = SqlTag
elif SearchUtils.is_param(key_type, '='):
entity = SqlParam
else:
raise MlflowException("Invalid identifier type '%s'" % key_type,
error_code=INVALID_PARAMETER_VALUE)
# build a subquery first because we will join it in the main request so that the
# metric we want to sort on is available when we apply the sorting clause
subquery = session \
.query(entity) \
.filter(entity.key == key) \
.subquery()
ordering_joins.append(subquery)
order_value = subquery.c.value
# sqlite does not support NULLS LAST expression, so we sort first by
# presence of the field (and is_nan for metrics), then by actual value
# As the subqueries are created independently and used later in the
# same main query, the CASE WHEN columns need to have unique names to