Skip to content

Commit

Permalink
Add HTTP header with MLflow version in Python client (mlflow#1131)
Browse files Browse the repository at this point in the history
  • Loading branch information
aarondav committed Apr 17, 2019
1 parent 9199dfc commit 44a501b
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 17 deletions.
8 changes: 7 additions & 1 deletion mlflow/utils/rest_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy
import requests

from mlflow import __version__
from mlflow.utils.string_utils import strip_suffix
from mlflow.exceptions import MlflowException, RestException

Expand All @@ -17,6 +18,11 @@
_logger = logging.getLogger(__name__)


_DEFAULT_HEADERS = {
'User-Agent': 'mlflow-python-client/%s' % __version__
}


def http_request(host_creds, endpoint, retries=3, retry_interval=3, **kwargs):
"""
Makes an HTTP request with the specified method to the specified hostname/endpoint. Retries
Expand All @@ -36,7 +42,7 @@ def http_request(host_creds, endpoint, retries=3, retry_interval=3, **kwargs):
elif host_creds.token:
auth_str = "Bearer %s" % host_creds.token

headers = {}
headers = dict(_DEFAULT_HEADERS)
if auth_str:
headers['Authorization'] = auth_str

Expand Down
8 changes: 5 additions & 3 deletions tests/projects/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_RUN_URL, \
MLFLOW_DATABRICKS_SHELL_JOB_RUN_ID, \
MLFLOW_DATABRICKS_WEBAPP_URL
from mlflow.utils.rest_utils import _DEFAULT_HEADERS


from tests.projects.utils import validate_exit_status, TEST_PROJECT_DIR
from tests.projects.utils import tracking_uri_mock # pylint: disable=unused-import
Expand Down Expand Up @@ -273,12 +275,12 @@ def get_config(self):
def test_databricks_http_request_integration(get_config, request):
"""Confirms that the databricks http request params can in fact be used as an HTTP request"""
def confirm_request_params(**kwargs):
headers = dict(_DEFAULT_HEADERS)
headers['Authorization'] = 'Basic dXNlcjpwYXNz'
assert kwargs == {
'method': 'PUT',
'url': 'host/clusters/list',
'headers': {
'Authorization': 'Basic dXNlcjpwYXNz'
},
'headers': headers,
'verify': True,
'json': {'a': 'b'}
}
Expand Down
4 changes: 2 additions & 2 deletions tests/store/test_rest_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mlflow.store.rest_store import RestStore
from mlflow.utils.proto_json_utils import message_to_json

from mlflow.utils.rest_utils import MlflowHostCreds
from mlflow.utils.rest_utils import MlflowHostCreds, _DEFAULT_HEADERS


class TestRestStore(unittest.TestCase):
Expand All @@ -25,7 +25,7 @@ def mock_request(**kwargs):
'method': 'GET',
'params': {'view_type': 'ACTIVE_ONLY'},
'url': 'https://hello/api/2.0/preview/mlflow/experiments/list',
'headers': {},
'headers': _DEFAULT_HEADERS,
'verify': True,
}
response = mock.MagicMock
Expand Down
22 changes: 11 additions & 11 deletions tests/utils/test_rest_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from mlflow.utils.rest_utils import NumpyEncoder, http_request, http_request_safe,\
MlflowHostCreds
MlflowHostCreds, _DEFAULT_HEADERS
from mlflow.exceptions import MlflowException, RestException


Expand All @@ -19,7 +19,7 @@ def test_http_request_hostonly(request):
request.assert_called_with(
url='http://my-host/my/endpoint',
verify=True,
headers={},
headers=_DEFAULT_HEADERS,
)


Expand All @@ -34,7 +34,7 @@ def test_http_request_cleans_hostname(request):
request.assert_called_with(
url='http://my-host/my/endpoint',
verify=True,
headers={},
headers=_DEFAULT_HEADERS,
)


Expand All @@ -45,12 +45,12 @@ def test_http_request_with_basic_auth(request):
response.status_code = 200
request.return_value = response
http_request(host_only, '/my/endpoint')
headers = dict(_DEFAULT_HEADERS)
headers['Authorization'] = 'Basic dXNlcjpwYXNz'
request.assert_called_with(
url='http://my-host/my/endpoint',
verify=True,
headers={
'Authorization': 'Basic dXNlcjpwYXNz'
},
headers=headers,
)


Expand All @@ -61,12 +61,12 @@ def test_http_request_with_token(request):
response.status_code = 200
request.return_value = response
http_request(host_only, '/my/endpoint')
headers = dict(_DEFAULT_HEADERS)
headers['Authorization'] = 'Bearer my-token'
request.assert_called_with(
url='http://my-host/my/endpoint',
verify=True,
headers={
'Authorization': 'Bearer my-token'
},
headers=headers,
)


Expand All @@ -80,7 +80,7 @@ def test_http_request_with_insecure(request):
request.assert_called_with(
url='http://my-host/my/endpoint',
verify=False,
headers={},
headers=_DEFAULT_HEADERS,
)


Expand All @@ -94,7 +94,7 @@ def test_http_request_wrapper(request):
request.assert_called_with(
url='http://my-host/my/endpoint',
verify=False,
headers={},
headers=_DEFAULT_HEADERS,
)
response.status_code = 400
response.text = ""
Expand Down

0 comments on commit 44a501b

Please sign in to comment.