Skip to content

Commit

Permalink
Allow writing Spark models directly to the target artifact store when…
Browse files Browse the repository at this point in the history
… possible (mlflow#808)

Allow writing Spark models directly to the target artifact store when the artifact store lies in a FS Spark can write to.
  • Loading branch information
smurching committed Jan 17, 2019
1 parent b867fbf commit a92d1bf
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 34 deletions.
113 changes: 80 additions & 33 deletions mlflow/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,25 @@
import yaml
import logging

from py4j.protocol import Py4JJavaError
import pyspark
from pyspark import SparkContext
from pyspark.ml.pipeline import PipelineModel

import mlflow
from mlflow import pyfunc, mleap
from mlflow.exceptions import MlflowException
from mlflow.models import Model
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.utils.environment import _mlflow_conda_env
from mlflow.utils.model_utils import _get_flavor_configuration
from mlflow.utils.file_utils import TempDir

FLAVOR_NAME = "spark"

# Default temporary directory on DFS. Used to write / read from Spark ML models.
DFS_TMP = "/tmp/mlflow"
_SPARK_MODEL_PATH_SUB = "sparkml"

DEFAULT_CONDA_ENV = _mlflow_conda_env(
additional_conda_deps=[
Expand Down Expand Up @@ -103,9 +108,35 @@ def log_model(spark_model, artifact_path, conda_env=None, jars=None, dfs_tmpdir=
>>> model = pipeline.fit(training)
>>> mlflow.spark.log_model(model, "spark-model")
"""
return Model.log(artifact_path=artifact_path, flavor=mlflow.spark, spark_model=spark_model,
jars=jars, conda_env=conda_env, dfs_tmpdir=dfs_tmpdir,
sample_input=sample_input)
_validate_model(spark_model, jars)
run_id = mlflow.tracking.fluent._get_or_start_run().info.run_uuid
run_root_artifact_uri = mlflow.get_artifact_uri()
# If the artifact URI is a local filesystem path, defer to Model.log() to persist the model,
# since Spark may not be able to write directly to the driver's filesystem. For example,
# writing to `file:/uri` will write to the local filesystem from each executor, which will
# be incorrect on multi-node clusters - to avoid such issues we just use the Model.log() path
# here.
if mlflow.tracking.utils._is_local_uri(run_root_artifact_uri):
return Model.log(artifact_path=artifact_path, flavor=mlflow.spark, spark_model=spark_model,
jars=jars, conda_env=conda_env, dfs_tmpdir=dfs_tmpdir,
sample_input=sample_input)
# If Spark cannot write directly to the artifact repo, defer to Model.log() to persist the
# model
model_dir = os.path.join(run_root_artifact_uri, artifact_path)
try:
spark_model.save(os.path.join(model_dir, _SPARK_MODEL_PATH_SUB))
except Py4JJavaError:
return Model.log(artifact_path=artifact_path, flavor=mlflow.spark, spark_model=spark_model,
jars=jars, conda_env=conda_env, dfs_tmpdir=dfs_tmpdir,
sample_input=sample_input)

# Otherwise, override the default model log behavior and save model directly to artifact repo
mlflow_model = Model(artifact_path=artifact_path, run_id=run_id)
with TempDir() as tmp:
tmp_model_metadata_dir = tmp.path()
_save_model_metadata(
tmp_model_metadata_dir, spark_model, mlflow_model, sample_input, conda_env)
mlflow.tracking.fluent.log_artifacts(tmp_model_metadata_dir, artifact_path)


def _tmp_path(dfs_tmp):
Expand Down Expand Up @@ -135,11 +166,14 @@ def _jvm(cls):
@classmethod
def _fs(cls):
if not cls._filesystem:
sc = SparkContext.getOrCreate()
cls._conf = sc._jsc.hadoopConfiguration()
cls._filesystem = cls._jvm().org.apache.hadoop.fs.FileSystem.get(cls._conf)
cls._filesystem = cls._jvm().org.apache.hadoop.fs.FileSystem.get(cls._conf())
return cls._filesystem

@classmethod
def _conf(cls):
sc = SparkContext.getOrCreate()
return sc._jsc.hadoopConfiguration()

@classmethod
def _local_path(cls, path):
return cls._jvm().org.apache.hadoop.fs.Path(os.path.abspath(path))
Expand Down Expand Up @@ -181,6 +215,41 @@ def delete(cls, path):
cls._fs().delete(cls._remote_path(path), True)


def _save_model_metadata(dst_dir, spark_model, mlflow_model, sample_input, conda_env):
"""
Saves model metadata into the passed-in directory. The persisted metadata assumes that a
model can be loaded from a relative path to the metadata file (currently hard-coded to
"sparkml").
"""
if sample_input is not None:
mleap.add_to_model(mlflow_model, dst_dir, spark_model, sample_input)

pyspark_version = pyspark.version.__version__

conda_env_subpath = "conda.yaml"
if conda_env is None:
conda_env = DEFAULT_CONDA_ENV
elif not isinstance(conda_env, dict):
with open(conda_env, "r") as f:
conda_env = yaml.safe_load(f)
with open(os.path.join(dst_dir, conda_env_subpath), "w") as f:
yaml.safe_dump(conda_env, stream=f, default_flow_style=False)

mlflow_model.add_flavor(FLAVOR_NAME, pyspark_version=pyspark_version,
model_data=_SPARK_MODEL_PATH_SUB)
pyfunc.add_to_model(mlflow_model, loader_module="mlflow.spark", data=_SPARK_MODEL_PATH_SUB,
env=conda_env_subpath)
mlflow_model.save(os.path.join(dst_dir, "MLmodel"))


def _validate_model(spark_model, jars):
if not isinstance(spark_model, PipelineModel):
raise MlflowException("Not a PipelineModel. SparkML can only save PipelineModels.",
INVALID_PARAMETER_VALUE)
if jars:
raise MlflowException("JAR dependencies are not implemented", INVALID_PARAMETER_VALUE)


def save_model(spark_model, path, mlflow_model=Model(), conda_env=None, jars=None,
dfs_tmpdir=None, sample_input=None):
"""
Expand Down Expand Up @@ -228,40 +297,18 @@ def save_model(spark_model, path, mlflow_model=Model(), conda_env=None, jars=Non
>>> model = ...
>>> mlflow.spark.save_model(model, "spark-model")
"""
if jars:
raise Exception("jar dependencies are not implemented")

if sample_input is not None:
mleap.add_to_model(mlflow_model, path, spark_model, sample_input)

if not isinstance(spark_model, PipelineModel):
raise Exception("Not a PipelineModel. SparkML can only save PipelineModels.")

_validate_model(spark_model, jars)
# Spark ML stores the model on DFS if running on a cluster
# Save it to a DFS temp dir first and copy it to local path
if dfs_tmpdir is None:
dfs_tmpdir = DFS_TMP
tmp_path = _tmp_path(dfs_tmpdir)
spark_model.save(tmp_path)
sparkml_data_path_sub = "sparkml"
sparkml_data_path = os.path.abspath(os.path.join(path, sparkml_data_path_sub))
sparkml_data_path = os.path.abspath(os.path.join(path, _SPARK_MODEL_PATH_SUB))
_HadoopFileSystem.copy_to_local_file(tmp_path, sparkml_data_path, remove_src=True)
pyspark_version = pyspark.version.__version__

conda_env_subpath = "conda.yaml"
if conda_env is None:
conda_env = DEFAULT_CONDA_ENV
elif not isinstance(conda_env, dict):
with open(conda_env, "r") as f:
conda_env = yaml.safe_load(f)
with open(os.path.join(path, conda_env_subpath), "w") as f:
yaml.safe_dump(conda_env, stream=f, default_flow_style=False)

mlflow_model.add_flavor(FLAVOR_NAME, pyspark_version=pyspark_version,
model_data=sparkml_data_path_sub)
pyfunc.add_to_model(mlflow_model, loader_module="mlflow.spark", data=sparkml_data_path_sub,
env=conda_env_subpath)
mlflow_model.save(os.path.join(path, "MLmodel"))
_save_model_metadata(
dst_dir=path, spark_model=spark_model, mlflow_model=mlflow_model,
sample_input=sample_input, conda_env=conda_env)


def _load_model(model_path, dfs_tmpdir=None):
Expand Down
17 changes: 16 additions & 1 deletion tests/spark/test_spark_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import mlflow.tracking
from mlflow import active_run, pyfunc, mleap
from mlflow import spark as sparkm
from mlflow.exceptions import MlflowException
from mlflow.models import Model
from mlflow.tracking.utils import _get_model_log_dir
from mlflow.utils.environment import _mlflow_conda_env
Expand Down Expand Up @@ -242,6 +243,20 @@ def test_sparkml_model_save_persists_specified_conda_env_in_mlflow_model_directo
assert saved_conda_env_parsed == spark_custom_env_parsed


def test_sparkml_model_log_invalid_args(spark_model_iris, model_path):
with pytest.raises(MlflowException) as e:
sparkm.log_model(
spark_model=spark_model_iris.model.stages[0],
artifact_path="model0")
assert e.message.contains("SparkML can only save PipelineModels")
with pytest.raises(MlflowException) as e:
sparkm.log_model(
spark_model=spark_model_iris.model,
artifact_path="model1",
jars=["something.jar"])
assert e.message.contains("JAR dependencies are not implemented")


def test_sparkml_model_save_accepts_conda_env_as_dict(spark_model_iris, model_path):
conda_env = dict(mlflow.spark.DEFAULT_CONDA_ENV)
conda_env["dependencies"].append("pytest")
Expand Down Expand Up @@ -350,7 +365,7 @@ def _transform(self, dataset):
unsupported_pipeline = Pipeline(stages=[CustomTransformer()])
unsupported_model = unsupported_pipeline.fit(spark_model_iris.spark_df)

with pytest.raises(mleap.MLeapSerializationException):
with pytest.raises(ValueError):
sparkm.save_model(spark_model=unsupported_model,
path=model_path,
sample_input=spark_model_iris.spark_df)
Expand Down

0 comments on commit a92d1bf

Please sign in to comment.