Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PECO-1926] Create a non pyarrow flow to handle small results for the column set #440

Merged
merged 6 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 88 additions & 8 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence

import pandas
import pyarrow
try:
import pyarrow
except ImportError:
pyarrow = None
import requests
import json
import os
Expand All @@ -22,6 +25,8 @@
ParamEscaper,
inject_parameters,
transform_paramstyle,
ColumnTable,
ColumnQueue
)
from databricks.sql.parameters.native import (
DbsqlParameterBase,
Expand Down Expand Up @@ -991,14 +996,14 @@ def fetchmany(self, size: int) -> List[Row]:
else:
raise Error("There is no active result set")

def fetchall_arrow(self) -> pyarrow.Table:
def fetchall_arrow(self) -> "pyarrow.Table":
self._check_not_closed()
if self.active_result_set:
return self.active_result_set.fetchall_arrow()
else:
raise Error("There is no active result set")

def fetchmany_arrow(self, size) -> pyarrow.Table:
def fetchmany_arrow(self, size) -> "pyarrow.Table":
self._check_not_closed()
if self.active_result_set:
return self.active_result_set.fetchmany_arrow(size)
Expand Down Expand Up @@ -1143,6 +1148,18 @@ def _fill_results_buffer(self):
self.results = results
self.has_more_rows = has_more_rows

def _convert_columnar_table(self, table):
column_names = [c[0] for c in self.description]
ResultRow = Row(*column_names)
result = []
for row_index in range(table.num_rows):
curr_row = []
for col_index in range(table.num_columns):
curr_row.append(table.get_item(col_index, row_index))
result.append(ResultRow(*curr_row))

return result

def _convert_arrow_table(self, table):
column_names = [c[0] for c in self.description]
ResultRow = Row(*column_names)
Expand Down Expand Up @@ -1185,7 +1202,7 @@ def _convert_arrow_table(self, table):
def rownumber(self):
return self._next_row_index

def fetchmany_arrow(self, size: int) -> pyarrow.Table:
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
"""
Fetch the next set of rows of a query result, returning a PyArrow table.

Expand All @@ -1210,7 +1227,46 @@ def fetchmany_arrow(self, size: int) -> pyarrow.Table:

return results

def fetchall_arrow(self) -> pyarrow.Table:
def merge_columnar(self, result1, result2):
"""
Function to merge / combining the columnar results into a single result
:param result1:
:param result2:
:return:
"""

if result1.column_names != result2.column_names:
raise ValueError("The columns in the results don't match")

merged_result = [result1.column_table[i] + result2.column_table[i] for i in range(result1.num_columns)]
return ColumnTable(merged_result, result1.column_names)

def fetchmany_columnar(self, size: int):
"""
Fetch the next set of rows of a query result, returning a Columnar Table.
An empty sequence is returned when no more rows are available.
"""
if size < 0:
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)

results = self.results.next_n_rows(size)
n_remaining_rows = size - results.num_rows
self._next_row_index += results.num_rows

while (
n_remaining_rows > 0
and not self.has_been_closed_server_side
and self.has_more_rows
):
self._fill_results_buffer()
partial_results = self.results.next_n_rows(n_remaining_rows)
results = self.merge_columnar(results, partial_results)
n_remaining_rows -= partial_results.num_rows
self._next_row_index += partial_results.num_rows

return results

def fetchall_arrow(self) -> "pyarrow.Table":
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
results = self.results.remaining_rows()
self._next_row_index += results.num_rows
Expand All @@ -1223,12 +1279,30 @@ def fetchall_arrow(self) -> pyarrow.Table:

return results

def fetchall_columnar(self):
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
results = self.results.remaining_rows()
self._next_row_index += results.num_rows

while not self.has_been_closed_server_side and self.has_more_rows:
self._fill_results_buffer()
partial_results = self.results.remaining_rows()
results = self.merge_columnar(results, partial_results)
self._next_row_index += partial_results.num_rows

return results

def fetchone(self) -> Optional[Row]:
"""
Fetch the next row of a query result set, returning a single sequence,
or None when no more data is available.
"""
res = self._convert_arrow_table(self.fetchmany_arrow(1))

if isinstance(self.results, ColumnQueue):
res = self._convert_columnar_table(self.fetchmany_columnar(1))
else:
res = self._convert_arrow_table(self.fetchmany_arrow(1))

if len(res) > 0:
return res[0]
else:
Expand All @@ -1238,15 +1312,21 @@ def fetchall(self) -> List[Row]:
"""
Fetch all (remaining) rows of a query result, returning them as a list of rows.
"""
return self._convert_arrow_table(self.fetchall_arrow())
if isinstance(self.results, ColumnQueue):
return self._convert_columnar_table(self.fetchall_columnar())
else:
return self._convert_arrow_table(self.fetchall_arrow())

def fetchmany(self, size: int) -> List[Row]:
"""
Fetch the next set of rows of a query result, returning a list of rows.

An empty sequence is returned when no more rows are available.
"""
return self._convert_arrow_table(self.fetchmany_arrow(size))
if isinstance(self.results, ColumnQueue):
return self._convert_columnar_table(self.fetchmany_columnar(size))
else:
return self._convert_arrow_table(self.fetchmany_arrow(size))

def close(self) -> None:
"""
Expand Down
25 changes: 17 additions & 8 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import threading
from typing import List, Union

import pyarrow
try:
import pyarrow
except ImportError:
pyarrow = None
import thrift.transport.THttpClient
import thrift.protocol.TBinaryProtocol
import thrift.transport.TSocket
Expand Down Expand Up @@ -621,6 +624,7 @@ def _get_metadata_resp(self, op_handle):

@staticmethod
def _hive_schema_to_arrow_schema(t_table_schema):

def map_type(t_type_entry):
if t_type_entry.primitiveEntry:
return {
Expand Down Expand Up @@ -726,12 +730,17 @@ def _results_message_to_execute_response(self, resp, operation_state):
description = self._hive_schema_to_description(
t_result_set_metadata_resp.schema
)
schema_bytes = (
t_result_set_metadata_resp.arrowSchema
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
.serialize()
.to_pybytes()
)

if pyarrow:
schema_bytes = (
t_result_set_metadata_resp.arrowSchema
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
.serialize()
.to_pybytes()
)
else:
schema_bytes = None

lz4_compressed = t_result_set_metadata_resp.lz4Compressed
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
if direct_results and direct_results.resultSet:
Expand Down Expand Up @@ -827,7 +836,7 @@ def execute_command(
getDirectResults=ttypes.TSparkGetDirectResults(
maxRows=max_rows, maxBytes=max_bytes
),
canReadArrowResult=True,
canReadArrowResult=True if pyarrow else False,
canDecompressLZ4Result=lz4_compression,
canDownloadResult=use_cloud_fetch,
confOverlay={
Expand Down
Loading
Loading