Skip to content

Commit

Permalink
Check for count aggregate, add test for sum
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Apr 15, 2024
1 parent d35a2c7 commit a05bcbe
Showing 1 changed file with 39 additions and 6 deletions.
45 changes: 39 additions & 6 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ use crate::utils::NamePreserver;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::Result;
use datafusion_expr::expr::{
AggregateFunction, AggregateFunctionDefinition, WindowFunction,
};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::{lit, Expr, LogicalPlan};
use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition};

/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
///
Expand Down Expand Up @@ -50,21 +53,40 @@ impl AnalyzerRule for CountWildcardRule {
fn is_wildcard(expr: &Expr) -> bool {
matches!(expr, Expr::Wildcard { qualifier: None })
}

fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
matches!(
&aggregate_function.func_def,
AggregateFunctionDefinition::BuiltIn(
datafusion_expr::aggregate_function::AggregateFunction::Count,
)
) && aggregate_function.args.len() == 1
&& is_wildcard(&aggregate_function.args[0])
}

fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
matches!(
&window_function.fun,
WindowFunctionDefinition::AggregateFunction(
datafusion_expr::aggregate_function::AggregateFunction::Count,
)
) && window_function.args.len() == 1
&& is_wildcard(&window_function.args[0])
}

fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
let name_preserver = NamePreserver::new(&plan);
plan.map_expressions(|expr| {
let original_name = name_preserver.save(&expr)?;
let transformed_expr = expr.transform_up(&|expr| match expr {
Expr::WindowFunction(mut window_function)
if window_function.args.len() == 1
&& is_wildcard(&window_function.args[0]) =>
if is_count_star_window_aggregate(&window_function) =>
{
window_function.args = vec![lit(COUNT_STAR_EXPANSION)];
Ok(Transformed::yes(Expr::WindowFunction(window_function)))
}
Expr::AggregateFunction(mut aggregate_function)
if aggregate_function.args.len() == 1
&& is_wildcard(&aggregate_function.args[0]) =>
if is_count_star_aggregate(&aggregate_function) =>
{
aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)];
Ok(Transformed::yes(Expr::AggregateFunction(
Expand All @@ -86,7 +108,7 @@ mod tests {
use datafusion_expr::expr::Sort;
use datafusion_expr::{
col, count, exists, expr, in_subquery, lit, logical_plan::LogicalPlanBuilder,
max, out_ref_col, scalar_subquery, wildcard, AggregateFunction, Expr,
max, out_ref_col, scalar_subquery, sum, wildcard, AggregateFunction, Expr,
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
use std::sync::Arc;
Expand Down Expand Up @@ -235,6 +257,17 @@ mod tests {
assert_plan_eq(&plan, expected)
}

#[test]
fn test_count_wildcard_on_non_count_aggregate() -> Result<()> {
let table_scan = test_table_scan()?;
let err = LogicalPlanBuilder::from(table_scan)
.aggregate(Vec::<Expr>::new(), vec![sum(wildcard())])
.unwrap_err()
.to_string();
assert!(err.contains("Error during planning: No function matches the given name and argument types 'SUM(Null)'."), "{err}");
Ok(())
}

#[test]
fn test_count_wildcard_on_nesting() -> Result<()> {
let table_scan = test_table_scan()?;
Expand Down

0 comments on commit a05bcbe

Please sign in to comment.