Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maintain right child's order in NestedLoopJoinExec #35

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 247 additions & 13 deletions datafusion/physical-plan/src/joins/nested_loop_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ impl NestedLoopJoinExec {
right.equivalence_properties().clone(),
&join_type,
schema,
&[false, false],
&Self::maintains_input_order(join_type),
None,
// No on columns in nested loop join
&[],
Expand All @@ -238,6 +238,31 @@ impl NestedLoopJoinExec {

PlanProperties::new(eq_properties, output_partitioning, mode)
}

/// Returns a vector indicating whether the left and right inputs maintain their order.
/// The first element corresponds to the left input, and the second to the right.
///
/// The left (build-side) input's order may change, but the right (probe-side) input's
/// order is maintained for INNER, RIGHT, RIGHT ANTI, and RIGHT SEMI joins.
///
/// Maintaining the right input's order helps optimize the nodes down the pipeline
/// (See [`ExecutionPlan::maintains_input_order`]).
///
/// This is a separate method because it is also called when computing properties, before
/// a [`NestedLoopJoinExec`] is created. It also takes [`JoinType`] as an argument, as
/// opposed to `Self`, for the same reason.
fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
alihan-synnada marked this conversation as resolved.
Show resolved Hide resolved
vec![
false,
matches!(
join_type,
JoinType::Inner
| JoinType::Right
| JoinType::RightAnti
| JoinType::RightSemi
),
]
}
}

impl DisplayAs for NestedLoopJoinExec {
Expand Down Expand Up @@ -278,6 +303,10 @@ impl ExecutionPlan for NestedLoopJoinExec {
]
}

fn maintains_input_order(&self) -> Vec<bool> {
Self::maintains_input_order(self.join_type)
}

fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.left, &self.right]
}
Expand Down Expand Up @@ -429,17 +458,17 @@ struct NestedLoopJoinStream {
}

