diff --git a/datafusion/optimizer/src/eliminate_distinct.rs b/datafusion/optimizer/src/eliminate_distinct.rs index a11ef11fff34..6a583809bf40 100644 --- a/datafusion/optimizer/src/eliminate_distinct.rs +++ b/datafusion/optimizer/src/eliminate_distinct.rs @@ -57,27 +57,23 @@ impl OptimizerRule for EliminateDistinct { &self, plan: LogicalPlan, _config: &dyn OptimizerConfig, - ) -> Result< - datafusion_common::tree_node::Transformed, - datafusion_common::DataFusionError, - > { - match plan { - LogicalPlan::Distinct(Distinct::All(distinct)) => { - let fields = distinct.schema().fields(); - let all_fields = (0..fields.len()).collect::>(); - let func_deps = distinct.schema().functional_dependencies().clone(); - - for func_dep in func_deps.iter() { - if func_dep.source_indices == all_fields { - return Ok(Transformed::yes(distinct.inputs()[0].clone())); - } - } - Ok(Transformed::no(LogicalPlan::Distinct(Distinct::All( - distinct, - )))) + ) -> Result, datafusion_common::DataFusionError> { + let LogicalPlan::Distinct(Distinct::All(distinct)) = plan else { + return Ok(Transformed::no(plan)); + }; + + let fields = distinct.schema().fields(); + let all_fields = (0..fields.len()).collect::>(); + let func_deps = distinct.schema().functional_dependencies().clone(); + + for func_dep in func_deps.iter() { + if func_dep.source_indices == all_fields { + return Ok(Transformed::yes(distinct.inputs()[0].clone())); } - _ => Ok(Transformed::no(plan)), } + Ok(Transformed::no(LogicalPlan::Distinct(Distinct::All( + distinct, + )))) } } @@ -110,7 +106,6 @@ mod tests { .build()?; let expected = "Aggregate: groupBy=[[test.c]], aggr=[[]]\n TableScan: test"; - // No aggregate / scan / limit assert_optimized_plan_equal(&plan, expected) } @@ -123,9 +118,34 @@ mod tests { .distinct()? .build()?; - // No aggregate / scan / limit let expected = "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n TableScan: test"; assert_optimized_plan_equal(&plan, expected) } + + #[test] + fn do_not_eliminate_distinct() -> Result<()> { + let table_scan = test_table_scan().unwrap(); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .distinct()? + .build()?; + + let expected = "Distinct:\n Projection: test.a, test.b\n TableScan: test"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn do_not_eliminate_distinct_with_aggr() -> Result<()> { + let table_scan = test_table_scan().unwrap(); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a"), col("b"), col("c")], Vec::::new())? + .project(vec![col("a"), col("b")])? + .distinct()? + .build()?; + + let expected = + "Distinct:\n Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[]]\n TableScan: test"; + assert_optimized_plan_equal(&plan, expected) + } }