Skip to content

Commit

Permalink
Adding search_runs API to python MlflowClient (mlflow#1055)
Browse files Browse the repository at this point in the history
* Adding search_runs API to python MlflowClient

- Refactor to SearchFilter constructor to simplify interface to only support filter string based search for client.
- end-to-end tracking test per review comment
  • Loading branch information
mparkhe committed Mar 29, 2019
1 parent 0ed27c5 commit 167b280
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 25 deletions.
6 changes: 3 additions & 3 deletions mlflow/server/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,9 @@ def _search_runs():
run_view_type = ViewType.ACTIVE_ONLY
if request_message.HasField('run_view_type'):
run_view_type = ViewType.from_proto(request_message.run_view_type)
run_entities = _get_store().search_runs(request_message.experiment_ids,
SearchFilter(request_message),
run_view_type)
sf = SearchFilter(anded_expressions=request_message.anded_expressions,
filter_string=request_message.filter)
run_entities = _get_store().search_runs(request_message.experiment_ids, sf, run_view_type)
response_message.runs.extend([r.to_proto() for r in run_entities])
response = Response(mimetype='application/json')
response.set_data(message_to_json(response_message))
Expand Down
15 changes: 15 additions & 0 deletions mlflow/tracking/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from six import iteritems

from mlflow.tracking import utils
from mlflow.utils.search_utils import SearchFilter
from mlflow.utils.validation import _validate_param_name, _validate_tag_name, _validate_run_id, \
_validate_experiment_name, _validate_metric
from mlflow.entities import Param, Metric, RunStatus, RunTag, ViewType, SourceType
Expand Down Expand Up @@ -274,6 +275,20 @@ def restore_run(self, run_id):
"""
self.store.restore_run(run_id)

def search_runs(self, experiment_ids, filter_string, run_view_type=ViewType.ACTIVE_ONLY):
"""
Search experiments that fit the search criteria.
:param experiment_ids: List of experiment IDs
:param filter_string: Filter query string.
:param run_view_type: one of enum values ACTIVE_ONLY, DELETED_ONLY, or ALL runs
defined in :py:class:`mlflow.entities.ViewType`.
:return:
"""
return self.store.search_runs(experiment_ids=experiment_ids,
search_filter=SearchFilter(filter_string=filter_string),
run_view_type=run_view_type)


def _get_user_id():
"""Get the ID of the user for the current run."""
Expand Down
6 changes: 3 additions & 3 deletions mlflow/utils/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ class SearchFilter(object):
STRING_VALUE_TYPES = set([TokenType.Literal.String.Single])
NUMERIC_VALUE_TYPES = set([TokenType.Literal.Number.Integer, TokenType.Literal.Number.Float])

def __init__(self, search_runs=None):
self._filter_string = search_runs.filter if search_runs else None
self._search_expressions = search_runs.anded_expressions if search_runs else None
def __init__(self, filter_string=None, anded_expressions=None):
self._filter_string = filter_string
self._search_expressions = anded_expressions
if self._filter_string and self._search_expressions:
raise MlflowException("Can specify only one of 'filter' or 'search_expression'",
error_code=INVALID_PARAMETER_VALUE)
Expand Down
2 changes: 1 addition & 1 deletion tests/store/test_sqlalchemy_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def _search(self, experiment_id, metrics_expressions=None, param_expressions=Non
search_runs = SearchRuns()
search_runs.anded_expressions.extend(metrics_expressions or [])
search_runs.anded_expressions.extend(param_expressions or [])
search_filter = SearchFilter(search_runs)
search_filter = SearchFilter(anded_expressions=search_runs.anded_expressions)
return [r.info.run_uuid
for r in self.store.search_runs([experiment_id], search_filter, run_view_type)]

Expand Down
18 changes: 17 additions & 1 deletion tests/tracking/test_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import mock

from mlflow.entities import RunTag, SourceType
from mlflow.entities import RunTag, SourceType, ViewType
from mlflow.tracking import MlflowClient
from mlflow.utils.mlflow_tags import MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE, MLFLOW_PARENT_RUN_ID, \
MLFLOW_GIT_COMMIT, MLFLOW_PROJECT_ENTRY_POINT
Expand All @@ -26,6 +26,12 @@ def mock_time():
yield time


@pytest.fixture
def mock_search_filter():
with mock.patch("mlflow.tracking.client.SearchFilter") as mock_search_filter:
yield mock_search_filter.return_value


def test_client_create_run(mock_store, mock_user_id, mock_time):

experiment_id = mock.Mock()
Expand Down Expand Up @@ -75,3 +81,13 @@ def test_client_create_run_overrides(mock_store):
entry_point_name=tags[MLFLOW_PROJECT_ENTRY_POINT],
source_version=tags[MLFLOW_GIT_COMMIT]
)


def test_client_search_runs(mock_store, mock_search_filter):
experiment_ids = [mock.Mock() for _ in range(5)]

MlflowClient().search_runs(experiment_ids, "metrics.acc > 0.93")

mock_store.search_runs.assert_called_once_with(experiment_ids=experiment_ids,
search_filter=mock_search_filter,
run_view_type=ViewType.ACTIVE_ONLY)
76 changes: 75 additions & 1 deletion tests/tracking/test_tracking.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
import copy
import filecmp
import os
Expand All @@ -11,7 +12,7 @@

import mlflow
from mlflow import tracking
from mlflow.entities import RunStatus, LifecycleStage, Metric, Param, RunTag
from mlflow.entities import RunStatus, LifecycleStage, Metric, Param, RunTag, ViewType
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import ErrorCode, INVALID_PARAMETER_VALUE
from mlflow.tracking.client import MlflowClient
Expand Down Expand Up @@ -428,3 +429,76 @@ def test_get_artifact_uri_uses_currently_active_run_id():
assert mlflow.get_artifact_uri(artifact_path=artifact_path) ==\
tracking.utils.get_artifact_uri(
run_id=active_run.info.run_uuid, artifact_path=artifact_path)


def test_search_runs(tracking_uri_mock, reset_active_experiment):
mlflow.set_experiment("exp-for-search")
# Create a run and verify that the current active experiment is the one we just set
logged_runs = {}
with mlflow.start_run() as active_run:
logged_runs["first"] = active_run.info.run_uuid
mlflow.log_metric("m1", 0.001)
mlflow.log_metric("m2", 0.002)
mlflow.log_metric("m1", 0.002)
mlflow.log_param("p1", "a")
with mlflow.start_run() as active_run:
logged_runs["second"] = active_run.info.run_uuid
mlflow.log_metric("m1", 0.008)
mlflow.log_param("p2", "aa")

def verify_runs(runs, expected_set):
assert set([r.info.run_uuid for r in runs]) == set([logged_runs[r] for r in expected_set])

experiment_id = MlflowClient().get_experiment_by_name("exp-for-search").experiment_id

# 2 runs in this experiment
assert len(MlflowClient().list_run_infos(experiment_id, ViewType.ACTIVE_ONLY)) == 2

# 2 runs that have metric "m1" > 0.001
runs = MlflowClient().search_runs([experiment_id], "metrics.m1 > 0.0001")
verify_runs(runs, ["first", "second"])

# 1 run with has metric "m1" > 0.002
runs = MlflowClient().search_runs([experiment_id], "metrics.m1 > 0.002")
verify_runs(runs, ["second"])

# no runs with metric "m1" > 0.1
runs = MlflowClient().search_runs([experiment_id], "metrics.m1 > 0.1")
verify_runs(runs, [])

# 1 run with metric "m2" > 0
runs = MlflowClient().search_runs([experiment_id], "metrics.m2 > 0")
verify_runs(runs, ["first"])

# 1 run each with param "p1" and "p2"
runs = MlflowClient().search_runs([experiment_id], "params.p1 = 'a'", ViewType.ALL)
verify_runs(runs, ["first"])
runs = MlflowClient().search_runs([experiment_id], "params.p2 != 'a'", ViewType.ALL)
verify_runs(runs, ["second"])
runs = MlflowClient().search_runs([experiment_id], "params.p2 = 'aa'", ViewType.ALL)
verify_runs(runs, ["second"])

# delete "first" run
MlflowClient().delete_run(logged_runs["first"])
runs = MlflowClient().search_runs([experiment_id], "params.p1 = 'a'", ViewType.ALL)
verify_runs(runs, ["first"])

runs = MlflowClient().search_runs([experiment_id], "params.p1 = 'a'", ViewType.DELETED_ONLY)
verify_runs(runs, ["first"])

runs = MlflowClient().search_runs([experiment_id], "params.p1 = 'a'", ViewType.ACTIVE_ONLY)
verify_runs(runs, [])


def test_search_runs_multiple_experiments(tracking_uri_mock, reset_active_experiment):
experiment_ids = [mlflow.create_experiment("exp__{}".format(id)) for id in range(1, 4)]
for eid in experiment_ids:
with mlflow.start_run(experiment_id=eid):
mlflow.log_metric("m0", 1)
mlflow.log_metric("m_{}".format(eid), 2)

assert len(MlflowClient().search_runs(experiment_ids, "metrics.m0 > 0", ViewType.ALL)) == 3

assert len(MlflowClient().search_runs(experiment_ids, "metrics.m_1 > 0", ViewType.ALL)) == 1
assert len(MlflowClient().search_runs(experiment_ids, "metrics.m_2 = 2", ViewType.ALL)) == 1
assert len(MlflowClient().search_runs(experiment_ids, "metrics.m_3 < 4", ViewType.ALL)) == 1
32 changes: 16 additions & 16 deletions tests/utils/test_search_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from mlflow.exceptions import MlflowException
from mlflow.protos.service_pb2 import SearchRuns, SearchExpression, DoubleClause, \
from mlflow.protos.service_pb2 import SearchExpression, DoubleClause, \
MetricSearchExpression, FloatClause, ParameterSearchExpression, StringClause
from mlflow.utils.search_utils import SearchFilter

Expand All @@ -11,22 +11,22 @@ def test_search_filter_basics():
anded_expressions = [SearchExpression(), SearchExpression()]

# only anded_expressions
SearchFilter(SearchRuns(anded_expressions=anded_expressions))
SearchFilter(anded_expressions=anded_expressions)

# only search filter
SearchFilter(SearchRuns(filter=search_filter))
SearchFilter(filter_string=search_filter)

# both
with pytest.raises(MlflowException) as e:
SearchFilter(SearchRuns(anded_expressions=anded_expressions, filter=search_filter))
SearchFilter(anded_expressions=anded_expressions, filter_string=search_filter)
assert e.message.contains("Can specify only one of 'filter' or 'search_expression'")


def test_anded_expression():
se = SearchExpression(metric=MetricSearchExpression(key="accuracy",
double=DoubleClause(comparator=">=",
value=.94)))
sf = SearchFilter(SearchRuns(anded_expressions=[se]))
sf = SearchFilter(anded_expressions=[se])
assert sf._parse() == [{"type": "metric", "key": "accuracy", "comparator": ">=", "value": 0.94}]


Expand All @@ -36,11 +36,11 @@ def test_anded_expression_2():
m3 = MetricSearchExpression(key="mse", float=FloatClause(comparator=">=", value=5))
p1 = ParameterSearchExpression(key="a", string=StringClause(comparator="=", value="0"))
p2 = ParameterSearchExpression(key="b", string=StringClause(comparator="!=", value="blah"))
sf = SearchFilter(SearchRuns(anded_expressions=[SearchExpression(metric=m1),
SearchExpression(metric=m2),
SearchExpression(metric=m3),
SearchExpression(parameter=p1),
SearchExpression(parameter=p2)]))
sf = SearchFilter(anded_expressions=[SearchExpression(metric=m1),
SearchExpression(metric=m2),
SearchExpression(metric=m3),
SearchExpression(parameter=p1),
SearchExpression(parameter=p2)])

assert sf._parse() == [
{'comparator': '>=', 'key': 'accuracy', 'type': 'metric', 'value': 0.94},
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_anded_expression_2():
'value': "LR"}]),
])
def test_filter(filter_string, parsed_filter):
assert SearchFilter(SearchRuns(filter=filter_string))._parse() == parsed_filter
assert SearchFilter(filter_string=filter_string)._parse() == parsed_filter


@pytest.mark.parametrize("filter_string, parsed_filter", [
Expand All @@ -97,7 +97,7 @@ def test_filter(filter_string, parsed_filter):
'key': 'm', 'value': "L'Hosp"}]),
])
def test_correct_quote_trimming(filter_string, parsed_filter):
assert SearchFilter(SearchRuns(filter=filter_string))._parse() == parsed_filter
assert SearchFilter(filter_string=filter_string)._parse() == parsed_filter


@pytest.mark.parametrize("filter_string, error_message", [
Expand All @@ -117,7 +117,7 @@ def test_correct_quote_trimming(filter_string, parsed_filter):
])
def test_error_filter(filter_string, error_message):
with pytest.raises(MlflowException) as e:
SearchFilter(SearchRuns(filter=filter_string))._parse()
SearchFilter(filter_string=filter_string)._parse()
assert error_message in e.value.message


Expand All @@ -130,7 +130,7 @@ def test_error_filter(filter_string, error_message):
])
def test_error_comparison_clauses(filter_string, error_message):
with pytest.raises(MlflowException) as e:
SearchFilter(SearchRuns(filter=filter_string))._parse()
SearchFilter(filter_string=filter_string)._parse()
assert error_message in e.value.message


Expand All @@ -143,7 +143,7 @@ def test_error_comparison_clauses(filter_string, error_message):
])
def test_bad_quotes(filter_string, error_message):
with pytest.raises(MlflowException) as e:
SearchFilter(SearchRuns(filter=filter_string))._parse()
SearchFilter(filter_string=filter_string)._parse()
assert error_message in e.value.message


Expand All @@ -158,5 +158,5 @@ def test_bad_quotes(filter_string, error_message):
])
def test_invalid_clauses(filter_string, error_message):
with pytest.raises(MlflowException) as e:
SearchFilter(SearchRuns(filter=filter_string))._parse()
SearchFilter(filter_string=filter_string)._parse()
assert error_message in e.value.message

0 comments on commit 167b280

Please sign in to comment.