Skip to content

Commit

Permalink
Make mlflow.pytorch.pickle_module the default when saving PyTorch mod…
Browse files Browse the repository at this point in the history
…els (mlflow#861)

* Move PyTorch module to subfolder

* Add mlflow.pytorch.pickle module

* Module docs improvements

* Rename pickle to pickle_module

* Disable warnings

* Add test

* Docs fix

* Functioning default pickle_module argument with tests

* Format code, add docs, pickle module load override test

* Log error if pickle module cannot be imported

* Prefix internal constants with underscore

* Address review comments

* Lint

* Test simplifications
  • Loading branch information
dbczumar committed Feb 4, 2019
1 parent 38853eb commit 8352a03
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 13 deletions.
84 changes: 73 additions & 11 deletions mlflow/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@

from __future__ import absolute_import

import importlib
import logging
import os
import yaml

import cloudpickle
import numpy as np
import pandas as pd
import torch
Expand All @@ -23,6 +25,8 @@
from mlflow import pyfunc
from mlflow.exceptions import MlflowException
from mlflow.models import Model
from mlflow.protos.databricks_pb2 import RESOURCE_DOES_NOT_EXIST
from mlflow.pytorch import pickle_module as mlflow_pytorch_pickle_module
import mlflow.tracking
from mlflow.utils.environment import _mlflow_conda_env
from mlflow.utils.file_utils import _copy_file_or_tree
Expand All @@ -35,16 +39,25 @@
"pytorch={}".format(torch.__version__),
"torchvision={}".format(torchvision.__version__),
],
additional_pip_deps=None,
additional_pip_deps=[
# We include CloudPickle in the default environment because
# it's required by the default pickle module used by `save_model()`
# and `log_model()`: `mlflow.pytorch.pickle_module`.
"cloudpickle=={}".format(cloudpickle.__version__)
],
additional_conda_channels=[
"pytorch",
],
)

_SERIALIZED_TORCH_MODEL_FILE_NAME = "model.pth"
_PICKLE_MODULE_INFO_FILE_NAME = "pickle_module_info.txt"

_logger = logging.getLogger(__name__)


def log_model(pytorch_model, artifact_path, conda_env=None, code_paths=None, **kwargs):
def log_model(pytorch_model, artifact_path, conda_env=None, code_paths=None,
pickle_module=mlflow_pytorch_pickle_module, **kwargs):
"""
Log a PyTorch model as an MLflow artifact for the current run.
Expand Down Expand Up @@ -78,6 +91,10 @@ def log_model(pytorch_model, artifact_path, conda_env=None, code_paths=None, **k
:param code_paths: A list of local filesystem paths to Python file dependencies (or directories
containing file dependencies). These files will be *prepended* to the system
path when the model is loaded.
:param pickle_module: The module that PyTorch should use to serialize ("pickle") the specified
``pytorch_model``. This will be passed as the ``pickle_module`` parameter
to ``torch.save()``. By default, this module will also be used to
deserialize ("unpickle") the PyTorch model at load time.
:param kwargs: kwargs to pass to ``torch.save`` method.
>>> import torch
Expand Down Expand Up @@ -122,11 +139,11 @@ def log_model(pytorch_model, artifact_path, conda_env=None, code_paths=None, **k
>>> mlflow.pytorch.log_model(pytorch_model, "models")
"""
Model.log(artifact_path=artifact_path, flavor=mlflow.pytorch, pytorch_model=pytorch_model,
conda_env=conda_env, code_paths=code_paths, **kwargs)
conda_env=conda_env, code_paths=code_paths, pickle_module=pickle_module, **kwargs)


