Skip to content

Commit

Permalink
fix: Several scan_parquet(parallel='prefiltered') problems (#18278)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Aug 22, 2024
1 parent f7e63dc commit 9833887
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 6 deletions.
22 changes: 17 additions & 5 deletions crates/polars-io/src/parquet/read/read_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,14 +252,16 @@ fn rg_to_dfs_prefiltered(
})
.collect::<PolarsResult<Vec<_>>>()?;

let num_live_columns = live_variables.len();
let num_dead_columns = projection.len() - num_live_columns;

// Deduplicate the live variables
let live_variables = live_variables
.iter()
.map(Deref::deref)
.collect::<PlHashSet<_>>();

// Get the number of live columns
let num_live_columns = live_variables.len();
let num_dead_columns = projection.len() - num_live_columns;

// We create two look-up tables that map indexes offsets into the live- and dead-set onto
// column indexes of the schema.
let mut live_idx_to_col_idx = Vec::with_capacity(num_live_columns);
Expand All @@ -271,7 +273,8 @@ fn rg_to_dfs_prefiltered(
dead_idx_to_col_idx.push(i);
}
}
debug_assert_eq!(live_variables.len(), num_live_columns);

debug_assert_eq!(live_idx_to_col_idx.len(), num_live_columns);
debug_assert_eq!(dead_idx_to_col_idx.len(), num_dead_columns);

POOL.install(|| {
Expand Down Expand Up @@ -316,8 +319,12 @@ fn rg_to_dfs_prefiltered(

let mut bitmap = MutableBitmap::with_capacity(mask.len());

// We need to account for the validity of the items
for chunk in mask.downcast_iter() {
bitmap.extend_from_bitmap(chunk.values());
match chunk.validity() {
None => bitmap.extend_from_bitmap(chunk.values()),
Some(validity) => bitmap.extend_from_bitmap(&(validity & chunk.values())),
}
}

let bitmap = bitmap.freeze();
Expand All @@ -341,6 +348,11 @@ fn rg_to_dfs_prefiltered(
.ok_or(ROW_COUNT_OVERFLOW_ERR)?;
}

// We don't need to do any further work if there are no dead columns
if num_dead_columns == 0 {
return Ok(dfs.into_iter().map(|(_, df)| df).collect());
}

// @TODO: Incorporate this if we how we can properly use it. The problem here is that
// different columns really have a different cost when it comes to collecting them. We
// would need a cost model to properly estimate this.
Expand Down
63 changes: 62 additions & 1 deletion py-polars/tests/unit/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import io
from datetime import datetime, time, timezone
from decimal import Decimal
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Literal, cast

import fsspec
import numpy as np
Expand Down Expand Up @@ -1510,3 +1510,64 @@ def test_delta_strings_encoding_roundtrip(

f.seek(0)
assert_frame_equal(pl.read_parquet(f), df)


EQUALITY_OPERATORS = ["__eq__", "__lt__", "__le__", "__gt__", "__ge__"]
BOOLEAN_OPERATORS = ["__or__", "__and__"]


@given(
df=dataframes(
min_size=0, max_size=100, min_cols=2, max_cols=5, allowed_dtypes=[pl.Int32]
),
first_op=st.sampled_from(EQUALITY_OPERATORS),
second_op=st.sampled_from(
[None]
+ [
(booljoin, eq)
for booljoin in BOOLEAN_OPERATORS
for eq in EQUALITY_OPERATORS
]
),
l1=st.integers(min_value=0, max_value=1000),
l2=st.integers(min_value=0, max_value=1000),
r1=st.integers(min_value=0, max_value=1000),
r2=st.integers(min_value=0, max_value=1000),
)
@pytest.mark.parametrize("parallel_st", ["auto", "prefiltered"])
@settings(
deadline=None,
suppress_health_check=[HealthCheck.function_scoped_fixture],
)
@pytest.mark.write_disk()
def test_predicate_filtering(
tmp_path: Path,
df: pl.DataFrame,
first_op: str,
second_op: None | tuple[str, str],
l1: int,
l2: int,
r1: int,
r2: int,
parallel_st: Literal["auto", "prefiltered"],
) -> None:
tmp_path.mkdir(exist_ok=True)
f = tmp_path / "test.parquet"

df.write_parquet(f, row_group_size=5)

cols = df.columns

l1s = cols[l1 % len(cols)]
l2s = cols[l2 % len(cols)]
expr = (getattr(pl.col(l1s), first_op))(pl.col(l2s))

if second_op is not None:
r1s = cols[r1 % len(cols)]
r2s = cols[r2 % len(cols)]
expr = getattr(expr, second_op[0])(
(getattr(pl.col(r1s), second_op[1]))(pl.col(r2s))
)

result = pl.scan_parquet(f, parallel=parallel_st).filter(expr).collect()
assert_frame_equal(result, df.filter(expr))

0 comments on commit 9833887

Please sign in to comment.