Skip to content

Commit

Permalink
fix(python): Address read_database issue with batched reads from Sn…
Browse files Browse the repository at this point in the history
…owflake (pola-rs#17688)
  • Loading branch information
alexander-beedie authored Jul 17, 2024
1 parent f253f99 commit ebba58d
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 19 deletions.
24 changes: 16 additions & 8 deletions py-polars/polars/io/database/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,14 @@ def _fetch_arrow(
fetch_method = driver_properties["fetch_all"]
yield getattr(self.result, fetch_method)()
else:
size = batch_size if driver_properties["exact_batch_size"] else None
size = [batch_size] if driver_properties["exact_batch_size"] else []
repeat_batch_calls = driver_properties["repeat_batch_calls"]
fetchmany_arrow = getattr(self.result, fetch_batches)
if not repeat_batch_calls:
yield from fetchmany_arrow(size)
yield from fetchmany_arrow(*size)
else:
while True:
arrow = fetchmany_arrow(size)
arrow = fetchmany_arrow(*size)
if not arrow:
break
yield arrow
Expand Down Expand Up @@ -213,6 +213,13 @@ def _from_arrow(
if re.match(f"^{driver}$", self.driver_name):
if ver := driver_properties["minimum_version"]:
self._check_module_version(self.driver_name, ver)

if iter_batches and (
driver_properties["exact_batch_size"] and not batch_size
):
msg = f"Cannot set `iter_batches` for {self.driver_name} without also setting a non-zero `batch_size`"
raise ValueError(msg) # noqa: TRY301

frames = (
self._apply_overrides(batch, (schema_overrides or {}))
if isinstance(batch, DataFrame)
Expand Down Expand Up @@ -247,6 +254,12 @@ def _from_rows(
"""Return resultset data row-wise for frame init."""
from polars import DataFrame

if iter_batches and not batch_size:
msg = (
"Cannot set `iter_batches` without also setting a non-zero `batch_size`"
)
raise ValueError(msg)

if is_async := isinstance(original_result := self.result, Coroutine):
self.result = _run_async(self.result)
try:
Expand Down Expand Up @@ -506,11 +519,6 @@ def to_polars(
if self.result is None:
msg = "Cannot return a frame before executing a query"
raise RuntimeError(msg)
elif iter_batches and not batch_size:
msg = (
"Cannot set `iter_batches` without also setting a non-zero `batch_size`"
)
raise ValueError(msg)

can_close = self.can_close_cursor

Expand Down
13 changes: 7 additions & 6 deletions py-polars/polars/io/database/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,17 @@ def read_database(
data returned by the query; this can be useful for processing large resultsets
in a memory-efficient manner. If supported by the backend, this value is passed
to the underlying query execution method (note that very low values will
typically result in poor performance as it will result in many round-trips to
the database as the data is returned). If the backend does not support changing
typically result in poor performance as it will cause many round-trips to the
database as the data is returned). If the backend does not support changing
the batch size then a single DataFrame is yielded from the iterator.
batch_size
Indicate the size of each batch when `iter_batches` is True (note that you can
still set this when `iter_batches` is False, in which case the resulting
DataFrame is constructed internally using batched return before being returned
to you. Note that some backends may support batched operation but not allow for
an explicit size; in this case you will still receive batches, but their exact
size will be determined by the backend (so may not equal the value set here).
to you. Note that some backends (such as Snowflake) may support batch operation
but not allow for an explicit size to be set; in this case you will still
receive batches but their size is determined by the backend (in which case any
value set here will be ignored).
schema_overrides
A dictionary mapping column names to dtypes, used to override the schema
inferred from the query cursor or given by the incoming Arrow data (depending
Expand Down Expand Up @@ -242,7 +243,7 @@ def read_database(
connection = ODBCCursorProxy(connection)
elif "://" in connection:
# otherwise looks like a mistaken call to read_database_uri
msg = "Use of string URI is invalid here; call `read_database_uri` instead"
msg = "use of string URI is invalid here; call `read_database_uri` instead"
raise ValueError(msg)
else:
msg = "unable to identify string connection as valid ODBC (no driver)"
Expand Down
30 changes: 25 additions & 5 deletions py-polars/tests/unit/io/database/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ def __init__(
self,
driver: str,
batch_size: int | None,
exact_batch_size: bool,
test_data: pa.Table,
repeat_batch_calls: bool,
) -> None:
self.__class__.__module__ = driver
self._cursor = MockCursor(
repeat_batch_calls=repeat_batch_calls,
exact_batch_size=exact_batch_size,
batched=(batch_size is not None),
test_data=test_data,
)
Expand All @@ -69,10 +71,17 @@ class MockCursor:
def __init__(
self,
batched: bool,
exact_batch_size: bool,
test_data: pa.Table,
repeat_batch_calls: bool,
) -> None:
self.resultset = MockResultSet(test_data, batched, repeat_batch_calls)
self.resultset = MockResultSet(
test_data=test_data,
batched=batched,
exact_batch_size=exact_batch_size,
repeat_batch_calls=repeat_batch_calls,
)
self.exact_batch_size = exact_batch_size
self.called: list[str] = []
self.batched = batched
self.n_calls = 1
Expand All @@ -94,14 +103,21 @@ class MockResultSet:
"""Mock resultset class for databases we can't test in CI."""

def __init__(
self, test_data: pa.Table, batched: bool, repeat_batch_calls: bool = False
self,
test_data: pa.Table,
batched: bool,
exact_batch_size: bool,
repeat_batch_calls: bool = False,
):
self.test_data = test_data
self.repeat_batched_calls = repeat_batch_calls
self.exact_batch_size = exact_batch_size
self.batched = batched
self.n_calls = 1

def __call__(self, *args: Any, **kwargs: Any) -> Any:
if not self.exact_batch_size:
assert len(args) == 0
if self.repeat_batched_calls:
res = self.test_data[: None if self.n_calls else 0]
self.n_calls -= 1
Expand Down Expand Up @@ -478,13 +494,17 @@ def test_read_database_mocked(
# since we don't have access to snowflake/databricks/etc from CI we
# mock them so we can check that we're calling the expected methods
arrow = pl.DataFrame({"x": [1, 2, 3], "y": ["aa", "bb", "cc"]}).to_arrow()

reg = ARROW_DRIVER_REGISTRY.get(driver, {}) # type: ignore[var-annotated]
exact_batch_size = reg.get("exact_batch_size", False)
repeat_batch_calls = reg.get("repeat_batch_calls", False)

mc = MockConnection(
driver,
batch_size,
test_data=arrow,
repeat_batch_calls=ARROW_DRIVER_REGISTRY.get(driver, {}).get( # type: ignore[call-overload]
"repeat_batch_calls", False
),
repeat_batch_calls=repeat_batch_calls,
exact_batch_size=exact_batch_size, # type: ignore[arg-type]
)
res = pl.read_database(
query="SELECT * FROM test_data",
Expand Down

0 comments on commit ebba58d

Please sign in to comment.