Skip to content

Commit

Permalink
Client-settable max_download_threads
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>
  • Loading branch information
mattdeekay committed Jun 27, 2023
1 parent f6403a1 commit 4bdbd98
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ def _fill_results_buffer(self):
lz4_compressed=self.lz4_compressed,
arrow_schema_bytes=self._arrow_schema_bytes,
description=self.description,
max_download_threads=self.connection.max_download_threads,
)
self.results = results
self.has_more_rows = has_more_rows
Expand Down
9 changes: 7 additions & 2 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def _hive_schema_to_description(t_table_schema):
ThriftBackend._col_to_description(col) for col in t_table_schema.columns
]

def _results_message_to_execute_response(self, resp, operation_state):
def _results_message_to_execute_response(self, resp, operation_state, max_download_threads):
if resp.directResults and resp.directResults.resultSetMetadata:
t_result_set_metadata_resp = resp.directResults.resultSetMetadata
else:
Expand Down Expand Up @@ -703,6 +703,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
arrow_schema_bytes=schema_bytes,
lz4_compressed=lz4_compressed,
description=description,
max_download_threads=max_download_threads,
)
else:
arrow_queue_opt = None
Expand Down Expand Up @@ -883,7 +884,9 @@ def _handle_execute_response(self, resp, cursor):
resp.directResults and resp.directResults.operationStatus,
)

return self._results_message_to_execute_response(resp, final_operation_state)
max_download_threads = cursor.connection.max_download_threads

return self._results_message_to_execute_response(resp, final_operation_state, max_download_threads)

def fetch_results(
self,
Expand All @@ -894,6 +897,7 @@ def fetch_results(
lz4_compressed,
arrow_schema_bytes,
description,
max_download_threads,
):
assert op_handle is not None

Expand Down Expand Up @@ -924,6 +928,7 @@ def fetch_results(
arrow_schema_bytes=arrow_schema_bytes,
lz4_compressed=lz4_compressed,
description=description,
max_download_threads=max_download_threads,
)

return queue, resp.hasMoreRows
Expand Down
8 changes: 5 additions & 3 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def build_queue(
arrow_schema_bytes,
lz4_compressed: bool = True,
description: str = None,
max_download_threads: int = DEFAULT_MAX_DOWNLOAD_THREADS,
) -> ResultSetQueue:
if row_set_type == TSparkRowSetType.ARROW_BASED_SET:
arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table(
Expand All @@ -54,7 +55,8 @@ def build_queue(
start_row_index=t_row_set.startRowOffset,
result_links=t_row_set.resultLinks,
lz4_compressed=lz4_compressed,
description=description
description=description,
max_download_threads=max_download_threads,
)
else:
raise AssertionError("Row set type is not valid")
Expand Down Expand Up @@ -99,6 +101,7 @@ class CloudFetchQueue(ResultSetQueue):
def __init__(
self,
schema_bytes,
max_download_threads: int,
start_row_index: int = 0,
result_links: List[TSparkArrowResultLink] = None,
lz4_compressed: bool = True,
Expand All @@ -108,11 +111,11 @@ def __init__(
A queue-like wrapper over CloudFetch arrow batches
"""
self.schema_bytes = schema_bytes
self.max_download_threads = max_download_threads
self.start_row_index = start_row_index
self.result_links = result_links
self.lz4_compressed = lz4_compressed
self.description = description
self.max_download_threads = DEFAULT_MAX_DOWNLOAD_THREADS

self.download_manager = ResultFileDownloadManager(self.max_download_threads, self.lz4_compressed)
self.download_manager.add_file_links(result_links, start_row_index)
Expand Down Expand Up @@ -152,7 +155,6 @@ def remaining_rows(self) -> pyarrow.Table:
return results

def _create_next_table(self) -> Union[pyarrow.Table, None]:
# TODO: add retry logic from _fill_results_buffer_cloudfetch
downloaded_file = self.download_manager.get_next_downloaded_file(self.start_row_index)
if not downloaded_file:
return None
Expand Down

0 comments on commit 4bdbd98

Please sign in to comment.