Skip to content

Commit

Permalink
Server-side implementation of batched logging endpoints (mlflow#951)
Browse files Browse the repository at this point in the history
Adds OSS server handlers & tests for the new /runs/log-batch endpoint
  • Loading branch information
smurching committed Mar 12, 2019
1 parent af714c2 commit 8c48f27
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 8 deletions.
17 changes: 16 additions & 1 deletion mlflow/server/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from flask import Response, request, send_file
from querystring_parser import parser

import mlflow
from mlflow.entities import Metric, Param, RunTag, ViewType
from mlflow.exceptions import MlflowException
from mlflow.protos import databricks_pb2
from mlflow.protos.service_pb2 import CreateExperiment, MlflowService, GetExperiment, \
GetRun, SearchRuns, ListArtifacts, GetMetricHistory, CreateRun, \
UpdateRun, LogMetric, LogParam, SetTag, ListExperiments, \
DeleteExperiment, RestoreExperiment, RestoreRun, DeleteRun, UpdateExperiment
DeleteExperiment, RestoreExperiment, RestoreRun, DeleteRun, UpdateExperiment, LogBatch
from mlflow.store.artifact_repository_registry import get_artifact_repository
from mlflow.tracking.utils import _is_database_uri, _is_local_uri
from mlflow.utils.proto_json_utils import message_to_json, parse_dict
Expand Down Expand Up @@ -332,6 +333,20 @@ def _get_artifact_repo(run):
return get_artifact_repository(run.info.artifact_uri, store)


@catch_mlflow_exception
def _log_batch():
request_message = _get_request_message(LogBatch())
metrics = [Metric.from_proto(proto_metric) for proto_metric in request_message.metrics]
params = [Param.from_proto(proto_param) for proto_param in request_message.params]
tags = [RunTag.from_proto(proto_tag) for proto_tag in request_message.tags]
mlflow.tracking.utils._get_store().log_batch(
run_id=request_message.run_id, metrics=metrics, params=params, tags=tags)
response_message = LogBatch.Response()
response = Response(mimetype='application/json')
response.set_data(message_to_json(response_message))
return response


def _get_paths(base_path):
"""
A service endpoints base path is typically something like /preview/mlflow/experiment.
Expand Down
10 changes: 7 additions & 3 deletions mlflow/store/rest_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from mlflow.store.abstract_store import AbstractStore

from mlflow.entities import Experiment, Run, RunInfo, RunTag, Metric, ViewType
from mlflow.exceptions import MlflowException

from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
from mlflow.utils.proto_json_utils import message_to_json, parse_dict
Expand All @@ -12,7 +11,7 @@
from mlflow.protos.service_pb2 import CreateExperiment, MlflowService, GetExperiment, \
GetRun, SearchRuns, ListExperiments, GetMetricHistory, LogMetric, LogParam, SetTag, \
UpdateRun, CreateRun, DeleteRun, RestoreRun, DeleteExperiment, RestoreExperiment, \
UpdateExperiment
UpdateExperiment, LogBatch

from mlflow.protos import databricks_pb2

Expand Down Expand Up @@ -237,4 +236,9 @@ def restore_run(self, run_id):
self._call_endpoint(RestoreRun, req_body)

def log_batch(self, run_id, metrics, params, tags):
raise MlflowException("The LogBatch REST API is not yet implemented")
metric_protos = [metric.to_proto() for metric in metrics]
param_protos = [param.to_proto() for param in params]
tag_protos = [tag.to_proto() for tag in tags]
req_body = message_to_json(
LogBatch(metrics=metric_protos, params=param_protos, tags=tag_protos, run_id=run_id))
self._call_endpoint(LogBatch, req_body)
81 changes: 78 additions & 3 deletions tests/server/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
import mock
import pytest

from mlflow.entities import ViewType
import mlflow
from mlflow.entities import ViewType, Metric, RunTag, Param
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INTERNAL_ERROR, ErrorCode
from mlflow.server.handlers import get_endpoints, _create_experiment, _get_request_message, \
_search_runs, catch_mlflow_exception
from mlflow.protos.service_pb2 import CreateExperiment, SearchRuns
_search_runs, _log_batch, catch_mlflow_exception
from mlflow.protos.service_pb2 import CreateExperiment, SearchRuns, LogBatch
from mlflow.store.file_store import FileStore
from mlflow.utils.mlflow_tags import MLFLOW_SOURCE_TYPE, MLFLOW_SOURCE_NAME


@pytest.fixture()
Expand Down Expand Up @@ -78,6 +81,78 @@ def test_search_runs_default_view_type(mock_get_request_message, mock_store):
assert args[2] == ViewType.ACTIVE_ONLY


def _assert_logged_entities(run_id, metric_entities, param_entities, tag_entities):
client = mlflow.tracking.MlflowClient()
store = mlflow.tracking.utils._get_store()
run = client.get_run(run_id)
# Assert logged metrics
all_logged_metrics = sum([store.get_metric_history(run_id, m.key)
for m in run.data.metrics], [])
assert len(all_logged_metrics) == len(metric_entities)
logged_metrics_dicts = [dict(m) for m in all_logged_metrics]
for metric in metric_entities:
assert dict(metric) in logged_metrics_dicts
# Assert logged params
param_entities_dict = [dict(p) for p in param_entities]
for p in run.data.params:
assert dict(p) in param_entities_dict
# Assert logged tags
tag_entities_dict = [dict(t) for t in tag_entities]
approx_expected_tags = [MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE]
for t in run.data.tags:
if t.key in approx_expected_tags:
pass
else:
assert dict(t) in tag_entities_dict


def test_log_batch_handler_success(mock_get_request_message, tmpdir):
# Test success cases for the LogBatch API
def _test_log_batch_helper_success(
metric_entities, param_entities, tag_entities,
expected_metrics=None, expected_params=None, expected_tags=None):
"""
Simulates a LogBatch API request using the provided metrics/params/tags, asserting that it
succeeds & that the backing store contains either the set of expected metrics/params/tags
(if provided) or, by default, the metrics/params/tags used in the API request.
"""
with mlflow.start_run() as active_run:
run_id = active_run.info.run_uuid
mock_get_request_message.return_value = LogBatch(
run_id=run_id,
metrics=[m.to_proto() for m in metric_entities],
params=[p.to_proto() for p in param_entities],
tags=[t.to_proto() for t in tag_entities])
response = _log_batch()
print(response, response.get_data())
assert response.status_code == 200
json_response = json.loads(response.get_data())
assert json_response == {}
_assert_logged_entities(
run_id, expected_metrics or metric_entities, expected_params or param_entities,
expected_tags or tag_entities)

store = FileStore(tmpdir.strpath)
with mock.patch('mlflow.tracking.utils._get_store', return_value=store):
mlflow.set_experiment("log-batch-experiment")
# Log an empty payload
_test_log_batch_helper_success([], [], [])
# Log multiple metrics/params/tags
_test_log_batch_helper_success(
metric_entities=[Metric(key="m-key", value=3.2 * i, timestamp=i) for i in range(3)],
param_entities=[Param(key="p-key-%s" % i, value="p-val-%s" % i) for i in range(4)],
tag_entities=[RunTag(key="t-key-%s" % i, value="t-val-%s" % i) for i in range(5)])
# Log metrics with the same key
_test_log_batch_helper_success(
metric_entities=[Metric(key="m-key", value=3.2 * i, timestamp=3) for i in range(3)],
param_entities=[], tag_entities=[])
# Log tags with the same key, verify the last one gets written
same_key_tags = [RunTag(key="t-key", value="t-val-%s" % i) for i in range(5)]
_test_log_batch_helper_success(
metric_entities=[], param_entities=[], tag_entities=same_key_tags,
expected_tags=[same_key_tags[-1]])


def test_catch_mlflow_exception():
@catch_mlflow_exception
def test_handler():
Expand Down
15 changes: 14 additions & 1 deletion tests/store/test_rest_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mlflow.exceptions import MlflowException
from mlflow.entities import Param, Metric, RunTag, SourceType
from mlflow.protos.service_pb2 import DeleteExperiment, RestoreExperiment, LogParam, LogMetric, \
SetTag, DeleteRun, RestoreRun, CreateRun, RunTag as ProtoRunTag
SetTag, DeleteRun, RestoreRun, CreateRun, RunTag as ProtoRunTag, LogBatch
from mlflow.store.rest_store import RestStore
from mlflow.utils.proto_json_utils import message_to_json

Expand Down Expand Up @@ -136,6 +136,19 @@ def test_requestor(self, request):
self._verify_requests(mock_http, creds,
"runs/log-metric", "POST", body)

with mock.patch('mlflow.store.rest_store.http_request_safe') as mock_http:
metrics = [Metric("m1", 0.87, 12345), Metric("m2", 0.49, 12345)]
params = [Param("p1", "p1val"), Param("p2", "p2val")]
tags = [RunTag("t1", "t1val"), RunTag("t2", "t2val")]
store.log_batch(run_id="u2", metrics=metrics, params=params, tags=tags)
metric_protos = [metric.to_proto() for metric in metrics]
param_protos = [param.to_proto() for param in params]
tag_protos = [tag.to_proto() for tag in tags]
body = message_to_json(LogBatch(run_id="u2", metrics=metric_protos,
params=param_protos, tags=tag_protos))
self._verify_requests(mock_http, creds,
"runs/log-batch", "POST", body)

with mock.patch('mlflow.store.rest_store.http_request_safe') as mock_http:
store.delete_run("u25")
self._verify_requests(mock_http, creds,
Expand Down

0 comments on commit 8c48f27

Please sign in to comment.