Skip to content

Commit

Permalink
Validate backend store URI before starting tracking server (mlflow#1218)
Browse files Browse the repository at this point in the history
  • Loading branch information
Luke Zhu authored and aarondav committed May 31, 2019
1 parent a11a02d commit bb8c760
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 4 deletions.
15 changes: 15 additions & 0 deletions mlflow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from mlflow.utils.process import ShellCommandException
from mlflow.utils import cli_args
from mlflow.server import _run_server
from mlflow.server.handlers import _get_store
from mlflow.store import DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH
from mlflow import tracking
import mlflow.store.cli
Expand Down Expand Up @@ -171,6 +172,13 @@ def ui(backend_store_uri, default_artifact_root, port):
else:
default_artifact_root = DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH

try:
_get_store(backend_store_uri, default_artifact_root)
except Exception as e: # pylint: disable=broad-except
_logger.error("Error initializing backend store")
_logger.exception(e)
sys.exit(1)

# TODO: We eventually want to disable the write path in this version of the server.
try:
_run_server(backend_store_uri, default_artifact_root, "127.0.0.1", port, 1, None, [])
Expand Down Expand Up @@ -235,6 +243,13 @@ def server(backend_store_uri, default_artifact_root, host, port,
"local file based.")
sys.exit(1)

try:
_get_store(backend_store_uri, default_artifact_root)
except Exception as e: # pylint: disable=broad-except
_logger.error("Error initializing backend store")
_logger.exception(e)
sys.exit(1)

try:
_run_server(backend_store_uri, default_artifact_root, host, port, workers, static_prefix,
gunicorn_opts)
Expand Down
6 changes: 3 additions & 3 deletions mlflow/server/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
_store = None


def _get_store():
def _get_store(backend_store_uri=None, default_artifact_root=None):
from mlflow.server import BACKEND_STORE_URI_ENV_VAR, ARTIFACT_ROOT_ENV_VAR
global _store
if _store is None:
store_dir = os.environ.get(BACKEND_STORE_URI_ENV_VAR, None)
artifact_root = os.environ.get(ARTIFACT_ROOT_ENV_VAR, None)
store_dir = backend_store_uri or os.environ.get(BACKEND_STORE_URI_ENV_VAR, None)
artifact_root = default_artifact_root or os.environ.get(ARTIFACT_ROOT_ENV_VAR, None)
if _is_database_uri(store_dir):
from mlflow.store.sqlalchemy_store import SqlAlchemyStore
_store = SqlAlchemyStore(store_dir, artifact_root)
Expand Down
22 changes: 21 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from click.testing import CliRunner
from mock import mock
import pytest

from mlflow.cli import server, run
from mlflow.cli import run, server, ui
from mlflow.server import handlers


def test_server_static_prefix_validation():
Expand All @@ -21,6 +23,24 @@ def test_server_static_prefix_validation():
run_server_mock.assert_not_called()


def test_server_default_artifact_root_validation():
with mock.patch("mlflow.cli._run_server") as run_server_mock:
result = CliRunner().invoke(server, ["--backend-store-uri", "postgresql://"])
assert result.output.startswith("Option 'default-artifact-root' is required")
run_server_mock.assert_not_called()


@pytest.mark.parametrize("command", [server, ui])
def test_tracking_uri_validation(command):
handlers._store = None
with mock.patch("mlflow.cli._run_server") as run_server_mock:
# SQLAlchemy expects postgresql:// not postgres://
CliRunner().invoke(command,
["--backend-store-uri", "postgres://user:pwd@host:5432/mydb",
"--default-artifact-root", "./mlruns"])
run_server_mock.assert_not_called()


def test_mlflow_run():
with mock.patch("mlflow.cli.projects") as mock_projects:
result = CliRunner().invoke(run)
Expand Down

0 comments on commit bb8c760

Please sign in to comment.