diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index 452f1862b274..666ea73027b3 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -524,22 +524,31 @@ pub fn aggregate_functional_dependencies( } } - // If we have a single GROUP BY key, we can guarantee uniqueness after + // When we have a GROUP BY key, we can guarantee uniqueness after // aggregation: - if group_by_expr_names.len() == 1 { - // If `source_indices` contain 0, delete this functional dependency - // as it will be added anyway with mode `Dependency::Single`: - aggregate_func_dependencies.retain(|item| !item.source_indices.contains(&0)); - // Add a new functional dependency associated with the whole table: - aggregate_func_dependencies.push( - // Use nullable property of the group by expression - FunctionalDependence::new( - vec![0], - target_indices, - aggr_fields[0].is_nullable(), - ) - .with_mode(Dependency::Single), - ); + if !group_by_expr_names.is_empty() { + let count = group_by_expr_names.len(); + let source_indices = (0..count).collect::>(); + let nullable = source_indices + .iter() + .any(|idx| aggr_fields[*idx].is_nullable()); + // If GROUP BY expressions do not already act as a determinant: + if !aggregate_func_dependencies.iter().any(|item| { + // If `item.source_indices` is a subset of GROUP BY expressions, we shouldn't add + // them since `item.source_indices` defines this relation already. + + // The following simple comparison is working well because + // GROUP BY expressions come here as a prefix. + item.source_indices.iter().all(|idx| idx < &count) + }) { + // Add a new functional dependency associated with the whole table: + // Use nullable property of the GROUP BY expression: + aggregate_func_dependencies.push( + // Use nullable property of the GROUP BY expression: + FunctionalDependence::new(source_indices, target_indices, nullable) + .with_mode(Dependency::Single), + ); + } } FunctionalDependencies::new(aggregate_func_dependencies) } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 430517121f2a..f73eeacfbf0e 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -77,6 +77,21 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { match plan { LogicalPlan::Distinct(Distinct::All(input)) => { let group_expr = expand_wildcard(input.schema(), &input, None)?; + + let field_count = input.schema().fields().len(); + for dep in input.schema().functional_dependencies().iter() { + // If distinct is exactly the same with a previous GROUP BY, we can + // simply remove it: + if dep.source_indices[..field_count] + .iter() + .enumerate() + .all(|(idx, f_idx)| idx == *f_idx) + { + return Ok(Transformed::yes(input.as_ref().clone())); + } + } + + // Replace with aggregation: let aggr_plan = LogicalPlan::Aggregate(Aggregate::try_new( input, group_expr, @@ -165,3 +180,78 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { Some(BottomUp) } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; + use crate::test::*; + + use datafusion_common::Result; + use datafusion_expr::{ + col, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan, + }; + use datafusion_functions_aggregate::sum::sum; + + fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq( + Arc::new(ReplaceDistinctWithAggregate::new()), + plan.clone(), + expected, + ) + } + + #[test] + fn eliminate_redundant_distinct_simple() -> Result<()> { + let table_scan = test_table_scan().unwrap(); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], Vec::::new())? + .project(vec![col("c")])? + .distinct()? + .build()?; + + let expected = "Projection: test.c\n Aggregate: groupBy=[[test.c]], aggr=[[]]\n TableScan: test"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_redundant_distinct_pair() -> Result<()> { + let table_scan = test_table_scan().unwrap(); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a"), col("b")], Vec::::new())? + .project(vec![col("a"), col("b")])? + .distinct()? + .build()?; + + let expected = + "Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n TableScan: test"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn do_not_eliminate_distinct() -> Result<()> { + let table_scan = test_table_scan().unwrap(); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .distinct()? + .build()?; + + let expected = "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n TableScan: test"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn do_not_eliminate_distinct_with_aggr() -> Result<()> { + let table_scan = test_table_scan().unwrap(); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a"), col("b"), col("c")], vec![sum(col("c"))])? + .project(vec![col("a"), col("b")])? + .distinct()? + .build()?; + + let expected = + "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]]\n TableScan: test"; + assert_optimized_plan_equal(&plan, expected) + } +} diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index d776e6598cbe..31e7fd8c55f3 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -40,7 +40,7 @@ use hashbrown::HashSet; /// single distinct to group by optimizer rule /// ```text /// Before: -/// SELECT a, count(DINSTINCT b), sum(c) +/// SELECT a, count(DISTINCT b), sum(c) /// FROM t /// GROUP BY a /// diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index ee72289d66eb..abeeb767b948 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -4536,19 +4536,14 @@ EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; logical_plan 01)Limit: skip=0, fetch=5 02)--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] -03)----Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] -04)------TableScan: aggregate_test_100 projection=[c3] +03)----TableScan: aggregate_test_100 projection=[c3] physical_plan 01)GlobalLimitExec: skip=0, fetch=5 02)--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[5] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[5] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -06)----------AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[5] -07)------------CoalescePartitionsExec -08)--------------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[5] -09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true +06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true query I SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; @@ -4699,19 +4694,14 @@ EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; logical_plan 01)Limit: skip=0, fetch=5 02)--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] -03)----Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] -04)------TableScan: aggregate_test_100 projection=[c3] +03)----TableScan: aggregate_test_100 projection=[c3] physical_plan 01)GlobalLimitExec: skip=0, fetch=5 02)--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -06)----------AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[] -07)------------CoalescePartitionsExec -08)--------------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[] -09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true +06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true statement ok set datafusion.optimizer.enable_distinct_aggregation_soft_limit = true;