fn build_join_indices(
left_row_index: usize,
right_batch: &RecordBatch,
right_row_index: usize,
left_batch: &RecordBatch,
right_batch: &RecordBatch,
filter: Option<&JoinFilter>,
) -> Result<(UInt64Array, UInt32Array)> {
// left indices: [left_index, left_index, ...., left_index]
// right indices: [0, 1, 2, 3, 4,....,right_row_count]
// left indices: [0, 1, 2, 3, 4, ..., left_row_count]
// right indices: [right_index, right_index, ..., right_index]

let right_row_count = right_batch.num_rows();
let left_indices = UInt64Array::from(vec![left_row_index as u64; right_row_count]);
let right_indices = UInt32Array::from_iter_values(0..(right_row_count as u32));
let left_row_count = left_batch.num_rows();
let left_indices = UInt64Array::from_iter_values(0..(left_row_count as u64));
let right_indices = UInt32Array::from(vec![right_row_index as u32; left_row_count]);
// in the nested loop join, the filter can contain non-equal and equal condition.
if let Some(filter) = filter {
apply_join_filter_to_indices(
Expand Down Expand Up @@ -561,9 +590,9 @@ fn join_left_and_right_batch(
schema: &Schema,
visited_left_side: &SharedBitmapBuilder,
) -> Result<RecordBatch> {
let indices = (0..left_batch.num_rows())
.map(|left_row_index| {
build_join_indices(left_row_index, right_batch, left_batch, filter)
let indices = (0..right_batch.num_rows())
.map(|right_row_index| {
build_join_indices(right_row_index, left_batch, right_batch, filter)
})
.collect::<Result<Vec<(UInt64Array, UInt32Array)>>>()
.map_err(|e| {
Expand Down Expand Up @@ -595,7 +624,7 @@ fn join_left_and_right_batch(
right_side,
0..right_batch.num_rows(),
join_type,
false,
true,
);

build_batch_from_indices(
Expand Down Expand Up @@ -643,27 +672,68 @@ mod tests {
};

use arrow::datatypes::{DataType, Field};
use arrow_array::Int32Array;
use arrow_schema::SortOptions;
use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue};
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
use datafusion_physical_expr::{Partitioning, PhysicalExpr};
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;

use rstest::rstest;

fn build_table(
a: (&str, &Vec<i32>),
b: (&str, &Vec<i32>),
c: (&str, &Vec<i32>),
batch_size: Option<usize>,
sorted_column_names: Vec<&str>,
) -> Arc<dyn ExecutionPlan> {
let batch = build_table_i32(a, b, c);
let schema = batch.schema();
Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())

let batches = if let Some(batch_size) = batch_size {
let num_batches = batch.num_rows().div_ceil(batch_size);
(0..num_batches)
.map(|i| {
let start = i * batch_size;
let remaining_rows = batch.num_rows() - start;
batch.slice(start, batch_size.min(remaining_rows))
})
.collect::<Vec<_>>()
} else {
vec![batch]
};

let mut exec =
MemoryExec::try_new(&[batches], Arc::clone(&schema), None).unwrap();
if !sorted_column_names.is_empty() {
let mut sort_info = Vec::new();
for name in sorted_column_names {
let index = schema.index_of(name).unwrap();
let sort_expr = PhysicalSortExpr {
expr: Arc::new(Column::new(name, index)),
options: SortOptions {
descending: false,
nulls_first: false,
},
};
sort_info.push(sort_expr);
}
exec = exec.with_sort_information(vec![sort_info]);
}

Arc::new(exec)
}

fn build_left_table() -> Arc<dyn ExecutionPlan> {
build_table(
("a1", &vec![5, 9, 11]),
("b1", &vec![5, 8, 8]),
("c1", &vec![50, 90, 110]),
None,
Vec::new(),
)
}

Expand All @@ -672,6 +742,8 @@ mod tests {
("a2", &vec![12, 2, 10]),
("b2", &vec![10, 2, 10]),
("c2", &vec![40, 80, 100]),
None,
Vec::new(),
)
}

Expand Down Expand Up @@ -999,11 +1071,15 @@ mod tests {
("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
None,
Vec::new(),
);
let right = build_table(
("a2", &vec![10, 11]),
("b2", &vec![12, 13]),
("c2", &vec![14, 15]),
None,
Vec::new(),
);
let filter = prepare_join_filter();

Expand Down Expand Up @@ -1044,6 +1120,164 @@ mod tests {
Ok(())
}

fn prepare_mod_join_filter() -> JoinFilter {
let column_indices = vec![
ColumnIndex {
index: 1,
side: JoinSide::Left,
},
ColumnIndex {
index: 1,
side: JoinSide::Right,
},
];
let intermediate_schema = Schema::new(vec![
Field::new("x", DataType::Int32, true),
Field::new("x", DataType::Int32, true),
]);

// left.b1 % 3
let left_mod = Arc::new(BinaryExpr::new(
Arc::new(Column::new("x", 0)),
Operator::Modulo,
Arc::new(Literal::new(ScalarValue::Int32(Some(3)))),
)) as Arc<dyn PhysicalExpr>;
// left.b1 % 3 != 0
let left_filter = Arc::new(BinaryExpr::new(
left_mod,
Operator::NotEq,
Arc::new(Literal::new(ScalarValue::Int32(Some(0)))),
)) as Arc<dyn PhysicalExpr>;

// right.b2 % 5
let right_mod = Arc::new(BinaryExpr::new(
Arc::new(Column::new("x", 1)),
Operator::Modulo,
Arc::new(Literal::new(ScalarValue::Int32(Some(5)))),
)) as Arc<dyn PhysicalExpr>;
// right.b2 % 5 != 0
let right_filter = Arc::new(BinaryExpr::new(
right_mod,
Operator::NotEq,
Arc::new(Literal::new(ScalarValue::Int32(Some(0)))),
)) as Arc<dyn PhysicalExpr>;
// filter = left.b1 % 3 != 0 and right.b2 % 5 != 0
let filter_expression =
Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
as Arc<dyn PhysicalExpr>;

JoinFilter::new(filter_expression, column_indices, intermediate_schema)
}

fn generate_columns(num_columns: usize, num_rows: usize) -> Vec<Vec<i32>> {
let column = (1..=num_rows).map(|x| x as i32).collect();
vec![column; num_columns]
}

#[rstest]
#[tokio::test]
async fn join_maintains_right_order(
#[values(
JoinType::Inner,
JoinType::Right,
JoinType::RightAnti,
JoinType::RightSemi
)]
join_type: JoinType,
#[values(1, 100, 1000)] left_batch_size: usize,
#[values(1, 100, 1000)] right_batch_size: usize,
) -> Result<()> {
let left_columns = generate_columns(3, 1000);
let left = build_table(
("a1", &left_columns[0]),
("b1", &left_columns[1]),
("c1", &left_columns[2]),
Some(left_batch_size),
Vec::new(),
);

let right_columns = generate_columns(3, 1000);
let right = build_table(
("a2", &right_columns[0]),
("b2", &right_columns[1]),
("c2", &right_columns[2]),
Some(right_batch_size),
vec!["a2", "b2", "c2"],
);

let filter = prepare_mod_join_filter();

let nested_loop_join = Arc::new(NestedLoopJoinExec::try_new(
left,
Arc::clone(&right),
Some(filter),
&join_type,
)?) as Arc<dyn ExecutionPlan>;
assert_eq!(nested_loop_join.maintains_input_order(), vec![false, true]);

let right_column_indices = match join_type {
JoinType::Inner | JoinType::Right => vec![3, 4, 5],
JoinType::RightAnti | JoinType::RightSemi => vec![0, 1, 2],
_ => unreachable!(),
};

let right_ordering = right.output_ordering().unwrap();
let join_ordering = nested_loop_join.output_ordering().unwrap();
for (right, join) in right_ordering.iter().zip(join_ordering.iter()) {
let right_column = right.expr.as_any().downcast_ref::<Column>().unwrap();
let join_column = join.expr.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(join_column.name(), join_column.name());
assert_eq!(
right_column_indices[right_column.index()],
join_column.index()
);
assert_eq!(right.options, join.options);
}

let batches = nested_loop_join
.execute(0, Arc::new(TaskContext::default()))?
.try_collect::<Vec<_>>()
.await?;

// Make sure that the order of the right side is maintained
let mut prev_values = [i32::MIN, i32::MIN, i32::MIN];

for (batch_index, batch) in batches.iter().enumerate() {
let columns: Vec<_> = right_column_indices
.iter()
.map(|&i| {
batch
.column(i)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
})
.collect();

for row in 0..batch.num_rows() {
let current_values = [
columns[0].value(row),
columns[1].value(row),
columns[2].value(row),
];
assert!(
current_values
.into_iter()
.zip(prev_values)
.all(|(current, prev)| current >= prev),
"batch_index: {} row: {} current: {:?}, prev: {:?}",
batch_index,
row,
current_values,
prev_values
);
prev_values = current_values;
}
}

Ok(())
}

/// Returns the column names on the schema
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
Expand Down
12 changes: 6 additions & 6 deletions datafusion/sqllogictest/test_files/join.slt
Original file line number Diff line number Diff line change
Expand Up @@ -838,10 +838,10 @@ LEFT JOIN department AS d
ON (e.name = 'Alice' OR e.name = 'Bob');
----
1 Alice HR
2 Bob HR
1 Alice Engineering
2 Bob Engineering
1 Alice Sales
2 Bob HR
2 Bob Engineering
2 Bob Sales
3 Carol NULL

Expand All @@ -853,10 +853,10 @@ RIGHT JOIN employees AS e
ON (e.name = 'Alice' OR e.name = 'Bob');
----
1 Alice HR
2 Bob HR
1 Alice Engineering
2 Bob Engineering
1 Alice Sales
2 Bob HR
2 Bob Engineering
2 Bob Sales
3 Carol NULL

Expand All @@ -868,10 +868,10 @@ FULL JOIN employees AS e
ON (e.name = 'Alice' OR e.name = 'Bob');
----
1 Alice HR
2 Bob HR
1 Alice Engineering
2 Bob Engineering
1 Alice Sales
2 Bob HR
2 Bob Engineering
2 Bob Sales
3 Carol NULL

Expand Down
Loading