Skip to content

Commit

Permalink
Reformatted
Browse files Browse the repository at this point in the history
  • Loading branch information
jprakash-db committed Sep 18, 2024
1 parent 3318b04 commit 2470581
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 16 deletions.
12 changes: 8 additions & 4 deletions databricks_sql_connector_core/src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
TSparkParameter,
)

try:
import pyarrow
except ImportError:
pyarrow = None

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -977,14 +981,14 @@ def fetchmany(self, size: int) -> List[Row]:
else:
raise Error("There is no active result set")

def fetchall_arrow(self) -> pyarrow.Table:
def fetchall_arrow(self) -> "pyarrow.Table":
self._check_not_closed()
if self.active_result_set:
return self.active_result_set.fetchall_arrow()
else:
raise Error("There is no active result set")

def fetchmany_arrow(self, size) -> pyarrow.Table:
def fetchmany_arrow(self, size) -> "pyarrow.Table":
self._check_not_closed()
if self.active_result_set:
return self.active_result_set.fetchmany_arrow(size)
Expand Down Expand Up @@ -1171,7 +1175,7 @@ def _convert_arrow_table(self, table):
def rownumber(self):
return self._next_row_index

def fetchmany_arrow(self, size: int) -> pyarrow.Table:
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
"""
Fetch the next set of rows of a query result, returning a PyArrow table.
Expand All @@ -1196,7 +1200,7 @@ def fetchmany_arrow(self, size: int) -> pyarrow.Table:

return results

def fetchall_arrow(self) -> pyarrow.Table:
def fetchall_arrow(self) -> "pyarrow.Table":
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
results = self.results.remaining_rows()
self._next_row_index += results.num_rows
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
convert_column_based_set_to_arrow_table,
)

try:
import pyarrow
except ImportError:
pyarrow = None

logger = logging.getLogger(__name__)

unsafe_logger = logging.getLogger("databricks.sql.unsafe")
Expand Down Expand Up @@ -652,6 +657,12 @@ def _get_metadata_resp(self, op_handle):

@staticmethod
def _hive_schema_to_arrow_schema(t_table_schema):

if pyarrow is None:
raise ImportError(
"pyarrow is required to convert Hive schema to Arrow schema"
)

def map_type(t_type_entry):
if t_type_entry.primitiveEntry:
return {
Expand Down Expand Up @@ -858,7 +869,7 @@ def execute_command(
getDirectResults=ttypes.TSparkGetDirectResults(
maxRows=max_rows, maxBytes=max_bytes
),
canReadArrowResult=True,
canReadArrowResult=True if pyarrow else False,
canDecompressLZ4Result=lz4_compression,
canDownloadResult=use_cloud_fetch,
confOverlay={
Expand Down
27 changes: 16 additions & 11 deletions databricks_sql_connector_core/src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,21 @@

import logging

try:
import pyarrow
except ImportError:
pyarrow = None

logger = logging.getLogger(__name__)


class ResultSetQueue(ABC):
@abstractmethod
def next_n_rows(self, num_rows: int) -> pyarrow.Table:
def next_n_rows(self, num_rows: int):
pass

@abstractmethod
def remaining_rows(self) -> pyarrow.Table:
def remaining_rows(self):
pass


Expand Down Expand Up @@ -100,7 +105,7 @@ def build_queue(
class ArrowQueue(ResultSetQueue):
def __init__(
self,
arrow_table: pyarrow.Table,
arrow_table: "pyarrow.Table",
n_valid_rows: int,
start_row_index: int = 0,
):
Expand All @@ -115,7 +120,7 @@ def __init__(
self.arrow_table = arrow_table
self.n_valid_rows = n_valid_rows

def next_n_rows(self, num_rows: int) -> pyarrow.Table:
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
"""Get upto the next n rows of the Arrow dataframe"""
length = min(num_rows, self.n_valid_rows - self.cur_row_index)
# Note that the table.slice API is not the same as Python's slice
Expand All @@ -124,7 +129,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table:
self.cur_row_index += slice.num_rows
return slice

def remaining_rows(self) -> pyarrow.Table:
def remaining_rows(self) -> "pyarrow.Table":
slice = self.arrow_table.slice(
self.cur_row_index, self.n_valid_rows - self.cur_row_index
)
Expand Down Expand Up @@ -184,7 +189,7 @@ def __init__(
self.table = self._create_next_table()
self.table_row_index = 0

def next_n_rows(self, num_rows: int) -> pyarrow.Table:
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
"""
Get up to the next n rows of the cloud fetch Arrow dataframes.
Expand Down Expand Up @@ -216,7 +221,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table:
logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows))
return results

def remaining_rows(self) -> pyarrow.Table:
def remaining_rows(self) -> "pyarrow.Table":
"""
Get all remaining rows of the cloud fetch Arrow dataframes.
Expand All @@ -237,7 +242,7 @@ def remaining_rows(self) -> pyarrow.Table:
self.table_row_index = 0
return results

def _create_next_table(self) -> Union[pyarrow.Table, None]:
def _create_next_table(self) -> Union["pyarrow.Table", None]:
logger.debug(
"CloudFetchQueue: Trying to get downloaded file for row {}".format(
self.start_row_index
Expand Down Expand Up @@ -276,7 +281,7 @@ def _create_next_table(self) -> Union[pyarrow.Table, None]:

return arrow_table

def _create_empty_table(self) -> pyarrow.Table:
def _create_empty_table(self) -> "pyarrow.Table":
# Create a 0-row table with just the schema bytes
return create_arrow_table_from_arrow_file(self.schema_bytes, self.description)

Expand Down Expand Up @@ -515,7 +520,7 @@ def transform_paramstyle(
return output


def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> pyarrow.Table:
def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> "pyarrow.Table":
arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes)
return convert_decimals_in_arrow_table(arrow_table, description)

Expand All @@ -542,7 +547,7 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema
return arrow_table, n_rows


def convert_decimals_in_arrow_table(table, description) -> pyarrow.Table:
def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table":
for i, col in enumerate(table.itercolumns()):
if description[i][1] == "decimal":
decimal_col = col.to_pandas().apply(
Expand Down

0 comments on commit 2470581

Please sign in to comment.