Skip to content

Commit

Permalink
Allow specifying HTTP auth via environment variables via REST store (m…
Browse files Browse the repository at this point in the history
  • Loading branch information
aarondav committed Aug 30, 2018
1 parent 07a7116 commit a3b2605
Show file tree
Hide file tree
Showing 14 changed files with 472 additions and 273 deletions.
21 changes: 8 additions & 13 deletions mlflow/projects/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@

from mlflow.entities import RunStatus
from mlflow.projects.submitted_run import SubmittedRun
from mlflow.utils import rest_utils, file_utils
from mlflow.utils import rest_utils, file_utils, databricks_utils
from mlflow.exceptions import ExecutionException
from mlflow.utils.logging_utils import eprint
from mlflow import tracking
from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_RUN_URL, MLFLOW_DATABRICKS_SHELL_JOB_ID, \
MLFLOW_DATABRICKS_SHELL_JOB_RUN_ID, MLFLOW_DATABRICKS_WEBAPP_URL
from mlflow.utils.rest_utils import get_databricks_http_request_kwargs_or_fail
from mlflow.version import VERSION

# Base directory within driver container for storing files related to MLflow
Expand All @@ -42,11 +41,9 @@ def __init__(self, databricks_profile):
self.databricks_profile = databricks_profile

def databricks_api_request(self, endpoint, method, **kwargs):
request_params = rest_utils.get_databricks_http_request_kwargs_or_fail(
self.databricks_profile)
request_params.update(kwargs)
host_creds = databricks_utils.get_databricks_host_creds(self.databricks_profile)
response = rest_utils.http_request(
endpoint=endpoint, method=method, **request_params)
host_creds=host_creds, endpoint=endpoint, method=method, **kwargs)
return json.loads(response.text)

def _jobs_runs_submit(self, json):
Expand All @@ -58,7 +55,7 @@ def _check_auth_available(self):
Verifies that information for making API requests to Databricks is available to MLflow,
raising an exception if not.
"""
rest_utils.get_databricks_http_request_kwargs_or_fail(self.databricks_profile)
databricks_utils.get_databricks_host_creds(self.databricks_profile)

def _upload_to_dbfs(self, src_path, dbfs_fuse_uri):
"""
Expand All @@ -67,11 +64,10 @@ def _upload_to_dbfs(self, src_path, dbfs_fuse_uri):
"""
eprint("=== Uploading project to DBFS path %s ===" % dbfs_fuse_uri)
http_endpoint = dbfs_fuse_uri
http_request_kwargs = \
rest_utils.get_databricks_http_request_kwargs_or_fail(self.databricks_profile)
host_creds = databricks_utils.get_databricks_host_creds(self.databricks_profile)
with open(src_path, 'rb') as f:
rest_utils.http_request(
endpoint=http_endpoint, method='POST', data=f, **http_request_kwargs)
host_creds=host_creds, endpoint=http_endpoint, method='POST', data=f)

def _dbfs_path_exists(self, dbfs_uri):
"""
Expand Down Expand Up @@ -308,14 +304,13 @@ def _print_description_and_log_tags(self):
run_info = self._job_runner.jobs_runs_get(self._databricks_run_id)
jobs_page_url = run_info["run_page_url"]
eprint("=== Check the run's status at %s ===" % jobs_page_url)
host_creds = databricks_utils.get_databricks_host_creds(self._job_runner.databricks_profile)
tracking.get_service().set_tag(self._mlflow_run_id,
MLFLOW_DATABRICKS_RUN_URL, jobs_page_url)
tracking.get_service().set_tag(self._mlflow_run_id,
MLFLOW_DATABRICKS_SHELL_JOB_RUN_ID, self._databricks_run_id)
tracking.get_service().set_tag(self._mlflow_run_id,
MLFLOW_DATABRICKS_WEBAPP_URL,
get_databricks_http_request_kwargs_or_fail(
profile=self._job_runner.databricks_profile)['hostname'])
MLFLOW_DATABRICKS_WEBAPP_URL, host_creds.host)
job_id = run_info.get('job_id')
# In some releases of Databricks we do not return the job ID. We start including it in DB
# releases 2.80 and above.
Expand Down
1 change: 1 addition & 0 deletions mlflow/spark.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
MLflow integration for Spark MLlib models.
This module enables the exporting of Spark MLlib models with the following flavors (formats):
1. Spark MLlib (native) format - Allows models to be loaded as Spark Transformers for scoring
in a Spark session. Models with this flavor can be loaded
back as PySpark PipelineModel objects in Python. This
Expand Down
8 changes: 4 additions & 4 deletions mlflow/store/artifact_repo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod, ABCMeta

