Skip to content

Commit

Permalink
Improve unparsing to handle aggregations with grouping set (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrebnov authored Oct 1, 2024
1 parent dc69f7b commit b15debb
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
27 changes: 19 additions & 8 deletions datafusion/sql/src/unparser/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ use datafusion_common::{
tree_node::{Transformed, TreeNode},
Column, Result,
};
use datafusion_expr::{Aggregate, Expr, LogicalPlan, Window};
use datafusion_expr::{
utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Window,
};

/// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
Expand Down Expand Up @@ -109,16 +111,16 @@ pub(crate) fn unproject_agg_exprs(
expr.clone()
.transform(|sub_expr| {
if let Expr::Column(c) = sub_expr {
if let Some(unprojected_expr) = find_agg_expr(agg, &c) {
if let Some(unprojected_expr) = find_agg_expr(agg, &c)? {
Ok(Transformed::yes(unprojected_expr.clone()))
} else if let Some(mut unprojected_expr) =
windows.and_then(|w| find_window_expr(w, &c.name).cloned())
{
if let Expr::WindowFunction(func) = &mut unprojected_expr {
// Window function can contain aggregation column, for ex 'avg(sum(ss_sales_price)) over ..' that needs to be unprojected
// Window function can contain an aggregation column, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected
for arg in &mut func.args {
if let Expr::Column(c) = arg {
if let Some(expr) = find_agg_expr(agg, c) {
if let Some(expr) = find_agg_expr(agg, c)? {
*arg = expr.clone();
}
}
Expand All @@ -127,7 +129,7 @@ pub(crate) fn unproject_agg_exprs(
Ok(Transformed::yes(unprojected_expr))
} else {
internal_err!(
"Tried to unproject agg expr not found in provided Aggregate!"
"Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name
)
}
} else {
Expand Down Expand Up @@ -158,11 +160,20 @@ pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result
.map(|e| e.data)
}

fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Option<&'a Expr> {
fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result<Option<&'a Expr>> {
if let Ok(index) = agg.schema.index_of_column(column) {
agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index)
if matches!(agg.group_expr.as_slice(), [Expr::GroupingSet(_)]) {
// For grouping set expr, we must operate by expression list from the grouping set
let grouping_expr = grouping_set_to_exprlist(agg.group_expr.as_slice())?;
return Ok(grouping_expr
.into_iter()
.chain(agg.aggr_expr.iter())
.nth(index));
} else {
return Ok(agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index));
};
} else {
None
Ok(None)
}
}

Expand Down
1 change: 1 addition & 0 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ fn roundtrip_statement() -> Result<()> {
SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum,
MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total
FROM person GROUP BY id, first_name"#,
"SELECT id, first_name, last_name, SUM(id) AS total_sum FROM person GROUP BY ROLLUP(id, first_name, last_name)",
"WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3",
"WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col), w4 as (SELECT 'd' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3 UNION ALL SELECT * FROM w4",
"WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col) SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col",
Expand Down

0 comments on commit b15debb

Please sign in to comment.