Skip to content

Commit

Permalink
Add multiple group by expression handling.
Browse files Browse the repository at this point in the history
  • Loading branch information
mustafasrepo committed Aug 2, 2024
1 parent 5598fb4 commit c3257a6
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 45 deletions.
36 changes: 21 additions & 15 deletions datafusion/common/src/functional_dependencies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,22 +524,28 @@ pub fn aggregate_functional_dependencies(
}
}

// If we have a single GROUP BY key, we can guarantee uniqueness after
// When we have a GROUP BY key, we can guarantee uniqueness after
// aggregation:
if group_by_expr_names.len() == 1 {
// If `source_indices` contain 0, delete this functional dependency
// as it will be added anyway with mode `Dependency::Single`:
aggregate_func_dependencies.retain(|item| !item.source_indices.contains(&0));
// Add a new functional dependency associated with the whole table:
aggregate_func_dependencies.push(
// Use nullable property of the group by expression
FunctionalDependence::new(
vec![0],
target_indices,
aggr_fields[0].is_nullable(),
)
.with_mode(Dependency::Single),
);
if !group_by_expr_names.is_empty() {
let source_indices = (0..group_by_expr_names.len()).collect::<Vec<_>>();
let nullable = source_indices
.iter()
.any(|idx| aggr_fields[*idx].is_nullable());
// If `source_indices` is not already a determinant in the existing `aggregate_func_dependencies`.
if !aggregate_func_dependencies.iter().any(|item| {
// `item.source_indices` is a subset of the `source_indices`. In this case, we shouldn't add
// `source_indices` as `item.source_indices` defines this relation already.
item.source_indices
.iter()
.all(|idx| source_indices.contains(idx))
}) {
// Add a new functional dependency associated with the whole table:
aggregate_func_dependencies.push(
// Use nullable property of the group by expression
FunctionalDependence::new(source_indices, target_indices, nullable)
.with_mode(Dependency::Single),
);
}
}
FunctionalDependencies::new(aggregate_func_dependencies)
}
Expand Down
51 changes: 23 additions & 28 deletions datafusion/optimizer/src/eliminate_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,25 @@
//! Optimizer rule to replaces redundant aggregations on a plan.
//! This saves time in planning and executing the query.

use std::ops::Deref;
use crate::optimizer::ApplyOrder;
use datafusion_common::{Result};
use datafusion_common::display::{ToStringifiedPlan};
use datafusion_common::tree_node::TreeNode;
use datafusion_expr::{logical_plan::{LogicalPlan}, Aggregate, Join, Distinct};
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::display::ToStringifiedPlan;
use datafusion_common::tree_node::TreeNode;
use datafusion_common::Result;
use datafusion_expr::{logical_plan::LogicalPlan, Aggregate, Distinct, Join};
use std::ops::Deref;

#[derive(Default)]
pub struct EliminateAggregate{
pub struct EliminateAggregate {
group_bys: Vec<LogicalPlan>,
}

impl EliminateAggregate {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {group_bys: Vec::new()}
Self {
group_bys: Vec::new(),
}
}
}

Expand Down Expand Up @@ -68,36 +70,26 @@ impl OptimizerRule for EliminateAggregate {

for func_dep in func_deps.iter() {
if func_dep.source_indices == all_fields {
return Ok(Some(distinct.inputs()[0].clone()))
return Ok(Some(distinct.inputs()[0].clone()));
}
}
return Ok(None)
},
return Ok(None);
}
LogicalPlan::Distinct(Distinct::On(distinct)) => {
let fields = distinct.schema.fields();
let all_fields = (0..fields.len()).collect::<Vec<_>>();
let func_deps = distinct.schema.functional_dependencies().clone();

for func_dep in func_deps.iter() {
if func_dep.source_indices == all_fields {
return Ok(Some(distinct.input.as_ref().clone()))
return Ok(Some(distinct.input.as_ref().clone()));
}
}
return Ok(None)
},
LogicalPlan::Aggregate(Aggregate {
..
}) => {
Ok(None)
},
LogicalPlan::Join(Join {
..
}) => {
Ok(None)
return Ok(None);
}
_ => {
Ok(None)
},
LogicalPlan::Aggregate(Aggregate { .. }) => Ok(None),
LogicalPlan::Join(Join { .. }) => Ok(None),
_ => Ok(None),
}
}

Expand All @@ -113,7 +105,7 @@ impl OptimizerRule for EliminateAggregate {
#[cfg(test)]
mod tests {
use crate::eliminate_aggregate::EliminateAggregate;
use datafusion_common::{Result};
use datafusion_common::Result;
use datafusion_expr::{
col, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan,
};
Expand All @@ -122,7 +114,11 @@ mod tests {
use crate::test::*;

fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> {
assert_optimized_plan_eq(Arc::new(EliminateAggregate::new()), plan.clone(), expected)
assert_optimized_plan_eq(
Arc::new(EliminateAggregate::new()),
plan.clone(),
expected,
)
}

#[test]
Expand Down Expand Up @@ -152,5 +148,4 @@ mod tests {
let expected = "Distinct:\n Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n TableScan: test";
assert_optimized_plan_equal(&plan, expected)
}

}
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ pub mod analyzer;
pub mod common_subexpr_eliminate;
pub mod decorrelate;
pub mod decorrelate_predicate_subquery;
pub mod eliminate_aggregate;
pub mod eliminate_cross_join;
pub mod eliminate_duplicated_expr;
pub mod eliminate_filter;
pub mod eliminate_group_by_constant;
pub mod eliminate_aggregate;
pub mod eliminate_join;
pub mod eliminate_limit;
pub mod eliminate_nested_union;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ use datafusion_expr::logical_plan::LogicalPlan;

use crate::common_subexpr_eliminate::CommonSubexprEliminate;
use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery;
use crate::eliminate_aggregate::EliminateAggregate;
use crate::eliminate_cross_join::EliminateCrossJoin;
use crate::eliminate_duplicated_expr::EliminateDuplicatedExpr;
use crate::eliminate_filter::EliminateFilter;
use crate::eliminate_group_by_constant::EliminateGroupByConstant;
use crate::eliminate_aggregate::EliminateAggregate;
use crate::eliminate_join::EliminateJoin;
use crate::eliminate_limit::EliminateLimit;
use crate::eliminate_nested_union::EliminateNestedUnion;
Expand Down

0 comments on commit c3257a6

Please sign in to comment.