def save_model(pytorch_model, path, conda_env=None, mlflow_model=Model(), code_paths=None,
**kwargs):
pickle_module=mlflow_pytorch_pickle_module, **kwargs):
"""
Save a PyTorch model to a path on the local file system.
Expand Down Expand Up @@ -162,6 +179,10 @@ def save_model(pytorch_model, path, conda_env=None, mlflow_model=Model(), code_p
:param code_paths: A list of local filesystem paths to Python file dependencies (or directories
containing file dependencies). These files will be *prepended* to the system
path when the model is loaded.
:param pickle_module: The module that PyTorch should use to serialize ("pickle") the specified
``pytorch_model``. This will be passed as the ``pickle_module`` parameter
to ``torch.save()``. By default, this module will also be used to
deserialize ("unpickle") the PyTorch model at load time.
:param kwargs: kwargs to pass to ``torch.save`` method.
>>> import torch
Expand All @@ -186,11 +207,23 @@ def save_model(pytorch_model, path, conda_env=None, mlflow_model=Model(), code_p
if os.path.exists(path):
raise RuntimeError("Path '{}' already exists".format(path))
os.makedirs(path)
model_path = os.path.join(path, "model.pth")

model_data_subpath = "data"
model_data_path = os.path.join(path, model_data_subpath)
os.makedirs(model_data_path)
# Persist the pickle module name as a file in the model's `data` directory. This is necessary
# because the `data` directory is the only available parameter to `_load_pyfunc`, and it
# does not contain the MLmodel configuration; therefore, it is not sufficient to place
# the module name in the MLmodel
#
# TODO: Stop persisting this information to the filesystem once we have a mechanism for
# supplying the MLmodel configuration to `mlflow.pytorch._load_pyfunc`
pickle_module_path = os.path.join(model_data_path, _PICKLE_MODULE_INFO_FILE_NAME)
with open(pickle_module_path, "w") as f:
f.write(pickle_module.__name__)
# Save pytorch model
torch.save(pytorch_model, model_path, **kwargs)
model_file = os.path.basename(model_path)
model_path = os.path.join(model_data_path, _SERIALIZED_TORCH_MODEL_FILE_NAME)
torch.save(pytorch_model, model_path, pickle_module=pickle_module, **kwargs)

conda_env_subpath = "conda.yaml"
if conda_env is None:
Expand All @@ -208,9 +241,11 @@ def save_model(pytorch_model, path, conda_env=None, mlflow_model=Model(), code_p
else:
code_dir_subpath = None

mlflow_model.add_flavor(FLAVOR_NAME, model_data=model_file, pytorch_version=torch.__version__)
pyfunc.add_to_model(mlflow_model, loader_module="mlflow.pytorch", data=model_file,
code=code_dir_subpath, env=conda_env_subpath)
mlflow_model.add_flavor(
FLAVOR_NAME, model_data=model_data_subpath, pytorch_version=torch.__version__)
pyfunc.add_to_model(mlflow_model, loader_module="mlflow.pytorch", data=model_data_subpath,
pickle_module_name=pickle_module.__name__, code=code_dir_subpath,
env=conda_env_subpath)
mlflow_model.save(os.path.join(path, "MLmodel"))


Expand All @@ -219,7 +254,34 @@ def _load_model(path, **kwargs):
:param path: The path to a serialized PyTorch model.
:param kwargs: Additional kwargs to pass to the PyTorch ``torch.load`` function.
"""
return torch.load(path, **kwargs)
if os.path.isdir(path):
# `path` is a directory containing a serialized PyTorch model and a text file containing
# information about the pickle module that should be used by PyTorch to load it
model_path = os.path.join(path, "model.pth")
pickle_module_path = os.path.join(path, _PICKLE_MODULE_INFO_FILE_NAME)
with open(pickle_module_path, "r") as f:
pickle_module_name = f.read()
if "pickle_module" in kwargs and kwargs["pickle_module"].__name__ != pickle_module_name:
_logger.warning(
"Attempting to load the PyTorch model with a pickle module, '%s', that does not"
" match the pickle module that was used to save the model: '%s'.",
kwargs["pickle_module"].__name__,
pickle_module_name)
else:
try:
kwargs["pickle_module"] = importlib.import_module(pickle_module_name)
except ImportError:
raise MlflowException(
message=(
"Failed to import the pickle module that was used to save the PyTorch"
" model. Pickle module name: `{pickle_module_name}`".format(
pickle_module_name=pickle_module_name)),
error_code=RESOURCE_DOES_NOT_EXIST)

else:
model_path = path

return torch.load(model_path, **kwargs)


