Skip to content

Commit

Permalink
Convert stddev and stddev_pop to UDAF (apache#10834)
Browse files Browse the repository at this point in the history
* add stddev and stddev_pot udaf

* remove aggregation function stddev and stddev_pop

* register func and modified return type

* cargo fmt

* regen proto

* cargo clippy

* fix window function support

* cargo fmt

* throw not_impl_err instead

* use default sliding accumulator
  • Loading branch information
goldmedal authored and findepi committed Jul 16, 2024
1 parent 50a1c6c commit 87dca5f
Show file tree
Hide file tree
Showing 17 changed files with 389 additions and 469 deletions.
6 changes: 3 additions & 3 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ use datafusion_common::{
};
use datafusion_expr::lit;
use datafusion_expr::{
avg, count, max, min, stddev, utils::COUNT_STAR_EXPANSION,
TableProviderFilterPushDown, UNNAMED_TABLE,
avg, count, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown,
UNNAMED_TABLE,
};
use datafusion_expr::{case, is_null};
use datafusion_functions_aggregate::expr_fn::median;
use datafusion_functions_aggregate::expr_fn::sum;
use datafusion_functions_aggregate::expr_fn::{median, stddev};

use async_trait::async_trait;

Expand Down
13 changes: 0 additions & 13 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,6 @@ pub enum AggregateFunction {
NthValue,
/// Variance (Population)
VariancePop,
/// Standard Deviation (Sample)
Stddev,
/// Standard Deviation (Population)
StddevPop,
/// Correlation
Correlation,
/// Slope from linear regression
Expand Down Expand Up @@ -107,8 +103,6 @@ impl AggregateFunction {
ArrayAgg => "ARRAY_AGG",
NthValue => "NTH_VALUE",
VariancePop => "VAR_POP",
Stddev => "STDDEV",
StddevPop => "STDDEV_POP",
Correlation => "CORR",
RegrSlope => "REGR_SLOPE",
RegrIntercept => "REGR_INTERCEPT",
Expand Down Expand Up @@ -159,9 +153,6 @@ impl FromStr for AggregateFunction {
"string_agg" => AggregateFunction::StringAgg,
// statistical
"corr" => AggregateFunction::Correlation,
"stddev" => AggregateFunction::Stddev,
"stddev_pop" => AggregateFunction::StddevPop,
"stddev_samp" => AggregateFunction::Stddev,
"var_pop" => AggregateFunction::VariancePop,
"regr_slope" => AggregateFunction::RegrSlope,
"regr_intercept" => AggregateFunction::RegrIntercept,
Expand Down Expand Up @@ -231,8 +222,6 @@ impl AggregateFunction {
AggregateFunction::Correlation => {
correlation_return_type(&coerced_data_types[0])
}
AggregateFunction::Stddev => stddev_return_type(&coerced_data_types[0]),
AggregateFunction::StddevPop => stddev_return_type(&coerced_data_types[0]),
AggregateFunction::RegrSlope
| AggregateFunction::RegrIntercept
| AggregateFunction::RegrCount
Expand Down Expand Up @@ -304,8 +293,6 @@ impl AggregateFunction {
}
AggregateFunction::Avg
| AggregateFunction::VariancePop
| AggregateFunction::Stddev
| AggregateFunction::StddevPop
| AggregateFunction::ApproxMedian => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
Expand Down
12 changes: 0 additions & 12 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,18 +383,6 @@ pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr {
})
}

/// Create an expression to represent the stddev() aggregate function
pub fn stddev(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Stddev,
vec![expr],
false,
None,
None,
None,
))
}

/// Create a grouping set
pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
Expr::GroupingSet(GroupingSet::GroupingSets(exprs))
Expand Down
37 changes: 0 additions & 37 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,6 @@ pub fn coerce_types(
}
Ok(vec![Float64, Float64])
}
AggregateFunction::Stddev | AggregateFunction::StddevPop => {
if !is_stddev_support_arg_type(&input_types[0]) {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
);
}
Ok(vec![Float64])
}
AggregateFunction::Correlation => {
if !is_correlation_support_arg_type(&input_types[0]) {
return plan_err!(
Expand Down Expand Up @@ -408,15 +398,6 @@ pub fn correlation_return_type(arg_type: &DataType) -> Result<DataType> {
}
}

/// function return type of standard deviation
pub fn stddev_return_type(arg_type: &DataType) -> Result<DataType> {
if NUMERICS.contains(arg_type) {
Ok(DataType::Float64)
} else {
plan_err!("STDDEV does not support {arg_type:?}")
}
}

/// function return type of an average
pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
Expand Down Expand Up @@ -511,13 +492,6 @@ pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool {
)
}

pub fn is_stddev_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
arg_type if NUMERICS.contains(arg_type)
)
}

pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
Expand Down Expand Up @@ -664,17 +638,6 @@ mod tests {
Ok(())
}

#[test]
fn test_stddev_return_data_type() -> Result<()> {
let data_type = DataType::Float64;
let result_type = stddev_return_type(&data_type)?;
assert_eq!(DataType::Float64, result_type);

let data_type = DataType::Decimal128(36, 10);
assert!(stddev_return_type(&data_type).is_err());
Ok(())
}

#[test]
fn test_covariance_return_data_type() -> Result<()> {
let data_type = DataType::Float64;
Expand Down
5 changes: 5 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ pub mod macros;
pub mod covariance;
pub mod first_last;
pub mod median;
pub mod stddev;
pub mod sum;
pub mod variance;

Expand All @@ -74,6 +75,8 @@ pub mod expr_fn {
pub use super::first_last::first_value;
pub use super::first_last::last_value;
pub use super::median::median;
pub use super::stddev::stddev;
pub use super::stddev::stddev_pop;
pub use super::sum::sum;
pub use super::variance::var_sample;
}
Expand All @@ -88,6 +91,8 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
covariance::covar_pop_udaf(),
median::median_udaf(),
variance::var_samp_udaf(),
stddev::stddev_udaf(),
stddev::stddev_pop_udaf(),
]
}

Expand Down
Loading

0 comments on commit 87dca5f

Please sign in to comment.