Skip to content

Commit

Permalink
Avoid unnecessary copy from / to DFS when loading spark models. (mlfl…
Browse files Browse the repository at this point in the history
…ow#2008)

* fix

* fix

* fix

* test

* nits

* lint.

* lint

* Addressed review comments.

* Update mlflow/spark.py

Co-Authored-By: dbczumar <39497902+dbczumar@users.noreply.github.com>

* lint.

* Update.

* Debug print.

* Debug print.

* update.

* fix

* nit

* fix

* Fixed exception text

* fix

* fix

* fix

* lint

* lint
  • Loading branch information
tomasatdatabricks committed Oct 29, 2019
1 parent ac5dadc commit 898b23e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 18 deletions.
59 changes: 46 additions & 13 deletions mlflow/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import os
import yaml
import logging
import posixpath

import mlflow
from mlflow import pyfunc, mleap
Expand All @@ -32,9 +33,10 @@
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.utils.environment import _mlflow_conda_env
from mlflow.utils.model_utils import _get_flavor_configuration
from mlflow.store.artifact.runs_artifact_repo import RunsArtifactRepository
from mlflow.utils.file_utils import TempDir
from mlflow.utils.uri import is_local_uri
from mlflow.utils.model_utils import _get_flavor_configuration_from_uri

FLAVOR_NAME = "spark"

Expand Down Expand Up @@ -227,6 +229,35 @@ def maybe_copy_from_local_file(cls, src, dst):
_logger.info("Copied SparkML model to %s", dst)
return dst

@classmethod
def _try_file_exists(cls, dfs_path):
try:
return cls._fs().exists(dfs_path)
except Exception as ex: # pylint: disable=broad-except
_logger.warning(
"Unexpected exception while checking if model uri is visible on "
"DFS: %s", ex)
return False

@classmethod
def maybe_copy_from_uri(cls, src_uri, dst_path):
"""
Conditionally copy the file to the Hadoop DFS from the source uri.
In case the file is already on the Hadoop DFS do nothing.
:return: If copied, return new target location, otherwise return source uri.
"""
try:
# makeQualified throws if wrong schema / uri
dfs_path = cls._fs().makeQualified(cls._remote_path(src_uri))
if cls._try_file_exists(dfs_path):
_logger.info("File '%s' is already on DFS, copy is not necessary.", src_uri)
return src_uri
except Exception: # pylint: disable=broad-except
_logger.info("URI '%s' does not point to the current DFS.", src_uri)
_logger.info("File '%s' not found on DFS. Will attempt to upload the file.", src_uri)
return cls.maybe_copy_from_local_file(_download_artifact_from_uri(src_uri), dst_path)

@classmethod
def delete(cls, path):
cls._fs().delete(cls._remote_path(path), True)
Expand Down Expand Up @@ -267,9 +298,9 @@ def _validate_model(spark_model):
or not isinstance(spark_model, MLReadable) \
or not isinstance(spark_model, MLWritable):
raise MlflowException(
"Cannot serialize this model. MLFlow can only save descendants of pyspark.Model"
"that implement MLWritable and MLReadable.",
INVALID_PARAMETER_VALUE)
"Cannot serialize this model. MLFlow can only save descendants of pyspark.Model"
"that implement MLWritable and MLReadable.",
INVALID_PARAMETER_VALUE)


