Skip to content

Commit

Permalink
perf: speedup boolean filter (#13905)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jan 22, 2024
1 parent 27a4c58 commit 5059bcb
Show file tree
Hide file tree
Showing 3 changed files with 359 additions and 182 deletions.
160 changes: 160 additions & 0 deletions crates/polars-compute/src/filter/boolean.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
use super::*;

pub(super) fn filter_bitmap_and_validity(
values: &Bitmap,
validity: Option<&Bitmap>,
mask: &Bitmap,
) -> (MutableBitmap, Option<MutableBitmap>) {
if let Some(validity) = validity {
let (values, validity) = null_filter(values, validity, mask);
(values, Some(validity))
} else {
(nonnull_filter(values, mask), None)
}
}

/// # Safety
/// This assumes that the `mask_chunks` contains a number of set/true items equal
/// to `filter_count`
unsafe fn nonnull_filter_impl<I>(
values: &Bitmap,
mut mask_chunks: I,
filter_count: usize,
) -> MutableBitmap
where
I: BitChunkIterExact<u64>,
{
// TODO! we might use ChunksExact here if offset = 0.
let mut chunks = values.chunks::<u64>();
let mut new = MutableBitmap::with_capacity(filter_count);

chunks
.by_ref()
.zip(mask_chunks.by_ref())
.for_each(|(chunk, mask_chunk)| {
let ones = mask_chunk.count_ones();
let leading_ones = get_leading_ones(mask_chunk);

if ones == leading_ones {
let size = leading_ones as usize;
unsafe { new.extend_from_slice_unchecked(chunk.to_ne_bytes().as_ref(), 0, size) };
return;
}

let ones_iter = BitChunkOnes::from_known_count(mask_chunk, ones as usize);
for pos in ones_iter {
new.push_unchecked(chunk & (1 << pos) > 0);
}
});

chunks
.remainder_iter()
.zip(mask_chunks.remainder_iter())
.for_each(|(value, is_selected)| {
if is_selected {
unsafe {
new.push_unchecked(value);
};
}
});

new
}

/// # Safety
/// This assumes that the `mask_chunks` contains a number of set/true items equal
/// to `filter_count`
unsafe fn null_filter_impl<I>(
values: &Bitmap,
validity: &Bitmap,
mut mask_chunks: I,
filter_count: usize,
) -> (MutableBitmap, MutableBitmap)
where
I: BitChunkIterExact<u64>,
{
let mut chunks = values.chunks::<u64>();
let mut validity_chunks = validity.chunks::<u64>();

let mut new = MutableBitmap::with_capacity(filter_count);
let mut new_validity = MutableBitmap::with_capacity(filter_count);

chunks
.by_ref()
.zip(validity_chunks.by_ref())
.zip(mask_chunks.by_ref())
.for_each(|((chunk, validity_chunk), mask_chunk)| {
let ones = mask_chunk.count_ones();
let leading_ones = get_leading_ones(mask_chunk);

if ones == leading_ones {
let size = leading_ones as usize;

unsafe {
new.extend_from_slice_unchecked(chunk.to_ne_bytes().as_ref(), 0, size);

// safety: invariant offset + length <= slice.len()
new_validity.extend_from_slice_unchecked(
validity_chunk.to_ne_bytes().as_ref(),
0,
size,
);
}
return;
}

// this triggers a bitcount
let ones_iter = BitChunkOnes::from_known_count(mask_chunk, ones as usize);
for pos in ones_iter {
new.push_unchecked(chunk & (1 << pos) > 0);
new_validity.push_unchecked(validity_chunk & (1 << pos) > 0);
}
});

chunks
.remainder_iter()
.zip(validity_chunks.remainder_iter())
.zip(mask_chunks.remainder_iter())
.for_each(|((value, is_valid), is_selected)| {
if is_selected {
unsafe {
new.push_unchecked(value);
new_validity.push_unchecked(is_valid);
};
}
});

(new, new_validity)
}

fn null_filter(
values: &Bitmap,
validity: &Bitmap,
mask: &Bitmap,
) -> (MutableBitmap, MutableBitmap) {
assert_eq!(values.len(), mask.len());
let filter_count = mask.len() - mask.unset_bits();

let (slice, offset, length) = mask.as_slice();
if offset == 0 {
let mask_chunks = BitChunksExact::<u64>::new(slice, length);
unsafe { null_filter_impl(values, validity, mask_chunks, filter_count) }
} else {
let mask_chunks = mask.chunks::<u64>();
unsafe { null_filter_impl(values, validity, mask_chunks, filter_count) }
}
}

fn nonnull_filter(values: &Bitmap, mask: &Bitmap) -> MutableBitmap {
assert_eq!(values.len(), mask.len());
let filter_count = mask.len() - mask.unset_bits();

let (slice, offset, length) = mask.as_slice();
if offset == 0 {
let mask_chunks = BitChunksExact::<u64>::new(slice, length);
unsafe { nonnull_filter_impl(values, mask_chunks, filter_count) }
} else {
let mask_chunks = mask.chunks::<u64>();
unsafe { nonnull_filter_impl(values, mask_chunks, filter_count) }
}
}
198 changes: 16 additions & 182 deletions crates/polars-compute/src/filter/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
//! Contains operators to filter arrays such as [`filter`].
mod boolean;
mod primitive;

use arrow::array::growable::make_growable;
use arrow::array::*;
use arrow::bitmap::utils::{BitChunkIterExact, BitChunksExact, SlicesIterator};
Expand All @@ -7,7 +10,9 @@ use arrow::datatypes::ArrowDataType;
use arrow::types::simd::Simd;
use arrow::types::{BitChunkOnes, NativeType};
use arrow::with_match_primitive_type_full;
use boolean::*;
use polars_error::*;
use primitive::*;

/// Function that can filter arbitrary arrays
pub type Filter<'a> = Box<dyn Fn(&dyn Array) -> Box<dyn Array> + 'a + Send + Sync>;
Expand All @@ -21,188 +26,6 @@ fn get_leading_ones(chunk: u64) -> u32 {
}
}

