Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Possible Congestion Scenario in SortPreservingMergeExec #12302

Merged
merged 13 commits into from
Sep 6, 2024
1 change: 1 addition & 0 deletions datafusion/core/tests/fuzz_cases/merge_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

//! Fuzz Test for various corner cases merging streams of RecordBatches

use std::sync::Arc;

use arrow::{
Expand Down
55 changes: 44 additions & 11 deletions datafusion/physical-plan/src/sorts/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,23 @@
//! Merge that deals with an arbitrary size of streaming inputs.
//! This is an order-preserving merge.

use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Context, Poll};

use crate::metrics::BaselineMetrics;
use crate::sorts::builder::BatchBuilder;
use crate::sorts::cursor::{Cursor, CursorValues};
use crate::sorts::stream::PartitionedStream;
use crate::RecordBatchStream;

use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_execution::memory_pool::MemoryReservation;

use futures::Stream;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Context, Poll};

/// A fallible [`PartitionedStream`] of [`Cursor`] and [`RecordBatch`]
type CursorStream<C> = Box<dyn PartitionedStream<Output = Result<(C, RecordBatch)>>>;
Expand Down Expand Up @@ -86,7 +90,7 @@ pub(crate) struct SortPreservingMergeStream<C: CursorValues> {
/// been updated
loser_tree_adjusted: bool,

/// target batch size
/// Target batch size
batch_size: usize,

/// Cursors for each input partition. `None` means the input is exhausted
Expand All @@ -97,6 +101,12 @@ pub(crate) struct SortPreservingMergeStream<C: CursorValues> {

/// number of rows produced
produced: usize,

/// This queue contains partition indices in order. When a partition is polled and returns `Poll::Ready`,
/// it is removed from the vector. If a partition returns `Poll::Pending`, it is moved to the end of the
/// vector to ensure the next iteration starts with a different partition, preventing the same partition
/// from being continuously polled.
uninitiated_partitions: VecDeque<usize>,
}

impl<C: CursorValues> SortPreservingMergeStream<C> {
Expand All @@ -121,6 +131,7 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
batch_size,
fetch,
produced: 0,
uninitiated_partitions: (0..stream_count).collect(),
}
}

