Skip to content

Commit

Permalink
Docstrings and comments
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>
  • Loading branch information
mattdeekay committed Jun 27, 2023
1 parent 4bdbd98 commit 3c4e3ac
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 48 deletions.
27 changes: 16 additions & 11 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
)

from databricks.sql.utils import (
ArrowQueue,
ExecuteResponse,
_bound,
RequestErrorInfo,
Expand Down Expand Up @@ -560,14 +559,9 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
(
arrow_table,
num_rows,
) = convert_column_based_set_to_arrow_table(
t_row_set.columns, description
)
) = convert_column_based_set_to_arrow_table(t_row_set.columns, description)
elif t_row_set.arrowBatches is not None:
(
arrow_table,
num_rows,
) = convert_arrow_based_set_to_arrow_table(
(arrow_table, num_rows,) = convert_arrow_based_set_to_arrow_table(
t_row_set.arrowBatches, lz4_compressed, schema_bytes
)
else:
Expand Down Expand Up @@ -656,7 +650,9 @@ def _hive_schema_to_description(t_table_schema):
ThriftBackend._col_to_description(col) for col in t_table_schema.columns
]

def _results_message_to_execute_response(self, resp, operation_state, max_download_threads):
def _results_message_to_execute_response(
self, resp, operation_state, max_download_threads
):
if resp.directResults and resp.directResults.resultSetMetadata:
t_result_set_metadata_resp = resp.directResults.resultSetMetadata
else:
Expand Down Expand Up @@ -758,7 +754,14 @@ def _check_direct_results_for_error(t_spark_direct_results):
)

def execute_command(
self, operation, session_handle, max_rows, max_bytes, lz4_compression, cursor, use_cloud_fetch=False
self,
operation,
session_handle,
max_rows,
max_bytes,
lz4_compression,
cursor,
use_cloud_fetch=False,
):
assert session_handle is not None

Expand Down Expand Up @@ -886,7 +889,9 @@ def _handle_execute_response(self, resp, cursor):

max_download_threads = cursor.connection.max_download_threads

return self._results_message_to_execute_response(resp, final_operation_state, max_download_threads)
return self._results_message_to_execute_response(
resp, final_operation_state, max_download_threads
)

