diff --git a/crates/polars-arrow/src/array/null.rs b/crates/polars-arrow/src/array/null.rs index 39fd2ee7c215..aece7ce6a29a 100644 --- a/crates/polars-arrow/src/array/null.rs +++ b/crates/polars-arrow/src/array/null.rs @@ -84,7 +84,8 @@ impl Array for NullArray { } fn with_validity(&self, _: Option) -> Box { - panic!("cannot set validity of a null array") + // Nulls with invalid nulls are also nulls. + self.clone().boxed() } } diff --git a/crates/polars-arrow/src/array/struct_/mod.rs b/crates/polars-arrow/src/array/struct_/mod.rs index dd99e8360b0c..08a56aa0fee1 100644 --- a/crates/polars-arrow/src/array/struct_/mod.rs +++ b/crates/polars-arrow/src/array/struct_/mod.rs @@ -194,7 +194,7 @@ impl StructArray { .for_each(|x| x.slice_unchecked(offset, length)); } - /// Set the outer nulls into the inner arrays, and clear the outer validity. + /// Set the outer nulls into the inner arrays. pub fn propagate_nulls(&self) -> StructArray { let has_nulls = self.null_count() > 0; let mut out = self.clone(); @@ -203,16 +203,10 @@ impl StructArray { }; for value_arr in &mut out.values { - let new = if has_nulls { - let new_validity = combine_validities_and(self.validity(), value_arr.validity()); - value_arr.with_validity(new_validity) - } else { - value_arr.clone() - }; - - *value_arr = new; + let new_validity = combine_validities_and(self.validity(), value_arr.validity()); + *value_arr = value_arr.with_validity(new_validity); } - out.with_validity(None) + out } impl_sliced!(); diff --git a/crates/polars-core/src/chunked_array/ops/downcast.rs b/crates/polars-core/src/chunked_array/ops/downcast.rs index a10386145291..99e469243c06 100644 --- a/crates/polars-core/src/chunked_array/ops/downcast.rs +++ b/crates/polars-core/src/chunked_array/ops/downcast.rs @@ -1,7 +1,6 @@ use std::marker::PhantomData; use arrow::array::*; -use arrow::bitmap::Bitmap; use arrow::compute::utils::combine_validities_and; use crate::prelude::*; @@ -161,43 +160,4 @@ impl ChunkedArray { } self.compute_len(); } - - pub(crate) fn set_outer_validity(&mut self, validity: Option) { - assert_eq!(self.chunks().len(), 1); - unsafe { - let arr = self.chunks_mut().iter_mut().next().unwrap(); - *arr = arr.with_validity(validity); - } - self.compute_len(); - } - - pub fn with_outer_validity(mut self, validity: Option) -> Self { - self.set_outer_validity(validity); - self - } - - pub fn with_outer_validity_chunked(mut self, validity: BooleanChunked) -> Self { - assert_eq!(self.len(), validity.len()); - if !self - .chunks - .iter() - .zip(validity.chunks.iter()) - .map(|(a, b)| a.len() == b.len()) - .all_equal() - || self.chunks.len() != validity.chunks().len() - { - let ca = self.rechunk(); - let validity = validity.rechunk(); - ca.with_outer_validity_chunked(validity) - } else { - unsafe { - for (arr, valid) in self.chunks_mut().iter_mut().zip(validity.downcast_iter()) { - assert!(valid.validity().is_none()); - *arr = arr.with_validity(Some(valid.values().clone())) - } - } - self.compute_len(); - self - } - } } diff --git a/crates/polars-core/src/chunked_array/struct_/mod.rs b/crates/polars-core/src/chunked_array/struct_/mod.rs index 008233dc3e97..ade2bb2c79f0 100644 --- a/crates/polars-core/src/chunked_array/struct_/mod.rs +++ b/crates/polars-core/src/chunked_array/struct_/mod.rs @@ -3,6 +3,7 @@ mod frame; use std::fmt::Write; use arrow::array::StructArray; +use arrow::bitmap::Bitmap; use arrow::compute::utils::combine_validities_and; use arrow::legacy::utils::CustomIterTools; use polars_error::{polars_ensure, PolarsResult}; @@ -300,12 +301,14 @@ impl StructChunked { } /// Set the outer nulls into the inner arrays, and clear the outer validity. - pub(crate) fn propagate_nulls(&mut self) { - // SAFETY: - // We keep length and dtypes the same. - unsafe { - for arr in self.downcast_iter_mut() { - *arr = arr.propagate_nulls() + fn propagate_nulls(&mut self) { + if self.null_count > 0 { + // SAFETY: + // We keep length and dtypes the same. + unsafe { + for arr in self.downcast_iter_mut() { + *arr = arr.propagate_nulls() + } } } } @@ -335,11 +338,10 @@ impl StructChunked { } } self.compute_len(); - } - - pub fn unnest(mut self) -> DataFrame { self.propagate_nulls(); + } + pub fn unnest(self) -> DataFrame { // SAFETY: invariants for struct are the same unsafe { DataFrame::new_no_checks(self.fields_as_series()) } } @@ -351,4 +353,44 @@ impl StructChunked { .find(|s| s.name() == name) .ok_or_else(|| polars_err!(StructFieldNotFound: "{}", name)) } + pub(crate) fn set_outer_validity(&mut self, validity: Option) { + assert_eq!(self.chunks().len(), 1); + unsafe { + let arr = self.chunks_mut().iter_mut().next().unwrap(); + *arr = arr.with_validity(validity); + } + self.compute_len(); + self.propagate_nulls(); + } + + pub fn with_outer_validity(mut self, validity: Option) -> Self { + self.set_outer_validity(validity); + self + } + + pub fn with_outer_validity_chunked(mut self, validity: BooleanChunked) -> Self { + assert_eq!(self.len(), validity.len()); + if !self + .chunks + .iter() + .zip(validity.chunks.iter()) + .map(|(a, b)| a.len() == b.len()) + .all_equal() + || self.chunks.len() != validity.chunks().len() + { + let ca = self.rechunk(); + let validity = validity.rechunk(); + ca.with_outer_validity_chunked(validity) + } else { + unsafe { + for (arr, valid) in self.chunks_mut().iter_mut().zip(validity.downcast_iter()) { + assert!(valid.validity().is_none()); + *arr = arr.with_validity(Some(valid.values().clone())) + } + } + self.compute_len(); + self.propagate_nulls(); + self + } + } } diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index 9b7338c9eedb..f01c86bad937 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -2996,8 +2996,7 @@ impl DataFrame { let mut count = 0; for s in &self.columns { if cols.contains(s.name()) { - let mut ca = s.struct_()?.clone(); - ca.propagate_nulls(); + let ca = s.struct_()?.clone(); new_cols.extend_from_slice(&ca.fields_as_series()); count += 1; } else { diff --git a/crates/polars-core/src/series/implementations/struct__.rs b/crates/polars-core/src/series/implementations/struct__.rs index c2c9c6f3726d..a6c775a4245d 100644 --- a/crates/polars-core/src/series/implementations/struct__.rs +++ b/crates/polars-core/src/series/implementations/struct__.rs @@ -71,6 +71,18 @@ impl PrivateSeries for SeriesWrap { unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { self.0.agg_list(groups) } + + fn vec_hash(&self, build_hasher: RandomState, buf: &mut Vec) -> PolarsResult<()> { + let mut fields = self.0.fields_as_series().into_iter(); + + if let Some(s) = fields.next() { + s.vec_hash(build_hasher.clone(), buf)? + }; + for s in fields { + s.vec_hash_combine(build_hasher.clone(), buf)? + } + Ok(()) + } } impl SeriesTrait for SeriesWrap { diff --git a/py-polars/tests/unit/operations/test_hash.py b/py-polars/tests/unit/operations/test_hash.py new file mode 100644 index 000000000000..c2c85f90f877 --- /dev/null +++ b/py-polars/tests/unit/operations/test_hash.py @@ -0,0 +1,12 @@ +import polars as pl + + +def test_hash_struct() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df = df.select(pl.struct(pl.all())) + + assert df.select(pl.col("a").hash())["a"].to_list() == [ + 8045264196180950307, + 14608543421872010777, + 12464129093563214397, + ]