Skip to content

Commit

Permalink
Cloud fetch queue and integration (#151)
Browse files Browse the repository at this point in the history
* Cloud fetch queue and integration

Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>

* Enable cloudfetch with direct results

Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>

* Typing and style changes

Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>

* Client-settable max_download_threads

Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>

* Docstrings and comments

Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>

* Increase default buffer size bytes to 104857600

Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>

* Move max_download_threads to kwargs of ThriftBackend, fix unit tests

Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>

* Fix tests: staticmethod make_arrow_table mock not callable

Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>

* cancel_futures in shutdown() only available in python >=3.9.0

Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>

* Black linting

Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>

* Fix typing errors

Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>

---------

Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>
  • Loading branch information
mattdeekay authored Jul 5, 2023
1 parent 01b7a8d commit 5a34a4a
Show file tree
Hide file tree
Showing 6 changed files with 596 additions and 136 deletions.
7 changes: 6 additions & 1 deletion src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

logger = logging.getLogger(__name__)

DEFAULT_RESULT_BUFFER_SIZE_BYTES = 10485760
DEFAULT_RESULT_BUFFER_SIZE_BYTES = 104857600
DEFAULT_ARRAY_SIZE = 100000


Expand Down Expand Up @@ -153,6 +153,8 @@ def read(self) -> Optional[OAuthToken]:
# _use_arrow_native_timestamps
# Databricks runtime will return native Arrow types for timestamps instead of Arrow strings
# (True by default)
# use_cloud_fetch
# Enable use of cloud fetch to extract large query results in parallel via cloud storage

if access_token:
access_token_kv = {"access_token": access_token}
Expand Down Expand Up @@ -189,6 +191,7 @@ def read(self) -> Optional[OAuthToken]:
self._session_handle = self.thrift_backend.open_session(
session_configuration, catalog, schema
)
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", False)
self.open = True
logger.info("Successfully opened session " + str(self.get_session_id_hex()))
self._cursors = [] # type: List[Cursor]
Expand Down Expand Up @@ -497,6 +500,7 @@ def execute(
max_bytes=self.buffer_size_bytes,
lz4_compression=self.connection.lz4_compression,
cursor=self,
use_cloud_fetch=self.connection.use_cloud_fetch,
)
self.active_result_set = ResultSet(
self.connection,
Expand Down Expand Up @@ -822,6 +826,7 @@ def __iter__(self):
break

def _fill_results_buffer(self):
# At initialization or if the server does not have cloud fetch result links available
results, has_more_rows = self.thrift_backend.fetch_results(
op_handle=self.command_id,
max_rows=self.arraysize,
Expand Down
4 changes: 2 additions & 2 deletions src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,6 @@ def _check_if_download_successful(self, handler: ResultSetDownloadHandler):
return True

def _shutdown_manager(self):
# Clear download handlers and shutdown the thread pool to cancel pending futures
# Clear download handlers and shutdown the thread pool
self.download_handlers = []
self.thread_pool.shutdown(wait=False, cancel_futures=True)
self.thread_pool.shutdown(wait=False)
151 changes: 39 additions & 112 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import time
import uuid
import threading
import lz4.frame
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
from typing import List, Union

Expand All @@ -26,11 +25,14 @@
)

from databricks.sql.utils import (
ArrowQueue,
ExecuteResponse,
_bound,
RequestErrorInfo,
NoRetryReason,
ResultSetQueueFactory,
convert_arrow_based_set_to_arrow_table,
convert_decimals_in_arrow_table,
convert_column_based_set_to_arrow_table,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,7 +69,6 @@
class ThriftBackend:
CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE
ERROR_OP_STATE = ttypes.TOperationState.ERROR_STATE
BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]

def __init__(
self,
Expand Down Expand Up @@ -115,6 +116,8 @@ def __init__(
# _socket_timeout
# The timeout in seconds for socket send, recv and connect operations. Should be a positive float or integer.
# (defaults to 900)
# max_download_threads
# Number of threads for handling cloud fetch downloads. Defaults to 10

port = port or 443
if kwargs.get("_connection_uri"):
Expand All @@ -136,6 +139,9 @@ def __init__(
"_use_arrow_native_timestamps", True
)

# Cloud fetch
self.max_download_threads = kwargs.get("max_download_threads", 10)

# Configure tls context
ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file"))
if kwargs.get("_tls_no_verify") is True:
Expand Down Expand Up @@ -558,108 +564,14 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
(
arrow_table,
num_rows,
) = ThriftBackend._convert_column_based_set_to_arrow_table(
t_row_set.columns, description
)
) = convert_column_based_set_to_arrow_table(t_row_set.columns, description)
elif t_row_set.arrowBatches is not None:
(
arrow_table,
num_rows,
) = ThriftBackend._convert_arrow_based_set_to_arrow_table(
(arrow_table, num_rows,) = convert_arrow_based_set_to_arrow_table(
t_row_set.arrowBatches, lz4_compressed, schema_bytes
)
else:
raise OperationalError("Unsupported TRowSet instance {}".format(t_row_set))
return self._convert_decimals_in_arrow_table(arrow_table, description), num_rows

@staticmethod
def _convert_decimals_in_arrow_table(table, description):
for (i, col) in enumerate(table.itercolumns()):
if description[i][1] == "decimal":
decimal_col = col.to_pandas().apply(
lambda v: v if v is None else Decimal(v)
)
precision, scale = description[i][4], description[i][5]
assert scale is not None
assert precision is not None
# Spark limits decimal to a maximum scale of 38,
# so 128 is guaranteed to be big enough
dtype = pyarrow.decimal128(precision, scale)
col_data = pyarrow.array(decimal_col, type=dtype)
field = table.field(i).with_type(dtype)
table = table.set_column(i, field, col_data)
return table

@staticmethod
def _convert_arrow_based_set_to_arrow_table(
arrow_batches, lz4_compressed, schema_bytes
):
ba = bytearray()
ba += schema_bytes
n_rows = 0
if lz4_compressed:
for arrow_batch in arrow_batches:
n_rows += arrow_batch.rowCount
ba += lz4.frame.decompress(arrow_batch.batch)
else:
for arrow_batch in arrow_batches:
n_rows += arrow_batch.rowCount
ba += arrow_batch.batch
arrow_table = pyarrow.ipc.open_stream(ba).read_all()
return arrow_table, n_rows

@staticmethod
def _convert_column_based_set_to_arrow_table(columns, description):
arrow_table = pyarrow.Table.from_arrays(
[ThriftBackend._convert_column_to_arrow_array(c) for c in columns],
# Only use the column names from the schema, the types are determined by the
# physical types used in column based set, as they can differ from the
# mapping used in _hive_schema_to_arrow_schema.
names=[c[0] for c in description],
)
return arrow_table, arrow_table.num_rows

@staticmethod
def _convert_column_to_arrow_array(t_col):
"""
Return a pyarrow array from the values in a TColumn instance.
Note that ColumnBasedSet has no native support for complex types, so they will be converted
to strings server-side.
"""
field_name_to_arrow_type = {
"boolVal": pyarrow.bool_(),
"byteVal": pyarrow.int8(),
"i16Val": pyarrow.int16(),
"i32Val": pyarrow.int32(),
"i64Val": pyarrow.int64(),
"doubleVal": pyarrow.float64(),
"stringVal": pyarrow.string(),
"binaryVal": pyarrow.binary(),
}
for field in field_name_to_arrow_type.keys():
wrapper = getattr(t_col, field)
if wrapper:
return ThriftBackend._create_arrow_array(
wrapper, field_name_to_arrow_type[field]
)

raise OperationalError("Empty TColumn instance {}".format(t_col))

@staticmethod
def _create_arrow_array(t_col_value_wrapper, arrow_type):
result = t_col_value_wrapper.values
nulls = t_col_value_wrapper.nulls # bitfield describing which values are null
assert isinstance(nulls, bytes)

# The number of bits in nulls can be both larger or smaller than the number of
# elements in result, so take the minimum of both to iterate over.
length = min(len(result), len(nulls) * 8)

for i in range(length):
if nulls[i >> 3] & ThriftBackend.BIT_MASKS[i & 0x7]:
result[i] = None

return pyarrow.array(result, type=arrow_type)
return convert_decimals_in_arrow_table(arrow_table, description), num_rows

def _get_metadata_resp(self, op_handle):
req = ttypes.TGetResultSetMetadataReq(operationHandle=op_handle)
Expand Down Expand Up @@ -752,6 +664,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
if t_result_set_metadata_resp.resultFormat not in [
ttypes.TSparkRowSetType.ARROW_BASED_SET,
ttypes.TSparkRowSetType.COLUMN_BASED_SET,
ttypes.TSparkRowSetType.URL_BASED_SET,
]:
raise OperationalError(
"Expected results to be in Arrow or column based format, "
Expand Down Expand Up @@ -783,13 +696,14 @@ def _results_message_to_execute_response(self, resp, operation_state):
assert direct_results.resultSet.results.startRowOffset == 0
assert direct_results.resultSetMetadata

arrow_results, n_rows = self._create_arrow_table(
direct_results.resultSet.results,
lz4_compressed,
schema_bytes,
description,
arrow_queue_opt = ResultSetQueueFactory.build_queue(
row_set_type=t_result_set_metadata_resp.resultFormat,
t_row_set=direct_results.resultSet.results,
arrow_schema_bytes=schema_bytes,
max_download_threads=self.max_download_threads,
lz4_compressed=lz4_compressed,
description=description,
)
arrow_queue_opt = ArrowQueue(arrow_results, n_rows, 0)
else:
arrow_queue_opt = None
return ExecuteResponse(
Expand Down Expand Up @@ -843,7 +757,14 @@ def _check_direct_results_for_error(t_spark_direct_results):
)

def execute_command(
self, operation, session_handle, max_rows, max_bytes, lz4_compression, cursor
self,
operation,
session_handle,
max_rows,
max_bytes,
lz4_compression,
cursor,
use_cloud_fetch=False,
):
assert session_handle is not None

Expand All @@ -864,7 +785,7 @@ def execute_command(
),
canReadArrowResult=True,
canDecompressLZ4Result=lz4_compression,
canDownloadResult=False,
canDownloadResult=use_cloud_fetch,
confOverlay={
# We want to receive proper Timestamp arrow types.
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
Expand Down Expand Up @@ -993,6 +914,7 @@ def fetch_results(
maxRows=max_rows,
maxBytes=max_bytes,
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
includeResultSetMetadata=True,
)

resp = self.make_request(self._client.FetchResults, req)
Expand All @@ -1002,12 +924,17 @@ def fetch_results(
expected_row_start_offset, resp.results.startRowOffset
)
)
arrow_results, n_rows = self._create_arrow_table(
resp.results, lz4_compressed, arrow_schema_bytes, description

queue = ResultSetQueueFactory.build_queue(
row_set_type=resp.resultSetMetadata.resultFormat,
t_row_set=resp.results,
arrow_schema_bytes=arrow_schema_bytes,
max_download_threads=self.max_download_threads,
lz4_compressed=lz4_compressed,
description=description,
)
arrow_queue = ArrowQueue(arrow_results, n_rows)

return arrow_queue, resp.hasMoreRows
return queue, resp.hasMoreRows

def close_command(self, op_handle):
req = ttypes.TCloseOperationReq(operationHandle=op_handle)
Expand Down
Loading

0 comments on commit 5a34a4a

Please sign in to comment.