Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_log_batch(self):
experiment_id = self._experiment_factory('log_batch')
run_id = self._run_factory(self._get_run_configs(experiment_id)).info.run_id
metric_entities = [Metric("m1", 0.87, 12345, 0), Metric("m2", 0.49, 12345, 1)]
param_entities = [Param("p1", "p1val"), Param("p2", "p2val")]
tag_entities = [RunTag("t1", "t1val"), RunTag("t2", "t2val")]
self.store.log_batch(
run_id=run_id, metrics=metric_entities, params=param_entities, tags=tag_entities)
run = self.store.get_run(run_id)
assert run.data.tags == {"t1": "t1val", "t2": "t2val"}
assert run.data.params == {"p1": "p1val", "p2": "p2val"}
metric_histories = sum(
[self.store.get_metric_history(run_id, key) for key in run.data.metrics], [])
metrics = [(m.key, m.value, m.timestamp, m.step) for m in metric_histories]
assert set(metrics) == set([("m1", 0.87, 12345, 0), ("m2", 0.49, 12345, 1)])
def test_log_batch_params_idempotency(self):
fs = FileStore(self.test_root)
run = self._create_run(fs)
params = [Param("p-key", "p-val")]
fs.log_batch(run.info.run_id, metrics=[], params=params, tags=[])
fs.log_batch(run.info.run_id, metrics=[], params=params, tags=[])
self._verify_logged(fs, run.info.run_id, metrics=[], params=params, tags=[])
def test_search_params(self):
experiment_id = self._experiment_factory('search_params')
r1 = self._run_factory(self._get_run_configs(experiment_id)).info.run_id
r2 = self._run_factory(self._get_run_configs(experiment_id)).info.run_id
self.store.log_param(r1, entities.Param('generic_param', 'p_val'))
self.store.log_param(r2, entities.Param('generic_param', 'p_val'))
self.store.log_param(r1, entities.Param('generic_2', 'some value'))
self.store.log_param(r2, entities.Param('generic_2', 'another value'))
self.store.log_param(r1, entities.Param('p_a', 'abc'))
self.store.log_param(r2, entities.Param('p_b', 'ABC'))
# test search returns both runs
filter_string = "params.generic_param = 'p_val'"
six.assertCountEqual(self, [r1, r2], self._search(experiment_id, filter_string))
# test search returns appropriate run (same key different values per run)
filter_string = "params.generic_2 = 'some value'"
six.assertCountEqual(self, [r1], self._search(experiment_id, filter_string))
filter_string = "params.generic_2 = 'another value'"
six.assertCountEqual(self, [r2], self._search(experiment_id, filter_string))
filter_string = "params.generic_param = 'wrong_val'"
six.assertCountEqual(self, [], self._search(experiment_id, filter_string))
def test_log_batch_internal_error(self):
# Verify that internal errors during log_batch result in MlflowExceptions
fs = FileStore(self.test_root)
run = self._create_run(fs)
def _raise_exception_fn(*args, **kwargs): # pylint: disable=unused-argument
raise Exception("Some internal error")
with mock.patch(FILESTORE_PACKAGE + ".FileStore.log_metric") as log_metric_mock, \
mock.patch(FILESTORE_PACKAGE + ".FileStore.log_param") as log_param_mock, \
mock.patch(FILESTORE_PACKAGE + ".FileStore.set_tag") as set_tag_mock:
log_metric_mock.side_effect = _raise_exception_fn
log_param_mock.side_effect = _raise_exception_fn
set_tag_mock.side_effect = _raise_exception_fn
for kwargs in [{"metrics": [Metric("a", 3, 1, 0)]}, {"params": [Param("b", "c")]},
{"tags": [RunTag("c", "d")]}]:
log_batch_kwargs = {"metrics": [], "params": [], "tags": []}
log_batch_kwargs.update(kwargs)
print(log_batch_kwargs)
with self.assertRaises(MlflowException) as e:
fs.log_batch(run.info.run_id, **log_batch_kwargs)
self.assertIn(str(e.exception.message), "Some internal error")
assert e.exception.error_code == ErrorCode.Name(INTERNAL_ERROR)
def test_validate_batch_log_limits():
too_many_metrics = [Metric("metric-key-%s" % i, 1, 0, i * 2) for i in range(1001)]
too_many_params = [Param("param-key-%s" % i, "b") for i in range(101)]
too_many_tags = [RunTag("tag-key-%s" % i, "b") for i in range(101)]
good_kwargs = {"metrics": [], "params": [], "tags": []}
bad_kwargs = {
"metrics": [too_many_metrics],
"params": [too_many_params],
"tags": [too_many_tags],
}
for arg_name, arg_values in bad_kwargs.items():
for arg_value in arg_values:
final_kwargs = copy.deepcopy(good_kwargs)
final_kwargs[arg_name] = arg_value
with pytest.raises(MlflowException):
_validate_batch_log_limits(**final_kwargs)
# Test the case where there are too many entities in aggregate
with pytest.raises(MlflowException):
def test_validate_batch_log_data():
metrics_with_bad_key = [Metric("good-metric-key", 1.0, 0, 0),
Metric("super-long-bad-key" * 1000, 4.0, 0, 0)]
metrics_with_bad_val = [Metric("good-metric-key", "not-a-double-val", 0, 0)]
metrics_with_bad_ts = [Metric("good-metric-key", 1.0, "not-a-timestamp", 0)]
metrics_with_neg_ts = [Metric("good-metric-key", 1.0, -123, 0)]
metrics_with_bad_step = [Metric("good-metric-key", 1.0, 0, "not-a-step")]
params_with_bad_key = [Param("good-param-key", "hi"),
Param("super-long-bad-key" * 1000, "but-good-val")]
params_with_bad_val = [Param("good-param-key", "hi"),
Param("another-good-key", "but-bad-val" * 1000)]
tags_with_bad_key = [RunTag("good-tag-key", "hi"),
RunTag("super-long-bad-key" * 1000, "but-good-val")]
tags_with_bad_val = [RunTag("good-tag-key", "hi"),
RunTag("another-good-key", "but-bad-val" * 1000)]
bad_kwargs = {
"metrics": [metrics_with_bad_key, metrics_with_bad_val, metrics_with_bad_ts,
metrics_with_neg_ts, metrics_with_bad_step],
"params": [params_with_bad_key, params_with_bad_val],
"tags": [tags_with_bad_key, tags_with_bad_val],
}
good_kwargs = {"metrics": [], "params": [], "tags": []}
for arg_name, arg_values in bad_kwargs.items():
for arg_value in arg_values:
final_kwargs = copy.deepcopy(good_kwargs)
assert mock_http.call_count == 1
actual_kwargs = mock_http.call_args[1]
# Test the passed tag values separately from the rest of the request
# Tag order is inconsistent on Python 2 and 3, but the order does not matter
expected_tags = expected_kwargs['json'].pop('tags')
actual_tags = actual_kwargs['json'].pop('tags')
assert (
sorted(expected_tags, key=lambda t: t['key']) ==
sorted(actual_tags, key=lambda t: t['key'])
)
assert expected_kwargs == actual_kwargs
with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
store.log_param("some_uuid", Param("k1", "v1"))
body = message_to_json(LogParam(
run_uuid="some_uuid", run_id="some_uuid", key="k1", value="v1"))
self._verify_requests(mock_http, creds,
"runs/log-parameter", "POST", body)
with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
store.set_experiment_tag("some_id", ExperimentTag("t1", "abcd"*1000))
body = message_to_json(SetExperimentTag(
experiment_id="some_id",
key="t1",
value="abcd"*1000))
self._verify_requests(mock_http, creds,
"experiments/set-experiment-tag", "POST", body)
with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
store.set_tag("some_uuid", RunTag("t1", "abcd"*1000))
def test_log_many_entities(self):
"""
Sanity check: verify that we can log a reasonable number of entities without failures due
to connection leaks etc.
"""
run = self._run_factory()
for i in range(100):
self.store.log_metric(run.info.run_id, entities.Metric("key", i, i * 2, i * 3))
self.store.log_param(run.info.run_id, entities.Param("pkey-%s" % i, "pval-%s" % i))
self.store.set_tag(run.info.run_id, entities.RunTag("tkey-%s" % i, "tval-%s" % i))
def test_search_params(self):
experiment_id = self._experiment_factory('search_params')
r1 = self._run_factory(self._get_run_configs(experiment_id)).info.run_id
r2 = self._run_factory(self._get_run_configs(experiment_id)).info.run_id
self.store.log_param(r1, entities.Param('generic_param', 'p_val'))
self.store.log_param(r2, entities.Param('generic_param', 'p_val'))
self.store.log_param(r1, entities.Param('generic_2', 'some value'))
self.store.log_param(r2, entities.Param('generic_2', 'another value'))
self.store.log_param(r1, entities.Param('p_a', 'abc'))
self.store.log_param(r2, entities.Param('p_b', 'ABC'))
# test search returns both runs
filter_string = "params.generic_param = 'p_val'"
six.assertCountEqual(self, [r1, r2], self._search(experiment_id, filter_string))
# test search returns appropriate run (same key different values per run)
filter_string = "params.generic_2 = 'some value'"
six.assertCountEqual(self, [r1], self._search(experiment_id, filter_string))
filter_string = "params.generic_2 = 'another value'"
six.assertCountEqual(self, [r2], self._search(experiment_id, filter_string))
filter_string = "params.generic_param = 'wrong_val'"
six.assertCountEqual(self, [], self._search(experiment_id, filter_string))
filter_string = "params.generic_param != 'p_val'"
six.assertCountEqual(self, [], self._search(experiment_id, filter_string))
def import_run_data(self, run_dct, run, src_user_id):
from mlflow.entities import Metric, Param, RunTag
now = round(time.time())
params = [ Param(k,v) for k,v in run_dct['params'].items() ]
metrics = [ Metric(k,v,now,0) for k,v in run_dct['metrics'].items() ] # TODO: missing timestamp and step semantics?
tags = self._create_tags_for_metadata(run_dct['tags'])
tags = utils.create_tags_for_mlflow_tags(tags, self.import_mlflow_tags)
#utils.dump_tags("RunImporter.import_run_data",tags)
#utils.set_dst_user_id(tags, src_user_id, self.use_src_user_id)
self.client.log_batch(run.info.run_id, metrics, params, tags)