from mlflow.store.rest_store import RestStore
from mlflow.exceptions import MlflowException
from mlflow.store.rest_store import DatabricksStore


class ArtifactRepository:
Expand Down Expand Up @@ -83,9 +83,9 @@ def from_artifact_uri(artifact_uri, store):
return SFTPArtifactRepository(artifact_uri)
elif artifact_uri.startswith("dbfs:/"):
from mlflow.store.dbfs_artifact_repo import DbfsArtifactRepository
if not isinstance(store, DatabricksStore):
raise MlflowException('`store` must be an instance of DatabricksStore.')
return DbfsArtifactRepository(artifact_uri, store.http_request_kwargs)
if not isinstance(store, RestStore):
raise MlflowException('`store` must be an instance of RestStore.')
return DbfsArtifactRepository(artifact_uri, store.get_host_creds)
else:
from mlflow.store.local_artifact_repo import LocalArtifactRepository
return LocalArtifactRepository(artifact_uri)
84 changes: 36 additions & 48 deletions mlflow/store/dbfs_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,57 +13,45 @@
DOWNLOAD_CHUNK_SIZE = 1024


def _dbfs_list_api(json, http_request_kwargs):
"""
Pulled out to make it easier to mock.
"""
return http_request(endpoint=LIST_API_ENDPOINT, method='GET',
json=json, **http_request_kwargs)


def _dbfs_download(output_path, endpoint, http_request_kwargs):
"""
Pulled out to make it easier to mock.
"""
with open(output_path, 'wb') as f:
response = http_request(endpoint=endpoint, method='GET', stream=True,
**http_request_kwargs)
try:
for content in response.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
f.write(content)
finally:
response.close()


def _dbfs_is_dir(dbfs_path, http_request_kwargs):
response = http_request(endpoint=GET_STATUS_ENDPOINT, method='GET',
json={'path': dbfs_path}, **http_request_kwargs)
json_response = json.loads(response.text)
try:
return json_response['is_dir']
except KeyError:
raise Exception('DBFS path %s does not exist' % dbfs_path)


class DbfsArtifactRepository(ArtifactRepository):
"""
Stores artifacts on DBFS.
This repository is used with URIs of the form ``dbfs:/<path>``. The repository can only be used
together with the DatabricksStore.
together with the RestStore.
"""

def __init__(self, artifact_uri, http_request_kwargs):
"""
:param http_request_kwargs arguments to add to rest_utils.http_request for all requests
'hostname', 'headers', and 'verify' are required.
Should include authentication information to Databricks.
"""
def __init__(self, artifact_uri, get_host_creds):
cleaned_artifact_uri = artifact_uri.rstrip('/')
super(DbfsArtifactRepository, self).__init__(cleaned_artifact_uri)
self.get_host_creds = get_host_creds
if not cleaned_artifact_uri.startswith('dbfs:/'):
raise MlflowException('DbfsArtifactRepository URI must start with dbfs:/')
self.http_request_kwargs = http_request_kwargs

def _databricks_api_request(self, **kwargs):
host_creds = self.get_host_creds()
return http_request(host_creds, **kwargs)

def _dbfs_list_api(self, json):
return self._databricks_api_request(endpoint=LIST_API_ENDPOINT, method='GET', json=json)

def _dbfs_download(self, output_path, endpoint):
with open(output_path, 'wb') as f:
response = self._databricks_api_request(endpoint=endpoint, method='GET', stream=True)
try:
for content in response.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
f.write(content)
finally:
response.close()

def _dbfs_is_dir(self, dbfs_path):
response = self._databricks_api_request(
endpoint=GET_STATUS_ENDPOINT, method='GET', json={'path': dbfs_path})
json_response = json.loads(response.text)
try:
return json_response['is_dir']
except KeyError:
raise MlflowException('DBFS path %s does not exist' % dbfs_path)

