diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index eeacc48b85db..ca1582bcb34f 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -390,6 +390,7 @@ pub(crate) mod tests { &[self.column()], &[], &[], + &[], schema, self.column_name(), false, diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 38b92959e841..b57f36f728d7 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -315,6 +315,7 @@ mod tests { &[expr], &[], &[], + &[], schema, name, false, @@ -404,6 +405,7 @@ mod tests { &[col("b", &schema)?], &[], &[], + &[], &schema, "Sum(b)", false, diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 154e77cd23ae..5320938d2eb8 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -245,6 +245,7 @@ pub fn bounded_window_exec( "count".to_owned(), &[col(col_name, &schema).unwrap()], &[], + &[], &sort_exprs, Arc::new(WindowFrame::new(Some(false))), schema.as_ref(), diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 4f9187595018..404bcbb2e7d4 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1766,7 +1766,8 @@ pub fn create_window_expr_with_name( window_frame, null_treatment, }) => { - let args = create_physical_exprs(args, logical_schema, execution_props)?; + let physical_args = + create_physical_exprs(args, logical_schema, execution_props)?; let partition_by = create_physical_exprs(partition_by, logical_schema, execution_props)?; let order_by = @@ -1780,13 +1781,13 @@ pub fn create_window_expr_with_name( } let window_frame = Arc::new(window_frame.clone()); - let ignore_nulls = null_treatment - .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls) + let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; windows::create_window_expr( fun, name, - &args, + &physical_args, + args, &partition_by, &order_by, window_frame, @@ -1837,7 +1838,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( order_by, null_treatment, }) => { - let args = + let physical_args = create_physical_exprs(args, logical_input_schema, execution_props)?; let filter = match filter { Some(e) => Some(create_physical_expr( @@ -1867,7 +1868,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( let agg_expr = aggregates::create_aggregate_expr( fun, *distinct, - &args, + &physical_args, &ordering_reqs, physical_input_schema, name, @@ -1889,7 +1890,8 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( physical_sort_exprs.clone().unwrap_or(vec![]); let agg_expr = udaf::create_aggregate_expr( fun, - &args, + &physical_args, + args, &sort_exprs, &ordering_reqs, physical_input_schema, diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index b05769a6ce9d..1c55c48fea40 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -33,7 +33,7 @@ use datafusion::assert_batches_eq; use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::expr::Alias; use datafusion_expr::ExprSchemable; -use datafusion_functions_aggregate::expr_fn::approx_median; +use datafusion_functions_aggregate::expr_fn::{approx_median, approx_percentile_cont}; fn test_schema() -> SchemaRef { Arc::new(Schema::new(vec![ @@ -363,7 +363,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { let expected = [ "+---------------------------------------------+", - "| APPROX_PERCENTILE_CONT(test.b,Float64(0.5)) |", + "| approx_percentile_cont(test.b,Float64(0.5)) |", "+---------------------------------------------+", "| 10 |", "+---------------------------------------------+", @@ -384,7 +384,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { let df = create_test_table().await?; let expected = [ "+--------------------------------------+", - "| APPROX_PERCENTILE_CONT(test.b,arg_2) |", + "| approx_percentile_cont(test.b,arg_2) |", "+--------------------------------------+", "| 10 |", "+--------------------------------------+", diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index c76c1fc2c736..a04f4f349122 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -108,6 +108,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str &[col("d", &schema).unwrap()], &[], &[], + &[], &schema, "sum1", false, diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 4358691ee5a5..5bd19850cacc 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -252,6 +252,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { let partitionby_exprs = vec![]; let orderby_exprs = vec![]; + let logical_exprs = vec![]; // Window frame starts with "UNBOUNDED PRECEDING": let start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None)); @@ -283,6 +284,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { &window_fn, fn_name.to_string(), &args, + &logical_exprs, &partitionby_exprs, &orderby_exprs, Arc::new(window_frame), @@ -699,6 +701,7 @@ async fn run_window_test( &window_fn, fn_name.clone(), &args, + &[], &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), @@ -717,6 +720,7 @@ async fn run_window_test( &window_fn, fn_name, &args, + &[], &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 81562bf12476..441e8953dffc 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use std::{fmt, str::FromStr}; use crate::utils; -use crate::{type_coercion::aggregates::*, Signature, TypeSignature, Volatility}; +use crate::{type_coercion::aggregates::*, Signature, Volatility}; use arrow::datatypes::{DataType, Field}; use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; @@ -45,10 +45,6 @@ pub enum AggregateFunction { NthValue, /// Correlation Correlation, - /// Approximate continuous percentile function - ApproxPercentileCont, - /// Approximate continuous percentile function with weight - ApproxPercentileContWithWeight, /// Grouping Grouping, /// Bit And @@ -75,8 +71,6 @@ impl AggregateFunction { ArrayAgg => "ARRAY_AGG", NthValue => "NTH_VALUE", Correlation => "CORR", - ApproxPercentileCont => "APPROX_PERCENTILE_CONT", - ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", Grouping => "GROUPING", BitAnd => "BIT_AND", BitOr => "BIT_OR", @@ -113,11 +107,6 @@ impl FromStr for AggregateFunction { "string_agg" => AggregateFunction::StringAgg, // statistical "corr" => AggregateFunction::Correlation, - // approximate - "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, - "approx_percentile_cont_with_weight" => { - AggregateFunction::ApproxPercentileContWithWeight - } // other "grouping" => AggregateFunction::Grouping, _ => { @@ -170,10 +159,6 @@ impl AggregateFunction { coerced_data_types[0].clone(), true, )))), - AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()), - AggregateFunction::ApproxPercentileContWithWeight => { - Ok(coerced_data_types[0].clone()) - } AggregateFunction::Grouping => Ok(DataType::Int32), AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()), AggregateFunction::StringAgg => Ok(DataType::LargeUtf8), @@ -230,39 +215,6 @@ impl AggregateFunction { AggregateFunction::Correlation => { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } - AggregateFunction::ApproxPercentileCont => { - let mut variants = - Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1)); - // Accept any numeric value paired with a float64 percentile - for num in NUMERICS { - variants - .push(TypeSignature::Exact(vec![num.clone(), DataType::Float64])); - // Additionally accept an integer number of centroids for T-Digest - for int in INTEGERS { - variants.push(TypeSignature::Exact(vec![ - num.clone(), - DataType::Float64, - int.clone(), - ])) - } - } - - Signature::one_of(variants, Volatility::Immutable) - } - AggregateFunction::ApproxPercentileContWithWeight => Signature::one_of( - // Accept any numeric value paired with a float64 percentile - NUMERICS - .iter() - .map(|t| { - TypeSignature::Exact(vec![ - t.clone(), - t.clone(), - DataType::Float64, - ]) - }) - .collect(), - Volatility::Immutable, - ), AggregateFunction::StringAgg => { Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable) } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index fb5b3991ecd8..099851aece46 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -242,34 +242,6 @@ pub fn in_list(expr: Expr, list: Vec, negated: bool) -> Expr { Expr::InList(InList::new(Box::new(expr), list, negated)) } -/// Calculate an approximation of the specified `percentile` for `expr`. -pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::ApproxPercentileCont, - vec![expr, percentile], - false, - None, - None, - None, - )) -} - -/// Calculate an approximation of the specified `percentile` for `expr` and `weight_expr`. -pub fn approx_percentile_cont_with_weight( - expr: Expr, - weight_expr: Expr, - percentile: Expr, -) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::ApproxPercentileContWithWeight, - vec![expr, weight_expr, percentile], - false, - None, - None, - None, - )) -} - /// Create an EXISTS subquery expression pub fn exists(subquery: Arc) -> Expr { let outer_ref_columns = subquery.all_out_ref_exprs(); diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index c06f177510e7..169436145aae 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -83,8 +83,8 @@ pub struct AccumulatorArgs<'a> { /// The input type of the aggregate function. pub input_type: &'a DataType, - /// The number of arguments the aggregate function takes. - pub args_num: usize, + /// The logical expression of arguments the aggregate function takes. + pub input_exprs: &'a [Expr], } /// [`StateFieldsArgs`] contains information about the fields that an diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 6c9a71bab46a..98324ed6120b 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -17,7 +17,6 @@ use std::ops::Deref; -use super::functions::can_coerce_from; use crate::{AggregateFunction, Signature, TypeSignature}; use arrow::datatypes::{ @@ -158,55 +157,6 @@ pub fn coerce_types( } Ok(vec![Float64, Float64]) } - AggregateFunction::ApproxPercentileCont => { - if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - if input_types.len() == 3 && !input_types[2].is_integer() { - return plan_err!( - "The percentile sample points count for {:?} must be integer, not {:?}.", - agg_fun, input_types[2] - ); - } - let mut result = input_types.to_vec(); - if can_coerce_from(&Float64, &input_types[1]) { - result[1] = Float64; - } else { - return plan_err!( - "Could not coerce the percent argument for {:?} to Float64. Was {:?}.", - agg_fun, input_types[1] - ); - } - Ok(result) - } - AggregateFunction::ApproxPercentileContWithWeight => { - if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - if !is_approx_percentile_cont_supported_arg_type(&input_types[1]) { - return plan_err!( - "The weight argument for {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[1] - ); - } - if !matches!(input_types[2], Float64) { - return plan_err!( - "The percentile argument for {:?} must be Float64, not {:?}.", - agg_fun, - input_types[2] - ); - } - Ok(input_types.to_vec()) - } AggregateFunction::NthValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), AggregateFunction::StringAgg => { @@ -459,15 +409,6 @@ pub fn is_integer_arg_type(arg_type: &DataType) -> bool { arg_type.is_integer() } -/// Return `true` if `arg_type` is of a [`DataType`] that the -/// [`AggregateFunction::ApproxPercentileCont`] aggregation can operate on. -pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - ) -} - /// Return `true` if `arg_type` is of a [`DataType`] that the /// [`AggregateFunction::StringAgg`] aggregation can operate on. pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool { @@ -532,29 +473,6 @@ mod tests { assert_eq!(r[0], DataType::Decimal128(20, 3)); let r = coerce_types(&fun, &[DataType::Decimal256(20, 3)], &signature).unwrap(); assert_eq!(r[0], DataType::Decimal256(20, 3)); - - // ApproxPercentileCont input types - let input_types = vec![ - vec![DataType::Int8, DataType::Float64], - vec![DataType::Int16, DataType::Float64], - vec![DataType::Int32, DataType::Float64], - vec![DataType::Int64, DataType::Float64], - vec![DataType::UInt8, DataType::Float64], - vec![DataType::UInt16, DataType::Float64], - vec![DataType::UInt32, DataType::Float64], - vec![DataType::UInt64, DataType::Float64], - vec![DataType::Float32, DataType::Float64], - vec![DataType::Float64, DataType::Float64], - ]; - for input_type in &input_types { - let signature = AggregateFunction::ApproxPercentileCont.signature(); - let result = coerce_types( - &AggregateFunction::ApproxPercentileCont, - input_type, - &signature, - ); - assert_eq!(*input_type, result.unwrap()); - } } #[test] diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index b8b86d30557a..bc723c862953 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -28,7 +28,6 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; -use datafusion_physical_expr_common::aggregate::utils::down_cast_any_ref; use crate::approx_percentile_cont::ApproxPercentileAccumulator; @@ -118,12 +117,3 @@ impl AggregateUDFImpl for ApproxMedian { ))) } } - -impl PartialEq for ApproxMedian { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.signature == x.signature) - .unwrap_or(false) - } -} diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index e75417efc684..5ae5684d9cab 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -15,6 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +use arrow::array::RecordBatch; use arrow::{ array::{ ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, @@ -22,12 +27,238 @@ use arrow::{ }, datatypes::DataType, }; +use arrow_schema::{Field, Schema}; -use datafusion_common::{downcast_value, internal_err, DataFusionError, ScalarValue}; -use datafusion_expr::Accumulator; +use datafusion_common::{ + downcast_value, internal_err, not_impl_err, plan_err, DataFusionError, ScalarValue, +}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, ColumnarValue, Expr, Signature, TypeSignature, + Volatility, +}; use datafusion_physical_expr_common::aggregate::tdigest::{ TDigest, TryIntoF64, DEFAULT_MAX_SIZE, }; +use datafusion_physical_expr_common::utils::limited_convert_logical_expr_to_physical_expr; + +make_udaf_expr_and_func!( + ApproxPercentileCont, + approx_percentile_cont, + expression percentile, + "Computes the approximate percentile continuous of a set of numbers", + approx_percentile_cont_udaf +); + +pub struct ApproxPercentileCont { + signature: Signature, +} + +impl Debug for ApproxPercentileCont { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + f.debug_struct("ApproxPercentileCont") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for ApproxPercentileCont { + fn default() -> Self { + Self::new() + } +} + +impl ApproxPercentileCont { + /// Create a new [`ApproxPercentileCont`] aggregate function. + pub fn new() -> Self { + let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1)); + // Accept any numeric value paired with a float64 percentile + for num in NUMERICS { + variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64])); + // Additionally accept an integer number of centroids for T-Digest + for int in INTEGERS { + variants.push(TypeSignature::Exact(vec![ + num.clone(), + DataType::Float64, + int.clone(), + ])) + } + } + Self { + signature: Signature::one_of(variants, Volatility::Immutable), + } + } + + pub(crate) fn create_accumulator( + &self, + args: AccumulatorArgs, + ) -> datafusion_common::Result { + let percentile = validate_input_percentile_expr(&args.input_exprs[1])?; + let tdigest_max_size = if args.input_exprs.len() == 3 { + Some(validate_input_max_size_expr(&args.input_exprs[2])?) + } else { + None + }; + + let accumulator: ApproxPercentileAccumulator = match args.input_type { + t @ (DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64) => { + if let Some(max_size) = tdigest_max_size { + ApproxPercentileAccumulator::new_with_max_size(percentile, t.clone(), max_size) + }else{ + ApproxPercentileAccumulator::new(percentile, t.clone()) + + } + } + other => { + return not_impl_err!( + "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented" + ) + } + }; + + Ok(accumulator) + } +} + +fn get_lit_value(expr: &Expr) -> datafusion_common::Result { + let empty_schema = Arc::new(Schema::empty()); + let empty_batch = RecordBatch::new_empty(Arc::clone(&empty_schema)); + let expr = limited_convert_logical_expr_to_physical_expr(expr, &empty_schema)?; + let result = expr.evaluate(&empty_batch)?; + match result { + ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!( + "The expr {:?} can't be evaluated to scalar value", + expr + ))), + ColumnarValue::Scalar(scalar_value) => Ok(scalar_value), + } +} + +fn validate_input_percentile_expr(expr: &Expr) -> datafusion_common::Result { + let lit = get_lit_value(expr)?; + let percentile = match &lit { + ScalarValue::Float32(Some(q)) => *q as f64, + ScalarValue::Float64(Some(q)) => *q, + got => return not_impl_err!( + "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", + got.data_type() + ) + }; + + // Ensure the percentile is between 0 and 1. + if !(0.0..=1.0).contains(&percentile) { + return plan_err!( + "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid" + ); + } + Ok(percentile) +} + +fn validate_input_max_size_expr(expr: &Expr) -> datafusion_common::Result { + let lit = get_lit_value(expr)?; + let max_size = match &lit { + ScalarValue::UInt8(Some(q)) => *q as usize, + ScalarValue::UInt16(Some(q)) => *q as usize, + ScalarValue::UInt32(Some(q)) => *q as usize, + ScalarValue::UInt64(Some(q)) => *q as usize, + ScalarValue::Int32(Some(q)) if *q > 0 => *q as usize, + ScalarValue::Int64(Some(q)) if *q > 0 => *q as usize, + ScalarValue::Int16(Some(q)) if *q > 0 => *q as usize, + ScalarValue::Int8(Some(q)) if *q > 0 => *q as usize, + got => return not_impl_err!( + "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).", + got.data_type() + ) + }; + Ok(max_size) +} + +impl AggregateUDFImpl for ApproxPercentileCont { + fn as_any(&self) -> &dyn Any { + self + } + + #[allow(rustdoc::private_intra_doc_links)] + /// See [`datafusion_physical_expr_common::aggregate::tdigest::TDigest::to_scalar_state()`] for a description of the serialised + /// state. + fn state_fields( + &self, + args: StateFieldsArgs, + ) -> datafusion_common::Result> { + Ok(vec![ + Field::new( + format_state_name(args.name, "max_size"), + DataType::UInt64, + false, + ), + Field::new( + format_state_name(args.name, "sum"), + DataType::Float64, + false, + ), + Field::new( + format_state_name(args.name, "count"), + DataType::Float64, + false, + ), + Field::new( + format_state_name(args.name, "max"), + DataType::Float64, + false, + ), + Field::new( + format_state_name(args.name, "min"), + DataType::Float64, + false, + ), + Field::new_list( + format_state_name(args.name, "centroids"), + Field::new("item", DataType::Float64, true), + false, + ), + ]) + } + + fn name(&self) -> &str { + "approx_percentile_cont" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + #[inline] + fn accumulator( + &self, + acc_args: AccumulatorArgs, + ) -> datafusion_common::Result> { + Ok(Box::new(self.create_accumulator(acc_args)?)) + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + if !arg_types[0].is_numeric() { + return plan_err!("approx_percentile_cont requires numeric input types"); + } + if arg_types.len() == 3 && !arg_types[2].is_integer() { + return plan_err!( + "approx_percentile_cont requires integer max_size input types" + ); + } + Ok(arg_types[0].clone()) + } +} #[derive(Debug)] pub struct ApproxPercentileAccumulator { diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs similarity index 51% rename from datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs rename to datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index 07c2aff3437f..a64218c606c4 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -15,105 +15,140 @@ // specific language governing permissions and limitations // under the License. -use crate::expressions::ApproxPercentileCont; -use crate::{AggregateExpr, PhysicalExpr}; +use std::any::Any; +use std::fmt::{Debug, Formatter}; + use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, }; -use datafusion_functions_aggregate::approx_percentile_cont::ApproxPercentileAccumulator; + +use datafusion_common::ScalarValue; +use datafusion_common::{not_impl_err, plan_err, Result}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, TypeSignature}; use datafusion_physical_expr_common::aggregate::tdigest::{ Centroid, TDigest, DEFAULT_MAX_SIZE, }; -use datafusion_common::Result; -use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; +use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont}; -use crate::aggregate::utils::down_cast_any_ref; -use std::{any::Any, sync::Arc}; +make_udaf_expr_and_func!( + ApproxPercentileContWithWeight, + approx_percentile_cont_with_weight, + expression weight percentile, + "Computes the approximate percentile continuous with weight of a set of numbers", + approx_percentile_cont_with_weight_udaf +); /// APPROX_PERCENTILE_CONT_WITH_WEIGTH aggregate expression -#[derive(Debug)] pub struct ApproxPercentileContWithWeight { + signature: Signature, approx_percentile_cont: ApproxPercentileCont, - column_expr: Arc, - weight_expr: Arc, - percentile_expr: Arc, +} + +impl Debug for ApproxPercentileContWithWeight { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ApproxPercentileContWithWeight") + .field("signature", &self.signature) + .finish() + } +} + +impl Default for ApproxPercentileContWithWeight { + fn default() -> Self { + Self::new() + } } impl ApproxPercentileContWithWeight { /// Create a new [`ApproxPercentileContWithWeight`] aggregate function. - pub fn new( - expr: Vec>, - name: impl Into, - return_type: DataType, - ) -> Result { - // Arguments should be [ColumnExpr, WeightExpr, DesiredPercentileLiteral] - debug_assert_eq!(expr.len(), 3); - - let sub_expr = vec![expr[0].clone(), expr[2].clone()]; - let approx_percentile_cont = - ApproxPercentileCont::new(sub_expr, name, return_type)?; - - Ok(Self { - approx_percentile_cont, - column_expr: expr[0].clone(), - weight_expr: expr[1].clone(), - percentile_expr: expr[2].clone(), - }) + pub fn new() -> Self { + Self { + signature: Signature::one_of( + // Accept any numeric value paired with a float64 percentile + NUMERICS + .iter() + .map(|t| { + TypeSignature::Exact(vec![ + t.clone(), + t.clone(), + DataType::Float64, + ]) + }) + .collect(), + Immutable, + ), + approx_percentile_cont: ApproxPercentileCont::new(), + } } } -impl AggregateExpr for ApproxPercentileContWithWeight { +impl AggregateUDFImpl for ApproxPercentileContWithWeight { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - self.approx_percentile_cont.field() + fn name(&self) -> &str { + "approx_percentile_cont_with_weight" } - #[allow(rustdoc::private_intra_doc_links)] - /// See [`TDigest::to_scalar_state()`] for a description of the serialised - /// state. - fn state_fields(&self) -> Result> { - self.approx_percentile_cont.state_fields() + fn signature(&self) -> &Signature { + &self.signature } - fn expressions(&self) -> Vec> { - vec![ - self.column_expr.clone(), - self.weight_expr.clone(), - self.percentile_expr.clone(), - ] + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!( + "approx_percentile_cont_with_weight requires numeric input types" + ); + } + if !arg_types[1].is_numeric() { + return plan_err!( + "approx_percentile_cont_with_weight requires numeric weight input types" + ); + } + if arg_types[2] != DataType::Float64 { + return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types"); + } + Ok(arg_types[0].clone()) } - fn create_accumulator(&self) -> Result> { + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return not_impl_err!( + "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available" + ); + } + + if acc_args.input_exprs.len() != 3 { + return plan_err!( + "approx_percentile_cont_with_weight requires three arguments: value, weight, percentile" + ); + } + + let sub_args = AccumulatorArgs { + input_exprs: &[ + acc_args.input_exprs[0].clone(), + acc_args.input_exprs[2].clone(), + ], + ..acc_args + }; let approx_percentile_cont_accumulator = - self.approx_percentile_cont.create_plain_accumulator()?; + self.approx_percentile_cont.create_accumulator(sub_args)?; let accumulator = ApproxPercentileWithWeightAccumulator::new( approx_percentile_cont_accumulator, ); Ok(Box::new(accumulator)) } - fn name(&self) -> &str { - self.approx_percentile_cont.name() - } -} - -impl PartialEq for ApproxPercentileContWithWeight { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.approx_percentile_cont == x.approx_percentile_cont - && self.column_expr.eq(&x.column_expr) - && self.weight_expr.eq(&x.weight_expr) - && self.percentile_expr.eq(&x.percentile_expr) - }) - .unwrap_or(false) + #[allow(rustdoc::private_intra_doc_links)] + /// See [`TDigest::to_scalar_state()`] for a description of the serialised + /// state. + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + self.approx_percentile_cont.state_fields(args) } } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index cfd56619537b..062e148975bf 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -258,7 +258,7 @@ impl AggregateUDFImpl for Count { if args.is_distinct { return false; } - args.args_num == 1 + args.input_exprs.len() == 1 } fn create_groups_accumulator( diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index fabe15e416f4..daddb9d93f78 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -68,7 +68,10 @@ pub mod variance; pub mod approx_median; pub mod approx_percentile_cont; +pub mod approx_percentile_cont_with_weight; +use crate::approx_percentile_cont::approx_percentile_cont_udaf; +use crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; use datafusion_expr::AggregateUDF; @@ -79,6 +82,8 @@ use std::sync::Arc; pub mod expr_fn { pub use super::approx_distinct; pub use super::approx_median::approx_median; + pub use super::approx_percentile_cont::approx_percentile_cont; + pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight; pub use super::count::count; pub use super::count::count_distinct; pub use super::covariance::covar_pop; @@ -127,6 +132,8 @@ pub fn all_default_aggregate_functions() -> Vec> { stddev::stddev_pop_udaf(), approx_median::approx_median_udaf(), approx_distinct::approx_distinct_udaf(), + approx_percentile_cont_udaf(), + approx_percentile_cont_with_weight_udaf(), ] } diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 4c3effe7650a..42cf44f65d8f 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -332,7 +332,7 @@ mod tests { name: "a", is_distinct: false, input_type: &DataType::Float64, - args_num: 1, + input_exprs: &[datafusion_expr::col("a")], }; let args2 = AccumulatorArgs { @@ -343,7 +343,7 @@ mod tests { name: "a", is_distinct: false, input_type: &DataType::Float64, - args_num: 1, + input_exprs: &[datafusion_expr::col("a")], }; let mut accum1 = agg1.accumulator(args1)?; diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 0c8e4ae34a90..acc21f14f44d 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -1055,31 +1055,6 @@ mod test { Ok(()) } - #[test] - fn agg_function_invalid_input_percentile() { - let empty = empty(); - let fun: AggregateFunction = AggregateFunction::ApproxPercentileCont; - let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, - vec![lit(0.95), lit(42.0), lit(100.0)], - false, - None, - None, - None, - )); - - let err = Projection::try_new(vec![agg_expr], empty) - .err() - .unwrap() - .strip_backtrace(); - - let prefix = "Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT(Float64, Float64, Float64)'. You might need to add explicit type casts.\n\tCandidate functions:"; - assert!(!err - .strip_prefix(prefix) - .unwrap() - .contains("APPROX_PERCENTILE_CONT(Float64, Float64, Float64)")); - } - #[test] fn binary_op_date32_op_interval() -> Result<()> { // CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("...") diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 21884f840dbd..432267e045b2 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -46,6 +46,7 @@ use datafusion_expr::utils::AggregateOrderSensitivity; pub fn create_aggregate_expr( fun: &AggregateUDF, input_phy_exprs: &[Arc], + input_exprs: &[Expr], sort_exprs: &[Expr], ordering_req: &[PhysicalSortExpr], schema: &Schema, @@ -76,6 +77,7 @@ pub fn create_aggregate_expr( Ok(Arc::new(AggregateFunctionExpr { fun: fun.clone(), args: input_phy_exprs.to_vec(), + logical_args: input_exprs.to_vec(), data_type: fun.return_type(&input_exprs_types)?, name: name.into(), schema: schema.clone(), @@ -231,6 +233,7 @@ pub struct AggregatePhysicalExpressions { pub struct AggregateFunctionExpr { fun: AggregateUDF, args: Vec>, + logical_args: Vec, /// Output / return type of this aggregate data_type: DataType, name: String, @@ -293,7 +296,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - args_num: self.args.len(), + input_exprs: &self.logical_args, name: &self.name, }; @@ -308,7 +311,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - args_num: self.args.len(), + input_exprs: &self.logical_args, name: &self.name, }; @@ -378,7 +381,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - args_num: self.args.len(), + input_exprs: &self.logical_args, name: &self.name, }; self.fun.groups_accumulator_supported(args) @@ -392,7 +395,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - args_num: self.args.len(), + input_exprs: &self.logical_args, name: &self.name, }; self.fun.create_groups_accumulator(args) @@ -434,6 +437,7 @@ impl AggregateExpr for AggregateFunctionExpr { create_aggregate_expr( &updated_fn, &self.args, + &self.logical_args, &self.sort_exprs, &self.ordering_req, &self.schema, @@ -468,6 +472,7 @@ impl AggregateExpr for AggregateFunctionExpr { let reverse_aggr = create_aggregate_expr( &reverse_udf, &self.args, + &self.logical_args, &reverse_sort_exprs, &reverse_ordering_req, &self.schema, diff --git a/datafusion/physical-expr-common/src/expressions/mod.rs b/datafusion/physical-expr-common/src/expressions/mod.rs index ea21c8e9a92b..dd534cc07d20 100644 --- a/datafusion/physical-expr-common/src/expressions/mod.rs +++ b/datafusion/physical-expr-common/src/expressions/mod.rs @@ -17,7 +17,7 @@ mod cast; pub mod column; -mod literal; +pub mod literal; pub use cast::{cast, cast_with_options, CastExpr}; pub use literal::{lit, Literal}; diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index f661400fcb10..d5cd3c6f4af0 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -17,18 +17,21 @@ use std::sync::Arc; -use crate::expressions::{self, CastExpr}; -use crate::physical_expr::PhysicalExpr; -use crate::sort_expr::PhysicalSortExpr; -use crate::tree_node::ExprContext; - use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; use arrow::datatypes::Schema; + use datafusion_common::{exec_err, Result}; +use datafusion_expr::expr::Alias; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::Expr; +use crate::expressions::literal::Literal; +use crate::expressions::{self, CastExpr}; +use crate::physical_expr::PhysicalExpr; +use crate::sort_expr::PhysicalSortExpr; +use crate::tree_node::ExprContext; + /// Represents a [`PhysicalExpr`] node with associated properties (order and /// range) in a context where properties are tracked. pub type ExprPropertiesNode = ExprContext; @@ -115,6 +118,9 @@ pub fn limited_convert_logical_expr_to_physical_expr( schema: &Schema, ) -> Result> { match expr { + Expr::Alias(Alias { expr, .. }) => { + Ok(limited_convert_logical_expr_to_physical_expr(expr, schema)?) + } Expr::Column(col) => expressions::column::col(&col.name, schema), Expr::Cast(cast_expr) => Ok(Arc::new(CastExpr::new( limited_convert_logical_expr_to_physical_expr( @@ -124,10 +130,7 @@ pub fn limited_convert_logical_expr_to_physical_expr( cast_expr.data_type.clone(), None, ))), - Expr::Alias(alias_expr) => limited_convert_logical_expr_to_physical_expr( - alias_expr.expr.as_ref(), - schema, - ), + Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), _ => exec_err!( "Unsupported expression: {expr} for conversion to Arc" ), @@ -138,11 +141,12 @@ pub fn limited_convert_logical_expr_to_physical_expr( mod tests { use std::sync::Arc; - use super::*; - use arrow::array::Int32Array; + use datafusion_common::cast::{as_boolean_array, as_int32_array}; + use super::*; + #[test] fn scatter_int() -> Result<()> { let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs deleted file mode 100644 index f2068bbc92cc..000000000000 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs +++ /dev/null @@ -1,249 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::{any::Any, sync::Arc}; - -use arrow::datatypes::{DataType, Field}; -use arrow_array::RecordBatch; -use arrow_schema::Schema; - -use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::{Accumulator, ColumnarValue}; -use datafusion_functions_aggregate::approx_percentile_cont::ApproxPercentileAccumulator; - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; - -/// APPROX_PERCENTILE_CONT aggregate expression -#[derive(Debug)] -pub struct ApproxPercentileCont { - name: String, - input_data_type: DataType, - expr: Vec>, - percentile: f64, - tdigest_max_size: Option, -} - -impl ApproxPercentileCont { - /// Create a new [`ApproxPercentileCont`] aggregate function. - pub fn new( - expr: Vec>, - name: impl Into, - input_data_type: DataType, - ) -> Result { - // Arguments should be [ColumnExpr, DesiredPercentileLiteral] - debug_assert_eq!(expr.len(), 2); - - let percentile = validate_input_percentile_expr(&expr[1])?; - - Ok(Self { - name: name.into(), - input_data_type, - // The physical expr to evaluate during accumulation - expr, - percentile, - tdigest_max_size: None, - }) - } - - /// Create a new [`ApproxPercentileCont`] aggregate function. - pub fn new_with_max_size( - expr: Vec>, - name: impl Into, - input_data_type: DataType, - ) -> Result { - // Arguments should be [ColumnExpr, DesiredPercentileLiteral, TDigestMaxSize] - debug_assert_eq!(expr.len(), 3); - let percentile = validate_input_percentile_expr(&expr[1])?; - let max_size = validate_input_max_size_expr(&expr[2])?; - Ok(Self { - name: name.into(), - input_data_type, - // The physical expr to evaluate during accumulation - expr, - percentile, - tdigest_max_size: Some(max_size), - }) - } - - pub(crate) fn create_plain_accumulator(&self) -> Result { - let accumulator: ApproxPercentileAccumulator = match &self.input_data_type { - t @ (DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64) => { - if let Some(max_size) = self.tdigest_max_size { - ApproxPercentileAccumulator::new_with_max_size(self.percentile, t.clone(), max_size) - - }else{ - ApproxPercentileAccumulator::new(self.percentile, t.clone()) - - } - } - other => { - return not_impl_err!( - "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented" - ) - } - }; - Ok(accumulator) - } -} - -impl PartialEq for ApproxPercentileCont { - fn eq(&self, other: &ApproxPercentileCont) -> bool { - self.name == other.name - && self.input_data_type == other.input_data_type - && self.percentile == other.percentile - && self.tdigest_max_size == other.tdigest_max_size - && self.expr.len() == other.expr.len() - && self - .expr - .iter() - .zip(other.expr.iter()) - .all(|(this, other)| this.eq(other)) - } -} - -fn get_lit_value(expr: &Arc) -> Result { - let empty_schema = Schema::empty(); - let empty_batch = RecordBatch::new_empty(Arc::new(empty_schema)); - let result = expr.evaluate(&empty_batch)?; - match result { - ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!( - "The expr {:?} can't be evaluated to scalar value", - expr - ))), - ColumnarValue::Scalar(scalar_value) => Ok(scalar_value), - } -} - -fn validate_input_percentile_expr(expr: &Arc) -> Result { - let lit = get_lit_value(expr)?; - let percentile = match &lit { - ScalarValue::Float32(Some(q)) => *q as f64, - ScalarValue::Float64(Some(q)) => *q, - got => return not_impl_err!( - "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", - got.data_type() - ) - }; - - // Ensure the percentile is between 0 and 1. - if !(0.0..=1.0).contains(&percentile) { - return plan_err!( - "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid" - ); - } - Ok(percentile) -} - -fn validate_input_max_size_expr(expr: &Arc) -> Result { - let lit = get_lit_value(expr)?; - let max_size = match &lit { - ScalarValue::UInt8(Some(q)) => *q as usize, - ScalarValue::UInt16(Some(q)) => *q as usize, - ScalarValue::UInt32(Some(q)) => *q as usize, - ScalarValue::UInt64(Some(q)) => *q as usize, - ScalarValue::Int32(Some(q)) if *q > 0 => *q as usize, - ScalarValue::Int64(Some(q)) if *q > 0 => *q as usize, - ScalarValue::Int16(Some(q)) if *q > 0 => *q as usize, - ScalarValue::Int8(Some(q)) if *q > 0 => *q as usize, - got => return not_impl_err!( - "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).", - got.data_type() - ) - }; - Ok(max_size) -} - -impl AggregateExpr for ApproxPercentileCont { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.input_data_type.clone(), false)) - } - - #[allow(rustdoc::private_intra_doc_links)] - /// See [`datafusion_physical_expr_common::aggregate::tdigest::TDigest::to_scalar_state()`] for a description of the serialised - /// state. - fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "max_size"), - DataType::UInt64, - false, - ), - Field::new( - format_state_name(&self.name, "sum"), - DataType::Float64, - false, - ), - Field::new( - format_state_name(&self.name, "count"), - DataType::Float64, - false, - ), - Field::new( - format_state_name(&self.name, "max"), - DataType::Float64, - false, - ), - Field::new( - format_state_name(&self.name, "min"), - DataType::Float64, - false, - ), - Field::new_list( - format_state_name(&self.name, "centroids"), - Field::new("item", DataType::Float64, true), - false, - ), - ]) - } - - fn expressions(&self) -> Vec> { - self.expr.clone() - } - - fn create_accumulator(&self) -> Result> { - let accumulator = self.create_plain_accumulator()?; - Ok(Box::new(accumulator)) - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for ApproxPercentileCont { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.eq(x)) - .unwrap_or(false) - } -} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index df87a2e261a1..a1f5f153a9ff 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -36,6 +36,7 @@ use datafusion_expr::AggregateFunction; use crate::aggregate::average::Avg; use crate::expressions::{self, Literal}; use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; + /// Create a physical aggregation expression. /// This function errors when `input_phy_exprs`' can't be coerced to a valid argument type of the aggregation function. pub fn create_aggregate_expr( @@ -154,41 +155,6 @@ pub fn create_aggregate_expr( (AggregateFunction::Correlation, true) => { return not_impl_err!("CORR(DISTINCT) aggregations are not available"); } - (AggregateFunction::ApproxPercentileCont, false) => { - if input_phy_exprs.len() == 2 { - Arc::new(expressions::ApproxPercentileCont::new( - // Pass in the desired percentile expr - input_phy_exprs, - name, - data_type, - )?) - } else { - Arc::new(expressions::ApproxPercentileCont::new_with_max_size( - // Pass in the desired percentile expr - input_phy_exprs, - name, - data_type, - )?) - } - } - (AggregateFunction::ApproxPercentileCont, true) => { - return not_impl_err!( - "approx_percentile_cont(DISTINCT) aggregations are not available" - ); - } - (AggregateFunction::ApproxPercentileContWithWeight, false) => { - Arc::new(expressions::ApproxPercentileContWithWeight::new( - // Pass in the desired percentile expr - input_phy_exprs, - name, - data_type, - )?) - } - (AggregateFunction::ApproxPercentileContWithWeight, true) => { - return not_impl_err!( - "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available" - ); - } (AggregateFunction::NthValue, _) => { let expr = &input_phy_exprs[0]; let Some(n) = input_phy_exprs[1] @@ -232,15 +198,15 @@ pub fn create_aggregate_expr( mod tests { use arrow::datatypes::{DataType, Field}; - use super::*; + use datafusion_common::plan_err; + use datafusion_expr::{type_coercion, Signature}; + use crate::expressions::{ - try_cast, ApproxPercentileCont, ArrayAgg, Avg, BitAnd, BitOr, BitXor, BoolAnd, - BoolOr, DistinctArrayAgg, Max, Min, + try_cast, ArrayAgg, Avg, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, + DistinctArrayAgg, Max, Min, }; - use datafusion_common::{plan_err, DataFusionError, ScalarValue}; - use datafusion_expr::type_coercion::aggregates::NUMERICS; - use datafusion_expr::{type_coercion, Signature}; + use super::*; #[test] fn test_approx_expr() -> Result<()> { @@ -304,59 +270,6 @@ mod tests { Ok(()) } - #[test] - fn test_agg_approx_percentile_phy_expr() { - for data_type in NUMERICS { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![ - Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - ), - Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.2)))), - ]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &AggregateFunction::ApproxPercentileCont, - false, - &input_phy_exprs[..], - &input_schema, - "c1", - ) - .expect("failed to create aggregate expr"); - - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), false), - result_agg_phy_exprs.field().unwrap() - ); - } - } - - #[test] - fn test_agg_approx_percentile_invalid_phy_expr() { - for data_type in NUMERICS { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![ - Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - ), - Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))), - ]; - let err = create_physical_agg_expr_for_test( - &AggregateFunction::ApproxPercentileCont, - false, - &input_phy_exprs[..], - &input_schema, - "c1", - ) - .expect_err("should fail due to invalid percentile"); - - assert!(matches!(err, DataFusionError::Plan(_))); - } - } - #[test] fn test_min_max_expr() -> Result<()> { let funcs = vec![AggregateFunction::Min, AggregateFunction::Max]; diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 9079a81e6241..c20902c11b86 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -17,8 +17,6 @@ pub use datafusion_physical_expr_common::aggregate::AggregateExpr; -pub(crate) mod approx_percentile_cont; -pub(crate) mod approx_percentile_cont_with_weight; pub(crate) mod array_agg; pub(crate) mod array_agg_distinct; pub(crate) mod array_agg_ordered; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 592393f800d0..b9a159b21e3d 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -35,8 +35,6 @@ mod try_cast; pub mod helpers { pub use crate::aggregate::min_max::{max, min}; } -pub use crate::aggregate::approx_percentile_cont::ApproxPercentileCont; -pub use crate::aggregate::approx_percentile_cont_with_weight::ApproxPercentileContWithWeight; pub use crate::aggregate::array_agg::ArrayAgg; pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg; pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg; @@ -65,8 +63,8 @@ pub use column::UnKnownColumn; pub use datafusion_expr::utils::format_state_name; pub use datafusion_functions_aggregate::first_last::{FirstValue, LastValue}; pub use datafusion_physical_expr_common::expressions::column::{col, Column}; +pub use datafusion_physical_expr_common::expressions::literal::{lit, Literal}; pub use datafusion_physical_expr_common::expressions::{cast, CastExpr}; -pub use datafusion_physical_expr_common::expressions::{lit, Literal}; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index b6fc70be7cbc..b7d8d60f4f35 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1339,6 +1339,7 @@ mod tests { let aggregates = vec![create_aggregate_expr( &count_udaf(), &[lit(1i8)], + &[datafusion_expr::lit(1i8)], &[], &[], &input_schema, @@ -1787,6 +1788,7 @@ mod tests { &args, &[], &[], + &[], schema, "MEDIAN(a)", false, @@ -1975,10 +1977,12 @@ mod tests { options: sort_options, }]; let args = vec![col("b", schema)?]; + let logical_args = vec![datafusion_expr::col("b")]; let func = datafusion_expr::AggregateUDF::new_from_impl(FirstValue::new()); datafusion_physical_expr_common::aggregate::create_aggregate_expr( &func, &args, + &logical_args, &sort_exprs, &ordering_req, schema, @@ -2005,10 +2009,12 @@ mod tests { options: sort_options, }]; let args = vec![col("b", schema)?]; + let logical_args = vec![datafusion_expr::col("b")]; let func = datafusion_expr::AggregateUDF::new_from_impl(LastValue::new()); - datafusion_physical_expr_common::aggregate::create_aggregate_expr( + create_aggregate_expr( &func, &args, + &logical_args, &sort_exprs, &ordering_req, schema, diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 56d780e51394..fc60ab997375 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -1194,7 +1194,7 @@ mod tests { RecordBatchStream, SendableRecordBatchStream, TaskContext, }; use datafusion_expr::{ - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + Expr, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::{col, Column, NthValue}; @@ -1301,7 +1301,10 @@ mod tests { let window_fn = WindowFunctionDefinition::AggregateUDF(count_udaf()); let col_expr = Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc; + let log_expr = + Expr::Column(datafusion_common::Column::from(schema.fields[0].name())); let args = vec![col_expr]; + let log_args = vec![log_expr]; let partitionby_exprs = vec![col(hash, &schema)?]; let orderby_exprs = vec![PhysicalSortExpr { expr: col(order_by, &schema)?, @@ -1322,6 +1325,7 @@ mod tests { &window_fn, fn_name, &args, + &log_args, &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 63ce473fc57e..ecfe123a43af 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -90,6 +90,7 @@ pub fn create_window_expr( fun: &WindowFunctionDefinition, name: String, args: &[Arc], + logical_args: &[Expr], partition_by: &[Arc], order_by: &[PhysicalSortExpr], window_frame: Arc, @@ -144,6 +145,7 @@ pub fn create_window_expr( let aggregate = udaf::create_aggregate_expr( fun.as_ref(), args, + logical_args, &sort_exprs, order_by, input_schema, @@ -754,6 +756,7 @@ mod tests { &[col("a", &schema)?], &[], &[], + &[], Arc::new(WindowFrame::new(None)), schema.as_ref(), false, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 83223a04d023..e5578ae62f3e 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -486,9 +486,9 @@ enum AggregateFunction { // STDDEV = 11; // STDDEV_POP = 12; CORRELATION = 13; - APPROX_PERCENTILE_CONT = 14; + // APPROX_PERCENTILE_CONT = 14; // APPROX_MEDIAN = 15; - APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; + // APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; GROUPING = 17; // MEDIAN = 18; BIT_AND = 19; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index f298dd241abf..4a7b9610e5bc 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -537,8 +537,6 @@ impl serde::Serialize for AggregateFunction { Self::Avg => "AVG", Self::ArrayAgg => "ARRAY_AGG", Self::Correlation => "CORRELATION", - Self::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", - Self::ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", Self::Grouping => "GROUPING", Self::BitAnd => "BIT_AND", Self::BitOr => "BIT_OR", @@ -563,8 +561,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "AVG", "ARRAY_AGG", "CORRELATION", - "APPROX_PERCENTILE_CONT", - "APPROX_PERCENTILE_CONT_WITH_WEIGHT", "GROUPING", "BIT_AND", "BIT_OR", @@ -618,8 +614,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "AVG" => Ok(AggregateFunction::Avg), "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), "CORRELATION" => Ok(AggregateFunction::Correlation), - "APPROX_PERCENTILE_CONT" => Ok(AggregateFunction::ApproxPercentileCont), - "APPROX_PERCENTILE_CONT_WITH_WEIGHT" => Ok(AggregateFunction::ApproxPercentileContWithWeight), "GROUPING" => Ok(AggregateFunction::Grouping), "BIT_AND" => Ok(AggregateFunction::BitAnd), "BIT_OR" => Ok(AggregateFunction::BitOr), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index fa0217e9ef4f..ffaef445d668 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1940,9 +1940,9 @@ pub enum AggregateFunction { /// STDDEV = 11; /// STDDEV_POP = 12; Correlation = 13, - ApproxPercentileCont = 14, + /// APPROX_PERCENTILE_CONT = 14; /// APPROX_MEDIAN = 15; - ApproxPercentileContWithWeight = 16, + /// APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; Grouping = 17, /// MEDIAN = 18; BitAnd = 19, @@ -1974,10 +1974,6 @@ impl AggregateFunction { AggregateFunction::Avg => "AVG", AggregateFunction::ArrayAgg => "ARRAY_AGG", AggregateFunction::Correlation => "CORRELATION", - AggregateFunction::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", - AggregateFunction::ApproxPercentileContWithWeight => { - "APPROX_PERCENTILE_CONT_WITH_WEIGHT" - } AggregateFunction::Grouping => "GROUPING", AggregateFunction::BitAnd => "BIT_AND", AggregateFunction::BitOr => "BIT_OR", @@ -1996,10 +1992,6 @@ impl AggregateFunction { "AVG" => Some(Self::Avg), "ARRAY_AGG" => Some(Self::ArrayAgg), "CORRELATION" => Some(Self::Correlation), - "APPROX_PERCENTILE_CONT" => Some(Self::ApproxPercentileCont), - "APPROX_PERCENTILE_CONT_WITH_WEIGHT" => { - Some(Self::ApproxPercentileContWithWeight) - } "GROUPING" => Some(Self::Grouping), "BIT_AND" => Some(Self::BitAnd), "BIT_OR" => Some(Self::BitOr), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index ed7b0129cc48..25b7413a984a 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -147,12 +147,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::BoolOr => Self::BoolOr, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, protobuf::AggregateFunction::Correlation => Self::Correlation, - protobuf::AggregateFunction::ApproxPercentileCont => { - Self::ApproxPercentileCont - } - protobuf::AggregateFunction::ApproxPercentileContWithWeight => { - Self::ApproxPercentileContWithWeight - } protobuf::AggregateFunction::Grouping => Self::Grouping, protobuf::AggregateFunction::NthValueAgg => Self::NthValue, protobuf::AggregateFunction::StringAgg => Self::StringAgg, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 04f7b596fea8..d9548325dac3 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -118,10 +118,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::BoolOr => Self::BoolOr, AggregateFunction::ArrayAgg => Self::ArrayAgg, AggregateFunction::Correlation => Self::Correlation, - AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont, - AggregateFunction::ApproxPercentileContWithWeight => { - Self::ApproxPercentileContWithWeight - } AggregateFunction::Grouping => Self::Grouping, AggregateFunction::NthValue => Self::NthValueAgg, AggregateFunction::StringAgg => Self::StringAgg, @@ -381,12 +377,6 @@ pub fn serialize_expr( }) => match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { let aggr_function = match fun { - AggregateFunction::ApproxPercentileCont => { - protobuf::AggregateFunction::ApproxPercentileCont - } - AggregateFunction::ApproxPercentileContWithWeight => { - protobuf::AggregateFunction::ApproxPercentileContWithWeight - } AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 0a91df568a1d..b636c77641c7 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -126,7 +126,6 @@ pub fn parse_physical_window_expr( ) -> Result> { let window_node_expr = parse_physical_exprs(&proto.args, registry, input_schema, codec)?; - let partition_by = parse_physical_exprs(&proto.partition_by, registry, input_schema, codec)?; @@ -178,10 +177,13 @@ pub fn parse_physical_window_expr( // TODO: Remove extended_schema if functions are all UDAF let extended_schema = schema_add_window_field(&window_node_expr, input_schema, &fun, &name)?; + // approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet. + let logical_exprs = &[]; create_window_expr( &fun, name, &window_node_expr, + logical_exprs, &partition_by, &order_by, Arc::new(window_frame), diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index d0011e4917bf..8a488d30cf24 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -496,11 +496,14 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { } AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = registry.udaf(udaf_name)?; + // TODO: 'logical_exprs' is not supported for UDAF yet. + // approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet. + let logical_exprs = &[]; // TODO: `order by` is not supported for UDAF yet let sort_exprs = &[]; let ordering_req = &[]; let ignore_nulls = false; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs, ordering_req, &physical_schema, name, ignore_nulls, false) + udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs, sort_exprs, ordering_req, &physical_schema, name, ignore_nulls, false) } } }).transpose()?.ok_or_else(|| { diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index ef462ac94b9a..3a4c35a93e16 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,12 +23,11 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - ApproxPercentileCont, ApproxPercentileContWithWeight, ArrayAgg, Avg, BinaryExpr, - BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, - CumeDist, DistinctArrayAgg, DistinctBitXor, Grouping, InListExpr, IsNotNullExpr, - IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, - OrderSensitiveArrayAgg, Rank, RankType, RowNumber, StringAgg, TryCastExpr, - WindowShift, + ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, + CastExpr, Column, Correlation, CumeDist, DistinctArrayAgg, DistinctBitXor, Grouping, + InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, + NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, RowNumber, + StringAgg, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -270,13 +269,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::Avg } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Correlation - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::ApproxPercentileCont - } else if aggr_expr - .downcast_ref::() - .is_some() - { - protobuf::AggregateFunction::ApproxPercentileContWithWeight } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::StringAgg } else if aggr_expr.downcast_ref::().is_some() { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index d0f1c4aade5e..a496e226855a 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -26,7 +26,6 @@ use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; -use datafusion_functions_aggregate::count::count_udaf; use prost::Message; use datafusion::datasource::provider::TableProviderFactory; @@ -34,10 +33,11 @@ use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; -use datafusion::functions_aggregate::approx_median::approx_median; +use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::expr_fn::{ - count, count_distinct, covar_pop, covar_samp, first_value, median, stddev, - stddev_pop, sum, var_pop, var_sample, + approx_median, approx_percentile_cont, approx_percentile_cont_with_weight, count, + count_distinct, covar_pop, covar_samp, first_value, median, stddev, stddev_pop, sum, + var_pop, var_sample, }; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; @@ -663,6 +663,8 @@ async fn roundtrip_expr_api() -> Result<()> { stddev(lit(2.2)), stddev_pop(lit(2.2)), approx_median(lit(2)), + approx_percentile_cont(lit(2), lit(0.5)), + approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)), ]; // ensure expressions created with the expr api can be round tripped @@ -1799,21 +1801,6 @@ fn roundtrip_count_distinct() { roundtrip_expr_test(test_expr, ctx); } -#[test] -fn roundtrip_approx_percentile_cont() { - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::ApproxPercentileCont, - vec![col("bananas"), lit(0.42_f32)], - false, - None, - None, - None, - )); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); -} - #[test] fn roundtrip_aggregate_udf() { #[derive(Debug)] diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index e517482f1db0..7f66cdbf7663 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -303,6 +303,7 @@ fn roundtrip_window() -> Result<()> { &args, &[], &[], + &[], &schema, "SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING", false, @@ -458,6 +459,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { &[col("b", &schema)?], &[], &[], + &[], &schema, "example_agg", false, diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 7ba1893bb11a..0a6def3d6f27 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -76,26 +76,26 @@ statement error DataFusion error: Schema error: Schema contains duplicate unqual SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as varchar)) count_c9_str FROM aggregate_test_100 # csv_query_approx_percentile_cont_with_weight -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT_WITH_WEIGHT\(Utf8, Int8, Float64\)'. You might need to add explicit type casts. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Utf8, Int8, Float64\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont_with_weight(c1, c2, 0.95) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT_WITH_WEIGHT\(Int16, Utf8, Float64\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Utf8, Float64\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont_with_weight(c3, c1, 0.95) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT_WITH_WEIGHT\(Int16, Int8, Utf8\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Int8, Utf8\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont_with_weight(c3, c2, c1) FROM aggregate_test_100 # csv_query_approx_percentile_cont_with_histogram_bins -statement error This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal \(got data type Int64\). +statement error DataFusion error: External error: This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal \(got data type Int64\)\. SELECT c1, approx_percentile_cont(c3, 0.95, -1000) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Int16, Float64, Utf8\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Float64, Utf8\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont(c3, 0.95, c1) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Int16, Float64, Float64\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Float64, Float64\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont(c3, 0.95, 111.1) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Float64, Float64, Float64\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Float64, Float64, Float64\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont(c12, 0.95, 111.1) FROM aggregate_test_100 # array agg can use order by