diff --git a/arrow-arith/src/aggregate.rs b/arrow-arith/src/aggregate.rs index f526268bf49e..a81e7fc1af4a 100644 --- a/arrow-arith/src/aggregate.rs +++ b/arrow-arith/src/aggregate.rs @@ -24,7 +24,9 @@ use arrow_buffer::{ArrowNativeType, NullBuffer}; use arrow_data::bit_iterator::try_for_each_valid_idx; use arrow_schema::*; use std::borrow::BorrowMut; +use std::cmp::{self, Ordering}; use std::ops::{BitAnd, BitOr, BitXor}; +use types::ByteViewType; /// An accumulator for primitive numeric values. trait NumericAccumulator: Copy + Default { @@ -425,6 +427,47 @@ where } } +/// Helper to compute min/max of [`GenericByteViewArray`]. +/// The specialized min/max leverages the inlined values to compare the byte views. +/// `swap_cond` is the condition to swap current min/max with the new value. +/// For example, `Ordering::Greater` for max and `Ordering::Less` for min. +fn min_max_view_helper( + array: &GenericByteViewArray, + swap_cond: cmp::Ordering, +) -> Option<&T::Native> { + let null_count = array.null_count(); + if null_count == array.len() { + None + } else if null_count == 0 { + let target_idx = (0..array.len()).reduce(|acc, item| { + // SAFETY: array's length is correct so item is within bounds + let cmp = unsafe { GenericByteViewArray::compare_unchecked(array, item, array, acc) }; + if cmp == swap_cond { + item + } else { + acc + } + }); + // SAFETY: idx came from valid range `0..array.len()` + unsafe { target_idx.map(|idx| array.value_unchecked(idx)) } + } else { + let nulls = array.nulls().unwrap(); + + let target_idx = nulls.valid_indices().reduce(|acc_idx, idx| { + let cmp = + unsafe { GenericByteViewArray::compare_unchecked(array, idx, array, acc_idx) }; + if cmp == swap_cond { + idx + } else { + acc_idx + } + }); + + // SAFETY: idx came from valid range `0..array.len()` + unsafe { target_idx.map(|idx| array.value_unchecked(idx)) } + } +} + /// Returns the maximum value in the binary array, according to the natural order. pub fn max_binary(array: &GenericBinaryArray) -> Option<&[u8]> { min_max_helper::<&[u8], _, _>(array, |a, b| *a < *b) @@ -432,7 +475,7 @@ pub fn max_binary(array: &GenericBinaryArray) -> Option<& /// Returns the maximum value in the binary view array, according to the natural order. pub fn max_binary_view(array: &BinaryViewArray) -> Option<&[u8]> { - min_max_helper::<&[u8], _, _>(array, |a, b| *a < *b) + min_max_view_helper(array, Ordering::Greater) } /// Returns the minimum value in the binary array, according to the natural order. @@ -442,7 +485,7 @@ pub fn min_binary(array: &GenericBinaryArray) -> Option<& /// Returns the minimum value in the binary view array, according to the natural order. pub fn min_binary_view(array: &BinaryViewArray) -> Option<&[u8]> { - min_max_helper::<&[u8], _, _>(array, |a, b| *a > *b) + min_max_view_helper(array, Ordering::Less) } /// Returns the maximum value in the string array, according to the natural order. @@ -452,7 +495,7 @@ pub fn max_string(array: &GenericStringArray) -> Option<& /// Returns the maximum value in the string view array, according to the natural order. pub fn max_string_view(array: &StringViewArray) -> Option<&str> { - min_max_helper::<&str, _, _>(array, |a, b| *a < *b) + min_max_view_helper(array, Ordering::Greater) } /// Returns the minimum value in the string array, according to the natural order. @@ -462,7 +505,7 @@ pub fn min_string(array: &GenericStringArray) -> Option<& /// Returns the minimum value in the string view array, according to the natural order. pub fn min_string_view(array: &StringViewArray) -> Option<&str> { - min_max_helper::<&str, _, _>(array, |a, b| *a > *b) + min_max_view_helper(array, Ordering::Less) } /// Returns the sum of values in the array. diff --git a/arrow-array/src/array/byte_view_array.rs b/arrow-array/src/array/byte_view_array.rs index 7017add49722..bd8c0cebc056 100644 --- a/arrow-array/src/array/byte_view_array.rs +++ b/arrow-array/src/array/byte_view_array.rs @@ -336,6 +336,66 @@ impl GenericByteViewArray { builder.finish() } + + /// Comparing two [`GenericByteViewArray`] at index `left_idx` and `right_idx` + /// + /// Comparing two ByteView types are non-trivial. + /// It takes a bit of patience to understand why we don't just compare two &[u8] directly. + /// + /// ByteView types give us the following two advantages, and we need to be careful not to lose them: + /// (1) For string/byte smaller than 12 bytes, the entire data is inlined in the view. + /// Meaning that reading one array element requires only one memory access + /// (two memory access required for StringArray, one for offset buffer, the other for value buffer). + /// + /// (2) For string/byte larger than 12 bytes, we can still be faster than (for certain operations) StringArray/ByteArray, + /// thanks to the inlined 4 bytes. + /// Consider equality check: + /// If the first four bytes of the two strings are different, we can return false immediately (with just one memory access). + /// + /// If we directly compare two &[u8], we materialize the entire string (i.e., make multiple memory accesses), which might be unnecessary. + /// - Most of the time (eq, ord), we only need to look at the first 4 bytes to know the answer, + /// e.g., if the inlined 4 bytes are different, we can directly return unequal without looking at the full string. + /// + /// # Order check flow + /// (1) if both string are smaller than 12 bytes, we can directly compare the data inlined to the view. + /// (2) if any of the string is larger than 12 bytes, we need to compare the full string. + /// (2.1) if the inlined 4 bytes are different, we can return the result immediately. + /// (2.2) o.w., we need to compare the full string. + /// + /// # Safety + /// The left/right_idx must within range of each array + pub unsafe fn compare_unchecked( + left: &GenericByteViewArray, + left_idx: usize, + right: &GenericByteViewArray, + right_idx: usize, + ) -> std::cmp::Ordering { + let l_view = left.views().get_unchecked(left_idx); + let l_len = *l_view as u32; + + let r_view = right.views().get_unchecked(right_idx); + let r_len = *r_view as u32; + + if l_len <= 12 && r_len <= 12 { + let l_data = unsafe { GenericByteViewArray::::inline_value(l_view, l_len as usize) }; + let r_data = unsafe { GenericByteViewArray::::inline_value(r_view, r_len as usize) }; + return l_data.cmp(r_data); + } + + // one of the string is larger than 12 bytes, + // we then try to compare the inlined data first + let l_inlined_data = unsafe { GenericByteViewArray::::inline_value(l_view, 4) }; + let r_inlined_data = unsafe { GenericByteViewArray::::inline_value(r_view, 4) }; + if r_inlined_data != l_inlined_data { + return l_inlined_data.cmp(r_inlined_data); + } + + // unfortunately, we need to compare the full data + let l_full_data: &[u8] = unsafe { left.value_unchecked(left_idx).as_ref() }; + let r_full_data: &[u8] = unsafe { right.value_unchecked(right_idx).as_ref() }; + + l_full_data.cmp(r_full_data) + } } impl Debug for GenericByteViewArray { diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs index 26eb0d8d6e41..9d7874c6444d 100644 --- a/arrow-ord/src/cmp.rs +++ b/arrow-ord/src/cmp.rs @@ -579,13 +579,13 @@ impl<'a, T: ByteViewType> ArrayOrd for &'a GenericByteViewArray { return false; } - unsafe { compare_byte_view_unchecked(l.0, l.1, r.0, r.1).is_eq() } + unsafe { GenericByteViewArray::compare_unchecked(l.0, l.1, r.0, r.1).is_eq() } } fn is_lt(l: Self::Item, r: Self::Item) -> bool { // # Safety // The index is within bounds as it is checked in value() - unsafe { compare_byte_view_unchecked(l.0, l.1, r.0, r.1).is_lt() } + unsafe { GenericByteViewArray::compare_unchecked(l.0, l.1, r.0, r.1).is_lt() } } fn len(&self) -> usize { @@ -626,7 +626,7 @@ pub fn compare_byte_view( ) -> std::cmp::Ordering { assert!(left_idx < left.len()); assert!(right_idx < right.len()); - unsafe { compare_byte_view_unchecked(left, left_idx, right, right_idx) } + unsafe { GenericByteViewArray::compare_unchecked(left, left_idx, right, right_idx) } } /// Comparing two [`GenericByteViewArray`] at index `left_idx` and `right_idx` @@ -656,6 +656,7 @@ pub fn compare_byte_view( /// /// # Safety /// The left/right_idx must within range of each array +#[deprecated(note = "Use `GenericByteViewArray::compare_unchecked` instead")] pub unsafe fn compare_byte_view_unchecked( left: &GenericByteViewArray, left_idx: usize, diff --git a/arrow/benches/aggregate_kernels.rs b/arrow/benches/aggregate_kernels.rs index 9bb866f364a4..434bb4778d3d 100644 --- a/arrow/benches/aggregate_kernels.rs +++ b/arrow/benches/aggregate_kernels.rs @@ -57,8 +57,8 @@ fn add_benchmark(c: &mut Criterion) { primitive_benchmark::(c, "int64"); { - let nonnull_strings = create_string_array::(BATCH_SIZE, 0.0); - let nullable_strings = create_string_array::(BATCH_SIZE, 0.5); + let nonnull_strings = create_string_array_with_len::(BATCH_SIZE, 0.0, 16); + let nullable_strings = create_string_array_with_len::(BATCH_SIZE, 0.5, 16); c.benchmark_group("string") .throughput(Throughput::Elements(BATCH_SIZE as u64)) .bench_function("min nonnull", |b| b.iter(|| min_string(&nonnull_strings))) @@ -67,6 +67,25 @@ fn add_benchmark(c: &mut Criterion) { .bench_function("max nullable", |b| b.iter(|| max_string(&nullable_strings))); } + { + let nonnull_strings = create_string_view_array_with_len(BATCH_SIZE, 0.0, 16, false); + let nullable_strings = create_string_view_array_with_len(BATCH_SIZE, 0.5, 16, false); + c.benchmark_group("string view") + .throughput(Throughput::Elements(BATCH_SIZE as u64)) + .bench_function("min nonnull", |b| { + b.iter(|| min_string_view(&nonnull_strings)) + }) + .bench_function("max nonnull", |b| { + b.iter(|| max_string_view(&nonnull_strings)) + }) + .bench_function("min nullable", |b| { + b.iter(|| min_string_view(&nullable_strings)) + }) + .bench_function("max nullable", |b| { + b.iter(|| max_string_view(&nullable_strings)) + }); + } + { let nonnull_bools_mixed = create_boolean_array(BATCH_SIZE, 0.0, 0.5); let nonnull_bools_all_false = create_boolean_array(BATCH_SIZE, 0.0, 0.0);