Skip to content

Commit

Permalink
migrate count to UDAF
Browse files Browse the repository at this point in the history
Builtin Count was removed upstream.

TBD whether we want to re-implement `count_star` with new API.

Ref: apache/datafusion#10893
  • Loading branch information
Michael-J-Ward committed Jul 24, 2024
1 parent 287279a commit f7bd619
Showing 1 changed file with 23 additions and 17 deletions.
40 changes: 23 additions & 17 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ use datafusion::functions_aggregate;
use datafusion_common::{Column, ScalarValue, TableReference};
use datafusion_expr::expr::Alias;
use datafusion_expr::{
aggregate_function,
expr::{
find_df_window_func, AggregateFunction, AggregateFunctionDefinition, Sort, WindowFunction,
},
Expand Down Expand Up @@ -326,21 +325,29 @@ fn col(name: &str) -> PyResult<PyExpr> {
})
}

/// Create a COUNT(1) aggregate expression
// TODO: do we want to create an equivalent?
// /// Create a COUNT(1) aggregate expression
// #[pyfunction]
// fn count_star() -> PyResult<PyExpr> {
// Ok(PyExpr {
// expr: Expr::AggregateFunction(AggregateFunction {
// func_def: datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn(
// aggregate_function::AggregateFunction::Count,
// ),
// args: vec![lit(1)],
// distinct: false,
// filter: None,
// order_by: None,
// null_treatment: None,
// }),
// })
// }

/// Wrapper for [`functions_aggregate::expr_fn::count`]
/// Count the number of non-null values in the column
#[pyfunction]
fn count_star() -> PyResult<PyExpr> {
Ok(PyExpr {
expr: Expr::AggregateFunction(AggregateFunction {
func_def: datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn(
aggregate_function::AggregateFunction::Count,
),
args: vec![lit(1)],
distinct: false,
filter: None,
order_by: None,
null_treatment: None,
}),
})
fn count(expr: PyExpr) -> PyExpr {
functions_aggregate::expr_fn::count(expr.expr).into()
}

/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
Expand Down Expand Up @@ -700,7 +707,6 @@ aggregate_function!(
aggregate_function!(array_agg, ArrayAgg);
aggregate_function!(avg, Avg);
aggregate_function!(corr, Correlation);
aggregate_function!(count, Count);
aggregate_function!(grouping, Grouping);
aggregate_function!(max, Max);
aggregate_function!(mean, Avg);
Expand Down Expand Up @@ -761,7 +767,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(cosh))?;
m.add_wrapped(wrap_pyfunction!(cot))?;
m.add_wrapped(wrap_pyfunction!(count))?;
m.add_wrapped(wrap_pyfunction!(count_star))?;
// m.add_wrapped(wrap_pyfunction!(count_star))?;
m.add_wrapped(wrap_pyfunction!(covar))?;
m.add_wrapped(wrap_pyfunction!(covar_pop))?;
m.add_wrapped(wrap_pyfunction!(covar_samp))?;
Expand Down

0 comments on commit f7bd619

Please sign in to comment.