From 8236565f69f89d67a9251338485cecd76f726fd8 Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Sat, 8 Jun 2024 22:15:02 +0800 Subject: [PATCH 01/10] add stddev and stddev_pot udaf --- datafusion/functions-aggregate/src/lib.rs | 3 + datafusion/functions-aggregate/src/stddev.rs | 335 ++++++++++++++++++ .../tests/cases/roundtrip_logical_plan.rs | 4 +- 3 files changed, 341 insertions(+), 1 deletion(-) create mode 100644 datafusion/functions-aggregate/src/stddev.rs diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index ff02d25ad00b..4660230445b9 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -58,6 +58,7 @@ pub mod macros; pub mod covariance; pub mod first_last; pub mod median; +pub mod stddev; pub mod sum; pub mod variance; @@ -74,6 +75,8 @@ pub mod expr_fn { pub use super::first_last::first_value; pub use super::first_last::last_value; pub use super::median::median; + pub use super::stddev::stddev; + pub use super::stddev::stddev_pop; pub use super::sum::sum; pub use super::variance::var_sample; } diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs new file mode 100644 index 000000000000..8797ca73211f --- /dev/null +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -0,0 +1,335 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::fmt::{Debug, Formatter}; + +use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; + +use datafusion_common::ScalarValue; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_physical_expr_common::aggregate::stats::StatsType; + +use crate::variance::VarianceAccumulator; + +make_udaf_expr_and_func!( + Stddev, + stddev, + expression, + "Compute the standard deviation of a set of numbers", + stddev_udaf +); + +/// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression +pub struct Stddev { + signature: Signature, +} + +impl Debug for Stddev { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Stddev") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Stddev { + fn default() -> Self { + Self::new() + } +} + +impl Stddev { + /// Create a new STDDEV aggregate function + pub fn new() -> Self { + Self { + signature: Signature::numeric(1, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Stddev { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "stddev" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new( + format_state_name(args.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + format_state_name(args.name, "mean"), + DataType::Float64, + true, + ), + Field::new(format_state_name(args.name, "m2"), DataType::Float64, true), + ]) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?)) + } +} + +make_udaf_expr_and_func!( + StddevPop, + stddev_pop, + expression, + "Compute the population standard deviation of a set of numbers", + stddev_pop_udaf +); + +/// STDDEV_POP population aggregate expression +#[derive(Debug)] +pub struct StddevPop { + signature: Signature, +} + +impl Default for StddevPop { + fn default() -> Self { + Self::new() + } +} + +impl StddevPop { + /// Create a new STDDEV aggregate function + pub fn new() -> Self { + Self { + signature: Signature::numeric(1, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for StddevPop { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "stddev_pop" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new( + format_state_name(args.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + format_state_name(args.name, "mean"), + DataType::Float64, + true, + ), + Field::new(format_state_name(args.name, "m2"), DataType::Float64, true), + ]) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } +} + +/// An accumulator to compute the average +#[derive(Debug)] +pub struct StddevAccumulator { + variance: VarianceAccumulator, +} + +impl StddevAccumulator { + /// Creates a new `StddevAccumulator` + pub fn try_new(s_type: StatsType) -> Result { + Ok(Self { + variance: VarianceAccumulator::try_new(s_type)?, + }) + } + + pub fn get_m2(&self) -> f64 { + self.variance.get_m2() + } +} + +impl Accumulator for StddevAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.variance.get_count()), + ScalarValue::from(self.variance.get_mean()), + ScalarValue::from(self.variance.get_m2()), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.variance.update_batch(values) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.variance.retract_batch(values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.variance.merge_batch(states) + } + + fn evaluate(&mut self) -> Result { + let variance = self.variance.evaluate()?; + dbg!(variance.clone()); + match variance { + ScalarValue::Float64(e) => { + if e.is_none() { + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64(e.map(|f| f.sqrt()))) + } + } + _ => internal_err!("Variance should be f64"), + } + } + + fn size(&self) -> usize { + std::mem::align_of_val(self) - std::mem::align_of_val(&self.variance) + + self.variance.size() + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::{array::*, datatypes::*}; + + use datafusion_expr::AggregateUDF; + use datafusion_physical_expr_common::aggregate::utils::get_accum_scalar_values_as_arrays; + use datafusion_physical_expr_common::expressions::column::col; + + use super::*; + + #[test] + fn stddev_f64_merge_1() -> Result<()> { + let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); + let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64])); + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); + + let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?; + + let agg1 = stddev_pop_udaf(); + let agg2 = stddev_pop_udaf(); + + let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?; + assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2)); + + Ok(()) + } + + #[test] + fn stddev_f64_merge_2() -> Result<()> { + let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); + let b = Arc::new(Float64Array::from(vec![None])); + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); + + let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?; + + let agg1 = stddev_pop_udaf(); + let agg2 = stddev_pop_udaf(); + + let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?; + assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2)); + + Ok(()) + } + + fn merge( + batch1: &RecordBatch, + batch2: &RecordBatch, + agg1: Arc, + agg2: Arc, + schema: &Schema, + ) -> Result { + let args1 = AccumulatorArgs { + data_type: &DataType::Float64, + schema, + ignore_nulls: false, + sort_exprs: &[], + name: "a", + is_distinct: false, + input_type: &DataType::Float64, + args_num: 1, + }; + + let args2 = AccumulatorArgs { + data_type: &DataType::Float64, + schema, + ignore_nulls: false, + sort_exprs: &[], + name: "a", + is_distinct: false, + input_type: &DataType::Float64, + args_num: 1, + }; + + let mut accum1 = agg1.accumulator(args1)?; + let mut accum2 = agg2.accumulator(args2)?; + + let value1 = vec![col("a", &schema)? + .evaluate(batch1) + .and_then(|v| v.into_array(batch1.num_rows()))?]; + let value2 = vec![col("a", &schema)? + .evaluate(batch2) + .and_then(|v| v.into_array(batch2.num_rows()))?]; + + accum1.update_batch(&value1)?; + accum2.update_batch(&value2)?; + let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?; + accum1.merge_batch(&state2)?; + let result = accum1.evaluate()?; + Ok(result) + } +} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index f32f4b04938f..d0e08033727d 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -32,7 +32,7 @@ use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::expr_fn::{ - covar_pop, covar_samp, first_value, median, sum, var_sample, + covar_pop, covar_samp, first_value, median, stddev, stddev_pop, sum, var_sample, }; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; @@ -653,6 +653,8 @@ async fn roundtrip_expr_api() -> Result<()> { sum(lit(1)), median(lit(2)), var_sample(lit(2.2)), + stddev(lit(2.2)), + stddev_pop(lit(2.2)), ]; // ensure expressions created with the expr api can be round tripped From 07f41f9d89d80cab174a07d111c8997da508cf05 Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Sat, 8 Jun 2024 23:35:52 +0800 Subject: [PATCH 02/10] remove aggregation function stddev and stddev_pop --- datafusion/core/src/dataframe/mod.rs | 4 +- datafusion/expr/src/aggregate_function.rs | 13 - datafusion/expr/src/expr_fn.rs | 12 - .../expr/src/type_coercion/aggregates.rs | 37 --- datafusion/functions-aggregate/src/stddev.rs | 16 +- .../physical-expr/src/aggregate/build_in.rs | 122 +-------- .../physical-expr/src/aggregate/stddev.rs | 258 +----------------- .../physical-expr/src/expressions/mod.rs | 1 - datafusion/proto/proto/datafusion.proto | 4 +- datafusion/proto/src/generated/pbjson.rs | 6 - datafusion/proto/src/generated/prost.rs | 8 +- .../proto/src/logical_plan/from_proto.rs | 2 - datafusion/proto/src/logical_plan/to_proto.rs | 6 - .../proto/src/physical_plan/to_proto.rs | 6 +- 14 files changed, 27 insertions(+), 468 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 5b1aef5d2b20..414e0c837faa 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -50,11 +50,11 @@ use datafusion_common::{ }; use datafusion_expr::lit; use datafusion_expr::{ - avg, count, max, min, stddev, utils::COUNT_STAR_EXPANSION, + avg, count, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; use datafusion_expr::{case, is_null}; -use datafusion_functions_aggregate::expr_fn::median; +use datafusion_functions_aggregate::expr_fn::{median, stddev}; use datafusion_functions_aggregate::expr_fn::sum; use async_trait::async_trait; diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 8f683cabe6d6..1f1335a242c5 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -51,10 +51,6 @@ pub enum AggregateFunction { NthValue, /// Variance (Population) VariancePop, - /// Standard Deviation (Sample) - Stddev, - /// Standard Deviation (Population) - StddevPop, /// Correlation Correlation, /// Slope from linear regression @@ -110,8 +106,6 @@ impl AggregateFunction { ArrayAgg => "ARRAY_AGG", NthValue => "NTH_VALUE", VariancePop => "VAR_POP", - Stddev => "STDDEV", - StddevPop => "STDDEV_POP", Correlation => "CORR", RegrSlope => "REGR_SLOPE", RegrIntercept => "REGR_INTERCEPT", @@ -163,9 +157,6 @@ impl FromStr for AggregateFunction { "string_agg" => AggregateFunction::StringAgg, // statistical "corr" => AggregateFunction::Correlation, - "stddev" => AggregateFunction::Stddev, - "stddev_pop" => AggregateFunction::StddevPop, - "stddev_samp" => AggregateFunction::Stddev, "var_pop" => AggregateFunction::VariancePop, "regr_slope" => AggregateFunction::RegrSlope, "regr_intercept" => AggregateFunction::RegrIntercept, @@ -236,8 +227,6 @@ impl AggregateFunction { AggregateFunction::Correlation => { correlation_return_type(&coerced_data_types[0]) } - AggregateFunction::Stddev => stddev_return_type(&coerced_data_types[0]), - AggregateFunction::StddevPop => stddev_return_type(&coerced_data_types[0]), AggregateFunction::RegrSlope | AggregateFunction::RegrIntercept | AggregateFunction::RegrCount @@ -310,8 +299,6 @@ impl AggregateFunction { AggregateFunction::Avg | AggregateFunction::Sum | AggregateFunction::VariancePop - | AggregateFunction::Stddev - | AggregateFunction::StddevPop | AggregateFunction::ApproxMedian => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index b1d9eb057753..0360478eac54 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -383,18 +383,6 @@ pub fn scalar_subquery(subquery: Arc) -> Expr { }) } -/// Create an expression to represent the stddev() aggregate function -pub fn stddev(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Stddev, - vec![expr], - false, - None, - None, - None, - )) -} - /// Create a grouping set pub fn grouping_set(exprs: Vec>) -> Expr { Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index b7004e200d70..e6e85ea84c00 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -183,16 +183,6 @@ pub fn coerce_types( } Ok(vec![Float64, Float64]) } - AggregateFunction::Stddev | AggregateFunction::StddevPop => { - if !is_stddev_support_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(vec![Float64]) - } AggregateFunction::Correlation => { if !is_correlation_support_arg_type(&input_types[0]) { return plan_err!( @@ -430,15 +420,6 @@ pub fn correlation_return_type(arg_type: &DataType) -> Result { } } -/// function return type of standard deviation -pub fn stddev_return_type(arg_type: &DataType) -> Result { - if NUMERICS.contains(arg_type) { - Ok(DataType::Float64) - } else { - plan_err!("STDDEV does not support {arg_type:?}") - } -} - /// function return type of an average pub fn avg_return_type(arg_type: &DataType) -> Result { match arg_type { @@ -533,13 +514,6 @@ pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool { ) } -pub fn is_stddev_support_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - ) -} - pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, @@ -704,17 +678,6 @@ mod tests { Ok(()) } - #[test] - fn test_stddev_return_data_type() -> Result<()> { - let data_type = DataType::Float64; - let result_type = stddev_return_type(&data_type)?; - assert_eq!(DataType::Float64, result_type); - - let data_type = DataType::Decimal128(36, 10); - assert!(stddev_return_type(&data_type).is_err()); - Ok(()) - } - #[test] fn test_covariance_return_data_type() -> Result<()> { let data_type = DataType::Float64; diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 8797ca73211f..24225c908a3d 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -42,6 +42,7 @@ make_udaf_expr_and_func!( /// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression pub struct Stddev { signature: Signature, + alias: Vec } impl Debug for Stddev { @@ -64,6 +65,7 @@ impl Stddev { pub fn new() -> Self { Self { signature: Signature::numeric(1, Volatility::Immutable), + alias: vec!["stddev_samp".to_string()] } } } @@ -102,9 +104,16 @@ impl AggregateUDFImpl for Stddev { ]) } - fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return internal_err!("STDDEV_POP(DISTINCT) aggregations are not available"); + } Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?)) } + + fn aliases(&self) -> &[String] { + &self.alias + } } make_udaf_expr_and_func!( @@ -166,7 +175,10 @@ impl AggregateUDFImpl for StddevPop { ]) } - fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return internal_err!("STDDEV_POP(DISTINCT) aggregations are not available"); + } Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 07409dd1f4dc..f09ddf4c026d 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -166,22 +166,6 @@ pub fn create_aggregate_expr( (AggregateFunction::VariancePop, true) => { return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available"); } - (AggregateFunction::Stddev, false) => Arc::new(expressions::Stddev::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::Stddev, true) => { - return not_impl_err!("STDDEV(DISTINCT) aggregations are not available"); - } - (AggregateFunction::StddevPop, false) => Arc::new(expressions::StddevPop::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::StddevPop, true) => { - return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available"); - } (AggregateFunction::Correlation, false) => { Arc::new(expressions::Correlation::new( input_phy_exprs[0].clone(), @@ -359,13 +343,13 @@ pub fn create_aggregate_expr( #[cfg(test)] mod tests { use arrow::datatypes::{DataType, Field}; - use expressions::{StddevPop, VariancePop}; + use expressions::VariancePop; use super::*; use crate::expressions::{ try_cast, ApproxDistinct, ApproxMedian, ApproxPercentileCont, ArrayAgg, Avg, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Count, DistinctArrayAgg, DistinctCount, - Max, Min, Stddev, + Max, Min, }; use datafusion_common::{plan_err, DataFusionError, ScalarValue}; @@ -750,82 +734,6 @@ mod tests { Ok(()) } - #[test] - fn test_stddev_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Stddev]; - let data_types = vec![ - DataType::UInt32, - DataType::UInt64, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::Stddev { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ) - } - } - } - Ok(()) - } - - #[test] - fn test_stddev_pop_expr() -> Result<()> { - let funcs = vec![AggregateFunction::StddevPop]; - let data_types = vec![ - DataType::UInt32, - DataType::UInt64, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::StddevPop { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ) - } - } - } - Ok(()) - } - #[test] fn test_median_expr() -> Result<()> { let funcs = vec![AggregateFunction::ApproxMedian]; @@ -942,32 +850,6 @@ mod tests { assert!(observed.is_err()); } - #[test] - fn test_stddev_return_type() -> Result<()> { - let observed = AggregateFunction::Stddev.return_type(&[DataType::Float32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Stddev.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Stddev.return_type(&[DataType::Int32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Stddev.return_type(&[DataType::UInt32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Stddev.return_type(&[DataType::Int64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_stddev_no_utf8() { - let observed = AggregateFunction::Stddev.return_type(&[DataType::Utf8]); - assert!(observed.is_err()); - } - // Helper function // Create aggregate expr with type coercion fn create_physical_agg_expr_for_test( diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs index ec8d8cea67c4..3891d055fcf1 100644 --- a/datafusion/physical-expr/src/aggregate/stddev.rs +++ b/datafusion/physical-expr/src/aggregate/stddev.rs @@ -17,168 +17,14 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use std::any::Any; -use std::sync::Arc; +use arrow::array::ArrayRef; -use crate::aggregate::stats::StatsType; -use crate::aggregate::utils::down_cast_any_ref; -use crate::aggregate::variance::VarianceAccumulator; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; -use datafusion_common::ScalarValue; use datafusion_common::{internal_err, Result}; +use datafusion_common::ScalarValue; use datafusion_expr::Accumulator; -/// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression -#[derive(Debug)] -pub struct Stddev { - name: String, - expr: Arc, -} - -/// STDDEV_POP population aggregate expression -#[derive(Debug)] -pub struct StddevPop { - name: String, - expr: Arc, -} - -impl Stddev { - /// Create a new STDDEV aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - // the result of stddev just support FLOAT64 and Decimal data type. - assert!(matches!(data_type, DataType::Float64)); - Self { - name: name.into(), - expr, - } - } -} - -impl AggregateExpr for Stddev { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?)) - } - - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "count"), - DataType::UInt64, - true, - ), - Field::new( - format_state_name(&self.name, "mean"), - DataType::Float64, - true, - ), - Field::new(format_state_name(&self.name, "m2"), DataType::Float64, true), - ]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for Stddev { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.name == x.name && self.expr.eq(&x.expr)) - .unwrap_or(false) - } -} - -impl StddevPop { - /// Create a new STDDEV aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - // the result of stddev just support FLOAT64 and Decimal data type. - assert!(matches!(data_type, DataType::Float64)); - Self { - name: name.into(), - expr, - } - } -} - -impl AggregateExpr for StddevPop { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) - } - - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "count"), - DataType::UInt64, - true, - ), - Field::new( - format_state_name(&self.name, "mean"), - DataType::Float64, - true, - ), - Field::new(format_state_name(&self.name, "m2"), DataType::Float64, true), - ]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for StddevPop { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.name == x.name && self.expr.eq(&x.expr)) - .unwrap_or(false) - } -} +use crate::aggregate::stats::StatsType; +use crate::aggregate::variance::VarianceAccumulator; /// An accumulator to compute the average #[derive(Debug)] @@ -239,99 +85,3 @@ impl Accumulator for StddevAccumulator { + self.variance.size() } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::aggregate::utils::get_accum_scalar_values_as_arrays; - use crate::expressions::col; - use arrow::{array::*, datatypes::*}; - - #[test] - fn stddev_f64_merge_1() -> Result<()> { - let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); - let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64])); - - let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); - - let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?; - - let agg1 = Arc::new(StddevPop::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let agg2 = Arc::new(StddevPop::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let actual = merge(&batch1, &batch2, agg1, agg2)?; - assert!(actual == ScalarValue::from(std::f64::consts::SQRT_2)); - - Ok(()) - } - - #[test] - fn stddev_f64_merge_2() -> Result<()> { - let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - let b = Arc::new(Float64Array::from(vec![None])); - - let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); - - let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?; - - let agg1 = Arc::new(StddevPop::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let agg2 = Arc::new(StddevPop::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let actual = merge(&batch1, &batch2, agg1, agg2)?; - assert!(actual == ScalarValue::from(std::f64::consts::SQRT_2)); - - Ok(()) - } - - fn merge( - batch1: &RecordBatch, - batch2: &RecordBatch, - agg1: Arc, - agg2: Arc, - ) -> Result { - let mut accum1 = agg1.create_accumulator()?; - let mut accum2 = agg2.create_accumulator()?; - let expr1 = agg1.expressions(); - let expr2 = agg2.expressions(); - - let values1 = expr1 - .iter() - .map(|e| { - e.evaluate(batch1) - .and_then(|v| v.into_array(batch1.num_rows())) - }) - .collect::>>()?; - let values2 = expr2 - .iter() - .map(|e| { - e.evaluate(batch2) - .and_then(|v| v.into_array(batch2.num_rows())) - }) - .collect::>>()?; - accum1.update_batch(&values1)?; - accum2.update_batch(&values2)?; - let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?; - accum1.merge_batch(&state2)?; - accum1.evaluate() - } -} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 324699af5b5c..fdc821840d87 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -56,7 +56,6 @@ pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; pub use crate::aggregate::nth_value::NthValueAgg; pub use crate::aggregate::regr::{Regr, RegrType}; pub use crate::aggregate::stats::StatsType; -pub use crate::aggregate::stddev::{Stddev, StddevPop}; pub use crate::aggregate::string_agg::StringAgg; pub use crate::aggregate::sum::Sum; pub use crate::aggregate::sum_distinct::DistinctSum; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index f8d229f48dc4..3f2ed8cc89e5 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -483,8 +483,8 @@ enum AggregateFunction { VARIANCE_POP = 8; // COVARIANCE = 9; // COVARIANCE_POP = 10; - STDDEV = 11; - STDDEV_POP = 12; + // STDDEV = 11; + // STDDEV_POP = 12; CORRELATION = 13; APPROX_PERCENTILE_CONT = 14; APPROX_MEDIAN = 15; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 6de030679c80..138ef5c852c7 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -540,8 +540,6 @@ impl serde::Serialize for AggregateFunction { Self::ApproxDistinct => "APPROX_DISTINCT", Self::ArrayAgg => "ARRAY_AGG", Self::VariancePop => "VARIANCE_POP", - Self::Stddev => "STDDEV", - Self::StddevPop => "STDDEV_POP", Self::Correlation => "CORRELATION", Self::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", Self::ApproxMedian => "APPROX_MEDIAN", @@ -582,8 +580,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "APPROX_DISTINCT", "ARRAY_AGG", "VARIANCE_POP", - "STDDEV", - "STDDEV_POP", "CORRELATION", "APPROX_PERCENTILE_CONT", "APPROX_MEDIAN", @@ -653,8 +649,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "APPROX_DISTINCT" => Ok(AggregateFunction::ApproxDistinct), "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), "VARIANCE_POP" => Ok(AggregateFunction::VariancePop), - "STDDEV" => Ok(AggregateFunction::Stddev), - "STDDEV_POP" => Ok(AggregateFunction::StddevPop), "CORRELATION" => Ok(AggregateFunction::Correlation), "APPROX_PERCENTILE_CONT" => Ok(AggregateFunction::ApproxPercentileCont), "APPROX_MEDIAN" => Ok(AggregateFunction::ApproxMedian), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index e397f3545986..5595ed853cad 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1927,8 +1927,8 @@ pub enum AggregateFunction { VariancePop = 8, /// COVARIANCE = 9; /// COVARIANCE_POP = 10; - Stddev = 11, - StddevPop = 12, + // Stddev = 11, + // StddevPop = 12, Correlation = 13, ApproxPercentileCont = 14, ApproxMedian = 15, @@ -1967,8 +1967,6 @@ impl AggregateFunction { AggregateFunction::ApproxDistinct => "APPROX_DISTINCT", AggregateFunction::ArrayAgg => "ARRAY_AGG", AggregateFunction::VariancePop => "VARIANCE_POP", - AggregateFunction::Stddev => "STDDEV", - AggregateFunction::StddevPop => "STDDEV_POP", AggregateFunction::Correlation => "CORRELATION", AggregateFunction::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", AggregateFunction::ApproxMedian => "APPROX_MEDIAN", @@ -2005,8 +2003,6 @@ impl AggregateFunction { "APPROX_DISTINCT" => Some(Self::ApproxDistinct), "ARRAY_AGG" => Some(Self::ArrayAgg), "VARIANCE_POP" => Some(Self::VariancePop), - "STDDEV" => Some(Self::Stddev), - "STDDEV_POP" => Some(Self::StddevPop), "CORRELATION" => Some(Self::Correlation), "APPROX_PERCENTILE_CONT" => Some(Self::ApproxPercentileCont), "APPROX_MEDIAN" => Some(Self::ApproxMedian), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index f8a78bdbdced..90337e8d474a 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -150,8 +150,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::ApproxDistinct => Self::ApproxDistinct, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, protobuf::AggregateFunction::VariancePop => Self::VariancePop, - protobuf::AggregateFunction::Stddev => Self::Stddev, - protobuf::AggregateFunction::StddevPop => Self::StddevPop, protobuf::AggregateFunction::Correlation => Self::Correlation, protobuf::AggregateFunction::RegrSlope => Self::RegrSlope, protobuf::AggregateFunction::RegrIntercept => Self::RegrIntercept, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 15d0d6dd491d..454a2de99885 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -121,8 +121,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::ApproxDistinct => Self::ApproxDistinct, AggregateFunction::ArrayAgg => Self::ArrayAgg, AggregateFunction::VariancePop => Self::VariancePop, - AggregateFunction::Stddev => Self::Stddev, - AggregateFunction::StddevPop => Self::StddevPop, AggregateFunction::Correlation => Self::Correlation, AggregateFunction::RegrSlope => Self::RegrSlope, AggregateFunction::RegrIntercept => Self::RegrIntercept, @@ -420,10 +418,6 @@ pub fn serialize_expr( AggregateFunction::VariancePop => { protobuf::AggregateFunction::VariancePop } - AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, - AggregateFunction::StddevPop => { - protobuf::AggregateFunction::StddevPop - } AggregateFunction::Correlation => { protobuf::AggregateFunction::Correlation } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 834f59abb10d..801b6815baf9 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -28,7 +28,7 @@ use datafusion::physical_plan::expressions::{ CastExpr, Column, Correlation, Count, CumeDist, DistinctArrayAgg, DistinctBitXor, DistinctCount, DistinctSum, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, - OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, + OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, RowNumber, StringAgg, Sum, TryCastExpr, VariancePop, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; @@ -283,10 +283,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::Avg } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::VariancePop - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Stddev - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::StddevPop } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Correlation } else if let Some(regr_expr) = aggr_expr.downcast_ref::() { From 9ec6a3474a227e1086d97408e0c4daf5c135ed00 Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Sun, 9 Jun 2024 00:23:35 +0800 Subject: [PATCH 03/10] register func and modified return type --- datafusion/functions-aggregate/src/lib.rs | 2 ++ datafusion/functions-aggregate/src/stddev.rs | 27 +++++++++++++++----- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 4660230445b9..2f58b9afaccd 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -91,6 +91,8 @@ pub fn all_default_aggregate_functions() -> Vec> { covariance::covar_pop_udaf(), median::median_udaf(), variance::var_samp_udaf(), + stddev::stddev_udaf(), + stddev::stddev_pop_udaf(), ] } diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 24225c908a3d..10e3587d0d25 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -22,7 +22,7 @@ use std::fmt::{Debug, Formatter}; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; -use datafusion_common::ScalarValue; +use datafusion_common::{plan_err, ScalarValue}; use datafusion_common::{internal_err, Result}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; @@ -85,7 +85,11 @@ impl AggregateUDFImpl for Stddev { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + if !arg_types[0].is_numeric() { + return plan_err!("Stddev requires numeric input types"); + } + + Ok(DataType::Float64) } fn state_fields(&self, args: StateFieldsArgs) -> Result> { @@ -125,11 +129,19 @@ make_udaf_expr_and_func!( ); /// STDDEV_POP population aggregate expression -#[derive(Debug)] pub struct StddevPop { signature: Signature, } +impl Debug for StddevPop { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StddevPop") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + impl Default for StddevPop { fn default() -> Self { Self::new() @@ -137,7 +149,7 @@ impl Default for StddevPop { } impl StddevPop { - /// Create a new STDDEV aggregate function + /// Create a new STDDEV_POP aggregate function pub fn new() -> Self { Self { signature: Signature::numeric(1, Volatility::Immutable), @@ -183,7 +195,11 @@ impl AggregateUDFImpl for StddevPop { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + if !arg_types[0].is_numeric() { + return plan_err!("StddevPop requires numeric input types"); + } + + Ok(DataType::Float64) } } @@ -229,7 +245,6 @@ impl Accumulator for StddevAccumulator { fn evaluate(&mut self) -> Result { let variance = self.variance.evaluate()?; - dbg!(variance.clone()); match variance { ScalarValue::Float64(e) => { if e.is_none() { From 8f8693312a754802bb514084e4452c175e99278f Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Sun, 9 Jun 2024 00:27:43 +0800 Subject: [PATCH 04/10] cargo fmt --- datafusion/core/src/dataframe/mod.rs | 6 +++--- datafusion/functions-aggregate/src/stddev.rs | 6 +++--- datafusion/physical-expr/src/aggregate/stddev.rs | 2 +- datafusion/proto/src/physical_plan/to_proto.rs | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 8545ce55801a..06a85d303687 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -50,12 +50,12 @@ use datafusion_common::{ }; use datafusion_expr::lit; use datafusion_expr::{ - avg, count, max, min, utils::COUNT_STAR_EXPANSION, - TableProviderFilterPushDown, UNNAMED_TABLE, + avg, count, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, + UNNAMED_TABLE, }; use datafusion_expr::{case, is_null}; -use datafusion_functions_aggregate::expr_fn::{median, stddev}; use datafusion_functions_aggregate::expr_fn::sum; +use datafusion_functions_aggregate::expr_fn::{median, stddev}; use async_trait::async_trait; diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 10e3587d0d25..8ed9779478ea 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -22,8 +22,8 @@ use std::fmt::{Debug, Formatter}; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; -use datafusion_common::{plan_err, ScalarValue}; use datafusion_common::{internal_err, Result}; +use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; @@ -42,7 +42,7 @@ make_udaf_expr_and_func!( /// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression pub struct Stddev { signature: Signature, - alias: Vec + alias: Vec, } impl Debug for Stddev { @@ -65,7 +65,7 @@ impl Stddev { pub fn new() -> Self { Self { signature: Signature::numeric(1, Volatility::Immutable), - alias: vec!["stddev_samp".to_string()] + alias: vec!["stddev_samp".to_string()], } } } diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs index 3891d055fcf1..3ade67b51905 100644 --- a/datafusion/physical-expr/src/aggregate/stddev.rs +++ b/datafusion/physical-expr/src/aggregate/stddev.rs @@ -19,8 +19,8 @@ use arrow::array::ArrayRef; -use datafusion_common::{internal_err, Result}; use datafusion_common::ScalarValue; +use datafusion_common::{internal_err, Result}; use datafusion_expr::Accumulator; use crate::aggregate::stats::StatsType; diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 88ac8490f77c..66405d4b9af6 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -28,8 +28,8 @@ use datafusion::physical_plan::expressions::{ CastExpr, Column, Correlation, Count, CumeDist, DistinctArrayAgg, DistinctBitXor, DistinctCount, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, - RankType, Regr, RegrType, RowNumber, StringAgg, TryCastExpr, - VariancePop, WindowShift, + RankType, Regr, RegrType, RowNumber, StringAgg, TryCastExpr, VariancePop, + WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; From 749aa429b71f8d01c99f29f108348ddd40f72a08 Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Sun, 9 Jun 2024 00:37:02 +0800 Subject: [PATCH 05/10] regen proto --- datafusion/proto/src/generated/prost.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index e8896d0ceb15..1b38168ba1d2 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1927,8 +1927,8 @@ pub enum AggregateFunction { VariancePop = 8, /// COVARIANCE = 9; /// COVARIANCE_POP = 10; - // Stddev = 11, - // StddevPop = 12, + /// STDDEV = 11; + /// STDDEV_POP = 12; Correlation = 13, ApproxPercentileCont = 14, ApproxMedian = 15, From 9453ca3ba844789eb6a45ab7785bff48e1236aa2 Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Sun, 9 Jun 2024 00:41:23 +0800 Subject: [PATCH 06/10] cargo clippy --- datafusion/functions-aggregate/src/stddev.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 8ed9779478ea..af7078cc2deb 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -345,10 +345,10 @@ mod tests { let mut accum1 = agg1.accumulator(args1)?; let mut accum2 = agg2.accumulator(args2)?; - let value1 = vec![col("a", &schema)? + let value1 = vec![col("a", schema)? .evaluate(batch1) .and_then(|v| v.into_array(batch1.num_rows()))?]; - let value2 = vec![col("a", &schema)? + let value2 = vec![col("a", schema)? .evaluate(batch2) .and_then(|v| v.into_array(batch2.num_rows()))?]; From 49759cfb14179571e4e3c6dc77f3b6eb2cb285ff Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Sun, 9 Jun 2024 00:59:45 +0800 Subject: [PATCH 07/10] fix window function support --- datafusion/functions-aggregate/src/stddev.rs | 12 ++++++++++++ datafusion/sqllogictest/test_files/functions.slt | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index af7078cc2deb..cc794f7b37de 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -115,6 +115,10 @@ impl AggregateUDFImpl for Stddev { Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?)) } + fn create_sliding_accumulator(&self, args: AccumulatorArgs) -> Result> { + self.accumulator(args) + } + fn aliases(&self) -> &[String] { &self.alias } @@ -194,6 +198,10 @@ impl AggregateUDFImpl for StddevPop { Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) } + fn create_sliding_accumulator(&self, args: AccumulatorArgs) -> Result> { + self.accumulator(args) + } + fn return_type(&self, arg_types: &[DataType]) -> Result { if !arg_types[0].is_numeric() { return plan_err!("StddevPop requires numeric input types"); @@ -261,6 +269,10 @@ impl Accumulator for StddevAccumulator { std::mem::align_of_val(self) - std::mem::align_of_val(&self.variance) + self.variance.size() } + + fn supports_retract_batch(&self) -> bool { + self.variance.supports_retract_batch() + } } #[cfg(test)] diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 56f30ad36296..f04d76822124 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -491,7 +491,7 @@ statement error Did you mean 'COUNT'? SELECT counter(*) from test; # Aggregate function -statement error Did you mean 'STDDEV'? +statement error Did you mean 'stddev'? SELECT STDEV(v1) from test; # Aggregate function From 3f0697499f887f30c3af20aa02cfba3217092611 Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Sun, 9 Jun 2024 01:01:08 +0800 Subject: [PATCH 08/10] cargo fmt --- datafusion/functions-aggregate/src/stddev.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index cc794f7b37de..a77eb67f6cb0 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -115,7 +115,10 @@ impl AggregateUDFImpl for Stddev { Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?)) } - fn create_sliding_accumulator(&self, args: AccumulatorArgs) -> Result> { + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { self.accumulator(args) } @@ -198,7 +201,10 @@ impl AggregateUDFImpl for StddevPop { Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) } - fn create_sliding_accumulator(&self, args: AccumulatorArgs) -> Result> { + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { self.accumulator(args) } From 671eff16cebe5355a463df6ebdc25aae4f03d194 Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Sun, 9 Jun 2024 02:05:54 +0800 Subject: [PATCH 09/10] throw not_impl_err instead --- datafusion/functions-aggregate/src/stddev.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index a77eb67f6cb0..dafbaf3b62e0 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -22,7 +22,7 @@ use std::fmt::{Debug, Formatter}; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; @@ -110,7 +110,7 @@ impl AggregateUDFImpl for Stddev { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { if acc_args.is_distinct { - return internal_err!("STDDEV_POP(DISTINCT) aggregations are not available"); + return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available"); } Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?)) } @@ -196,7 +196,7 @@ impl AggregateUDFImpl for StddevPop { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { if acc_args.is_distinct { - return internal_err!("STDDEV_POP(DISTINCT) aggregations are not available"); + return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available"); } Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) } From 6d3e10621449fc7e59d197931bd81bb0e899a1a4 Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Sun, 9 Jun 2024 10:55:41 +0800 Subject: [PATCH 10/10] use default sliding accumulator --- datafusion/functions-aggregate/src/stddev.rs | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index dafbaf3b62e0..4c3effe7650a 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -115,13 +115,6 @@ impl AggregateUDFImpl for Stddev { Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?)) } - fn create_sliding_accumulator( - &self, - args: AccumulatorArgs, - ) -> Result> { - self.accumulator(args) - } - fn aliases(&self) -> &[String] { &self.alias } @@ -201,13 +194,6 @@ impl AggregateUDFImpl for StddevPop { Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) } - fn create_sliding_accumulator( - &self, - args: AccumulatorArgs, - ) -> Result> { - self.accumulator(args) - } - fn return_type(&self, arg_types: &[DataType]) -> Result { if !arg_types[0].is_numeric() { return plan_err!("StddevPop requires numeric input types");