Skip to content

Commit

Permalink
Fixing search by filter string for params (mlflow#1042)
Browse files Browse the repository at this point in the history
Handling different types of quotes and special characters for keys and values in search filters.
  • Loading branch information
mparkhe committed Mar 29, 2019
1 parent 0f9aab2 commit 0ed27c5
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 30 deletions.
114 changes: 88 additions & 26 deletions mlflow/utils/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ class SearchFilter(object):
_ALTERNATE_PARAM_IDENTIFIERS = set(["param", "params"])
VALID_KEY_TYPE = set([_METRIC_IDENTIFIER] + list(_ALTERNATE_METRIC_IDENTIFIERS)
+ [_PARAM_IDENTIFIER] + list(_ALTERNATE_PARAM_IDENTIFIERS))
VALUE_TYPES = set([TokenType.Literal.String.Single,
TokenType.Literal.Number.Integer,
TokenType.Literal.Number.Float])
STRING_VALUE_TYPES = set([TokenType.Literal.String.Single])
NUMERIC_VALUE_TYPES = set([TokenType.Literal.Number.Integer, TokenType.Literal.Number.Float])

def __init__(self, search_runs=None):
self._filter_string = search_runs.filter if search_runs else None
Expand All @@ -37,19 +36,44 @@ def filter_string(self):
def search_expressions(self):
return self._search_expressions

@classmethod
def _trim_ends(cls, string_value):
return string_value[1:-1]

@classmethod
def _is_quoted(cls, value, pattern):
return len(value) >= 2 and value.startswith(pattern) and value.endswith(pattern)

@classmethod
def _trim_backticks(cls, entity_type):
if entity_type.startswith("`"):
assert entity_type.endswith("`")
return entity_type[1:-1]
"""Remove backticks from identifier like `param`, if they exist."""
if cls._is_quoted(entity_type, "`"):
return cls._trim_ends(entity_type)
return entity_type

@classmethod
def _strip_quotes(cls, value, expect_quoted_value=False):
"""
Remove quotes for input string.
Values of type strings are expected to have quotes.
Keys containing special characters are also expected to be enclose in quotes.
"""
if cls._is_quoted(value, "'") or cls._is_quoted(value, '"'):
return cls._trim_ends(value)
elif expect_quoted_value:
raise MlflowException("Parameter value is either not quoted or unidentified quote "
"types used for string value %s. Use either single or double "
"quotes." % value, error_code=INVALID_PARAMETER_VALUE)
else:
return value

@classmethod
def _valid_entity_type(cls, entity_type):
entity_type = cls._trim_backticks(entity_type)
if entity_type not in cls.VALID_KEY_TYPE:
raise MlflowException("Invalid search expression type '%s'. "
"Valid values are '%s" % (entity_type, cls.VALID_KEY_TYPE))
"Valid values are '%s" % (entity_type, cls.VALID_KEY_TYPE),
error_code=INVALID_PARAMETER_VALUE)

if entity_type in cls._ALTERNATE_PARAM_IDENTIFIERS:
return cls._PARAM_IDENTIFIER
Expand All @@ -68,25 +92,56 @@ def _get_identifier(cls, identifier):
"'metric.<key> <comparator> <value>' or"
"'params.<key> <comparator> <value>'." % identifier,
error_code=INVALID_PARAMETER_VALUE)
return {"type": cls._valid_entity_type(entity_type), "key": key}
return {"type": cls._valid_entity_type(entity_type), "key": cls._strip_quotes(key)}

@classmethod
def _process_token(cls, token):
if token.ttype == TokenType.Operator.Comparison:
return {"comparator": token.value}
elif token.ttype in cls.VALUE_TYPES:
return {"value": token.value}
def _get_value(cls, identifier_type, token):
if identifier_type == cls._METRIC_IDENTIFIER:
if token.ttype not in cls.NUMERIC_VALUE_TYPES:
raise MlflowException("Expected numeric value type for metric. "
"Found {}".format(token.value),
error_code=INVALID_PARAMETER_VALUE)
return token.value
elif identifier_type == cls._PARAM_IDENTIFIER:
if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier):
return cls._strip_quotes(token.value, expect_quoted_value=True)
raise MlflowException("Expected string value type for parameter. "
"Found {}".format(token.value),
error_code=INVALID_PARAMETER_VALUE)
else:
return {}
# Expected to be either "param" or "metric".
raise MlflowException("Invalid identifier type. Expected one of "
"{}.".format([cls._METRIC_IDENTIFIER, cls._PARAM_IDENTIFIER]))

