Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_register_model_with_non_runs_uri():
create_model_patch = mock.patch.object(MlflowClient, "create_registered_model",
return_value=RegisteredModel("Model 1"))
create_version_patch = mock.patch.object(
MlflowClient, "create_model_version",
return_value=ModelVersion(RegisteredModel("Model 1"), 1))
with create_model_patch, create_version_patch:
register_model("s3:/some/path/to/model", "Model 1")
MlflowClient.create_registered_model.assert_called_once_with("Model 1")
MlflowClient.create_model_version.assert_called_once_with("Model 1", run_id=None,
source="s3:/some/path/to/model")
def test_metric_timestamp(tracking_uri_mock):
with mlflow.start_run() as active_run:
mlflow.log_metric("name_1", 25)
mlflow.log_metric("name_1", 30)
run_id = active_run.info.run_uuid
# Check that metric timestamps are between run start and finish
client = mlflow.tracking.MlflowClient()
history = client.get_metric_history(run_id, "name_1")
finished_run = client.get_run(run_id)
assert len(history) == 2
assert all([
m.timestamp >= finished_run.info.start_time and m.timestamp <= finished_run.info.end_time
for m in history
])
def _already_ran(entry_point_name, parameters, git_commit, experiment_id=None):
"""Best-effort detection of if a run with the given entrypoint name,
parameters, and experiment id already ran. The run must have completed
successfully and have at least the parameters provided.
"""
experiment_id = experiment_id if experiment_id is not None else _get_experiment_id()
client = mlflow.tracking.MlflowClient()
all_run_infos = reversed(client.list_run_infos(experiment_id))
for run_info in all_run_infos:
full_run = client.get_run(run_info.run_id)
tags = full_run.data.tags
if tags.get(mlflow_tags.MLFLOW_PROJECT_ENTRY_POINT, None) != entry_point_name:
continue
match_failed = False
for param_key, param_value in six.iteritems(parameters):
run_value = full_run.data.params.get(param_key)
if run_value != param_value:
match_failed = True
break
if match_failed:
continue
if run_info.status != RunStatus.FINISHED:
def __init__(self, experiment_name):
self.client = MlflowClient()
self.experiment = self.client.get_experiment_by_name(experiment_name)
self.experiment_id = self.experiment.experiment_id
self.url = "{host}/_mlflow/#/experiments/{experiment_id}/runs/".format(
host=os.environ["DBJL_HOST"].strip("/"), experiment_id=self.experiment_id
)
self.runs = []
self.table = None
def _print_description_and_log_tags(self):
_logger.info(
"=== Launched MLflow run as Databricks job run with ID %s."
" Getting run status page URL... ===",
self._databricks_run_id)
run_info = self._job_runner.jobs_runs_get(self._databricks_run_id)
jobs_page_url = run_info["run_page_url"]
_logger.info("=== Check the run's status at %s ===", jobs_page_url)
host_creds = databricks_utils.get_databricks_host_creds(self._job_runner.databricks_profile)
tracking.MlflowClient().set_tag(self._mlflow_run_id,
MLFLOW_DATABRICKS_RUN_URL, jobs_page_url)
tracking.MlflowClient().set_tag(self._mlflow_run_id,
MLFLOW_DATABRICKS_SHELL_JOB_RUN_ID, self._databricks_run_id)
tracking.MlflowClient().set_tag(self._mlflow_run_id,
MLFLOW_DATABRICKS_WEBAPP_URL, host_creds.host)
job_id = run_info.get('job_id')
# In some releases of Databricks we do not return the job ID. We start including it in DB
# releases 2.80 and above.
if job_id is not None:
tracking.MlflowClient().set_tag(self._mlflow_run_id,
MLFLOW_DATABRICKS_SHELL_JOB_ID, job_id)
from __future__ import print_function
from pyspark.sql import SparkSession, Row
import traceback
import os, time
import mlflow
from mlflow_fun.common import mlflow_utils
from mlflow_fun.metrics.dataframe_builder import get_data_frame_builder
from mlflow_fun.metrics import file_api
mlflow_utils.dump_mlflow_info()
mlflow_client = mlflow.tracking.MlflowClient()
spark = SparkSession.builder.appName("mlflow_metrics").enableHiveSupport().getOrCreate()
class TableBuilder(object):
def __init__(self, database, data_dir, data_frame="slow", use_parquet=False):
print("database:",database)
print("data_dir:",data_dir)
print("use_parquet:",use_parquet)
print("data_frame:",data_frame)
self.database = database
self.data_dir = data_dir
self.use_parquet = use_parquet
self.file_api = file_api.get_file_api(data_dir)
print("file_api:",type(self.file_api).__name__)
self.df_builder = get_data_frame_builder(data_frame)
print("df_builder:",type(self.df_builder).__name__)
self.delimiter = "\t"
def _get_or_run(entrypoint, parameters, git_commit, use_cache=True):
existing_run = _already_ran(entrypoint, parameters, git_commit)
if use_cache and existing_run:
print("Found existing run for entrypoint=%s and parameters=%s" %
(entrypoint, parameters))
return existing_run
print("Launching new run for entrypoint=%s and parameters=%s" %
(entrypoint, parameters))
submitted_run = mlflow.run(".", entrypoint, parameters=parameters, use_conda=False)
return mlflow.tracking.MlflowClient().get_run(submitted_run.run_id)
def __init__(self, mlflow_client=None, use_src_user_id=False):
self.client = mlflow_client or mlflow.tracking.MlflowClient()
self.run_importer = RunImporter(self.client, use_src_user_id)
parser.add_argument("--artifact_max_level", dest="artifact_max_level", help="Number of artifact levels to recurse", required=False, default=1, type=int)
parser.add_argument("--sort", dest="sort", help="Show run info", required=False, default=False, action='store_true')
parser.add_argument("--pretty_time", dest="pretty_time", help="Show info", required=False, default=False, action='store_true')
parser.add_argument("--duration", dest="duration", help="Show duration", required=False, default=False, action='store_true')
parser.add_argument("--nan_to_blank", dest="nan_to_blank", help="nan_to_blank", required=False, default=False, action='store_true')
parser.add_argument("--skip_params", dest="skip_params", help="skip_params", required=False, default=False, action='store_true')
parser.add_argument("--skip_metrics", dest="skip_metrics", help="skip_metrics", required=False, default=False, action='store_true')
parser.add_argument("--skip_tags", dest="skip_tags", help="skip_tags", required=False, default=False, action='store_true')
parser.add_argument("--csv_file", dest="csv_file", help="CSV file")
args = parser.parse_args()
print("Options:")
for arg in vars(args):
print(" {}: {}".format(arg,getattr(args, arg)))
client = mlflow.tracking.MlflowClient()
smart_client = MlflowSmartClient()
exp = mlflow_utils.get_experiment(client, args.experiment_id_or_name)
exp_id = exp.experiment_id
print("experiment_id:",exp_id)
runs = smart_client.list_runs(exp_id)
converter = RunsToPandasConverter(args.sort, args.pretty_time, args.duration, args.skip_params, args.skip_metrics, args.skip_tags)
df = converter.to_pandas_df(runs)
print(tabulate(df, headers='keys', tablefmt='psql'))
path = "exp_runs_{}.csv".format(exp_id) if args.csv_file is None else args.csv_file
print("Output CSV file:",path)
with open(path, 'w') as f:
df.to_csv(f, index=False)
def train(training_data, max_runs, epochs, metric, algo, seed):
"""
Run hyperparameter optimization.
"""
# create random file to store run ids of the training tasks
tracking_client = mlflow.tracking.MlflowClient()
def new_eval(nepochs,
experiment_id,
null_train_loss,
null_valid_loss,
null_test_loss,
return_all=False):
"""
Create a new eval function
:param nepochs: Number of epochs to train the model.
:experiment_id: Experiment id for the training run
:valid_null_loss: Loss of a null model on the validation dataset
:test_null_loss: Loss of a null model on the test dataset.
:return_test_loss: Return both validation and test loss if set.