diff --git a/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs b/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs index 1effa83bae08..dbf41ee7579c 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs @@ -4,14 +4,12 @@ use polars_error::PolarsResult; use super::utils::{self, BatchableCollector}; use super::{BasicDecompressor, Filter}; -use crate::parquet::encoding::hybrid_rle::gatherer::{ - HybridRleGatherer, ZeroCount, ZeroCountGatherer, -}; +use crate::parquet::encoding::hybrid_rle::gatherer::HybridRleGatherer; use crate::parquet::encoding::hybrid_rle::HybridRleDecoder; use crate::parquet::error::ParquetResult; use crate::parquet::page::{split_buffer, DataPage}; use crate::parquet::read::levels::get_bit_width; -use crate::read::deserialize::utils::BatchedCollector; +use crate::read::deserialize::utils::{hybrid_rle_count_zeros, BatchedCollector}; #[derive(Debug)] pub struct Nested { @@ -537,6 +535,7 @@ fn extend_offsets2<'a, D: utils::NestedDecoder>( )?; } } + Ok(()) }, } @@ -806,54 +805,147 @@ impl PageNestedDecoder { } }, Some(mut filter) => { + enum PageStartAction { + Skip, + Collect, + } + + // We may have an action (skip / collect) for one row value left over from the + // previous page. Every page may state what the next page needs to do until the + // first of its own row values (rep_lvl = 0). + let mut last_row_value_action = PageStartAction::Skip; let mut num_rows_remaining = filter.num_rows(); - loop { + while num_rows_remaining > 0 + || matches!(last_row_value_action, PageStartAction::Collect) + { let Some(page) = self.iter.next() else { break; }; let page = page?; - // We cannot lazily decompress because we don't have the number of leaf values - // at this point. This is encoded within the `definition level` values. *sign*. - // In general, lazy decompression is quite difficult with nested values. + // We cannot lazily decompress because we don't have the number of row values + // at this point. We need repetition levels for that. *sign*. In general, lazy + // decompression is quite difficult with nested values. + // + // @TODO + // Lazy decompression is quite doable in the V2 specification since that does + // not compress the repetition and definition levels. However, not a lot of + // people use the V2 specification. So let us ignore that for now. let page = page.decompress(&mut self.iter)?; - let (def_iter, rep_iter) = level_iters(&page)?; + let (mut def_iter, mut rep_iter) = level_iters(&page)?; - let mut count = ZeroCount::default(); - rep_iter - .clone() - .gather_into(&mut count, &ZeroCountGatherer)?; + let mut state; + let mut batched_collector; - let is_fully_read = count.num_zero > num_rows_remaining; - let state_filter; - (state_filter, filter) = Filter::split_at(&filter, count.num_zero); - let state_filter = if count.num_zero > 0 { - Some(state_filter) - } else { - None - }; + let start_length = nested_state.len(); - let mut state = - utils::State::new_nested(&self.decoder, &page, self.dict.as_ref())?; + // rep lvl == 0 ==> row value + let num_row_values = hybrid_rle_count_zeros(&rep_iter)?; - let start_length = nested_state.len(); + let state_filter; + (state_filter, filter) = Filter::split_at(&filter, num_row_values); + + match last_row_value_action { + PageStartAction::Skip => { + // Fast path: skip the whole page. + // No new row values or we don't care about any of the row values. + if num_row_values == 0 && state_filter.num_rows() == 0 { + self.iter.reuse_page_buffer(page); + continue; + } - // @TODO: move this to outside the loop. - let mut batched_collector = BatchedCollector::new( - BatchedNestedDecoder { - state: &mut state, - decoder: &mut self.decoder, + let limit = idx_to_limit(&rep_iter, 0)?; + + // We just saw that we had at least one row value. + debug_assert!(limit < rep_iter.len()); + + state = + utils::State::new_nested(&self.decoder, &page, self.dict.as_ref())?; + batched_collector = BatchedCollector::new( + BatchedNestedDecoder { + state: &mut state, + decoder: &mut self.decoder, + }, + &mut target, + ); + + let num_leaf_values = + limit_to_num_values(&def_iter, &def_levels, limit)?; + batched_collector.skip_in_place(num_leaf_values)?; + rep_iter.skip_in_place(limit)?; + def_iter.skip_in_place(limit)?; }, - &mut target, - ); + PageStartAction::Collect => { + let limit = if num_row_values == 0 { + rep_iter.len() + } else { + idx_to_limit(&rep_iter, 0)? + }; + + // Fast path: we are not interested in any of the row values in this + // page. + if limit == 0 && state_filter.num_rows() == 0 { + self.iter.reuse_page_buffer(page); + last_row_value_action = PageStartAction::Skip; + continue; + } + + state = + utils::State::new_nested(&self.decoder, &page, self.dict.as_ref())?; + batched_collector = BatchedCollector::new( + BatchedNestedDecoder { + state: &mut state, + decoder: &mut self.decoder, + }, + &mut target, + ); + + extend_offsets_limited( + &mut def_iter, + &mut rep_iter, + &mut batched_collector, + &mut nested_state.nested, + limit, + &def_levels, + &rep_levels, + )?; + + // No new row values. Keep collecting. + if rep_iter.len() == 0 { + batched_collector.finalize()?; + + let num_done = nested_state.len() - start_length; + debug_assert!(num_done <= num_rows_remaining); + debug_assert!(num_done <= num_row_values); + num_rows_remaining -= num_done; + + drop(state); + self.iter.reuse_page_buffer(page); + + continue; + } + }, + } + + // Two cases: + // 1. First page: Must always start with a row value. + // 2. Other pages: If they did not have a row value, they would have been + // handled by the last_row_value_action. + debug_assert!(num_row_values > 0); + + last_row_value_action = if state_filter.do_include_at(num_row_values - 1) { + PageStartAction::Collect + } else { + PageStartAction::Skip + }; extend_offsets2( def_iter, rep_iter, &mut batched_collector, &mut nested_state.nested, - state_filter, + Some(state_filter), &def_levels, &rep_levels, )?; @@ -862,15 +954,11 @@ impl PageNestedDecoder { let num_done = nested_state.len() - start_length; debug_assert!(num_done <= num_rows_remaining); - debug_assert!(num_done <= count.num_zero); + debug_assert!(num_done <= num_row_values); num_rows_remaining -= num_done; drop(state); self.iter.reuse_page_buffer(page); - - if is_fully_read { - break; - } } }, } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/filter.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/filter.rs index b7a9c6645701..03e641634467 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/utils/filter.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils/filter.rs @@ -22,6 +22,13 @@ impl Filter { Filter::Mask(mask) } + pub fn do_include_at(&self, at: usize) -> bool { + match self { + Filter::Range(range) => range.contains(&at), + Filter::Mask(bitmap) => bitmap.get_bit(at), + } + } + pub(crate) fn num_rows(&self) -> usize { match self { Filter::Range(range) => range.len(), diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs index 6abb7612307c..7a18a0c16a85 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs @@ -760,3 +760,13 @@ pub fn freeze_validity(validity: MutableBitmap) -> Option { Some(validity) } + +pub(crate) fn hybrid_rle_count_zeros( + decoder: &hybrid_rle::HybridRleDecoder<'_>, +) -> ParquetResult { + let mut count = ZeroCount::default(); + decoder + .clone() + .gather_into(&mut count, &ZeroCountGatherer)?; + Ok(count.num_zero) +} diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index f93090ad8302..0b108d6b5508 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -1659,6 +1659,62 @@ def test_nested_skip_18303( assert_frame_equal(scanned, pl.DataFrame(tb).slice(1, 1)) +def test_nested_span_multiple_pages_18400() -> None: + width = 4100 + df = pl.DataFrame( + [ + pl.Series( + "a", + [ + list(range(width)), + list(range(width)), + ], + pl.Array(pl.Int64, width), + ), + ] + ) + + f = io.BytesIO() + pq.write_table( + df.to_arrow(), + f, + use_dictionary=False, + data_page_size=1024, + column_encoding={"a": "PLAIN"}, + ) + + f.seek(0) + assert_frame_equal(df.head(1), pl.read_parquet(f, n_rows=1)) + + +@given( + df=dataframes( + min_size=0, + max_size=1000, + min_cols=2, + max_cols=5, + excluded_dtypes=[pl.Decimal, pl.Categorical, pl.Enum, pl.Array], + include_cols=[column("filter_col", pl.Boolean, allow_null=False)], + ), +) +@pytest.mark.write_disk() +@settings( + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_parametric_small_page_mask_filtering( + tmp_path: Path, + df: pl.DataFrame, +) -> None: + tmp_path.mkdir(exist_ok=True) + f = tmp_path / "test.parquet" + + df.write_parquet(f, data_page_size=1024) + + expr = pl.col("filter_col") + result = pl.scan_parquet(f, parallel="prefiltered").filter(expr).collect() + assert_frame_equal(result, df.filter(expr)) + + @given( df=dataframes( min_size=0,