def _get_dbfs_path(self, artifact_path):
return '/%s/%s' % (strip_prefix(self.artifact_uri, 'dbfs:/'),
Expand All @@ -81,8 +69,8 @@ def log_artifact(self, local_file, artifact_path=None):
else:
http_endpoint = self._get_dbfs_endpoint(os.path.basename(local_file))
with open(local_file, 'rb') as f:
response = http_request(endpoint=http_endpoint, method='POST', data=f,
allow_redirects=False, **self.http_request_kwargs)
response = self._databricks_api_request(
endpoint=http_endpoint, method='POST', data=f, allow_redirects=False)
if response.status_code == 409:
raise MlflowException('File already exists at {} and can\'t be overwritten.'
.format(http_endpoint))
Expand All @@ -103,8 +91,8 @@ def log_artifacts(self, local_dir, artifact_path=None):
for name in filenames:
endpoint = build_path(dir_http_endpoint, name)
with open(build_path(dirpath, name), 'rb') as f:
response = http_request(endpoint=endpoint, method='POST', data=f,
allow_redirects=False, **self.http_request_kwargs)
response = self._databricks_api_request(
endpoint=endpoint, method='POST', data=f, allow_redirects=False)
if response.status_code == 409:
raise MlflowException('File already exists at {} and can\'t be overwritten.'
.format(endpoint))
Expand All @@ -117,7 +105,7 @@ def list_artifacts(self, path=None):
dbfs_list_json = {'path': self._get_dbfs_path(path)}
else:
dbfs_list_json = {'path': self._get_dbfs_path('')}
response = _dbfs_list_api(dbfs_list_json, self.http_request_kwargs)
response = self._dbfs_list_api(dbfs_list_json)
json_response = json.loads(response.text)
# /api/2.0/dbfs/list will not have the 'files' key in the response for empty directories.
infos = []
Expand All @@ -141,13 +129,13 @@ def _download_artifacts_into(self, artifact_path, dest_dir):
basename = os.path.basename(artifact_path)
local_path = build_path(dest_dir, basename)
dbfs_path = self._get_dbfs_path(artifact_path)
if _dbfs_is_dir(dbfs_path, self.http_request_kwargs):
if self._dbfs_is_dir(dbfs_path):
# Artifact_path is a directory, so make a directory for it and download everything
if not os.path.exists(local_path):
os.mkdir(local_path)
for file_info in self.list_artifacts(artifact_path):
self._download_artifacts_into(file_info.path, local_path)
else:
_dbfs_download(output_path=local_path, endpoint=self._get_dbfs_endpoint(artifact_path),
http_request_kwargs=self.http_request_kwargs)
self._dbfs_download(output_path=local_path,
endpoint=self._get_dbfs_endpoint(artifact_path))
return local_path
32 changes: 8 additions & 24 deletions mlflow/store/rest_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,24 @@ def __init__(self, json):
class RestStore(AbstractStore):
"""
Client for a remote tracking server accessed via REST API calls
:param http_request_kwargs arguments to add to rest_utils.http_request for all requests.
'hostname' is required.
:param get_host_creds: Method to be invoked prior to every REST request to get the
:py:class:`mlflow.rest_utils.MlflowHostCreds` for the request. Note that this
is a function so that we can obtain fresh credentials in the case of expiry.
"""

def __init__(self, http_request_kwargs):
def __init__(self, get_host_creds):
super(RestStore, self).__init__()
self.http_request_kwargs = http_request_kwargs
if not http_request_kwargs['hostname']:
raise Exception('hostname must be provided to RestStore')
self.get_host_creds = get_host_creds

def _call_endpoint(self, api, json_body):
endpoint, method = _METHOD_TO_INFO[api]
response_proto = api.Response()
# Convert json string to json dictionary, to pass to requests
if json_body:
json_body = json.loads(json_body)
response = http_request(endpoint=endpoint, method=method,
json=json_body, **self.http_request_kwargs)
host_creds = self.get_host_creds()
response = http_request(host_creds=host_creds, endpoint=endpoint, method=method,
json=json_body)
js_dict = json.loads(response.text)

if 'error_code' in js_dict:
Expand Down Expand Up @@ -241,19 +241,3 @@ def list_run_infos(self, experiment_id):
"""
runs = self.search_runs(experiment_ids=[experiment_id], search_expressions=[])
return [run.info for run in runs]


class DatabricksStore(RestStore):
"""
A specific type of RestStore which includes authentication information to Databricks.
:param http_request_kwargs arguments to add to rest_utils.http_request for all requests.
'hostname', 'headers', and 'secure_verify' are required.
"""
def __init__(self, http_request_kwargs):
if http_request_kwargs['hostname'] is None:
raise Exception('hostname must be provided to DatabricksStore')
if http_request_kwargs['headers'] is None:
raise Exception('headers must be provided to DatabricksStore')
if http_request_kwargs['verify'] is None:
raise Exception('verify must be provided to DatabricksStore')
super(DatabricksStore, self).__init__(http_request_kwargs)
24 changes: 20 additions & 4 deletions mlflow/tracking/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,24 @@
from six.moves import urllib

from mlflow.store.file_store import FileStore
from mlflow.store.rest_store import RestStore, DatabricksStore
from mlflow.store.rest_store import RestStore
from mlflow.store.artifact_repo import ArtifactRepository
from mlflow.utils import env, rest_utils
from mlflow.utils.databricks_utils import get_databricks_host_creds


_TRACKING_URI_ENV_VAR = "MLFLOW_TRACKING_URI"
_LOCAL_FS_URI_PREFIX = "file:///"
_REMOTE_URI_PREFIX = "http://"

# Extra environment variables which take precedence for setting the basic/bearer
# auth on http requests.
_TRACKING_USERNAME_ENV_VAR = "MLFLOW_TRACKING_USERNAME"
_TRACKING_PASSWORD_ENV_VAR = "MLFLOW_TRACKING_PASSWORD"
_TRACKING_TOKEN_ENV_VAR = "MLFLOW_TRACKING_TOKEN"
_TRACKING_INSECURE_TLS_ENV_VAR = "MLFLOW_TRACKING_INSECURE_TLS"


_tracking_uri = None


Expand Down Expand Up @@ -98,7 +107,15 @@ def _get_file_store(store_uri):


def _get_rest_store(store_uri):
return RestStore({'hostname': store_uri})
def get_default_host_creds():
return rest_utils.MlflowHostCreds(
host=store_uri,
username=os.environ.get(_TRACKING_USERNAME_ENV_VAR),
password=os.environ.get(_TRACKING_PASSWORD_ENV_VAR),
token=os.environ.get(_TRACKING_TOKEN_ENV_VAR),
ignore_tls_verification=os.environ.get(_TRACKING_INSECURE_TLS_ENV_VAR) == 'true',
)
return RestStore(get_default_host_creds)


def get_db_profile_from_uri(uri):
Expand All @@ -114,8 +131,7 @@ def get_db_profile_from_uri(uri):

def _get_databricks_rest_store(store_uri):
profile = get_db_profile_from_uri(store_uri)
http_request_kwargs = rest_utils.get_databricks_http_request_kwargs_or_fail(profile)
return DatabricksStore(http_request_kwargs)
return RestStore(lambda: get_databricks_host_creds(profile))


def _get_model_log_dir(model_name, run_id):
Expand Down
44 changes: 44 additions & 0 deletions mlflow/utils/databricks_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from mlflow.exceptions import MlflowException
from mlflow.utils.rest_utils import MlflowHostCreds
from mlflow.utils.logging_utils import eprint
from databricks_cli.configure import provider


def _get_dbutils():
try:
import IPython
Expand Down Expand Up @@ -44,3 +50,41 @@ def get_notebook_path():
def get_webapp_url():
"""Should only be called if is_in_databricks_notebook is true"""
return _get_extra_context("api_url")


def _fail_malformed_databricks_auth(profile):
raise MlflowException("Got malformed Databricks CLI profile '%s'. Please make sure the "
"Databricks CLI is properly configured as described at "
"https://github.com/databricks/databricks-cli." % profile)


def get_databricks_host_creds(profile=None):
"""
Reads in configuration necessary to make HTTP requests to a Databricks server. This
uses the Databricks CLI's ConfigProvider interface to load the DatabricksConfig object.
This method will throw an exception if sufficient auth cannot be found.
:param profile: Databricks CLI profile. If not provided, we will read the default profile.
:return: :py:class:`mlflow.rest_utils.MlflowHostCreds` which includes the hostname and
authentication information necessary to talk to the Databricks server.
"""
if not hasattr(provider, 'get_config'):
eprint("Warning: support for databricks-cli<0.8.0 is deprecated and will be removed"
" in a future version.")
config = provider.get_config_for_profile(profile)
elif profile:
config = provider.ProfileConfigProvider(profile).get_config()
else:
config = provider.get_config()

if not config or not config.host:
_fail_malformed_databricks_auth(profile)

insecure = hasattr(config, 'insecure') and config.insecure

if config.username is not None and config.password is not None:
return MlflowHostCreds(config.host, username=config.username, password=config.password,
ignore_tls_verification=insecure)
elif config.token:
return MlflowHostCreds(config.host, token=config.token, ignore_tls_verification=insecure)
_fail_malformed_databricks_auth(profile)
Loading

0 comments on commit a3b2605

Please sign in to comment.