/// # Safety
/// This assumes that the `mask_chunks` contains a number of set/true items equal
/// to `filter_count`
unsafe fn nonnull_filter_impl<T, I>(values: &[T], mut mask_chunks: I, filter_count: usize) -> Vec<T>
where
T: NativeType,
I: BitChunkIterExact<u64>,
{
let mut chunks = values.chunks_exact(64);
let mut new = Vec::<T>::with_capacity(filter_count);
let mut dst = new.as_mut_ptr();

chunks
.by_ref()
.zip(mask_chunks.by_ref())
.for_each(|(chunk, mask_chunk)| {
let ones = mask_chunk.count_ones();
let leading_ones = get_leading_ones(mask_chunk);

if ones == leading_ones {
let size = leading_ones as usize;
unsafe {
std::ptr::copy(chunk.as_ptr(), dst, size);
dst = dst.add(size);
}
return;
}

let ones_iter = BitChunkOnes::from_known_count(mask_chunk, ones as usize);
for pos in ones_iter {
dst.write(*chunk.get_unchecked(pos));
dst = dst.add(1);
}
});

chunks
.remainder()
.iter()
.zip(mask_chunks.remainder_iter())
.for_each(|(value, b)| {
if b {
unsafe {
dst.write(*value);
dst = dst.add(1);
};
}
});

unsafe { new.set_len(filter_count) };
new
}

