Skip to content

Commit

Permalink
fix: Propagate struct outer nullability eagerly (#17697)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jul 18, 2024
1 parent ebba58d commit ab5f8c1
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 62 deletions.
3 changes: 2 additions & 1 deletion crates/polars-arrow/src/array/null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ impl Array for NullArray {
}

fn with_validity(&self, _: Option<Bitmap>) -> Box<dyn Array> {
panic!("cannot set validity of a null array")
// Nulls with invalid nulls are also nulls.
self.clone().boxed()
}
}

Expand Down
14 changes: 4 additions & 10 deletions crates/polars-arrow/src/array/struct_/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ impl StructArray {
.for_each(|x| x.slice_unchecked(offset, length));
}

/// Set the outer nulls into the inner arrays, and clear the outer validity.
/// Set the outer nulls into the inner arrays.
pub fn propagate_nulls(&self) -> StructArray {
let has_nulls = self.null_count() > 0;
let mut out = self.clone();
Expand All @@ -203,16 +203,10 @@ impl StructArray {
};

for value_arr in &mut out.values {
let new = if has_nulls {
let new_validity = combine_validities_and(self.validity(), value_arr.validity());
value_arr.with_validity(new_validity)
} else {
value_arr.clone()
};

*value_arr = new;
let new_validity = combine_validities_and(self.validity(), value_arr.validity());
*value_arr = value_arr.with_validity(new_validity);
}
out.with_validity(None)
out
}

impl_sliced!();
Expand Down
40 changes: 0 additions & 40 deletions crates/polars-core/src/chunked_array/ops/downcast.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::marker::PhantomData;

use arrow::array::*;
use arrow::bitmap::Bitmap;
use arrow::compute::utils::combine_validities_and;

use crate::prelude::*;
Expand Down Expand Up @@ -161,43 +160,4 @@ impl<T: PolarsDataType> ChunkedArray<T> {
}
self.compute_len();
}

pub(crate) fn set_outer_validity(&mut self, validity: Option<Bitmap>) {
assert_eq!(self.chunks().len(), 1);
unsafe {
let arr = self.chunks_mut().iter_mut().next().unwrap();
*arr = arr.with_validity(validity);
}
self.compute_len();
}

pub fn with_outer_validity(mut self, validity: Option<Bitmap>) -> Self {
self.set_outer_validity(validity);
self
}

pub fn with_outer_validity_chunked(mut self, validity: BooleanChunked) -> Self {
assert_eq!(self.len(), validity.len());
if !self
.chunks
.iter()
.zip(validity.chunks.iter())
.map(|(a, b)| a.len() == b.len())
.all_equal()
|| self.chunks.len() != validity.chunks().len()
{
let ca = self.rechunk();
let validity = validity.rechunk();
ca.with_outer_validity_chunked(validity)
} else {
unsafe {
for (arr, valid) in self.chunks_mut().iter_mut().zip(validity.downcast_iter()) {
assert!(valid.validity().is_none());
*arr = arr.with_validity(Some(valid.values().clone()))
}
}
self.compute_len();
self
}
}
}
60 changes: 51 additions & 9 deletions crates/polars-core/src/chunked_array/struct_/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod frame;
use std::fmt::Write;

use arrow::array::StructArray;
use arrow::bitmap::Bitmap;
use arrow::compute::utils::combine_validities_and;
use arrow::legacy::utils::CustomIterTools;
use polars_error::{polars_ensure, PolarsResult};
Expand Down Expand Up @@ -300,12 +301,14 @@ impl StructChunked {
}

/// Set the outer nulls into the inner arrays, and clear the outer validity.
pub(crate) fn propagate_nulls(&mut self) {
// SAFETY:
// We keep length and dtypes the same.
unsafe {
for arr in self.downcast_iter_mut() {
*arr = arr.propagate_nulls()
fn propagate_nulls(&mut self) {
if self.null_count > 0 {
// SAFETY:
// We keep length and dtypes the same.
unsafe {
for arr in self.downcast_iter_mut() {
*arr = arr.propagate_nulls()
}
}
}
}
Expand Down Expand Up @@ -335,11 +338,10 @@ impl StructChunked {
}
}
self.compute_len();
}

pub fn unnest(mut self) -> DataFrame {
self.propagate_nulls();
}

pub fn unnest(self) -> DataFrame {
// SAFETY: invariants for struct are the same
unsafe { DataFrame::new_no_checks(self.fields_as_series()) }
}
Expand All @@ -351,4 +353,44 @@ impl StructChunked {
.find(|s| s.name() == name)
.ok_or_else(|| polars_err!(StructFieldNotFound: "{}", name))
}
pub(crate) fn set_outer_validity(&mut self, validity: Option<Bitmap>) {
assert_eq!(self.chunks().len(), 1);
unsafe {
let arr = self.chunks_mut().iter_mut().next().unwrap();
*arr = arr.with_validity(validity);
}
self.compute_len();
self.propagate_nulls();
}

pub fn with_outer_validity(mut self, validity: Option<Bitmap>) -> Self {
self.set_outer_validity(validity);
self
}

pub fn with_outer_validity_chunked(mut self, validity: BooleanChunked) -> Self {
assert_eq!(self.len(), validity.len());
if !self
.chunks
.iter()
.zip(validity.chunks.iter())
.map(|(a, b)| a.len() == b.len())
.all_equal()
|| self.chunks.len() != validity.chunks().len()
{
let ca = self.rechunk();
let validity = validity.rechunk();
ca.with_outer_validity_chunked(validity)
} else {
unsafe {
for (arr, valid) in self.chunks_mut().iter_mut().zip(validity.downcast_iter()) {
assert!(valid.validity().is_none());
*arr = arr.with_validity(Some(valid.values().clone()))
}
}
self.compute_len();
self.propagate_nulls();
self
}
}
}
3 changes: 1 addition & 2 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2996,8 +2996,7 @@ impl DataFrame {
let mut count = 0;
for s in &self.columns {
if cols.contains(s.name()) {
let mut ca = s.struct_()?.clone();
ca.propagate_nulls();
let ca = s.struct_()?.clone();
new_cols.extend_from_slice(&ca.fields_as_series());
count += 1;
} else {
Expand Down
12 changes: 12 additions & 0 deletions crates/polars-core/src/series/implementations/struct__.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ impl PrivateSeries for SeriesWrap<StructChunked> {
unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series {
self.0.agg_list(groups)
}

fn vec_hash(&self, build_hasher: RandomState, buf: &mut Vec<u64>) -> PolarsResult<()> {
let mut fields = self.0.fields_as_series().into_iter();

if let Some(s) = fields.next() {
s.vec_hash(build_hasher.clone(), buf)?
};
for s in fields {
s.vec_hash_combine(build_hasher.clone(), buf)?
}
Ok(())
}
}

impl SeriesTrait for SeriesWrap<StructChunked> {
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/operations/test_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import polars as pl


def test_hash_struct() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
df = df.select(pl.struct(pl.all()))

assert df.select(pl.col("a").hash())["a"].to_list() == [
8045264196180950307,
14608543421872010777,
12464129093563214397,
]

0 comments on commit ab5f8c1

Please sign in to comment.