diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 5304d4e4..fd2ad7c2 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -25,7 +25,6 @@ ) from databricks.sql.utils import ( - ArrowQueue, ExecuteResponse, _bound, RequestErrorInfo, @@ -560,14 +559,9 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti ( arrow_table, num_rows, - ) = convert_column_based_set_to_arrow_table( - t_row_set.columns, description - ) + ) = convert_column_based_set_to_arrow_table(t_row_set.columns, description) elif t_row_set.arrowBatches is not None: - ( - arrow_table, - num_rows, - ) = convert_arrow_based_set_to_arrow_table( + (arrow_table, num_rows,) = convert_arrow_based_set_to_arrow_table( t_row_set.arrowBatches, lz4_compressed, schema_bytes ) else: @@ -656,7 +650,9 @@ def _hive_schema_to_description(t_table_schema): ThriftBackend._col_to_description(col) for col in t_table_schema.columns ] - def _results_message_to_execute_response(self, resp, operation_state, max_download_threads): + def _results_message_to_execute_response( + self, resp, operation_state, max_download_threads + ): if resp.directResults and resp.directResults.resultSetMetadata: t_result_set_metadata_resp = resp.directResults.resultSetMetadata else: @@ -758,7 +754,14 @@ def _check_direct_results_for_error(t_spark_direct_results): ) def execute_command( - self, operation, session_handle, max_rows, max_bytes, lz4_compression, cursor, use_cloud_fetch=False + self, + operation, + session_handle, + max_rows, + max_bytes, + lz4_compression, + cursor, + use_cloud_fetch=False, ): assert session_handle is not None @@ -886,7 +889,9 @@ def _handle_execute_response(self, resp, cursor): max_download_threads = cursor.connection.max_download_threads - return self._results_message_to_execute_response(resp, final_operation_state, max_download_threads) + return self._results_message_to_execute_response( + resp, final_operation_state, max_download_threads + ) def fetch_results( self, diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 3848a137..255eb7f4 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -11,7 +11,11 @@ from databricks.sql import exc, OperationalError from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager -from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink, TSparkRowSetType, TRowSet +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TSparkArrowResultLink, + TSparkRowSetType, + TRowSet, +) DEFAULT_MAX_DOWNLOAD_THREADS = 10 BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128] @@ -32,27 +36,45 @@ class ResultSetQueueFactory(ABC): def build_queue( row_set_type: TSparkRowSetType, t_row_set: TRowSet, - arrow_schema_bytes, + arrow_schema_bytes: bytes, lz4_compressed: bool = True, - description: str = None, + description: List[List[any]] = None, max_download_threads: int = DEFAULT_MAX_DOWNLOAD_THREADS, ) -> ResultSetQueue: + """ + Factory method to build a result set queue. + + Args: + row_set_type (enum): Row set type (Arrow, Column, or URL). + t_row_set (TRowSet): Result containing arrow batches, columns, or cloud fetch links. + arrow_schema_bytes (bytes): Bytes representing the arrow schema. + lz4_compressed (bool): Whether result data has been lz4 compressed. + description (List[List[any]]): Hive table schema description. + max_download_threads (int): Maximum number of downloader thread pool threads. + + Returns: + ResultSetQueue + """ if row_set_type == TSparkRowSetType.ARROW_BASED_SET: arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table( t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes ) - converted_arrow_table = convert_decimals_in_arrow_table(arrow_table, description) + converted_arrow_table = convert_decimals_in_arrow_table( + arrow_table, description + ) return ArrowQueue(converted_arrow_table, n_valid_rows) elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET: arrow_table, n_valid_rows = convert_column_based_set_to_arrow_table( t_row_set.columns, description ) - converted_arrow_table = convert_decimals_in_arrow_table(arrow_table, description) + converted_arrow_table = convert_decimals_in_arrow_table( + arrow_table, description + ) return ArrowQueue(converted_arrow_table, n_valid_rows) elif row_set_type == TSparkRowSetType.URL_BASED_SET: return CloudFetchQueue( arrow_schema_bytes, - start_row_index=t_row_set.startRowOffset, + start_row_offset=t_row_set.startRowOffset, result_links=t_row_set.resultLinks, lz4_compressed=lz4_compressed, description=description, @@ -102,37 +124,59 @@ def __init__( self, schema_bytes, max_download_threads: int, - start_row_index: int = 0, + start_row_offset: int = 0, result_links: List[TSparkArrowResultLink] = None, lz4_compressed: bool = True, - description: str = None, + description: List[List[any]] = None, ): """ - A queue-like wrapper over CloudFetch arrow batches + A queue-like wrapper over CloudFetch arrow batches. + + Attributes: + schema_bytes (bytes): Table schema in bytes. + max_download_threads (int): Maximum number of downloader thread pool threads. + start_row_offset (int): The offset of the first row of the cloud fetch links. + result_links (List[TSparkArrowResultLink]): Links containing the downloadable URL and metadata. + lz4_compressed (bool): Whether the files are lz4 compressed. + description (List[List[any]]): Hive table schema description. """ self.schema_bytes = schema_bytes self.max_download_threads = max_download_threads - self.start_row_index = start_row_index + self.start_row_index = start_row_offset self.result_links = result_links self.lz4_compressed = lz4_compressed self.description = description - self.download_manager = ResultFileDownloadManager(self.max_download_threads, self.lz4_compressed) - self.download_manager.add_file_links(result_links, start_row_index) + self.download_manager = ResultFileDownloadManager( + self.max_download_threads, self.lz4_compressed + ) + self.download_manager.add_file_links(result_links) self.table = self._create_next_table() self.table_row_index = 0 def next_n_rows(self, num_rows: int) -> pyarrow.Table: + """ + Get up to the next n rows of the cloud fetch Arrow dataframes. + + Args: + num_rows (int): Number of rows to retrieve. + + Returns: + pyarrow.Table + """ if not self.table: # Return empty pyarrow table to cause retry of fetch return self._create_empty_table() results = self.table.slice(0, 0) while num_rows > 0 and self.table: + # Get remaining of num_rows or the rest of the current table, whichever is smaller length = min(num_rows, self.table.num_rows - self.table_row_index) table_slice = self.table.slice(self.table_row_index, length) results = pyarrow.concat_tables([results, table_slice]) self.table_row_index += table_slice.num_rows + + # Replace current table with the next table if we are at the end of the current table if self.table_row_index == self.table.num_rows: self.table = self._create_next_table() self.table_row_index = 0 @@ -140,6 +184,12 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table: return results def remaining_rows(self) -> pyarrow.Table: + """ + Get all remaining rows of the cloud fetch Arrow dataframes. + + Returns: + pyarrow.Table + """ if not self.table: # Return empty pyarrow table to cause retry of fetch return self._create_empty_table() @@ -155,18 +205,30 @@ def remaining_rows(self) -> pyarrow.Table: return results def _create_next_table(self) -> Union[pyarrow.Table, None]: - downloaded_file = self.download_manager.get_next_downloaded_file(self.start_row_index) + # Create next table by retrieving the logical next downloaded file, or return None to signal end of queue + downloaded_file = self.download_manager.get_next_downloaded_file( + self.start_row_index + ) if not downloaded_file: + # None signals no more Arrow tables can be built from the remaining handlers if any remain return None - arrow_table = create_arrow_table_from_arrow_file(downloaded_file.file_bytes, self.description) + arrow_table = create_arrow_table_from_arrow_file( + downloaded_file.file_bytes, self.description + ) + + # The server rarely prepares the exact number of rows requested by the client in cloud fetch. + # Subsequently, we drop the extraneous rows in the last file if more rows are retrieved than requested if arrow_table.num_rows > downloaded_file.row_count: self.start_row_index += downloaded_file.row_count return arrow_table.slice(0, downloaded_file.row_count) + + # At this point, whether the file has extraneous rows or not, the arrow table should have the correct num rows assert downloaded_file.row_count == arrow_table.num_rows self.start_row_index += arrow_table.num_rows return arrow_table def _create_empty_table(self) -> pyarrow.Table: + # Create a 0-row table with just the schema bytes return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) @@ -315,7 +377,9 @@ def inject_parameters(operation: str, parameters: Dict[str, str]): return operation % parameters -def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> (pyarrow.Table, int): +def create_arrow_table_from_arrow_file( + file_bytes: bytes, description +) -> (pyarrow.Table, int): arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes) return convert_decimals_in_arrow_table(arrow_table, description) @@ -327,15 +391,17 @@ def convert_arrow_based_file_to_arrow_table(file_bytes: bytes): raise RuntimeError("Failure to convert arrow based file to arrow table", e) -def convert_arrow_based_set_to_arrow_table( - arrow_batches, lz4_compressed, schema_bytes -): +def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema_bytes): ba = bytearray() ba += schema_bytes n_rows = 0 for arrow_batch in arrow_batches: n_rows += arrow_batch.rowCount - ba += lz4.frame.decompress(arrow_batch.batch) if lz4_compressed else arrow_batch.batch + ba += ( + lz4.frame.decompress(arrow_batch.batch) + if lz4_compressed + else arrow_batch.batch + ) arrow_table = pyarrow.ipc.open_stream(ba).read_all() return arrow_table, n_rows @@ -388,9 +454,7 @@ def _convert_column_to_arrow_array(t_col): for field in field_name_to_arrow_type.keys(): wrapper = getattr(t_col, field) if wrapper: - return _create_arrow_array( - wrapper, field_name_to_arrow_type[field] - ) + return _create_arrow_array(wrapper, field_name_to_arrow_type[field]) raise OperationalError("Empty TColumn instance {}".format(t_col)) diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index 4ffa8e13..7c8e4bf4 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -47,7 +47,7 @@ def get_schema_bytes(): def test_initializer_adds_links(self, mock_create_next_table): schema_bytes = MagicMock() result_links = self.create_result_links(10) - queue = utils.CloudFetchQueue(schema_bytes, result_links=result_links) + queue = utils.CloudFetchQueue(schema_bytes, result_links=result_links, max_download_threads=10) assert len(queue.download_manager.download_handlers) == 10 mock_create_next_table.assert_called() @@ -55,14 +55,14 @@ def test_initializer_adds_links(self, mock_create_next_table): def test_initializer_no_links_to_add(self): schema_bytes = MagicMock() result_links = [] - queue = utils.CloudFetchQueue(schema_bytes, result_links=result_links) + queue = utils.CloudFetchQueue(schema_bytes, result_links=result_links, max_download_threads=10) assert len(queue.download_manager.download_handlers) == 0 assert queue.table is None @patch("databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", return_value=None) def test_create_next_table_no_download(self, mock_get_next_downloaded_file): - queue = utils.CloudFetchQueue(MagicMock(), result_links=[]) + queue = utils.CloudFetchQueue(MagicMock(), result_links=[], max_download_threads=10) assert queue._create_next_table() is None assert mock_get_next_downloaded_file.called_with(0) @@ -72,7 +72,7 @@ def test_create_next_table_no_download(self, mock_get_next_downloaded_file): return_value=MagicMock(file_bytes=b"1234567890", row_count=4)) def test_initializer_create_next_table_success(self, mock_get_next_downloaded_file, mock_create_arrow_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) expected_result = self.make_arrow_table() assert mock_create_arrow_table.called_with(b"1234567890", True, schema_bytes, description) @@ -90,7 +90,7 @@ def test_initializer_create_next_table_success(self, mock_get_next_downloaded_fi @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=make_arrow_table()) def test_next_n_rows_0_rows(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -103,7 +103,7 @@ def test_next_n_rows_0_rows(self, mock_create_next_table): @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=make_arrow_table()) def test_next_n_rows_partial_table(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -116,7 +116,7 @@ def test_next_n_rows_partial_table(self, mock_create_next_table): @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=make_arrow_table()) def test_next_n_rows_more_than_one_table(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -129,7 +129,7 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=make_arrow_table()) def test_next_n_rows_more_than_one_table(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -142,7 +142,7 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[make_arrow_table(), None]) def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -155,7 +155,7 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): def test_next_n_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() description = MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table is None result = queue.next_n_rows(100) @@ -164,7 +164,7 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[make_arrow_table(), None, 0]) def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 4 @@ -176,7 +176,7 @@ def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table) @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[make_arrow_table(), None]) def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 2 @@ -188,7 +188,7 @@ def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_tabl @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[make_arrow_table(), None]) def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -201,7 +201,7 @@ def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): side_effect=[make_arrow_table(), make_arrow_table(), None]) def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 3 @@ -215,7 +215,7 @@ def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_ta def test_remaining_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() description = MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table is None result = queue.remaining_rows()