/// # Safety
/// This assumes that the `mask_chunks` contains a number of set/true items equal
/// to `filter_count`
unsafe fn null_filter_impl<T, I>(
values: &[T],
validity: &Bitmap,
mut mask_chunks: I,
filter_count: usize,
) -> (Vec<T>, MutableBitmap)
where
T: NativeType,
I: BitChunkIterExact<u64>,
{
let mut chunks = values.chunks_exact(64);

let mut validity_chunks = validity.chunks::<u64>();

let mut new = Vec::<T>::with_capacity(filter_count);
let mut dst = new.as_mut_ptr();
let mut new_validity = MutableBitmap::with_capacity(filter_count);

chunks
.by_ref()
.zip(validity_chunks.by_ref())
.zip(mask_chunks.by_ref())
.for_each(|((chunk, validity_chunk), mask_chunk)| {
let ones = mask_chunk.count_ones();
let leading_ones = get_leading_ones(mask_chunk);

if ones == leading_ones {
let size = leading_ones as usize;
unsafe {
std::ptr::copy(chunk.as_ptr(), dst, size);
dst = dst.add(size);

// safety: invariant offset + length <= slice.len()
new_validity.extend_from_slice_unchecked(
validity_chunk.to_ne_bytes().as_ref(),
0,
size,
);
}
return;
}

// this triggers a bitcount
let ones_iter = BitChunkOnes::from_known_count(mask_chunk, ones as usize);
for pos in ones_iter {
dst.write(*chunk.get_unchecked(pos));
dst = dst.add(1);
new_validity.push_unchecked(validity_chunk & (1 << pos) > 0);
}
});

chunks
.remainder()
.iter()
.zip(validity_chunks.remainder_iter())
.zip(mask_chunks.remainder_iter())
.for_each(|((value, is_valid), is_selected)| {
if is_selected {
unsafe {
dst.write(*value);
dst = dst.add(1);
new_validity.push_unchecked(is_valid);
};
}
});

unsafe { new.set_len(filter_count) };
(new, new_validity)
}

fn null_filter<T: NativeType>(
values: &[T],
validity: &Bitmap,
mask: &Bitmap,
) -> (Vec<T>, MutableBitmap) {
assert_eq!(values.len(), mask.len());
let filter_count = mask.len() - mask.unset_bits();

let (slice, offset, length) = mask.as_slice();
if offset == 0 {
let mask_chunks = BitChunksExact::<u64>::new(slice, length);
unsafe { null_filter_impl(values, validity, mask_chunks, filter_count) }
} else {
let mask_chunks = mask.chunks::<u64>();
unsafe { null_filter_impl(values, validity, mask_chunks, filter_count) }
}
}

fn nonnull_filter<T: NativeType>(values: &[T], mask: &Bitmap) -> Vec<T> {
assert_eq!(values.len(), mask.len());
let filter_count = mask.len() - mask.unset_bits();

let (slice, offset, length) = mask.as_slice();
if offset == 0 {
let mask_chunks = BitChunksExact::<u64>::new(slice, length);
unsafe { nonnull_filter_impl(values, mask_chunks, filter_count) }
} else {
let mask_chunks = mask.chunks::<u64>();
unsafe { nonnull_filter_impl(values, mask_chunks, filter_count) }
}
}

fn filter_values_and_validity<T: NativeType>(
values: &[T],
validity: Option<&Bitmap>,
mask: &Bitmap,
) -> (Vec<T>, Option<MutableBitmap>) {
if let Some(validity) = validity {
let (values, validity) = null_filter(values, validity, mask);
(values, Some(validity))
} else {
(nonnull_filter(values, mask), None)
}
}

fn filter_primitive<T: NativeType + Simd>(
array: &PrimitiveArray<T>,
mask: &Bitmap,
) -> PrimitiveArray<T> {
assert_eq!(array.len(), mask.len());
let (values, validity) = filter_values_and_validity(array.values(), array.validity(), mask);
let validity = validity.map(|validity| validity.freeze());
unsafe {
PrimitiveArray::<T>::new_unchecked(array.data_type().clone(), values.into(), validity)
}
}

pub fn filter(array: &dyn Array, mask: &BooleanArray) -> PolarsResult<Box<dyn Array>> {
// The validities may be masking out `true` bits, making the filter operation
// based on the values incorrect
Expand All @@ -229,6 +52,17 @@ pub fn filter(array: &dyn Array, mask: &BooleanArray) -> PolarsResult<Box<dyn Ar
let array = array.as_any().downcast_ref().unwrap();
Ok(Box::new(filter_primitive::<$T>(array, mask.values())))
}),
Boolean => {
let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
let (values, validity) =
filter_bitmap_and_validity(array.values(), array.validity(), mask.values());
Ok(BooleanArray::new(
array.data_type().clone(),
values.freeze(),
validity.map(|v| v.freeze()),
)
.boxed())
},
BinaryView => {
let array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
let views = array.views();
Expand Down
Loading

0 comments on commit 5059bcb

Please sign in to comment.