Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_get_preferred_deployment_flavor_obtains_valid_flavor_from_model(pretrained_model):
model_config_path = os.path.join(
_download_artifact_from_uri(pretrained_model.model_uri), "MLmodel")
model_config = Model.load(model_config_path)
selected_flavor = mfs._get_preferred_deployment_flavor(model_config=model_config)
assert selected_flavor in mfs.SUPPORTED_DEPLOYMENT_FLAVORS
assert selected_flavor in model_config.flavors
_mlflow_conda_env(conda_env, additional_pip_deps=["xgboost"])
mlflow.xgboost.log_model(
xgb_model=model,
artifact_path=artifact_path,
conda_env=conda_env)
model_uri = "runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id,
artifact_path=artifact_path)
reloaded_model = mlflow.xgboost.load_model(model_uri=model_uri)
np.testing.assert_array_almost_equal(
model.predict(xgb_model.inference_dmatrix),
reloaded_model.predict(xgb_model.inference_dmatrix))
model_path = _download_artifact_from_uri(artifact_uri=model_uri)
model_config = Model.load(os.path.join(model_path, "MLmodel"))
assert pyfunc.FLAVOR_NAME in model_config.flavors
assert pyfunc.ENV in model_config.flavors[pyfunc.FLAVOR_NAME]
env_path = model_config.flavors[pyfunc.FLAVOR_NAME][pyfunc.ENV]
assert os.path.exists(os.path.join(model_path, env_path))
finally:
mlflow.end_run()
mlflow.set_tracking_uri(old_uri)
def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(
sequential_model, pytorch_custom_env):
artifact_path = "model"
with mlflow.start_run():
mlflow.pytorch.log_model(pytorch_model=sequential_model,
artifact_path=artifact_path,
conda_env=pytorch_custom_env)
model_path = _download_artifact_from_uri("runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path))
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
assert os.path.exists(saved_conda_env_path)
assert saved_conda_env_path != pytorch_custom_env
with open(pytorch_custom_env, "r") as f:
pytorch_custom_env_text = f.read()
with open(saved_conda_env_path, "r") as f:
saved_conda_env_text = f.read()
assert saved_conda_env_text == pytorch_custom_env_text
def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(
sklearn_knn_model, sklearn_custom_env):
artifact_path = "model"
with mlflow.start_run():
mlflow.sklearn.log_model(sk_model=sklearn_knn_model.model,
artifact_path=artifact_path,
conda_env=sklearn_custom_env)
model_uri = "runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id,
artifact_path=artifact_path)
model_path = _download_artifact_from_uri(artifact_uri=model_uri)
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
assert os.path.exists(saved_conda_env_path)
assert saved_conda_env_path != sklearn_custom_env
with open(sklearn_custom_env, "r") as f:
sklearn_custom_env_parsed = yaml.safe_load(f)
with open(saved_conda_env_path, "r") as f:
saved_conda_env_parsed = yaml.safe_load(f)
assert saved_conda_env_parsed == sklearn_custom_env_parsed
def test_model_log_without_specified_conda_env_uses_default_env_with_expected_dependencies(
onnx_model):
import mlflow.onnx
artifact_path = "model"
with mlflow.start_run():
mlflow.onnx.log_model(onnx_model=onnx_model, artifact_path=artifact_path, conda_env=None)
model_path = _download_artifact_from_uri("runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path))
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
with open(conda_env_path, "r") as f:
conda_env = yaml.safe_load(f)
assert conda_env == mlflow.onnx.get_default_conda_env()
def test_log_model_persists_specified_conda_env_in_mlflow_model_directory(
saved_tf_iris_model, tf_custom_env):
artifact_path = "model"
with mlflow.start_run():
mlflow.tensorflow.log_model(tf_saved_model_dir=saved_tf_iris_model.path,
tf_meta_graph_tags=saved_tf_iris_model.meta_graph_tags,
tf_signature_def_key=saved_tf_iris_model.signature_def_key,
artifact_path=artifact_path,
conda_env=tf_custom_env)
model_uri = "runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id,
artifact_path=artifact_path)
model_path = _download_artifact_from_uri(artifact_uri=model_uri)
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
assert os.path.exists(saved_conda_env_path)
assert saved_conda_env_path != tf_custom_env
with open(tf_custom_env, "r") as f:
tf_custom_env_text = f.read()
with open(saved_conda_env_path, "r") as f:
saved_conda_env_text = f.read()
assert saved_conda_env_text == tf_custom_env_text
def _get_flavor_backend(model_uri, **kwargs):
with TempDir() as tmp:
if ModelsArtifactRepository.is_models_uri(model_uri):
underlying_model_uri = ModelsArtifactRepository.get_underlying_uri(model_uri)
else:
underlying_model_uri = model_uri
local_path = _download_artifact_from_uri(posixpath.join(underlying_model_uri, "MLmodel"),
output_path=tmp.path())
model = Model.load(local_path)
flavor_name, flavor_backend = get_flavor_backend(model, **kwargs)
if flavor_backend is None:
raise Exception("No suitable flavor backend was found for the model.")
_logger.info("Selected backend for flavor '%s'", flavor_name)
return flavor_backend
- ``/Users/me/path/to/local/model``
- ``relative/path/to/local/model``
- ``s3://my_bucket/path/to/model``
- ``runs://run-relative/path/to/model``
- ``models://``
- ``models://``
For more information about supported URI schemes, see
`Referencing Artifacts `_.
:return: An `H2OEstimator model object
`_.
"""
local_model_path = _download_artifact_from_uri(artifact_uri=model_uri)
flavor_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name=FLAVOR_NAME)
# Flavor configurations for models saved in MLflow version <= 0.8.0 may not contain a
# `data` key; in this case, we assume the model artifact path to be `model.h2o`
h2o_model_file_path = os.path.join(local_model_path, flavor_conf.get("data", "model.h2o"))
return _load_model(path=h2o_model_file_path)
if LooseVersion(tensorflow.__version__) < LooseVersion('2.0.0'):
if not tf_sess:
tf_sess = tensorflow.get_default_session()
if not tf_sess:
raise MlflowException("No TensorFlow session found while calling load_model()." +
"You can set the default Tensorflow session before calling" +
" load_model via `session.as_default()`, or directly pass " +
"a session in which to load the model via the tf_sess " +
"argument.")
else:
if tf_sess:
warnings.warn("A TensorFlow session was passed into load_model, but the " +
"currently used version is TF 2.0 where sessions are deprecated. " +
"The tf_sess argument will be ignored.", FutureWarning)
local_model_path = _download_artifact_from_uri(artifact_uri=model_uri)
tf_saved_model_dir, tf_meta_graph_tags, tf_signature_def_key =\
_get_and_parse_flavor_configuration(model_path=local_model_path)
return _load_tensorflow_saved_model(tf_saved_model_dir=tf_saved_model_dir,
tf_meta_graph_tags=tf_meta_graph_tags,
tf_signature_def_key=tf_signature_def_key,
tf_sess=tf_sess)
def serve(self, model_uri, port, host):
"""
Generate R model locally.
"""
model_path = _download_artifact_from_uri(model_uri)
command = "mlflow::mlflow_rfunc_serve('{0}', port = {1}, host = '{2}')".format(
shlex_quote(model_path), port, host)
_execute(command)