Skip to content

Commit

Permalink
Implemented the columnar flow for non arrow users
Browse files Browse the repository at this point in the history
  • Loading branch information
jprakash-db committed Sep 13, 2024
1 parent d31063c commit eeaee96
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 32 deletions.
92 changes: 84 additions & 8 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence

import pandas
import pyarrow
try:
import pyarrow
except ImportError:
pyarrow = None
import requests
import json
import os
Expand All @@ -22,6 +25,8 @@
ParamEscaper,
inject_parameters,
transform_paramstyle,
ArrowQueue,
ColumnQueue
)
from databricks.sql.parameters.native import (
DbsqlParameterBase,
Expand Down Expand Up @@ -991,14 +996,14 @@ def fetchmany(self, size: int) -> List[Row]:
else:
raise Error("There is no active result set")

def fetchall_arrow(self) -> pyarrow.Table:
def fetchall_arrow(self) -> "pyarrow.Table":
self._check_not_closed()
if self.active_result_set:
return self.active_result_set.fetchall_arrow()
else:
raise Error("There is no active result set")

def fetchmany_arrow(self, size) -> pyarrow.Table:
def fetchmany_arrow(self, size) -> "pyarrow.Table":
self._check_not_closed()
if self.active_result_set:
return self.active_result_set.fetchmany_arrow(size)
Expand Down Expand Up @@ -1143,6 +1148,18 @@ def _fill_results_buffer(self):
self.results = results
self.has_more_rows = has_more_rows

def _convert_columnar_table(self, table):
column_names = [c[0] for c in self.description]
ResultRow = Row(*column_names)
result = []
for row_index in range(len(table[0])):
curr_row = []
for col_index in range(len(table)):
curr_row.append(table[col_index][row_index])
result.append(ResultRow(*curr_row))

return result

def _convert_arrow_table(self, table):
column_names = [c[0] for c in self.description]
ResultRow = Row(*column_names)
Expand Down Expand Up @@ -1185,7 +1202,7 @@ def _convert_arrow_table(self, table):
def rownumber(self):
return self._next_row_index

def fetchmany_arrow(self, size: int) -> pyarrow.Table:
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
"""
Fetch the next set of rows of a query result, returning a PyArrow table.
Expand All @@ -1210,7 +1227,42 @@ def fetchmany_arrow(self, size: int) -> pyarrow.Table:

return results

def fetchall_arrow(self) -> pyarrow.Table:
def merge_columnar(self, result1, result2):
"""
Function to merge / combining the columnar results into a single result
:param result1:
:param result2:
:return:
"""
merged_result = [result1[i] + result2[i] for i in range(len(result1))]
return merged_result

def fetchmany_columnar(self, size: int):
"""
Fetch the next set of rows of a query result, returning a Columnar Table.
An empty sequence is returned when no more rows are available.
"""
if size < 0:
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)

results = self.results.next_n_rows(size)
n_remaining_rows = size - len(results[0])
self._next_row_index += len(results[0])

while (
n_remaining_rows > 0
and not self.has_been_closed_server_side
and self.has_more_rows
):
self._fill_results_buffer()
partial_results = self.results.next_n_rows(n_remaining_rows)
results = self.merge_columnar(results, partial_results)
n_remaining_rows -= len(partial_results[0])
self._next_row_index += len(partial_results[0])

return results

def fetchall_arrow(self) -> "pyarrow.Table":
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
results = self.results.remaining_rows()
self._next_row_index += results.num_rows
Expand All @@ -1223,12 +1275,30 @@ def fetchall_arrow(self) -> pyarrow.Table:

return results

def fetchall_columnar(self):
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
results = self.results.remaining_rows()
self._next_row_index += len(results[0])

while not self.has_been_closed_server_side and self.has_more_rows:
self._fill_results_buffer()
partial_results = self.results.remaining_rows()
results = self.merge_columnar(results, partial_results)
self._next_row_index += len(partial_results[0])

return results

def fetchone(self) -> Optional[Row]:
"""
Fetch the next row of a query result set, returning a single sequence,
or None when no more data is available.
"""
res = self._convert_arrow_table(self.fetchmany_arrow(1))

if isinstance(self.results, ColumnQueue):
res = self._convert_columnar_table(self.fetchmany_columnar(1))
else:
res = self._convert_arrow_table(self.fetchmany_arrow(1))

if len(res) > 0:
return res[0]
else:
Expand All @@ -1238,15 +1308,21 @@ def fetchall(self) -> List[Row]:
"""
Fetch all (remaining) rows of a query result, returning them as a list of rows.
"""
return self._convert_arrow_table(self.fetchall_arrow())
if isinstance(self.results, ColumnQueue):
return self._convert_columnar_table(self.fetchall_columnar())
else:
return self._convert_arrow_table(self.fetchall_arrow())

def fetchmany(self, size: int) -> List[Row]:
"""
Fetch the next set of rows of a query result, returning a list of rows.
An empty sequence is returned when no more rows are available.
"""
return self._convert_arrow_table(self.fetchmany_arrow(size))
if isinstance(self.results, ColumnQueue):
return self._convert_columnar_table(self.fetchmany_columnar(size))
else:
return self._convert_arrow_table(self.fetchmany_arrow(size))

def close(self) -> None:
"""
Expand Down
30 changes: 22 additions & 8 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import threading
from typing import List, Union

import pyarrow
try:
import pyarrow
except ImportError:
pyarrow = None
import thrift.transport.THttpClient
import thrift.protocol.TBinaryProtocol
import thrift.transport.TSocket
Expand Down Expand Up @@ -621,6 +624,12 @@ def _get_metadata_resp(self, op_handle):

@staticmethod
def _hive_schema_to_arrow_schema(t_table_schema):

if pyarrow is None:
raise ImportError(
"pyarrow is required to convert Hive schema to Arrow schema"
)

def map_type(t_type_entry):
if t_type_entry.primitiveEntry:
return {
Expand Down Expand Up @@ -726,12 +735,17 @@ def _results_message_to_execute_response(self, resp, operation_state):
description = self._hive_schema_to_description(
t_result_set_metadata_resp.schema
)
schema_bytes = (
t_result_set_metadata_resp.arrowSchema
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
.serialize()
.to_pybytes()
)

if pyarrow:
schema_bytes = (
t_result_set_metadata_resp.arrowSchema
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
.serialize()
.to_pybytes()
)
else:
schema_bytes = None

lz4_compressed = t_result_set_metadata_resp.lz4Compressed
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
if direct_results and direct_results.resultSet:
Expand Down Expand Up @@ -827,7 +841,7 @@ def execute_command(
getDirectResults=ttypes.TSparkGetDirectResults(
maxRows=max_rows, maxBytes=max_bytes
),
canReadArrowResult=True,
canReadArrowResult=True if pyarrow else False,
canDecompressLZ4Result=lz4_compression,
canDownloadResult=use_cloud_fetch,
confOverlay={
Expand Down
Loading

0 comments on commit eeaee96

Please sign in to comment.