Skip to content

Commit

Permalink
simplify function
Browse files Browse the repository at this point in the history
add additional tests for not removing cases
  • Loading branch information
Mert Akkaya committed Aug 2, 2024
1 parent 289d157 commit a8e1b05
Showing 1 changed file with 41 additions and 21 deletions.
62 changes: 41 additions & 21 deletions datafusion/optimizer/src/eliminate_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,27 +57,23 @@ impl OptimizerRule for EliminateDistinct {
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<
datafusion_common::tree_node::Transformed<LogicalPlan>,
datafusion_common::DataFusionError,
> {
match plan {
LogicalPlan::Distinct(Distinct::All(distinct)) => {
let fields = distinct.schema().fields();
let all_fields = (0..fields.len()).collect::<Vec<_>>();
let func_deps = distinct.schema().functional_dependencies().clone();

for func_dep in func_deps.iter() {
if func_dep.source_indices == all_fields {
return Ok(Transformed::yes(distinct.inputs()[0].clone()));
}
}
Ok(Transformed::no(LogicalPlan::Distinct(Distinct::All(
distinct,
))))
) -> Result<Transformed<LogicalPlan>, datafusion_common::DataFusionError> {
let LogicalPlan::Distinct(Distinct::All(distinct)) = plan else {
return Ok(Transformed::no(plan));
};

let fields = distinct.schema().fields();
let all_fields = (0..fields.len()).collect::<Vec<_>>();
let func_deps = distinct.schema().functional_dependencies().clone();

for func_dep in func_deps.iter() {
if func_dep.source_indices == all_fields {
return Ok(Transformed::yes(distinct.inputs()[0].clone()));
}
_ => Ok(Transformed::no(plan)),
}
Ok(Transformed::no(LogicalPlan::Distinct(Distinct::All(
distinct,
))))
}
}

Expand Down Expand Up @@ -110,7 +106,6 @@ mod tests {
.build()?;

let expected = "Aggregate: groupBy=[[test.c]], aggr=[[]]\n TableScan: test";
// No aggregate / scan / limit
assert_optimized_plan_equal(&plan, expected)
}

Expand All @@ -123,9 +118,34 @@ mod tests {
.distinct()?
.build()?;

// No aggregate / scan / limit
let expected =
"Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n TableScan: test";
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn do_not_eliminate_distinct() -> Result<()> {
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.distinct()?
.build()?;

let expected = "Distinct:\n Projection: test.a, test.b\n TableScan: test";
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn do_not_eliminate_distinct_with_aggr() -> Result<()> {
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a"), col("b"), col("c")], Vec::<Expr>::new())?
.project(vec![col("a"), col("b")])?
.distinct()?
.build()?;

let expected =
"Distinct:\n Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[]]\n TableScan: test";
assert_optimized_plan_equal(&plan, expected)
}
}

0 comments on commit a8e1b05

Please sign in to comment.