Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PECO-1857] Use SSL options with HTTPS connection pool #425

Merged
merged 7 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions src/databricks/sql/auth/thrift_http_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import base64
import logging
import urllib.parse
from typing import Dict, Union
from typing import Dict, Union, Optional

import six
import thrift

logger = logging.getLogger(__name__)

import ssl
import warnings
from http.client import HTTPResponse
Expand All @@ -16,6 +14,9 @@
from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager
from urllib3.util import make_headers
from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy
from databricks.sql.types import SSLOptions

logger = logging.getLogger(__name__)


class THttpClient(thrift.transport.THttpClient.THttpClient):
Expand All @@ -25,13 +26,12 @@ def __init__(
uri_or_host,
port=None,
path=None,
cafile=None,
cert_file=None,
key_file=None,
ssl_context=None,
ssl_options: Optional[SSLOptions] = None,
max_connections: int = 1,
retry_policy: Union[DatabricksRetryPolicy, int] = 0,
):
self._ssl_options = ssl_options

if port is not None:
warnings.warn(
"Please use the THttpClient('http{s}://host:port/path') constructor",
Expand All @@ -48,13 +48,11 @@ def __init__(
self.scheme = parsed.scheme
assert self.scheme in ("http", "https")
if self.scheme == "https":
self.certfile = cert_file
self.keyfile = key_file
self.context = (
ssl.create_default_context(cafile=cafile)
if (cafile and not ssl_context)
else ssl_context
)
if self._ssl_options is not None:
# TODO: Not sure if those options are used anywhere - need to double-check
kravets-levko marked this conversation as resolved.
Show resolved Hide resolved
self.certfile = self._ssl_options.tls_client_cert_file
self.keyfile = self._ssl_options.tls_client_cert_key_file
self.context = self._ssl_options.create_ssl_context()
self.port = parsed.port
self.host = parsed.hostname
self.path = parsed.path
Expand Down Expand Up @@ -109,12 +107,23 @@ def startRetryTimer(self):
def open(self):

# self.__pool replaces the self.__http used by the original THttpClient
_pool_kwargs = {"maxsize": self.max_connections}

if self.scheme == "http":
pool_class = HTTPConnectionPool
elif self.scheme == "https":
pool_class = HTTPSConnectionPool

_pool_kwargs = {"maxsize": self.max_connections}
_pool_kwargs.update(
{
"cert_reqs": ssl.CERT_REQUIRED
if self._ssl_options.tls_verify
else ssl.CERT_NONE,
"ca_certs": self._ssl_options.tls_trusted_ca_file,
"cert_file": self._ssl_options.tls_client_cert_file,
"key_file": self._ssl_options.tls_client_cert_key_file,
"key_password": self._ssl_options.tls_client_cert_key_password,
}
)

if self.using_proxy():
proxy_manager = ProxyManager(
Expand Down
18 changes: 16 additions & 2 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)


from databricks.sql.types import Row
from databricks.sql.types import Row, SSLOptions
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.experimental.oauth_persistence import OAuthPersistence

Expand Down Expand Up @@ -178,8 +178,9 @@ def read(self) -> Optional[OAuthToken]:
# _tls_trusted_ca_file
# Set to the path of the file containing trusted CA certificates for server certificate
# verification. If not provide, uses system truststore.
# _tls_client_cert_file, _tls_client_cert_key_file
# _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
# Set client SSL certificate.
# See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
# _retry_stop_after_attempts_count
# The maximum number of attempts during a request retry sequence (defaults to 24)
# _socket_timeout
Expand Down Expand Up @@ -220,12 +221,25 @@ def read(self) -> Optional[OAuthToken]:

base_headers = [("User-Agent", useragent_header)]

self._ssl_options = SSLOptions(
# Double negation is generally a bad thing, but we have to keep backward compatibility
tls_verify=not kwargs.get(
"_tls_no_verify", False
), # by default - verify cert and host
tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
)

self.thrift_backend = ThriftBackend(
self.host,
self.port,
http_path,
(http_headers or []) + base_headers,
auth_provider,
ssl_options=self._ssl_options,
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
**kwargs,
)
Expand Down
9 changes: 5 additions & 4 deletions src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging

from ssl import SSLContext
from concurrent.futures import ThreadPoolExecutor, Future
from typing import List, Union

Expand All @@ -9,6 +8,8 @@
DownloadableResultSettings,
DownloadedFile,
)
from databricks.sql.types import SSLOptions

from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

logger = logging.getLogger(__name__)
Expand All @@ -20,7 +21,7 @@ def __init__(
links: List[TSparkArrowResultLink],
max_download_threads: int,
lz4_compressed: bool,
ssl_context: SSLContext,
ssl_options: SSLOptions,
):
self._pending_links: List[TSparkArrowResultLink] = []
for link in links:
Expand All @@ -38,7 +39,7 @@ def __init__(
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)

self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
self._ssl_context = ssl_context
self._ssl_options = ssl_options

def get_next_downloaded_file(
self, next_row_offset: int
Expand Down Expand Up @@ -95,7 +96,7 @@ def _schedule_downloads(self):
handler = ResultSetDownloadHandler(
settings=self._downloadable_result_settings,
link=link,
ssl_context=self._ssl_context,
ssl_options=self._ssl_options,
)
task = self._thread_pool.submit(handler.run)
self._download_tasks.append(task)
Expand Down
12 changes: 5 additions & 7 deletions src/databricks/sql/cloudfetch/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@

import requests
from requests.adapters import HTTPAdapter, Retry
from ssl import SSLContext, CERT_NONE
import lz4.frame
import time

from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

from databricks.sql.exc import Error
from databricks.sql.types import SSLOptions

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,11 +65,11 @@ def __init__(
self,
settings: DownloadableResultSettings,
link: TSparkArrowResultLink,
ssl_context: SSLContext,
ssl_options: SSLOptions,
):
self.settings = settings
self.link = link
self._ssl_context = ssl_context
self._ssl_options = ssl_options

def run(self) -> DownloadedFile:
"""
Expand All @@ -95,14 +94,13 @@ def run(self) -> DownloadedFile:
session.mount("http://", HTTPAdapter(max_retries=retryPolicy))
session.mount("https://", HTTPAdapter(max_retries=retryPolicy))

ssl_verify = self._ssl_context.verify_mode != CERT_NONE

try:
# Get the file via HTTP request
response = session.get(
self.link.fileLink,
timeout=self.settings.download_timeout,
verify=ssl_verify,
verify=self._ssl_options.tls_verify,
# TODO: Pass cert from `self._ssl_options`
)
response.raise_for_status()

Expand Down
43 changes: 6 additions & 37 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import time
import uuid
import threading
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
from typing import List, Union

import pyarrow
Expand Down Expand Up @@ -36,6 +35,7 @@
convert_decimals_in_arrow_table,
convert_column_based_set_to_arrow_table,
)
from databricks.sql.types import SSLOptions

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -85,6 +85,7 @@ def __init__(
http_path: str,
http_headers,
auth_provider: AuthProvider,
ssl_options: SSLOptions,
staging_allowed_local_path: Union[None, str, List[str]] = None,
**kwargs,
):
Expand All @@ -93,16 +94,6 @@ def __init__(
# Tag to add to User-Agent header. For use by partners.
# _username, _password
# Username and password Basic authentication (no official support)
# _tls_no_verify
# Set to True (Boolean) to completely disable SSL verification.
# _tls_verify_hostname
# Set to False (Boolean) to disable SSL hostname verification, but check certificate.
# _tls_trusted_ca_file
# Set to the path of the file containing trusted CA certificates for server certificate
# verification. If not provide, uses system truststore.
# _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
# Set client SSL certificate.
# See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
# _connection_uri
# Overrides server_hostname and http_path.
# RETRY/ATTEMPT POLICY
Expand Down Expand Up @@ -162,29 +153,7 @@ def __init__(
# Cloud fetch
self.max_download_threads = kwargs.get("max_download_threads", 10)

# Configure tls context
ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file"))
if kwargs.get("_tls_no_verify") is True:
ssl_context.check_hostname = False
ssl_context.verify_mode = CERT_NONE
elif kwargs.get("_tls_verify_hostname") is False:
ssl_context.check_hostname = False
ssl_context.verify_mode = CERT_REQUIRED
else:
ssl_context.check_hostname = True
ssl_context.verify_mode = CERT_REQUIRED

tls_client_cert_file = kwargs.get("_tls_client_cert_file")
tls_client_cert_key_file = kwargs.get("_tls_client_cert_key_file")
tls_client_cert_key_password = kwargs.get("_tls_client_cert_key_password")
if tls_client_cert_file:
ssl_context.load_cert_chain(
certfile=tls_client_cert_file,
keyfile=tls_client_cert_key_file,
password=tls_client_cert_key_password,
)

self._ssl_context = ssl_context
self._ssl_options = ssl_options

self._auth_provider = auth_provider

Expand Down Expand Up @@ -225,7 +194,7 @@ def __init__(
self._transport = databricks.sql.auth.thrift_http_client.THttpClient(
auth_provider=self._auth_provider,
uri_or_host=uri,
ssl_context=self._ssl_context,
ssl_options=self._ssl_options,
**additional_transport_args, # type: ignore
)

Expand Down Expand Up @@ -776,7 +745,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
max_download_threads=self.max_download_threads,
lz4_compressed=lz4_compressed,
description=description,
ssl_context=self._ssl_context,
ssl_options=self._ssl_options,
)
else:
arrow_queue_opt = None
Expand Down Expand Up @@ -1008,7 +977,7 @@ def fetch_results(
max_download_threads=self.max_download_threads,
lz4_compressed=lz4_compressed,
description=description,
ssl_context=self._ssl_context,
ssl_options=self._ssl_options,
)

return queue, resp.hasMoreRows
Expand Down
48 changes: 48 additions & 0 deletions src/databricks/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,54 @@
from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar
import datetime
import decimal
from ssl import SSLContext, CERT_NONE, CERT_REQUIRED, create_default_context


class SSLOptions:
tls_verify: bool
tls_verify_hostname: bool
tls_trusted_ca_file: Optional[str]
tls_client_cert_file: Optional[str]
tls_client_cert_key_file: Optional[str]
tls_client_cert_key_password: Optional[str]

def __init__(
self,
tls_verify: bool = True,
tls_verify_hostname: bool = True,
tls_trusted_ca_file: Optional[str] = None,
tls_client_cert_file: Optional[str] = None,
tls_client_cert_key_file: Optional[str] = None,
tls_client_cert_key_password: Optional[str] = None,
):
self.tls_verify = tls_verify
self.tls_verify_hostname = tls_verify_hostname
self.tls_trusted_ca_file = tls_trusted_ca_file
self.tls_client_cert_file = tls_client_cert_file
self.tls_client_cert_key_file = tls_client_cert_key_file
self.tls_client_cert_key_password = tls_client_cert_key_password

def create_ssl_context(self) -> SSLContext:
ssl_context = create_default_context(cafile=self.tls_trusted_ca_file)

if self.tls_verify is False:
ssl_context.check_hostname = False
ssl_context.verify_mode = CERT_NONE
elif self.tls_verify_hostname is False:
ssl_context.check_hostname = False
ssl_context.verify_mode = CERT_REQUIRED
else:
ssl_context.check_hostname = True
ssl_context.verify_mode = CERT_REQUIRED

if self.tls_client_cert_file:
ssl_context.load_cert_chain(
certfile=self.tls_client_cert_file,
keyfile=self.tls_client_cert_key_file,
password=self.tls_client_cert_key_password,
)

return ssl_context


class Row(tuple):
Expand Down
Loading
Loading