Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Possible Congestion Scenario in SortPreservingMergeExec #12302

Merged
merged 13 commits into from
Sep 6, 2024
160 changes: 159 additions & 1 deletion datafusion/core/tests/fuzz_cases/merge_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,41 @@
// under the License.

//! Fuzz Test for various corner cases merging streams of RecordBatches
use std::sync::Arc;

use std::any::Any;
use std::fmt::Formatter;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::Duration;

use arrow::{
array::{ArrayRef, Int32Array},
compute::SortOptions,
record_batch::RecordBatch,
};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion::physical_plan::{
collect,
expressions::{col, PhysicalSortExpr},
memory::MemoryExec,
sorts::sort_preserving_merge::SortPreservingMergeExec,
};
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_common::{DataFusionError, Result};
use datafusion_common_runtime::SpawnedTask;
use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_plan::{
DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, PlanProperties,
};
use test_utils::{batches_to_vec, partitions_to_sorted_vec, stagger_batch_with_seed};

use futures::Stream;
use tokio::time::timeout;

#[tokio::test]
async fn test_merge_2() {
run_merge_test(vec![
Expand Down Expand Up @@ -160,3 +179,142 @@ fn concat(mut v1: Vec<RecordBatch>, v2: Vec<RecordBatch>) -> Vec<RecordBatch> {
v1.extend(v2);
v1
}

/// It returns pending for the 1st partition until the 2nd partition is polled.
#[derive(Debug, Clone)]
struct CongestedExec {
schema: Schema,
cache: PlanProperties,
congestion_cleared: Arc<Mutex<bool>>,
}

impl CongestedExec {
fn compute_properties(schema: SchemaRef) -> PlanProperties {
let columns = schema
.fields
.iter()
.enumerate()
.map(|(i, f)| Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>)
.collect::<Vec<_>>();
let mut eq_properties = EquivalenceProperties::new(schema);
eq_properties.add_new_orderings(vec![columns
.iter()
.map(|expr| PhysicalSortExpr::new(expr.clone(), SortOptions::default()))
.collect::<Vec<_>>()]);
let mode = ExecutionMode::Unbounded;
PlanProperties::new(eq_properties, Partitioning::Hash(columns, 2), mode)
}
}

impl ExecutionPlan for CongestedExec {
fn name(&self) -> &'static str {
Self::static_name()
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &PlanProperties {
&self.cache
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(self)
}
fn execute(
&self,
partition: usize,
_context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
Ok(Box::pin(CongestedStream {
schema: Arc::new(self.schema.clone()),
congestion_cleared: self.congestion_cleared.clone(),
partition,
}))
}
}

impl DisplayAs for CongestedExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(f, "CongestedExec",).unwrap()
}
}
Ok(())
}
}

/// It returns pending for the 1st partition until the 2nd partition is polled.
#[derive(Debug)]
pub struct CongestedStream {
schema: SchemaRef,
congestion_cleared: Arc<Mutex<bool>>,
partition: usize,
}

impl Stream for CongestedStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.partition {
0 => {
let cleared = self.congestion_cleared.lock().unwrap();
if *cleared {
return Poll::Ready(None);
} else {
Poll::Pending
}
}
1 => {
let mut cleared = self.congestion_cleared.lock().unwrap();
*cleared = true;
Poll::Ready(None)
}
_ => unreachable!(),
}
}
}

impl RecordBatchStream for CongestedStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}

#[tokio::test]
async fn test_spm_congestion() -> Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read this test a bit more -- it doesn't seem like it is actually a fuzz test (aka it doesn't seem to have any random inputs, for example).

I think it would make more sense to put it with the other sort preserving merge tests:

let task_ctx = Arc::new(TaskContext::default());
let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]);
let source = CongestedExec {
schema: schema.clone(),
cache: CongestedExec::compute_properties(Arc::new(schema.clone())),
congestion_cleared: Arc::new(Mutex::new(false)),
};
let spm = SortPreservingMergeExec::new(
vec![PhysicalSortExpr::new(
Arc::new(Column::new("c1", 0)),
SortOptions::default(),
)],
Arc::new(source),
);
let spm_task = SpawnedTask::spawn(collect(Arc::new(spm), task_ctx));