def load_model(path, run_id=None, **kwargs):
Expand Down
6 changes: 6 additions & 0 deletions mlflow/pytorch/pickle_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,9 @@
# into a CloudPickle release or the ``torch.save`` API has been updated to be compatible with
# the existing CloudPickle API.
from cloudpickle import CloudPickler as Pickler
# CloudPickle does not include `Unpickler` in its namespace, which is required by PyTorch for
# deserialization. Noting that CloudPickle's `load()` and `loads()` routines are aliases for
# `pickle.load()` and `pickle.loads()`, we therefore import Unpickler from the native
# Python pickle library.
# pylint: disable=unused-import
from pickle import Unpickler
189 changes: 187 additions & 2 deletions tests/pytorch/test_pytorch_model_export.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import print_function

import importlib
import os
import logging
import json
import logging
import mock
import pickle

import pytest
import numpy as np
Expand Down Expand Up @@ -210,7 +212,6 @@ def test_raise_exception(sequential_model):

from mlflow import sklearn
import sklearn.neighbors as knn
import pickle
path = tmp.path("knn.pkl")
knn = knn.KNeighborsClassifier()
with open(path, "wb") as f:
Expand Down Expand Up @@ -440,6 +441,190 @@ def predict(self, context, model_input):
decimal=4)


def test_load_pyfunc_loads_torch_model_using_pickle_module_specified_at_save_time(
module_scoped_subclassed_model, model_path):
custom_pickle_module = pickle

mlflow.pytorch.save_model(
path=model_path,
pytorch_model=module_scoped_subclassed_model,
conda_env=None,
pickle_module=custom_pickle_module)

import_module_fn = importlib.import_module
imported_modules = []

def track_module_imports(module_name):
imported_modules.append(module_name)
return import_module_fn(module_name)

with mock.patch("importlib.import_module") as import_mock,\
mock.patch("torch.load") as torch_load_mock:
import_mock.side_effect = track_module_imports
pyfunc.load_pyfunc(model_path)

torch_load_mock.assert_called_with(mock.ANY, pickle_module=custom_pickle_module)
assert custom_pickle_module.__name__ in imported_modules


def test_load_model_loads_torch_model_using_pickle_module_specified_at_save_time(
module_scoped_subclassed_model):
custom_pickle_module = pickle

artifact_path = "pytorch_model"
with mlflow.start_run():
mlflow.pytorch.log_model(
artifact_path=artifact_path,
pytorch_model=module_scoped_subclassed_model,
conda_env=None,
pickle_module=custom_pickle_module)
run_id = mlflow.active_run().info.run_uuid

import_module_fn = importlib.import_module
imported_modules = []

def track_module_imports(module_name):
imported_modules.append(module_name)
return import_module_fn(module_name)

with mock.patch("importlib.import_module") as import_mock,\
mock.patch("torch.load") as torch_load_mock:
import_mock.side_effect = track_module_imports
pyfunc.load_pyfunc(artifact_path, run_id)

torch_load_mock.assert_called_with(mock.ANY, pickle_module=custom_pickle_module)
assert custom_pickle_module.__name__ in imported_modules


def test_load_pyfunc_succeeds_when_data_is_model_file_instead_of_directory(
module_scoped_subclassed_model, model_path, data):
"""
This test verifies that PyTorch models saved in older versions of MLflow are loaded successfully
by `mlflow.pytorch.load_model`. The `data` path associated with these older models is serialized
PyTorch model file, as opposed to the current format: a directory containing a serialized
model file and pickle module information
"""
mlflow.pytorch.save_model(
path=model_path,
pytorch_model=module_scoped_subclassed_model,
conda_env=None)

model_conf_path = os.path.join(model_path, "MLmodel")
model_conf = Model.load(model_conf_path)
pyfunc_conf = model_conf.flavors.get(pyfunc.FLAVOR_NAME)
assert pyfunc_conf is not None
model_data_path = os.path.join(model_path, pyfunc_conf[pyfunc.DATA])
assert os.path.exists(model_data_path)
assert mlflow.pytorch._SERIALIZED_TORCH_MODEL_FILE_NAME in os.listdir(model_data_path)
pyfunc_conf[pyfunc.DATA] = os.path.join(
model_data_path, mlflow.pytorch._SERIALIZED_TORCH_MODEL_FILE_NAME)
model_conf.save(model_conf_path)

loaded_pyfunc = pyfunc.load_pyfunc(model_path)

np.testing.assert_array_almost_equal(
loaded_pyfunc.predict(data[0]),
pd.DataFrame(_predict(model=module_scoped_subclassed_model, data=data)),
decimal=4)


