Skip to content

Commit

Permalink
Handle PyTorch models mlflow#198 (mlflow#264)
Browse files Browse the repository at this point in the history
Add support for logging & loading PyTorch models
  • Loading branch information
vfdev-5 authored and smurching committed Aug 17, 2018
1 parent 7bd4c51 commit b00826f
Show file tree
Hide file tree
Showing 5 changed files with 310 additions and 1 deletion.
6 changes: 6 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ TensorFlow (``tensorflow``)

The ``tensorflow`` model flavor enables logging TensorFlow ``Saved Models`` and loading them back as ``Python Function`` models for inference on pandas DataFrames. Given a directory containing a saved model, you can log the model to MLflow via ``log_saved_model``. The saved model can then be loaded for inference via ``load_pyfunc()``. For more information, see :py:mod:`mlflow.tensorflow`.

PyTorch (``pytorch``)
^^^^^^^^^^^^^^^^^^^^^^^^^^^

The ``pytorch`` model flavor enables logging and loading PyTorch models. Model is completely stored in `.pth` format using `torch.save(model)` method. Given a directory containing a saved model, you can log the model to MLflow via ``log_saved_model``. The saved model can then be loaded for inference via ``load_pyfunc()``. For more information, see :py:mod:`mlflow.pytorch`.


H\ :sub:`2`\ O (``h2o``)
^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
7 changes: 7 additions & 0 deletions docs/source/python_api/mlflow.pytorch.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mlflow.pytorch
==================