Expand Down Expand Up @@ -154,14 +165,36 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
if self.aborted {
return Poll::Ready(None);
}
// try to initialize the loser tree
// Once all partitions have set their corresponding cursors for the loser tree,
// we skip the following block. Until then, this function may be called multiple
// times and can return Poll::Pending if any partition returns Poll::Pending.
if self.loser_tree.is_empty() {
// Ensure all non-exhausted streams have a cursor from which
// rows can be pulled
for i in 0..self.streams.partitions() {
if let Err(e) = ready!(self.maybe_poll_stream(cx, i)) {
alamb marked this conversation as resolved.
Show resolved Hide resolved
self.aborted = true;
return Poll::Ready(Some(Err(e)));
let remaining_partitions = self.uninitiated_partitions.clone();
for i in remaining_partitions {
match self.maybe_poll_stream(cx, i) {
Poll::Ready(Err(e)) => {
self.aborted = true;
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => {
// If a partition returns Poll::Pending, to avoid continuously polling it
// and potentially increasing upstream buffer sizes, we move it to the
// back of the polling queue.
if let Some(front) = self.uninitiated_partitions.pop_front() {
// This pop_front can never return `None`.
self.uninitiated_partitions.push_back(front);
}
// This function could remain in a pending state, so we manually wake it here.
// However, this approach can be investigated further to find a more natural way
// to avoid disrupting the runtime scheduler.
cx.waker().wake_by_ref();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this usage has some side-effects or decrease performance, but I cannot wake the SPM poll again once it receives a pending from its first partition

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some research -- see https://github.com/synnada-ai/datafusion-upstream/pull/34/files#r1743621057

I think calling wake_by_ref effectively tells tokio to schedule this poll loop again after handling other tasks, which makes sense to me (as I am not sure how else we would signal to tokio that the merge is ready to go)

But I share your concern that this will cause some sort of performance issue

return Poll::Pending;
}
_ => {
// If the polling result is Poll::Ready(Some(batch)) or Poll::Ready(None),
// we remove this partition from the queue so it is not polled again.
self.uninitiated_partitions.retain(|idx| *idx != i);
}
}
}
self.init_loser_tree();
Expand Down
171 changes: 168 additions & 3 deletions datafusion/physical-plan/src/sorts/sort_preserving_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,11 @@ impl ExecutionPlan for SortPreservingMergeExec {

#[cfg(test)]
mod tests {
use std::fmt::Formatter;
use std::pin::Pin;
use std::sync::Mutex;
use std::task::{Context, Poll};
use std::time::Duration;

use super::*;
use crate::coalesce_partitions::CoalescePartitionsExec;
Expand All @@ -310,16 +315,23 @@ mod tests {
use crate::stream::RecordBatchReceiverStream;
use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
use crate::test::{self, assert_is_pending, make_partition};
use crate::{collect, common};
use crate::{collect, common, ExecutionMode};

use arrow::array::{ArrayRef, Int32Array, StringArray, TimestampNanosecondArray};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::{assert_batches_eq, assert_contains};
use arrow_schema::SchemaRef;
use datafusion_common::{assert_batches_eq, assert_contains, DataFusionError};
use datafusion_common_runtime::SpawnedTask;
use datafusion_execution::config::SessionConfig;
use datafusion_execution::RecordBatchStream;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::EquivalenceProperties;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;

use futures::{FutureExt, StreamExt};
use futures::{FutureExt, Stream, StreamExt};
use tokio::time::timeout;

#[tokio::test]
async fn test_merge_interleave() {
Expand Down Expand Up @@ -1141,4 +1153,157 @@ mod tests {
collected.as_slice()
);
}

/// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
/// partition is exhausted from the start, and if it is polled more than one, it panics.
#[derive(Debug, Clone)]
struct CongestedExec {
schema: Schema,
cache: PlanProperties,
congestion_cleared: Arc<Mutex<bool>>,
}

impl CongestedExec {
fn compute_properties(schema: SchemaRef) -> PlanProperties {
let columns = schema
.fields
.iter()
.enumerate()
.map(|(i, f)| Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>)
.collect::<Vec<_>>();
let mut eq_properties = EquivalenceProperties::new(schema);
eq_properties.add_new_orderings(vec![columns
.iter()
.map(|expr| {
PhysicalSortExpr::new(Arc::clone(expr), SortOptions::default())
})
.collect::<Vec<_>>()]);
let mode = ExecutionMode::Unbounded;
PlanProperties::new(eq_properties, Partitioning::Hash(columns, 3), mode)
}
}

impl ExecutionPlan for CongestedExec {
fn name(&self) -> &'static str {
Self::static_name()
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &PlanProperties {
&self.cache
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(self)
}
fn execute(
&self,
partition: usize,
_context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
Ok(Box::pin(CongestedStream {
schema: Arc::new(self.schema.clone()),
none_polled_once: false,
congestion_cleared: Arc::clone(&self.congestion_cleared),
partition,
}))
}
}

impl DisplayAs for CongestedExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(f, "CongestedExec",).unwrap()
}
}
Ok(())
}
}

/// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
/// partition is exhausted from the start, and if it is polled more than once, it panics.
#[derive(Debug)]
pub struct CongestedStream {
schema: SchemaRef,
none_polled_once: bool,
congestion_cleared: Arc<Mutex<bool>>,
partition: usize,
}

impl Stream for CongestedStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.partition {
0 => {
if self.none_polled_once {
panic!("Exhausted stream is polled more than one")
} else {
self.none_polled_once = true;
Poll::Ready(None)
}
}
1 => {
let cleared = self.congestion_cleared.lock().unwrap();
if *cleared {
Poll::Ready(None)
} else {
Poll::Pending
}
}
2 => {
let mut cleared = self.congestion_cleared.lock().unwrap();
*cleared = true;
Poll::Ready(None)
}
_ => unreachable!(),
}
}
}

impl RecordBatchStream for CongestedStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}

#[tokio::test]
async fn test_spm_congestion() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]);
let source = CongestedExec {
schema: schema.clone(),
cache: CongestedExec::compute_properties(Arc::new(schema.clone())),
congestion_cleared: Arc::new(Mutex::new(false)),
};
let spm = SortPreservingMergeExec::new(
vec![PhysicalSortExpr::new(
Arc::new(Column::new("c1", 0)),
SortOptions::default(),
)],
Arc::new(source),
);
let spm_task = SpawnedTask::spawn(collect(Arc::new(spm), task_ctx));

let result = timeout(Duration::from_secs(3), spm_task.join()).await;
match result {
Ok(Ok(Ok(_batches))) => Ok(()),
Ok(Ok(Err(e))) => Err(e),
Ok(Err(_)) => Err(DataFusionError::Execution(
"SortPreservingMerge task panicked or was cancelled".to_string(),
)),
Err(_) => Err(DataFusionError::Execution(
"SortPreservingMerge caused a deadlock".to_string(),
)),
}
}
}