Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve unparsing to handle aggregations with grouping set #36

Merged
merged 3 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions datafusion/sql/src/unparser/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ use datafusion_common::{
tree_node::{Transformed, TreeNode},
Column, Result,
};
use datafusion_expr::{Aggregate, Expr, LogicalPlan, Window};
use datafusion_expr::{
utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Window,
};

/// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
Expand Down Expand Up @@ -109,16 +111,16 @@ pub(crate) fn unproject_agg_exprs(
expr.clone()
.transform(|sub_expr| {
if let Expr::Column(c) = sub_expr {
if let Some(unprojected_expr) = find_agg_expr(agg, &c) {
if let Some(unprojected_expr) = find_agg_expr(agg, &c)? {
Ok(Transformed::yes(unprojected_expr.clone()))
} else if let Some(mut unprojected_expr) =
windows.and_then(|w| find_window_expr(w, &c.name).cloned())
{
if let Expr::WindowFunction(func) = &mut unprojected_expr {
// Window function can contain aggregation column, for ex 'avg(sum(ss_sales_price)) over ..' that needs to be unprojected
// Window function can contain an aggregation column, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected
for arg in &mut func.args {
if let Expr::Column(c) = arg {
if let Some(expr) = find_agg_expr(agg, c) {
if let Some(expr) = find_agg_expr(agg, c)? {
*arg = expr.clone();
}
}
Expand All @@ -127,7 +129,7 @@ pub(crate) fn unproject_agg_exprs(
Ok(Transformed::yes(unprojected_expr))
} else {
internal_err!(
"Tried to unproject agg expr not found in provided Aggregate!"
"Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name
)
}
} else {
Expand Down Expand Up @@ -158,11 +160,20 @@ pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result
.map(|e| e.data)
}

fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Option<&'a Expr> {
fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result<Option<&'a Expr>> {
if let Ok(index) = agg.schema.index_of_column(column) {
agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index)
if matches!(agg.group_expr.as_slice(), [Expr::GroupingSet(_)]) {
// For grouping set expr, we must operate by expression list from the grouping set
let grouping_expr = grouping_set_to_exprlist(agg.group_expr.as_slice())?;
return Ok(grouping_expr
.into_iter()
.chain(agg.aggr_expr.iter())
.nth(index));
} else {
return Ok(agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index));
};
} else {
None
Ok(None)
}
}

Expand Down
1 change: 1 addition & 0 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ fn roundtrip_statement() -> Result<()> {
SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum,
MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total
FROM person GROUP BY id, first_name"#,
"SELECT id, first_name, last_name, SUM(id) AS total_sum FROM person GROUP BY ROLLUP(id, first_name, last_name)",
"WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3",
"WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col), w4 as (SELECT 'd' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3 UNION ALL SELECT * FROM w4",
"WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col) SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col",
Expand Down