@classmethod
def _validate_comparison(cls, tokens):
base_error_string = "Invalid comparison clause"
if len(tokens) != 3:
raise MlflowException("{}. Expected 3 tokens found {}".format(base_error_string,
len(tokens)),
error_code=INVALID_PARAMETER_VALUE)
if not isinstance(tokens[0], Identifier):
raise MlflowException("{}. Expected 'Identifier' found '{}'".format(base_error_string,
str(tokens[0])),
error_code=INVALID_PARAMETER_VALUE)
if not isinstance(tokens[1], Token) and tokens[1].ttype != TokenType.Operator.Comparison:
raise MlflowException("{}. Expected comparison found '{}'".format(base_error_string,
str(tokens[1])),
error_code=INVALID_PARAMETER_VALUE)
if not isinstance(tokens[2], Token) and \
(tokens[2].ttype not in cls.STRING_VALUE_TYPES.union(cls.NUMERIC_VALUE_TYPES) or
isinstance(tokens[2], Identifier)):
raise MlflowException("{}. Expected value token found '{}'".format(base_error_string,
str(tokens[2])),
error_code=INVALID_PARAMETER_VALUE)

@classmethod
def _get_comparison(cls, comparison):
comp = {}
for t in comparison.tokens:
if isinstance(t, Identifier):
comp.update(cls._get_identifier(t.value))
elif isinstance(t, Token):
comp.update(cls._process_token(t))
stripped_comparison = [token for token in comparison.tokens if not token.is_whitespace]
cls._validate_comparison(stripped_comparison)
comp = cls._get_identifier(stripped_comparison[0].value)
comp["comparator"] = stripped_comparison[1].value
comp["value"] = cls._get_value(comp.get("type"), stripped_comparison[2])
return comp

@classmethod
Expand All @@ -106,7 +161,8 @@ def _process_statement(cls, statement):
invalids = list(filter(cls._invalid_statement_token, statement.tokens))
if len(invalids) > 0:
invalid_clauses = ", ".join("'%s'" % token for token in invalids)
raise MlflowException("Invalid clause(s) in filter string: %s" % invalid_clauses)
raise MlflowException("Invalid clause(s) in filter string: %s" % invalid_clauses,
error_code=INVALID_PARAMETER_VALUE)
return [cls._get_comparison(si) for si in statement.tokens if isinstance(si, Comparison)]

@classmethod
Expand All @@ -122,7 +178,8 @@ def search_expression_to_dict(cls, search_expression):
comparator = search_expression.metric.double.comparator
value = search_expression.metric.double.value
else:
raise MlflowException("Invalid metric type: '%s', expected float or double")
raise MlflowException("Invalid metric type: '%s', expected float or double",
error_code=INVALID_PARAMETER_VALUE)
return {
"type": cls._METRIC_IDENTIFIER,
"key": key,
Expand All @@ -140,7 +197,8 @@ def search_expression_to_dict(cls, search_expression):
"value": value
}
else:
raise MlflowException("Invalid search expression type '%s'" % key_type)
raise MlflowException("Invalid search expression type '%s'" % key_type,
error_code=INVALID_PARAMETER_VALUE)

def _parse(self):
if self._filter_string:
Expand Down Expand Up @@ -171,18 +229,22 @@ def does_run_match_clause(cls, run, sed):
if key_type == cls._METRIC_IDENTIFIER:
if comparator not in cls.VALID_METRIC_COMPARATORS:
raise MlflowException("Invalid comparator '%s' "
"not one of '%s" % (comparator, cls.VALID_METRIC_COMPARATORS))
"not one of '%s" % (comparator,
cls.VALID_METRIC_COMPARATORS),
error_code=INVALID_PARAMETER_VALUE)
metric = next((m for m in run.data.metrics if m.key == key), None)
lhs = metric.value if metric else None
value = float(value)
elif key_type == cls._PARAM_IDENTIFIER:
if comparator not in cls.VALID_PARAM_COMPARATORS:
raise MlflowException("Invalid comparator '%s' "
"not one of '%s" % (comparator, cls.VALID_PARAM_COMPARATORS))
"not one of '%s" % (comparator, cls.VALID_PARAM_COMPARATORS),
error_code=INVALID_PARAMETER_VALUE)
param = next((p for p in run.data.params if p.key == key), None)
lhs = param.value if param else None
else:
raise MlflowException("Invalid search expression type '%s'" % key_type)
raise MlflowException("Invalid search expression type '%s'" % key_type,
error_code=INVALID_PARAMETER_VALUE)

if lhs is None:
return False
Expand Down
73 changes: 69 additions & 4 deletions tests/utils/test_search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,29 +53,53 @@ def test_anded_expression_2():