def fetch_results(
self,
Expand Down
108 changes: 86 additions & 22 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@

from databricks.sql import exc, OperationalError
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink, TSparkRowSetType, TRowSet
from databricks.sql.thrift_api.TCLIService.ttypes import (
TSparkArrowResultLink,
TSparkRowSetType,
TRowSet,
)

DEFAULT_MAX_DOWNLOAD_THREADS = 10
BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]
Expand All @@ -32,27 +36,45 @@ class ResultSetQueueFactory(ABC):
def build_queue(
row_set_type: TSparkRowSetType,
t_row_set: TRowSet,
arrow_schema_bytes,
arrow_schema_bytes: bytes,
lz4_compressed: bool = True,
description: str = None,
description: List[List[any]] = None,
max_download_threads: int = DEFAULT_MAX_DOWNLOAD_THREADS,
) -> ResultSetQueue:
"""
Factory method to build a result set queue.
Args:
row_set_type (enum): Row set type (Arrow, Column, or URL).
t_row_set (TRowSet): Result containing arrow batches, columns, or cloud fetch links.
arrow_schema_bytes (bytes): Bytes representing the arrow schema.
lz4_compressed (bool): Whether result data has been lz4 compressed.
description (List[List[any]]): Hive table schema description.
max_download_threads (int): Maximum number of downloader thread pool threads.
Returns:
ResultSetQueue
"""
if row_set_type == TSparkRowSetType.ARROW_BASED_SET:
arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table(
t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes
)
converted_arrow_table = convert_decimals_in_arrow_table(arrow_table, description)
converted_arrow_table = convert_decimals_in_arrow_table(
arrow_table, description
)
return ArrowQueue(converted_arrow_table, n_valid_rows)
elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET:
arrow_table, n_valid_rows = convert_column_based_set_to_arrow_table(
t_row_set.columns, description
)
converted_arrow_table = convert_decimals_in_arrow_table(arrow_table, description)
converted_arrow_table = convert_decimals_in_arrow_table(
arrow_table, description
)
return ArrowQueue(converted_arrow_table, n_valid_rows)
elif row_set_type == TSparkRowSetType.URL_BASED_SET:
return CloudFetchQueue(
arrow_schema_bytes,
start_row_index=t_row_set.startRowOffset,
start_row_offset=t_row_set.startRowOffset,
result_links=t_row_set.resultLinks,
lz4_compressed=lz4_compressed,
description=description,
Expand Down Expand Up @@ -102,44 +124,72 @@ def __init__(
self,
schema_bytes,
max_download_threads: int,
start_row_index: int = 0,
start_row_offset: int = 0,
result_links: List[TSparkArrowResultLink] = None,
lz4_compressed: bool = True,
description: str = None,
description: List[List[any]] = None,
):
"""
A queue-like wrapper over CloudFetch arrow batches
A queue-like wrapper over CloudFetch arrow batches.
Attributes:
schema_bytes (bytes): Table schema in bytes.
max_download_threads (int): Maximum number of downloader thread pool threads.
start_row_offset (int): The offset of the first row of the cloud fetch links.
result_links (List[TSparkArrowResultLink]): Links containing the downloadable URL and metadata.
lz4_compressed (bool): Whether the files are lz4 compressed.
description (List[List[any]]): Hive table schema description.
"""
self.schema_bytes = schema_bytes
self.max_download_threads = max_download_threads
self.start_row_index = start_row_index
self.start_row_index = start_row_offset
self.result_links = result_links
self.lz4_compressed = lz4_compressed
self.description = description

self.download_manager = ResultFileDownloadManager(self.max_download_threads, self.lz4_compressed)
self.download_manager.add_file_links(result_links, start_row_index)
self.download_manager = ResultFileDownloadManager(
self.max_download_threads, self.lz4_compressed
)
self.download_manager.add_file_links(result_links)

self.table = self._create_next_table()
self.table_row_index = 0

def next_n_rows(self, num_rows: int) -> pyarrow.Table:
"""
Get up to the next n rows of the cloud fetch Arrow dataframes.
Args:
num_rows (int): Number of rows to retrieve.
Returns:
pyarrow.Table
"""
if not self.table:
# Return empty pyarrow table to cause retry of fetch
return self._create_empty_table()
results = self.table.slice(0, 0)
while num_rows > 0 and self.table:
# Get remaining of num_rows or the rest of the current table, whichever is smaller
length = min(num_rows, self.table.num_rows - self.table_row_index)
table_slice = self.table.slice(self.table_row_index, length)
results = pyarrow.concat_tables([results, table_slice])
self.table_row_index += table_slice.num_rows

# Replace current table with the next table if we are at the end of the current table
if self.table_row_index == self.table.num_rows:
self.table = self._create_next_table()
self.table_row_index = 0
num_rows -= table_slice.num_rows
return results

def remaining_rows(self) -> pyarrow.Table:
"""
Get all remaining rows of the cloud fetch Arrow dataframes.
Returns:
pyarrow.Table
"""
if not self.table:
# Return empty pyarrow table to cause retry of fetch
return self._create_empty_table()
Expand All @@ -155,18 +205,30 @@ def remaining_rows(self) -> pyarrow.Table:
return results

def _create_next_table(self) -> Union[pyarrow.Table, None]:
downloaded_file = self.download_manager.get_next_downloaded_file(self.start_row_index)
# Create next table by retrieving the logical next downloaded file, or return None to signal end of queue
downloaded_file = self.download_manager.get_next_downloaded_file(
self.start_row_index
)
if not downloaded_file:
# None signals no more Arrow tables can be built from the remaining handlers if any remain
return None
arrow_table = create_arrow_table_from_arrow_file(downloaded_file.file_bytes, self.description)
arrow_table = create_arrow_table_from_arrow_file(
downloaded_file.file_bytes, self.description
)

# The server rarely prepares the exact number of rows requested by the client in cloud fetch.
# Subsequently, we drop the extraneous rows in the last file if more rows are retrieved than requested
if arrow_table.num_rows > downloaded_file.row_count:
self.start_row_index += downloaded_file.row_count
return arrow_table.slice(0, downloaded_file.row_count)

# At this point, whether the file has extraneous rows or not, the arrow table should have the correct num rows
assert downloaded_file.row_count == arrow_table.num_rows
self.start_row_index += arrow_table.num_rows
return arrow_table

def _create_empty_table(self) -> pyarrow.Table:
# Create a 0-row table with just the schema bytes
return create_arrow_table_from_arrow_file(self.schema_bytes, self.description)


Expand Down Expand Up @@ -315,7 +377,9 @@ def inject_parameters(operation: str, parameters: Dict[str, str]):
return operation % parameters


def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> (pyarrow.Table, int):
def create_arrow_table_from_arrow_file(
file_bytes: bytes, description
) -> (pyarrow.Table, int):
arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes)
return convert_decimals_in_arrow_table(arrow_table, description)

Expand All @@ -327,15 +391,17 @@ def convert_arrow_based_file_to_arrow_table(file_bytes: bytes):
raise RuntimeError("Failure to convert arrow based file to arrow table", e)


def convert_arrow_based_set_to_arrow_table(
arrow_batches, lz4_compressed, schema_bytes
):
def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema_bytes):
ba = bytearray()
ba += schema_bytes
n_rows = 0
for arrow_batch in arrow_batches:
n_rows += arrow_batch.rowCount
ba += lz4.frame.decompress(arrow_batch.batch) if lz4_compressed else arrow_batch.batch
ba += (
lz4.frame.decompress(arrow_batch.batch)
if lz4_compressed
else arrow_batch.batch
)
arrow_table = pyarrow.ipc.open_stream(ba).read_all()
return arrow_table, n_rows

Expand Down Expand Up @@ -388,9 +454,7 @@ def _convert_column_to_arrow_array(t_col):
for field in field_name_to_arrow_type.keys():
wrapper = getattr(t_col, field)
if wrapper:
return _create_arrow_array(
wrapper, field_name_to_arrow_type[field]
)
return _create_arrow_array(wrapper, field_name_to_arrow_type[field])

raise OperationalError("Empty TColumn instance {}".format(t_col))

Expand Down
Loading

0 comments on commit 3c4e3ac

Please sign in to comment.