Skip to content

Commit

Permalink
Use absolute path for MLeap bundle serialization (mlflow#671)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbczumar committed Oct 26, 2018
1 parent 8ccc546 commit 4edde91
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
3 changes: 3 additions & 0 deletions mlflow/mleap.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def add_to_model(mlflow_model, path, spark_model, sample_input):
raise Exception("The sample input must be a PySpark dataframe of type `{df_type}`".format(
df_type=DataFrame.__module__))

# MLeap's model serialization routine requires an absolute output path
path = os.path.abspath(path)

mleap_path_full = os.path.join(path, "mleap")
mleap_datapath_sub = os.path.join("mleap", "model")
mleap_datapath_full = os.path.join(path, mleap_datapath_sub)
Expand Down
42 changes: 39 additions & 3 deletions tests/spark/test_spark_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from mlflow import active_run, pyfunc, mleap
from mlflow import spark as sparkm
from mlflow.models import Model
from mlflow.utils.file_utils import TempDir

from mlflow.utils.environment import _mlflow_conda_env
from tests.helper_functions import score_model_in_sagemaker_docker_container
Expand All @@ -47,8 +48,8 @@ def spark_conda_env(tmpdir):
def spark_context():
conf = pyspark.SparkConf()
conf.set(key="spark.jars.packages",
value='ml.combust.mleap:mleap-spark-base_2.11:0.10.0,'
'ml.combust.mleap:mleap-spark_2.11:0.10.0')
value='ml.combust.mleap:mleap-spark-base_2.11:0.12.0,'
'ml.combust.mleap:mleap-spark_2.11:0.12.0')
conf.set(key="spark_session.python.worker.reuse", value=True)
sc = pyspark.SparkContext(master="local-cluster[2, 1, 1024]", conf=conf).getOrCreate()
return sc
Expand Down Expand Up @@ -236,8 +237,43 @@ def _transform(self, dataset):
sample_input=spark_model_iris.spark_df)


def test_mleap_module_model_save_with_valid_sample_input_produces_mleap_flavor(
def test_spark_module_model_save_with_relative_path_and_valid_sample_input_produces_mleap_flavor(
spark_model_iris):
with TempDir(chdr=True) as tmp:
model_path = os.path.basename(tmp.path("model"))
mlflow_model = Model()
sparkm.save_model(spark_model=spark_model_iris.model,
path=model_path,
sample_input=spark_model_iris.spark_df,
mlflow_model=mlflow_model)
assert mleap.FLAVOR_NAME in mlflow_model.flavors

config_path = os.path.join(model_path, "MLmodel")
assert os.path.exists(config_path)
config = Model.load(config_path)
assert mleap.FLAVOR_NAME in config.flavors


def test_mleap_module_model_save_with_relative_path_and_valid_sample_input_produces_mleap_flavor(
spark_model_iris):
with TempDir(chdr=True) as tmp:
model_path = os.path.basename(tmp.path("model"))
mlflow_model = Model()
mleap.save_model(spark_model=spark_model_iris.model,
path=model_path,
sample_input=spark_model_iris.spark_df,
mlflow_model=mlflow_model)
assert mleap.FLAVOR_NAME in mlflow_model.flavors

config_path = os.path.join(model_path, "MLmodel")
assert os.path.exists(config_path)
config = Model.load(config_path)
assert mleap.FLAVOR_NAME in config.flavors


def test_mleap_module_model_save_with_absolute_path_and_valid_sample_input_produces_mleap_flavor(
spark_model_iris, model_path):
model_path = os.path.abspath(model_path)
mlflow_model = Model()
mleap.save_model(spark_model=spark_model_iris.model,
path=model_path,
Expand Down

0 comments on commit 4edde91

Please sign in to comment.