Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maintain right child's order in NestedLoopJoinExec #35

Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 182 additions & 13 deletions datafusion/physical-plan/src/joins/nested_loop_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ impl NestedLoopJoinExec {
right.equivalence_properties().clone(),
&join_type,
schema,
&[false, false],
&Self::maintains_input_order(join_type),
None,
// No on columns in nested loop join
&[],
Expand All @@ -238,6 +238,19 @@ impl NestedLoopJoinExec {

PlanProperties::new(eq_properties, output_partitioning, mode)
}

fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
alihan-synnada marked this conversation as resolved.
Show resolved Hide resolved
vec![
false,
matches!(
join_type,
JoinType::Inner
| JoinType::Right
| JoinType::RightAnti
| JoinType::RightSemi
),
]
}
}

impl DisplayAs for NestedLoopJoinExec {
Expand Down Expand Up @@ -278,6 +291,10 @@ impl ExecutionPlan for NestedLoopJoinExec {
]
}

fn maintains_input_order(&self) -> Vec<bool> {
Self::maintains_input_order(self.join_type)
}

fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.left, &self.right]
}
Expand Down Expand Up @@ -429,17 +446,17 @@ struct NestedLoopJoinStream {
}

fn build_join_indices(
left_row_index: usize,
right_batch: &RecordBatch,
right_row_index: usize,
left_batch: &RecordBatch,
right_batch: &RecordBatch,
filter: Option<&JoinFilter>,
) -> Result<(UInt64Array, UInt32Array)> {
// left indices: [left_index, left_index, ...., left_index]
// right indices: [0, 1, 2, 3, 4,....,right_row_count]
// left indices: [0, 1, 2, 3, 4, ..., left_row_count]
// right indices: [right_index, right_index, ..., right_index]

let right_row_count = right_batch.num_rows();
let left_indices = UInt64Array::from(vec![left_row_index as u64; right_row_count]);
let right_indices = UInt32Array::from_iter_values(0..(right_row_count as u32));
let left_row_count = left_batch.num_rows();
let left_indices = UInt64Array::from_iter_values(0..(left_row_count as u64));
let right_indices = UInt32Array::from(vec![right_row_index as u32; left_row_count]);
// in the nested loop join, the filter can contain non-equal and equal condition.
if let Some(filter) = filter {
apply_join_filter_to_indices(
Expand Down Expand Up @@ -561,9 +578,9 @@ fn join_left_and_right_batch(
schema: &Schema,
visited_left_side: &SharedBitmapBuilder,
) -> Result<RecordBatch> {
let indices = (0..left_batch.num_rows())
.map(|left_row_index| {
build_join_indices(left_row_index, right_batch, left_batch, filter)
let indices = (0..right_batch.num_rows())
.map(|right_row_index| {
build_join_indices(right_row_index, left_batch, right_batch, filter)
})
.collect::<Result<Vec<(UInt64Array, UInt32Array)>>>()
.map_err(|e| {
Expand Down Expand Up @@ -595,7 +612,7 @@ fn join_left_and_right_batch(
right_side,
0..right_batch.num_rows(),
join_type,
false,
true,
);

build_batch_from_indices(
Expand Down Expand Up @@ -643,27 +660,40 @@ mod tests {
};

use arrow::datatypes::{DataType, Field};
use arrow_array::Int32Array;
use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue};
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
use datafusion_physical_expr::{Partitioning, PhysicalExpr};

use rstest::rstest;

fn build_table(
a: (&str, &Vec<i32>),
b: (&str, &Vec<i32>),
c: (&str, &Vec<i32>),
num_batches: usize,
) -> Arc<dyn ExecutionPlan> {
let batch = build_table_i32(a, b, c);
let rows_per_batch = batch.num_rows() / num_batches;
let batches = (0..num_batches)
.map(|i| {
let start = i * rows_per_batch;
let remaining_rows = batch.num_rows() - start;
batch.slice(start, rows_per_batch.min(remaining_rows))
})
.collect::<Vec<_>>();
let schema = batch.schema();
Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
Arc::new(MemoryExec::try_new(&[batches], schema, None).unwrap())
}

fn build_left_table() -> Arc<dyn ExecutionPlan> {
build_table(
("a1", &vec![5, 9, 11]),
("b1", &vec![5, 8, 8]),
("c1", &vec![50, 90, 110]),
1,
)
}

Expand All @@ -672,6 +702,7 @@ mod tests {
("a2", &vec![12, 2, 10]),
("b2", &vec![10, 2, 10]),
("c2", &vec![40, 80, 100]),
1,
)
}

Expand Down Expand Up @@ -999,11 +1030,13 @@ mod tests {
("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
1,
);
let right = build_table(
("a2", &vec![10, 11]),
("b2", &vec![12, 13]),
("c2", &vec![14, 15]),
1,
);
let filter = prepare_join_filter();

Expand Down Expand Up @@ -1044,6 +1077,142 @@ mod tests {
Ok(())
}

fn prepare_mod_join_filter() -> JoinFilter {
let column_indices = vec![
ColumnIndex {
index: 1,
side: JoinSide::Left,
},
ColumnIndex {
index: 1,
side: JoinSide::Right,
},
];
let intermediate_schema = Schema::new(vec![
Field::new("x", DataType::Int32, true),
Field::new("x", DataType::Int32, true),
]);

// left.b1 % 3
let left_mod = Arc::new(BinaryExpr::new(
Arc::new(Column::new("x", 0)),
Operator::Modulo,
Arc::new(Literal::new(ScalarValue::Int32(Some(3)))),
)) as Arc<dyn PhysicalExpr>;
// left.b1 % 3 != 0
let left_filter = Arc::new(BinaryExpr::new(
left_mod,
Operator::NotEq,
Arc::new(Literal::new(ScalarValue::Int32(Some(0)))),
)) as Arc<dyn PhysicalExpr>;

// right.b2 % 5
let right_mod = Arc::new(BinaryExpr::new(
Arc::new(Column::new("x", 1)),
Operator::Modulo,
Arc::new(Literal::new(ScalarValue::Int32(Some(5)))),
)) as Arc<dyn PhysicalExpr>;
// right.b2 % 5 != 0
let right_filter = Arc::new(BinaryExpr::new(
right_mod,
Operator::NotEq,
Arc::new(Literal::new(ScalarValue::Int32(Some(0)))),
)) as Arc<dyn PhysicalExpr>;
// filter = left.b1 % 3 != 0 and right.b2 % 5 != 0
let filter_expression =
Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
as Arc<dyn PhysicalExpr>;

JoinFilter::new(filter_expression, column_indices, intermediate_schema)
}

fn generate_columns(num_columns: usize, num_rows: i32) -> Vec<Vec<i32>> {
let column = (1..=num_rows).collect();
vec![column; num_columns]
}

#[rstest]
#[tokio::test]
async fn join_maintains_right_order(
#[values(
JoinType::Inner,
JoinType::Right,
JoinType::RightAnti,
JoinType::RightSemi
)]
join_type: JoinType,
#[values(1, 5, 10, 30)] num_batches: usize,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defining number of rows per batch is a more common style rather than number of batches. You can also increase the test coverage with different sized left and right batches (like left_batch size: 1, 5, 10..., right_batch size: 1, 5, 10...)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 4b962ab

) -> Result<()> {
let left_columns = generate_columns(3, 20);
let left = build_table(
("a1", &left_columns[0]),
("b1", &left_columns[1]),
("c1", &left_columns[2]),
1,
);

let right_columns = generate_columns(3, 100);
let right = build_table(
("a2", &right_columns[0]),
("b2", &right_columns[1]),
("c2", &right_columns[2]),
num_batches,
);

let filter = prepare_mod_join_filter();

let nested_loop_join =
NestedLoopJoinExec::try_new(left, right, Some(filter), &join_type)?;

let batches = nested_loop_join
.execute(0, Arc::new(TaskContext::default()))?
.try_collect::<Vec<_>>()
.await?;

// Make sure that the order of the right side is maintained
let mut prev_values = [i32::MIN, i32::MIN, i32::MIN];

for (batch_index, batch) in batches.iter().enumerate() {
let right_column_indices = match join_type {
JoinType::Inner | JoinType::Right => vec![3, 4, 5],
JoinType::RightAnti | JoinType::RightSemi => vec![0, 1, 2],
_ => unreachable!(),
};
let columns: Vec<_> = right_column_indices
.into_iter()
.map(|i| {
batch
.column(i)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
})
.collect();

for row in 0..batch.num_rows() {
let current_values = [
columns[0].value(row),
columns[1].value(row),
columns[2].value(row),
];
assert!(
current_values
.into_iter()
.zip(prev_values)
.all(|(current, prev)| current >= prev),
"batch_index: {} row: {} current: {:?}, prev: {:?}",
batch_index,
row,
current_values,
prev_values
);
prev_values = current_values;
}
}

Ok(())
}

/// Returns the column names on the schema
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
Expand Down
12 changes: 6 additions & 6 deletions datafusion/sqllogictest/test_files/join.slt
Original file line number Diff line number Diff line change
Expand Up @@ -838,10 +838,10 @@ LEFT JOIN department AS d
ON (e.name = 'Alice' OR e.name = 'Bob');
----
1 Alice HR
2 Bob HR
1 Alice Engineering
2 Bob Engineering
1 Alice Sales
2 Bob HR
2 Bob Engineering
2 Bob Sales
3 Carol NULL

Expand All @@ -853,10 +853,10 @@ RIGHT JOIN employees AS e
ON (e.name = 'Alice' OR e.name = 'Bob');
----
1 Alice HR
2 Bob HR
1 Alice Engineering
2 Bob Engineering
1 Alice Sales
2 Bob HR
2 Bob Engineering
2 Bob Sales
3 Carol NULL

Expand All @@ -868,10 +868,10 @@ FULL JOIN employees AS e
ON (e.name = 'Alice' OR e.name = 'Bob');
----
1 Alice HR
2 Bob HR
1 Alice Engineering
2 Bob Engineering
1 Alice Sales
2 Bob HR
2 Bob Engineering
2 Bob Sales
3 Carol NULL

Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/joins.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2136,10 +2136,10 @@ FROM (select t1_id from join_t1 where join_t1.t1_id > 22) as join_t1
RIGHT JOIN (select t2_id from join_t2 where join_t2.t2_id > 11) as join_t2
ON join_t1.t1_id < join_t2.t2_id
----
NULL 22
33 44
33 55
44 55
NULL 22

#####
# Configuration teardown
Expand Down