Skip to content

Commit

Permalink
Extend argument types for udf return_type_from_exprs (#9522)
Browse files Browse the repository at this point in the history
* Extend argument types for udf return type function

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* rm incorrect assumption

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* possible empty types

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

---------

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 authored Mar 10, 2024
1 parent f4107d4 commit 31fcd72
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@ impl ScalarUDFImpl for TakeUDF {
&self,
arg_exprs: &[Expr],
schema: &dyn ExprSchema,
_arg_data_types: &[DataType],
) -> Result<DataType> {
if arg_exprs.len() != 3 {
return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len());
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl ExprSchemable for Expr {

// perform additional function arguments validation (due to limited
// expressiveness of `TypeSignature`), then infer return type
Ok(fun.return_type_from_exprs(args, schema)?)
Ok(fun.return_type_from_exprs(args, schema, &arg_data_types)?)
}
ScalarFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be resolved.")
Expand Down
15 changes: 6 additions & 9 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
//! [`ScalarUDF`]: Scalar User Defined Functions

use crate::simplify::{ExprSimplifyResult, SimplifyInfo};
use crate::ExprSchemable;
use crate::{
ColumnarValue, Expr, FuncMonotonicity, ReturnTypeFunction,
ScalarFunctionImplementation, Signature,
Expand Down Expand Up @@ -157,9 +156,10 @@ impl ScalarUDF {
&self,
args: &[Expr],
schema: &dyn ExprSchema,
arg_types: &[DataType],
) -> Result<DataType> {
// If the implementation provides a return_type_from_exprs, use it
self.inner.return_type_from_exprs(args, schema)
self.inner.return_type_from_exprs(args, schema, arg_types)
}

/// Do the function rewrite
Expand Down Expand Up @@ -305,14 +305,11 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// value for `('foo' | 'bar')` as it does for ('foobar').
fn return_type_from_exprs(
&self,
args: &[Expr],
schema: &dyn ExprSchema,
_args: &[Expr],
_schema: &dyn ExprSchema,
arg_types: &[DataType],
) -> Result<DataType> {
let arg_types = args
.iter()
.map(|arg| arg.get_type(schema))
.collect::<Result<Vec<_>>>()?;
self.return_type(&arg_types)
self.return_type(arg_types)
}

/// Invoke the function on `args`, returning the appropriate result
Expand Down
17 changes: 8 additions & 9 deletions datafusion/physical-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ pub fn create_physical_expr(
.iter()
.map(|e| create_physical_expr(e, input_dfschema, execution_props))
.collect::<Result<Vec<_>>>()?;

match func_def {
ScalarFunctionDefinition::BuiltIn(fun) => {
functions::create_physical_expr(
Expand All @@ -264,15 +265,13 @@ pub fn create_physical_expr(
execution_props,
)
}
ScalarFunctionDefinition::UDF(fun) => {
let return_type = fun.return_type_from_exprs(args, input_dfschema)?;

udf::create_physical_expr(
fun.clone().as_ref(),
&physical_args,
return_type,
)
}
ScalarFunctionDefinition::UDF(fun) => udf::create_physical_expr(
fun.clone().as_ref(),
&physical_args,
input_schema,
args,
input_dfschema,
),
ScalarFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be resolved.")
}
Expand Down
29 changes: 23 additions & 6 deletions datafusion/physical-expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

//! UDF support
use crate::{PhysicalExpr, ScalarFunctionExpr};
use arrow_schema::DataType;
use datafusion_common::Result;
use arrow_schema::Schema;
use datafusion_common::{DFSchema, Result};
pub use datafusion_expr::ScalarUDF;
use datafusion_expr::{type_coercion::functions::data_types, Expr};
use std::sync::Arc;

/// Create a physical expression of the UDF.
Expand All @@ -28,8 +29,22 @@ use std::sync::Arc;
pub fn create_physical_expr(
fun: &ScalarUDF,
input_phy_exprs: &[Arc<dyn PhysicalExpr>],
return_type: DataType,
input_schema: &Schema,
args: &[Expr],
input_dfschema: &DFSchema,
) -> Result<Arc<dyn PhysicalExpr>> {
let input_expr_types = input_phy_exprs
.iter()
.map(|e| e.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;

// verify that input data types is consistent with function's `TypeSignature`
data_types(&input_expr_types, fun.signature())?;

// Since we have arg_types, we dont need args and schema.
let return_type =
fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?;

Ok(Arc::new(ScalarFunctionExpr::new(
fun.name(),
fun.fun(),
Expand All @@ -42,8 +57,8 @@ pub fn create_physical_expr(

#[cfg(test)]
mod tests {
use arrow_schema::DataType;
use datafusion_common::Result;
use arrow_schema::{DataType, Schema};
use datafusion_common::{DFSchema, Result};
use datafusion_expr::{
ColumnarValue, FuncMonotonicity, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
Expand Down Expand Up @@ -97,7 +112,9 @@ mod tests {
// create and register the udf
let udf = ScalarUDF::from(TestScalarUDF::new());

let p_expr = create_physical_expr(&udf, &[], DataType::Float64)?;
let e = crate::expressions::lit(1.1);
let p_expr =
create_physical_expr(&udf, &[e], &Schema::empty(), &[], &DFSchema::empty())?;

assert_eq!(
p_expr
Expand Down

0 comments on commit 31fcd72

Please sign in to comment.