Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert stddev and stddev_pop to UDAF #10834

Merged
merged 11 commits into from
Jun 9, 2024
6 changes: 3 additions & 3 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ use datafusion_common::{
};
use datafusion_expr::lit;
use datafusion_expr::{
avg, count, max, min, stddev, utils::COUNT_STAR_EXPANSION,
TableProviderFilterPushDown, UNNAMED_TABLE,
avg, count, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown,
UNNAMED_TABLE,
};
use datafusion_expr::{case, is_null};
use datafusion_functions_aggregate::expr_fn::median;
use datafusion_functions_aggregate::expr_fn::sum;
use datafusion_functions_aggregate::expr_fn::{median, stddev};

use async_trait::async_trait;

Expand Down
13 changes: 0 additions & 13 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,6 @@ pub enum AggregateFunction {
NthValue,
/// Variance (Population)
VariancePop,
/// Standard Deviation (Sample)
Stddev,
/// Standard Deviation (Population)
StddevPop,
/// Correlation
Correlation,
/// Slope from linear regression
Expand Down Expand Up @@ -107,8 +103,6 @@ impl AggregateFunction {
ArrayAgg => "ARRAY_AGG",
NthValue => "NTH_VALUE",
VariancePop => "VAR_POP",
Stddev => "STDDEV",
StddevPop => "STDDEV_POP",
Correlation => "CORR",
RegrSlope => "REGR_SLOPE",
RegrIntercept => "REGR_INTERCEPT",
Expand Down Expand Up @@ -159,9 +153,6 @@ impl FromStr for AggregateFunction {
"string_agg" => AggregateFunction::StringAgg,
// statistical
"corr" => AggregateFunction::Correlation,
"stddev" => AggregateFunction::Stddev,
"stddev_pop" => AggregateFunction::StddevPop,
"stddev_samp" => AggregateFunction::Stddev,
"var_pop" => AggregateFunction::VariancePop,
"regr_slope" => AggregateFunction::RegrSlope,
"regr_intercept" => AggregateFunction::RegrIntercept,
Expand Down Expand Up @@ -231,8 +222,6 @@ impl AggregateFunction {
AggregateFunction::Correlation => {
correlation_return_type(&coerced_data_types[0])
}
AggregateFunction::Stddev => stddev_return_type(&coerced_data_types[0]),
AggregateFunction::StddevPop => stddev_return_type(&coerced_data_types[0]),
AggregateFunction::RegrSlope
| AggregateFunction::RegrIntercept
| AggregateFunction::RegrCount
Expand Down Expand Up @@ -304,8 +293,6 @@ impl AggregateFunction {
}
AggregateFunction::Avg
| AggregateFunction::VariancePop
| AggregateFunction::Stddev
| AggregateFunction::StddevPop
| AggregateFunction::ApproxMedian => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
Expand Down
12 changes: 0 additions & 12 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,18 +383,6 @@ pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr {
})
}

/// Create an expression to represent the stddev() aggregate function
pub fn stddev(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Stddev,
vec![expr],
false,
None,
None,
None,
))
}

/// Create a grouping set
pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
Expr::GroupingSet(GroupingSet::GroupingSets(exprs))
Expand Down
37 changes: 0 additions & 37 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,6 @@ pub fn coerce_types(
}
Ok(vec![Float64, Float64])
}
AggregateFunction::Stddev | AggregateFunction::StddevPop => {
if !is_stddev_support_arg_type(&input_types[0]) {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
);
}
Ok(vec![Float64])
}
AggregateFunction::Correlation => {
if !is_correlation_support_arg_type(&input_types[0]) {
return plan_err!(
Expand Down Expand Up @@ -408,15 +398,6 @@ pub fn correlation_return_type(arg_type: &DataType) -> Result<DataType> {
}
}

/// function return type of standard deviation
pub fn stddev_return_type(arg_type: &DataType) -> Result<DataType> {
if NUMERICS.contains(arg_type) {
Ok(DataType::Float64)
} else {
plan_err!("STDDEV does not support {arg_type:?}")
}
}

/// function return type of an average
pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
Expand Down Expand Up @@ -511,13 +492,6 @@ pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool {
)
}

pub fn is_stddev_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
arg_type if NUMERICS.contains(arg_type)
)
}

pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
Expand Down Expand Up @@ -664,17 +638,6 @@ mod tests {
Ok(())
}

#[test]
fn test_stddev_return_data_type() -> Result<()> {
let data_type = DataType::Float64;
let result_type = stddev_return_type(&data_type)?;
assert_eq!(DataType::Float64, result_type);

let data_type = DataType::Decimal128(36, 10);
assert!(stddev_return_type(&data_type).is_err());
Ok(())
}

#[test]
fn test_covariance_return_data_type() -> Result<()> {
let data_type = DataType::Float64;
Expand Down
5 changes: 5 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ pub mod macros;
pub mod covariance;
pub mod first_last;
pub mod median;
pub mod stddev;
pub mod sum;
pub mod variance;

Expand All @@ -74,6 +75,8 @@ pub mod expr_fn {
pub use super::first_last::first_value;
pub use super::first_last::last_value;
pub use super::median::median;
pub use super::stddev::stddev;
pub use super::stddev::stddev_pop;
pub use super::sum::sum;
pub use super::variance::var_sample;
}
Expand All @@ -88,6 +91,8 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
covariance::covar_pop_udaf(),
median::median_udaf(),
variance::var_samp_udaf(),
stddev::stddev_udaf(),
stddev::stddev_pop_udaf(),
]
}

Expand Down
Loading