def save_model(spark_model, path, mlflow_model=Model(), conda_env=None,
Expand Down Expand Up @@ -335,16 +366,15 @@ def save_model(spark_model, path, mlflow_model=Model(), conda_env=None,
sample_input=sample_input, conda_env=conda_env)


def _load_model(model_path, dfs_tmpdir=None):
def _load_model(model_uri, dfs_tmpdir=None):
from pyspark.ml.pipeline import PipelineModel

if dfs_tmpdir is None:
dfs_tmpdir = DFS_TMP
tmp_path = _tmp_path(dfs_tmpdir)
# Spark ML expects the model to be stored on DFS
# Copy the model to a temp DFS location first. We cannot delete this file, as
# Spark may read from it at any point.
model_path = _HadoopFileSystem.maybe_copy_from_local_file(model_path, tmp_path)
model_path = _HadoopFileSystem.maybe_copy_from_uri(model_uri, tmp_path)
return PipelineModel.load(model_path)


Expand Down Expand Up @@ -378,10 +408,13 @@ def load_model(model_uri, dfs_tmpdir=None):
>>> # Make predictions on test documents.
>>> prediction = model.transform(test)
"""
local_model_path = _download_artifact_from_uri(artifact_uri=model_uri)
flavor_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name=FLAVOR_NAME)
spark_model_artifacts_path = os.path.join(local_model_path, flavor_conf['model_data'])
return _load_model(model_path=spark_model_artifacts_path, dfs_tmpdir=dfs_tmpdir)
if RunsArtifactRepository.is_runs_uri(model_uri):
runs_uri = model_uri
model_uri = RunsArtifactRepository.get_underlying_uri(model_uri)
_logger.info("'%s' resolved as '%s'", runs_uri, model_uri)
flavor_conf = _get_flavor_configuration_from_uri(model_uri, FLAVOR_NAME)
model_uri = posixpath.join(model_uri, flavor_conf["model_data"])
return _load_model(model_uri=model_uri, dfs_tmpdir=dfs_tmpdir)


def _load_pyfunc(path):
Expand All @@ -398,9 +431,9 @@ def _load_pyfunc(path):

spark = pyspark.sql.SparkSession._instantiatedSession
if spark is None:
spark = pyspark.sql.SparkSession.builder.config("spark.python.worker.reuse", True)\
spark = pyspark.sql.SparkSession.builder.config("spark.python.worker.reuse", True) \
.master("local[1]").getOrCreate()
return _PyFuncModelWrapper(spark, _load_model(model_path=path))
return _PyFuncModelWrapper(spark, _load_model(model_uri=path))


class _PyFuncModelWrapper(object):
Expand Down
39 changes: 34 additions & 5 deletions mlflow/utils/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import posixpath

from mlflow.exceptions import MlflowException
from mlflow.models import Model
from mlflow.protos.databricks_pb2 import RESOURCE_DOES_NOT_EXIST
from mlflow.tracking.artifact_utils import _download_artifact_from_uri


def _get_flavor_configuration(model_path, flavor_name):
Expand All @@ -19,14 +21,41 @@ def _get_flavor_configuration(model_path, flavor_name):
model_configuration_path = os.path.join(model_path, "MLmodel")
if not os.path.exists(model_configuration_path):
raise MlflowException(
"Could not find an \"MLmodel\" configuration file at \"{model_path}\"".format(
model_path=model_path),
RESOURCE_DOES_NOT_EXIST)
"Could not find an \"MLmodel\" configuration file at \"{model_path}\"".format(
model_path=model_path),
RESOURCE_DOES_NOT_EXIST)

model_conf = Model.load(model_configuration_path)
if flavor_name not in model_conf.flavors:
raise MlflowException(
"Model does not have the \"{flavor_name}\" flavor".format(flavor_name=flavor_name),
RESOURCE_DOES_NOT_EXIST)
"Model does not have the \"{flavor_name}\" flavor".format(flavor_name=flavor_name),
RESOURCE_DOES_NOT_EXIST)
conf = model_conf.flavors[flavor_name]
return conf


def _get_flavor_configuration_from_uri(model_uri, flavor_name):
"""
Obtains the configuration for the specified flavor from the specified
MLflow model uri. If the model does not contain the specified flavor,
an exception will be thrown.
:param model_uri: The path to the root directory of the MLflow model for which to load
the specified flavor configuration.
:param flavor_name: The name of the flavor configuration to load.
:return: The flavor configuration as a dictionary.
"""
try:
ml_model_file = _download_artifact_from_uri(
artifact_uri=posixpath.join(model_uri, "MLmodel"))
except Exception as ex:
raise MlflowException(
"Failed to download an \"MLmodel\" model file from \"{model_uri}\": {ex}".format(
model_uri=model_uri, ex=ex),
RESOURCE_DOES_NOT_EXIST)
model_conf = Model.load(ml_model_file)
if flavor_name not in model_conf.flavors:
raise MlflowException(
"Model does not have the \"{flavor_name}\" flavor".format(flavor_name=flavor_name),
RESOURCE_DOES_NOT_EXIST)
return model_conf.flavors[flavor_name]

0 comments on commit 898b23e

Please sign in to comment.