.. automodule:: mlflow.pytorch
:members:
:undoc-members:
:show-inheritance:
145 changes: 145 additions & 0 deletions mlflow/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
MLflow integration for PyTorch.
Manages logging and loading PyTorch models; logged models can be loaded back as PyTorch
models or as Python Function models.
"""

from __future__ import absolute_import

import os

import numpy as np
import pandas as pd
import torch

from mlflow import pyfunc
from mlflow.models import Model
import mlflow.tracking


FLAVOR_NAME = "pytorch"


def log_model(pytorch_model, artifact_path, conda_env=None, **kwargs):
"""
Log a PyTorch model as an MLflow artifact for the current run.
:param pytorch_model: PyTorch model to be saved. Must accept a single torch.FloatTensor as input
and produce a single output tensor.
:param artifact_path: Run-relative artifact path.
:param conda_env: Path to a Conda environment file. If provided, this defines the environment
for the model. At minimum, it should specify Python, PyTorch and MLflow with appropriate
versions.
:param kwargs: kwargs to pass to ``torch.save`` method
"""
Model.log(artifact_path=artifact_path, flavor=mlflow.pytorch,
pytorch_model=pytorch_model, conda_env=conda_env, **kwargs)


def save_model(pytorch_model, path, conda_env=None, mlflow_model=Model(), **kwargs):
"""
Save a PyTorch model to a path on the local file system.
:param pytorch_model: PyTorch model to be saved. Must accept a single torch.FloatTensor as input
and produce a single output tensor.
:param path: Local path where the model is to be saved.
:param conda_env: Path to a Conda environment file. If provided, this decribes the environment
this model should be run in. At minimum, it should specify Python, PyTorch
and MLflow with appropriate versions.
:param mlflow_model: MLflow model config this flavor is being added to.
:param kwargs: kwargs to pass to ``torch.save`` method
"""

if not isinstance(pytorch_model, torch.nn.Module):
raise TypeError("Argument 'pytorch_model' should be a torch.nn.Module")

path = os.path.abspath(path)
if os.path.exists(path):
raise RuntimeError("Path '{}' already exists".format(path))
os.makedirs(path)
model_path = os.path.join(path, "model.pth")

# Save pytorch model
torch.save(pytorch_model, model_path, **kwargs)
model_file = os.path.basename(model_path)

mlflow_model.add_flavor(FLAVOR_NAME, model_data=model_file, pytorch_version=torch.__version__)
pyfunc.add_to_model(mlflow_model, loader_module="mlflow.pytorch",
data=model_file, env=conda_env)
mlflow_model.save(os.path.join(path, "MLmodel"))


def _load_model(path, **kwargs):
mlflow_model_path = os.path.join(path, "MLmodel")
if not os.path.exists(mlflow_model_path):
raise RuntimeError("MLmodel is not found at '{}'".format(path))

mlflow_model = Model.load(mlflow_model_path)

if FLAVOR_NAME not in mlflow_model.flavors:
raise ValueError("Could not find flavor '{}' amongst available flavors {}, "
"unable to load stored model"
.format(FLAVOR_NAME, list(mlflow_model.flavors.keys())))

# This maybe replaced by a warning and then try/except torch.load
flavor = mlflow_model.flavors[FLAVOR_NAME]
if torch.__version__ != flavor["pytorch_version"]:
raise ValueError("Unfortunately stored model version '{}' does not match "
"installed PyTorch version '{}'"
.format(flavor["pytorch_version"], torch.__version__))

path = os.path.abspath(path)
path = os.path.join(path, mlflow_model.flavors[FLAVOR_NAME]['model_data'])
return torch.load(path, **kwargs)


def load_model(path, run_id=None, **kwargs):
"""
Load a PyTorch model from a local file (if run_id is None) or a run.
:param path: Local filesystem path or Run-relative artifact path to the model saved
by `mlflow.pytorch.log_model`.
:param run_id: Run ID. If provided it is combined with path to identify the model.
:param kwargs: kwargs to pass to `torch.load` method
"""
if run_id is not None:
path = mlflow.tracking.utils._get_model_log_dir(model_name=path, run_id=run_id)

return _load_model(path, **kwargs)


def load_pyfunc(path, **kwargs):
"""
Load the model as PyFunc. The loaded PyFunc exposes a `predict(pd.DataFrame) -> pd.DataFrame`
method that, given an input DataFrame of n rows and k float-valued columns, feeds a
corresponding (n x k) torch.FloatTensor (or torch.cuda.FloatTensor) as input to the PyTorch
model. ``predict`` returns the model's predictions (output tensor) in a single-column DataFrame.
:param path: Local filesystem path to the model saved by `mlflow.pytorch.log_model`.
:param kwargs: kwargs to pass to `torch.load` method.
"""
return _PyTorchWrapper(_load_model(os.path.dirname(path), **kwargs))


class _PyTorchWrapper(object):
"""
Wrapper class that creates a predict function such that
predict(data: pd.DataFrame) -> model's output as pd.DataFrame (pandas DataFrame)
"""
def __init__(self, pytorch_model):
self.pytorch_model = pytorch_model

def predict(self, data, device='cpu'):
if not isinstance(data, pd.DataFrame):
raise TypeError("Input data should be pandas.DataFrame")
self.pytorch_model.to(device)
self.pytorch_model.eval()
with torch.no_grad():
input_tensor = torch.from_numpy(data.values.astype(np.float32)).to(device)
preds = self.pytorch_model(input_tensor)
if not isinstance(preds, torch.Tensor):
raise TypeError("Expected PyTorch model to output a single output tensor, "
"but got output of type '{}'".format(type(preds)))
predicted = pd.DataFrame(preds.numpy())
predicted.index = data.index
return predicted
4 changes: 3 additions & 1 deletion test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@ pytest-cov
rstcheck==3.2
scipy
tensorflow
http://download.pytorch.org/whl/cpu/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl ; python_version == '2.7'
http://download.pytorch.org/whl/cpu/torch-0.4.1-cp36-cp36m-linux_x86_64.whl ; python_version == '3.6'
pysftp
keras
keras
149 changes: 149 additions & 0 deletions tests/pytorch/test_pytorch_model_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from __future__ import print_function

import pytest

import numpy as np
import pandas as pd

import sklearn.datasets as datasets

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import mlflow.pytorch
from mlflow import tracking
from mlflow.utils.file_utils import TempDir


@pytest.fixture(scope='module')
def data():
iris = datasets.load_iris()
data = pd.DataFrame(data=np.c_[iris['data'], iris['target']],
columns=iris['feature_names'] + ['target'])
y = data['target']
x = data.drop('target', axis=1)
return x, y


def get_dataset(data):
x, y = data
dataset = [(xi.astype(np.float32), yi.astype(np.float32))
for xi, yi in zip(x.values, y.values)]
return dataset


@pytest.fixture(scope='module')
def model(data):
dataset = get_dataset(data)
model = nn.Sequential(
nn.Linear(4, 3),
nn.ReLU(),
nn.Linear(3, 1),
)

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

batch_size = 16
num_workers = 4
dataloader = DataLoader(dataset, batch_size=batch_size,
num_workers=num_workers, shuffle=True, drop_last=False)

model.train()
for epoch in range(5):
for batch in dataloader:
optimizer.zero_grad()
batch_size = batch[0].shape[0]
y_pred = model(batch[0]).squeeze(dim=1)
loss = criterion(y_pred, batch[1])
loss.backward()
optimizer.step()

return model


def _predict(model, data):
dataset = get_dataset(data)
batch_size = 16
num_workers = 4
dataloader = DataLoader(dataset, batch_size=batch_size,
num_workers=num_workers, shuffle=False, drop_last=False)
predictions = np.zeros((len(dataloader.sampler),))
model.eval()
with torch.no_grad():
for i, batch in enumerate(dataloader):
y_preds = model(batch[0]).squeeze(dim=1).numpy()
predictions[i * batch_size:(i + 1) * batch_size] = y_preds
return predictions


@pytest.fixture(scope='module')
def predicted(model, data):
return _predict(model, data)


def test_log_model(model, data, predicted):

old_uri = tracking.get_tracking_uri()
# should_start_run tests whether or not calling log_model() automatically starts a run.
for should_start_run in [False, True]:
with TempDir(chdr=True, remove_on_exit=True) as tmp:
try:
tracking.set_tracking_uri(tmp.path("test"))
if should_start_run:
mlflow.start_run()

mlflow.pytorch.log_model(model, artifact_path="pytorch")

# Load model
run_id = mlflow.active_run().info.run_uuid
model_loaded = mlflow.pytorch.load_model("pytorch", run_id=run_id)

test_predictions = _predict(model_loaded, data)
assert np.all(test_predictions == predicted)
finally:
mlflow.end_run()
tracking.set_tracking_uri(old_uri)


def test_raise_exception(model):
with TempDir(chdr=True, remove_on_exit=True) as tmp:
path = tmp.path("model")
with pytest.raises(RuntimeError):
mlflow.pytorch.load_model(path)

with pytest.raises(TypeError):
mlflow.pytorch.save_model([1, 2, 3], path)

mlflow.pytorch.save_model(model, path)
with pytest.raises(RuntimeError):
mlflow.pytorch.save_model(model, path)

from mlflow import sklearn
import sklearn.neighbors as knn
import pickle
path = tmp.path("knn.pkl")
knn = knn.KNeighborsClassifier()
with open(path, "wb") as f:
pickle.dump(knn, f)
path = tmp.path("knn")
sklearn.save_model(knn, path=path)
with pytest.raises(ValueError):
mlflow.pytorch.load_model(path)


def test_save_and_load_model(model, data, predicted):

x, y = data
with TempDir(chdr=True, remove_on_exit=True) as tmp:
path = tmp.path("model")
mlflow.pytorch.save_model(model, path)

# Loading pytorch model
model_loaded = mlflow.pytorch.load_model(path)
assert np.all(_predict(model_loaded, data) == predicted)

# Loading pyfunc model
pyfunc_loaded = mlflow.pyfunc.load_pyfunc(path)
assert np.all(pyfunc_loaded.predict(x).values[:, 0] == predicted)

0 comments on commit b00826f

Please sign in to comment.