@pytest.mark.parametrize("filter_string, parsed_filter", [
("metric.acc >= 0.94", [{'comparator': '>=', 'key': 'acc', 'type': 'metric', 'value': '0.94'}]),
("metric.acc>=100", [{'comparator': '>=', 'key': 'acc', 'type': 'metric', 'value': '100'}]),
("params.m!='tf'", [{'comparator': '!=', 'key': 'm', 'type': 'parameter', 'value': 'tf'}]),
('params."m"!="tf"', [{'comparator': '!=', 'key': 'm', 'type': 'parameter', 'value': 'tf'}]),
('metric."legit name" >= 0.243', [{'comparator': '>=',
'key': '"legit name"',
'key': 'legit name',
'type': 'metric',
'value': '0.243'}]),
("metrics.XYZ = 3", [{'comparator': '=', 'key': 'XYZ', 'type': 'metric', 'value': '3'}]),
('params."cat dog" = "pets"', [{'comparator': '=',
'key': 'cat dog',
'type': 'parameter',
'value': 'pets'}]),
('metrics."X-Y-Z" = 3', [{'comparator': '=', 'key': 'X-Y-Z', 'type': 'metric', 'value': '3'}]),
('metrics."X//Y#$$@&Z" = 3', [{'comparator': '=',
'key': 'X//Y#$$@&Z',
'type': 'metric',
'value': '3'}]),
("params.model = 'LinearRegression'", [{'comparator': '=',
'key': 'model',
'type': 'parameter',
'value': "'LinearRegression'"}]),
'value': "LinearRegression"}]),
("metrics.rmse < 1 and params.model_class = 'LR'", [
{'comparator': '<', 'key': 'rmse', 'type': 'metric', 'value': '1'},
{'comparator': '=', 'key': 'model_class', 'type': 'parameter', 'value': "'LR'"}
{'comparator': '=', 'key': 'model_class', 'type': 'parameter', 'value': "LR"}
]),
('', []),
("`metric`.a >= 0.1", [{'comparator': '>=', 'key': 'a', 'type': 'metric', 'value': '0.1'}]),
("`params`.model >= 'LR'", [{'comparator': '>=',
'key': 'model',
'type': 'parameter',
'value': "'LR'"}]),
'value': "LR"}]),
])
def test_filter(filter_string, parsed_filter):
assert SearchFilter(SearchRuns(filter=filter_string))._parse() == parsed_filter


@pytest.mark.parametrize("filter_string, parsed_filter", [
("params.m = 'LR'", [{'type': 'parameter', 'comparator': '=', 'key': 'm', 'value': 'LR'}]),
("params.m = \"LR\"", [{'type': 'parameter', 'comparator': '=', 'key': 'm', 'value': 'LR'}]),
('params.m = "LR"', [{'type': 'parameter', 'comparator': '=', 'key': 'm', 'value': 'LR'}]),
('params.m = "L\'Hosp"', [{'type': 'parameter', 'comparator': '=',
'key': 'm', 'value': "L'Hosp"}]),
])
def test_correct_quote_trimming(filter_string, parsed_filter):
assert SearchFilter(SearchRuns(filter=filter_string))._parse() == parsed_filter


@pytest.mark.parametrize("filter_string, error_message", [
("metric.acc >= 0.94; metrics.rmse < 1", "Search filter contained multiple expression"),
("m.acc >= 0.94", "Invalid search expression type"),
Expand All @@ -95,3 +119,44 @@ def test_error_filter(filter_string, error_message):
with pytest.raises(MlflowException) as e:
SearchFilter(SearchRuns(filter=filter_string))._parse()
assert error_message in e.value.message


@pytest.mark.parametrize("filter_string, error_message", [
("metric.model = 'LR'", "Expected numeric value type for metric"),
("metric.model = '5'", "Expected numeric value type for metric"),
("params.acc = 5", "Expected string value type for param"),
("metrics.acc != metrics.acc", "Expected numeric value type for metric"),
("1.0 > metrics.acc", "Expected 'Identifier' found"),
])
def test_error_comparison_clauses(filter_string, error_message):
with pytest.raises(MlflowException) as e:
SearchFilter(SearchRuns(filter=filter_string))._parse()
assert error_message in e.value.message


@pytest.mark.parametrize("filter_string, error_message", [
("params.acc = LR", "value is either not quoted or unidentified quote types"),
("params.'acc = LR", "Invalid clause(s) in filter string"),
("params.acc = 'LR", "Invalid clause(s) in filter string"),
("params.acc = LR'", "Invalid clause(s) in filter string"),
("params.acc = \"LR'", "Invalid clause(s) in filter string"),
])
def test_bad_quotes(filter_string, error_message):
with pytest.raises(MlflowException) as e:
SearchFilter(SearchRuns(filter=filter_string))._parse()
assert error_message in e.value.message


@pytest.mark.parametrize("filter_string, error_message", [
("params.acc LR !=", "Invalid clause(s) in filter string"),
("params.acc LR", "Invalid clause(s) in filter string"),
("metric.acc !=", "Invalid clause(s) in filter string"),
("acc != 1.0", "Invalid filter string"),
("foo is null", "Invalid clause(s) in filter string"),
("1=1", "Expected 'Identifier' found"),
("1==2", "Expected 'Identifier' found"),
])
def test_invalid_clauses(filter_string, error_message):
with pytest.raises(MlflowException) as e:
SearchFilter(SearchRuns(filter=filter_string))._parse()
assert error_message in e.value.message

0 comments on commit 0ed27c5

Please sign in to comment.