Skip to content

Commit

Permalink
Remove _get_model_log_dir() method (mlflow#1224)
Browse files Browse the repository at this point in the history
* Initial changes to load_model and tests

* UDF and CLI methods

* Move S3 and boto fixtures into helper functions

* Revert space diff

* Test fix

* Lint

* Lint2

* Remote URI tests for h2o, keras, sklearn, pytorch

* Remote URI tests for tensorflow and spark

* Test cases fixes

* Lint

* Pytorch tests fix

* Test fixes

* Param fixes

* Address subset of comments

* Address remaining comments

* Lint

* Remove remote URI test

* Remove unused variable

* H2o and keras

* Param fix in Keras

* Remove get_model_log_dir from tests

* Arg fix for pytorch tests

* Remove _get_model_log_dir
  • Loading branch information
dbczumar committed May 9, 2019
1 parent cbc9c9e commit fcc87c4
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 37 deletions.
9 changes: 0 additions & 9 deletions mlflow/tracking/artifact_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,3 @@ def _download_artifact_from_uri(artifact_uri, output_path=None):
artifact_repo = get_artifact_repository(artifact_uri=artifact_src_dir)
return artifact_repo.download_artifacts(artifact_path=artifact_src_relative_path,
dst_path=output_path)


def _get_model_log_dir(model_name, run_id):
if not run_id:
raise Exception("Must specify a run_id to get logging directory for a model.")
store = _get_store()
run = store.get_run(run_id)
artifact_repo = get_artifact_repository(run.info.artifact_uri)
return artifact_repo.download_artifacts(model_name)
10 changes: 5 additions & 5 deletions tests/h2o/test_h2o_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from mlflow import pyfunc
from mlflow.models import Model
from mlflow.store.s3_artifact_repo import S3ArtifactRepository
from mlflow.tracking.artifact_utils import _get_model_log_dir
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.utils.environment import _mlflow_conda_env
from mlflow.utils.file_utils import TempDir
from mlflow.utils.model_utils import _get_flavor_configuration
Expand Down Expand Up @@ -170,8 +170,8 @@ def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(
mlflow.h2o.log_model(h2o_model=h2o_iris_model.model,
artifact_path=artifact_path,
conda_env=h2o_custom_env)
run_id = mlflow.active_run().info.run_id
model_path = _get_model_log_dir(artifact_path, run_id)
model_path = _download_artifact_from_uri("runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path))

pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
Expand Down Expand Up @@ -204,8 +204,8 @@ def test_model_log_without_specified_conda_env_uses_default_env_with_expected_de
artifact_path = "model"
with mlflow.start_run():
mlflow.h2o.log_model(h2o_model=h2o_iris_model.model, artifact_path=artifact_path)
run_id = mlflow.active_run().info.run_id
model_path = _get_model_log_dir(artifact_path, run_id)
model_path = _download_artifact_from_uri("runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path))

pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
Expand Down
10 changes: 5 additions & 5 deletions tests/keras/test_keras_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from mlflow import pyfunc
from mlflow.models import Model
from mlflow.store.s3_artifact_repo import S3ArtifactRepository
from mlflow.tracking.artifact_utils import _get_model_log_dir
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.utils.environment import _mlflow_conda_env
from mlflow.utils.model_utils import _get_flavor_configuration
from tests.helper_functions import pyfunc_serve_and_score_model
Expand Down Expand Up @@ -176,8 +176,8 @@ def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(model,
with mlflow.start_run():
mlflow.keras.log_model(
keras_model=model, artifact_path=artifact_path, conda_env=keras_custom_env)
run_id = mlflow.active_run().info.run_id
model_path = _get_model_log_dir(artifact_path, run_id)
model_path = _download_artifact_from_uri("runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path))

pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
Expand Down Expand Up @@ -209,8 +209,8 @@ def test_model_log_without_specified_conda_env_uses_default_env_with_expected_de
artifact_path = "model"
with mlflow.start_run():
mlflow.keras.log_model(keras_model=model, artifact_path=artifact_path, conda_env=None)
run_id = mlflow.active_run().info.run_id
model_path = _get_model_log_dir(artifact_path, run_id)
model_path = _download_artifact_from_uri("runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path))

pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
Expand Down
10 changes: 5 additions & 5 deletions tests/pyfunc/test_model_export_with_class_and_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from mlflow.models import Model
from mlflow.store.s3_artifact_repo import S3ArtifactRepository
from mlflow.tracking.artifact_utils import get_artifact_uri as utils_get_artifact_uri, \
_get_model_log_dir
_download_artifact_from_uri
from mlflow.utils.environment import _mlflow_conda_env
from mlflow.utils.model_utils import _get_flavor_configuration

Expand Down Expand Up @@ -410,9 +410,9 @@ def test_log_model_persists_specified_conda_env_in_mlflow_model_directory(
},
python_model=main_scoped_model_class(predict_fn=None),
conda_env=pyfunc_custom_env)
pyfunc_run_id = mlflow.active_run().info.run_id
pyfunc_model_path = _download_artifact_from_uri("runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=pyfunc_artifact_path))

pyfunc_model_path = _get_model_log_dir(pyfunc_artifact_path, pyfunc_run_id)
pyfunc_conf = _get_flavor_configuration(
model_path=pyfunc_model_path, flavor_name=mlflow.pyfunc.FLAVOR_NAME)
saved_conda_env_path = os.path.join(pyfunc_model_path, pyfunc_conf[mlflow.pyfunc.ENV])
Expand Down Expand Up @@ -465,9 +465,9 @@ def test_log_model_without_specified_conda_env_uses_default_env_with_expected_de
run_id=sklearn_run_id)
},
python_model=main_scoped_model_class(predict_fn=None))
pyfunc_run_id = mlflow.active_run().info.run_id
pyfunc_model_path = _download_artifact_from_uri("runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=pyfunc_artifact_path))

pyfunc_model_path = _get_model_log_dir(pyfunc_artifact_path, pyfunc_run_id)
pyfunc_conf = _get_flavor_configuration(
model_path=pyfunc_model_path, flavor_name=mlflow.pyfunc.FLAVOR_NAME)
conda_env_path = os.path.join(pyfunc_model_path, pyfunc_conf[mlflow.pyfunc.ENV])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import mlflow.sklearn
from mlflow.exceptions import MlflowException
from mlflow.models import Model
from mlflow.tracking.artifact_utils import _get_model_log_dir
from mlflow.tracking.artifact_utils import _download_artifact_from_uri


def _load_pyfunc(path):
Expand Down Expand Up @@ -83,9 +83,9 @@ def test_model_log_load(sklearn_knn_model, iris_data, tmpdir):
data_path=sk_model_path,
loader_module=os.path.basename(__file__)[:-3],
code_path=[__file__])
pyfunc_run_id = mlflow.active_run().info.run_id
pyfunc_model_path = _download_artifact_from_uri("runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=pyfunc_artifact_path))

pyfunc_model_path = _get_model_log_dir(pyfunc_artifact_path, pyfunc_run_id)
model_config = Model.load(os.path.join(pyfunc_model_path, "MLmodel"))
assert mlflow.pyfunc.FLAVOR_NAME in model_config.flavors
assert mlflow.pyfunc.PY_VERSION in model_config.flavors[mlflow.pyfunc.FLAVOR_NAME]
Expand Down
19 changes: 9 additions & 10 deletions tests/pytorch/test_pytorch_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from mlflow.models import Model
from mlflow.pytorch import pickle_module as mlflow_pytorch_pickle_module
from mlflow.store.s3_artifact_repo import S3ArtifactRepository
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.utils.environment import _mlflow_conda_env
from mlflow.utils.file_utils import TempDir
from mlflow.utils.model_utils import _get_flavor_configuration
Expand Down Expand Up @@ -301,8 +302,8 @@ def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(
mlflow.pytorch.log_model(pytorch_model=sequential_model,
artifact_path=artifact_path,
conda_env=pytorch_custom_env)
run_id = mlflow.active_run().info.run_id
model_path = tracking.artifact_utils._get_model_log_dir(artifact_path, run_id)
model_path = _download_artifact_from_uri("runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path))

pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
Expand Down Expand Up @@ -337,8 +338,8 @@ def test_model_log_without_specified_conda_env_uses_default_env_with_expected_de
mlflow.pytorch.log_model(pytorch_model=sequential_model,
artifact_path=artifact_path,
conda_env=None)
run_id = mlflow.active_run().info.run_id
model_path = tracking.artifact_utils._get_model_log_dir(artifact_path, run_id)
model_path = _download_artifact_from_uri("runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path))

pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
Expand Down Expand Up @@ -455,10 +456,8 @@ def predict(self, context, model_input):
artifacts={
"pytorch_model": model_path,
})
pyfunc_run_id = mlflow.active_run().info.run_id

pyfunc_model_path = tracking.artifact_utils._get_model_log_dir(pyfunc_artifact_path,
pyfunc_run_id)
pyfunc_model_path = _download_artifact_from_uri("runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=pyfunc_artifact_path))

# Deploy the custom pyfunc model and ensure that it is able to successfully load its
# constituent PyTorch model via `mlflow.pytorch.load_model`
Expand Down Expand Up @@ -583,8 +582,8 @@ def test_load_model_succeeds_when_data_is_model_file_instead_of_directory(
artifact_path=artifact_path,
pytorch_model=module_scoped_subclassed_model,
conda_env=None)
run_id = mlflow.active_run().info.run_id
model_path = tracking.artifact_utils._get_model_log_dir(artifact_path, run_id)
model_path = _download_artifact_from_uri("runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path))

model_conf_path = os.path.join(model_path, "MLmodel")
model_conf = Model.load(model_conf_path)
Expand Down

0 comments on commit fcc87c4

Please sign in to comment.