diff --git a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs index 30f4d6b8f3d5..4cf264f1524a 100644 --- a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs @@ -37,7 +37,7 @@ use datafusion::physical_plan::{ sorts::sort_preserving_merge::SortPreservingMergeExec, }; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{exec_err, Result}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::expressions::Column; @@ -310,11 +310,9 @@ async fn test_spm_congestion() -> Result<()> { match result { Ok(Ok(Ok(_batches))) => Ok(()), Ok(Ok(Err(e))) => Err(e), - Ok(Err(_)) => Err(DataFusionError::Execution( - "SortPreservingMerge task panicked or was cancelled".to_string(), - )), - Err(_) => Err(DataFusionError::Execution( - "SortPreservingMerge caused a deadlock".to_string(), - )), + Ok(Err(e)) => { + exec_err!("SortPreservingMerge task panicked or was cancelled: {e}") + } + Err(e) => exec_err!("SortPreservingMerge caused a deadlock: {e}"), } } diff --git a/datafusion/physical-plan/src/sorts/merge.rs b/datafusion/physical-plan/src/sorts/merge.rs index 811c93ce5892..443e882f217d 100644 --- a/datafusion/physical-plan/src/sorts/merge.rs +++ b/datafusion/physical-plan/src/sorts/merge.rs @@ -100,12 +100,6 @@ pub(crate) struct SortPreservingMergeStream { /// number of rows produced produced: usize, - - /// This vector contains partition indices in order. When a partition is polled and returns `Poll::Ready`, - /// it is removed from the vector. If a partition returns `Poll::Pending`, it is moved to the end of the - /// vector to ensure the next iteration starts with a different partition, preventing the same partition - /// from being continuously polled. - uninitiated_partitions: Vec, } impl SortPreservingMergeStream { @@ -130,7 +124,6 @@ impl SortPreservingMergeStream { batch_size, fetch, produced: 0, - uninitiated_partitions: (0..stream_count).collect(), } } @@ -167,23 +160,28 @@ impl SortPreservingMergeStream { // try to initialize the loser tree if self.loser_tree.is_empty() { // Ensure all non-exhausted streams have a cursor from which rows can be pulled - let remaining_partitions = self.uninitiated_partitions.clone(); - for i in remaining_partitions { + let mut any_pending = false; + for i in 0..self.streams.partitions() { match self.maybe_poll_stream(cx, i) { Poll::Ready(Err(e)) => { self.aborted = true; return Poll::Ready(Some(Err(e))); } - Poll::Pending => { - self.uninitiated_partitions.rotate_left(1); - cx.waker().wake_by_ref(); - return Poll::Pending; + Poll::Ready(Ok(())) => { + // input i is ready } - _ => { - self.uninitiated_partitions.retain(|idx| *idx != i); + Poll::Pending => { + // input i is not ready + any_pending = true; } } } + if any_pending { + // If any stream is not ready, return pending and tell the executor to wake us up + // to try again when it is ready + cx.waker().wake_by_ref(); + return Poll::Pending; + } self.init_loser_tree(); }