Skip to content

Commit

Permalink
Add support for specifying metric step in Python fluent APIs (mlflow#…
Browse files Browse the repository at this point in the history
…1143)

Adds ability to optionally pass metric x-coordinates via a step argument to mlflow.log_metric and mlflow.log_metrics in the Python fluent APIs. If a step is unspecified, metrics are logged with a default step of 0. Existing logged metrics are also read back with a step value of 0.
  • Loading branch information
smurching committed Apr 22, 2019
1 parent 8af29d0 commit 0d3690f
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 76 deletions.
9 changes: 5 additions & 4 deletions mlflow/tracking/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,15 @@ def rename_experiment(self, experiment_id, new_name):
"""
self.store.rename_experiment(experiment_id, new_name)

def log_metric(self, run_id, key, value, timestamp=None):
def log_metric(self, run_id, key, value, timestamp=None, step=None):
"""
Log a metric against the run ID. If timestamp is not provided, uses
the current timestamp.
the current timestamp. The metric's step defaults to 0 if unspecified.
"""
timestamp = timestamp if timestamp is not None else int(time.time())
_validate_metric(key, value, timestamp, 0)
metric = Metric(key, value, timestamp, 0)
step = step if step is not None else 0
_validate_metric(key, value, timestamp, step)
metric = Metric(key, value, timestamp, step)
self.store.log_metric(run_id, metric)

def log_param(self, run_id, key, value):
Expand Down
12 changes: 8 additions & 4 deletions mlflow/tracking/fluent.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,26 +188,30 @@ def set_tag(key, value):
MlflowClient().set_tag(run_id, key, value)


def log_metric(key, value):
def log_metric(key, value, step=None):
"""
Log a metric under the current run, creating a run if necessary.
:param key: Metric name (string).
:param value: Metric value (float).
:param step: Metric step (int). Defaults to zero if unspecified.
"""
run_id = _get_or_start_run().info.run_uuid
MlflowClient().log_metric(run_id, key, value, int(time.time()))
MlflowClient().log_metric(run_id, key, value, int(time.time()), step or 0)


def log_metrics(metrics):
def log_metrics(metrics, step=None):
"""
Log multiple metrics for the current run, starting a run if no runs are active.
:param metrics: Dictionary of metric_name: String -> value: Float
:param step: A single integer step at which to log the specified
Metrics. If unspecified, each metric is logged at step zero.
:returns: None
"""
run_id = _get_or_start_run().info.run_uuid
timestamp = int(time.time())
metrics_arr = [Metric(key, value, timestamp, 0) for key, value in metrics.items()]
metrics_arr = [Metric(key, value, timestamp, step or 0) for key, value in metrics.items()]
MlflowClient().log_batch(run_id=run_id, metrics=metrics_arr, params=[], tags=[])


Expand Down
29 changes: 19 additions & 10 deletions tests/tracking/test_rest_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,32 +224,40 @@ def test_log_metrics_params_tags(mlflow_client):
experiment_id = mlflow_client.create_experiment('Oh My')
created_run = mlflow_client.create_run(experiment_id)
run_id = created_run.info.run_uuid
# TODO(sid): pass and assert on step
mlflow_client.log_metric(run_id, key='metric', value=123.456, timestamp=789)
mlflow_client.log_metric(run_id, key='metric', value=123.456, timestamp=789, step=2)
mlflow_client.log_metric(run_id, key='stepless-metric', value=987.654, timestamp=321)
mlflow_client.log_param(run_id, 'param', 'value')
mlflow_client.set_tag(run_id, 'taggity', 'do-dah')
run = mlflow_client.get_run(run_id)
assert run.data.metrics.get('metric') == 123.456
assert run.data.metrics.get('stepless-metric') == 987.654
assert run.data.params.get('param') == 'value'
assert run.data.tags.get('taggity') == 'do-dah'
# TODO(sid): replace this with mlflow_client.get_metric_history
fs = FileStore(server_root_dir)
metric_history = fs.get_metric_history(run_id, "metric")
assert len(metric_history) == 1
metric = metric_history[0]
assert metric.key == "metric"
assert metric.value == 123.456
assert metric.timestamp == 789
metric_history0 = fs.get_metric_history(run_id, "metric")
assert len(metric_history0) == 1
metric0 = metric_history0[0]
assert metric0.key == "metric"
assert metric0.value == 123.456
assert metric0.timestamp == 789
assert metric0.step == 2
metric_history1 = fs.get_metric_history(run_id, "stepless-metric")
assert len(metric_history1) == 1
metric1 = metric_history1[0]
assert metric1.key == "stepless-metric"
assert metric1.value == 987.654
assert metric1.timestamp == 321
assert metric1.step == 0


def test_log_batch(mlflow_client):
experiment_id = mlflow_client.create_experiment('Batch em up')
created_run = mlflow_client.create_run(experiment_id)
run_id = created_run.info.run_uuid
# TODO(sid): pass and assert on step
mlflow_client.log_batch(
run_id=run_id,
metrics=[Metric("metric", 123.456, 789, 0)], params=[Param("param", "value")],
metrics=[Metric("metric", 123.456, 789, 3)], params=[Param("param", "value")],
tags=[RunTag("taggity", "do-dah")])
run = mlflow_client.get_run(run_id)
assert run.data.metrics.get('metric') == 123.456
Expand All @@ -263,6 +271,7 @@ def test_log_batch(mlflow_client):
assert metric.key == "metric"
assert metric.value == 123.456
assert metric.timestamp == 789
assert metric.step == 3


def test_set_terminated_defaults(mlflow_client):
Expand Down
131 changes: 73 additions & 58 deletions tests/tracking/test_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,24 +81,21 @@ def test_set_experiment(tracking_uri_mock, reset_active_experiment):
name = "random_exp"
exp_id = mlflow.create_experiment(name)
mlflow.set_experiment(name)
run = start_run()
assert run.info.experiment_id == exp_id
end_run()
with start_run() as run:
assert run.info.experiment_id == exp_id

another_name = "another_experiment"
mlflow.set_experiment(another_name)
exp_id2 = mlflow.tracking.MlflowClient().get_experiment_by_name(another_name)
another_run = start_run()
assert another_run.info.experiment_id == exp_id2.experiment_id
end_run()
with start_run() as another_run:
assert another_run.info.experiment_id == exp_id2.experiment_id


def test_set_experiment_with_deleted_experiment_name(tracking_uri_mock):
name = "dead_exp"
mlflow.set_experiment(name)
run = start_run()
end_run()
exp_id = run.info.experiment_id
with start_run() as run:
exp_id = run.info.experiment_id

tracking.MlflowClient().delete_experiment(exp_id)

Expand All @@ -120,9 +117,8 @@ def test_set_experiment_with_zero_id(reset_mock, reset_active_experiment):


def test_start_run_context_manager(tracking_uri_mock):
first_run = start_run()
first_uuid = first_run.info.run_uuid
with first_run:
with start_run() as first_run:
first_uuid = first_run.info.run_uuid
# Check that start_run() causes the run information to be persisted in the store
persisted_run = tracking.MlflowClient().get_run(first_uuid)
assert persisted_run is not None
Expand All @@ -131,12 +127,12 @@ def test_start_run_context_manager(tracking_uri_mock):
assert finished_run.info.status == RunStatus.FINISHED
# Launch a separate run that fails, verify the run status is FAILED and the run UUID is
# different
second_run = start_run()
assert second_run.info.run_uuid != first_uuid
with pytest.raises(Exception):
with second_run:
with start_run() as second_run:
second_run_id = second_run.info.run_uuid
raise Exception("Failing run!")
finished_run2 = tracking.MlflowClient().get_run(second_run.info.run_uuid)
assert second_run_id != first_uuid
finished_run2 = tracking.MlflowClient().get_run(second_run_id)
assert finished_run2.info.status == RunStatus.FAILED


Expand All @@ -151,28 +147,39 @@ def test_start_and_end_run(tracking_uri_mock):
assert finished_run.data.metrics["name_1"] == 25


def test_log_batch(tracking_uri_mock):
def test_log_batch(tracking_uri_mock, tmpdir):
expected_metrics = {"metric-key0": 1.0, "metric-key1": 4.0}
expected_params = {"param-key0": "param-val0", "param-key1": "param-val1"}
exact_expected_tags = {"tag-key0": "tag-val0", "tag-key1": "tag-val1"}
approx_expected_tags = set([MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE])

t = int(time.time())
metrics = [Metric(key=key, value=value, timestamp=t, step=0)
for key, value in expected_metrics.items()]
sorted_expected_metrics = sorted(expected_metrics.items(), key=lambda kv: kv[0])
metrics = [Metric(key=key, value=value, timestamp=t, step=i)
for i, (key, value) in enumerate(sorted_expected_metrics)]
params = [Param(key=key, value=value) for key, value in expected_params.items()]
tags = [RunTag(key=key, value=value) for key, value in exact_expected_tags.items()]

active_run = start_run()
run_uuid = active_run.info.run_uuid
with active_run:
with start_run() as active_run:
run_uuid = active_run.info.run_uuid
mlflow.tracking.MlflowClient().log_batch(run_id=run_uuid, metrics=metrics, params=params,
tags=tags)
finished_run = tracking.MlflowClient().get_run(run_uuid)
# Validate metrics
assert len(finished_run.data.metrics) == 2
for key, value in finished_run.data.metrics.items():
assert expected_metrics[key] == value
# TODO: use client get_metric_history API here instead once it exists
fs = FileStore(os.path.join(tmpdir.strpath, "mlruns"))
metric_history0 = fs.get_metric_history(run_uuid, "metric-key0")
assert set([(m.value, m.timestamp, m.step) for m in metric_history0]) == set([
(1.0, t, 0),
])
metric_history1 = fs.get_metric_history(run_uuid, "metric-key1")
assert set([(m.value, m.timestamp, m.step) for m in metric_history1]) == set([
(4.0, t, 1),
])

# Validate tags (for automatically-set tags)
assert len(finished_run.data.tags) == len(exact_expected_tags) + len(approx_expected_tags)
for tag_key, tag_value in finished_run.data.tags.items():
Expand All @@ -184,36 +191,51 @@ def test_log_batch(tracking_uri_mock):
assert finished_run.data.params == expected_params


def test_log_metric(tracking_uri_mock):
active_run = start_run()
run_uuid = active_run.info.run_uuid
with active_run:
def test_log_metric(tracking_uri_mock, tmpdir):
with start_run() as active_run, mock.patch("time.time") as time_mock:
time_mock.side_effect = range(300, 400)
run_uuid = active_run.info.run_uuid
mlflow.log_metric("name_1", 25)
mlflow.log_metric("name_2", -3)
mlflow.log_metric("name_1", 30)
mlflow.log_metric("name_1", 30, 5)
mlflow.log_metric("name_1", 40, -2)
mlflow.log_metric("nested/nested/name", 40)
finished_run = tracking.MlflowClient().get_run(run_uuid)
# Validate metrics
assert len(finished_run.data.metrics) == 3
expected_pairs = {"name_1": 30, "name_2": -3, "nested/nested/name": 40}
for key, value in finished_run.data.metrics.items():
assert expected_pairs[key] == value


def test_log_metrics(tracking_uri_mock):
active_run = start_run()
run_uuid = active_run.info.run_uuid
# TODO: use client get_metric_history API here instead once it exists
fs = FileStore(os.path.join(tmpdir.strpath, "mlruns"))
metric_history_name1 = fs.get_metric_history(run_uuid, "name_1")
assert set([(m.value, m.timestamp, m.step) for m in metric_history_name1]) == set([
(25, 300, 0),
(30, 302, 5),
(40, 303, -2),
])
metric_history_name2 = fs.get_metric_history(run_uuid, "name_2")
assert set([(m.value, m.timestamp, m.step) for m in metric_history_name2]) == set([
(-3, 301, 0),
])


@pytest.mark.parametrize("step_kwarg", [None, -10, 5])
def test_log_metrics(tracking_uri_mock, step_kwarg):
expected_metrics = {"name_1": 30, "name_2": -3, "nested/nested/name": 40}
with active_run:
mlflow.log_metrics(expected_metrics)
with start_run() as active_run:
run_uuid = active_run.info.run_uuid
mlflow.log_metrics(expected_metrics, step=step_kwarg)
finished_run = tracking.MlflowClient().get_run(run_uuid)
# Validate metric key/values match what we expect, and that all metrics have the same timestamp
assert len(finished_run.data.metrics) == len(expected_metrics)
for key, value in finished_run.data.metrics.items():
assert expected_metrics[key] == value
common_timestamp = finished_run.data._metric_objs[0].timestamp
expected_step = step_kwarg if step_kwarg is not None else 0
for metric_obj in finished_run.data._metric_objs:
assert metric_obj.timestamp == common_timestamp
assert metric_obj.step == expected_step


@pytest.fixture
Expand All @@ -225,9 +247,8 @@ def get_store_mock(tmpdir):
def test_set_tags(tracking_uri_mock):
exact_expected_tags = {"name_1": "c", "name_2": "b", "nested/nested/name": "5"}
approx_expected_tags = set([MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE])
active_run = start_run()
run_uuid = active_run.info.run_uuid
with active_run:
with start_run() as active_run:
run_uuid = active_run.info.run_uuid
mlflow.set_tags(exact_expected_tags)
finished_run = tracking.MlflowClient().get_run(run_uuid)
# Validate tags
Expand All @@ -240,19 +261,18 @@ def test_set_tags(tracking_uri_mock):


def test_log_metric_validation(tracking_uri_mock):
active_run = start_run()
run_uuid = active_run.info.run_uuid
with active_run, pytest.raises(MlflowException) as e:
mlflow.log_metric("name_1", "apple")
with start_run() as active_run:
run_uuid = active_run.info.run_uuid
with pytest.raises(MlflowException) as e:
mlflow.log_metric("name_1", "apple")
assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
finished_run = tracking.MlflowClient().get_run(run_uuid)
assert len(finished_run.data.metrics) == 0


def test_log_param(tracking_uri_mock):
active_run = start_run()
run_uuid = active_run.info.run_uuid
with active_run:
with start_run() as active_run:
run_uuid = active_run.info.run_uuid
mlflow.log_param("name_1", "a")
mlflow.log_param("name_2", "b")
mlflow.log_param("name_1", "c")
Expand All @@ -264,17 +284,15 @@ def test_log_param(tracking_uri_mock):

def test_log_params(tracking_uri_mock):
expected_params = {"name_1": "c", "name_2": "b", "nested/nested/name": "5"}
active_run = start_run()
run_uuid = active_run.info.run_uuid
with active_run:
with start_run() as active_run:
run_uuid = active_run.info.run_uuid
mlflow.log_params(expected_params)
finished_run = tracking.MlflowClient().get_run(run_uuid)
# Validate params
assert finished_run.data.params == {"name_1": "c", "name_2": "b", "nested/nested/name": "5"}


def test_log_batch_validates_entity_names_and_values(tracking_uri_mock):
active_run = start_run()
bad_kwargs = {
"metrics": [
[Metric(key="../bad/metric/name", value=0.3, timestamp=3, step=0)],
Expand All @@ -284,7 +302,7 @@ def test_log_batch_validates_entity_names_and_values(tracking_uri_mock):
"params": [[Param(key="../bad/param/name", value="my-val")]],
"tags": [[Param(key="../bad/tag/name", value="my-val")]],
}
with active_run:
with start_run() as active_run:
for kwarg, bad_values in bad_kwargs.items():
for bad_kwarg_value in bad_values:
final_kwargs = {
Expand Down Expand Up @@ -370,22 +388,19 @@ def test_with_startrun():


def test_parent_create_run(tracking_uri_mock):
parent_run = mlflow.start_run()
with pytest.raises(Exception, match='To start a nested run'):
mlflow.start_run()
child_run = mlflow.start_run(nested=True)
grand_child_run = mlflow.start_run(nested=True)
with mlflow.start_run() as parent_run:
with pytest.raises(Exception, match='To start a nested run'):
mlflow.start_run()
with mlflow.start_run(nested=True) as child_run:
with mlflow.start_run(nested=True) as grand_child_run:
pass

def verify_has_parent_id_tag(child_id, expected_parent_id):
tags = tracking.MlflowClient().get_run(child_id).data.tags
assert tags[MLFLOW_PARENT_RUN_ID] == expected_parent_id

verify_has_parent_id_tag(child_run.info.run_uuid, parent_run.info.run_uuid)
verify_has_parent_id_tag(grand_child_run.info.run_uuid, child_run.info.run_uuid)

mlflow.end_run()
mlflow.end_run()
mlflow.end_run()
assert mlflow.active_run() is None


Expand Down

0 comments on commit 0d3690f

Please sign in to comment.