From 2eb38028130966a79e5b2eafdc38530d4215689d Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 10 Mar 2024 09:22:23 +0800 Subject: [PATCH 1/3] Extend argument types for udf return type function Signed-off-by: jayzhan211 --- .../user_defined_scalar_functions.rs | 1 + datafusion/expr/src/expr_schema.rs | 2 +- datafusion/expr/src/udf.rs | 18 ++++++++---- datafusion/physical-expr/src/planner.rs | 17 +++++------ datafusion/physical-expr/src/udf.rs | 29 +++++++++++++++---- 5 files changed, 45 insertions(+), 22 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index d9b60134b3d9..13787a64fac9 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -655,6 +655,7 @@ impl ScalarUDFImpl for TakeUDF { &self, arg_exprs: &[Expr], schema: &dyn ExprSchema, + _arg_data_types: &[DataType], ) -> Result { if arg_exprs.len() != 3 { return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len()); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 026627a05e62..70ffa5064a52 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -153,7 +153,7 @@ impl ExprSchemable for Expr { // perform additional function arguments validation (due to limited // expressiveness of `TypeSignature`), then infer return type - Ok(fun.return_type_from_exprs(args, schema)?) + Ok(fun.return_type_from_exprs(args, schema, &arg_data_types)?) } ScalarFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 5ad420b2f382..f97fa44e27ad 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -157,9 +157,10 @@ impl ScalarUDF { &self, args: &[Expr], schema: &dyn ExprSchema, + arg_types: &[DataType], ) -> Result { // If the implementation provides a return_type_from_exprs, use it - self.inner.return_type_from_exprs(args, schema) + self.inner.return_type_from_exprs(args, schema, arg_types) } /// Do the function rewrite @@ -307,12 +308,17 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { &self, args: &[Expr], schema: &dyn ExprSchema, + arg_types: &[DataType], ) -> Result { - let arg_types = args - .iter() - .map(|arg| arg.get_type(schema)) - .collect::>>()?; - self.return_type(&arg_types) + if arg_types.is_empty() { + let arg_types = args + .iter() + .map(|arg| arg.get_type(schema)) + .collect::>>()?; + self.return_type(&arg_types) + } else { + self.return_type(arg_types) + } } /// Invoke the function on `args`, returning the appropriate result diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 858dbd30c124..417889100232 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -261,6 +261,7 @@ pub fn create_physical_expr( .iter() .map(|e| create_physical_expr(e, input_dfschema, execution_props)) .collect::>>()?; + match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { functions::create_physical_expr( @@ -270,15 +271,13 @@ pub fn create_physical_expr( execution_props, ) } - ScalarFunctionDefinition::UDF(fun) => { - let return_type = fun.return_type_from_exprs(args, input_dfschema)?; - - udf::create_physical_expr( - fun.clone().as_ref(), - &physical_args, - return_type, - ) - } + ScalarFunctionDefinition::UDF(fun) => udf::create_physical_expr( + fun.clone().as_ref(), + &physical_args, + input_schema, + args, + input_dfschema, + ), ScalarFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") } diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index d9c7c9e5c2a6..ede3e5badbb1 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -17,9 +17,10 @@ //! UDF support use crate::{PhysicalExpr, ScalarFunctionExpr}; -use arrow_schema::DataType; -use datafusion_common::Result; +use arrow_schema::Schema; +use datafusion_common::{DFSchema, Result}; pub use datafusion_expr::ScalarUDF; +use datafusion_expr::{type_coercion::functions::data_types, Expr}; use std::sync::Arc; /// Create a physical expression of the UDF. @@ -28,8 +29,22 @@ use std::sync::Arc; pub fn create_physical_expr( fun: &ScalarUDF, input_phy_exprs: &[Arc], - return_type: DataType, + input_schema: &Schema, + args: &[Expr], + input_dfschema: &DFSchema, ) -> Result> { + let input_expr_types = input_phy_exprs + .iter() + .map(|e| e.data_type(input_schema)) + .collect::>>()?; + + // verify that input data types is consistent with function's `TypeSignature` + data_types(&input_expr_types, fun.signature())?; + + // Since we have arg_types, we dont need args and schema. + let return_type = + fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?; + Ok(Arc::new(ScalarFunctionExpr::new( fun.name(), fun.fun(), @@ -42,8 +57,8 @@ pub fn create_physical_expr( #[cfg(test)] mod tests { - use arrow_schema::DataType; - use datafusion_common::Result; + use arrow_schema::{DataType, Schema}; + use datafusion_common::{DFSchema, Result}; use datafusion_expr::{ ColumnarValue, FuncMonotonicity, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; @@ -97,7 +112,9 @@ mod tests { // create and register the udf let udf = ScalarUDF::from(TestScalarUDF::new()); - let p_expr = create_physical_expr(&udf, &[], DataType::Float64)?; + let e = crate::expressions::lit(1.1); + let p_expr = + create_physical_expr(&udf, &[e], &Schema::empty(), &[], &DFSchema::empty())?; assert_eq!( p_expr From af1d499f60716d7c105e65cdf4f81c8bc731194f Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 10 Mar 2024 13:37:08 +0800 Subject: [PATCH 2/3] rm incorrect assumption Signed-off-by: jayzhan211 --- datafusion/expr/src/udf.rs | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index f97fa44e27ad..cd2b7d1a9070 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -18,7 +18,6 @@ //! [`ScalarUDF`]: Scalar User Defined Functions use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; -use crate::ExprSchemable; use crate::{ ColumnarValue, Expr, FuncMonotonicity, ReturnTypeFunction, ScalarFunctionImplementation, Signature, @@ -159,6 +158,8 @@ impl ScalarUDF { schema: &dyn ExprSchema, arg_types: &[DataType], ) -> Result { + // we always pre-compute the argument types before called, so arg_types can be ensured to be non-empty + assert!(!arg_types.is_empty()); // If the implementation provides a return_type_from_exprs, use it self.inner.return_type_from_exprs(args, schema, arg_types) } @@ -306,19 +307,12 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// value for `('foo' | 'bar')` as it does for ('foobar'). fn return_type_from_exprs( &self, - args: &[Expr], - schema: &dyn ExprSchema, + _args: &[Expr], + _schema: &dyn ExprSchema, arg_types: &[DataType], ) -> Result { - if arg_types.is_empty() { - let arg_types = args - .iter() - .map(|arg| arg.get_type(schema)) - .collect::>>()?; - self.return_type(&arg_types) - } else { - self.return_type(arg_types) - } + assert!(!arg_types.is_empty()); + self.return_type(arg_types) } /// Invoke the function on `args`, returning the appropriate result From 6ec947fcd5e64cf23c26c3e0d959e3b223fa24e2 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 10 Mar 2024 13:53:23 +0800 Subject: [PATCH 3/3] possible empty types Signed-off-by: jayzhan211 --- datafusion/expr/src/udf.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index cd2b7d1a9070..3002a745055f 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -158,8 +158,6 @@ impl ScalarUDF { schema: &dyn ExprSchema, arg_types: &[DataType], ) -> Result { - // we always pre-compute the argument types before called, so arg_types can be ensured to be non-empty - assert!(!arg_types.is_empty()); // If the implementation provides a return_type_from_exprs, use it self.inner.return_type_from_exprs(args, schema, arg_types) } @@ -311,7 +309,6 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { _schema: &dyn ExprSchema, arg_types: &[DataType], ) -> Result { - assert!(!arg_types.is_empty()); self.return_type(arg_types) }