Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maintain right child's order in NestedLoopJoinExec #35

Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 167 additions & 13 deletions datafusion/physical-plan/src/joins/nested_loop_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ impl NestedLoopJoinExec {
right.equivalence_properties().clone(),
&join_type,
schema,
&[false, false],
&Self::maintains_input_order(join_type),
None,
// No on columns in nested loop join
&[],
Expand All @@ -238,6 +238,19 @@ impl NestedLoopJoinExec {

PlanProperties::new(eq_properties, output_partitioning, mode)
}

fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
alihan-synnada marked this conversation as resolved.
Show resolved Hide resolved
vec![
false,
matches!(
join_type,
JoinType::Inner
| JoinType::Right
| JoinType::RightAnti
| JoinType::RightSemi
),
]
}
}

impl DisplayAs for NestedLoopJoinExec {
Expand Down Expand Up @@ -278,6 +291,10 @@ impl ExecutionPlan for NestedLoopJoinExec {
]
}

fn maintains_input_order(&self) -> Vec<bool> {
Self::maintains_input_order(self.join_type)
}

fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.left, &self.right]
}
Expand Down Expand Up @@ -429,17 +446,17 @@ struct NestedLoopJoinStream {
}

fn build_join_indices(
left_row_index: usize,
right_batch: &RecordBatch,
right_row_index: usize,
left_batch: &RecordBatch,
right_batch: &RecordBatch,
filter: Option<&JoinFilter>,
) -> Result<(UInt64Array, UInt32Array)> {
// left indices: [left_index, left_index, ...., left_index]
// right indices: [0, 1, 2, 3, 4,....,right_row_count]
// left indices: [0, 1, 2, 3, 4, ..., left_row_count]
// right indices: [right_index, right_index, ..., right_index]

let right_row_count = right_batch.num_rows();
let left_indices = UInt64Array::from(vec![left_row_index as u64; right_row_count]);
let right_indices = UInt32Array::from_iter_values(0..(right_row_count as u32));
let left_row_count = left_batch.num_rows();
let left_indices = UInt64Array::from_iter_values(0..(left_row_count as u64));
let right_indices = UInt32Array::from(vec![right_row_index as u32; left_row_count]);
// in the nested loop join, the filter can contain non-equal and equal condition.
if let Some(filter) = filter {
apply_join_filter_to_indices(
Expand Down Expand Up @@ -561,9 +578,9 @@ fn join_left_and_right_batch(
schema: &Schema,
visited_left_side: &SharedBitmapBuilder,
) -> Result<RecordBatch> {
let indices = (0..left_batch.num_rows())
.map(|left_row_index| {
build_join_indices(left_row_index, right_batch, left_batch, filter)
let indices = (0..right_batch.num_rows())
.map(|right_row_index| {
build_join_indices(right_row_index, left_batch, right_batch, filter)
})
.collect::<Result<Vec<(UInt64Array, UInt32Array)>>>()
.map_err(|e| {
Expand Down Expand Up @@ -595,7 +612,7 @@ fn join_left_and_right_batch(
right_side,
0..right_batch.num_rows(),
join_type,
false,
true,
);

build_batch_from_indices(
Expand Down Expand Up @@ -643,27 +660,40 @@ mod tests {
};

use arrow::datatypes::{DataType, Field};
use arrow_array::Int32Array;
use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue};
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
use datafusion_physical_expr::{Partitioning, PhysicalExpr};

use rstest::rstest;

fn build_table(
a: (&str, &Vec<i32>),
b: (&str, &Vec<i32>),
c: (&str, &Vec<i32>),
num_batches: usize,
) -> Arc<dyn ExecutionPlan> {
let batch = build_table_i32(a, b, c);
let rows_per_batch = batch.num_rows() / num_batches;
let batches = (0..num_batches)
.map(|i| {
let start = i * rows_per_batch;
let remaining_rows = batch.num_rows() - start;
batch.slice(start, rows_per_batch.min(remaining_rows))
})
.collect::<Vec<_>>();
let schema = batch.schema();
Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
Arc::new(MemoryExec::try_new(&[batches], schema, None).unwrap())
}

fn build_left_table() -> Arc<dyn ExecutionPlan> {
build_table(
("a1", &vec![5, 9, 11]),
("b1", &vec![5, 8, 8]),
("c1", &vec![50, 90, 110]),
1,
)
}

Expand All @@ -672,6 +702,7 @@ mod tests {
("a2", &vec![12, 2, 10]),
("b2", &vec![10, 2, 10]),
("c2", &vec![40, 80, 100]),
1,
)
}

Expand Down Expand Up @@ -999,11 +1030,13 @@ mod tests {
("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
1,
);
let right = build_table(
("a2", &vec![10, 11]),
("b2", &vec![12, 13]),
("c2", &vec![14, 15]),
1,
);
let filter = prepare_join_filter();

Expand Down Expand Up @@ -1044,6 +1077,127 @@ mod tests {
Ok(())
}

fn prepare_mod_join_filter() -> JoinFilter {
let column_indices = vec![
ColumnIndex {
index: 1,
side: JoinSide::Left,
},
ColumnIndex {
index: 1,
side: JoinSide::Right,
},
];
let intermediate_schema = Schema::new(vec![
Field::new("x", DataType::Int32, true),
Field::new("x", DataType::Int32, true),
]);

// left.b1 % 3
let left_mod = Arc::new(BinaryExpr::new(
Arc::new(Column::new("x", 0)),
Operator::Modulo,
Arc::new(Literal::new(ScalarValue::Int32(Some(3)))),
)) as Arc<dyn PhysicalExpr>;
// left.b1 % 3 != 0
let left_filter = Arc::new(BinaryExpr::new(
left_mod,
Operator::NotEq,
Arc::new(Literal::new(ScalarValue::Int32(Some(0)))),
)) as Arc<dyn PhysicalExpr>;

// right.b2 % 5
let right_mod = Arc::new(BinaryExpr::new(
Arc::new(Column::new("x", 1)),
Operator::Modulo,
Arc::new(Literal::new(ScalarValue::Int32(Some(5)))),
)) as Arc<dyn PhysicalExpr>;
// right.b2 % 5 != 0
let right_filter = Arc::new(BinaryExpr::new(
right_mod,
Operator::NotEq,
Arc::new(Literal::new(ScalarValue::Int32(Some(0)))),
)) as Arc<dyn PhysicalExpr>;
// filter = left.b1 % 3 != 0 and right.b2 % 5 != 0
let filter_expression =
Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
as Arc<dyn PhysicalExpr>;

JoinFilter::new(filter_expression, column_indices, intermediate_schema)
}

fn generate_columns(num_columns:usize, num_rows:i32) -> Vec<Vec<i32>> {
alihan-synnada marked this conversation as resolved.
Show resolved Hide resolved
let column = (1..=num_rows).collect();
vec![column; num_columns]
}

#[rstest]
#[tokio::test]
async fn join_maintains_right_order(
#[values(JoinType::Inner, JoinType::Right, JoinType::RightAnti, JoinType::RightSemi)] join_type: JoinType,
#[values(1, 5, 10, 30)] num_batches: usize,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defining number of rows per batch is a more common style rather than number of batches. You can also increase the test coverage with different sized left and right batches (like left_batch size: 1, 5, 10..., right_batch size: 1, 5, 10...)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 4b962ab

) -> Result<()> {
let left_columns = generate_columns(3, 20);
let left = build_table(
("a1", &left_columns[0]),
("b1", &left_columns[1]),
("c1", &left_columns[2]),
1,
);

let right_columns = generate_columns(3, 100);
let right = build_table(
("a2", &right_columns[0]),
("b2", &right_columns[1]),
("c2", &right_columns[2]),
num_batches,
);

let filter = prepare_mod_join_filter();

let nested_loop_join = NestedLoopJoinExec::try_new(
left,
right,
Some(filter),
&join_type,
)?;

let batches = nested_loop_join
.execute(0, Arc::new(TaskContext::default()))?
.try_collect::<Vec<_>>()
.await?;

// Make sure that the order of the right side is maintained
let mut prev_values = (i32::MIN, i32::MIN, i32::MIN);

for (batch_index, batch) in batches.iter().enumerate() {
let right_column_indices = match join_type {
JoinType::Inner | JoinType::Right => vec![3, 4, 5],
JoinType::RightAnti | JoinType::RightSemi => vec![0, 1, 2],
_ => unreachable!(),
};
let columns: Vec<_> = right_column_indices
.into_iter().map(|i| batch.column(i).as_any().downcast_ref::<Int32Array>().unwrap())
.collect();

for row in 0..batch.num_rows() {
let current_values = (
columns[0].value(row),
columns[1].value(row),
columns[2].value(row),
);

assert!(current_values.0 >= prev_values.0, "batch_index: {} row: {} current.0: {}, prev.0: {}", batch_index, row, current_values.0, prev_values.0);
assert!(current_values.1 >= prev_values.1, "batch_index: {} row: {} current.1: {}, prev.1: {}", batch_index, row, current_values.1, prev_values.1);
assert!(current_values.2 >= prev_values.2, "batch_index: {} row: {} current.2: {}, prev.2: {}", batch_index, row, current_values.2, prev_values.2);

prev_values = current_values;
}
alihan-synnada marked this conversation as resolved.
Show resolved Hide resolved
}

Ok(())
}

/// Returns the column names on the schema
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
Expand Down
Loading