From d1e8025e689d7c262a3f2591a5ff7b5dc69a3b75 Mon Sep 17 00:00:00 2001 From: dbczumar <39497902+dbczumar@users.noreply.github.com> Date: Tue, 28 Aug 2018 10:58:36 -0700 Subject: [PATCH] MLeap Flavor/Deployment #2: Python support for the MLeap flavor (#324) * Add SparkJava plugin * Add MLeap flavor and modify SparkML module to use it * Fixes * Add mleap * Import model in test * Import mleap * Add missing assert * Reorder spark session creation params in test * Docs fix * revert pom xml change * remove java * Add standalone to mleap * Import fix * Add docs and * Add warning about py incompatibility * Address comments * Code spacing fix * Revert log test changes * Whitespace fixes * Exception import fixes * Fix lint issues * Fix method call * callfix * Remove unused imports * Py4j log level fix * reorder tests * testfix * test fixes * Spacing fix * lint fixes * Add mleap schema test * Fix test * Whitespace fix * Test fix * Fix exception path * Update test_exception.py * Disable warning * Disable unused variable warning --- mlflow/exceptions.py | 5 + mlflow/mleap.py | 120 +++++++++++++ mlflow/projects/__init__.py | 2 +- mlflow/projects/_project_spec.py | 2 +- mlflow/projects/databricks.py | 2 +- mlflow/pyfunc/__init__.py | 13 +- mlflow/spark.py | 71 +++++--- mlflow/utils/exception.py | 7 - mlflow/utils/file_utils.py | 11 ++ setup.py | 1 + tests/projects/test_entry_point.py | 2 +- tests/projects/test_project_spec.py | 2 +- tests/projects/test_projects.py | 2 +- tests/spark/test_spark_model_export.py | 239 +++++++++++++++++++------ tests/utils/test_exception.py | 2 +- 15 files changed, 379 insertions(+), 102 deletions(-) create mode 100644 mlflow/mleap.py delete mode 100644 mlflow/utils/exception.py diff --git a/mlflow/exceptions.py b/mlflow/exceptions.py index cd42639e15f12..6cf1696a59886 100644 --- a/mlflow/exceptions.py +++ b/mlflow/exceptions.py @@ -4,3 +4,8 @@ class MlflowException(Exception): class IllegalArtifactPathError(MlflowException): """The artifact_path parameter was invalid.""" + + +class ExecutionException(MlflowException): + """Exception thrown when executing a project fails.""" + pass diff --git a/mlflow/mleap.py b/mlflow/mleap.py new file mode 100644 index 0000000000000..9a9bbe0eb3b47 --- /dev/null +++ b/mlflow/mleap.py @@ -0,0 +1,120 @@ +""" +MLflow integration of the MLeap serialization tool for PySpark MLlib pipelines + +This module provides utilities for saving models using the MLeap +using the MLeap library's persistence mechanism. + +A companion module for loading MLFlow models with the MLeap flavor format is available in the +`mlflow/java` package. + +For more information about MLeap, see https://github.com/combust/mleap. +""" + +from __future__ import absolute_import + +import os +import sys +import traceback +import json +from six import reraise + +import mlflow +from mlflow.models import Model + +FLAVOR_NAME = "mleap" + + +def log_model(spark_model, sample_input, artifact_path): + """ + Log a Spark MLLib model in MLeap format as an MLflow artifact + for the current run. The logged model will have the MLeap flavor. + + NOTE: The MLeap model flavor cannot be loaded in Python. It must be loaded using the + Java module within the `mlflow/java` package. + + :param spark_model: Spark PipelineModel to be saved. This model must be MLeap-compatible and + cannot contain any custom transformers. + :param sample_input: A sample PySpark Dataframe input that the model can evaluate. This is + required by MLeap for data schema inference. + """ + return Model.log(artifact_path=artifact_path, flavor=mlflow.mleap, + spark_model=spark_model, sample_input=sample_input) + + +def save_model(spark_model, sample_input, path, mlflow_model=Model()): + """ + Save a Spark MLlib PipelineModel in MLeap format at the given local path. + The saved model will have the MLeap flavor. + + NOTE: The MLeap model flavor cannot be loaded in Python. It must be loaded using the + Java module within the `mlflow/java` package. + + :param path: Path of the MLFlow model to which this flavor is being added. + :param spark_model: Spark PipelineModel to be saved. This model must be MLeap-compatible and + cannot contain any custom transformers. + :param sample_input: A sample PySpark Dataframe input that the model can evaluate. This is + required by MLeap for data schema inference. + :param mlflow_model: MLFlow model config to which this flavor is being added. + """ + add_to_model(mlflow_model, path, spark_model, sample_input) + mlflow_model.save(os.path.join(path, "MLmodel")) + + +def add_to_model(mlflow_model, path, spark_model, sample_input): + """ + Add the MLeap flavor to a pre-existing MLFlow model. + + :param mlflow_model: MLFlow model config to which this flavor is being added. + :param path: Path of the MLFlow model to which this flavor is being added. + :param spark_model: Spark PipelineModel to be saved. This model must be MLeap-compatible and + cannot contain any custom transformers. + :param sample_input: A sample PySpark Dataframe input that the model can evaluate. This is + required by MLeap for data schema inference. + """ + from pyspark.ml.pipeline import PipelineModel + from pyspark.sql import DataFrame + import mleap.version + from mleap.pyspark.spark_support import SimpleSparkSerializer # pylint: disable=unused-variable + from py4j.protocol import Py4JError + + if not isinstance(spark_model, PipelineModel): + raise Exception("Not a PipelineModel." + " MLeap can currently only save PipelineModels.") + if sample_input is None: + raise Exception("A sample input must be specified in order to add the MLeap flavor.") + if not isinstance(sample_input, DataFrame): + raise Exception("The sample input must be a PySpark dataframe of type `{df_type}`".format( + df_type=DataFrame.__module__)) + + mleap_path_full = os.path.join(path, "mleap") + mleap_datapath_sub = os.path.join("mleap", "model") + mleap_datapath_full = os.path.join(path, mleap_datapath_sub) + if os.path.exists(mleap_path_full): + raise Exception("MLeap model data path already exists at: {path}".format( + path=mleap_path_full)) + os.makedirs(mleap_path_full) + + dataset = spark_model.transform(sample_input) + model_path = "file:{mp}".format(mp=mleap_datapath_full) + try: + spark_model.serializeToBundle(path=model_path, + dataset=dataset) + except Py4JError as e: + tb = sys.exc_info()[2] + error_str = ("MLeap encountered an error while serializing the model. Please ensure that" + " the model is compatible with MLeap" + " (i.e does not contain any custom transformers). Error text: {err}".format( + err=str(e))) + traceback.print_exc() + reraise(Exception, error_str, tb) + + input_schema = json.loads(sample_input.schema.json()) + mleap_schemapath_sub = os.path.join("mleap", "schema.json") + mleap_schemapath_full = os.path.join(path, mleap_schemapath_sub) + with open(mleap_schemapath_full, "w") as out: + json.dump(input_schema, out, indent=4) + + mlflow_model.add_flavor(FLAVOR_NAME, + mleap_version=mleap.version.__version__, + model_data=mleap_datapath_sub, + input_schema=mleap_schemapath_sub) diff --git a/mlflow/projects/__init__.py b/mlflow/projects/__init__.py index 025cb5e36a635..5a3c515c025c5 100644 --- a/mlflow/projects/__init__.py +++ b/mlflow/projects/__init__.py @@ -12,7 +12,7 @@ from mlflow.projects.submitted_run import LocalSubmittedRun from mlflow.projects import _project_spec -from mlflow.utils.exception import ExecutionException +from mlflow.exceptions import ExecutionException from mlflow.entities import RunStatus, SourceType, Param import mlflow.tracking as tracking from mlflow.tracking.fluent import _get_experiment_id, _get_git_commit diff --git a/mlflow/projects/_project_spec.py b/mlflow/projects/_project_spec.py index 5bd73dee9dc0c..1af7cd5b86e02 100644 --- a/mlflow/projects/_project_spec.py +++ b/mlflow/projects/_project_spec.py @@ -7,7 +7,7 @@ from six.moves import shlex_quote from mlflow import data -from mlflow.utils.exception import ExecutionException +from mlflow.exceptions import ExecutionException MLPROJECT_FILE_NAME = "MLproject" diff --git a/mlflow/projects/databricks.py b/mlflow/projects/databricks.py index fbd4bd4b09f18..99c1c569d8968 100644 --- a/mlflow/projects/databricks.py +++ b/mlflow/projects/databricks.py @@ -12,7 +12,7 @@ from mlflow.projects import _fetch_project from mlflow.projects.submitted_run import SubmittedRun from mlflow.utils import rest_utils, file_utils -from mlflow.utils.exception import ExecutionException +from mlflow.exceptions import ExecutionException from mlflow.utils.logging_utils import eprint from mlflow import tracking from mlflow.version import VERSION diff --git a/mlflow/pyfunc/__init__.py b/mlflow/pyfunc/__init__.py index 3304596738f81..e3c86a531ac59 100644 --- a/mlflow/pyfunc/__init__.py +++ b/mlflow/pyfunc/__init__.py @@ -83,7 +83,7 @@ from mlflow import tracking from mlflow.models import Model from mlflow.utils import PYTHON_VERSION, get_major_minor_py_version -from mlflow.utils.file_utils import TempDir +from mlflow.utils.file_utils import TempDir, _copy_file_or_tree from mlflow.utils.logging_utils import eprint FLAVOR_NAME = "python_function" @@ -229,17 +229,6 @@ def predict(*args): return pandas_udf(predict, result_type) -def _copy_file_or_tree(src, dst, dst_dir): - name = os.path.join(dst_dir, os.path.basename(os.path.abspath(src))) - if dst_dir: - os.mkdir(os.path.join(dst, dst_dir)) - if os.path.isfile(src): - shutil.copy(src=src, dst=os.path.join(dst, name)) - else: - shutil.copytree(src=src, dst=os.path.join(dst, name)) - return name - - def save_model(dst_path, loader_module, data_path=None, code_path=(), conda_env=None, model=Model()): """ diff --git a/mlflow/spark.py b/mlflow/spark.py index bb6af320c0d82..1f428da2ede0c 100644 --- a/mlflow/spark.py +++ b/mlflow/spark.py @@ -1,10 +1,20 @@ """ MLflow integration for Spark MLlib models. - -Spark MLlib models are saved and loaded using native Spark MLlib persistence. -The models can be exported as pyfunc for out-of Spark deployment or it can be loaded back as Spark -Transformer in order to score it in Spark. The pyfunc flavor instantiates SparkContext internally -and reads the input data as Spark DataFrame prior to scoring. +This module enables the exporting of Spark MLlib models with the following flavors (formats): + 1. Spark MLlib (native) format - Allows models to be loaded as Spark Transformers for scoring + in a Spark session. Models with this flavor can be loaded + back as PySpark PipelineModel objects in Python. This + is the main flavor and is always produced. + 2. PyFunc - Supports deployment outside of Spark by instantiating a SparkContext and reading + input data as a Spark DataFrame prior to scoring. Also supports deployment in Spark + as a Spark UDF. Models with this flavor can be loaded back as Python functions + for performing inference. This flavor is always produced. + 3. MLeap - Enables high-performance deployment outside of Spark by leveraging MLeap's + custom dataframe and pipeline representations. For more informatin about MLeap, + see https://github.com/combust/mleap. Models with this flavor *cannot* be loaded + back as Python objects. Rather, they must be deserialized in Java using the + `mlflow/java` package. This flavor is only produced if MLeap-compatible arguments + are specified. """ from __future__ import absolute_import @@ -15,10 +25,9 @@ import pyspark from pyspark import SparkContext from pyspark.ml.pipeline import PipelineModel -from pyspark.ml.base import Transformer import mlflow -from mlflow import pyfunc +from mlflow import pyfunc, mleap from mlflow.models import Model from mlflow.utils.logging_utils import eprint @@ -28,9 +37,11 @@ DFS_TMP = "/tmp/mlflow" -def log_model(spark_model, artifact_path, conda_env=None, jars=None, dfs_tmpdir=None): +def log_model(spark_model, artifact_path, conda_env=None, jars=None, dfs_tmpdir=None, + sample_input=None): """ - Log a Spark MLlib model as an MLflow artifact for the current run. + Log a Spark MLlib model as an MLflow artifact for the current run. This will use the + MLlib persistence format, and the logged model will have the Spark flavor. :param spark_model: PipelineModel to be saved. :param artifact_path: Run relative artifact path. @@ -43,7 +54,10 @@ def log_model(spark_model, artifact_path, conda_env=None, jars=None, dfs_tmpdir= destination and then copied into the model's artifact directory. This is necessary as Spark ML models read / write from / to DFS if running on a cluster. All temporary files created on the DFS will be removed if this - operation completes successfully. Defaults to /tmp/mlflow. + operation completes successfully. Defaults to /tmp/mlflow.` + :param sample_input: A sample input that will be used to add the MLeap flavor to the model. + This must be a PySpark dataframe that the model can evaluate. If + `sample_input` is `None`, the MLeap flavor will not be added. >>> from pyspark.ml import Pipeline >>> from pyspark.ml.classification import LogisticRegression @@ -62,7 +76,8 @@ def log_model(spark_model, artifact_path, conda_env=None, jars=None, dfs_tmpdir= """ return Model.log(artifact_path=artifact_path, flavor=mlflow.spark, spark_model=spark_model, - jars=jars, conda_env=conda_env, dfs_tmpdir=dfs_tmpdir) + jars=jars, conda_env=conda_env, dfs_tmpdir=dfs_tmpdir, + sample_input=sample_input) def _tmp_path(dfs_tmp): @@ -118,11 +133,13 @@ def delete(cls, path): def save_model(spark_model, path, mlflow_model=Model(), conda_env=None, jars=None, - dfs_tmpdir=None): + dfs_tmpdir=None, sample_input=None): """ - Save Spark MLlib PipelineModel at given local path. + Save a Spark MLlib PipelineModel at the given local path. - Uses Spark MLlib persistence mechanism. + By default, this function saves models using the Spark MLlib persistence mechanism. + Additionally, if a sample input is specified via the `sample_input` parameter, the model + will also be serialized in MLeap format and the MLeap flavor will be added. :param spark_model: Spark PipelineModel to be saved. Can save only PipelineModels. :param path: Local path where the model is to be saved. @@ -135,7 +152,9 @@ def save_model(spark_model, path, mlflow_model=Model(), conda_env=None, jars=Non as Spark ML models read / write from / to DFS if running on a cluster. All temporary files created on the DFS will be removed if this operation completes successfully. Defaults to /tmp/mlflow. - + :param sample_input: A sample input that will be used to add the MLeap flavor to the model. + This must be a PySpark dataframe that the model can evaluate. If + `sample_input` is `None`, the MLeap flavor will not be added. >>> from mlflow import spark >>> from pyspark.ml.pipeline.PipelineModel @@ -147,26 +166,28 @@ def save_model(spark_model, path, mlflow_model=Model(), conda_env=None, jars=Non dfs_tmpdir = dfs_tmpdir if dfs_tmpdir is not None else DFS_TMP if jars: raise Exception("jar dependencies are not implemented") - if not isinstance(spark_model, Transformer): - raise Exception("Unexpected type {}. SparkML model works only with Transformers".format( - str(type(spark_model)))) + + if sample_input is not None: + mleap.add_to_model(mlflow_model, path, spark_model, sample_input) + if not isinstance(spark_model, PipelineModel): - raise Exception("Not a PipelineModel. SparkML can save only PipelineModels.") + raise Exception("Not a PipelineModel. SparkML can only save PipelineModels.") + # Spark ML stores the model on DFS if running on a cluster # Save it to a DFS temp dir first and copy it to local path tmp_path = _tmp_path(dfs_tmpdir) spark_model.save(tmp_path) - model_path = os.path.abspath(os.path.join(path, "model")) - _HadoopFileSystem.copy_to_local_file(tmp_path, model_path, removeSrc=True) + sparkml_data_path_sub = "sparkml" + sparkml_data_path = os.path.abspath(os.path.join(path, sparkml_data_path_sub)) + _HadoopFileSystem.copy_to_local_file(tmp_path, sparkml_data_path, removeSrc=True) pyspark_version = pyspark.version.__version__ model_conda_env = None if conda_env: model_conda_env = os.path.basename(os.path.abspath(conda_env)) shutil.copyfile(conda_env, os.path.join(path, model_conda_env)) - if jars: - raise Exception("JAR dependencies are not yet implemented") - mlflow_model.add_flavor(FLAVOR_NAME, pyspark_version=pyspark_version, model_data="model") - pyfunc.add_to_model(mlflow_model, loader_module="mlflow.spark", data="model", + mlflow_model.add_flavor(FLAVOR_NAME, pyspark_version=pyspark_version, + model_data=sparkml_data_path_sub) + pyfunc.add_to_model(mlflow_model, loader_module="mlflow.spark", data=sparkml_data_path_sub, env=model_conda_env) mlflow_model.save(os.path.join(path, "MLmodel")) diff --git a/mlflow/utils/exception.py b/mlflow/utils/exception.py deleted file mode 100644 index e19eb830e8cb0..0000000000000 --- a/mlflow/utils/exception.py +++ /dev/null @@ -1,7 +0,0 @@ -class MLflowException(Exception): - pass - - -class ExecutionException(Exception): - """Exception thrown when executing a project fails.""" - pass diff --git a/mlflow/utils/file_utils.py b/mlflow/utils/file_utils.py index dca375c7469ee..83fcd5db1792f 100644 --- a/mlflow/utils/file_utils.py +++ b/mlflow/utils/file_utils.py @@ -299,3 +299,14 @@ def ignore(_, names): shutil.copytree(src_path, os.path.join(dst_path, mlflow_dir), ignore=_docker_ignore(src_path)) return mlflow_dir + + +def _copy_file_or_tree(src, dst, dst_dir): + name = os.path.join(dst_dir, os.path.basename(os.path.abspath(src))) + if dst_dir: + os.mkdir(os.path.join(dst, dst_dir)) + if os.path.isfile(src): + shutil.copy(src=src, dst=os.path.join(dst, name)) + else: + shutil.copytree(src=src, dst=os.path.join(dst, name)) + return name diff --git a/setup.py b/setup.py index 5431083a1b51b..203f289a1b548 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ def package_files(directory): 'boto3', 'querystring_parser', 'simplejson', + 'mleap>=0.8.1', ], entry_points=''' [console_scripts] diff --git a/tests/projects/test_entry_point.py b/tests/projects/test_entry_point.py index a79d3f73be8fd..c879b9fa2389b 100644 --- a/tests/projects/test_entry_point.py +++ b/tests/projects/test_entry_point.py @@ -5,7 +5,7 @@ from six.moves import shlex_quote -from mlflow.utils.exception import ExecutionException +from mlflow.exceptions import ExecutionException from mlflow.utils.file_utils import TempDir from tests.projects.utils import load_project, TEST_PROJECT_DIR diff --git a/tests/projects/test_project_spec.py b/tests/projects/test_project_spec.py index 060686cad607d..903bb5cb1d6d3 100644 --- a/tests/projects/test_project_spec.py +++ b/tests/projects/test_project_spec.py @@ -2,7 +2,7 @@ import pytest -from mlflow.utils.exception import ExecutionException +from mlflow.exceptions import ExecutionException from mlflow.projects import _project_spec from tests.projects.utils import load_project diff --git a/tests/projects/test_projects.py b/tests/projects/test_projects.py index c5e818aa7d09e..5361bc466c1fb 100644 --- a/tests/projects/test_projects.py +++ b/tests/projects/test_projects.py @@ -9,7 +9,7 @@ import mlflow from mlflow.entities import RunStatus -from mlflow.utils.exception import ExecutionException +from mlflow.exceptions import ExecutionException from mlflow.store.file_store import FileStore from mlflow.utils import env diff --git a/tests/spark/test_spark_model_export.py b/tests/spark/test_spark_model_export.py index 805fc22ee9785..8b163d8b79300 100644 --- a/tests/spark/test_spark_model_export.py +++ b/tests/spark/test_spark_model_export.py @@ -1,18 +1,23 @@ import os +import json import pandas as pd import pyspark from pyspark.ml.classification import LogisticRegression from pyspark.ml.feature import VectorAssembler from pyspark.ml.pipeline import Pipeline +from pyspark.ml.wrapper import JavaModel from pyspark.version import __version__ as pyspark_version import pytest from sklearn import datasets import shutil +from collections import namedtuple import mlflow -from mlflow import active_run, pyfunc +import mlflow.tracking +from mlflow import active_run, pyfunc, mleap from mlflow import spark as sparkm +from mlflow.models import Model from mlflow.utils.environment import _mlflow_conda_env from tests.helper_functions import score_model_in_sagemaker_docker_container @@ -20,6 +25,55 @@ from tests.pyfunc.test_spark import score_model_as_udf +@pytest.fixture +def spark_conda_env(tmpdir): + conda_env = os.path.join(str(tmpdir), "conda_env.yml") + _mlflow_conda_env(conda_env, additional_pip_deps=["pyspark=={}".format(pyspark_version)]) + return conda_env + + +SparkModelWithData = namedtuple("SparkModelWithData", + ["model", "training_df", "inference_df"]) + + +# Specify `autouse=True` to ensure that a context is created +# before any tests are executed. This ensures that the Hadoop filesystem +# does not create its own SparkContext without the MLeap libraries required by +# other tests. +@pytest.fixture(scope="session", autouse=True) +def spark_context(): + conf = pyspark.SparkConf() + conf.set(key="spark.jars.packages", + value='ml.combust.mleap:mleap-spark-base_2.11:0.10.0,' + 'ml.combust.mleap:mleap-spark_2.11:0.10.0') + conf.set(key="spark_session.python.worker.reuse", value=True) + sc = pyspark.SparkContext(master="local-cluster[2, 1, 1024]", conf=conf).getOrCreate() + return sc + + +@pytest.fixture(scope="session") +def spark_model_iris(spark_context): + iris = datasets.load_iris() + X = iris.data # we only take the first two features. + y = iris.target + feature_names = ["0", "1", "2", "3"] + pandas_df = pd.DataFrame(X, columns=feature_names) # to make spark_udf work + pandas_df['label'] = pd.Series(y) + spark_session = pyspark.sql.SparkSession(spark_context) + spark_df = spark_session.createDataFrame(pandas_df) + assembler = VectorAssembler(inputCols=feature_names, outputCol="features") + lr = LogisticRegression(maxIter=50, regParam=0.1, elasticNetParam=0.8) + pipeline = Pipeline(stages=[assembler, lr]) + # Fit the model + model = pipeline.fit(spark_df) + return SparkModelWithData(model=model, training_df=spark_df, inference_df=pandas_df) + + +@pytest.fixture +def model_path(tmpdir): + return str(tmpdir.mkdir("model")) + + def test_hadoop_filesystem(tmpdir): # copy local dir to and back from HadoopFS and make sure the results match from mlflow.spark import _HadoopFileSystem as FS @@ -58,37 +112,20 @@ def test_hadoop_filesystem(tmpdir): @pytest.mark.large -def test_model_export(tmpdir): - conda_env = os.path.join(str(tmpdir), "conda_env.yml") - _mlflow_conda_env(conda_env, additional_pip_deps=["pyspark=={}".format(pyspark_version)]) - iris = datasets.load_iris() - X = iris.data # we only take the first two features. - y = iris.target - pandas_df = pd.DataFrame(X, columns=iris.feature_names) - pandas_df['label'] = pd.Series(y) - spark_session = pyspark.sql.SparkSession.builder \ - .config(key="spark_session.python.worker.reuse", value=True) \ - .master("local-cluster[2, 1, 1024]") \ - .getOrCreate() - spark_df = spark_session.createDataFrame(pandas_df) - model_path = tmpdir.mkdir("model") - assembler = VectorAssembler(inputCols=iris.feature_names, outputCol="features") - lr = LogisticRegression(maxIter=50, regParam=0.1, elasticNetParam=0.8) - pipeline = Pipeline(stages=[assembler, lr]) - # Fit the model - model = pipeline.fit(spark_df) - # Print the coefficients and intercept for multinomial logistic regression - preds_df = model.transform(spark_df) +def test_model_export(spark_model_iris, model_path, spark_conda_env): + preds_df = spark_model_iris.model.transform(spark_model_iris.training_df) preds1 = [x.prediction for x in preds_df.select("prediction").collect()] - sparkm.save_model(model, path=str(model_path), conda_env=conda_env) - reloaded_model = sparkm.load_model(path=str(model_path)) - preds_df_1 = reloaded_model.transform(spark_df) + sparkm.save_model(spark_model_iris.model, path=model_path, + conda_env=spark_conda_env) + reloaded_model = sparkm.load_model(path=model_path) + preds_df_1 = reloaded_model.transform(spark_model_iris.training_df) preds1_1 = [x.prediction for x in preds_df_1.select("prediction").collect()] assert preds1 == preds1_1 - m = pyfunc.load_pyfunc(str(model_path)) - preds2 = m.predict(pandas_df) + m = pyfunc.load_pyfunc(model_path) + preds2 = m.predict(spark_model_iris.inference_df) assert preds1 == preds2 - preds3 = score_model_in_sagemaker_docker_container(model_path=str(model_path), data=pandas_df) + preds3 = score_model_in_sagemaker_docker_container(model_path=model_path, + data=spark_model_iris.inference_df) assert preds1 == preds3 assert os.path.exists(sparkm.DFS_TMP) print(os.listdir(sparkm.DFS_TMP)) @@ -97,25 +134,9 @@ def test_model_export(tmpdir): @pytest.mark.large -def test_model_log(tmpdir): - conda_env = os.path.join(str(tmpdir), "conda_env.yml") - _mlflow_conda_env(conda_env, additional_pip_deps=["pyspark=={}".format(pyspark_version)]) - iris = datasets.load_iris() - feature_names = ["0", "1", "2", "3"] - pandas_df = pd.DataFrame(iris.data, columns=feature_names) # to make spark_udf work - pandas_df['label'] = pd.Series(iris.target) - spark_session = pyspark.sql.SparkSession.builder \ - .config(key="spark_session.python.worker.reuse", value=True) \ - .master("local-cluster[2, 1, 1024]") \ - .getOrCreate() - spark_df = spark_session.createDataFrame(pandas_df) - assembler = VectorAssembler(inputCols=feature_names, outputCol="features") - lr = LogisticRegression(maxIter=50, regParam=0.1, elasticNetParam=0.8) - pipeline = Pipeline(stages=[assembler, lr]) - # Fit the model - model = pipeline.fit(spark_df) +def test_model_log_with_sparkml_format(tmpdir, spark_model_iris): # Print the coefficients and intercept for multinomial logistic regression - preds_df = model.transform(spark_df) + preds_df = spark_model_iris.model.transform(spark_model_iris.training_df) preds1 = [x.prediction for x in preds_df.select("prediction").collect()] old_tracking_uri = mlflow.get_tracking_uri() cnt = 0 @@ -130,21 +151,21 @@ def test_model_log(tmpdir): mlflow.start_run() artifact_path = "model%d" % cnt cnt += 1 - sparkm.log_model(artifact_path=artifact_path, spark_model=model, + sparkm.log_model(artifact_path=artifact_path, spark_model=spark_model_iris.model, dfs_tmpdir=dfs_tmp_dir) run_id = active_run().info.run_uuid # test pyfunc x = pyfunc.load_pyfunc(artifact_path, run_id=run_id) - preds2 = x.predict(pandas_df) + preds2 = x.predict(spark_model_iris.inference_df) assert preds1 == preds2 # test load model reloaded_model = sparkm.load_model(artifact_path, run_id=run_id, dfs_tmpdir=dfs_tmp_dir) - preds_df_1 = reloaded_model.transform(spark_df) + preds_df_1 = reloaded_model.transform(spark_model_iris.training_df) preds3 = [x.prediction for x in preds_df_1.select("prediction").collect()] assert preds1 == preds3 - # test spar_udf - preds4 = score_model_as_udf(artifact_path, run_id, pandas_df) + # test spark_udf + preds4 = score_model_as_udf(artifact_path, run_id, spark_model_iris.inference_df) assert preds1 == preds4 # We expect not to delete the DFS tempdir. x = dfs_tmp_dir or sparkm.DFS_TMP @@ -155,3 +176,119 @@ def test_model_log(tmpdir): mlflow.end_run() mlflow.set_tracking_uri(old_tracking_uri) shutil.rmtree(tracking_dir) + + +def test_spark_module_model_save_with_sample_input_produces_sparkml_and_mleap_flavors( + spark_model_iris, model_path): + mlflow_model = Model() + sparkm.save_model(spark_model=spark_model_iris.model, + path=model_path, + sample_input=spark_model_iris.training_df, + mlflow_model=mlflow_model) + assert sparkm.FLAVOR_NAME in mlflow_model.flavors + assert mleap.FLAVOR_NAME in mlflow_model.flavors + + config_path = os.path.join(model_path, "MLmodel") + assert os.path.exists(config_path) + config = Model.load(config_path) + assert sparkm.FLAVOR_NAME in config.flavors + assert mleap.FLAVOR_NAME in config.flavors + + +def test_spark_module_model_log_with_sample_input_produces_sparkml_and_mleap_flavors( + spark_model_iris): + artifact_path = "model" + mlflow_model = sparkm.log_model(spark_model=spark_model_iris.model, + sample_input=spark_model_iris.training_df, + artifact_path=artifact_path) + rid = active_run().info.run_uuid + model_path = mlflow.tracking.utils._get_model_log_dir(model_name=artifact_path, run_id=rid) + config_path = os.path.join(model_path, "MLmodel") + mlflow_model = Model.load(config_path) + assert sparkm.FLAVOR_NAME in mlflow_model.flavors + assert mleap.FLAVOR_NAME in mlflow_model.flavors + + +def test_mleap_module_model_log_produces_mleap_flavor(spark_model_iris): + artifact_path = "model" + mlflow_model = mleap.log_model(spark_model=spark_model_iris.model, + sample_input=spark_model_iris.training_df, + artifact_path=artifact_path) + rid = active_run().info.run_uuid + model_path = mlflow.tracking.utils._get_model_log_dir(model_name=artifact_path, run_id=rid) + config_path = os.path.join(model_path, "MLmodel") + mlflow_model = Model.load(config_path) + assert mleap.FLAVOR_NAME in mlflow_model.flavors + + +def test_mleap_model_save_outputs_json_formatted_schema_with_named_fields( + spark_model_iris, model_path): + mlflow_model = Model() + mleap.save_model(spark_model=spark_model_iris.model, + path=model_path, + sample_input=spark_model_iris.training_df, + mlflow_model=mlflow_model) + mleap_conf = mlflow_model.flavors[mleap.FLAVOR_NAME] + schema_path_sub = mleap_conf["input_schema"] + schema_path_full = os.path.join(model_path, schema_path_sub) + with open(schema_path_full, "r") as f: + json_schema = json.load(f) + + assert "fields" in json_schema.keys() + assert len(json_schema["fields"]) > 0 + assert type(json_schema["fields"][0]) == dict + assert "name" in json_schema["fields"][0] + + +def test_spark_module_model_save_with_mleap_and_unsupported_transformer_raises_exception( + spark_model_iris, model_path): + class CustomTransformer(JavaModel): + def _transform(self, dataset): + return dataset + + unsupported_pipeline = Pipeline(stages=[CustomTransformer()]) + unsupported_model = unsupported_pipeline.fit(spark_model_iris.training_df) + + with pytest.raises(Exception): + sparkm.save_model(spark_model=unsupported_model, + path=model_path, + sample_input=spark_model_iris.training_df) + + +def test_mleap_module_model_save_with_valid_sample_input_produces_mleap_flavor( + spark_model_iris, model_path): + mlflow_model = Model() + mleap.save_model(spark_model=spark_model_iris.model, + path=model_path, + sample_input=spark_model_iris.training_df, + mlflow_model=mlflow_model) + assert mleap.FLAVOR_NAME in mlflow_model.flavors + + config_path = os.path.join(model_path, "MLmodel") + assert os.path.exists(config_path) + config = Model.load(config_path) + assert mleap.FLAVOR_NAME in config.flavors + + +def test_mleap_module_model_save_with_invalid_sample_input_type_raises_exception( + spark_model_iris, model_path): + with pytest.raises(Exception): + invalid_input = pd.DataFrame() + sparkm.save_model(spark_model=spark_model_iris.model, + path=model_path, + sample_input=invalid_input) + + +def test_mleap_module_model_save_with_unsupported_transformer_raises_exception( + spark_model_iris, model_path): + class CustomTransformer(JavaModel): + def _transform(self, dataset): + return dataset + + unsupported_pipeline = Pipeline(stages=[CustomTransformer()]) + unsupported_model = unsupported_pipeline.fit(spark_model_iris.training_df) + + with pytest.raises(Exception): + mleap.save_model(spark_model=unsupported_model, + path=model_path, + sample_input=spark_model_iris.training_df) diff --git a/tests/utils/test_exception.py b/tests/utils/test_exception.py index 279ef17059987..86fd511a9416d 100644 --- a/tests/utils/test_exception.py +++ b/tests/utils/test_exception.py @@ -1,4 +1,4 @@ -from mlflow.utils.exception import ExecutionException +from mlflow.exceptions import ExecutionException def test_execution_exception_string_repr():