def test_load_model_succeeds_when_data_is_model_file_instead_of_directory(
module_scoped_subclassed_model, model_path, data):
"""
This test verifies that PyTorch models saved in older versions of MLflow are loaded successfully
by `mlflow.pytorch.load_model`. The `data` path associated with these older models is serialized
PyTorch model file, as opposed to the current format: a directory containing a serialized
model file and pickle module information
"""
artifact_path = "pytorch_model"
with mlflow.start_run():
mlflow.pytorch.log_model(
artifact_path=artifact_path,
pytorch_model=module_scoped_subclassed_model,
conda_env=None)
run_id = mlflow.active_run().info.run_uuid
model_path = tracking.utils._get_model_log_dir(artifact_path, run_id)

model_conf_path = os.path.join(model_path, "MLmodel")
model_conf = Model.load(model_conf_path)
pyfunc_conf = model_conf.flavors.get(pyfunc.FLAVOR_NAME)
assert pyfunc_conf is not None
model_data_path = os.path.join(model_path, pyfunc_conf[pyfunc.DATA])
assert os.path.exists(model_data_path)
assert mlflow.pytorch._SERIALIZED_TORCH_MODEL_FILE_NAME in os.listdir(model_data_path)
pyfunc_conf[pyfunc.DATA] = os.path.join(
model_data_path, mlflow.pytorch._SERIALIZED_TORCH_MODEL_FILE_NAME)
model_conf.save(model_conf_path)

loaded_pyfunc = pyfunc.load_pyfunc(model_path)

np.testing.assert_array_almost_equal(
loaded_pyfunc.predict(data[0]),
pd.DataFrame(_predict(model=module_scoped_subclassed_model, data=data)),
decimal=4)


def test_load_model_allows_user_to_override_pickle_module_via_keyword_argument(
module_scoped_subclassed_model, model_path):
mlflow.pytorch.save_model(
path=model_path,
pytorch_model=module_scoped_subclassed_model,
conda_env=None,
pickle_module=pickle)

mlflow_torch_pickle_load = mlflow_pytorch_pickle_module.load
pickle_call_results = {
"mlflow_torch_pickle_load_called": False,
}

def validate_mlflow_torch_pickle_load_called(*args, **kwargs):
pickle_call_results["mlflow_torch_pickle_load_called"] = True
return mlflow_torch_pickle_load(*args, **kwargs)

log_messages = []

def custom_warn(message_text, *args, **kwargs):
log_messages.append(message_text % args % kwargs)

with mock.patch("mlflow.pytorch.pickle_module.load") as mlflow_torch_pickle_load_mock,\
mock.patch("mlflow.pytorch._logger.warning") as warn_mock:
mlflow_torch_pickle_load_mock.side_effect = validate_mlflow_torch_pickle_load_called
warn_mock.side_effect = custom_warn
mlflow.pytorch.load_model(path=model_path, pickle_module=mlflow_pytorch_pickle_module)

assert all(pickle_call_results.values())
assert any([
"does not match the pickle module that was used to save the model" in log_message and
pickle.__name__ in log_message and
mlflow_pytorch_pickle_module.__name__ in log_message
for log_message in log_messages
])


def test_load_model_raises_exception_when_pickle_module_cannot_be_imported(
main_scoped_subclassed_model, model_path):
mlflow.pytorch.save_model(
path=model_path,
pytorch_model=main_scoped_subclassed_model,
conda_env=None)

bad_pickle_module_name = "not.a.real.module"

pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
model_data_path = os.path.join(model_path, pyfunc_conf[pyfunc.DATA])
assert os.path.exists(model_data_path)
assert mlflow.pytorch._PICKLE_MODULE_INFO_FILE_NAME in os.listdir(model_data_path)
with open(
os.path.join(model_data_path, mlflow.pytorch._PICKLE_MODULE_INFO_FILE_NAME), "w") as f:
f.write(bad_pickle_module_name)

with pytest.raises(MlflowException) as exc_info:
mlflow.pytorch.load_model(model_path)

assert "Failed to import the pickle module" in str(exc_info)
assert bad_pickle_module_name in str(exc_info)


@pytest.mark.release
def test_sagemaker_docker_model_scoring_with_sequential_model_and_default_conda_env(
model, model_path, data, sequential_predicted):
Expand Down

0 comments on commit 8352a03

Please sign in to comment.