diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index acf540d44465..826992e132ba 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1140,6 +1140,7 @@ message NestedLoopJoinExecNode { message CoalesceBatchesExecNode { PhysicalPlanNode input = 1; uint32 target_batch_size = 2; + optional uint32 fetch = 3; } message CoalescePartitionsExecNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 489b6c67534f..b4d63798f080 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -2000,6 +2000,9 @@ impl serde::Serialize for CoalesceBatchesExecNode { if self.target_batch_size != 0 { len += 1; } + if self.fetch.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.CoalesceBatchesExecNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; @@ -2007,6 +2010,9 @@ impl serde::Serialize for CoalesceBatchesExecNode { if self.target_batch_size != 0 { struct_ser.serialize_field("targetBatchSize", &self.target_batch_size)?; } + if let Some(v) = self.fetch.as_ref() { + struct_ser.serialize_field("fetch", v)?; + } struct_ser.end() } } @@ -2020,12 +2026,14 @@ impl<'de> serde::Deserialize<'de> for CoalesceBatchesExecNode { "input", "target_batch_size", "targetBatchSize", + "fetch", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Input, TargetBatchSize, + Fetch, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2049,6 +2057,7 @@ impl<'de> serde::Deserialize<'de> for CoalesceBatchesExecNode { match value { "input" => Ok(GeneratedField::Input), "targetBatchSize" | "target_batch_size" => Ok(GeneratedField::TargetBatchSize), + "fetch" => Ok(GeneratedField::Fetch), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -2070,6 +2079,7 @@ impl<'de> serde::Deserialize<'de> for CoalesceBatchesExecNode { { let mut input__ = None; let mut target_batch_size__ = None; + let mut fetch__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -2086,11 +2096,20 @@ impl<'de> serde::Deserialize<'de> for CoalesceBatchesExecNode { Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); + } + fetch__ = + map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0) + ; + } } } Ok(CoalesceBatchesExecNode { input: input__, target_batch_size: target_batch_size__.unwrap_or_default(), + fetch: fetch__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index c98c950d35f9..875d2af75dd7 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1813,6 +1813,8 @@ pub struct CoalesceBatchesExecNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(uint32, tag = "2")] pub target_batch_size: u32, + #[prost(uint32, optional, tag = "3")] + pub fetch: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 0f6722dd375b..96fb45eafe62 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -259,10 +259,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { runtime, extension_codec, )?; - Ok(Arc::new(CoalesceBatchesExec::new( - input, - coalesce_batches.target_batch_size as usize, - ))) + Ok(Arc::new( + CoalesceBatchesExec::new( + input, + coalesce_batches.target_batch_size as usize, + ) + .with_fetch(coalesce_batches.fetch.map(|f| f as usize)), + )) } PhysicalPlanType::Merge(merge) => { let input: Arc = @@ -1536,6 +1539,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { protobuf::CoalesceBatchesExecNode { input: Some(Box::new(input)), target_batch_size: coalesce_batches.target_batch_size() as u32, + fetch: coalesce_batches.fetch().map(|n| n as u32), }, ))), }); diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 6766468ef443..0ffc494321fb 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -25,6 +25,7 @@ use std::vec; use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; use datafusion::physical_expr_functions_aggregate::aggregate::AggregateExprBuilder; +use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf; use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_functions_aggregate::min_max::max_udaf; @@ -629,6 +630,23 @@ fn roundtrip_sort_preserve_partitioning() -> Result<()> { )) } +#[test] +fn roundtrip_coalesce_with_fetch() -> Result<()> { + let field_a = Field::new("a", DataType::Boolean, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + roundtrip_test(Arc::new(CoalesceBatchesExec::new( + Arc::new(EmptyExec::new(schema.clone())), + 8096, + )))?; + + roundtrip_test(Arc::new( + CoalesceBatchesExec::new(Arc::new(EmptyExec::new(schema.clone())), 8096) + .with_fetch(Some(10)), + )) +} + #[test] fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { let scan_config = FileScanConfig {