Skip to content

Commit

Permalink
Resolve circular dependencies
Browse files Browse the repository at this point in the history
Signed-off-by: Levko Kravets <levko.ne@gmail.com>
  • Loading branch information
kravets-levko committed Aug 14, 2024
1 parent a7be4cc commit 6f224b3
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 56 deletions.
2 changes: 1 addition & 1 deletion src/databricks/sql/auth/thrift_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager
from urllib3.util import make_headers
from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy
from databricks.sql.utils import SSLOptions
from databricks.sql.types import SSLOptions

logger = logging.getLogger(__name__)

Expand Down
3 changes: 1 addition & 2 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
ParamEscaper,
inject_parameters,
transform_paramstyle,
SSLOptions,
)
from databricks.sql.parameters.native import (
DbsqlParameterBase,
Expand All @@ -36,7 +35,7 @@
)


from databricks.sql.types import Row
from databricks.sql.types import Row, SSLOptions
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.experimental.oauth_persistence import OAuthPersistence

Expand Down
3 changes: 1 addition & 2 deletions src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging

from ssl import SSLContext
from concurrent.futures import ThreadPoolExecutor, Future
from typing import List, Union

Expand All @@ -9,7 +8,7 @@
DownloadableResultSettings,
DownloadedFile,
)
from databricks.sql.utils import SSLOptions
from databricks.sql.types import SSLOptions

from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

Expand Down
2 changes: 1 addition & 1 deletion src/databricks/sql/cloudfetch/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
from databricks.sql.exc import Error
from databricks.sql.utils import SSLOptions
from databricks.sql.types import SSLOptions

logger = logging.getLogger(__name__)

Expand Down
3 changes: 1 addition & 2 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import time
import uuid
import threading
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
from typing import List, Union

import pyarrow
Expand Down Expand Up @@ -35,8 +34,8 @@
convert_arrow_based_set_to_arrow_table,
convert_decimals_in_arrow_table,
convert_column_based_set_to_arrow_table,
SSLOptions,
)
from databricks.sql.types import SSLOptions

logger = logging.getLogger(__name__)

Expand Down
48 changes: 48 additions & 0 deletions src/databricks/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,54 @@
from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar
import datetime
import decimal
from ssl import SSLContext, CERT_NONE, CERT_REQUIRED, create_default_context


class SSLOptions:
tls_verify: bool
tls_verify_hostname: bool
tls_trusted_ca_file: Optional[str]
tls_client_cert_file: Optional[str]
tls_client_cert_key_file: Optional[str]
tls_client_cert_key_password: Optional[str]

def __init__(
self,
tls_verify: Optional[bool] = True,
tls_verify_hostname: Optional[bool] = True,
tls_trusted_ca_file: Optional[str] = None,
tls_client_cert_file: Optional[str] = None,
tls_client_cert_key_file: Optional[str] = None,
tls_client_cert_key_password: Optional[str] = None,
):
self.tls_verify = tls_verify
self.tls_verify_hostname = tls_verify_hostname
self.tls_trusted_ca_file = tls_trusted_ca_file
self.tls_client_cert_file = tls_client_cert_file
self.tls_client_cert_key_file = tls_client_cert_key_file
self.tls_client_cert_key_password = tls_client_cert_key_password

def create_ssl_context(self) -> SSLContext:
ssl_context = create_default_context(cafile=self.tls_trusted_ca_file)

if self.tls_verify is False:
ssl_context.check_hostname = False
ssl_context.verify_mode = CERT_NONE
elif self.tls_verify_hostname is False:
ssl_context.check_hostname = False
ssl_context.verify_mode = CERT_REQUIRED
else:
ssl_context.check_hostname = True
ssl_context.verify_mode = CERT_REQUIRED

if self.tls_client_cert_file:
ssl_context.load_cert_chain(
certfile=self.tls_client_cert_file,
keyfile=self.tls_client_cert_key_file,
password=self.tls_client_cert_key_password,
)

return ssl_context


class Row(tuple):
Expand Down
49 changes: 1 addition & 48 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Union
import re
from ssl import SSLContext, CERT_NONE, CERT_REQUIRED, create_default_context

import lz4.frame
import pyarrow
Expand All @@ -21,6 +20,7 @@
TSparkArrowResultLink,
TSparkRowSetType,
)
from databricks.sql.types import SSLOptions

from databricks.sql.parameters.native import ParameterStructure, TDbsqlParameter

Expand All @@ -31,53 +31,6 @@
logger = logging.getLogger(__name__)


class SSLOptions:
tls_verify: bool
tls_verify_hostname: bool
tls_trusted_ca_file: Optional[str]
tls_client_cert_file: Optional[str]
tls_client_cert_key_file: Optional[str]
tls_client_cert_key_password: Optional[str]

def __init__(
self,
tls_verify: Optional[bool] = True,
tls_verify_hostname: Optional[bool] = True,
tls_trusted_ca_file: Optional[str] = None,
tls_client_cert_file: Optional[str] = None,
tls_client_cert_key_file: Optional[str] = None,
tls_client_cert_key_password: Optional[str] = None,
):
self.tls_verify = tls_verify
self.tls_verify_hostname = tls_verify_hostname
self.tls_trusted_ca_file = tls_trusted_ca_file
self.tls_client_cert_file = tls_client_cert_file
self.tls_client_cert_key_file = tls_client_cert_key_file
self.tls_client_cert_key_password = tls_client_cert_key_password

def create_ssl_context(self) -> SSLContext:
ssl_context = create_default_context(cafile=self.tls_trusted_ca_file)

if self.tls_verify is False:
ssl_context.check_hostname = False
ssl_context.verify_mode = CERT_NONE
elif self.tls_verify_hostname is False:
ssl_context.check_hostname = False
ssl_context.verify_mode = CERT_REQUIRED
else:
ssl_context.check_hostname = True
ssl_context.verify_mode = CERT_REQUIRED

if self.tls_client_cert_file:
ssl_context.load_cert_chain(
certfile=self.tls_client_cert_file,
keyfile=self.tls_client_cert_key_file,
password=self.tls_client_cert_key_password,
)

return ssl_context


class ResultSetQueue(ABC):
@abstractmethod
def next_n_rows(self, num_rows: int) -> pyarrow.Table:
Expand Down

0 comments on commit 6f224b3

Please sign in to comment.