let result = timeout(Duration::from_secs(3), spm_task.join()).await;
match result {
Ok(Ok(Ok(_batches))) => Ok(()),
Ok(Ok(Err(e))) => Err(e),
Ok(Err(_)) => Err(DataFusionError::Execution(
"SortPreservingMerge task panicked or was cancelled".to_string(),
)),
Err(_) => Err(DataFusionError::Execution(
"SortPreservingMerge caused a deadlock".to_string(),
)),
}
}
36 changes: 27 additions & 9 deletions datafusion/physical-plan/src/sorts/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,22 @@
//! Merge that deals with an arbitrary size of streaming inputs.
//! This is an order-preserving merge.

use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Context, Poll};

use crate::metrics::BaselineMetrics;
use crate::sorts::builder::BatchBuilder;
use crate::sorts::cursor::{Cursor, CursorValues};
use crate::sorts::stream::PartitionedStream;
use crate::RecordBatchStream;

use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_execution::memory_pool::MemoryReservation;

use futures::Stream;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Context, Poll};

/// A fallible [`PartitionedStream`] of [`Cursor`] and [`RecordBatch`]
type CursorStream<C> = Box<dyn PartitionedStream<Output = Result<(C, RecordBatch)>>>;
Expand Down Expand Up @@ -97,6 +100,10 @@ pub(crate) struct SortPreservingMergeStream<C: CursorValues> {

/// number of rows produced
produced: usize,

/// Unitiated partitions. They are stored in a vector to keep them in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please document what an "uninitiated partition" means in this context? I think it means partitions whose streams that have been polled haven't been ready yet

/// a priortiy order to visit the partitions in a round-robin fashion
uninitiated_partitions: Vec<usize>,
}

impl<C: CursorValues> SortPreservingMergeStream<C> {
Expand All @@ -121,6 +128,7 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
batch_size,
fetch,
produced: 0,
uninitiated_partitions: (0..stream_count).collect(),
}
}

Expand Down Expand Up @@ -156,12 +164,22 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
}
// try to initialize the loser tree
if self.loser_tree.is_empty() {
// Ensure all non-exhausted streams have a cursor from which
// rows can be pulled
for i in 0..self.streams.partitions() {
if let Err(e) = ready!(self.maybe_poll_stream(cx, i)) {
alamb marked this conversation as resolved.
Show resolved Hide resolved
self.aborted = true;
return Poll::Ready(Some(Err(e)));
// Ensure all non-exhausted streams have a cursor from which rows can be pulled
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment implies to me that the code would / should poll all the streams. However, the code seems to ensure now that only streams that had previously not returned Ready for a poll are now polled.

Copy link
Contributor Author

@berkaysynnada berkaysynnada Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, the behavior is more correct now. In the previous version, let's assume the 1st partition is exhausted and returns None without setting its cursor. Then, the 2nd partition returns Pending. When poll_next_inner() is polled again, the iteration starts from the 1st partition, which has already returned None. AFAIK polling exhausted streams could cause problems). Therefore, I track which streams have returned a result (either None or Some()), and which ones have returned Pending only.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see synnada-ai#34 for alternate idea

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried to explain my concern with that: synnada-ai#34 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 -- response in synnada-ai#34 (comment)

let remaining_partitions = self.uninitiated_partitions.clone();
for i in remaining_partitions {
match self.maybe_poll_stream(cx, i) {
Poll::Ready(Err(e)) => {
self.aborted = true;
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => {
self.uninitiated_partitions.rotate_left(1);
cx.waker().wake_by_ref();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this usage has some side-effects or decrease performance, but I cannot wake the SPM poll again once it receives a pending from its first partition

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some research -- see https://github.com/synnada-ai/datafusion-upstream/pull/34/files#r1743621057

I think calling wake_by_ref effectively tells tokio to schedule this poll loop again after handling other tasks, which makes sense to me (as I am not sure how else we would signal to tokio that the merge is ready to go)

But I share your concern that this will cause some sort of performance issue

return Poll::Pending;
}
_ => {
self.uninitiated_partitions.retain(|idx| *idx != i);
}
}
}
self.init_loser_tree();
Expand Down
Loading