From 9ab42e249336ccec5abeaccf75f436957f044c1b Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Sat, 2 Mar 2024 21:28:49 -0600 Subject: [PATCH 01/11] initial try --- .../physical-expr/src/scalar_function.rs | 9 +- .../proto/src/physical_plan/to_proto.rs | 391 +++++++++--------- 2 files changed, 207 insertions(+), 193 deletions(-) diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index bfe0fdb279f5..131e6be52f45 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -44,12 +44,12 @@ use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_expr::{ expr_vec_fmt, BuiltinScalarFunction, ColumnarValue, FuncMonotonicity, - ScalarFunctionImplementation, + ScalarFunctionDefinition, ScalarFunctionImplementation, }; /// Physical expression of a scalar function pub struct ScalarFunctionExpr { - fun: ScalarFunctionImplementation, + fun: ScalarFunctionDefinition, name: String, args: Vec>, return_type: DataType, @@ -79,7 +79,7 @@ impl ScalarFunctionExpr { /// Create a new Scalar function pub fn new( name: &str, - fun: ScalarFunctionImplementation, + fun: ScalarFunctionDefinition, args: Vec>, return_type: DataType, monotonicity: Option, @@ -97,7 +97,8 @@ impl ScalarFunctionExpr { /// Get the scalar function implementation pub fn fun(&self) -> &ScalarFunctionImplementation { - &self.fun + // + todo!() } /// The name for this expression diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index ce3df8183dc9..669e8a739f41 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -30,6 +30,7 @@ use crate::protobuf::{ #[cfg(feature = "parquet")] use datafusion::datasource::file_format::parquet::ParquetSink; +use datafusion_expr::ScalarUDF; use crate::logical_plan::{csv_writer_options_to_proto, writer_properties_to_proto}; use datafusion::datasource::{ @@ -70,14 +71,17 @@ use datafusion_common::{ DataFusionError, FileTypeWriterOptions, JoinSide, Result, }; +use super::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; + impl TryFrom> for protobuf::PhysicalExprNode { type Error = DataFusionError; fn try_from(a: Arc) -> Result { + let codec = DefaultPhysicalExtensionCodec {}; let expressions: Vec = a .expressions() .iter() - .map(|e| e.clone().try_into()) + .map(|e| serialize_expr(e.clone(), &codec)) .collect::>>()?; let ordering_req: Vec = a @@ -237,16 +241,16 @@ impl TryFrom> for protobuf::PhysicalWindowExprNode { } else { return not_impl_err!("WindowExpr not supported: {window_expr:?}"); }; - + let codec = DefaultPhysicalExtensionCodec {}; let args = args .into_iter() - .map(|e| e.try_into()) + .map(|e| serialize_expr(e, &codec)) .collect::>>()?; let partition_by = window_expr .partition_by() .iter() - .map(|p| p.clone().try_into()) + .map(|p| serialize_expr(p.clone(), &codec)) .collect::>>()?; let order_by = window_expr @@ -374,184 +378,192 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { Ok(AggrFn { inner, distinct }) } -impl TryFrom> for protobuf::PhysicalExprNode { - type Error = DataFusionError; +fn serialize_expr( + value: Arc, + codec: &dyn PhysicalExtensionCodec, +) -> Result { + let expr = value.as_any(); - fn try_from(value: Arc) -> Result { - let expr = value.as_any(); - - if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Column( - protobuf::PhysicalColumn { - name: expr.name().to_string(), - index: expr.index() as u32, - }, - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode { - l: Some(Box::new(expr.left().to_owned().try_into()?)), - r: Some(Box::new(expr.right().to_owned().try_into()?)), - op: format!("{:?}", expr.op()), - }); + if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Column( + protobuf::PhysicalColumn { + name: expr.name().to_string(), + index: expr.index() as u32, + }, + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode { + l: Some(Box::new(serialize_expr(expr.left().clone(), codec)?)), + r: Some(Box::new(serialize_expr(expr.right().clone(), codec)?)), + op: format!("{:?}", expr.op()), + }); - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr( - binary_expr, - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::Case( - Box::new( - protobuf::PhysicalCaseNode { - expr: expr - .expr() - .map(|exp| exp.clone().try_into().map(Box::new)) - .transpose()?, - when_then_expr: expr - .when_then_expr() - .iter() - .map(|(when_expr, then_expr)| { - try_parse_when_then_expr(when_expr, then_expr) - }) - .collect::, - Self::Error, - >>()?, - else_expr: expr - .else_expr() - .map(|a| a.clone().try_into().map(Box::new)) - .transpose()?, - }, - ), + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr( + binary_expr, + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some( + protobuf::physical_expr_node::ExprType::Case( + Box::new( + protobuf::PhysicalCaseNode { + expr: expr + .expr() + .map(|exp| { + serialize_expr(exp.clone(), codec).map(Box::new) + }) + .transpose()?, + when_then_expr: expr + .when_then_expr() + .iter() + .map(|(when_expr, then_expr)| { + try_parse_when_then_expr(when_expr, then_expr, codec) + }) + .collect::, + DataFusionError, + >>()?, + else_expr: expr + .else_expr() + .map(|a| serialize_expr(a.clone(), codec).map(Box::new)) + .transpose()?, + }, ), ), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr( - Box::new(protobuf::PhysicalNot { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( - Box::new(protobuf::PhysicalIsNull { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( - Box::new(protobuf::PhysicalIsNotNull { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::InList( - Box::new( - protobuf::PhysicalInListNode { - expr: Some(Box::new(expr.expr().to_owned().try_into()?)), - list: expr - .list() - .iter() - .map(|a| a.clone().try_into()) - .collect::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr(Box::new( + protobuf::PhysicalNot { + expr: Some(Box::new(serialize_expr(expr.arg().to_owned(), codec)?)), + }, + ))), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( + Box::new(protobuf::PhysicalIsNull { + expr: Some(Box::new(serialize_expr(expr.arg().to_owned(), codec)?)), + }), + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( + Box::new(protobuf::PhysicalIsNotNull { + expr: Some(Box::new(serialize_expr(expr.arg().to_owned(), codec)?)), + }), + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some( + protobuf::physical_expr_node::ExprType::InList( + Box::new( + protobuf::PhysicalInListNode { + expr: Some(Box::new(serialize_expr( + expr.expr().to_owned(), + codec, + )?)), + list: expr + .list() + .iter() + .map(|a| serialize_expr(a.clone(), codec)) + .collect::, - Self::Error, + DataFusionError, >>()?, - negated: expr.negated(), - }, - ), + negated: expr.negated(), + }, ), ), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Negative( - Box::new(protobuf::PhysicalNegativeNode { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(lit) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( - lit.value().try_into()?, - )), - }) - } else if let Some(cast) = expr.downcast_ref::() { + ), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Negative(Box::new( + protobuf::PhysicalNegativeNode { + expr: Some(Box::new(serialize_expr(expr.arg().to_owned(), codec)?)), + }, + ))), + }) + } else if let Some(lit) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( + lit.value().try_into()?, + )), + }) + } else if let Some(cast) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( + protobuf::PhysicalCastNode { + expr: Some(Box::new(serialize_expr(cast.expr().to_owned(), codec)?)), + arrow_type: Some(cast.cast_type().try_into()?), + }, + ))), + }) + } else if let Some(cast) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new( + protobuf::PhysicalTryCastNode { + expr: Some(Box::new(serialize_expr(cast.expr().to_owned(), codec)?)), + arrow_type: Some(cast.cast_type().try_into()?), + }, + ))), + }) + } else if let Some(expr) = expr.downcast_ref::() { + let args: Vec = expr + .args() + .iter() + .map(|e| serialize_expr(e.to_owned(), codec)) + .collect::, _>>()?; + if let Ok(fun) = BuiltinScalarFunction::from_str(expr.name()) { + let fun: protobuf::ScalarFunction = (&fun).try_into()?; + Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( - protobuf::PhysicalCastNode { - expr: Some(Box::new(cast.expr().clone().try_into()?)), - arrow_type: Some(cast.cast_type().try_into()?), + expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarFunction( + protobuf::PhysicalScalarFunctionNode { + name: expr.name().to_string(), + fun: fun.into(), + args, + return_type: Some(expr.return_type().try_into()?), }, - ))), - }) - } else if let Some(cast) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast( - Box::new(protobuf::PhysicalTryCastNode { - expr: Some(Box::new(cast.expr().clone().try_into()?)), - arrow_type: Some(cast.cast_type().try_into()?), - }), )), }) - } else if let Some(expr) = expr.downcast_ref::() { - let args: Vec = expr - .args() - .iter() - .map(|e| e.to_owned().try_into()) - .collect::, _>>()?; - if let Ok(fun) = BuiltinScalarFunction::from_str(expr.name()) { - let fun: protobuf::ScalarFunction = (&fun).try_into()?; - - Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::ScalarFunction( - protobuf::PhysicalScalarFunctionNode { - name: expr.name().to_string(), - fun: fun.into(), - args, - return_type: Some(expr.return_type().try_into()?), - }, - ), - ), - }) - } else { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( - protobuf::PhysicalScalarUdfNode { - name: expr.name().to_string(), - args, - return_type: Some(expr.return_type().try_into()?), - }, - )), - }) - } - } else if let Some(expr) = expr.downcast_ref::() { + } else { + let mut buf = Vec::new(); + // let _ = codec.try_encode_udf(, &mut buf); Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::LikeExpr( - Box::new(protobuf::PhysicalLikeExprNode { - negated: expr.negated(), - case_insensitive: expr.case_insensitive(), - expr: Some(Box::new(expr.expr().to_owned().try_into()?)), - pattern: Some(Box::new(expr.pattern().to_owned().try_into()?)), - }), + expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( + protobuf::PhysicalScalarUdfNode { + name: expr.name().to_string(), + args, + return_type: Some(expr.return_type().try_into()?), + }, )), }) - } else if let Some(expr) = expr.downcast_ref::() { - let field = match expr.field() { + } + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::LikeExpr(Box::new( + protobuf::PhysicalLikeExprNode { + negated: expr.negated(), + case_insensitive: expr.case_insensitive(), + expr: Some(Box::new(serialize_expr(expr.expr().to_owned(), codec)?)), + pattern: Some(Box::new(serialize_expr( + expr.pattern().to_owned(), + codec, + )?)), + }, + ))), + }) + } else if let Some(expr) = expr.downcast_ref::() { + let field = match expr.field() { GetFieldAccessExpr::NamedStructField{name} => Some( protobuf::physical_get_indexed_field_expr_node::Field::NamedStructFieldExpr(protobuf::NamedStructFieldExpr { name: Some(ScalarValue::try_from(name)?) @@ -559,43 +571,41 @@ impl TryFrom> for protobuf::PhysicalExprNode { ), GetFieldAccessExpr::ListIndex{key} => Some( protobuf::physical_get_indexed_field_expr_node::Field::ListIndexExpr(Box::new(protobuf::ListIndexExpr { - key: Some(Box::new(key.to_owned().try_into()?)) + key: Some(Box::new(serialize_expr(key.to_owned(), codec)?)) })) ), GetFieldAccessExpr::ListRange { start, stop, stride } => { Some( protobuf::physical_get_indexed_field_expr_node::Field::ListRangeExpr(Box::new(protobuf::ListRangeExpr { - start: Some(Box::new(start.to_owned().try_into()?)), - stop: Some(Box::new(stop.to_owned().try_into()?)), - stride: Some(Box::new(stride.to_owned().try_into()?)), + start: Some(Box::new(serialize_expr(start.to_owned(), codec)?)), + stop: Some(Box::new(serialize_expr(stop.to_owned(), codec)?)), + stride: Some(Box::new(serialize_expr(stride.to_owned(), codec)?)), })) ) } }; - Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::GetIndexedFieldExpr( - Box::new(protobuf::PhysicalGetIndexedFieldExprNode { - arg: Some(Box::new(expr.arg().to_owned().try_into()?)), - field, - }), - ), - ), - }) - } else { - internal_err!("physical_plan::to_proto() unsupported expression {value:?}") - } + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::GetIndexedFieldExpr( + Box::new(protobuf::PhysicalGetIndexedFieldExprNode { + arg: Some(Box::new(serialize_expr(expr.arg().to_owned(), codec)?)), + field, + }), + )), + }) + } else { + internal_err!("physical_plan::to_proto() unsupported expression {value:?}") } } fn try_parse_when_then_expr( when_expr: &Arc, then_expr: &Arc, + codec: &dyn PhysicalExtensionCodec, ) -> Result { Ok(protobuf::PhysicalWhenThen { - when_expr: Some(when_expr.clone().try_into()?), - then_expr: Some(then_expr.clone().try_into()?), + when_expr: Some(serialize_expr(when_expr.clone(), codec)?), + then_expr: Some(serialize_expr(then_expr.clone(), codec)?), }) } @@ -716,6 +726,7 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { fn try_from( conf: &FileScanConfig, ) -> Result { + let codec = DefaultPhysicalExtensionCodec {}; let file_groups = conf .file_groups .iter() @@ -727,7 +738,7 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { let expr_node_vec = order .iter() .map(|sort_expr| { - let expr = sort_expr.expr.clone().try_into()?; + let expr = serialize_expr(sort_expr.expr, &codec)?; Ok(PhysicalSortExprNode { expr: Some(Box::new(expr)), asc: !sort_expr.options.descending, @@ -790,10 +801,11 @@ impl TryFrom>> for protobuf::MaybeFilter { type Error = DataFusionError; fn try_from(expr: Option>) -> Result { + let codec = DefaultPhysicalExtensionCodec {}; match expr { None => Ok(protobuf::MaybeFilter { expr: None }), Some(expr) => Ok(protobuf::MaybeFilter { - expr: Some(expr.try_into()?), + expr: Some(serialize_expr(expr, &codec)?), }), } } @@ -819,8 +831,9 @@ impl TryFrom for protobuf::PhysicalSortExprNode { type Error = DataFusionError; fn try_from(sort_expr: PhysicalSortExpr) -> std::result::Result { + let codec = DefaultPhysicalExtensionCodec {}; Ok(PhysicalSortExprNode { - expr: Some(Box::new(sort_expr.expr.try_into()?)), + expr: Some(Box::new(serialize_expr(sort_expr.expr, &codec)?)), asc: !sort_expr.options.descending, nulls_first: sort_expr.options.nulls_first, }) From 8d7c86644681457d7927c13b26c9c2f6569c63ac Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Sun, 3 Mar 2024 11:17:59 -0600 Subject: [PATCH 02/11] revert --- datafusion/physical-expr/src/scalar_function.rs | 7 +++---- datafusion/proto/src/physical_plan/to_proto.rs | 3 ++- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 131e6be52f45..03682f099679 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -49,7 +49,7 @@ use datafusion_expr::{ /// Physical expression of a scalar function pub struct ScalarFunctionExpr { - fun: ScalarFunctionDefinition, + fun: ScalarFunctionImplementation, name: String, args: Vec>, return_type: DataType, @@ -79,7 +79,7 @@ impl ScalarFunctionExpr { /// Create a new Scalar function pub fn new( name: &str, - fun: ScalarFunctionDefinition, + fun: ScalarFunctionImplementation, args: Vec>, return_type: DataType, monotonicity: Option, @@ -97,8 +97,7 @@ impl ScalarFunctionExpr { /// Get the scalar function implementation pub fn fun(&self) -> &ScalarFunctionImplementation { - // - todo!() + &self.fun } /// The name for this expression diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 669e8a739f41..c4896a3b31d5 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -30,7 +30,7 @@ use crate::protobuf::{ #[cfg(feature = "parquet")] use datafusion::datasource::file_format::parquet::ParquetSink; -use datafusion_expr::ScalarUDF; +use datafusion_expr::{create_udf, ScalarUDF}; use crate::logical_plan::{csv_writer_options_to_proto, writer_properties_to_proto}; use datafusion::datasource::{ @@ -537,6 +537,7 @@ fn serialize_expr( }) } else { let mut buf = Vec::new(); + let udf = create_udf(expr.name(), expr., return_type, volatility, fun); // let _ = codec.try_encode_udf(, &mut buf); Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( From 6b154e002f7f43156db1e8af08b6fab62c14246a Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Tue, 12 Mar 2024 13:12:46 -0500 Subject: [PATCH 03/11] stage commit --- datafusion/physical-expr/src/functions.rs | 9 ++++--- .../physical-expr/src/scalar_function.rs | 26 ++++++++++++++----- datafusion/physical-expr/src/udf.rs | 11 ++++++-- .../proto/src/physical_plan/to_proto.rs | 2 +- 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 5d13f945692a..fd55cf0d280b 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -44,6 +44,7 @@ use arrow_array::Array; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::execution_props::ExecutionProps; pub use datafusion_expr::FuncMonotonicity; +use datafusion_expr::ScalarFunctionDefinition; use datafusion_expr::{ type_coercion::functions::data_types, BuiltinScalarFunction, ColumnarValue, ScalarFunctionImplementation, @@ -69,18 +70,20 @@ pub fn create_physical_expr( let data_type = fun.return_type(&input_expr_types)?; - let fun_expr: ScalarFunctionImplementation = - create_physical_fun(fun, execution_props)?; + // let fun_expr: ScalarFunctionImplementation = + // create_physical_fun(fun, execution_props)?; let monotonicity = fun.monotonicity(); + let fun_def = ScalarFunctionDefinition::BuiltIn(*fun); Ok(Arc::new(ScalarFunctionExpr::new( &format!("{fun}"), - fun_expr, + fun_def, input_phy_exprs.to_vec(), data_type, monotonicity, fun.signature().type_signature.supports_zero_argument(), + execution_props, ))) } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 1fa17b955ac2..8c04b70ea2e5 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -34,14 +34,15 @@ use std::fmt::{self, Debug, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::functions::out_ordering; +use crate::functions::{create_physical_fun, out_ordering}; use crate::physical_expr::{down_cast_any_ref, physical_exprs_equal}; use crate::sort_properties::SortProperties; use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::Result; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{ expr_vec_fmt, BuiltinScalarFunction, ColumnarValue, FuncMonotonicity, ScalarFunctionDefinition, ScalarFunctionImplementation, @@ -49,7 +50,7 @@ use datafusion_expr::{ /// Physical expression of a scalar function pub struct ScalarFunctionExpr { - fun: ScalarFunctionImplementation, + fun: ScalarFunctionDefinition, name: String, args: Vec>, return_type: DataType, @@ -60,6 +61,8 @@ pub struct ScalarFunctionExpr { monotonicity: Option, // Whether this function can be invoked with zero arguments supports_zero_argument: bool, + // Execution properties + execution_props: ExecutionProps, } impl Debug for ScalarFunctionExpr { @@ -79,11 +82,12 @@ impl ScalarFunctionExpr { /// Create a new Scalar function pub fn new( name: &str, - fun: ScalarFunctionImplementation, + fun: ScalarFunctionDefinition, args: Vec>, return_type: DataType, monotonicity: Option, supports_zero_argument: bool, + execution_props: &ExecutionProps, ) -> Self { Self { fun, @@ -92,11 +96,12 @@ impl ScalarFunctionExpr { return_type, monotonicity, supports_zero_argument, + execution_props: execution_props.clone(), } } /// Get the scalar function implementation - pub fn fun(&self) -> &ScalarFunctionImplementation { + pub fn fun(&self) -> &ScalarFunctionDefinition { &self.fun } @@ -171,8 +176,16 @@ impl PhysicalExpr for ScalarFunctionExpr { .collect::>>()?, }; + let fun_implementation = match self.fun { + ScalarFunctionDefinition::BuiltIn(ref fun) => { + create_physical_fun(fun, &self.execution_props)? + } + _ => { + todo!("User-defined functions are not supported yet") + } + }; // evaluate the function - let fun = self.fun.as_ref(); + let fun = fun_implementation.as_ref(); (fun)(&inputs) } @@ -191,6 +204,7 @@ impl PhysicalExpr for ScalarFunctionExpr { self.return_type().clone(), self.monotonicity.clone(), self.supports_zero_argument, + &self.execution_props, ))) } diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index ede3e5badbb1..0df2ab2182a1 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -20,7 +20,10 @@ use crate::{PhysicalExpr, ScalarFunctionExpr}; use arrow_schema::Schema; use datafusion_common::{DFSchema, Result}; pub use datafusion_expr::ScalarUDF; -use datafusion_expr::{type_coercion::functions::data_types, Expr}; +use datafusion_expr::{ + execution_props::ExecutionProps, type_coercion::functions::data_types, Expr, + ScalarFunctionDefinition, +}; use std::sync::Arc; /// Create a physical expression of the UDF. @@ -45,13 +48,17 @@ pub fn create_physical_expr( let return_type = fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?; + let execution_props = ExecutionProps::new(); + + let fun_def = ScalarFunctionDefinition::UDF(Arc::new(fun.clone())); Ok(Arc::new(ScalarFunctionExpr::new( fun.name(), - fun.fun(), + fun_def, input_phy_exprs.to_vec(), return_type, fun.monotonicity()?, fun.signature().type_signature.supports_zero_argument(), + &execution_props, ))) } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 1e6bfd830a6e..2105d72b464f 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -542,7 +542,7 @@ fn serialize_expr( }) } else { let mut buf = Vec::new(); - let udf = create_udf(expr.name(), expr., return_type, volatility, fun); + let udf = create_udf(expr.name(), expr, return_type, volatility, fun); // let _ = codec.try_encode_udf(, &mut buf); Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( From b33656a50b5c0633b99fd974ef32c10430b1b88a Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Wed, 13 Mar 2024 14:52:54 -0500 Subject: [PATCH 04/11] use ScalarFunctionDefinition to rewrite PhysicalExpr proto --- .../physical_optimizer/projection_pushdown.rs | 58 +++++++- datafusion/physical-expr/src/functions.rs | 7 +- .../physical-expr/src/scalar_function.rs | 14 +- datafusion/physical-expr/src/udf.rs | 6 +- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 21 +++ datafusion/proto/src/generated/prost.rs | 2 + .../proto/src/physical_plan/from_proto.rs | 44 ++++-- datafusion/proto/src/physical_plan/mod.rs | 131 +++++++++++++----- .../proto/src/physical_plan/to_proto.rs | 37 +++-- .../tests/cases/roundtrip_physical_plan.rs | 16 +-- 11 files changed, 239 insertions(+), 98 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index e8f3bf01ecaa..08a78013b298 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -1287,6 +1287,7 @@ fn new_join_children( #[cfg(test)] mod tests { use super::*; + use std::any::Any; use std::sync::Arc; use crate::datasource::file_format::file_compression_type::FileCompressionType; @@ -1313,7 +1314,10 @@ mod tests { use datafusion_common::{JoinSide, JoinType, Result, ScalarValue, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; - use datafusion_expr::{ColumnarValue, Operator}; + use datafusion_expr::{ + ColumnarValue, Operator, ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, + }; use datafusion_physical_expr::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, }; @@ -1329,6 +1333,42 @@ mod tests { use itertools::Itertools; + /// Mocked UDF + #[derive(Debug)] + struct DummyUDF { + signature: Signature, + } + + impl DummyUDF { + fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for DummyUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "dummy_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!("DummyUDF::invoke") + } + } + #[test] fn test_update_matching_exprs() -> Result<()> { let exprs: Vec> = vec![ @@ -1345,7 +1385,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1412,7 +1454,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1482,7 +1526,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1549,7 +1595,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b_new", 1)), diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index f1636c1e55a4..3c1c0663235a 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -58,7 +58,7 @@ pub fn create_physical_expr( fun: &BuiltinScalarFunction, input_phy_exprs: &[Arc], input_schema: &Schema, - execution_props: &ExecutionProps, + _execution_props: &ExecutionProps, ) -> Result> { let input_expr_types = input_phy_exprs .iter() @@ -70,9 +70,6 @@ pub fn create_physical_expr( let data_type = fun.return_type(&input_expr_types)?; - // let fun_expr: ScalarFunctionImplementation = - // create_physical_fun(fun, execution_props)?; - let monotonicity = fun.monotonicity(); let fun_def = ScalarFunctionDefinition::BuiltIn(*fun); @@ -83,7 +80,6 @@ pub fn create_physical_expr( data_type, monotonicity, fun.signature().type_signature.supports_zero_argument(), - execution_props, ))) } @@ -198,7 +194,6 @@ where /// Create a physical scalar function. pub fn create_physical_fun( fun: &BuiltinScalarFunction, - _execution_props: &ExecutionProps, ) -> Result { Ok(match fun { // math functions diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 8c04b70ea2e5..f83dfe2ada1b 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -41,11 +41,10 @@ use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::{internal_err, Result}; -use datafusion_expr::execution_props::ExecutionProps; +use datafusion_common::Result; use datafusion_expr::{ expr_vec_fmt, BuiltinScalarFunction, ColumnarValue, FuncMonotonicity, - ScalarFunctionDefinition, ScalarFunctionImplementation, + ScalarFunctionDefinition, }; /// Physical expression of a scalar function @@ -61,8 +60,6 @@ pub struct ScalarFunctionExpr { monotonicity: Option, // Whether this function can be invoked with zero arguments supports_zero_argument: bool, - // Execution properties - execution_props: ExecutionProps, } impl Debug for ScalarFunctionExpr { @@ -87,7 +84,6 @@ impl ScalarFunctionExpr { return_type: DataType, monotonicity: Option, supports_zero_argument: bool, - execution_props: &ExecutionProps, ) -> Self { Self { fun, @@ -96,7 +92,6 @@ impl ScalarFunctionExpr { return_type, monotonicity, supports_zero_argument, - execution_props: execution_props.clone(), } } @@ -177,9 +172,7 @@ impl PhysicalExpr for ScalarFunctionExpr { }; let fun_implementation = match self.fun { - ScalarFunctionDefinition::BuiltIn(ref fun) => { - create_physical_fun(fun, &self.execution_props)? - } + ScalarFunctionDefinition::BuiltIn(ref fun) => create_physical_fun(fun)?, _ => { todo!("User-defined functions are not supported yet") } @@ -204,7 +197,6 @@ impl PhysicalExpr for ScalarFunctionExpr { self.return_type().clone(), self.monotonicity.clone(), self.supports_zero_argument, - &self.execution_props, ))) } diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index 0df2ab2182a1..4fc94bfa15ec 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -21,8 +21,7 @@ use arrow_schema::Schema; use datafusion_common::{DFSchema, Result}; pub use datafusion_expr::ScalarUDF; use datafusion_expr::{ - execution_props::ExecutionProps, type_coercion::functions::data_types, Expr, - ScalarFunctionDefinition, + type_coercion::functions::data_types, Expr, ScalarFunctionDefinition, }; use std::sync::Arc; @@ -48,8 +47,6 @@ pub fn create_physical_expr( let return_type = fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?; - let execution_props = ExecutionProps::new(); - let fun_def = ScalarFunctionDefinition::UDF(Arc::new(fun.clone())); Ok(Arc::new(ScalarFunctionExpr::new( fun.name(), @@ -58,7 +55,6 @@ pub fn create_physical_expr( return_type, fun.monotonicity()?, fun.signature().type_signature.supports_zero_argument(), - &execution_props, ))) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index b5683dc1425e..1d25463cdb85 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1458,6 +1458,7 @@ message PhysicalExprNode { message PhysicalScalarUdfNode { string name = 1; repeated PhysicalExprNode args = 2; + optional bytes fun_definition = 3; ArrowType return_type = 4; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index f5be49dc9de7..d93a49be8e1e 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20391,6 +20391,9 @@ impl serde::Serialize for PhysicalScalarUdfNode { if !self.args.is_empty() { len += 1; } + if self.fun_definition.is_some() { + len += 1; + } if self.return_type.is_some() { len += 1; } @@ -20401,6 +20404,10 @@ impl serde::Serialize for PhysicalScalarUdfNode { if !self.args.is_empty() { struct_ser.serialize_field("args", &self.args)?; } + if let Some(v) = self.fun_definition.as_ref() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; + } if let Some(v) = self.return_type.as_ref() { struct_ser.serialize_field("returnType", v)?; } @@ -20416,6 +20423,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { const FIELDS: &[&str] = &[ "name", "args", + "fun_definition", + "funDefinition", "return_type", "returnType", ]; @@ -20424,6 +20433,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { enum GeneratedField { Name, Args, + FunDefinition, ReturnType, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -20448,6 +20458,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { match value { "name" => Ok(GeneratedField::Name), "args" => Ok(GeneratedField::Args), + "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), "returnType" | "return_type" => Ok(GeneratedField::ReturnType), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -20470,6 +20481,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { { let mut name__ = None; let mut args__ = None; + let mut fun_definition__ = None; let mut return_type__ = None; while let Some(k) = map_.next_key()? { match k { @@ -20485,6 +20497,14 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { } args__ = Some(map_.next_value()?); } + GeneratedField::FunDefinition => { + if fun_definition__.is_some() { + return Err(serde::de::Error::duplicate_field("funDefinition")); + } + fun_definition__ = + map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) + ; + } GeneratedField::ReturnType => { if return_type__.is_some() { return Err(serde::de::Error::duplicate_field("returnType")); @@ -20496,6 +20516,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { Ok(PhysicalScalarUdfNode { name: name__.unwrap_or_default(), args: args__.unwrap_or_default(), + fun_definition: fun_definition__, return_type: return_type__, }) } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index e1c9af105bbd..8b025028dc6b 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2092,6 +2092,8 @@ pub struct PhysicalScalarUdfNode { pub name: ::prost::alloc::string::String, #[prost(message, repeated, tag = "2")] pub args: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", optional, tag = "3")] + pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, #[prost(message, optional, tag = "4")] pub return_type: ::core::option::Option, } diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 184c048c1bdd..ca54d4e803ca 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -59,9 +59,12 @@ use datafusion_common::stats::Precision; use datafusion_common::{not_impl_err, DataFusionError, JoinSide, Result, ScalarValue}; use chrono::{TimeZone, Utc}; +use datafusion_expr::ScalarFunctionDefinition; use object_store::path::Path; use object_store::ObjectMeta; +use super::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; + impl From<&protobuf::PhysicalColumn> for Column { fn from(c: &protobuf::PhysicalColumn) -> Column { Column::new(&c.name, c.index as usize) @@ -82,7 +85,8 @@ pub fn parse_physical_sort_expr( input_schema: &Schema, ) -> Result { if let Some(expr) = &proto.expr { - let expr = parse_physical_expr(expr.as_ref(), registry, input_schema)?; + let codec = DefaultPhysicalExtensionCodec {}; + let expr = parse_physical_expr(expr.as_ref(), registry, input_schema, &codec)?; let options = SortOptions { descending: !proto.asc, nulls_first: proto.nulls_first, @@ -110,7 +114,9 @@ pub fn parse_physical_sort_exprs( .iter() .map(|sort_expr| { if let Some(expr) = &sort_expr.expr { - let expr = parse_physical_expr(expr.as_ref(), registry, input_schema)?; + let codec = DefaultPhysicalExtensionCodec {}; + let expr = + parse_physical_expr(expr.as_ref(), registry, input_schema, &codec)?; let options = SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -137,16 +143,17 @@ pub fn parse_physical_window_expr( registry: &dyn FunctionRegistry, input_schema: &Schema, ) -> Result> { + let codec = DefaultPhysicalExtensionCodec {}; let window_node_expr = proto .args .iter() - .map(|e| parse_physical_expr(e, registry, input_schema)) + .map(|e| parse_physical_expr(e, registry, input_schema, &codec)) .collect::>>()?; let partition_by = proto .partition_by .iter() - .map(|p| parse_physical_expr(p, registry, input_schema)) + .map(|p| parse_physical_expr(p, registry, input_schema, &codec)) .collect::>>()?; let order_by = proto @@ -191,6 +198,7 @@ pub fn parse_physical_expr( proto: &protobuf::PhysicalExprNode, registry: &dyn FunctionRegistry, input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, ) -> Result> { let expr_type = proto .expr_type @@ -270,7 +278,7 @@ pub fn parse_physical_expr( )?, e.list .iter() - .map(|x| parse_physical_expr(x, registry, input_schema)) + .map(|x| parse_physical_expr(x, registry, input_schema, codec)) .collect::, _>>()?, &e.negated, input_schema, @@ -278,7 +286,7 @@ pub fn parse_physical_expr( ExprType::Case(e) => Arc::new(CaseExpr::try_new( e.expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema)) + .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema, codec)) .transpose()?, e.when_then_expr .iter() @@ -301,7 +309,7 @@ pub fn parse_physical_expr( .collect::>>()?, e.else_expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema)) + .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema, codec)) .transpose()?, )?), ExprType::Cast(e) => Arc::new(CastExpr::new( @@ -334,7 +342,7 @@ pub fn parse_physical_expr( let args = e .args .iter() - .map(|x| parse_physical_expr(x, registry, input_schema)) + .map(|x| parse_physical_expr(x, registry, input_schema, codec)) .collect::, _>>()?; // TODO Do not create new the ExecutionProps @@ -348,19 +356,22 @@ pub fn parse_physical_expr( )? } ExprType::ScalarUdf(e) => { - let udf = registry.udf(e.name.as_str())?; + let udf = match &e.fun_definition { + Some(buf) => codec.try_decode_udf(&e.name, buf)?, + None => registry.udf(e.name.as_str())?, + }; let signature = udf.signature(); - let scalar_fun = udf.fun().clone(); + let scalar_fun_def = ScalarFunctionDefinition::UDF(udf.clone()); let args = e .args .iter() - .map(|x| parse_physical_expr(x, registry, input_schema)) + .map(|x| parse_physical_expr(x, registry, input_schema, codec)) .collect::, _>>()?; Arc::new(ScalarFunctionExpr::new( e.name.as_str(), - scalar_fun, + scalar_fun_def, args, convert_required!(e.return_type)?, None, @@ -394,7 +405,8 @@ fn parse_required_physical_expr( field: &str, input_schema: &Schema, ) -> Result> { - expr.map(|e| parse_physical_expr(e, registry, input_schema)) + let codec = DefaultPhysicalExtensionCodec {}; + expr.map(|e| parse_physical_expr(e, registry, input_schema, &codec)) .transpose()? .ok_or_else(|| { DataFusionError::Internal(format!("Missing required field {field:?}")) @@ -439,10 +451,11 @@ pub fn parse_protobuf_hash_partitioning( ) -> Result> { match partitioning { Some(hash_part) => { + let codec = DefaultPhysicalExtensionCodec {}; let expr = hash_part .hash_expr .iter() - .map(|e| parse_physical_expr(e, registry, input_schema)) + .map(|e| parse_physical_expr(e, registry, input_schema, &codec)) .collect::>, _>>()?; Ok(Some(Partitioning::Hash( @@ -503,6 +516,7 @@ pub fn parse_protobuf_file_scan_config( let mut output_ordering = vec![]; for node_collection in &proto.output_ordering { + let codec = DefaultPhysicalExtensionCodec {}; let sort_expr = node_collection .physical_sort_expr_nodes .iter() @@ -510,7 +524,7 @@ pub fn parse_protobuf_file_scan_config( let expr = node .expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), registry, &schema)) + .map(|e| parse_physical_expr(e.as_ref(), registry, &schema, &codec)) .unwrap()?; Ok(PhysicalSortExpr { expr, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 004948da938f..6264aa8e4060 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -20,6 +20,7 @@ use std::fmt::Debug; use std::sync::Arc; use self::from_proto::parse_physical_window_expr; +use self::to_proto::serialize_expr; use crate::common::{byte_to_string, proto_error, str_to_byte}; use crate::convert_required; @@ -138,7 +139,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .zip(projection.expr_name.iter()) .map(|(expr, name)| { Ok(( - parse_physical_expr(expr, registry, input.schema().as_ref())?, + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + )?, name.to_string(), )) }) @@ -156,7 +162,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .expr .as_ref() .map(|expr| { - parse_physical_expr(expr, registry, input.schema().as_ref()) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) }) .transpose()? .ok_or_else(|| { @@ -208,6 +219,7 @@ impl AsExecutionPlan for PhysicalPlanNode { expr, registry, base_config.file_schema.as_ref(), + extension_codec, ) }) .transpose()?; @@ -254,7 +266,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .hash_expr .iter() .map(|e| { - parse_physical_expr(e, registry, input.schema().as_ref()) + parse_physical_expr( + e, + registry, + input.schema().as_ref(), + extension_codec, + ) }) .collect::>, _>>()?; @@ -329,7 +346,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .partition_keys .iter() .map(|expr| { - parse_physical_expr(expr, registry, input.schema().as_ref()) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) }) .collect::>>>()?; @@ -396,8 +418,13 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, registry, input.schema().as_ref()) - .map(|expr| (expr, name.to_string())) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) + .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -406,8 +433,13 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, registry, input.schema().as_ref()) - .map(|expr| (expr, name.to_string())) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) + .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -434,7 +466,14 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|expr| { expr.expr .as_ref() - .map(|e| parse_physical_expr(e, registry, &physical_schema)) + .map(|e| { + parse_physical_expr( + e, + registry, + &physical_schema, + extension_codec, + ) + }) .transpose() }) .collect::, _>>()?; @@ -451,7 +490,7 @@ impl AsExecutionPlan for PhysicalPlanNode { match expr_type { ExprType::AggregateExpr(agg_node) => { let input_phy_expr: Vec> = agg_node.expr.iter() - .map(|e| parse_physical_expr(e, registry, &physical_schema).unwrap()).collect(); + .map(|e| parse_physical_expr(e, registry, &physical_schema, extension_codec).unwrap()).collect(); let ordering_req: Vec = agg_node.ordering_req.iter() .map(|e| parse_physical_sort_expr(e, registry, &physical_schema).unwrap()).collect(); agg_node.aggregate_function.as_ref().map(|func| { @@ -524,11 +563,13 @@ impl AsExecutionPlan for PhysicalPlanNode { &col.left.clone().unwrap(), registry, left_schema.as_ref(), + extension_codec, )?; let right = parse_physical_expr( &col.right.clone().unwrap(), registry, right_schema.as_ref(), + extension_codec, )?; Ok((left, right)) }) @@ -555,6 +596,7 @@ impl AsExecutionPlan for PhysicalPlanNode { proto_error("Unexpected empty filter expression") })?, registry, &schema, + extension_codec, )?; let column_indices = f.column_indices .iter() @@ -635,11 +677,13 @@ impl AsExecutionPlan for PhysicalPlanNode { &col.left.clone().unwrap(), registry, left_schema.as_ref(), + extension_codec, )?; let right = parse_physical_expr( &col.right.clone().unwrap(), registry, right_schema.as_ref(), + extension_codec, )?; Ok((left, right)) }) @@ -666,6 +710,7 @@ impl AsExecutionPlan for PhysicalPlanNode { proto_error("Unexpected empty filter expression") })?, registry, &schema, + extension_codec, )?; let column_indices = f.column_indices .iter() @@ -805,7 +850,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref(), extension_codec)?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -852,7 +897,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref(), extension_codec)?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -916,6 +961,7 @@ impl AsExecutionPlan for PhysicalPlanNode { proto_error("Unexpected empty filter expression") })?, registry, &schema, + extension_codec, )?; let column_indices = f.column_indices .iter() @@ -1088,7 +1134,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let expr = exec .expr() .iter() - .map(|expr| expr.0.clone().try_into()) + .map(|expr| serialize_expr(expr.0.clone(), extension_codec)) .collect::>>()?; let expr_name = exec.expr().iter().map(|expr| expr.1.clone()).collect(); return Ok(protobuf::PhysicalPlanNode { @@ -1128,7 +1174,10 @@ impl AsExecutionPlan for PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Filter(Box::new( protobuf::FilterExecNode { input: Some(Box::new(input)), - expr: Some(exec.predicate().clone().try_into()?), + expr: Some(serialize_expr( + exec.predicate().clone(), + extension_codec, + )?), default_filter_selectivity: exec.default_selectivity() as u32, }, ))), @@ -1183,8 +1232,8 @@ impl AsExecutionPlan for PhysicalPlanNode { .on() .iter() .map(|tuple| { - let l = tuple.0.to_owned().try_into()?; - let r = tuple.1.to_owned().try_into()?; + let l = serialize_expr(tuple.0.to_owned(), extension_codec)?; + let r = serialize_expr(tuple.1.to_owned(), extension_codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -1196,7 +1245,8 @@ impl AsExecutionPlan for PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = f.expression().to_owned().try_into()?; + let expression = + serialize_expr(f.expression().to_owned(), extension_codec)?; let column_indices = f .column_indices() .iter() @@ -1254,8 +1304,8 @@ impl AsExecutionPlan for PhysicalPlanNode { .on() .iter() .map(|tuple| { - let l = tuple.0.to_owned().try_into()?; - let r = tuple.1.to_owned().try_into()?; + let l = serialize_expr(tuple.0.to_owned(), extension_codec)?; + let r = serialize_expr(tuple.1.to_owned(), extension_codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -1267,7 +1317,8 @@ impl AsExecutionPlan for PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = f.expression().to_owned().try_into()?; + let expression = + serialize_expr(f.expression().to_owned(), extension_codec)?; let column_indices = f .column_indices() .iter() @@ -1304,7 +1355,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }) @@ -1321,7 +1375,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }) @@ -1423,14 +1480,14 @@ impl AsExecutionPlan for PhysicalPlanNode { .group_expr() .null_expr() .iter() - .map(|expr| expr.0.to_owned().try_into()) + .map(|expr| serialize_expr(expr.0.to_owned(), extension_codec)) .collect::>>()?; let group_expr = exec .group_expr() .expr() .iter() - .map(|expr| expr.0.to_owned().try_into()) + .map(|expr| serialize_expr(expr.0.to_owned(), extension_codec)) .collect::>>()?; return Ok(protobuf::PhysicalPlanNode { @@ -1512,7 +1569,7 @@ impl AsExecutionPlan for PhysicalPlanNode { if let Some(exec) = plan.downcast_ref::() { let predicate = exec .predicate() - .map(|pred| pred.clone().try_into()) + .map(|pred| serialize_expr(pred.clone(), extension_codec)) .transpose()?; return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ParquetScan( @@ -1559,7 +1616,7 @@ impl AsExecutionPlan for PhysicalPlanNode { PartitionMethod::Hash(protobuf::PhysicalHashRepartition { hash_expr: exprs .iter() - .map(|expr| expr.clone().try_into()) + .map(|expr| serialize_expr(expr.clone(), extension_codec)) .collect::>>()?, partition_count: *partition_count as u64, }) @@ -1592,7 +1649,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); @@ -1658,7 +1718,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); @@ -1695,7 +1758,8 @@ impl AsExecutionPlan for PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = f.expression().to_owned().try_into()?; + let expression = + serialize_expr(f.expression().to_owned(), extension_codec)?; let column_indices = f .column_indices() .iter() @@ -1743,7 +1807,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let partition_keys = exec .partition_keys .iter() - .map(|e| e.clone().try_into()) + .map(|e| serialize_expr(e.clone(), extension_codec)) .collect::>>()?; return Ok(protobuf::PhysicalPlanNode { @@ -1773,7 +1837,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let partition_keys = exec .partition_keys .iter() - .map(|e| e.clone().try_into()) + .map(|e| serialize_expr(e.clone(), extension_codec)) .collect::>>()?; let input_order_mode = match &exec.input_order_mode { @@ -1816,7 +1880,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|requirement| { let expr: PhysicalSortExpr = requirement.to_owned().into(); let sort_expr = protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }; diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 7dbe7ed3530b..6dbb7efbc541 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -22,7 +22,6 @@ use std::{ sync::Arc, }; -use crate::logical_plan::csv_writer_options_to_proto; use crate::protobuf::{ self, copy_to_node, physical_aggregate_expr_node, physical_window_expr_node, scalar_value::Value, ArrowOptions, AvroOptions, PhysicalSortExprNode, @@ -32,10 +31,9 @@ use crate::protobuf::{ #[cfg(feature = "parquet")] use datafusion::datasource::file_format::parquet::ParquetSink; -use datafusion_expr::{create_udf, ScalarUDF}; - -use crate::logical_plan::{csv_writer_options_to_proto, writer_properties_to_proto}; +use datafusion_expr::ScalarFunctionDefinition; +use crate::logical_plan::csv_writer_options_to_proto; use datafusion::datasource::{ file_format::csv::CsvSink, file_format::json::JsonSink, @@ -51,10 +49,10 @@ use datafusion::physical_plan::expressions::{ ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, Count, Covariance, CovariancePop, CumeDist, DistinctArrayAgg, DistinctBitXor, DistinctCount, DistinctSum, FirstValue, Grouping, - InListExpr, IsNotNullExpr, IsNullExpr, LastValue, LikeExpr, Literal, Max, Median, - Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, - Rank, RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, StringAgg, Sum, - TryCastExpr, Variance, VariancePop, WindowShift, + InListExpr, IsNotNullExpr, IsNullExpr, LastValue, Literal, Max, Median, Min, + NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, + RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, StringAgg, Sum, TryCastExpr, + Variance, VariancePop, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -382,7 +380,7 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { Ok(AggrFn { inner, distinct }) } -fn serialize_expr( +pub fn serialize_expr( value: Arc, codec: &dyn PhysicalExtensionCodec, ) -> Result { @@ -541,20 +539,31 @@ fn serialize_expr( }) } else { let mut buf = Vec::new(); - let udf = create_udf(expr.name(), expr, return_type, volatility, fun); - // let _ = codec.try_encode_udf(, &mut buf); + match expr.fun() { + ScalarFunctionDefinition::UDF(udf) => { + codec.try_encode_udf(udf, &mut buf)?; + } + _ => { + return not_impl_err!( + "Proto serialization error: Trying to serialize a unresolved function" + ); + } + } + + let fun_definition = if buf.is_empty() { None } else { Some(buf) }; Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( protobuf::PhysicalScalarUdfNode { name: expr.name().to_string(), args, + fun_definition, return_type: Some(expr.return_type().try_into()?), }, )), }) - } else { - internal_err!("physical_plan::to_proto() unsupported expression {value:?}") } + } else { + internal_err!("physical_plan::to_proto() unsupported expression {value:?}") } } @@ -698,7 +707,7 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { let expr_node_vec = order .iter() .map(|sort_expr| { - let expr = serialize_expr(sort_expr.expr, &codec)?; + let expr = serialize_expr(sort_expr.expr.clone(), &codec)?; Ok(PhysicalSortExprNode { expr: Some(Box::new(expr)), asc: !sort_expr.options.descending, diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 7f0c6286a19d..75abcb8da684 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -32,7 +32,6 @@ use datafusion::datasource::physical_plan::{ wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, FileSinkConfig, ParquetExec, }; -use datafusion::execution::context::ExecutionProps; use datafusion::logical_expr::{ create_udf, BuiltinScalarFunction, JoinType, Operator, Volatility, }; @@ -49,7 +48,6 @@ use datafusion::physical_plan::expressions::{ NotExpr, NthValue, PhysicalSortExpr, StringAgg, Sum, }; use datafusion::physical_plan::filter::FilterExec; -use datafusion::physical_plan::functions; use datafusion::physical_plan::insert::FileSinkExec; use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, @@ -75,8 +73,7 @@ use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::Result; use datafusion_expr::{ - Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, Signature, - SimpleAggregateUDF, WindowFrame, WindowFrameBound, + Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarFunctionDefinition, Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound }; use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; use datafusion_proto::protobuf; @@ -603,14 +600,11 @@ fn roundtrip_builtin_scalar_function() -> Result<()> { let input = Arc::new(EmptyExec::new(schema.clone())); - let execution_props = ExecutionProps::new(); - - let fun_expr = - functions::create_physical_fun(&BuiltinScalarFunction::Sin, &execution_props)?; + let fun_def = ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Sin); let expr = ScalarFunctionExpr::new( "sin", - fun_expr, + fun_def, vec![col("a", &schema)?], DataType::Float64, None, @@ -646,9 +640,11 @@ fn roundtrip_scalar_udf() -> Result<()> { scalar_fn.clone(), ); + let fun_def = ScalarFunctionDefinition::UDF(Arc::new(udf.clone())); + let expr = ScalarFunctionExpr::new( "dummy", - scalar_fn, + fun_def, vec![col("a", &schema)?], DataType::Int64, None, From 1738529de00dc8c571fd0b8d8c8e02359ab241a6 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Wed, 13 Mar 2024 14:55:16 -0500 Subject: [PATCH 05/11] cargo fmt --- datafusion/proto/tests/cases/roundtrip_physical_plan.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 75abcb8da684..d43d845f7656 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -73,7 +73,9 @@ use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::Result; use datafusion_expr::{ - Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarFunctionDefinition, Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound + Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, + ScalarFunctionDefinition, Signature, SimpleAggregateUDF, WindowFrame, + WindowFrameBound, }; use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; use datafusion_proto::protobuf; From 93d97505bd4f0912cc92594e556efe40b5df4124 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Wed, 13 Mar 2024 15:34:45 -0500 Subject: [PATCH 06/11] feat : add test --- .../tests/cases/roundtrip_physical_plan.rs | 142 +++++++++++++++++- 1 file changed, 138 insertions(+), 4 deletions(-) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index d43d845f7656..34981feec7eb 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; use std::ops::Deref; use std::sync::Arc; use std::vec; @@ -32,6 +33,7 @@ use datafusion::datasource::physical_plan::{ wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, FileSinkConfig, ParquetExec, }; +use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ create_udf, BuiltinScalarFunction, JoinType, Operator, Volatility, }; @@ -71,14 +73,19 @@ use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; -use datafusion_common::Result; +use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, - ScalarFunctionDefinition, Signature, SimpleAggregateUDF, WindowFrame, - WindowFrameBound, + ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, + WindowFrame, WindowFrameBound, +}; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr; +use datafusion_proto::physical_plan::to_proto::serialize_expr; +use datafusion_proto::physical_plan::{ + AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; -use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; use datafusion_proto::protobuf; +use prost::Message; /// Perform a serde roundtrip and assert that the string representation of the before and after plans /// are identical. Note that this often isn't sufficient to guarantee that no information is @@ -663,6 +670,133 @@ fn roundtrip_scalar_udf() -> Result<()> { roundtrip_test_with_context(Arc::new(project), ctx) } +#[test] +fn roundtrip_scalar_udf_extension_codec() { + #[derive(Debug)] + struct MyRegexUdf { + signature: Signature, + // regex as original string + pattern: String, + } + + impl MyRegexUdf { + fn new(pattern: String) -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Int32], + Volatility::Immutable, + ), + pattern, + } + } + } + + /// Implement the ScalarUDFImpl trait for MyRegexUdf + impl ScalarUDFImpl for MyRegexUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "regex_udf" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, args: &[DataType]) -> Result { + if !matches!(args.first(), Some(&DataType::Utf8)) { + return plan_err!("regex_udf only accepts Utf8 arguments"); + } + Ok(DataType::Int32) + } + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!() + } + } + + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct MyRegexUdfNode { + #[prost(string, tag = "1")] + pub pattern: String, + } + + #[derive(Debug)] + pub struct ScalarUDFExtensionCodec {} + + impl PhysicalExtensionCodec for ScalarUDFExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> Result> { + not_impl_err!("No extension codec provided") + } + + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + not_impl_err!("No extension codec provided") + } + + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "regex_udf" { + let proto = MyRegexUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!( + "failed to decode regex_udf: {}", + err + )) + })?; + + Ok(Arc::new(ScalarUDF::new_from_impl(MyRegexUdf::new( + proto.pattern, + )))) + } else { + not_impl_err!("unrecognized scalar UDF implementation, cannot decode") + } + } + + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + let udf = binding.as_any().downcast_ref::().unwrap(); + let proto = MyRegexUdfNode { + pattern: udf.pattern.clone(), + }; + proto.encode(buf).map_err(|e| { + DataFusionError::Internal(format!("failed to encode udf: {e:?}")) + })?; + Ok(()) + } + } + + let pattern = ".*"; + let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string())); + let test_expr = ScalarFunctionExpr::new( + udf.name(), + ScalarFunctionDefinition::UDF(Arc::new(udf.clone())), + vec![], + DataType::Int32, + None, + false, + ); + let fmt_expr = format!("{test_expr:?}"); + let ctx = SessionContext::new(); + + ctx.register_udf(udf.clone()); + let extension_codec = ScalarUDFExtensionCodec {}; + let proto: protobuf::PhysicalExprNode = + match serialize_expr(Arc::new(test_expr), &extension_codec) { + Ok(proto) => proto, + Err(e) => panic!("failed to serialize expr: {e:?}"), + }; + let field_a = Field::new("a", DataType::Int32, false); + let schema = Arc::new(Schema::new(vec![field_a])); + let round_trip = + parse_physical_expr(&proto, &ctx, &schema, &extension_codec).unwrap(); + assert_eq!(fmt_expr, format!("{round_trip:?}")); +} #[test] fn roundtrip_distinct_count() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); From 376905b0fac8ebbc5451d1bc10990e42e8900adc Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Wed, 13 Mar 2024 15:47:02 -0500 Subject: [PATCH 07/11] fix bug --- datafusion/physical-expr/src/scalar_function.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index f83dfe2ada1b..6bb452ddec61 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -41,7 +41,7 @@ use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::Result; +use datafusion_common::{internal_err, Result}; use datafusion_expr::{ expr_vec_fmt, BuiltinScalarFunction, ColumnarValue, FuncMonotonicity, ScalarFunctionDefinition, @@ -173,8 +173,11 @@ impl PhysicalExpr for ScalarFunctionExpr { let fun_implementation = match self.fun { ScalarFunctionDefinition::BuiltIn(ref fun) => create_physical_fun(fun)?, - _ => { - todo!("User-defined functions are not supported yet") + ScalarFunctionDefinition::UDF(ref fun) => fun.fun(), + ScalarFunctionDefinition::Name(_) => { + return internal_err!( + "Name function must be resolved to one of the other variants prior to physical planning" + ); } }; // evaluate the function From d8d1eb2bdb17d14a9cdb77d037ded230c888cc11 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Wed, 13 Mar 2024 16:03:39 -0500 Subject: [PATCH 08/11] fix wrong delete code when resolve conflict --- .../proto/src/physical_plan/to_proto.rs | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 6dbb7efbc541..d5eac2eb077f 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -34,13 +34,6 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion_expr::ScalarFunctionDefinition; use crate::logical_plan::csv_writer_options_to_proto; -use datafusion::datasource::{ - file_format::csv::CsvSink, - file_format::json::JsonSink, - listing::{FileRange, PartitionedFile}, - physical_plan::FileScanConfig, - physical_plan::FileSinkConfig, -}; use datafusion::logical_expr::BuiltinScalarFunction; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; @@ -59,6 +52,14 @@ use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindow use datafusion::physical_plan::{ AggregateExpr, ColumnStatistics, PhysicalExpr, Statistics, WindowExpr, }; +use datafusion::{ + datasource::{ + file_format::{csv::CsvSink, json::JsonSink}, + listing::{FileRange, PartitionedFile}, + physical_plan::{FileScanConfig, FileSinkConfig}, + }, + physical_plan::expressions::LikeExpr, +}; use datafusion_common::config::{ ColumnOptions, CsvOptions, FormatOptions, JsonOptions, ParquetOptions, TableParquetOptions, @@ -562,6 +563,20 @@ pub fn serialize_expr( )), }) } + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::LikeExpr(Box::new( + protobuf::PhysicalLikeExprNode { + negated: expr.negated(), + case_insensitive: expr.case_insensitive(), + expr: Some(Box::new(serialize_expr(expr.expr().to_owned(), codec)?)), + pattern: Some(Box::new(serialize_expr( + expr.pattern().to_owned(), + codec, + )?)), + }, + ))), + }) } else { internal_err!("physical_plan::to_proto() unsupported expression {value:?}") } From 623e56c8c86ea694f3b70dd686906154883e3653 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Tue, 19 Mar 2024 10:03:42 -0500 Subject: [PATCH 09/11] Update datafusion/proto/src/physical_plan/to_proto.rs Co-authored-by: Dan Harris <1327726+thinkharderdev@users.noreply.github.com> --- datafusion/proto/src/physical_plan/to_proto.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index d5eac2eb077f..18d63d6111cb 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -381,7 +381,11 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { Ok(AggrFn { inner, distinct }) } -pub fn serialize_expr( +/// Serialize a `PhysicalExpr` to default protobuf representation. +/// +/// If required, a [`PhysicalExtensionCodec`] can be provided which can handle +/// serialization of udfs requiring specialized serialization (see [`PhysicalExtensionCodec::try_encode_udf`]) +pub fn serialize_physical_expr( value: Arc, codec: &dyn PhysicalExtensionCodec, ) -> Result { From 1dfce4aedfa204447bd06a603f8b7b60ee02b13d Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Tue, 19 Mar 2024 10:04:19 -0500 Subject: [PATCH 10/11] Update datafusion/proto/tests/cases/roundtrip_physical_plan.rs Co-authored-by: Dan Harris <1327726+thinkharderdev@users.noreply.github.com> --- .../proto/tests/cases/roundtrip_physical_plan.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 34981feec7eb..76a73a09dea9 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -760,13 +760,14 @@ fn roundtrip_scalar_udf_extension_codec() { fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { let binding = node.inner(); - let udf = binding.as_any().downcast_ref::().unwrap(); - let proto = MyRegexUdfNode { - pattern: udf.pattern.clone(), - }; - proto.encode(buf).map_err(|e| { - DataFusionError::Internal(format!("failed to encode udf: {e:?}")) - })?; + if let Some(udf) = binding.as_any().downcast_ref::() { + let proto = MyRegexUdfNode { + pattern: udf.pattern.clone(), + }; + proto.encode(buf).map_err(|e| { + DataFusionError::Internal(format!("failed to encode udf: {e:?}")) + })?; + } Ok(()) } } From 767b53df0f0f604ece7cd54015e1118c4cf3f570 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Tue, 19 Mar 2024 10:18:36 -0500 Subject: [PATCH 11/11] address the comment --- .../physical-expr/src/scalar_function.rs | 19 ++--- datafusion/proto/src/physical_plan/mod.rs | 56 ++++++++------ .../proto/src/physical_plan/to_proto.rs | 77 +++++++++++++------ .../tests/cases/roundtrip_physical_plan.rs | 4 +- 4 files changed, 98 insertions(+), 58 deletions(-) diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 6bb452ddec61..d34084236690 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -171,18 +171,19 @@ impl PhysicalExpr for ScalarFunctionExpr { .collect::>>()?, }; - let fun_implementation = match self.fun { - ScalarFunctionDefinition::BuiltIn(ref fun) => create_physical_fun(fun)?, - ScalarFunctionDefinition::UDF(ref fun) => fun.fun(), + // evaluate the function + match self.fun { + ScalarFunctionDefinition::BuiltIn(ref fun) => { + let fun = create_physical_fun(fun)?; + (fun)(&inputs) + } + ScalarFunctionDefinition::UDF(ref fun) => fun.invoke(&inputs), ScalarFunctionDefinition::Name(_) => { - return internal_err!( + internal_err!( "Name function must be resolved to one of the other variants prior to physical planning" - ); + ) } - }; - // evaluate the function - let fun = fun_implementation.as_ref(); - (fun)(&inputs) + } } fn children(&self) -> Vec> { diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 6264aa8e4060..da31c5e762bc 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -20,7 +20,7 @@ use std::fmt::Debug; use std::sync::Arc; use self::from_proto::parse_physical_window_expr; -use self::to_proto::serialize_expr; +use self::to_proto::serialize_physical_expr; use crate::common::{byte_to_string, proto_error, str_to_byte}; use crate::convert_required; @@ -1134,7 +1134,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let expr = exec .expr() .iter() - .map(|expr| serialize_expr(expr.0.clone(), extension_codec)) + .map(|expr| serialize_physical_expr(expr.0.clone(), extension_codec)) .collect::>>()?; let expr_name = exec.expr().iter().map(|expr| expr.1.clone()).collect(); return Ok(protobuf::PhysicalPlanNode { @@ -1174,7 +1174,7 @@ impl AsExecutionPlan for PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Filter(Box::new( protobuf::FilterExecNode { input: Some(Box::new(input)), - expr: Some(serialize_expr( + expr: Some(serialize_physical_expr( exec.predicate().clone(), extension_codec, )?), @@ -1232,8 +1232,8 @@ impl AsExecutionPlan for PhysicalPlanNode { .on() .iter() .map(|tuple| { - let l = serialize_expr(tuple.0.to_owned(), extension_codec)?; - let r = serialize_expr(tuple.1.to_owned(), extension_codec)?; + let l = serialize_physical_expr(tuple.0.to_owned(), extension_codec)?; + let r = serialize_physical_expr(tuple.1.to_owned(), extension_codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -1245,8 +1245,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = - serialize_expr(f.expression().to_owned(), extension_codec)?; + let expression = serialize_physical_expr( + f.expression().to_owned(), + extension_codec, + )?; let column_indices = f .column_indices() .iter() @@ -1304,8 +1306,8 @@ impl AsExecutionPlan for PhysicalPlanNode { .on() .iter() .map(|tuple| { - let l = serialize_expr(tuple.0.to_owned(), extension_codec)?; - let r = serialize_expr(tuple.1.to_owned(), extension_codec)?; + let l = serialize_physical_expr(tuple.0.to_owned(), extension_codec)?; + let r = serialize_physical_expr(tuple.1.to_owned(), extension_codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -1317,8 +1319,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = - serialize_expr(f.expression().to_owned(), extension_codec)?; + let expression = serialize_physical_expr( + f.expression().to_owned(), + extension_codec, + )?; let column_indices = f .column_indices() .iter() @@ -1355,7 +1359,7 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_expr( + expr: Some(Box::new(serialize_physical_expr( expr.expr.to_owned(), extension_codec, )?)), @@ -1375,7 +1379,7 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_expr( + expr: Some(Box::new(serialize_physical_expr( expr.expr.to_owned(), extension_codec, )?)), @@ -1480,14 +1484,14 @@ impl AsExecutionPlan for PhysicalPlanNode { .group_expr() .null_expr() .iter() - .map(|expr| serialize_expr(expr.0.to_owned(), extension_codec)) + .map(|expr| serialize_physical_expr(expr.0.to_owned(), extension_codec)) .collect::>>()?; let group_expr = exec .group_expr() .expr() .iter() - .map(|expr| serialize_expr(expr.0.to_owned(), extension_codec)) + .map(|expr| serialize_physical_expr(expr.0.to_owned(), extension_codec)) .collect::>>()?; return Ok(protobuf::PhysicalPlanNode { @@ -1569,7 +1573,7 @@ impl AsExecutionPlan for PhysicalPlanNode { if let Some(exec) = plan.downcast_ref::() { let predicate = exec .predicate() - .map(|pred| serialize_expr(pred.clone(), extension_codec)) + .map(|pred| serialize_physical_expr(pred.clone(), extension_codec)) .transpose()?; return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ParquetScan( @@ -1616,7 +1620,9 @@ impl AsExecutionPlan for PhysicalPlanNode { PartitionMethod::Hash(protobuf::PhysicalHashRepartition { hash_expr: exprs .iter() - .map(|expr| serialize_expr(expr.clone(), extension_codec)) + .map(|expr| { + serialize_physical_expr(expr.clone(), extension_codec) + }) .collect::>>()?, partition_count: *partition_count as u64, }) @@ -1649,7 +1655,7 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_expr( + expr: Some(Box::new(serialize_physical_expr( expr.expr.to_owned(), extension_codec, )?)), @@ -1718,7 +1724,7 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_expr( + expr: Some(Box::new(serialize_physical_expr( expr.expr.to_owned(), extension_codec, )?)), @@ -1758,8 +1764,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = - serialize_expr(f.expression().to_owned(), extension_codec)?; + let expression = serialize_physical_expr( + f.expression().to_owned(), + extension_codec, + )?; let column_indices = f .column_indices() .iter() @@ -1807,7 +1815,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let partition_keys = exec .partition_keys .iter() - .map(|e| serialize_expr(e.clone(), extension_codec)) + .map(|e| serialize_physical_expr(e.clone(), extension_codec)) .collect::>>()?; return Ok(protobuf::PhysicalPlanNode { @@ -1837,7 +1845,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let partition_keys = exec .partition_keys .iter() - .map(|e| serialize_expr(e.clone(), extension_codec)) + .map(|e| serialize_physical_expr(e.clone(), extension_codec)) .collect::>>()?; let input_order_mode = match &exec.input_order_mode { @@ -1880,7 +1888,7 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|requirement| { let expr: PhysicalSortExpr = requirement.to_owned().into(); let sort_expr = protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_expr( + expr: Some(Box::new(serialize_physical_expr( expr.expr.to_owned(), extension_codec, )?)), diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 18d63d6111cb..b66709d0c5bd 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -82,7 +82,7 @@ impl TryFrom> for protobuf::PhysicalExprNode { let expressions: Vec = a .expressions() .iter() - .map(|e| serialize_expr(e.clone(), &codec)) + .map(|e| serialize_physical_expr(e.clone(), &codec)) .collect::>>()?; let ordering_req: Vec = a @@ -247,13 +247,13 @@ impl TryFrom> for protobuf::PhysicalWindowExprNode { let codec = DefaultPhysicalExtensionCodec {}; let args = args .into_iter() - .map(|e| serialize_expr(e, &codec)) + .map(|e| serialize_physical_expr(e, &codec)) .collect::>>()?; let partition_by = window_expr .partition_by() .iter() - .map(|p| serialize_expr(p.clone(), &codec)) + .map(|p| serialize_physical_expr(p.clone(), &codec)) .collect::>>()?; let order_by = window_expr @@ -402,8 +402,14 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode { - l: Some(Box::new(serialize_expr(expr.left().clone(), codec)?)), - r: Some(Box::new(serialize_expr(expr.right().clone(), codec)?)), + l: Some(Box::new(serialize_physical_expr( + expr.left().clone(), + codec, + )?)), + r: Some(Box::new(serialize_physical_expr( + expr.right().clone(), + codec, + )?)), op: format!("{:?}", expr.op()), }); @@ -421,7 +427,8 @@ pub fn serialize_physical_expr( expr: expr .expr() .map(|exp| { - serialize_expr(exp.clone(), codec).map(Box::new) + serialize_physical_expr(exp.clone(), codec) + .map(Box::new) }) .transpose()?, when_then_expr: expr @@ -436,7 +443,10 @@ pub fn serialize_physical_expr( >>()?, else_expr: expr .else_expr() - .map(|a| serialize_expr(a.clone(), codec).map(Box::new)) + .map(|a| { + serialize_physical_expr(a.clone(), codec) + .map(Box::new) + }) .transpose()?, }, ), @@ -447,7 +457,10 @@ pub fn serialize_physical_expr( Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr(Box::new( protobuf::PhysicalNot { - expr: Some(Box::new(serialize_expr(expr.arg().to_owned(), codec)?)), + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), }, ))), }) @@ -455,7 +468,10 @@ pub fn serialize_physical_expr( Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( Box::new(protobuf::PhysicalIsNull { - expr: Some(Box::new(serialize_expr(expr.arg().to_owned(), codec)?)), + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), }), )), }) @@ -463,7 +479,10 @@ pub fn serialize_physical_expr( Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( Box::new(protobuf::PhysicalIsNotNull { - expr: Some(Box::new(serialize_expr(expr.arg().to_owned(), codec)?)), + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), }), )), }) @@ -473,14 +492,14 @@ pub fn serialize_physical_expr( protobuf::physical_expr_node::ExprType::InList( Box::new( protobuf::PhysicalInListNode { - expr: Some(Box::new(serialize_expr( + expr: Some(Box::new(serialize_physical_expr( expr.expr().to_owned(), codec, )?)), list: expr .list() .iter() - .map(|a| serialize_expr(a.clone(), codec)) + .map(|a| serialize_physical_expr(a.clone(), codec)) .collect::, DataFusionError, @@ -495,7 +514,10 @@ pub fn serialize_physical_expr( Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::Negative(Box::new( protobuf::PhysicalNegativeNode { - expr: Some(Box::new(serialize_expr(expr.arg().to_owned(), codec)?)), + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), }, ))), }) @@ -509,7 +531,10 @@ pub fn serialize_physical_expr( Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( protobuf::PhysicalCastNode { - expr: Some(Box::new(serialize_expr(cast.expr().to_owned(), codec)?)), + expr: Some(Box::new(serialize_physical_expr( + cast.expr().to_owned(), + codec, + )?)), arrow_type: Some(cast.cast_type().try_into()?), }, ))), @@ -518,7 +543,10 @@ pub fn serialize_physical_expr( Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new( protobuf::PhysicalTryCastNode { - expr: Some(Box::new(serialize_expr(cast.expr().to_owned(), codec)?)), + expr: Some(Box::new(serialize_physical_expr( + cast.expr().to_owned(), + codec, + )?)), arrow_type: Some(cast.cast_type().try_into()?), }, ))), @@ -527,7 +555,7 @@ pub fn serialize_physical_expr( let args: Vec = expr .args() .iter() - .map(|e| serialize_expr(e.to_owned(), codec)) + .map(|e| serialize_physical_expr(e.to_owned(), codec)) .collect::, _>>()?; if let Ok(fun) = BuiltinScalarFunction::from_str(expr.name()) { let fun: protobuf::ScalarFunction = (&fun).try_into()?; @@ -573,8 +601,11 @@ pub fn serialize_physical_expr( protobuf::PhysicalLikeExprNode { negated: expr.negated(), case_insensitive: expr.case_insensitive(), - expr: Some(Box::new(serialize_expr(expr.expr().to_owned(), codec)?)), - pattern: Some(Box::new(serialize_expr( + expr: Some(Box::new(serialize_physical_expr( + expr.expr().to_owned(), + codec, + )?)), + pattern: Some(Box::new(serialize_physical_expr( expr.pattern().to_owned(), codec, )?)), @@ -592,8 +623,8 @@ fn try_parse_when_then_expr( codec: &dyn PhysicalExtensionCodec, ) -> Result { Ok(protobuf::PhysicalWhenThen { - when_expr: Some(serialize_expr(when_expr.clone(), codec)?), - then_expr: Some(serialize_expr(then_expr.clone(), codec)?), + when_expr: Some(serialize_physical_expr(when_expr.clone(), codec)?), + then_expr: Some(serialize_physical_expr(then_expr.clone(), codec)?), }) } @@ -726,7 +757,7 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { let expr_node_vec = order .iter() .map(|sort_expr| { - let expr = serialize_expr(sort_expr.expr.clone(), &codec)?; + let expr = serialize_physical_expr(sort_expr.expr.clone(), &codec)?; Ok(PhysicalSortExprNode { expr: Some(Box::new(expr)), asc: !sort_expr.options.descending, @@ -793,7 +824,7 @@ impl TryFrom>> for protobuf::MaybeFilter { match expr { None => Ok(protobuf::MaybeFilter { expr: None }), Some(expr) => Ok(protobuf::MaybeFilter { - expr: Some(serialize_expr(expr, &codec)?), + expr: Some(serialize_physical_expr(expr, &codec)?), }), } } @@ -821,7 +852,7 @@ impl TryFrom for protobuf::PhysicalSortExprNode { fn try_from(sort_expr: PhysicalSortExpr) -> std::result::Result { let codec = DefaultPhysicalExtensionCodec {}; Ok(PhysicalSortExprNode { - expr: Some(Box::new(serialize_expr(sort_expr.expr, &codec)?)), + expr: Some(Box::new(serialize_physical_expr(sort_expr.expr, &codec)?)), asc: !sort_expr.options.descending, nulls_first: sort_expr.options.nulls_first, }) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 76a73a09dea9..4924128ae190 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -80,7 +80,7 @@ use datafusion_expr::{ WindowFrame, WindowFrameBound, }; use datafusion_proto::physical_plan::from_proto::parse_physical_expr; -use datafusion_proto::physical_plan::to_proto::serialize_expr; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr; use datafusion_proto::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; @@ -788,7 +788,7 @@ fn roundtrip_scalar_udf_extension_codec() { ctx.register_udf(udf.clone()); let extension_codec = ScalarUDFExtensionCodec {}; let proto: protobuf::PhysicalExprNode = - match serialize_expr(Arc::new(test_expr), &extension_codec) { + match serialize_physical_expr(Arc::new(test_expr), &extension_codec) { Ok(proto) => proto, Err(e) => panic!("failed to serialize expr: {e:?}"), };