Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
exp_id = self._experiment_factory('setExperimentTagExp')
tag = entities.ExperimentTag("tag0", "value0")
new_tag = entities.RunTag("tag0", "value00000")
self.store.set_experiment_tag(exp_id, tag)
experiment = self.store.get_experiment(exp_id)
self.assertTrue(experiment.tags["tag0"] == "value0")
# test that updating a tag works
self.store.set_experiment_tag(exp_id, new_tag)
experiment = self.store.get_experiment(exp_id)
self.assertTrue(experiment.tags["tag0"] == "value00000")
# test that setting a tag on 1 experiment does not impact another experiment.
exp_id_2 = self._experiment_factory('setExperimentTagExp2')
experiment2 = self.store.get_experiment(exp_id_2)
self.assertTrue(len(experiment2.tags) == 0)
# setting a tag on different experiments maintains different values across experiments
different_tag = entities.RunTag("tag0", "differentValue")
self.store.set_experiment_tag(exp_id_2, different_tag)
experiment = self.store.get_experiment(exp_id)
self.assertTrue(experiment.tags["tag0"] == "value00000")
experiment2 = self.store.get_experiment(exp_id_2)
self.assertTrue(experiment2.tags["tag0"] == "differentValue")
# test can set multi-line tags
multiLineTag = entities.ExperimentTag("multiline tag", "value2\nvalue2\nvalue2")
self.store.set_experiment_tag(exp_id, multiLineTag)
experiment = self.store.get_experiment(exp_id)
self.assertTrue(experiment.tags["multiline tag"] == "value2\nvalue2\nvalue2")
# test cannot set tags that are too long
longTag = entities.ExperimentTag("longTagKey", "a" * 5001)
with pytest.raises(MlflowException):
self.store.set_experiment_tag(exp_id, longTag)
# test can set tags that are somewhat long
longTag = entities.ExperimentTag("longTagKey", "a" * 4999)
def test_set_tags(self):
fs = FileStore(self.test_root)
run_id = self.exp_data[FileStore.DEFAULT_EXPERIMENT_ID]["runs"][0]
fs.set_tag(run_id, RunTag("tag0", "value0"))
fs.set_tag(run_id, RunTag("tag1", "value1"))
tags = fs.get_run(run_id).data.tags
assert tags["tag0"] == "value0"
assert tags["tag1"] == "value1"
# Can overwrite tags.
fs.set_tag(run_id, RunTag("tag0", "value2"))
tags = fs.get_run(run_id).data.tags
assert tags["tag0"] == "value2"
assert tags["tag1"] == "value1"
# Can set multiline tags.
fs.set_tag(run_id, RunTag("multiline_tag", "value2\nvalue2\nvalue2"))
tags = fs.get_run(run_id).data.tags
assert tags["multiline_tag"] == "value2\nvalue2\nvalue2"
def test_log_batch(self):
experiment_id = self._experiment_factory('log_batch')
run_id = self._run_factory(self._get_run_configs(experiment_id)).info.run_id
metric_entities = [Metric("m1", 0.87, 12345, 0), Metric("m2", 0.49, 12345, 1)]
param_entities = [Param("p1", "p1val"), Param("p2", "p2val")]
tag_entities = [RunTag("t1", "t1val"), RunTag("t2", "t2val")]
self.store.log_batch(
run_id=run_id, metrics=metric_entities, params=param_entities, tags=tag_entities)
run = self.store.get_run(run_id)
assert run.data.tags == {"t1": "t1val", "t2": "t2val"}
assert run.data.params == {"p1": "p1val", "p2": "p2val"}
metric_histories = sum(
[self.store.get_metric_history(run_id, key) for key in run.data.metrics], [])
metrics = [(m.key, m.value, m.timestamp, m.step) for m in metric_histories]
assert set(metrics) == set([("m1", 0.87, 12345, 0), ("m2", 0.49, 12345, 1)])
def test_search_tags(self):
experiment_id = self._experiment_factory('search_tags')
r1 = self._run_factory(self._get_run_configs(experiment_id)).info.run_id
r2 = self._run_factory(self._get_run_configs(experiment_id)).info.run_id
self.store.set_tag(r1, entities.RunTag('generic_tag', 'p_val'))
self.store.set_tag(r2, entities.RunTag('generic_tag', 'p_val'))
self.store.set_tag(r1, entities.RunTag('generic_2', 'some value'))
self.store.set_tag(r2, entities.RunTag('generic_2', 'another value'))
self.store.set_tag(r1, entities.RunTag('p_a', 'abc'))
self.store.set_tag(r2, entities.RunTag('p_b', 'ABC'))
# test search returns both runs
six.assertCountEqual(self, [r1, r2],
self._search(experiment_id,
filter_string="tags.generic_tag = 'p_val'"))
# test search returns appropriate run (same key different values per run)
six.assertCountEqual(self, [r1],
self._search(experiment_id,
filter_string="tags.generic_2 = 'some value'"))
six.assertCountEqual(self, [r2],
self._search(experiment_id,
filter_string="tags.generic_2 = 'another value'"))
six.assertCountEqual(self, [],
self._search(experiment_id,
filter_string="tags.generic_tag = 'wrong_val'"))
six.assertCountEqual(self, [],
def create_run(start_time, end):
return self.store.create_run(
experiment_id,
user_id="MrDuck",
start_time=start_time,
tags=[entities.RunTag(mlflow_tags.MLFLOW_RUN_NAME, end)]).info.run_id
def test_validate_batch_log_data():
metrics_with_bad_key = [Metric("good-metric-key", 1.0, 0, 0),
Metric("super-long-bad-key" * 1000, 4.0, 0, 0)]
metrics_with_bad_val = [Metric("good-metric-key", "not-a-double-val", 0, 0)]
metrics_with_bad_ts = [Metric("good-metric-key", 1.0, "not-a-timestamp", 0)]
metrics_with_neg_ts = [Metric("good-metric-key", 1.0, -123, 0)]
metrics_with_bad_step = [Metric("good-metric-key", 1.0, 0, "not-a-step")]
params_with_bad_key = [Param("good-param-key", "hi"),
Param("super-long-bad-key" * 1000, "but-good-val")]
params_with_bad_val = [Param("good-param-key", "hi"),
Param("another-good-key", "but-bad-val" * 1000)]
tags_with_bad_key = [RunTag("good-tag-key", "hi"),
RunTag("super-long-bad-key" * 1000, "but-good-val")]
tags_with_bad_val = [RunTag("good-tag-key", "hi"),
RunTag("another-good-key", "but-bad-val" * 1000)]
bad_kwargs = {
"metrics": [metrics_with_bad_key, metrics_with_bad_val, metrics_with_bad_ts,
metrics_with_neg_ts, metrics_with_bad_step],
"params": [params_with_bad_key, params_with_bad_val],
"tags": [tags_with_bad_key, tags_with_bad_val],
}
good_kwargs = {"metrics": [], "params": [], "tags": []}
for arg_name, arg_values in bad_kwargs.items():
for arg_value in arg_values:
final_kwargs = copy.deepcopy(good_kwargs)
final_kwargs[arg_name] = arg_value
with pytest.raises(MlflowException):
_validate_batch_log_data(**final_kwargs)
# Test that we don't reject entities within the limit
_validate_batch_log_data(
metrics=[Metric("metric-key", 1.0, 0, 0)], params=[Param("param-key", "param-val")],
def test_log_batch(self):
fs = FileStore(self.test_root)
run = fs.create_run(
experiment_id=FileStore.DEFAULT_EXPERIMENT_ID, user_id='user', start_time=0, tags=[])
run_id = run.info.run_id
metric_entities = [Metric("m1", 0.87, 12345, 0), Metric("m2", 0.49, 12345, 0)]
param_entities = [Param("p1", "p1val"), Param("p2", "p2val")]
tag_entities = [RunTag("t1", "t1val"), RunTag("t2", "t2val")]
fs.log_batch(
run_id=run_id, metrics=metric_entities, params=param_entities, tags=tag_entities)
self._verify_logged(fs, run_id, metric_entities, param_entities, tag_entities)
def test_log_batch_allows_tag_overwrite(self):
fs = FileStore(self.test_root)
run = self._create_run(fs)
fs.log_batch(run.info.run_id, metrics=[], params=[], tags=[RunTag("t-key", "val")])
fs.log_batch(run.info.run_id, metrics=[], params=[], tags=[RunTag("t-key", "newval")])
self._verify_logged(fs, run.info.run_id, metrics=[], params=[],
tags=[RunTag("t-key", "newval")])
def set_tag(self, run_id, key, value):
"""
Set a tag on the run ID. Value is converted to a string.
"""
_validate_tag_name(key)
tag = RunTag(key, str(value))
self.store.set_tag(run_id, tag)
def set_dst_user_id(tags, user_id, use_src_user_id):
from mlflow.entities import RunTag
from mlflow.utils.mlflow_tags import MLFLOW_USER
user_id = user_id if use_src_user_id else get_user_id()
tags.append(RunTag(PREFIX_SRC_RUN+"."+MLFLOW_USER,user_id ))