Skip to content

Commit

Permalink
Numpy array conversion in pyfunc scoring server (mlflow#1479)
Browse files Browse the repository at this point in the history
* numpy conversion

* added conversion from split and records to numpy array in pyfunc

* formatting fix

* formatting fix

* remove unused imports

* remove decode utf-8

* change the name of some tests

* keep ordereddict for python2.7

* using ordereddict in tests

* revert two lines of existing pandas variables

* wrap numpy in pandas to be consistent with other conversion apis

* remove comments...

* include columns in records orient

* remove dtype setting in numpy

* format minor fix in tests

* remove records oriented support

* use infer object to convert pd dataframe to the correct data type

* fix format

* mistakenly removed 1 test, adding it back
  • Loading branch information
lennon310 authored and tomasatdatabricks committed Jun 26, 2019
1 parent cebc16b commit bc374f6
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 4 deletions.
33 changes: 30 additions & 3 deletions mlflow/pyfunc/scoring_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"""
from __future__ import print_function

from collections import OrderedDict
import flask
import json
from json import JSONEncoder
Expand Down Expand Up @@ -47,12 +48,14 @@
CONTENT_TYPE_JSON = "application/json"
CONTENT_TYPE_JSON_RECORDS_ORIENTED = "application/json; format=pandas-records"
CONTENT_TYPE_JSON_SPLIT_ORIENTED = "application/json; format=pandas-split"
CONTENT_TYPE_JSON_SPLIT_NUMPY = "application/json-numpy-split"

CONTENT_TYPES = [
CONTENT_TYPE_CSV,
CONTENT_TYPE_JSON,
CONTENT_TYPE_JSON_RECORDS_ORIENTED,
CONTENT_TYPE_JSON_SPLIT_ORIENTED
CONTENT_TYPE_JSON_SPLIT_ORIENTED,
CONTENT_TYPE_JSON_SPLIT_NUMPY
]

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -95,6 +98,28 @@ def parse_csv_input(csv_input):
error_code=MALFORMED_REQUEST)


def parse_split_oriented_json_input_to_numpy(json_input):
"""
:param json_input: A JSON-formatted string representation of a Pandas DataFrame with split
orient, or a stream containing such a string representation.
"""
# pylint: disable=broad-except
try:
json_input_list = json.loads(json_input, object_pairs_hook=OrderedDict)
return pd.DataFrame(index=json_input_list['index'],
data=np.array(json_input_list['data'], dtype=object),
columns=json_input_list['columns']).infer_objects()
except Exception:
_handle_serving_error(
error_message=(
"Failed to parse input as a Numpy. Ensure that the input is"
" a valid JSON-formatted Pandas DataFrame with the split orient"
" produced using the `pandas.DataFrame.to_json(..., orient='split')`"
" method."
),
error_code=MALFORMED_REQUEST)


def predictions_to_json(raw_predictions, output):
predictions = _get_jsonable_obj(raw_predictions, pandas_orient="records")
json.dump(predictions, output, cls=NumpyEncoder)
Expand Down Expand Up @@ -140,8 +165,8 @@ def ping(): # pylint: disable=unused-variable
def transformation(): # pylint: disable=unused-variable
"""
Do an inference on a single batch of data. In this sample server,
we take data as CSV or json, convert it to a Pandas DataFrame,
generate predictions and convert them back to CSV.
we take data as CSV or json, convert it to a Pandas DataFrame or Numpy,
generate predictions and convert them back to json.
"""
# Convert from CSV to pandas
if flask.request.content_type == CONTENT_TYPE_CSV:
Expand All @@ -154,6 +179,8 @@ def transformation(): # pylint: disable=unused-variable
elif flask.request.content_type == CONTENT_TYPE_JSON_RECORDS_ORIENTED:
data = parse_json_input(json_input=flask.request.data.decode('utf-8'),
orient="records")
elif flask.request.content_type == CONTENT_TYPE_JSON_SPLIT_NUMPY:
data = parse_split_oriented_json_input_to_numpy(flask.request.data.decode('utf-8'))
else:
return flask.Response(
response=("This predictor only supports the following content types,"
Expand Down
42 changes: 41 additions & 1 deletion tests/pyfunc/test_scoring_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import pandas as pd
import numpy as np
from collections import namedtuple
from collections import namedtuple, OrderedDict

import pytest
import sklearn.datasets as datasets
Expand Down Expand Up @@ -153,6 +153,19 @@ def test_scoring_server_successfully_evaluates_correct_dataframes_with_pandas_sp
assert response.status_code == 200


@pytest.mark.large
def test_scoring_server_successfully_evaluates_correct_split_to_numpy(
sklearn_model, model_path):
mlflow.sklearn.save_model(sk_model=sklearn_model.model, path=model_path)

pandas_split_content = pd.DataFrame(sklearn_model.inference_data).to_json(orient="split")
response_records_content_type = pyfunc_serve_and_score_model(
model_uri=os.path.abspath(model_path),
data=pandas_split_content,
content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON_SPLIT_NUMPY)
assert response_records_content_type.status_code == 200


@pytest.mark.large
def test_scoring_server_responds_to_invalid_content_type_request_with_unsupported_content_type_code(
sklearn_model, model_path):
Expand Down Expand Up @@ -190,6 +203,22 @@ def test_parse_json_input_split_oriented():
assert all(p1 == p2)


@pytest.mark.large
def test_parse_json_input_split_oriented_to_numpy_array():
size = 200
data = OrderedDict([("col_m", [random_int(0, 1000) for _ in range(size)]),
("col_z", [random_str(4) for _ in range(size)]),
("col_a", [random_int() for _ in range(size)])])
p0 = pd.DataFrame.from_dict(data)
np_array = np.array([[a, b, c] for a, b, c in
zip(data['col_m'], data['col_z'], data['col_a'])],
dtype=object)
p1 = pd.DataFrame(np_array).infer_objects()
p2 = pyfunc_scoring_server.parse_split_oriented_json_input_to_numpy(
p0.to_json(orient="split"))
np.testing.assert_array_equal(p1, p2)


@pytest.mark.large
def test_records_oriented_json_to_df():
# test that datatype for "zip" column is not converted to "int64"
Expand All @@ -215,6 +244,17 @@ def test_split_oriented_json_to_df():
assert set(str(dt) for dt in df.dtypes) == {'object', 'float64', 'int64'}


@pytest.mark.large
def test_split_oriented_json_to_numpy_array():
# test that datatype for "zip" column is not converted to "int64"
jstr = '{"columns":["zip","cost","count"],"index":[0,1,2],' \
'"data":[["95120",10.45,-8],["95128",23.0,-1],["95128",12.1,1000]]}'
df = pyfunc_scoring_server.parse_split_oriented_json_input_to_numpy(jstr)

assert set(df.columns) == {'zip', 'cost', 'count'}
assert set(str(dt) for dt in df.dtypes) == {'object', 'float64', 'int64'}


def test_get_jsonnable_obj():
from mlflow.pyfunc.scoring_server import _get_jsonable_obj
from mlflow.pyfunc.scoring_server import NumpyEncoder
Expand Down

0 comments on commit bc374f6

Please sign in to comment.