Skip to content

Commit

Permalink
Make flavor implementation of load_pyfunc private, i.e. falvor_loader…
Browse files Browse the repository at this point in the history
…_module.load_pyfunc is renamed to flavor_loader_module._load_pyfunc for all flavors. (mlflow#539)
  • Loading branch information
tomasatdatabricks authored and aarondav committed Sep 24, 2018
1 parent cf6633d commit e27d677
Show file tree
Hide file tree
Showing 12 changed files with 25 additions and 79 deletions.
6 changes: 3 additions & 3 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ A ``python_function`` model directory must contain an ``MLmodel`` file in its ro
(for example, ``mlflow.sklearn``) importable via ``importlib.import_module``.
The imported module must contain a function with the following signature:

load_pyfunc(path: string) -> <pyfunc model>
_load_pyfunc(path: string) -> <pyfunc model>

The path argument is specified by the ``data`` parameter and may refer to a file or directory.

Expand Down Expand Up @@ -208,7 +208,7 @@ The ``mleap`` model flavor supports saving models using the MLeap persistence me
PyTorch (``pytorch``)
^^^^^^^^^^^^^^^^^^^^^

The ``pytorch`` model flavor enables logging and loading PyTorch models. Model is completely stored in `.pth` format using `torch.save(model)` method. Given a directory containing a saved model, you can log the model to MLflow via ``log_saved_model``. The saved model can then be loaded for inference via ``load_pyfunc()``. For more information, see :py:mod:`mlflow.pytorch`.
The ``pytorch`` model flavor enables logging and loading PyTorch models. Model is completely stored in `.pth` format using `torch.save(model)` method. Given a directory containing a saved model, you can log the model to MLflow via ``log_saved_model``. The saved model can then be loaded for inference via ``mlflow.pyfunc.load_pyfunc()``. For more information, see :py:mod:`mlflow.pytorch`.

Scikit-learn (``sklearn``)
^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -232,7 +232,7 @@ MLflow. For more information, see :py:mod:`mlflow.spark`.
TensorFlow (``tensorflow``)
^^^^^^^^^^^^^^^^^^^^^^^^^^^

The ``tensorflow`` model flavor enables logging TensorFlow ``Saved Models`` and loading them back as ``Python Function`` models for inference on pandas DataFrames. Given a directory containing a saved model, you can log the model to MLflow via ``log_saved_model`` and then load the saved model for inference using ``load_pyfunc``. For more information, see :py:mod:`mlflow.tensorflow`.
The ``tensorflow`` model flavor enables logging TensorFlow ``Saved Models`` and loading them back as ``Python Function`` models for inference on pandas DataFrames. Given a directory containing a saved model, you can log the model to MLflow via ``log_saved_model`` and then load the saved model for inference using ``mlflow.pyfunc.load_pyfunc``. For more information, see :py:mod:`mlflow.tensorflow`.

Custom Flavors
--------------
Expand Down
6 changes: 3 additions & 3 deletions examples/tensorflow/train_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import print_function

import mlflow
from mlflow import tensorflow, tracking
from mlflow import tensorflow, tracking, pyfunc
import numpy as np
import pandas as pd
import shutil
Expand Down Expand Up @@ -38,10 +38,10 @@ def main(argv):
# Logging the saved model
tensorflow.log_saved_model(saved_model_dir=saved_estimator_path, signature_def_key="predict", artifact_path="model")
# Reloading the model
pyfunc = tensorflow.load_pyfunc(saved_estimator_path)
pyfunc_model = pyfunc.load_pyfunc(saved_estimator_path)
df = pd.DataFrame(data=x_test, columns=["features"] * x_train.shape[1])
# Predicting on the loaded Python Function
predict_df = pyfunc.predict(df)
predict_df = pyfunc_model.predict(df)
predict_df['original_labels'] = y_test
print(predict_df)
finally:
Expand Down
11 changes: 2 additions & 9 deletions mlflow/h2o.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,9 @@ def predict(self, dataframe):
return predicted


def load_pyfunc(path):
def _load_pyfunc(path):
"""
Load a persisted H2O model as a ``python_function`` model.
This method calls ``h2o.init``, so the right version of h2o(-py) must be in the
environment. The arguments given to ``h2o.init`` can be customized in ``path/h2o.yaml``
under the key ``init``.
:param path: Local filesystem path to the model saved by :py:func:`mlflow.h2o.save_model`.
:rtype: Pyfunc format model with function
``model.predict(pandas DataFrame) -> pandas DataFrame``.
Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``.
"""
return _H2OModelWrapper(_load_model(path, init=True))

Expand Down
13 changes: 2 additions & 11 deletions mlflow/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,9 @@ def predict(self, dataframe):
return predicted


def load_pyfunc(model_file):
def _load_pyfunc(model_file):
"""
Load a persisted Keras model as a ``python_function`` model.
:param model_file: Local filesystem path to model saved by :py:func:`mlflow.keras.log_model`.
:rtype: Pyfunc format model with function
``model.predict(pandas DataFrame) -> pandas DataFrame``.
>>> model_file = "/tmp/pyfunc-keras-model"
>>> keras_model = mlflow.keras.load_pyfunc(model_file)
>>> # We can apply the loaded PyFunc for inference on a pandas DataFrame via predict()
>>> predictions = keras_model.predict(x_test)
Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``.
"""
if K._BACKEND == 'tensorflow':
import tensorflow as tf
Expand Down
6 changes: 3 additions & 3 deletions mlflow/pyfunc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
e.g. ``mlflow.sklearn``, it will be imported via ``importlib.import_module``.
The imported module must contain function with the following signature::
load_pyfunc(path: string) -> <pyfunc model>
_load_pyfunc(path: string) -> <pyfunc model>
The path argument is specified by the ``data`` parameter and may refer to a file or
directory.
Expand Down Expand Up @@ -166,7 +166,7 @@ def load_pyfunc(path, run_id=None, suppress_warnings=False):
code_path = os.path.join(path, conf[CODE])
sys.path = [code_path] + _get_code_dirs(code_path) + sys.path
data_path = os.path.join(path, conf[DATA]) if (DATA in conf) else path
return importlib.import_module(conf[MAIN]).load_pyfunc(data_path)
return importlib.import_module(conf[MAIN])._load_pyfunc(data_path)


def _warn_potentially_incompatible_py_version_if_necessary(model_py_version):
Expand Down Expand Up @@ -327,6 +327,6 @@ def get_module_loader_src(src_path, dst_path):
import sys
def load_pyfunc():
{update_path}return importlib.import_module('{main}').load_pyfunc('{data_path}')
{update_path}return importlib.import_module('{main}')._load_pyfunc('{data_path}')
"""
23 changes: 2 additions & 21 deletions mlflow/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,28 +177,9 @@ def load_model(path, run_id=None, **kwargs):
return _load_model(path, **kwargs)


def load_pyfunc(path, **kwargs):
def _load_pyfunc(path, **kwargs):
"""
Load a persisted PyTorch model as a ``python_function`` model.
The loaded PyFunc exposes a ``predict(pd.DataFrame) -> pd.DataFrame``
method that, given an input DataFrame of n rows and k float-valued columns, feeds a
corresponding (n x k) ``torch.FloatTensor`` (or ``torch.cuda.FloatTensor``) as input to the
PyTorch model. ``predict`` returns the model's predictions (output tensor) in a single column
DataFrame.
:param path: Local filesystem path to the model saved by :py:func:`mlflow.pytorch.log_model`.
:param kwargs: kwargs to pass to ``torch.load`` method.
:rtype: Pyfunc format model with function
``model.predict(pandas DataFrame) -> pandas DataFrame``.
>>> import torch
>>> import mlflow
>>> import mlflow.pytorch
>>> # set values
>>> model_path_dir = ...
>>> new_pandas_df = ...
>>> pytorch_model = mlfow.pytorch.load_pyfunc(model_path_dir)
>>> predictions = pytorch_model.predict(new_pandas_df)
Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``.
"""
return _PyTorchWrapper(_load_model(os.path.dirname(path), **kwargs))

Expand Down
9 changes: 2 additions & 7 deletions mlflow/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,10 @@ def _load_model_from_local_file(path):
return pickle.load(f)


def load_pyfunc(path):
def _load_pyfunc(path):
"""
Load a persisted scikit-learn model as a ``python_function`` model.
:param path: Local filesystem path to the model saved by :py:func:`mlflow.sklearn.save_model`.
:rtype: Pyfunc format model with function
``model.predict(pandas DataFrame) -> pandas DataFrame``.
Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``.
"""

with open(path, "rb") as f:
return pickle.load(f)

Expand Down
11 changes: 2 additions & 9 deletions mlflow/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,16 +258,9 @@ def load_model(path, run_id=None, dfs_tmpdir=None):
return _load_model(model_path=model_path, dfs_tmpdir=dfs_tmpdir)


def load_pyfunc(path):
def _load_pyfunc(path):
"""
Load a persisted Spark MLlib PipelineModel as a ``python_function`` model.
>>> pyfunc_model = load_pyfunc("/tmp/pyfunc-spark-model")
>>> predictions = pyfunc_model.predict(test_pandas_df)
:param path: Local filesystem path to the model saved by :py:func:`mlflow.spark.log_model`.
:rtype: Pyfunc format model with function
``model.predict(pandas DataFrame) -> pandas DataFrame``.
Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``.
"""
# NOTE: The getOrCreate() call below may change settings of the active session which we do not
# intend to do here. In particular, setting master to local[1] can break distributed clusters.
Expand Down
13 changes: 3 additions & 10 deletions mlflow/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
You must save your own ``saved_model`` and pass its
path to ``log_saved_model(saved_model_dir)``. To load the model to predict on it, you call
``model = load_pyfunc(saved_model_dir)`` followed by
``model = pyfunc.load_pyfunc(saved_model_dir)`` followed by
``prediction = model.predict(pandas DataFrame)`` to obtain a prediction in a pandas DataFrame.
The loaded :py:mod:`mlflow.pyfunc` model *does not* expose any APIs for model training.
Expand Down Expand Up @@ -81,15 +81,8 @@ def log_saved_model(saved_model_dir, signature_def_key, artifact_path):
log_artifacts(saved_model_dir, artifact_path)


def load_pyfunc(saved_model_dir):
def _load_pyfunc(saved_model_dir):
"""
Load a persisted TensorFlow model as a PyFunc.
The loaded model object exposes a ``predict(pandas DataFrame)`` method that returns a pandas
DataFrame containing the model's inference output on an input DataFrame.
:param saved_model_dir: Directory where the TensorFlow model is saved.
:rtype: Pyfunc format model with function
``model.predict(pandas DataFrame) -> pandas DataFrame``.
Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``.
"""
return _TFWrapper(saved_model_dir)
2 changes: 1 addition & 1 deletion tests/azureml/test_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mlflow.azureml import cli


def load_pyfunc(path):
def _load_pyfunc(path):
with open(path, "rb") as f:
return pickle.load(f)

Expand Down
2 changes: 1 addition & 1 deletion tests/pyfunc/test_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from mlflow.utils.file_utils import TempDir


def load_pyfunc(path):
def _load_pyfunc(path):
with open(path, "rb") as f:
if six.PY2:
return pickle.load(f)
Expand Down
2 changes: 1 addition & 1 deletion tests/sklearn/test_sklearn_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from mlflow.utils.environment import _mlflow_conda_env


def load_pyfunc(path):
def _load_pyfunc(path):
with open(path, "rb") as f:
return pickle.load(f)

Expand Down

0 comments on commit e27d677

Please sign in to comment.