From 1c95533787dad50dc0394b46a19ac2e2e72dbbb6 Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Tue, 31 Oct 2023 09:31:55 +0300 Subject: [PATCH 01/13] HashJoin partial batch emitting --- .../physical-plan/src/joins/hash_join.rs | 284 ++++++++++++------ .../src/joins/nested_loop_join.rs | 6 +- .../src/joins/symmetric_hash_join.rs | 4 +- datafusion/physical-plan/src/joins/utils.rs | 58 ++-- .../join_disable_repartition_joins.slt | 10 +- 5 files changed, 233 insertions(+), 129 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 9aa776fe054c..bdff46c49853 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -78,6 +78,9 @@ use futures::{ready, Stream, StreamExt, TryStreamExt}; type JoinLeftData = (JoinHashMap, RecordBatch, MemoryReservation); +/// Tuple representing last matched probe-build side indices for partial join output +type MatchedIndicesPair = Option<(usize, usize)>; + /// Join execution plan executes partitions in parallel and combines them into a set of /// partitions. /// @@ -465,6 +468,8 @@ impl ExecutionPlan for HashJoinExec { } }; + let batch_size = context.session_config().batch_size(); + let reservation = MemoryConsumer::new(format!("HashJoinStream[{partition}]")) .register(context.memory_pool()); @@ -487,6 +492,9 @@ impl ExecutionPlan for HashJoinExec { null_equals_null: self.null_equals_null, is_exhausted: false, reservation, + batch_size, + last_matched_indices: None, + probe_batch: None, })) } @@ -682,6 +690,15 @@ struct HashJoinStream { null_equals_null: bool, /// Memory reservation reservation: MemoryReservation, + /// Batch size + batch_size: usize, + /// Current probe batch + probe_batch: Option, + /// In case joining current probe batch with build side may produce more than `batch_size` records + /// (cross-join due to key duplication on build side) `HashJoinStream` saves last matched indices + /// and emits result batch to upstream operator. + /// On next poll these indices are used to skip already matched rows. + last_matched_indices: MatchedIndicesPair, } impl RecordBatchStream for HashJoinStream { @@ -734,7 +751,9 @@ pub fn build_equal_condition_join_indices( filter: Option<&JoinFilter>, build_side: JoinSide, deleted_offset: Option, -) -> Result<(UInt64Array, UInt32Array)> { + output_limit: usize, + start_indices: MatchedIndicesPair, +) -> Result<(UInt64Array, UInt32Array, MatchedIndicesPair)> { let keys_values = probe_on .iter() .map(|c| Ok(c.evaluate(probe_batch)?.into_array(probe_batch.num_rows()))) @@ -783,16 +802,37 @@ pub fn build_equal_condition_join_indices( // With this approach, the lexicographic order on both the probe side and the build side is preserved. let hash_map = build_hashmap.get_map(); let next_chain = build_hashmap.get_list(); - for (row, hash_value) in hash_values.iter().enumerate().rev() { - // Get the hash and find it in the build index + + let mut output_tuples = 0_usize; + let mut last_matched_indices = None; + + let (initial_probe, initial_build) = + start_indices.map_or_else(|| (0, 0), |pair| pair); + + 'probe: for (row, hash_value) in hash_values.iter().enumerate().skip(initial_probe) { + let index = if start_indices.is_some() && row == initial_probe { + // in case of partially skipped input -- calculating next build index + // using last matched pair + let next = next_chain[initial_build]; + if next == 0 { + continue; + } + Some(next) + } else if let Some((_, index)) = + hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash) + { + // otherwise -- checking build hashmap for precense of current hash_value + Some(*index) + } else { + None + }; // For every item on the build and probe we check if it matches // This possibly contains rows with hash collisions, // So we have to check here whether rows are equal or not - if let Some((_, index)) = - hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash) - { - let mut i = *index - 1; + if let Some(index) = index { + let mut i = index - 1; + loop { let build_row_value = if let Some(offset) = deleted_offset { // This arguments means that we prune the next index way before here. @@ -806,8 +846,17 @@ pub fn build_equal_condition_join_indices( }; build_indices.append(build_row_value); probe_indices.append(row as u32); + + output_tuples += 1; + + if output_tuples >= output_limit { + last_matched_indices = Some((row, i as usize)); + break 'probe; + } + // Follow the chain to get the next index value let next = next_chain[build_row_value as usize]; + if next == 0 { // end of list break; @@ -816,9 +865,13 @@ pub fn build_equal_condition_join_indices( } } } - // Reversing both sets of indices - build_indices.as_slice_mut().reverse(); - probe_indices.as_slice_mut().reverse(); + + // if both probe and build sides have been scanned -- return None + if last_matched_indices + .is_some_and(|(probe, build)| probe == probe_batch.num_rows() - 1 && build == 0) + { + last_matched_indices = None + }; let left: UInt64Array = PrimitiveArray::new(build_indices.finish().into(), None); let right: UInt32Array = PrimitiveArray::new(probe_indices.finish().into(), None); @@ -837,13 +890,15 @@ pub fn build_equal_condition_join_indices( (left, right) }; - equal_rows_arr( + let matched_indices = equal_rows_arr( &left, &right, &build_join_values, &keys_values, null_equals_null, - ) + )?; + + Ok((matched_indices.0, matched_indices.1, last_matched_indices)) } // version of eq_dyn supporting equality on null arrays @@ -942,107 +997,140 @@ impl HashJoinStream { } }); let mut hashes_buffer = vec![]; - self.right - .poll_next_unpin(cx) - .map(|maybe_batch| match maybe_batch { - // one right batch in the join loop - Some(Ok(batch)) => { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - let timer = self.join_metrics.join_time.timer(); - - // get the matched two indices for the on condition - let left_right_indices = build_equal_condition_join_indices( - &left_data.0, - &left_data.1, - &batch, - &self.on_left, - &self.on_right, - &self.random_state, - self.null_equals_null, - &mut hashes_buffer, - self.filter.as_ref(), - JoinSide::Left, - None, - ); - let result = match left_right_indices { - Ok((left_side, right_side)) => { - // set the left bitmap - // and only left, full, left semi, left anti need the left bitmap - if need_produce_result_in_final(self.join_type) { - left_side.iter().flatten().for_each(|x| { - visited_left_side.set_bit(x as usize, true); - }); - } - - // adjust the two side indices base on the join type - let (left_side, right_side) = adjust_indices_by_join_type( - left_side, - right_side, - batch.num_rows(), - self.join_type, - ); - - let result = build_batch_from_indices( - &self.schema, - &left_data.1, - &batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - Some(result) - } - Err(err) => Some(exec_err!( - "Fail to build join indices in HashJoinExec, error:{err}" - )), - }; - timer.done(); - result + // Fetch next probe batch + if self.probe_batch.is_none() { + match self.right.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + self.probe_batch = Some(batch); } - None => { - let timer = self.join_metrics.join_time.timer(); - if need_produce_result_in_final(self.join_type) && !self.is_exhausted - { - // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = get_final_indices_from_bit_map( - visited_left_side, + Poll::Ready(None) => { + self.probe_batch = None; + } + Poll::Ready(Some(err)) => return Poll::Ready(Some(err)), + Poll::Pending => return Poll::Pending, + } + } + + let output_batch = match &self.probe_batch { + // one right batch in the join loop + Some(batch) => { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + let timer = self.join_metrics.join_time.timer(); + + // get the matched two indices for the on condition + let left_right_indices = build_equal_condition_join_indices( + &left_data.0, + &left_data.1, + batch, + &self.on_left, + &self.on_right, + &self.random_state, + self.null_equals_null, + &mut hashes_buffer, + self.filter.as_ref(), + JoinSide::Left, + None, + self.batch_size, + self.last_matched_indices, + ); + + let result = match left_right_indices { + Ok((left_side, right_side, last_matched_indices)) => { + // set the left bitmap + // and only left, full, left semi, left anti need the left bitmap + if need_produce_result_in_final(self.join_type) { + left_side.iter().flatten().for_each(|x| { + visited_left_side.set_bit(x as usize, true); + }); + } + + // adjust the two side indices base on the join type + // no need to adjust `self.last_matched_indices.0` (if some) + // as it has been joined while previous iteration + let adjust_range = + match (self.last_matched_indices, last_matched_indices) { + (None, None) => 0..batch.num_rows(), + (None, Some((range_end, _))) => 0..range_end + 1, + (Some((range_start, _)), None) => { + range_start + 1..batch.num_rows() + } + (Some((range_start, _)), Some((range_end, _))) => { + range_start + 1..range_end + 1 + } + }; + + let (left_side, right_side) = adjust_indices_by_join_type( + left_side, + right_side, + adjust_range, self.join_type, ); - let empty_right_batch = - RecordBatch::new_empty(self.right.schema()); - // use the left and right indices to produce the batch result + let result = build_batch_from_indices( &self.schema, &left_data.1, - &empty_right_batch, + batch, &left_side, &right_side, &self.column_indices, JoinSide::Left, ); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); - if let Ok(ref batch) = result { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); + if last_matched_indices.is_none() { + self.probe_batch = None; + }; + self.last_matched_indices = last_matched_indices; - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - timer.done(); - self.is_exhausted = true; Some(result) - } else { - // end of the join loop - None } + Err(err) => Some(exec_err!( + "Fail to build join indices in HashJoinExec, error:{err}" + )), + }; + + timer.done(); + result + } + None => { + let timer = self.join_metrics.join_time.timer(); + if need_produce_result_in_final(self.join_type) && !self.is_exhausted { + // use the global left bitmap to produce the left indices and right indices + let (left_side, right_side) = + get_final_indices_from_bit_map(visited_left_side, self.join_type); + let empty_right_batch = RecordBatch::new_empty(self.right.schema()); + // use the left and right indices to produce the batch result + let result = build_batch_from_indices( + &self.schema, + &left_data.1, + &empty_right_batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + + if let Ok(ref batch) = result { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + timer.done(); + self.is_exhausted = true; + Some(result) + } else { + // end of the join loop + None } - Some(err) => Some(err), - }) + } + }; + + Poll::Ready(output_batch) } } @@ -2406,7 +2494,7 @@ mod tests { }, left, ); - let (l, r) = build_equal_condition_join_indices( + let (l, r, _) = build_equal_condition_join_indices( &left_data.0, &left_data.1, &right, @@ -2418,6 +2506,8 @@ mod tests { None, JoinSide::Left, None, + 64, + None, )?; let mut left_ids = UInt64Builder::with_capacity(0); diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index a113066e39d1..73fd5c1caec7 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -670,20 +670,20 @@ fn adjust_indices_by_join_type( // matched // unmatched right row will be produced in this batch let right_unmatched_indices = - get_anti_indices(count_right_batch, &right_indices); + get_anti_indices(0..count_right_batch, &right_indices); // combine the matched and unmatched right result together append_right_indices(left_indices, right_indices, right_unmatched_indices) } JoinType::RightSemi => { // need to remove the duplicated record in the right side - let right_indices = get_semi_indices(count_right_batch, &right_indices); + let right_indices = get_semi_indices(0..count_right_batch, &right_indices); // the left_indices will not be used later for the `right semi` join (left_indices, right_indices) } JoinType::RightAnti => { // need to remove the duplicated record in the right side // get the anti index for the right side - let right_indices = get_anti_indices(count_right_batch, &right_indices); + let right_indices = get_anti_indices(0..count_right_batch, &right_indices); // the left_indices will not be used later for the `right anti` join (left_indices, right_indices) } diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 00d43aead434..5f0db8117846 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -818,7 +818,7 @@ pub(crate) fn join_with_probe_batch( if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { return Ok(None); } - let (build_indices, probe_indices) = build_equal_condition_join_indices( + let (build_indices, probe_indices, _) = build_equal_condition_join_indices( &build_hash_joiner.hashmap, &build_hash_joiner.input_buffer, probe_batch, @@ -830,6 +830,8 @@ pub(crate) fn join_with_probe_batch( filter, build_hash_joiner.build_side, Some(build_hash_joiner.deleted_offset), + usize::MAX, + None, )?; if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) { record_visited_indices( diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index cf150ddf575f..01d43c2b51b1 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -19,6 +19,7 @@ use std::collections::HashSet; use std::future::Future; +use std::ops::Range; use std::sync::Arc; use std::task::{Context, Poll}; use std::usize; @@ -847,7 +848,7 @@ pub(crate) fn build_batch_from_indices( pub(crate) fn adjust_indices_by_join_type( left_indices: UInt64Array, right_indices: UInt32Array, - count_right_batch: usize, + adjust_range: Range, join_type: JoinType, ) -> (UInt64Array, UInt32Array) { match join_type { @@ -863,21 +864,20 @@ pub(crate) fn adjust_indices_by_join_type( JoinType::Right | JoinType::Full => { // matched // unmatched right row will be produced in this batch - let right_unmatched_indices = - get_anti_indices(count_right_batch, &right_indices); + let right_unmatched_indices = get_anti_indices(adjust_range, &right_indices); // combine the matched and unmatched right result together append_right_indices(left_indices, right_indices, right_unmatched_indices) } JoinType::RightSemi => { // need to remove the duplicated record in the right side - let right_indices = get_semi_indices(count_right_batch, &right_indices); + let right_indices = get_semi_indices(adjust_range, &right_indices); // the left_indices will not be used later for the `right semi` join (left_indices, right_indices) } JoinType::RightAnti => { // need to remove the duplicated record in the right side // get the anti index for the right side - let right_indices = get_anti_indices(count_right_batch, &right_indices); + let right_indices = get_anti_indices(adjust_range, &right_indices); // the left_indices will not be used later for the `right anti` join (left_indices, right_indices) } @@ -919,20 +919,26 @@ pub(crate) fn append_right_indices( } } -/// Get unmatched and deduplicated indices +/// Get unmatched and deduplicated indices for specified range of indices pub(crate) fn get_anti_indices( - row_count: usize, + rg: Range, input_indices: &UInt32Array, ) -> UInt32Array { - let mut bitmap = BooleanBufferBuilder::new(row_count); - bitmap.append_n(row_count, false); - input_indices.iter().flatten().for_each(|v| { - bitmap.set_bit(v as usize, true); - }); + let mut bitmap = BooleanBufferBuilder::new(rg.len()); + bitmap.append_n(rg.len(), false); + input_indices + .iter() + .flatten() + .map(|v| v as usize) + .filter(|v| rg.contains(v)) + .for_each(|v| { + bitmap.set_bit(v - rg.start, true); + }); + + let offset = rg.start; // get the anti index - (0..row_count) - .filter_map(|idx| (!bitmap.get_bit(idx)).then_some(idx as u32)) + (rg).filter_map(|idx| (!bitmap.get_bit(idx - offset)).then_some(idx as u32)) .collect::() } @@ -953,20 +959,26 @@ pub(crate) fn get_anti_u64_indices( .collect::() } -/// Get matched and deduplicated indices +/// Get matched and deduplicated indices for specified range of indices pub(crate) fn get_semi_indices( - row_count: usize, + rg: Range, input_indices: &UInt32Array, ) -> UInt32Array { - let mut bitmap = BooleanBufferBuilder::new(row_count); - bitmap.append_n(row_count, false); - input_indices.iter().flatten().for_each(|v| { - bitmap.set_bit(v as usize, true); - }); + let mut bitmap = BooleanBufferBuilder::new(rg.len()); + bitmap.append_n(rg.len(), false); + input_indices + .iter() + .flatten() + .map(|v| v as usize) + .filter(|v| rg.contains(v)) + .for_each(|v| { + bitmap.set_bit(v - rg.start, true); + }); + + let offset = rg.start; // get the semi index - (0..row_count) - .filter_map(|idx| (bitmap.get_bit(idx)).then_some(idx as u32)) + (rg).filter_map(|idx| (bitmap.get_bit(idx - offset)).then_some(idx as u32)) .collect::() } diff --git a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt index 1312f2916ed6..e5d4c25f48c8 100644 --- a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt +++ b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt @@ -72,11 +72,11 @@ SELECT t1.a, t1.b, t1.c, t2.a as a2 ON t1.d = t2.d ORDER BY a2, t2.b LIMIT 5 ---- -0 0 0 0 -0 0 2 0 -0 0 3 0 -0 0 6 0 -0 0 20 0 +1 3 95 0 +1 3 93 0 +1 3 92 0 +1 3 81 0 +1 3 76 0 query TT EXPLAIN SELECT t2.a as a2, t2.b From 5acdd6d1302f9be01d3b2a9188d815264edc2a73 Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Thu, 2 Nov 2023 11:28:58 +0300 Subject: [PATCH 02/13] batch splitting tests --- .../physical-plan/src/joins/hash_join.rs | 140 +++++++++++++++++- .../src/joins/nested_loop_join.rs | 6 +- datafusion/physical-plan/src/joins/utils.rs | 76 ++++++---- 3 files changed, 188 insertions(+), 34 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index a20e643b2e4a..265315d2555e 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -1266,7 +1266,9 @@ mod tests { use arrow::array::{ArrayRef, Date32Array, Int32Array, UInt32Builder, UInt64Builder}; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue}; + use datafusion_common::{ + assert_batches_eq, assert_batches_sorted_eq, assert_contains, ScalarValue, + }; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::Literal; use hashbrown::raw::RawTable; @@ -2973,6 +2975,142 @@ mod tests { } } + #[tokio::test] + async fn join_splitted_batch() { + let left = build_table( + ("a1", &vec![1, 2, 3, 4]), + ("b1", &vec![1, 1, 1, 1]), + ("c1", &vec![0, 0, 0, 0]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40, 50]), + ("b2", &vec![1, 1, 1, 1, 1]), + ("c2", &vec![0, 0, 0, 0, 0]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", &right.schema()).unwrap(), + )]; + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::RightSemi, + JoinType::RightAnti, + JoinType::LeftSemi, + JoinType::LeftAnti, + ]; + let expected_resultset_records = 20; + let common_result = [ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 4 | 1 | 0 | 10 | 1 | 0 |", + "| 3 | 1 | 0 | 10 | 1 | 0 |", + "| 2 | 1 | 0 | 10 | 1 | 0 |", + "| 1 | 1 | 0 | 10 | 1 | 0 |", + "| 4 | 1 | 0 | 20 | 1 | 0 |", + "| 3 | 1 | 0 | 20 | 1 | 0 |", + "| 2 | 1 | 0 | 20 | 1 | 0 |", + "| 1 | 1 | 0 | 20 | 1 | 0 |", + "| 4 | 1 | 0 | 30 | 1 | 0 |", + "| 3 | 1 | 0 | 30 | 1 | 0 |", + "| 2 | 1 | 0 | 30 | 1 | 0 |", + "| 1 | 1 | 0 | 30 | 1 | 0 |", + "| 4 | 1 | 0 | 40 | 1 | 0 |", + "| 3 | 1 | 0 | 40 | 1 | 0 |", + "| 2 | 1 | 0 | 40 | 1 | 0 |", + "| 1 | 1 | 0 | 40 | 1 | 0 |", + "| 4 | 1 | 0 | 50 | 1 | 0 |", + "| 3 | 1 | 0 | 50 | 1 | 0 |", + "| 2 | 1 | 0 | 50 | 1 | 0 |", + "| 1 | 1 | 0 | 50 | 1 | 0 |", + "+----+----+----+----+----+----+", + ]; + let left_batch = [ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 1 | 1 | 0 |", + "| 2 | 1 | 0 |", + "| 3 | 1 | 0 |", + "| 4 | 1 | 0 |", + "+----+----+----+", + ]; + let right_batch = [ + "+----+----+----+", + "| a2 | b2 | c2 |", + "+----+----+----+", + "| 10 | 1 | 0 |", + "| 20 | 1 | 0 |", + "| 30 | 1 | 0 |", + "| 40 | 1 | 0 |", + "| 50 | 1 | 0 |", + "+----+----+----+", + ]; + let right_empty = [ + "+----+----+----+", + "| a2 | b2 | c2 |", + "+----+----+----+", + "+----+----+----+", + ]; + let left_empty = [ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "+----+----+----+", + ]; + + // validation of partial join results output for different batch_size setting + for join_type in join_types { + for batch_size in (1..21).rev() { + let session_config = SessionConfig::default().with_batch_size(batch_size); + let task_ctx = TaskContext::default().with_session_config(session_config); + let task_ctx = Arc::new(task_ctx); + + let join = + join(left.clone(), right.clone(), on.clone(), &join_type, false) + .unwrap(); + + let stream = join.execute(0, task_ctx).unwrap(); + let batches = common::collect(stream).await.unwrap(); + + // For inner/right join expected batch count equals ceil_div result, + // as there is no need to append non-joined build side data. + // For other join types it'll be ceil_div + 1 -- for additional batch + // containing not visited build side rows (empty in this test case). + let expected_batch_count = match join_type { + JoinType::Inner + | JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti => { + (expected_resultset_records + batch_size - 1) / batch_size + } + _ => (expected_resultset_records + batch_size - 1) / batch_size + 1, + }; + assert_eq!( + batches.len(), + expected_batch_count, + "expected {} output batches for {} join with batch_size = {}", + expected_batch_count, + join_type, + batch_size + ); + + let expected = match join_type { + JoinType::RightSemi => right_batch.to_vec(), + JoinType::RightAnti => right_empty.to_vec(), + JoinType::LeftSemi => left_batch.to_vec(), + JoinType::LeftAnti => left_empty.to_vec(), + _ => common_result.to_vec(), + }; + assert_batches_eq!(expected, &batches); + } + } + } + #[tokio::test] async fn single_partition_join_overallocation() -> Result<()> { let left = build_table( diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 73fd5c1caec7..e629ab59278a 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -648,20 +648,20 @@ fn adjust_indices_by_join_type( // matched // unmatched left row will be produced in this batch let left_unmatched_indices = - get_anti_u64_indices(count_left_batch, &left_indices); + get_anti_u64_indices(0..count_left_batch, &left_indices); // combine the matched and unmatched left result together append_left_indices(left_indices, right_indices, left_unmatched_indices) } JoinType::LeftSemi => { // need to remove the duplicated record in the left side - let left_indices = get_semi_u64_indices(count_left_batch, &left_indices); + let left_indices = get_semi_u64_indices(0..count_left_batch, &left_indices); // the right_indices will not be used later for the `left semi` join (left_indices, right_indices) } JoinType::LeftAnti => { // need to remove the duplicated record in the left side // get the anti index for the left side - let left_indices = get_anti_u64_indices(count_left_batch, &left_indices); + let left_indices = get_anti_u64_indices(0..count_left_batch, &left_indices); // the right_indices will not be used later for the `left anti` join (left_indices, right_indices) } diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 01d43c2b51b1..695111d79dd1 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -921,81 +921,97 @@ pub(crate) fn append_right_indices( /// Get unmatched and deduplicated indices for specified range of indices pub(crate) fn get_anti_indices( - rg: Range, + range: Range, input_indices: &UInt32Array, ) -> UInt32Array { - let mut bitmap = BooleanBufferBuilder::new(rg.len()); - bitmap.append_n(rg.len(), false); + let mut bitmap = BooleanBufferBuilder::new(range.len()); + bitmap.append_n(range.len(), false); input_indices .iter() .flatten() .map(|v| v as usize) - .filter(|v| rg.contains(v)) + .filter(|v| range.contains(v)) .for_each(|v| { - bitmap.set_bit(v - rg.start, true); + bitmap.set_bit(v - range.start, true); }); - let offset = rg.start; + let offset = range.start; // get the anti index - (rg).filter_map(|idx| (!bitmap.get_bit(idx - offset)).then_some(idx as u32)) + (range) + .filter_map(|idx| (!bitmap.get_bit(idx - offset)).then_some(idx as u32)) .collect::() } /// Get unmatched and deduplicated indices pub(crate) fn get_anti_u64_indices( - row_count: usize, + range: Range, input_indices: &UInt64Array, ) -> UInt64Array { - let mut bitmap = BooleanBufferBuilder::new(row_count); - bitmap.append_n(row_count, false); - input_indices.iter().flatten().for_each(|v| { - bitmap.set_bit(v as usize, true); - }); + let mut bitmap = BooleanBufferBuilder::new(range.len()); + bitmap.append_n(range.len(), false); + input_indices + .iter() + .flatten() + .map(|v| v as usize) + .filter(|v| range.contains(v)) + .for_each(|v| { + bitmap.set_bit(v - range.start, true); + }); + + let offset = range.start; // get the anti index - (0..row_count) - .filter_map(|idx| (!bitmap.get_bit(idx)).then_some(idx as u64)) + (range) + .filter_map(|idx| (!bitmap.get_bit(idx - offset)).then_some(idx as u64)) .collect::() } /// Get matched and deduplicated indices for specified range of indices pub(crate) fn get_semi_indices( - rg: Range, + range: Range, input_indices: &UInt32Array, ) -> UInt32Array { - let mut bitmap = BooleanBufferBuilder::new(rg.len()); - bitmap.append_n(rg.len(), false); + let mut bitmap = BooleanBufferBuilder::new(range.len()); + bitmap.append_n(range.len(), false); input_indices .iter() .flatten() .map(|v| v as usize) - .filter(|v| rg.contains(v)) + .filter(|v| range.contains(v)) .for_each(|v| { - bitmap.set_bit(v - rg.start, true); + bitmap.set_bit(v - range.start, true); }); - let offset = rg.start; + let offset = range.start; // get the semi index - (rg).filter_map(|idx| (bitmap.get_bit(idx - offset)).then_some(idx as u32)) + (range) + .filter_map(|idx| (bitmap.get_bit(idx - offset)).then_some(idx as u32)) .collect::() } /// Get matched and deduplicated indices pub(crate) fn get_semi_u64_indices( - row_count: usize, + range: Range, input_indices: &UInt64Array, ) -> UInt64Array { - let mut bitmap = BooleanBufferBuilder::new(row_count); - bitmap.append_n(row_count, false); - input_indices.iter().flatten().for_each(|v| { - bitmap.set_bit(v as usize, true); - }); + let mut bitmap = BooleanBufferBuilder::new(range.len()); + bitmap.append_n(range.len(), false); + input_indices + .iter() + .flatten() + .map(|v| v as usize) + .filter(|v| range.contains(v)) + .for_each(|v| { + bitmap.set_bit(v - range.start, true); + }); + + let offset = range.start; // get the semi index - (0..row_count) - .filter_map(|idx| (bitmap.get_bit(idx)).then_some(idx as u64)) + (range) + .filter_map(|idx| (bitmap.get_bit(idx - offset)).then_some(idx as u64)) .collect::() } From 7fbe9182adac8cc4719f76b8fe87c502381cf45e Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Fri, 3 Nov 2023 22:46:14 +0300 Subject: [PATCH 03/13] stream state & extended tests --- datafusion/physical-plan/Cargo.toml | 1 + .../physical-plan/src/joins/hash_join.rs | 360 ++++++++++++------ .../src/joins/symmetric_hash_join.rs | 15 +- datafusion/physical-plan/src/lib.rs | 2 + 4 files changed, 262 insertions(+), 116 deletions(-) diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 82c8f49a764f..b8d8b6d2d61b 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -59,5 +59,6 @@ uuid = { version = "^1.2", features = ["v4"] } [dev-dependencies] rstest = { workspace = true } +rstest_reuse = "0.6.0" termtree = "0.4.1" tokio = { version = "1.28", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 265315d2555e..6aab0b3c358f 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -19,6 +19,7 @@ use std::fmt; use std::mem::size_of; +use std::ops::Range; use std::sync::Arc; use std::task::Poll; use std::{any::Any, usize, vec}; @@ -77,9 +78,6 @@ use futures::{ready, Stream, StreamExt, TryStreamExt}; type JoinLeftData = (JoinHashMap, RecordBatch, MemoryReservation); -/// Tuple representing last matched probe-build side indices for partial join output -type MatchedIndicesPair = Option<(usize, usize)>; - /// Join execution plan: Evaluates eqijoin predicates in parallel on multiple /// partitions using a hash table and an optional filter list to apply post /// join. @@ -607,8 +605,8 @@ impl ExecutionPlan for HashJoinExec { is_exhausted: false, reservation, batch_size, - last_matched_indices: None, probe_batch: None, + state: HashJoinStreamState::default(), })) } @@ -774,6 +772,97 @@ where Ok(()) } +// State for storing left/right side indices used for partial batch output +// & producing ranges for adjusting indices +#[derive(Debug, Default)] +pub(crate) struct HashJoinStreamState { + // total rows in current probe batch + probe_rows: usize, + // saved probe-build indices to resume matching from + last_matched_indices: Option<(usize, usize)>, + // current iteration has been updated + matched_indices_updated: bool, + // tracking last joined probe side index seen for further indices adjustment + last_joined_probe_index: Option, + // tracking last joined probe side index seen for further indices adjustment + prev_joined_probe_index: Option, +} + +impl HashJoinStreamState { + // set total probe rows to process + pub(crate) fn set_probe_rows(&mut self, probe_rows: usize) { + self.probe_rows = probe_rows; + } + // obtain last_matched_indices -- initial point to resume matching + fn start_mathching_iteration(&self) -> (usize, usize) { + self.last_matched_indices + .map_or_else(|| (0, 0), |pair| pair) + } + + // if current probe batch processing is in partial-output state + fn partial_output(&self) -> bool { + self.last_matched_indices.is_some() + } + + // if current probe batch processing completed -- all probe rows have been joined to build rows + pub(crate) fn is_completed(&self) -> bool { + self.last_matched_indices + .is_some_and(|(probe, build)| probe + 1 >= self.probe_rows && build == 0) + } + + // saving next probe-build indices to start next iteration of matching + fn update_matching_iteration(&mut self, probe_idx: usize, build_idx: usize) { + self.last_matched_indices = Some((probe_idx, build_idx)); + self.matched_indices_updated = true; + } + + // updating state after matching iteration has been performed + fn finalize_matching_iteration(&mut self, joined_right_side: &UInt32Array) { + // if there were no intermediate updates of matched inidices, during current iteration, + // setting indices like whole current batch has been scanned + if !self.matched_indices_updated { + self.last_matched_indices = Some((self.probe_rows, 0)); + } + self.matched_indices_updated = false; + + // advancing joined probe-side indices + self.prev_joined_probe_index = self.last_joined_probe_index; + if !joined_right_side.is_empty() { + self.last_joined_probe_index = + Some(joined_right_side.value(joined_right_side.len() - 1) as usize); + } + } + + pub(crate) fn reset_state(&mut self) { + self.probe_rows = 0; + self.last_matched_indices = None; + self.last_joined_probe_index = None; + self.matched_indices_updated = false; + } + + // The goals for different join types are: + // 1) Right & FullJoin -- to append all missing probe-side indices between + // previous (excluding) and current joined indices. + // 2) SemiJoin -- deduplicate probe indices in range between previous + // (excluding) and current joined indices. + // 3) AntiJoin -- return only missing indices in range between + // previous and current joined indices. + // Inclusion/exclusion of the indices themselves don't matter + // As a result -- partial adjustment range can be produced based only on + // joined (matched with filters applied) probe side indices, excluding starting one + // (left from previous iteration) + pub(crate) fn adjust_range(&self) -> Range { + let rg_start = self.prev_joined_probe_index.map_or(0, |v| v + 1); + let rg_end = if self.is_completed() { + self.probe_rows + } else { + self.last_joined_probe_index.map_or(0, |v| v + 1) + }; + + rg_start..rg_end + } +} + /// A stream that issues [RecordBatch]es as they arrive from the right of the join. struct HashJoinStream { /// Input schema @@ -809,10 +898,10 @@ struct HashJoinStream { /// Current probe batch probe_batch: Option, /// In case joining current probe batch with build side may produce more than `batch_size` records - /// (cross-join due to key duplication on build side) `HashJoinStream` saves last matched indices + /// (cross-join due to key duplication on build side) `HashJoinStream` saves its state /// and emits result batch to upstream operator. - /// On next poll these indices are used to skip already matched rows. - last_matched_indices: MatchedIndicesPair, + /// On next poll these indices are used to skip already matched rows and adjusted probe-side indices. + state: HashJoinStreamState, } impl RecordBatchStream for HashJoinStream { @@ -853,7 +942,7 @@ impl RecordBatchStream for HashJoinStream { // Build indices: 4, 5, 6, 6 // Probe indices: 3, 3, 4, 5 #[allow(clippy::too_many_arguments)] -pub fn build_equal_condition_join_indices( +pub(crate) fn build_equal_condition_join_indices( build_hashmap: &T, build_input_buffer: &RecordBatch, probe_batch: &RecordBatch, @@ -866,8 +955,8 @@ pub fn build_equal_condition_join_indices( build_side: JoinSide, deleted_offset: Option, output_limit: usize, - start_indices: MatchedIndicesPair, -) -> Result<(UInt64Array, UInt32Array, MatchedIndicesPair)> { + state: &mut HashJoinStreamState, +) -> Result<(UInt64Array, UInt32Array)> { let keys_values = probe_on .iter() .map(|c| Ok(c.evaluate(probe_batch)?.into_array(probe_batch.num_rows()))) @@ -918,20 +1007,18 @@ pub fn build_equal_condition_join_indices( let next_chain = build_hashmap.get_list(); let mut output_tuples = 0_usize; - let mut last_matched_indices = None; - let (initial_probe, initial_build) = - start_indices.map_or_else(|| (0, 0), |pair| pair); + // Get starting point in case resuming current probe-batch + let (initial_probe, initial_build) = state.start_mathching_iteration(); 'probe: for (row, hash_value) in hash_values.iter().enumerate().skip(initial_probe) { - let index = if start_indices.is_some() && row == initial_probe { - // in case of partially skipped input -- calculating next build index - // using last matched pair - let next = next_chain[initial_build]; - if next == 0 { + let index = if state.partial_output() && row == initial_probe { + // using build index from state for the first row + // in case of partially skipped input + if initial_build == 0 { continue; } - Some(next) + Some(initial_build as u64) } else if let Some((_, index)) = hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash) { @@ -960,17 +1047,16 @@ pub fn build_equal_condition_join_indices( }; build_indices.append(build_row_value); probe_indices.append(row as u32); - output_tuples += 1; + // Follow the chain to get the next index value + let next = next_chain[build_row_value as usize]; + if output_tuples >= output_limit { - last_matched_indices = Some((row, i as usize)); + state.update_matching_iteration(row, next as usize); break 'probe; } - // Follow the chain to get the next index value - let next = next_chain[build_row_value as usize]; - if next == 0 { // end of list break; @@ -980,13 +1066,6 @@ pub fn build_equal_condition_join_indices( } } - // if both probe and build sides have been scanned -- return None - if last_matched_indices - .is_some_and(|(probe, build)| probe == probe_batch.num_rows() - 1 && build == 0) - { - last_matched_indices = None - }; - let left: UInt64Array = PrimitiveArray::new(build_indices.finish().into(), None); let right: UInt32Array = PrimitiveArray::new(probe_indices.finish().into(), None); @@ -1012,7 +1091,10 @@ pub fn build_equal_condition_join_indices( null_equals_null, )?; - Ok((matched_indices.0, matched_indices.1, last_matched_indices)) + // try set completed & update with last joined probe-side idx + state.finalize_matching_iteration(&matched_indices.1); + + Ok((matched_indices.0, matched_indices.1)) } // version of eq_dyn supporting equality on null arrays @@ -1116,6 +1198,7 @@ impl HashJoinStream { if self.probe_batch.is_none() { match self.right.poll_next_unpin(cx) { Poll::Ready(Some(Ok(batch))) => { + self.state.set_probe_rows(batch.num_rows()); self.probe_batch = Some(batch); } Poll::Ready(None) => { @@ -1147,11 +1230,11 @@ impl HashJoinStream { JoinSide::Left, None, self.batch_size, - self.last_matched_indices, + &mut self.state, ); let result = match left_right_indices { - Ok((left_side, right_side, last_matched_indices)) => { + Ok((left_side, right_side)) => { // set the left bitmap // and only left, full, left semi, left anti need the left bitmap if need_produce_result_in_final(self.join_type) { @@ -1161,24 +1244,11 @@ impl HashJoinStream { } // adjust the two side indices base on the join type - // no need to adjust `self.last_matched_indices.0` (if some) - // as it has been joined while previous iteration - let adjust_range = - match (self.last_matched_indices, last_matched_indices) { - (None, None) => 0..batch.num_rows(), - (None, Some((range_end, _))) => 0..range_end + 1, - (Some((range_start, _)), None) => { - range_start + 1..batch.num_rows() - } - (Some((range_start, _)), Some((range_end, _))) => { - range_start + 1..range_end + 1 - } - }; let (left_side, right_side) = adjust_indices_by_join_type( left_side, right_side, - adjust_range, + self.state.adjust_range(), self.join_type, ); @@ -1194,10 +1264,10 @@ impl HashJoinStream { self.join_metrics.output_batches.add(1); self.join_metrics.output_rows.add(batch.num_rows()); - if last_matched_indices.is_none() { + if self.state.is_completed() { self.probe_batch = None; - }; - self.last_matched_indices = last_matched_indices; + self.state.reset_state(); + } Some(result) } @@ -1272,6 +1342,8 @@ mod tests { use datafusion_expr::Operator; use datafusion_physical_expr::expressions::Literal; use hashbrown::raw::RawTable; + use rstest::*; + use rstest_reuse::{self, *}; use crate::{ common, expressions::Column, hash_utils::create_hashes, @@ -1284,6 +1356,19 @@ mod tests { use super::*; + fn div_ceil(a: usize, b: usize) -> usize { + (a + b - 1) / b + } + + #[template] + #[rstest] + fn batch_sizes(#[values(8192, 10, 5, 2, 1)] batch_size: usize) {} + + fn prepare_task_ctx(batch_size: usize) -> Arc { + let session_config = SessionConfig::default().with_batch_size(batch_size); + Arc::new(TaskContext::default().with_session_config(session_config)) + } + fn build_table( a: (&str, &Vec), b: (&str, &Vec), @@ -1401,9 +1486,10 @@ mod tests { Ok((columns, batches)) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_inner_one() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_inner_one(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1446,9 +1532,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn partitioned_join_inner_one() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn partitioned_join_inner_one(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1528,9 +1615,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_inner_two() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_inner_two(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 2]), ("b2", &vec![1, 2, 2]), @@ -1557,7 +1645,13 @@ mod tests { assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); - assert_eq!(batches.len(), 1); + // expected joined records = 3 + // in case batch_size is 1 - additional empty batch for remaining 3-2 row + let mut expected_batch_count = div_ceil(3, batch_size); + if batch_size == 1 { + expected_batch_count += 1; + } + assert_eq!(batches.len(), expected_batch_count); let expected = [ "+----+----+----+----+----+----+", @@ -1575,9 +1669,10 @@ mod tests { } /// Test where the left has 2 parts, the right with 1 part => 1 part + #[apply(batch_sizes)] #[tokio::test] - async fn join_inner_one_two_parts_left() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_inner_one_two_parts_left(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let batch1 = build_table_i32( ("a1", &vec![1, 2]), ("b2", &vec![1, 2]), @@ -1611,7 +1706,13 @@ mod tests { assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); - assert_eq!(batches.len(), 1); + // expected joined records = 3 + // in case batch_size is 1 - additional empty batch for remaining 3-2 row + let mut expected_batch_count = div_ceil(3, batch_size); + if batch_size == 1 { + expected_batch_count += 1; + } + assert_eq!(batches.len(), expected_batch_count); let expected = [ "+----+----+----+----+----+----+", @@ -1629,9 +1730,10 @@ mod tests { } /// Test where the left has 1 part, the right has 2 parts => 2 parts + #[apply(batch_sizes)] #[tokio::test] - async fn join_inner_one_two_parts_right() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_inner_one_two_parts_right(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1663,7 +1765,14 @@ mod tests { // first part let stream = join.execute(0, task_ctx.clone())?; let batches = common::collect(stream).await?; - assert_eq!(batches.len(), 1); + + // expected joined records = 1 (first right batch) + // and additional empty batch for non-joined 20-6-80 + let mut expected_batch_count = div_ceil(1, batch_size); + if batch_size == 1 { + expected_batch_count += 1; + } + assert_eq!(batches.len(), expected_batch_count); let expected = [ "+----+----+----+----+----+----+", @@ -1677,7 +1786,11 @@ mod tests { // second part let stream = join.execute(1, task_ctx.clone())?; let batches = common::collect(stream).await?; - assert_eq!(batches.len(), 1); + + // expected joined records = 2 (second right batch) + let expected_batch_count = div_ceil(2, batch_size); + assert_eq!(batches.len(), expected_batch_count); + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", @@ -1704,9 +1817,10 @@ mod tests { ) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_left_multi_batch() { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_multi_batch(batch_size: usize) { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1745,9 +1859,10 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); } + #[apply(batch_sizes)] #[tokio::test] - async fn join_full_multi_batch() { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_full_multi_batch(batch_size: usize) { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1789,9 +1904,10 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); } + #[apply(batch_sizes)] #[tokio::test] - async fn join_left_empty_right() { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_empty_right(batch_size: usize) { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -1825,9 +1941,10 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); } + #[apply(batch_sizes)] #[tokio::test] - async fn join_full_empty_right() { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_full_empty_right(batch_size: usize) { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -1861,9 +1978,10 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); } + #[apply(batch_sizes)] #[tokio::test] - async fn join_left_one() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_one(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1904,9 +2022,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn partitioned_join_left_one() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn partitioned_join_left_one(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1967,9 +2086,10 @@ mod tests { ) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_left_semi() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_semi(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table left semi join right_table on left_table.b1 = right_table.b2 @@ -2001,9 +2121,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_left_semi_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_semi_with_filter(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); @@ -2087,9 +2208,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_right_semi() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_right_semi(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); @@ -2121,9 +2243,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_right_semi_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_right_semi_with_filter(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); @@ -2206,9 +2329,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_left_anti() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_anti(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table left anti join right_table on left_table.b1 = right_table.b2 @@ -2239,9 +2363,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_left_anti_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_anti_with_filter(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table left anti join right_table on left_table.b1 = right_table.b2 and right_table.a2!=8 @@ -2332,9 +2457,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_right_anti() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_right_anti(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); let on = vec![( @@ -2363,9 +2489,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_right_anti_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_right_anti_with_filter(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table right anti join right_table on left_table.b1 = right_table.b2 and left_table.a1!=13 @@ -2459,9 +2586,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_right_one() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_right_one(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -2497,9 +2625,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn partitioned_join_right_one() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn partitioned_join_right_one(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -2536,9 +2665,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_full_one() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_full_one(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -2610,7 +2740,7 @@ mod tests { }, left, ); - let (l, r, _) = build_equal_condition_join_indices( + let (l, r) = build_equal_condition_join_indices( &left_data.0, &left_data.1, &right, @@ -2623,7 +2753,7 @@ mod tests { JoinSide::Left, None, 64, - None, + &mut HashJoinStreamState::default(), )?; let mut left_ids = UInt64Builder::with_capacity(0); @@ -2705,9 +2835,10 @@ mod tests { JoinFilter::new(filter_expression, column_indices, intermediate_schema) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_inner_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_inner_with_filter(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -2745,9 +2876,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_left_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_with_filter(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -2788,9 +2920,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_right_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_right_with_filter(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -2830,9 +2963,10 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] #[tokio::test] - async fn join_full_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_full_with_filter(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -3066,9 +3200,7 @@ mod tests { // validation of partial join results output for different batch_size setting for join_type in join_types { for batch_size in (1..21).rev() { - let session_config = SessionConfig::default().with_batch_size(batch_size); - let task_ctx = TaskContext::default().with_session_config(session_config); - let task_ctx = Arc::new(task_ctx); + let task_ctx = prepare_task_ctx(batch_size); let join = join(left.clone(), right.clone(), on.clone(), &join_type, false) @@ -3077,9 +3209,9 @@ mod tests { let stream = join.execute(0, task_ctx).unwrap(); let batches = common::collect(stream).await.unwrap(); - // For inner/right join expected batch count equals ceil_div result, + // For inner/right join expected batch count equals dev_ceil result, // as there is no need to append non-joined build side data. - // For other join types it'll be ceil_div + 1 -- for additional batch + // For other join types it'll be div_ceil + 1 -- for additional batch // containing not visited build side rows (empty in this test case). let expected_batch_count = match join_type { JoinType::Inner diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 5f0db8117846..2a83660f0968 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -72,6 +72,8 @@ use futures::{Stream, StreamExt}; use hashbrown::HashSet; use parking_lot::Mutex; +use super::hash_join::HashJoinStreamState; + const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4; /// A symmetric hash join with range conditions is when both streams are hashed on the @@ -550,6 +552,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { null_equals_null: self.null_equals_null, final_result: false, reservation, + state: HashJoinStreamState::default(), })) } } @@ -586,6 +589,8 @@ struct SymmetricHashJoinStream { reservation: SharedMemoryReservation, /// Flag indicating whether there is nothing to process anymore final_result: bool, + /// Stream state for compatibility with HashJoinExec + state: HashJoinStreamState, } impl RecordBatchStream for SymmetricHashJoinStream { @@ -814,11 +819,12 @@ pub(crate) fn join_with_probe_batch( column_indices: &[ColumnIndex], random_state: &RandomState, null_equals_null: bool, + hash_join_stream_state: &mut HashJoinStreamState, ) -> Result> { if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { return Ok(None); } - let (build_indices, probe_indices, _) = build_equal_condition_join_indices( + let (build_indices, probe_indices) = build_equal_condition_join_indices( &build_hash_joiner.hashmap, &build_hash_joiner.input_buffer, probe_batch, @@ -831,8 +837,12 @@ pub(crate) fn join_with_probe_batch( build_hash_joiner.build_side, Some(build_hash_joiner.deleted_offset), usize::MAX, - None, + hash_join_stream_state, )?; + + // Resetting state to avoid potential overflows + hash_join_stream_state.reset_state(); + if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) { record_visited_indices( &mut build_hash_joiner.visited_rows, @@ -1099,6 +1109,7 @@ impl SymmetricHashJoinStream { &self.column_indices, &self.random_state, self.null_equals_null, + &mut self.state, )?; // Increment the offset for the probe hash joiner: probe_hash_joiner.offset += probe_batch.num_rows(); diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 8ae2a8686674..adce5565f802 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -517,4 +517,6 @@ pub fn unbounded_output(plan: &Arc) -> bool { } #[cfg(test)] +use rstest_reuse; + pub mod test; From 5bdfe3f6a1e53fb8617f66d2a97824b700f3956e Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Sat, 4 Nov 2023 15:44:36 +0300 Subject: [PATCH 04/13] fmt & clippy warns fixed --- datafusion/physical-plan/src/joins/hash_join.rs | 4 +++- datafusion/physical-plan/src/lib.rs | 1 + datafusion/physical-plan/src/test/exec.rs | 6 +++--- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 521509697469..ef1f93b1cd0c 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -1323,7 +1323,9 @@ mod tests { use arrow::array::{ArrayRef, Date32Array, Int32Array, UInt32Builder, UInt64Builder}; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{assert_batches_eq, assert_batches_sorted_eq, assert_contains, ScalarValue}; + use datafusion_common::{ + assert_batches_eq, assert_batches_sorted_eq, assert_contains, ScalarValue, + }; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::Operator; diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 16b7f860b421..081916f4f42d 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -555,6 +555,7 @@ pub fn unbounded_output(plan: &Arc) -> bool { } #[cfg(test)] +#[allow(clippy::single_component_path_imports)] use rstest_reuse; pub mod test; diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index 71e6cba6741e..fcc0cf6b7af8 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -61,7 +61,7 @@ impl BatchIndex { /// Iterator over batches #[derive(Debug, Default)] -pub(crate) struct TestStream { +pub struct TestStream { /// Vector of record batches data: Vec, /// Index into the data that has been returned so far @@ -684,7 +684,7 @@ pub struct PanicExec { } impl PanicExec { - /// Create new [`PanickingExec`] with a give schema and number of + /// Create new [`PanicExec`] with a give schema and number of /// partitions, which will each panic immediately. pub fn new(schema: SchemaRef, n_partitions: usize) -> Self { Self { @@ -708,7 +708,7 @@ impl DisplayAs for PanicExec { ) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "PanickingExec",) + write!(f, "PanicExec",) } } } From b1dc37d1b280b8b9ae81453ef42c065c99c53fd5 Mon Sep 17 00:00:00 2001 From: Eduard Karacharov <13005055+korowa@users.noreply.github.com> Date: Tue, 7 Nov 2023 10:25:39 +0300 Subject: [PATCH 05/13] Apply suggestions from code review Co-authored-by: Andrew Lamb --- datafusion/physical-plan/src/joins/hash_join.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index ef1f93b1cd0c..6e797342385b 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -1003,7 +1003,7 @@ pub(crate) fn build_equal_condition_join_indices( } else if let Some((_, index)) = hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash) { - // otherwise -- checking build hashmap for precense of current hash_value + // otherwise -- checking build hashmap for presence of current hash_value Some(*index) } else { None From 02651f601f51426d858a591f308dc3df3bdbc983 Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Tue, 7 Nov 2023 11:57:28 +0300 Subject: [PATCH 06/13] review comments --- .../physical-plan/src/joins/hash_join.rs | 35 ++++---- .../src/joins/nested_loop_join.rs | 12 +-- .../src/joins/symmetric_hash_join.rs | 18 ++-- datafusion/physical-plan/src/joins/utils.rs | 88 ++++++------------- 4 files changed, 58 insertions(+), 95 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 6e797342385b..80376865740c 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -593,7 +593,7 @@ impl ExecutionPlan for HashJoinExec { reservation, batch_size, probe_batch: None, - state: HashJoinStreamState::default(), + output_state: HashJoinOutputState::default(), })) } @@ -756,20 +756,20 @@ where // State for storing left/right side indices used for partial batch output // & producing ranges for adjusting indices #[derive(Debug, Default)] -pub(crate) struct HashJoinStreamState { +pub(crate) struct HashJoinOutputState { // total rows in current probe batch probe_rows: usize, // saved probe-build indices to resume matching from last_matched_indices: Option<(usize, usize)>, // current iteration has been updated matched_indices_updated: bool, - // tracking last joined probe side index seen for further indices adjustment + // last probe side index, joined during current iteration last_joined_probe_index: Option, - // tracking last joined probe side index seen for further indices adjustment + // last probe side index, joined during previous iteration prev_joined_probe_index: Option, } -impl HashJoinStreamState { +impl HashJoinOutputState { // set total probe rows to process pub(crate) fn set_probe_rows(&mut self, probe_rows: usize) { self.probe_rows = probe_rows; @@ -882,7 +882,7 @@ struct HashJoinStream { /// (cross-join due to key duplication on build side) `HashJoinStream` saves its state /// and emits result batch to upstream operator. /// On next poll these indices are used to skip already matched rows and adjusted probe-side indices. - state: HashJoinStreamState, + output_state: HashJoinOutputState, } impl RecordBatchStream for HashJoinStream { @@ -936,7 +936,7 @@ pub(crate) fn build_equal_condition_join_indices( build_side: JoinSide, deleted_offset: Option, output_limit: usize, - state: &mut HashJoinStreamState, + state: &mut HashJoinOutputState, ) -> Result<(UInt64Array, UInt32Array)> { let keys_values = probe_on .iter() @@ -1177,16 +1177,15 @@ impl HashJoinStream { // Fetch next probe batch if self.probe_batch.is_none() { - match self.right.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(batch))) => { - self.state.set_probe_rows(batch.num_rows()); + match ready!(self.right.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + self.output_state.set_probe_rows(batch.num_rows()); self.probe_batch = Some(batch); } - Poll::Ready(None) => { + None => { self.probe_batch = None; } - Poll::Ready(Some(err)) => return Poll::Ready(Some(err)), - Poll::Pending => return Poll::Pending, + Some(err) => return Poll::Ready(Some(err)), } } @@ -1211,7 +1210,7 @@ impl HashJoinStream { JoinSide::Left, None, self.batch_size, - &mut self.state, + &mut self.output_state, ); let result = match left_right_indices { @@ -1229,7 +1228,7 @@ impl HashJoinStream { let (left_side, right_side) = adjust_indices_by_join_type( left_side, right_side, - self.state.adjust_range(), + self.output_state.adjust_range(), self.join_type, ); @@ -1245,9 +1244,9 @@ impl HashJoinStream { self.join_metrics.output_batches.add(1); self.join_metrics.output_rows.add(batch.num_rows()); - if self.state.is_completed() { + if self.output_state.is_completed() { self.probe_batch = None; - self.state.reset_state(); + self.output_state.reset_state(); } Some(result) @@ -2732,7 +2731,7 @@ mod tests { JoinSide::Left, None, 64, - &mut HashJoinStreamState::default(), + &mut HashJoinOutputState::default(), )?; let mut left_ids = UInt64Builder::with_capacity(0); diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index f82bf3cf8d82..f89a2445fd07 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -28,9 +28,9 @@ use crate::coalesce_batches::concat_batches; use crate::joins::utils::{ append_right_indices, apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, check_join_is_valid, estimate_join_statistics, get_anti_indices, - get_anti_u64_indices, get_final_indices_from_bit_map, get_semi_indices, - get_semi_u64_indices, partitioned_join_output_partitioning, BuildProbeJoinMetrics, - ColumnIndex, JoinFilter, OnceAsync, OnceFut, + get_final_indices_from_bit_map, get_semi_indices, + partitioned_join_output_partitioning, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, + OnceAsync, OnceFut, }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::{ @@ -649,20 +649,20 @@ fn adjust_indices_by_join_type( // matched // unmatched left row will be produced in this batch let left_unmatched_indices = - get_anti_u64_indices(0..count_left_batch, &left_indices); + get_anti_indices(0..count_left_batch, &left_indices); // combine the matched and unmatched left result together append_left_indices(left_indices, right_indices, left_unmatched_indices) } JoinType::LeftSemi => { // need to remove the duplicated record in the left side - let left_indices = get_semi_u64_indices(0..count_left_batch, &left_indices); + let left_indices = get_semi_indices(0..count_left_batch, &left_indices); // the right_indices will not be used later for the `left semi` join (left_indices, right_indices) } JoinType::LeftAnti => { // need to remove the duplicated record in the left side // get the anti index for the left side - let left_indices = get_anti_u64_indices(0..count_left_batch, &left_indices); + let left_indices = get_anti_indices(0..count_left_batch, &left_indices); // the right_indices will not be used later for the `left anti` join (left_indices, right_indices) } diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 4cc35c1ebe2f..a7e0877537cf 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -33,7 +33,9 @@ use std::vec; use std::{any::Any, usize}; use crate::common::SharedMemoryReservation; -use crate::joins::hash_join::{build_equal_condition_join_indices, update_hash}; +use crate::joins::hash_join::{ + build_equal_condition_join_indices, update_hash, HashJoinOutputState, +}; use crate::joins::hash_join_utils::{ calculate_filter_expr_intervals, combine_two_batches, convert_sort_expr_with_filter_schema, get_pruning_anti_indices, @@ -72,8 +74,6 @@ use futures::{Stream, StreamExt}; use hashbrown::HashSet; use parking_lot::Mutex; -use super::hash_join::HashJoinStreamState; - const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4; /// A symmetric hash join with range conditions is when both streams are hashed on the @@ -553,7 +553,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { null_equals_null: self.null_equals_null, final_result: false, reservation, - state: HashJoinStreamState::default(), + output_state: HashJoinOutputState::default(), })) } } @@ -591,7 +591,7 @@ struct SymmetricHashJoinStream { /// Flag indicating whether there is nothing to process anymore final_result: bool, /// Stream state for compatibility with HashJoinExec - state: HashJoinStreamState, + output_state: HashJoinOutputState, } impl RecordBatchStream for SymmetricHashJoinStream { @@ -820,7 +820,7 @@ pub(crate) fn join_with_probe_batch( column_indices: &[ColumnIndex], random_state: &RandomState, null_equals_null: bool, - hash_join_stream_state: &mut HashJoinStreamState, + output_state: &mut HashJoinOutputState, ) -> Result> { if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { return Ok(None); @@ -838,11 +838,11 @@ pub(crate) fn join_with_probe_batch( build_hash_joiner.build_side, Some(build_hash_joiner.deleted_offset), usize::MAX, - hash_join_stream_state, + output_state, )?; // Resetting state to avoid potential overflows - hash_join_stream_state.reset_state(); + output_state.reset_state(); if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) { record_visited_indices( @@ -1110,7 +1110,7 @@ impl SymmetricHashJoinStream { &self.column_indices, &self.random_state, self.null_equals_null, - &mut self.state, + &mut self.output_state, )?; // Increment the offset for the probe hash joiner: probe_hash_joiner.offset += probe_batch.num_rows(); diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 3c7f0b00e4b7..53c762ff9511 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -35,6 +35,8 @@ use arrow::array::{ use arrow::compute; use arrow::datatypes::{Field, Schema, SchemaBuilder}; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; +use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray}; +use arrow_buffer::ArrowNativeType; use datafusion_common::cast::as_boolean_array; use datafusion_common::stats::Precision; use datafusion_common::{ @@ -920,17 +922,20 @@ pub(crate) fn append_right_indices( } } -/// Get unmatched and deduplicated indices for specified range of indices -pub(crate) fn get_anti_indices( +/// Returns `range` indices which are not present in `input_indices` +pub(crate) fn get_anti_indices( range: Range, - input_indices: &UInt32Array, -) -> UInt32Array { + input_indices: &PrimitiveArray, +) -> PrimitiveArray +where + NativeAdapter: From<::Native>, +{ let mut bitmap = BooleanBufferBuilder::new(range.len()); bitmap.append_n(range.len(), false); input_indices .iter() .flatten() - .map(|v| v as usize) + .map(|v| v.as_usize()) .filter(|v| range.contains(v)) .for_each(|v| { bitmap.set_bit(v - range.start, true); @@ -940,69 +945,26 @@ pub(crate) fn get_anti_indices( // get the anti index (range) - .filter_map(|idx| (!bitmap.get_bit(idx - offset)).then_some(idx as u32)) - .collect::() + .filter_map(|idx| { + (!bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx)) + }) + .collect::>() } -/// Get unmatched and deduplicated indices -pub(crate) fn get_anti_u64_indices( +/// Returns intersection of `range` and `input_indices` omitting duplicates +pub(crate) fn get_semi_indices( range: Range, - input_indices: &UInt64Array, -) -> UInt64Array { + input_indices: &PrimitiveArray, +) -> PrimitiveArray +where + NativeAdapter: From<::Native>, +{ let mut bitmap = BooleanBufferBuilder::new(range.len()); bitmap.append_n(range.len(), false); input_indices .iter() .flatten() - .map(|v| v as usize) - .filter(|v| range.contains(v)) - .for_each(|v| { - bitmap.set_bit(v - range.start, true); - }); - - let offset = range.start; - - // get the anti index - (range) - .filter_map(|idx| (!bitmap.get_bit(idx - offset)).then_some(idx as u64)) - .collect::() -} - -/// Get matched and deduplicated indices for specified range of indices -pub(crate) fn get_semi_indices( - range: Range, - input_indices: &UInt32Array, -) -> UInt32Array { - let mut bitmap = BooleanBufferBuilder::new(range.len()); - bitmap.append_n(range.len(), false); - input_indices - .iter() - .flatten() - .map(|v| v as usize) - .filter(|v| range.contains(v)) - .for_each(|v| { - bitmap.set_bit(v - range.start, true); - }); - - let offset = range.start; - - // get the semi index - (range) - .filter_map(|idx| (bitmap.get_bit(idx - offset)).then_some(idx as u32)) - .collect::() -} - -/// Get matched and deduplicated indices -pub(crate) fn get_semi_u64_indices( - range: Range, - input_indices: &UInt64Array, -) -> UInt64Array { - let mut bitmap = BooleanBufferBuilder::new(range.len()); - bitmap.append_n(range.len(), false); - input_indices - .iter() - .flatten() - .map(|v| v as usize) + .map(|v| v.as_usize()) .filter(|v| range.contains(v)) .for_each(|v| { bitmap.set_bit(v - range.start, true); @@ -1012,8 +974,10 @@ pub(crate) fn get_semi_u64_indices( // get the semi index (range) - .filter_map(|idx| (bitmap.get_bit(idx - offset)).then_some(idx as u64)) - .collect::() + .filter_map(|idx| { + (bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx)) + }) + .collect::>() } /// Metrics for build & probe joins From c8914a82faf0e280b0395213ea571cd5e0a5ead2 Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Mon, 8 Jan 2024 22:43:33 +0200 Subject: [PATCH 07/13] ported join limited output --- .../physical-plan/src/joins/hash_join.rs | 310 +++++++++--------- .../src/joins/symmetric_hash_join.rs | 131 +++++++- datafusion/physical-plan/src/joins/utils.rs | 81 +++++ .../join_disable_repartition_joins.slt | 10 +- 4 files changed, 369 insertions(+), 163 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 408a2b13c57b..adcac4bca064 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -26,7 +26,7 @@ use std::{any::Any, usize, vec}; use crate::joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, calculate_join_output_ordering, get_final_indices_from_bit_map, - need_produce_result_in_final, JoinHashMap, JoinHashMapType, + need_produce_result_in_final, JoinHashMap, JoinHashMapOffset, JoinHashMapType, }; use crate::{ coalesce_partitions::CoalescePartitionsExec, @@ -61,7 +61,8 @@ use arrow::util::bit_util; use arrow_array::cast::downcast_array; use arrow_schema::ArrowError; use datafusion_common::{ - exec_err, internal_err, plan_err, DataFusionError, JoinSide, JoinType, Result, + internal_datafusion_err, internal_err, plan_err, DataFusionError, JoinSide, JoinType, + Result, }; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; @@ -911,16 +912,10 @@ enum HashJoinStreamState { Completed, } -/// Container for HashJoinStreamState::ProcessProbeBatch related data -struct ProcessProbeBatchState { - /// Current probe-side batch - batch: RecordBatch, -} - impl HashJoinStreamState { /// Tries to extract ProcessProbeBatchState from HashJoinStreamState enum. /// Returns an error if state is not ProcessProbeBatchState. - fn try_as_process_probe_batch(&self) -> Result<&ProcessProbeBatchState> { + fn try_as_process_probe_batch_mut(&mut self) -> Result<&mut ProcessProbeBatchState> { match self { HashJoinStreamState::ProcessProbeBatch(state) => Ok(state), _ => internal_err!("Expected hash join stream in ProcessProbeBatch state"), @@ -928,6 +923,25 @@ impl HashJoinStreamState { } } +/// Container for HashJoinStreamState::ProcessProbeBatch related data +struct ProcessProbeBatchState { + /// Current probe-side batch + batch: RecordBatch, + /// Matching offset + offset: JoinHashMapOffset, + /// Max joined probe-side index from current batch + joined_probe_idx: Option, +} + +impl ProcessProbeBatchState { + fn advance(&mut self, offset: JoinHashMapOffset, joined_probe_idx: Option) { + self.offset = offset; + if joined_probe_idx.is_some() { + self.joined_probe_idx = joined_probe_idx; + } + } +} + /// [`Stream`] for [`HashJoinExec`] that does the actual join. /// /// This stream: @@ -973,7 +987,9 @@ impl RecordBatchStream for HashJoinStream { } } -/// Returns build/probe indices satisfying the equality condition. +/// Lookups by hash agaist JoinHashMap and resolves potential hash collisions. +/// Returns build/probe indices satisfying the equality condition, along with +/// starting point for next iteration. /// /// # Example /// @@ -1019,7 +1035,7 @@ impl RecordBatchStream for HashJoinStream { /// Probe indices: 3, 3, 4, 5 /// ``` #[allow(clippy::too_many_arguments)] -pub(crate) fn build_equal_condition_join_indices( +fn lookup_join_hashmap( build_hashmap: &T, build_input_buffer: &RecordBatch, probe_batch: &RecordBatch, @@ -1027,12 +1043,9 @@ pub(crate) fn build_equal_condition_join_indices( probe_on: &[Column], random_state: &RandomState, null_equals_null: bool, - hashes_buffer: &mut Vec, - filter: Option<&JoinFilter>, - build_side: JoinSide, - deleted_offset: Option, - fifo_hashmap: bool, -) -> Result<(UInt64Array, UInt32Array)> { + limit: usize, + offset: JoinHashMapOffset, +) -> Result<(UInt64Array, UInt32Array, Option)> { let keys_values = probe_on .iter() .map(|c| c.evaluate(probe_batch)?.into_array(probe_batch.num_rows())) @@ -1044,78 +1057,32 @@ pub(crate) fn build_equal_condition_join_indices( .into_array(build_input_buffer.num_rows()) }) .collect::>>()?; - hashes_buffer.clear(); - hashes_buffer.resize(probe_batch.num_rows(), 0); - let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; - // In case build-side input has not been inverted while JoinHashMap creation, the chained list algorithm - // will return build indices for each probe row in a reverse order as such: - // Build Indices: [5, 4, 3] - // Probe Indices: [1, 1, 1] - // - // This affects the output sequence. Hypothetically, it's possible to preserve the lexicographic order on the build side. - // Let's consider probe rows [0,1] as an example: - // - // When the probe iteration sequence is reversed, the following pairings can be derived: - // - // For probe row 1: - // (5, 1) - // (4, 1) - // (3, 1) - // - // For probe row 0: - // (5, 0) - // (4, 0) - // (3, 0) - // - // After reversing both sets of indices, we obtain reversed indices: - // - // (3,0) - // (4,0) - // (5,0) - // (3,1) - // (4,1) - // (5,1) - // - // With this approach, the lexicographic order on both the probe side and the build side is preserved. - let (mut probe_indices, mut build_indices) = if fifo_hashmap { - build_hashmap.get_matched_indices(hash_values.iter().enumerate(), deleted_offset) - } else { - let (mut matched_probe, mut matched_build) = build_hashmap - .get_matched_indices(hash_values.iter().enumerate().rev(), deleted_offset); + let mut hashes_buffer = vec![0; probe_batch.num_rows()]; + let hash_values = create_hashes(&keys_values, random_state, &mut hashes_buffer)?; - matched_probe.as_slice_mut().reverse(); - matched_build.as_slice_mut().reverse(); - - (matched_probe, matched_build) - }; - - let left: UInt64Array = PrimitiveArray::new(build_indices.finish().into(), None); - let right: UInt32Array = PrimitiveArray::new(probe_indices.finish().into(), None); + let (mut probe_builder, mut build_builder, next_offset) = build_hashmap + .get_matched_indices_with_limit_offset( + hash_values.iter().enumerate(), + None, + limit, + offset, + ); - let (left, right) = if let Some(filter) = filter { - // Filter the indices which satisfy the non-equal join condition, like `left.b1 = 10` - apply_join_filter_to_indices( - build_input_buffer, - probe_batch, - left, - right, - filter, - build_side, - )? - } else { - (left, right) - }; + let build_indices: UInt64Array = + PrimitiveArray::new(build_builder.finish().into(), None); + let probe_indices: UInt32Array = + PrimitiveArray::new(probe_builder.finish().into(), None); - let matched_indices = equal_rows_arr( - &left, - &right, + let (build_indices, probe_indices) = equal_rows_arr( + &build_indices, + &probe_indices, &build_join_values, &keys_values, null_equals_null, )?; - Ok((matched_indices.0, matched_indices.1)) + Ok((build_indices, probe_indices, next_offset)) } // version of eq_dyn supporting equality on null arrays @@ -1263,6 +1230,8 @@ impl HashJoinStream { self.state = HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState { batch, + offset: (0, None), + joined_probe_idx: None, }); } Some(Err(err)) => return Poll::Ready(Err(err)), @@ -1277,16 +1246,15 @@ impl HashJoinStream { fn process_probe_batch( &mut self, ) -> Result>> { - let state = self.state.try_as_process_probe_batch()?; + let state = self.state.try_as_process_probe_batch_mut()?; let build_side = self.build_side.try_as_ready_mut()?; self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(state.batch.num_rows()); let timer = self.join_metrics.join_time.timer(); - let mut hashes_buffer = vec![]; - // get the matched two indices for the on condition - let left_right_indices = build_equal_condition_join_indices( + // get the matched by join keys indices + let (left_indices, right_indices, next_offset) = lookup_join_hashmap( build_side.left_data.hash_map(), build_side.left_data.batch(), &state.batch, @@ -1294,53 +1262,102 @@ impl HashJoinStream { &self.on_right, &self.random_state, self.null_equals_null, - &mut hashes_buffer, - self.filter.as_ref(), - JoinSide::Left, - None, - true, - ); + self.batch_size, + state.offset, + )?; - let result = match left_right_indices { - Ok((left_side, right_side)) => { - // set the left bitmap - // and only left, full, left semi, left anti need the left bitmap - if need_produce_result_in_final(self.join_type) { - left_side.iter().flatten().for_each(|x| { - build_side.visited_left_side.set_bit(x as usize, true); - }); - } + // apply join filters if exists + let (left_indices, right_indices) = if let Some(filter) = &self.filter { + // Filter the indices which satisfy the non-equal join condition, like `left.b1 = 10` + apply_join_filter_to_indices( + build_side.left_data.batch(), + &state.batch, + left_indices, + right_indices, + filter, + JoinSide::Left, + )? + } else { + (left_indices, right_indices) + }; - // adjust the two side indices base on the join type - let (left_side, right_side) = adjust_indices_by_join_type( - left_side, - right_side, - 0..state.batch.num_rows(), - self.join_type, - ); + // mark joined left-side indices as visited, if required by join type + if need_produce_result_in_final(self.join_type) { + left_indices.iter().flatten().for_each(|x| { + build_side.visited_left_side.set_bit(x as usize, true); + }); + } - let result = build_batch_from_indices( - &self.schema, - build_side.left_data.batch(), - &state.batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(state.batch.num_rows()); - result - } - Err(err) => { - exec_err!("Fail to build join indices in HashJoinExec, error:{err}") - } + // check if probe batch scanned based on `next_offset` returned from lookup function + let probe_batch_scanned = next_offset.is_none() + || next_offset.is_some_and(|(probe_idx, build_idx)| { + probe_idx + 1 >= state.batch.num_rows() + && build_idx.is_some_and(|v| v == 0) + }); + + // The goals of index alignment for different join types are: + // + // 1) Right & FullJoin -- to append all missing probe-side indices between + // previous (excluding) and current joined indices. + // 2) SemiJoin -- deduplicate probe indices in range between previous + // (excluding) and current joined indices. + // 3) AntiJoin -- return only missing indices in range between + // previous and current joined indices. + // Inclusion/exclusion of the indices themselves don't matter + // + // As a summary -- alignment range can be produced based only on + // joined (matched with filters applied) probe side indices, excluding starting one + // (left from previous iteration). + + // if any rows have been joined -- get last joined probe-side (right) row + // it's important that index counts as "joined" after hash collisions checks + // and join filters applied. + let last_joined_right_idx = match right_indices.len() { + 0 => None, + n => Some(right_indices.value(n - 1) as usize), + }; + + // Calculate range and perform alignment. + // In case probe batch has been processed -- align all remaining rows. + let index_alignment_range_start = state.joined_probe_idx.map_or(0, |v| v + 1); + let index_alignment_range_end = if probe_batch_scanned { + state.batch.num_rows() + } else { + last_joined_right_idx.map_or(0, |v| v + 1) }; + + let (left_indices, right_indices) = adjust_indices_by_join_type( + left_indices, + right_indices, + index_alignment_range_start..index_alignment_range_end, + self.join_type, + ); + + let result = build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &state.batch, + &left_indices, + &right_indices, + &self.column_indices, + JoinSide::Left, + )?; + + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(state.batch.num_rows()); timer.done(); - self.state = HashJoinStreamState::FetchProbeBatch; + if probe_batch_scanned { + self.state = HashJoinStreamState::FetchProbeBatch; + } else { + state.advance( + next_offset + .ok_or_else(|| internal_datafusion_err!("unexpected None offset"))?, + last_joined_right_idx, + ) + }; - Ok(StatefulStreamResult::Ready(Some(result?))) + Ok(StatefulStreamResult::Ready(Some(result))) } /// Processes unmatched build-side rows for certain join types and produces output batch @@ -1406,15 +1423,15 @@ mod tests { use super::*; use crate::{ - common, expressions::Column, hash_utils::create_hashes, - joins::hash_join::build_equal_condition_join_indices, memory::MemoryExec, + common, expressions::Column, hash_utils::create_hashes, memory::MemoryExec, repartition::RepartitionExec, test::build_table_i32, test::exec::MockExec, }; use arrow::array::{ArrayRef, Date32Array, Int32Array, UInt32Builder, UInt64Builder}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{ - assert_batches_eq, assert_batches_sorted_eq, assert_contains, ScalarValue, + assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, + ScalarValue, }; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; @@ -2914,7 +2931,7 @@ mod tests { let join_hash_map = JoinHashMap::new(hashmap_left, next); - let (l, r) = build_equal_condition_join_indices( + let (l, r, _) = lookup_join_hashmap( &join_hash_map, &left, &right, @@ -2922,11 +2939,8 @@ mod tests { &[Column::new("a", 0)], &random_state, false, - &mut vec![0; right.num_rows()], - None, - JoinSide::Left, - None, - false, + 8192, + (0, None), )?; let mut left_ids = UInt64Builder::with_capacity(0); @@ -3314,26 +3328,26 @@ mod tests { "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", "+----+----+----+----+----+----+", - "| 4 | 1 | 0 | 10 | 1 | 0 |", - "| 3 | 1 | 0 | 10 | 1 | 0 |", - "| 2 | 1 | 0 | 10 | 1 | 0 |", "| 1 | 1 | 0 | 10 | 1 | 0 |", - "| 4 | 1 | 0 | 20 | 1 | 0 |", - "| 3 | 1 | 0 | 20 | 1 | 0 |", - "| 2 | 1 | 0 | 20 | 1 | 0 |", + "| 2 | 1 | 0 | 10 | 1 | 0 |", + "| 3 | 1 | 0 | 10 | 1 | 0 |", + "| 4 | 1 | 0 | 10 | 1 | 0 |", "| 1 | 1 | 0 | 20 | 1 | 0 |", - "| 4 | 1 | 0 | 30 | 1 | 0 |", - "| 3 | 1 | 0 | 30 | 1 | 0 |", - "| 2 | 1 | 0 | 30 | 1 | 0 |", + "| 2 | 1 | 0 | 20 | 1 | 0 |", + "| 3 | 1 | 0 | 20 | 1 | 0 |", + "| 4 | 1 | 0 | 20 | 1 | 0 |", "| 1 | 1 | 0 | 30 | 1 | 0 |", - "| 4 | 1 | 0 | 40 | 1 | 0 |", - "| 3 | 1 | 0 | 40 | 1 | 0 |", - "| 2 | 1 | 0 | 40 | 1 | 0 |", + "| 2 | 1 | 0 | 30 | 1 | 0 |", + "| 3 | 1 | 0 | 30 | 1 | 0 |", + "| 4 | 1 | 0 | 30 | 1 | 0 |", "| 1 | 1 | 0 | 40 | 1 | 0 |", - "| 4 | 1 | 0 | 50 | 1 | 0 |", - "| 3 | 1 | 0 | 50 | 1 | 0 |", - "| 2 | 1 | 0 | 50 | 1 | 0 |", + "| 2 | 1 | 0 | 40 | 1 | 0 |", + "| 3 | 1 | 0 | 40 | 1 | 0 |", + "| 4 | 1 | 0 | 40 | 1 | 0 |", "| 1 | 1 | 0 | 50 | 1 | 0 |", + "| 2 | 1 | 0 | 50 | 1 | 0 |", + "| 3 | 1 | 0 | 50 | 1 | 0 |", + "| 4 | 1 | 0 | 50 | 1 | 0 |", "+----+----+----+----+----+----+", ]; let left_batch = [ diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 7a3db04d8255..e7c267817708 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -32,7 +32,7 @@ use std::task::Poll; use std::{usize, vec}; use crate::common::SharedMemoryReservation; -use crate::joins::hash_join::{build_equal_condition_join_indices, update_hash}; +use crate::joins::hash_join::{equal_rows_arr, update_hash}; use crate::joins::stream_join_utils::{ calculate_filter_expr_intervals, combine_two_batches, convert_sort_expr_with_filter_schema, get_pruning_anti_indices, @@ -41,22 +41,26 @@ use crate::joins::stream_join_utils::{ StreamJoinMetrics, }; use crate::joins::utils::{ - build_batch_from_indices, build_join_schema, check_join_is_valid, - partitioned_join_output_partitioning, ColumnIndex, JoinFilter, JoinOn, - StatefulStreamResult, + apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, + check_join_is_valid, partitioned_join_output_partitioning, ColumnIndex, JoinFilter, + JoinHashMapType, JoinOn, StatefulStreamResult, }; use crate::{ expressions::{Column, PhysicalSortExpr}, joins::StreamJoinPartitionMode, metrics::{ExecutionPlanMetricsSet, MetricsSet}, DisplayAs, DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, - Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, + Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use arrow::array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, PrimitiveBuilder}; +use arrow::array::{ + ArrowPrimitiveType, NativeAdapter, PrimitiveArray, PrimitiveBuilder, UInt32Array, + UInt64Array, +}; use arrow::compute::concat_batches; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::bisect; use datafusion_common::{ internal_err, plan_err, DataFusionError, JoinSide, JoinType, Result, @@ -759,7 +763,7 @@ pub(crate) fn join_with_probe_batch( if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { return Ok(None); } - let (build_indices, probe_indices) = build_equal_condition_join_indices( + let (build_indices, probe_indices) = lookup_join_hashmap( &build_hash_joiner.hashmap, &build_hash_joiner.input_buffer, probe_batch, @@ -768,12 +772,22 @@ pub(crate) fn join_with_probe_batch( random_state, null_equals_null, &mut build_hash_joiner.hashes_buffer, - filter, - build_hash_joiner.build_side, Some(build_hash_joiner.deleted_offset), - false, )?; + let (build_indices, probe_indices) = if let Some(filter) = filter { + apply_join_filter_to_indices( + &build_hash_joiner.input_buffer, + probe_batch, + build_indices, + probe_indices, + filter, + build_hash_joiner.build_side, + )? + } else { + (build_indices, probe_indices) + }; + if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) { record_visited_indices( &mut build_hash_joiner.visited_rows, @@ -810,6 +824,103 @@ pub(crate) fn join_with_probe_batch( } } +/// This method performs lookups against JoinHashMap by hash values of join-key columns, and handles potential +/// hash collisions. +/// +/// # Arguments +/// +/// * `build_hashmap` - hashmap collected from build side data. +/// * `build_batch` - Build side record batch. +/// * `probe_batch` - Probe side record batch. +/// * `build_on` - An array of columns on which the join will be performed. The columns are from the build side of the join. +/// * `probe_on` - An array of columns on which the join will be performed. The columns are from the probe side of the join. +/// * `random_state` - The random state for the join. +/// * `null_equals_null` - A boolean indicating whether NULL values should be treated as equal when joining. +/// * `hashes_buffer` - Buffer used for probe side keys hash calculation. +/// * `probe_batch` - The second record batch to be joined. +/// * `column_indices` - An array of columns to be selected for the result of the join. +/// * `deleted_offset` - deleted offset for build side data. +/// +/// # Returns +/// +/// A [Result] containing a tuple with two equal length arrays, representing indices of rows from build and probe side, +/// matched by join key columns. +#[allow(clippy::too_many_arguments)] +fn lookup_join_hashmap( + build_hashmap: &PruningJoinHashMap, + build_batch: &RecordBatch, + probe_batch: &RecordBatch, + build_on: &[Column], + probe_on: &[Column], + random_state: &RandomState, + null_equals_null: bool, + hashes_buffer: &mut Vec, + deleted_offset: Option, +) -> Result<(UInt64Array, UInt32Array)> { + let keys_values = probe_on + .iter() + .map(|c| c.evaluate(probe_batch)?.into_array(probe_batch.num_rows())) + .collect::>>()?; + let build_join_values = build_on + .iter() + .map(|c| c.evaluate(build_batch)?.into_array(build_batch.num_rows())) + .collect::>>()?; + hashes_buffer.clear(); + hashes_buffer.resize(probe_batch.num_rows(), 0); + let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; + + // As SymmetricHashJoin uses LIFO JoinHashMap, the chained list algorithm + // will return build indices for each probe row in a reverse order as such: + // Build Indices: [5, 4, 3] + // Probe Indices: [1, 1, 1] + // + // This affects the output sequence. Hypothetically, it's possible to preserve the lexicographic order on the build side. + // Let's consider probe rows [0,1] as an example: + // + // When the probe iteration sequence is reversed, the following pairings can be derived: + // + // For probe row 1: + // (5, 1) + // (4, 1) + // (3, 1) + // + // For probe row 0: + // (5, 0) + // (4, 0) + // (3, 0) + // + // After reversing both sets of indices, we obtain reversed indices: + // + // (3,0) + // (4,0) + // (5,0) + // (3,1) + // (4,1) + // (5,1) + // + // With this approach, the lexicographic order on both the probe side and the build side is preserved. + let (mut matched_probe, mut matched_build) = build_hashmap + .get_matched_indices(hash_values.iter().enumerate().rev(), deleted_offset); + + matched_probe.as_slice_mut().reverse(); + matched_build.as_slice_mut().reverse(); + + let build_indices: UInt64Array = + PrimitiveArray::new(matched_build.finish().into(), None); + let probe_indices: UInt32Array = + PrimitiveArray::new(matched_probe.finish().into(), None); + + let (build_indices, probe_indices) = equal_rows_arr( + &build_indices, + &probe_indices, + &build_join_values, + &keys_values, + null_equals_null, + )?; + + Ok((build_indices, probe_indices)) +} + pub struct OneSideHashJoiner { /// Build side build_side: JoinSide, diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 36c7143ee0d8..a6d784502cdc 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -138,6 +138,8 @@ impl JoinHashMap { } } +pub(crate) type JoinHashMapOffset = (usize, Option); + // Trait defining methods that must be implemented by a hash map type to be used for joins. pub trait JoinHashMapType { /// The type of list used to store the next list @@ -226,6 +228,85 @@ pub trait JoinHashMapType { (input_indices, match_indices) } + + /// Matches hashes with taking limit and offset into account. + /// Returns pairs of matched indices along with the starting point for next + /// matching iteration (`None` if limit has not been reached). + /// + /// This method only compares hashes, so additional further check for actual values + /// equality may be required. + fn get_matched_indices_with_limit_offset<'a>( + &self, + iter: impl Iterator, + deleted_offset: Option, + limit: usize, + offset: JoinHashMapOffset, + ) -> ( + UInt32BufferBuilder, + UInt64BufferBuilder, + Option, + ) { + let mut input_indices = UInt32BufferBuilder::new(0); + let mut match_indices = UInt64BufferBuilder::new(0); + + let mut output_tuples = 0_usize; + let mut next_offset = None; + + let hash_map: &RawTable<(u64, u64)> = self.get_map(); + let next_chain = self.get_list(); + + let (initial_idx, initial_next_idx) = offset; + 'probe: for (row_idx, hash_value) in iter.skip(initial_idx) { + let index = if initial_next_idx.is_some() && row_idx == initial_idx { + // If `initial_next_idx` is zero, then input index has been processed + // during previous iteration, and it can be skipped now + if let Some(0) = initial_next_idx { + continue; + } + // Otherwise, use `initial_next_idx` as-is + initial_next_idx + } else if let Some((_, index)) = + hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash) + { + Some(*index) + } else { + None + }; + + if let Some(index) = index { + let mut i = index - 1; + loop { + let match_row_idx = if let Some(offset) = deleted_offset { + // This arguments means that we prune the next index way before here. + if i < offset as u64 { + // End of the list due to pruning + break; + } + i - offset as u64 + } else { + i + }; + match_indices.append(match_row_idx); + input_indices.append(row_idx as u32); + output_tuples += 1; + // Follow the chain to get the next index value + let next = next_chain[match_row_idx as usize]; + + if output_tuples >= limit { + next_offset = Some((row_idx, Some(next))); + break 'probe; + } + if next == 0 { + // end of list + break; + } + i = next - 1; + } + } + } + + (input_indices, match_indices, next_offset) + } } /// Implementation of `JoinHashMapType` for `JoinHashMap`. diff --git a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt index e5d4c25f48c8..1312f2916ed6 100644 --- a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt +++ b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt @@ -72,11 +72,11 @@ SELECT t1.a, t1.b, t1.c, t2.a as a2 ON t1.d = t2.d ORDER BY a2, t2.b LIMIT 5 ---- -1 3 95 0 -1 3 93 0 -1 3 92 0 -1 3 81 0 -1 3 76 0 +0 0 0 0 +0 0 2 0 +0 0 3 0 +0 0 6 0 +0 0 20 0 query TT EXPLAIN SELECT t2.a as a2, t2.b From 879cd416d95197553fc860d4448101907e69c74a Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Sun, 14 Jan 2024 18:33:55 +0200 Subject: [PATCH 08/13] comments & formatting --- datafusion/physical-plan/src/joins/hash_join.rs | 6 +++--- datafusion/physical-plan/src/joins/symmetric_hash_join.rs | 1 + datafusion/physical-plan/src/joins/utils.rs | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index adcac4bca064..8886a3ac5588 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -987,7 +987,8 @@ impl RecordBatchStream for HashJoinStream { } } -/// Lookups by hash agaist JoinHashMap and resolves potential hash collisions. +/// Executes lookups by hash against JoinHashMap and resolves potential +/// hash collisions. /// Returns build/probe indices satisfying the equality condition, along with /// starting point for next iteration. /// @@ -1266,9 +1267,8 @@ impl HashJoinStream { state.offset, )?; - // apply join filters if exists + // apply join filter if exists let (left_indices, right_indices) = if let Some(filter) = &self.filter { - // Filter the indices which satisfy the non-equal join condition, like `left.b1 = 10` apply_join_filter_to_indices( build_side.left_data.batch(), &state.batch, diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index e974ffa81ccd..e6478b759335 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -891,6 +891,7 @@ fn lookup_join_hashmap( .iter() .map(|c| c.evaluate(build_batch)?.into_array(build_batch.num_rows())) .collect::>>()?; + hashes_buffer.clear(); hashes_buffer.resize(probe_batch.num_rows(), 0); let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index a6d784502cdc..a4ffde87645b 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -138,6 +138,7 @@ impl JoinHashMap { } } +// Type of offsets for obtaining indices from JoinHashMap. pub(crate) type JoinHashMapOffset = (usize, Option); // Trait defining methods that must be implemented by a hash map type to be used for joins. From 302c2235cdda1dbb795da9135757669e3942a7b9 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 15 Jan 2024 07:23:10 -0500 Subject: [PATCH 09/13] Reuse hashes buffer --- datafusion/physical-plan/src/joins/hash_join.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 8886a3ac5588..5e0d8759a7d4 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -669,6 +669,7 @@ impl ExecutionPlan for HashJoinExec { state: HashJoinStreamState::WaitBuildSide, build_side: BuildSide::Initial(BuildSideInitialState { left_fut }), batch_size, + hashes_buffer: vec![], })) } @@ -979,6 +980,8 @@ struct HashJoinStream { build_side: BuildSide, /// Maximum output batch size batch_size: usize, + /// Scratch space for computing hashes + hashes_buffer: Vec, } impl RecordBatchStream for HashJoinStream { @@ -1044,6 +1047,7 @@ fn lookup_join_hashmap( probe_on: &[Column], random_state: &RandomState, null_equals_null: bool, + hashes_buffer: &mut Vec, limit: usize, offset: JoinHashMapOffset, ) -> Result<(UInt64Array, UInt32Array, Option)> { @@ -1059,8 +1063,9 @@ fn lookup_join_hashmap( }) .collect::>>()?; - let mut hashes_buffer = vec![0; probe_batch.num_rows()]; - let hash_values = create_hashes(&keys_values, random_state, &mut hashes_buffer)?; + hashes_buffer.clear(); + hashes_buffer.resize(probe_batch.num_rows(), 0); + let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; let (mut probe_builder, mut build_builder, next_offset) = build_hashmap .get_matched_indices_with_limit_offset( @@ -1263,6 +1268,7 @@ impl HashJoinStream { &self.on_right, &self.random_state, self.null_equals_null, + &mut self.hashes_buffer, self.batch_size, state.offset, )?; @@ -2930,6 +2936,7 @@ mod tests { ); let join_hash_map = JoinHashMap::new(hashmap_left, next); + let mut hashes_buffer = vec![0]; let (l, r, _) = lookup_join_hashmap( &join_hash_map, @@ -2939,6 +2946,7 @@ mod tests { &[Column::new("a", 0)], &random_state, false, + &mut hashes_buffer, 8192, (0, None), )?; From 7adc66731aec5a4c8be310b1fa0aa8a898195bc1 Mon Sep 17 00:00:00 2001 From: Eduard Karacharov <13005055+korowa@users.noreply.github.com> Date: Mon, 15 Jan 2024 21:11:11 +0200 Subject: [PATCH 10/13] Apply suggestions from code review Co-authored-by: Andrew Lamb --- datafusion/physical-plan/src/joins/hash_join.rs | 2 +- datafusion/physical-plan/src/joins/symmetric_hash_join.rs | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 8886a3ac5588..27d265f41d52 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -990,7 +990,7 @@ impl RecordBatchStream for HashJoinStream { /// Executes lookups by hash against JoinHashMap and resolves potential /// hash collisions. /// Returns build/probe indices satisfying the equality condition, along with -/// starting point for next iteration. +/// (optional) starting point for next iteration. /// /// # Example /// diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index e6478b759335..00950f082582 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -863,8 +863,6 @@ pub(crate) fn join_with_probe_batch( /// * `random_state` - The random state for the join. /// * `null_equals_null` - A boolean indicating whether NULL values should be treated as equal when joining. /// * `hashes_buffer` - Buffer used for probe side keys hash calculation. -/// * `probe_batch` - The second record batch to be joined. -/// * `column_indices` - An array of columns to be selected for the result of the join. /// * `deleted_offset` - deleted offset for build side data. /// /// # Returns From cdd7f1293628c6bef79cee4ce04ce167892f936c Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Mon, 15 Jan 2024 21:16:32 +0200 Subject: [PATCH 11/13] fixed metrics and updated comment --- datafusion/physical-plan/src/joins/hash_join.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 27d265f41d52..195d06883cf2 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -927,7 +927,7 @@ impl HashJoinStreamState { struct ProcessProbeBatchState { /// Current probe-side batch batch: RecordBatch, - /// Matching offset + /// Starting offset for JoinHashMap lookups offset: JoinHashMapOffset, /// Max joined probe-side index from current batch joined_probe_idx: Option, @@ -1344,7 +1344,7 @@ impl HashJoinStream { )?; self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(state.batch.num_rows()); + self.join_metrics.output_rows.add(result.num_rows()); timer.done(); if probe_batch_scanned { From 9e353e7a409be691fd8458b0dc5caffe71e363ce Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Sun, 21 Jan 2024 15:23:11 +0200 Subject: [PATCH 12/13] precalculate hashes & remove iterators --- .../physical-plan/src/joins/hash_join.rs | 48 +++++++------ datafusion/physical-plan/src/joins/utils.rs | 67 +++++++++++++------ 2 files changed, 72 insertions(+), 43 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 648feaf69d9b..552dffa5acf3 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -1039,15 +1039,14 @@ impl RecordBatchStream for HashJoinStream { /// Probe indices: 3, 3, 4, 5 /// ``` #[allow(clippy::too_many_arguments)] -fn lookup_join_hashmap( - build_hashmap: &T, +fn lookup_join_hashmap( + build_hashmap: &JoinHashMap, build_input_buffer: &RecordBatch, probe_batch: &RecordBatch, build_on: &[Column], probe_on: &[Column], - random_state: &RandomState, null_equals_null: bool, - hashes_buffer: &mut Vec, + hashes_buffer: &[u64], limit: usize, offset: JoinHashMapOffset, ) -> Result<(UInt64Array, UInt32Array, Option)> { @@ -1063,17 +1062,8 @@ fn lookup_join_hashmap( }) .collect::>>()?; - hashes_buffer.clear(); - hashes_buffer.resize(probe_batch.num_rows(), 0); - let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; - let (mut probe_builder, mut build_builder, next_offset) = build_hashmap - .get_matched_indices_with_limit_offset( - hash_values.iter().enumerate(), - None, - limit, - offset, - ); + .get_matched_indices_with_limit_offset(hashes_buffer, None, limit, offset); let build_indices: UInt64Array = PrimitiveArray::new(build_builder.finish().into(), None); @@ -1233,6 +1223,17 @@ impl HashJoinStream { self.state = HashJoinStreamState::ExhaustedProbeSide; } Some(Ok(batch)) => { + // Precalculate hash values for fetched batch + let keys_values = self + .on_right + .iter() + .map(|c| c.evaluate(&batch)?.into_array(batch.num_rows())) + .collect::>>()?; + + self.hashes_buffer.clear(); + self.hashes_buffer.resize(batch.num_rows(), 0); + create_hashes(&keys_values, &self.random_state, &mut self.hashes_buffer)?; + self.state = HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState { batch, @@ -1266,9 +1267,8 @@ impl HashJoinStream { &state.batch, &self.on_left, &self.on_right, - &self.random_state, self.null_equals_null, - &mut self.hashes_buffer, + &self.hashes_buffer, self.batch_size, state.offset, )?; @@ -2935,18 +2935,24 @@ mod tests { ("c", &vec![30, 40]), ); + // Join key column for both join sides + let key_column = Column::new("a", 0); + let join_hash_map = JoinHashMap::new(hashmap_left, next); - let mut hashes_buffer = vec![0]; + + let right_keys_values = + key_column.evaluate(&right)?.into_array(right.num_rows())?; + let mut hashes_buffer = vec![0; right.num_rows()]; + create_hashes(&[right_keys_values], &random_state, &mut hashes_buffer)?; let (l, r, _) = lookup_join_hashmap( &join_hash_map, &left, &right, - &[Column::new("a", 0)], - &[Column::new("a", 0)], - &random_state, + &[key_column.clone()], + &[key_column], false, - &mut hashes_buffer, + &hashes_buffer, 8192, (0, None), )?; diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index a4ffde87645b..5f7d9b66e6c9 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -236,9 +236,9 @@ pub trait JoinHashMapType { /// /// This method only compares hashes, so additional further check for actual values /// equality may be required. - fn get_matched_indices_with_limit_offset<'a>( + fn get_matched_indices_with_limit_offset( &self, - iter: impl Iterator, + iter: &[u64], deleted_offset: Option, limit: usize, offset: JoinHashMapOffset, @@ -251,30 +251,52 @@ pub trait JoinHashMapType { let mut match_indices = UInt64BufferBuilder::new(0); let mut output_tuples = 0_usize; - let mut next_offset = None; let hash_map: &RawTable<(u64, u64)> = self.get_map(); let next_chain = self.get_list(); - let (initial_idx, initial_next_idx) = offset; - 'probe: for (row_idx, hash_value) in iter.skip(initial_idx) { - let index = if initial_next_idx.is_some() && row_idx == initial_idx { - // If `initial_next_idx` is zero, then input index has been processed - // during previous iteration, and it can be skipped now - if let Some(0) = initial_next_idx { - continue; + let to_skip = match offset { + (initial_idx, None) => initial_idx, + (initial_idx, Some(0)) => initial_idx + 1, + (initial_idx, Some(initial_next_idx)) => { + let mut i = initial_next_idx - 1; + loop { + let match_row_idx = if let Some(offset) = deleted_offset { + // This arguments means that we prune the next index way before here. + if i < offset as u64 { + // End of the list due to pruning + break; + } + i - offset as u64 + } else { + i + }; + match_indices.append(match_row_idx); + input_indices.append(initial_idx as u32); + output_tuples += 1; + // Follow the chain to get the next index value + let next = next_chain[match_row_idx as usize]; + + if output_tuples >= limit { + let next_offset = Some((initial_idx, Some(next))); + return (input_indices, match_indices, next_offset); + } + if next == 0 { + // end of list + break; + } + i = next - 1; } - // Otherwise, use `initial_next_idx` as-is - initial_next_idx - } else if let Some((_, index)) = + + initial_idx + 1 + } + }; + + let mut row_idx = to_skip; + for hash_value in &iter[to_skip..] { + if let Some((_, index)) = hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash) { - Some(*index) - } else { - None - }; - - if let Some(index) = index { let mut i = index - 1; loop { let match_row_idx = if let Some(offset) = deleted_offset { @@ -294,8 +316,8 @@ pub trait JoinHashMapType { let next = next_chain[match_row_idx as usize]; if output_tuples >= limit { - next_offset = Some((row_idx, Some(next))); - break 'probe; + let next_offset = Some((row_idx, Some(next))); + return (input_indices, match_indices, next_offset); } if next == 0 { // end of list @@ -304,9 +326,10 @@ pub trait JoinHashMapType { i = next - 1; } } + row_idx += 1; } - (input_indices, match_indices, next_offset) + (input_indices, match_indices, None) } } From 54f471485e87f4d4a3e847c4b7a211578219613f Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Sun, 21 Jan 2024 19:52:12 +0200 Subject: [PATCH 13/13] draft: preparing for review --- .../physical-plan/src/joins/hash_join.rs | 16 +-- datafusion/physical-plan/src/joins/utils.rs | 132 ++++++++++-------- 2 files changed, 78 insertions(+), 70 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 552dffa5acf3..0c213f425785 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -1234,6 +1234,9 @@ impl HashJoinStream { self.hashes_buffer.resize(batch.num_rows(), 0); create_hashes(&keys_values, &self.random_state, &mut self.hashes_buffer)?; + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + self.state = HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState { batch, @@ -1256,8 +1259,6 @@ impl HashJoinStream { let state = self.state.try_as_process_probe_batch_mut()?; let build_side = self.build_side.try_as_ready_mut()?; - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(state.batch.num_rows()); let timer = self.join_metrics.join_time.timer(); // get the matched by join keys indices @@ -1294,13 +1295,6 @@ impl HashJoinStream { }); } - // check if probe batch scanned based on `next_offset` returned from lookup function - let probe_batch_scanned = next_offset.is_none() - || next_offset.is_some_and(|(probe_idx, build_idx)| { - probe_idx + 1 >= state.batch.num_rows() - && build_idx.is_some_and(|v| v == 0) - }); - // The goals of index alignment for different join types are: // // 1) Right & FullJoin -- to append all missing probe-side indices between @@ -1326,7 +1320,7 @@ impl HashJoinStream { // Calculate range and perform alignment. // In case probe batch has been processed -- align all remaining rows. let index_alignment_range_start = state.joined_probe_idx.map_or(0, |v| v + 1); - let index_alignment_range_end = if probe_batch_scanned { + let index_alignment_range_end = if next_offset.is_none() { state.batch.num_rows() } else { last_joined_right_idx.map_or(0, |v| v + 1) @@ -1353,7 +1347,7 @@ impl HashJoinStream { self.join_metrics.output_rows.add(result.num_rows()); timer.done(); - if probe_batch_scanned { + if next_offset.is_none() { self.state = HashJoinStreamState::FetchProbeBatch; } else { state.advance( diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 5f7d9b66e6c9..6ab08d3db022 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -141,6 +141,50 @@ impl JoinHashMap { // Type of offsets for obtaining indices from JoinHashMap. pub(crate) type JoinHashMapOffset = (usize, Option); +// Macro for traversing chained values with limit. +// Early returns in case of reacing output tuples limit. +macro_rules! chain_traverse { + ( + $input_indices:ident, $match_indices:ident, $hash_values:ident, $next_chain:ident, + $input_idx:ident, $chain_idx:ident, $deleted_offset:ident, $remaining_output:ident + ) => { + let mut i = $chain_idx - 1; + loop { + let match_row_idx = if let Some(offset) = $deleted_offset { + // This arguments means that we prune the next index way before here. + if i < offset as u64 { + // End of the list due to pruning + break; + } + i - offset as u64 + } else { + i + }; + $match_indices.append(match_row_idx); + $input_indices.append($input_idx as u32); + $remaining_output -= 1; + // Follow the chain to get the next index value + let next = $next_chain[match_row_idx as usize]; + + if $remaining_output == 0 { + // In case current input index is the last, and no more chain values left + // returning None as whole input has been scanned + let next_offset = if $input_idx == $hash_values.len() - 1 && next == 0 { + None + } else { + Some(($input_idx, Some(next))) + }; + return ($input_indices, $match_indices, next_offset); + } + if next == 0 { + // end of list + break; + } + i = next - 1; + } + }; +} + // Trait defining methods that must be implemented by a hash map type to be used for joins. pub trait JoinHashMapType { /// The type of list used to store the next list @@ -238,7 +282,7 @@ pub trait JoinHashMapType { /// equality may be required. fn get_matched_indices_with_limit_offset( &self, - iter: &[u64], + hash_values: &[u64], deleted_offset: Option, limit: usize, offset: JoinHashMapOffset, @@ -250,81 +294,51 @@ pub trait JoinHashMapType { let mut input_indices = UInt32BufferBuilder::new(0); let mut match_indices = UInt64BufferBuilder::new(0); - let mut output_tuples = 0_usize; + let mut remaining_output = limit; let hash_map: &RawTable<(u64, u64)> = self.get_map(); let next_chain = self.get_list(); + // Calculate initial `hash_values` index before iterating let to_skip = match offset { + // None `initial_next_idx` indicates that `initial_idx` processing has'n been started (initial_idx, None) => initial_idx, + // Zero `initial_next_idx` indicates that `initial_idx` has been processed during + // previous iteration, and it should be skipped (initial_idx, Some(0)) => initial_idx + 1, + // Otherwise, process remaining `initial_idx` matches by traversing `next_chain`, + // to start with the next index (initial_idx, Some(initial_next_idx)) => { - let mut i = initial_next_idx - 1; - loop { - let match_row_idx = if let Some(offset) = deleted_offset { - // This arguments means that we prune the next index way before here. - if i < offset as u64 { - // End of the list due to pruning - break; - } - i - offset as u64 - } else { - i - }; - match_indices.append(match_row_idx); - input_indices.append(initial_idx as u32); - output_tuples += 1; - // Follow the chain to get the next index value - let next = next_chain[match_row_idx as usize]; - - if output_tuples >= limit { - let next_offset = Some((initial_idx, Some(next))); - return (input_indices, match_indices, next_offset); - } - if next == 0 { - // end of list - break; - } - i = next - 1; - } + chain_traverse!( + input_indices, + match_indices, + hash_values, + next_chain, + initial_idx, + initial_next_idx, + deleted_offset, + remaining_output + ); initial_idx + 1 } }; let mut row_idx = to_skip; - for hash_value in &iter[to_skip..] { + for hash_value in &hash_values[to_skip..] { if let Some((_, index)) = hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash) { - let mut i = index - 1; - loop { - let match_row_idx = if let Some(offset) = deleted_offset { - // This arguments means that we prune the next index way before here. - if i < offset as u64 { - // End of the list due to pruning - break; - } - i - offset as u64 - } else { - i - }; - match_indices.append(match_row_idx); - input_indices.append(row_idx as u32); - output_tuples += 1; - // Follow the chain to get the next index value - let next = next_chain[match_row_idx as usize]; - - if output_tuples >= limit { - let next_offset = Some((row_idx, Some(next))); - return (input_indices, match_indices, next_offset); - } - if next == 0 { - // end of list - break; - } - i = next - 1; - } + chain_traverse!( + input_indices, + match_indices, + hash_values, + next_chain, + row_idx, + index, + deleted_offset, + remaining_output + ); } row_idx += 1; }