Skip to content

Commit

Permalink
Disallow running projects on Databricks with a local tracking URI (ml…
Browse files Browse the repository at this point in the history
…flow#522)

Perform validation upfront to verify that tracking URI is not a local tracking URI when running projects on Databricks
  • Loading branch information
smurching committed Oct 1, 2018
1 parent 4d36406 commit 9b1de17
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 31 deletions.
4 changes: 4 additions & 0 deletions mlflow/projects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from mlflow.tracking.fluent import _get_experiment_id, _get_git_commit


import mlflow.projects.databricks
from mlflow.utils import process
from mlflow.utils.logging_utils import eprint
from mlflow.utils.mlflow_tags import MLFLOW_GIT_BRANCH_NAME
Expand All @@ -38,6 +39,9 @@ def _run(uri, entry_point="main", version=None, parameters=None, experiment_id=N
Helper that delegates to the project-running method corresponding to the passed-in mode.
Returns a ``SubmittedRun`` corresponding to the project run.
"""
if mode == "databricks":
mlflow.projects.databricks.before_run_validations(mlflow.get_tracking_uri(), cluster_spec)

exp_id = experiment_id or _get_experiment_id()
parameters = parameters or {}
work_dir = _fetch_project(uri=uri, force_tempdir=False, version=version,
Expand Down
36 changes: 14 additions & 22 deletions mlflow/projects/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ def _check_response_status_code(response):
raise ExecutionException("The Databricks request failed. Error: {}".format(response.text))


def before_run_validations(tracking_uri, cluster_spec):
"""Validations to perform before running a project on Databricks."""
if cluster_spec is None:
raise ExecutionException("Cluster spec must be provided when launching MLflow project "
"runs on Databricks.")
if tracking.utils._is_local_uri(tracking_uri):
raise ExecutionException(
"When running on Databricks, the MLflow tracking URI must be of the form "
"'databricks' or 'databricks://profile', or a remote HTTP URI accessible to both the "
"current client and code running on Databricks. Got local tracking URI %s. "
"Please specify a valid tracking URI via mlflow.set_tracking_uri or by setting the "
"MLFLOW_TRACKING_URI environment variable." % tracking_uri)


class DatabricksJobRunner(object):
"""
Helper class for running an MLflow project as a Databricks Job.
Expand All @@ -59,13 +73,6 @@ def _jobs_runs_submit(self, req_body):
_check_response_status_code(response)
return json.loads(response.text)

def _check_auth_available(self):
"""
Verify that information for making API requests to Databricks is available to MLflow,
raising an exception if not.
"""
databricks_utils.get_databricks_host_creds(self.databricks_profile)

def _upload_to_dbfs(self, src_path, dbfs_fuse_uri):
"""
Upload the file at `src_path` to the specified DBFS URI within the Databricks workspace
Expand Down Expand Up @@ -159,22 +166,9 @@ def _run_shell_command_job(self, project_uri, command, env_vars, cluster_spec):
databricks_run_id = run_submit_res["run_id"]
return databricks_run_id

def _before_run_validations(self, tracking_uri, cluster_spec):
"""Validations to perform before running a project on Databricks."""
self._check_auth_available()
if cluster_spec is None:
raise ExecutionException("Cluster spec must be provided when launching MLflow project "
"runs on Databricks.")
if tracking.utils._is_local_uri(tracking_uri):
raise ExecutionException(
"When running on Databricks, the MLflow tracking URI must be set to a remote URI "
"accessible to both the current client and code running on Databricks. Got local "
"tracking URI %s." % tracking_uri)

def run_databricks(self, uri, entry_point, work_dir, parameters, experiment_id, cluster_spec,
run_id):
tracking_uri = _get_tracking_uri_for_run()
self._before_run_validations(tracking_uri, cluster_spec)
dbfs_fuse_uri = self._upload_project_to_dbfs(work_dir, experiment_id)
env_vars = {
tracking._TRACKING_URI_ENV_VAR: tracking_uri,
Expand Down Expand Up @@ -229,8 +223,6 @@ def jobs_runs_get(self, databricks_run_id):


def _get_tracking_uri_for_run():
if not tracking.utils.is_tracking_uri_set():
return "databricks"
uri = tracking.get_tracking_uri()
if uri.startswith("databricks"):
return "databricks"
Expand Down
15 changes: 6 additions & 9 deletions tests/projects/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def dbfs_mocks(dbfs_path_exists_mock, upload_to_dbfs_mock): # pylint: disable=u

@pytest.fixture()
def before_run_validations_mock(): # pylint: disable=unused-argument
with mock.patch("mlflow.projects.databricks.DatabricksJobRunner._before_run_validations"):
with mock.patch("mlflow.projects.databricks.before_run_validations"):
yield


Expand Down Expand Up @@ -156,8 +156,7 @@ def test_run_databricks_validations(
"""
Tests that running on Databricks fails before making any API requests if validations fail.
"""
with mock.patch("mlflow.projects.databricks.DatabricksJobRunner._check_auth_available"),\
mock.patch.dict(os.environ, {'DATABRICKS_HOST': 'test-host', 'DATABRICKS_TOKEN': 'foo'}),\
with mock.patch.dict(os.environ, {'DATABRICKS_HOST': 'test-host', 'DATABRICKS_TOKEN': 'foo'}),\
mock.patch("mlflow.projects.databricks.DatabricksJobRunner._databricks_api_request")\
as db_api_req_mock:
# Test bad tracking URI
Expand All @@ -166,6 +165,8 @@ def test_run_databricks_validations(
run_databricks_project(cluster_spec_mock, block=True)
assert db_api_req_mock.call_count == 0
db_api_req_mock.reset_mock()
mlflow_service = mlflow.tracking.MlflowClient()
assert len(mlflow_service.list_run_infos(experiment_id=0)) == 0
tracking_uri_mock.return_value = "http://"
# Test misspecified parameters
with pytest.raises(ExecutionException):
Expand All @@ -180,9 +181,8 @@ def test_run_databricks_validations(
assert db_api_req_mock.call_count == 0
db_api_req_mock.reset_mock()
# Test that validations pass with good tracking URIs
runner = DatabricksJobRunner(databricks_profile="DEFAULT")
runner._before_run_validations("http://", cluster_spec_mock)
runner._before_run_validations("databricks", cluster_spec_mock)
databricks.before_run_validations("http://", cluster_spec_mock)
databricks.before_run_validations("databricks", cluster_spec_mock)


def test_run_databricks(
Expand Down Expand Up @@ -233,9 +233,6 @@ def test_run_databricks_cancel(


def test_get_tracking_uri_for_run():
with mock.patch.dict(os.environ, {}):
mlflow.set_tracking_uri(None)
assert databricks._get_tracking_uri_for_run() == "databricks"
mlflow.set_tracking_uri("http://some-uri")
assert databricks._get_tracking_uri_for_run() == "http://some-uri"
mlflow.set_tracking_uri("databricks://profile")
Expand Down

0 comments on commit 9b1de17

Please sign in to comment.