Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Propagate struct outer nullability eagerly #17697

Merged
merged 3 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion crates/polars-arrow/src/array/null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ impl Array for NullArray {
}

fn with_validity(&self, _: Option<Bitmap>) -> Box<dyn Array> {
panic!("cannot set validity of a null array")
// Nulls with invalid nulls are also nulls.
self.clone().boxed()
}
}

Expand Down
14 changes: 4 additions & 10 deletions crates/polars-arrow/src/array/struct_/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ impl StructArray {
.for_each(|x| x.slice_unchecked(offset, length));
}

/// Set the outer nulls into the inner arrays, and clear the outer validity.
/// Set the outer nulls into the inner arrays.
pub fn propagate_nulls(&self) -> StructArray {
let has_nulls = self.null_count() > 0;
let mut out = self.clone();
Expand All @@ -203,16 +203,10 @@ impl StructArray {
};

for value_arr in &mut out.values {
let new = if has_nulls {
let new_validity = combine_validities_and(self.validity(), value_arr.validity());
value_arr.with_validity(new_validity)
} else {
value_arr.clone()
};

*value_arr = new;
let new_validity = combine_validities_and(self.validity(), value_arr.validity());
*value_arr = value_arr.with_validity(new_validity);
}
out.with_validity(None)
out
}

impl_sliced!();
Expand Down
40 changes: 0 additions & 40 deletions crates/polars-core/src/chunked_array/ops/downcast.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::marker::PhantomData;

use arrow::array::*;
use arrow::bitmap::Bitmap;
use arrow::compute::utils::combine_validities_and;

use crate::prelude::*;
Expand Down Expand Up @@ -161,43 +160,4 @@ impl<T: PolarsDataType> ChunkedArray<T> {
}
self.compute_len();
}

pub(crate) fn set_outer_validity(&mut self, validity: Option<Bitmap>) {
assert_eq!(self.chunks().len(), 1);
unsafe {
let arr = self.chunks_mut().iter_mut().next().unwrap();
*arr = arr.with_validity(validity);
}
self.compute_len();
}

pub fn with_outer_validity(mut self, validity: Option<Bitmap>) -> Self {
self.set_outer_validity(validity);
self
}

pub fn with_outer_validity_chunked(mut self, validity: BooleanChunked) -> Self {
assert_eq!(self.len(), validity.len());
if !self
.chunks
.iter()
.zip(validity.chunks.iter())
.map(|(a, b)| a.len() == b.len())
.all_equal()
|| self.chunks.len() != validity.chunks().len()
{
let ca = self.rechunk();
let validity = validity.rechunk();
ca.with_outer_validity_chunked(validity)
} else {
unsafe {
for (arr, valid) in self.chunks_mut().iter_mut().zip(validity.downcast_iter()) {
assert!(valid.validity().is_none());
*arr = arr.with_validity(Some(valid.values().clone()))
}
}
self.compute_len();
self
}
}
}
60 changes: 51 additions & 9 deletions crates/polars-core/src/chunked_array/struct_/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod frame;
use std::fmt::Write;

use arrow::array::StructArray;
use arrow::bitmap::Bitmap;
use arrow::compute::utils::combine_validities_and;
use arrow::legacy::utils::CustomIterTools;
use polars_error::{polars_ensure, PolarsResult};
Expand Down Expand Up @@ -300,12 +301,14 @@ impl StructChunked {
}

/// Set the outer nulls into the inner arrays, and clear the outer validity.
pub(crate) fn propagate_nulls(&mut self) {
// SAFETY:
// We keep length and dtypes the same.
unsafe {
for arr in self.downcast_iter_mut() {
*arr = arr.propagate_nulls()
fn propagate_nulls(&mut self) {
if self.null_count > 0 {
// SAFETY:
// We keep length and dtypes the same.
unsafe {
for arr in self.downcast_iter_mut() {
*arr = arr.propagate_nulls()
}
}
}
}
Expand Down Expand Up @@ -335,11 +338,10 @@ impl StructChunked {
}
}
self.compute_len();
}

pub fn unnest(mut self) -> DataFrame {
self.propagate_nulls();
}

pub fn unnest(self) -> DataFrame {
// SAFETY: invariants for struct are the same
unsafe { DataFrame::new_no_checks(self.fields_as_series()) }
}
Expand All @@ -351,4 +353,44 @@ impl StructChunked {
.find(|s| s.name() == name)
.ok_or_else(|| polars_err!(StructFieldNotFound: "{}", name))
}
pub(crate) fn set_outer_validity(&mut self, validity: Option<Bitmap>) {
assert_eq!(self.chunks().len(), 1);
unsafe {
let arr = self.chunks_mut().iter_mut().next().unwrap();
*arr = arr.with_validity(validity);
}
self.compute_len();
self.propagate_nulls();
}

pub fn with_outer_validity(mut self, validity: Option<Bitmap>) -> Self {
self.set_outer_validity(validity);
self
}

pub fn with_outer_validity_chunked(mut self, validity: BooleanChunked) -> Self {
assert_eq!(self.len(), validity.len());
if !self
.chunks
.iter()
.zip(validity.chunks.iter())
.map(|(a, b)| a.len() == b.len())
.all_equal()
|| self.chunks.len() != validity.chunks().len()
{
let ca = self.rechunk();
let validity = validity.rechunk();
ca.with_outer_validity_chunked(validity)
} else {
unsafe {
for (arr, valid) in self.chunks_mut().iter_mut().zip(validity.downcast_iter()) {
assert!(valid.validity().is_none());
*arr = arr.with_validity(Some(valid.values().clone()))
}
}
self.compute_len();
self.propagate_nulls();
self
}
}
}
3 changes: 1 addition & 2 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2996,8 +2996,7 @@ impl DataFrame {
let mut count = 0;
for s in &self.columns {
if cols.contains(s.name()) {
let mut ca = s.struct_()?.clone();
ca.propagate_nulls();
let ca = s.struct_()?.clone();
new_cols.extend_from_slice(&ca.fields_as_series());
count += 1;
} else {
Expand Down
12 changes: 12 additions & 0 deletions crates/polars-core/src/series/implementations/struct__.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ impl PrivateSeries for SeriesWrap<StructChunked> {
unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series {
self.0.agg_list(groups)
}

fn vec_hash(&self, build_hasher: RandomState, buf: &mut Vec<u64>) -> PolarsResult<()> {
let mut fields = self.0.fields_as_series().into_iter();

if let Some(s) = fields.next() {
s.vec_hash(build_hasher.clone(), buf)?
};
for s in fields {
s.vec_hash_combine(build_hasher.clone(), buf)?
}
Ok(())
}
}

impl SeriesTrait for SeriesWrap<StructChunked> {
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/operations/test_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import polars as pl


def test_hash_struct() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
df = df.select(pl.struct(pl.all()))

assert df.select(pl.col("a").hash())["a"].to_list() == [
8045264196180950307,
14608543421872010777,
12464129093563214397,
]
Loading