Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_search_metrics(self):
experiment_id = self._experiment_factory('search_metric')
r1 = self._run_factory(self._get_run_configs(experiment_id)).info.run_id
r2 = self._run_factory(self._get_run_configs(experiment_id)).info.run_id
self.store.log_metric(r1, entities.Metric("common", 1.0, 1, 0))
self.store.log_metric(r2, entities.Metric("common", 1.0, 1, 0))
self.store.log_metric(r1, entities.Metric("measure_a", 1.0, 1, 0))
self.store.log_metric(r2, entities.Metric("measure_a", 200.0, 2, 0))
self.store.log_metric(r2, entities.Metric("measure_a", 400.0, 3, 0))
self.store.log_metric(r1, entities.Metric("m_a", 2.0, 2, 0))
self.store.log_metric(r2, entities.Metric("m_b", 3.0, 2, 0))
self.store.log_metric(r2, entities.Metric("m_b", 4.0, 8, 0)) # this is last timestamp
self.store.log_metric(r2, entities.Metric("m_b", 8.0, 3, 0))
filter_string = "metrics.common = 1.0"
six.assertCountEqual(self, [r1, r2], self._search(experiment_id, filter_string))
filter_string = "metrics.common > 0.0"
six.assertCountEqual(self, [r1, r2], self._search(experiment_id, filter_string))
filter_string = "metrics.common >= 0.0"
six.assertCountEqual(self, [r1, r2], self._search(experiment_id, filter_string))
filter_string = "metrics.common < 4.0"
six.assertCountEqual(self, [r1, r2], self._search(experiment_id, filter_string))
filter_string = "metrics.common <= 4.0"
six.assertCountEqual(self, [r1, r2], self._search(experiment_id, filter_string))
def _check(metric, key, value, timestamp, step):
assert type(metric) == Metric
assert metric.key == key
assert metric.value == value
assert metric.timestamp == timestamp
assert metric.step == step
def create_and_log_run(names):
name = str(names[0]) + "/" + names[1]
run_id = self.store.create_run(
experiment_id,
user_id="MrDuck",
start_time=123,
tags=[entities.RunTag(mlflow_tags.MLFLOW_RUN_NAME, name),
entities.RunTag("metric", names[1])]
).info.run_id
if names[0] is not None:
self.store.log_metric(run_id, entities.Metric("x", float(names[0]), 1, 0))
self.store.log_metric(run_id, entities.Metric("y", float(names[1]), 1, 0))
self.store.log_param(run_id, entities.Param("metric", names[1]))
return run_id
with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
store.delete_tag("some_uuid", "t1")
body = message_to_json(DeleteTag(run_id="some_uuid", key="t1"))
self._verify_requests(mock_http, creds,
"runs/delete-tag", "POST", body)
with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
store.log_metric("u2", Metric("m1", 0.87, 12345, 3))
body = message_to_json(LogMetric(
run_uuid="u2", run_id="u2", key="m1", value=0.87, timestamp=12345, step=3))
self._verify_requests(mock_http, creds,
"runs/log-metric", "POST", body)
with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
metrics = [Metric("m1", 0.87, 12345, 0), Metric("m2", 0.49, 12345, -1),
Metric("m3", 0.58, 12345, 2)]
params = [Param("p1", "p1val"), Param("p2", "p2val")]
tags = [RunTag("t1", "t1val"), RunTag("t2", "t2val")]
store.log_batch(run_id="u2", metrics=metrics, params=params, tags=tags)
metric_protos = [metric.to_proto() for metric in metrics]
param_protos = [param.to_proto() for param in params]
tag_protos = [tag.to_proto() for tag in tags]
body = message_to_json(LogBatch(run_id="u2", metrics=metric_protos,
params=param_protos, tags=tag_protos))
self._verify_requests(mock_http, creds,
"runs/log-batch", "POST", body)
with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
store.delete_run("u25")
self._verify_requests(mock_http, creds,
"runs/delete", "POST",
message_to_json(DeleteRun(run_id="u25")))
def test_weird_metric_names(self):
WEIRD_METRIC_NAME = "this is/a weird/but valid metric"
fs = FileStore(self.test_root)
run_id = self.exp_data[FileStore.DEFAULT_EXPERIMENT_ID]["runs"][0]
fs.log_metric(run_id, Metric(WEIRD_METRIC_NAME, 10, 1234, 0))
run = fs.get_run(run_id)
assert run.data.metrics[WEIRD_METRIC_NAME] == 10
history = fs.get_metric_history(run_id, WEIRD_METRIC_NAME)
assert len(history) == 1
metric = history[0]
assert metric.key == WEIRD_METRIC_NAME
assert metric.value == 10
assert metric.timestamp == 1234
def test_log_batch_same_metric_repeated_multiple_reqs(self):
run = self._run_factory()
metric0 = Metric(key="metric-key", value=1, timestamp=2, step=0)
metric1 = Metric(key="metric-key", value=2, timestamp=3, step=0)
self.store.log_batch(run.info.run_id, params=[], metrics=[metric0], tags=[])
self._verify_logged(run.info.run_id, params=[], metrics=[metric0], tags=[])
self.store.log_batch(run.info.run_id, params=[], metrics=[metric1], tags=[])
self._verify_logged(run.info.run_id, params=[], metrics=[metric0, metric1], tags=[])
def test_search_metrics(self):
experiment_id = self._experiment_factory('search_metric')
r1 = self._run_factory(self._get_run_configs(experiment_id)).info.run_id
r2 = self._run_factory(self._get_run_configs(experiment_id)).info.run_id
self.store.log_metric(r1, entities.Metric("common", 1.0, 1, 0))
self.store.log_metric(r2, entities.Metric("common", 1.0, 1, 0))
self.store.log_metric(r1, entities.Metric("measure_a", 1.0, 1, 0))
self.store.log_metric(r2, entities.Metric("measure_a", 200.0, 2, 0))
self.store.log_metric(r2, entities.Metric("measure_a", 400.0, 3, 0))
self.store.log_metric(r1, entities.Metric("m_a", 2.0, 2, 0))
self.store.log_metric(r2, entities.Metric("m_b", 3.0, 2, 0))
self.store.log_metric(r2, entities.Metric("m_b", 4.0, 8, 0)) # this is last timestamp
self.store.log_metric(r2, entities.Metric("m_b", 8.0, 3, 0))
filter_string = "metrics.common = 1.0"
six.assertCountEqual(self, [r1, r2], self._search(experiment_id, filter_string))
filter_string = "metrics.common > 0.0"
six.assertCountEqual(self, [r1, r2], self._search(experiment_id, filter_string))
filter_string = "metrics.common >= 0.0"
six.assertCountEqual(self, [r1, r2], self._search(experiment_id, filter_string))
filter_string = "metrics.common < 4.0"
six.assertCountEqual(self, [r1, r2], self._search(experiment_id, filter_string))
def test_search_full(self):
experiment_id = self._experiment_factory('search_params')
r1 = self._run_factory(self._get_run_configs(experiment_id)).info.run_id
r2 = self._run_factory(self._get_run_configs(experiment_id)).info.run_id
self.store.log_param(r1, entities.Param('generic_param', 'p_val'))
self.store.log_param(r2, entities.Param('generic_param', 'p_val'))
self.store.log_param(r1, entities.Param('p_a', 'abc'))
self.store.log_param(r2, entities.Param('p_b', 'ABC'))
self.store.log_metric(r1, entities.Metric("common", 1.0, 1, 0))
self.store.log_metric(r2, entities.Metric("common", 1.0, 1, 0))
self.store.log_metric(r1, entities.Metric("m_a", 2.0, 2, 0))
self.store.log_metric(r2, entities.Metric("m_b", 3.0, 2, 0))
self.store.log_metric(r2, entities.Metric("m_b", 4.0, 8, 0))
self.store.log_metric(r2, entities.Metric("m_b", 8.0, 3, 0))
filter_string = "params.generic_param = 'p_val' and metrics.common = 1.0"
six.assertCountEqual(self, [r1, r2], self._search(experiment_id, filter_string))
# all params and metrics match
filter_string = ("params.generic_param = 'p_val' and metrics.common = 1.0"
"and metrics.m_a > 1.0")
six.assertCountEqual(self, [r1], self._search(experiment_id, filter_string))
# test with mismatch param
filter_string = ("params.random_bad_name = 'p_val' and metrics.common = 1.0"
"and metrics.m_a > 1.0")
six.assertCountEqual(self, [], self._search(experiment_id, filter_string))
def test_search_full(self):
experiment_id = self._experiment_factory('search_params')
r1 = self._run_factory(self._get_run_configs(experiment_id)).info.run_id
r2 = self._run_factory(self._get_run_configs(experiment_id)).info.run_id
self.store.log_param(r1, entities.Param('generic_param', 'p_val'))
self.store.log_param(r2, entities.Param('generic_param', 'p_val'))
self.store.log_param(r1, entities.Param('p_a', 'abc'))
self.store.log_param(r2, entities.Param('p_b', 'ABC'))
self.store.log_metric(r1, entities.Metric("common", 1.0, 1, 0))
self.store.log_metric(r2, entities.Metric("common", 1.0, 1, 0))
self.store.log_metric(r1, entities.Metric("m_a", 2.0, 2, 0))
self.store.log_metric(r2, entities.Metric("m_b", 3.0, 2, 0))
self.store.log_metric(r2, entities.Metric("m_b", 4.0, 8, 0))
self.store.log_metric(r2, entities.Metric("m_b", 8.0, 3, 0))
filter_string = "params.generic_param = 'p_val' and metrics.common = 1.0"
six.assertCountEqual(self, [r1, r2], self._search(experiment_id, filter_string))
# all params and metrics match
filter_string = ("params.generic_param = 'p_val' and metrics.common = 1.0"
"and metrics.m_a > 1.0")
six.assertCountEqual(self, [r1], self._search(experiment_id, filter_string))
# test with mismatch param
def test_set_deleted_run(self):
"""
Setting metrics/tags/params/updating run info should not be allowed on deleted runs.
"""
fs = FileStore(self.test_root)
exp_id = self.experiments[random_int(0, len(self.experiments) - 1)]
run_id = self.exp_data[exp_id]['runs'][0]
fs.delete_run(run_id)
assert fs.get_run(run_id).info.lifecycle_stage == LifecycleStage.DELETED
with pytest.raises(MlflowException):
fs.set_tag(run_id, RunTag('a', 'b'))
with pytest.raises(MlflowException):
fs.log_metric(run_id, Metric('a', 0.0, timestamp=0, step=0))
with pytest.raises(MlflowException):
fs.log_param(run_id, Param('a', 'b'))