Skip to content

Commit

Permalink
Make errors from RestStore more descriptive. (mlflow#582)
Browse files Browse the repository at this point in the history
* init

* address comments

* address corey again
  • Loading branch information
andrewmchen committed Oct 3, 2018
1 parent 1d9924a commit 53de566
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 12 deletions.
26 changes: 26 additions & 0 deletions mlflow/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,31 @@
import json

from mlflow.protos.databricks_pb2 import INTERNAL_ERROR, ErrorCode


class MlflowException(Exception):
"""Base exception in MLflow."""
def __init__(self, message, error_code=INTERNAL_ERROR):
try:
self.error_code = ErrorCode.Name(error_code)
except (ValueError, TypeError):
self.error_code = ErrorCode.Name(INTERNAL_ERROR)
self.message = message
super(MlflowException, self).__init__(message)

def serialize_as_json(self):
return json.dumps({'error_code': self.error_code, 'message': self.message})


class RestException(MlflowException):
"""Exception thrown on non 200-level responses from the REST API"""
def __init__(self, json):
error_code = json['error_code']
message = error_code
if 'message' in json:
message = "%s: %s" % (error_code, json['message'])
super(RestException, self).__init__(message, error_code=error_code)
self.json = json


class IllegalArtifactPathError(MlflowException):
Expand Down
35 changes: 35 additions & 0 deletions mlflow/server/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import re
import six

from functools import wraps
from flask import Response, request, send_file
from querystring_parser import parser

from mlflow.entities import Metric, Param, RunTag, ViewType
from mlflow.exceptions import MlflowException
from mlflow.protos import databricks_pb2
from mlflow.protos.service_pb2 import CreateExperiment, MlflowService, GetExperiment, \
GetRun, SearchRuns, ListArtifacts, GetMetricHistory, CreateRun, \
Expand Down Expand Up @@ -59,6 +61,19 @@ def _get_request_message(request_message, flask_request=request):
return request_message


def catch_mlflow_exception(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except MlflowException as e:
response = Response(mimetype='application/json')
response.set_data(e.serialize_as_json())
response.status_code = 500
return response
return wrapper


def get_handler(request_class):
"""
:param request_class: The type of protobuf message
Expand All @@ -71,6 +86,7 @@ def get_handler(request_class):
'csv', 'tsv', 'md', 'rst', 'MLmodel', 'MLproject']


@catch_mlflow_exception
def get_artifact_handler():
query_string = request.query_string.decode('utf-8')
request_dict = parser.parse(query_string, normalized=True)
Expand All @@ -89,6 +105,7 @@ def _not_implemented():
return response


@catch_mlflow_exception
def _create_experiment():
request_message = _get_request_message(CreateExperiment())
experiment_id = _get_store().create_experiment(request_message.name,
Expand All @@ -100,6 +117,7 @@ def _create_experiment():
return response


@catch_mlflow_exception
def _get_experiment():
request_message = _get_request_message(GetExperiment())
response_message = GetExperiment.Response()
Expand All @@ -113,6 +131,7 @@ def _get_experiment():
return response


@catch_mlflow_exception
def _delete_experiment():
request_message = _get_request_message(DeleteExperiment())
_get_store().delete_experiment(request_message.experiment_id)
Expand All @@ -122,6 +141,7 @@ def _delete_experiment():
return response


@catch_mlflow_exception
def _restore_experiment():
request_message = _get_request_message(RestoreExperiment())
_get_store().restore_experiment(request_message.experiment_id)
Expand All @@ -131,6 +151,7 @@ def _restore_experiment():
return response


@catch_mlflow_exception
def _create_run():
request_message = _get_request_message(CreateRun())

Expand All @@ -154,6 +175,7 @@ def _create_run():
return response


@catch_mlflow_exception
def _update_run():
request_message = _get_request_message(UpdateRun())
updated_info = _get_store().update_run_info(request_message.run_uuid, request_message.status,
Expand All @@ -164,6 +186,7 @@ def _update_run():
return response


@catch_mlflow_exception
def _delete_run():
request_message = _get_request_message(DeleteRun())
_get_store().delete_run(request_message.run_id)
Expand All @@ -173,6 +196,7 @@ def _delete_run():
return response


@catch_mlflow_exception
def _restore_run():
request_message = _get_request_message(RestoreRun())
_get_store().restore_run(request_message.run_id)
Expand All @@ -182,6 +206,7 @@ def _restore_run():
return response


@catch_mlflow_exception
def _log_metric():
request_message = _get_request_message(LogMetric())
metric = Metric(request_message.key, request_message.value, request_message.timestamp)
Expand All @@ -192,6 +217,7 @@ def _log_metric():
return response


@catch_mlflow_exception
def _log_param():
request_message = _get_request_message(LogParam())
param = Param(request_message.key, request_message.value)
Expand All @@ -202,6 +228,7 @@ def _log_param():
return response


@catch_mlflow_exception
def _set_tag():
request_message = _get_request_message(SetTag())
tag = RunTag(request_message.key, request_message.value)
Expand All @@ -212,6 +239,7 @@ def _set_tag():
return response


@catch_mlflow_exception
def _get_run():
request_message = _get_request_message(GetRun())
response_message = GetRun.Response()
Expand All @@ -221,6 +249,7 @@ def _get_run():
return response


@catch_mlflow_exception
def _search_runs():
request_message = _get_request_message(SearchRuns())
response_message = SearchRuns.Response()
Expand All @@ -236,6 +265,7 @@ def _search_runs():
return response


@catch_mlflow_exception
def _list_artifacts():
request_message = _get_request_message(ListArtifacts())
response_message = ListArtifacts.Response()
Expand All @@ -252,6 +282,7 @@ def _list_artifacts():
return response


@catch_mlflow_exception
def _get_metric_history():
request_message = _get_request_message(GetMetricHistory())
response_message = GetMetricHistory.Response()
Expand All @@ -263,6 +294,7 @@ def _get_metric_history():
return response


@catch_mlflow_exception
def _get_metric():
request_message = _get_request_message(GetMetric())
response_message = GetMetric.Response()
Expand All @@ -273,6 +305,7 @@ def _get_metric():
return response


@catch_mlflow_exception
def _get_param():
request_message = _get_request_message(GetParam())
response_message = GetParam.Response()
Expand All @@ -283,6 +316,7 @@ def _get_param():
return response


@catch_mlflow_exception
def _list_experiments():
request_message = _get_request_message(ListExperiments())
experiment_entities = _get_store().list_experiments(request_message.view_type)
Expand All @@ -293,6 +327,7 @@ def _list_experiments():
return response


@catch_mlflow_exception
def _get_artifact_repo(run):
store = _get_store()
if run.info.artifact_uri:
Expand Down
11 changes: 1 addition & 10 deletions mlflow/store/rest_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json

from mlflow.exceptions import RestException
from mlflow.store.abstract_store import AbstractStore

from mlflow.entities import Experiment, Run, RunInfo, RunTag, Param, Metric, ViewType
Expand Down Expand Up @@ -35,16 +36,6 @@ def _api_method_to_info():
_METHOD_TO_INFO = _api_method_to_info()


class RestException(Exception):
"""Exception thrown on 400-level errors from the REST API"""
def __init__(self, json):
message = json['error_code']
if 'message' in json:
message = "%s: %s" % (message, json['message'])
super(RestException, self).__init__(message)
self.json = json


class RestStore(AbstractStore):
"""
Client for a remote tracking server accessed via REST API calls
Expand Down
5 changes: 4 additions & 1 deletion mlflow/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import os.path
import re

from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE

_VALID_PARAM_AND_METRIC_NAMES = re.compile(r"^[/\w.\- ]*$")

# Regex for valid run IDs: must be a 32-character hex string.
Expand Down Expand Up @@ -55,4 +58,4 @@ def _validate_tag_name(name):
def _validate_run_id(run_id):
"""Check that `run_id` is a valid run ID and raise an exception if it isn't."""
if _RUN_ID_REGEX.match(run_id) is None:
raise Exception("Invalid run ID: '%s'" % run_id)
raise MlflowException("Invalid run ID: '%s'" % run_id, error_code=INVALID_PARAMETER_VALUE)
19 changes: 18 additions & 1 deletion tests/server/test_handlers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import json

import mock
import pytest

from mlflow.entities import ViewType
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INTERNAL_ERROR, ErrorCode
from mlflow.server.handlers import get_endpoints, _create_experiment, _get_request_message, \
_search_runs
_search_runs, catch_mlflow_exception
from mlflow.protos.service_pb2 import CreateExperiment, SearchRuns


Expand Down Expand Up @@ -72,3 +76,16 @@ def test_search_runs_default_view_type(mock_get_request_message, mock_store):
_search_runs()
args, _ = mock_store.search_runs.call_args
assert args[2] == ViewType.ACTIVE_ONLY


def test_catch_mlflow_exception():
@catch_mlflow_exception
def test_handler():
raise MlflowException('test error', error_code=INTERNAL_ERROR)

# pylint: disable=assignment-from-no-return
response = test_handler()
json_response = json.loads(response.get_data())
assert response.status_code == 500
assert json_response['error_code'] == ErrorCode.Name(INTERNAL_ERROR)
assert json_response['message'] == 'test error'
19 changes: 19 additions & 0 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import json

from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE


class TestMlflowException(object):
def test_error_code_constructor(self):
assert MlflowException('test', error_code=INVALID_PARAMETER_VALUE).error_code == \
'INVALID_PARAMETER_VALUE'

def test_default_error_code(self):
assert MlflowException('test').error_code == 'INTERNAL_ERROR'

def test_serialize_to_json(self):
mlflow_exception = MlflowException('test')
deserialized = json.loads(mlflow_exception.serialize_as_json())
assert deserialized['message'] == 'test'
assert deserialized['error_code'] == 'INTERNAL_ERROR'

0 comments on commit 53de566

Please sign in to comment.