Skip to content

Commit

Permalink
Support unparsing plans with both Aggregation and Window functions (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrebnov authored Sep 30, 2024
1 parent 34f7d90 commit 1ed3963
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 44 deletions.
25 changes: 16 additions & 9 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ use super::{
rewrite_plan_for_sort_on_non_projected_fields,
subquery_alias_inner_query_and_columns,
},
utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant},
utils::{
find_agg_node_within_select, find_window_nodes_within_select,
unproject_window_exprs,
},
Unparser,
};

Expand Down Expand Up @@ -170,13 +173,17 @@ impl Unparser<'_> {
p: &Projection,
select: &mut SelectBuilder,
) -> Result<()> {
match find_agg_node_within_select(plan, None, true) {
Some(AggVariant::Aggregate(agg)) => {
match (
find_agg_node_within_select(plan, true),
find_window_nodes_within_select(plan, None, true),
) {
(Some(agg), window) => {
let window_option = window.as_deref();
let items = p
.expr
.iter()
.map(|proj_expr| {
let unproj = unproject_agg_exprs(proj_expr, agg)?;
let unproj = unproject_agg_exprs(proj_expr, agg, window_option)?;
self.select_item_to_sql(&unproj)
})
.collect::<Result<Vec<_>>>()?;
Expand All @@ -190,7 +197,7 @@ impl Unparser<'_> {
vec![],
));
}
Some(AggVariant::Window(window)) => {
(None, Some(window)) => {
let items = p
.expr
.iter()
Expand All @@ -202,7 +209,7 @@ impl Unparser<'_> {

select.projection(items);
}
None => {
_ => {
let items = p
.expr
.iter()
Expand Down Expand Up @@ -285,10 +292,10 @@ impl Unparser<'_> {
self.select_to_sql_recursively(p.input.as_ref(), query, select, relation)
}
LogicalPlan::Filter(filter) => {
if let Some(AggVariant::Aggregate(agg)) =
find_agg_node_within_select(plan, None, select.already_projected())
if let Some(agg) =
find_agg_node_within_select(plan, select.already_projected())
{
let unprojected = unproject_agg_exprs(&filter.predicate, agg)?;
let unprojected = unproject_agg_exprs(&filter.predicate, agg, None)?;
let filter_expr = self.expr_to_sql(&unprojected)?;
select.having(Some(filter_expr));
} else {
Expand Down
118 changes: 83 additions & 35 deletions datafusion/sql/src/unparser/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,58 +18,81 @@
use datafusion_common::{
internal_err,
tree_node::{Transformed, TreeNode},
Result,
Column, Result,
};
use datafusion_expr::{Aggregate, Expr, LogicalPlan, Window};

/// One of the possible aggregation plans which can be found within a single select query.
pub(crate) enum AggVariant<'a> {
Aggregate(&'a Aggregate),
Window(Vec<&'a Window>),
/// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
/// If an Aggregate or node is not found prior to this or at all before reaching the end
/// of the tree, None is returned.
pub(crate) fn find_agg_node_within_select(
plan: &LogicalPlan,
already_projected: bool,
) -> Option<&Aggregate> {
// Note that none of the nodes that have a corresponding node can have more
// than 1 input node. E.g. Projection / Filter always have 1 input node.
let input = plan.inputs();
let input = if input.len() > 1 {
return None;
} else {
input.first()?
};
// Agg nodes explicitly return immediately with a single node
if let LogicalPlan::Aggregate(agg) = input {
Some(agg)
} else if let LogicalPlan::TableScan(_) = input {
None
} else if let LogicalPlan::Projection(_) = input {
if already_projected {
None
} else {
find_agg_node_within_select(input, true)
}
} else {
find_agg_node_within_select(input, already_projected)
}
}

/// Recursively searches children of [LogicalPlan] to find an Aggregate or window node if one exists
/// Recursively searches children of [LogicalPlan] to find Window nodes if exist
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
/// If an Aggregate or window node is not found prior to this or at all before reaching the end
/// of the tree, None is returned. It is assumed that a Window and Aggregate node cannot both
/// be found in a single select query.
pub(crate) fn find_agg_node_within_select<'a>(
/// If Window node is not found prior to this or at all before reaching the end
/// of the tree, None is returned.
pub(crate) fn find_window_nodes_within_select<'a>(
plan: &'a LogicalPlan,
mut prev_windows: Option<AggVariant<'a>>,
mut prev_windows: Option<Vec<&'a Window>>,
already_projected: bool,
) -> Option<AggVariant<'a>> {
// Note that none of the nodes that have a corresponding agg node can have more
) -> Option<Vec<&'a Window>> {
// Note that none of the nodes that have a corresponding node can have more
// than 1 input node. E.g. Projection / Filter always have 1 input node.
let input = plan.inputs();
let input = if input.len() > 1 {
return None;
return prev_windows;
} else {
input.first()?
};

// Agg nodes explicitly return immediately with a single node
// Window nodes accumulate in a vec until encountering a TableScan or 2nd projection
match input {
LogicalPlan::Aggregate(agg) => Some(AggVariant::Aggregate(agg)),
LogicalPlan::Window(window) => {
prev_windows = match &mut prev_windows {
Some(AggVariant::Window(windows)) => {
Some(windows) => {
windows.push(window);
prev_windows
}
_ => Some(AggVariant::Window(vec![window])),
_ => Some(vec![window]),
};
find_agg_node_within_select(input, prev_windows, already_projected)
find_window_nodes_within_select(input, prev_windows, already_projected)
}
LogicalPlan::Projection(_) => {
if already_projected {
prev_windows
} else {
find_agg_node_within_select(input, prev_windows, true)
find_window_nodes_within_select(input, prev_windows, true)
}
}
LogicalPlan::TableScan(_) => prev_windows,
_ => find_agg_node_within_select(input, prev_windows, already_projected),
_ => find_window_nodes_within_select(input, prev_windows, already_projected),
}
}

Expand All @@ -78,19 +101,30 @@ pub(crate) fn find_agg_node_within_select<'a>(
///
/// For example, if expr contains the column expr "COUNT(*)" it will be transformed
/// into an actual aggregate expression COUNT(*) as identified in the aggregate node.
pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result<Expr> {
pub(crate) fn unproject_agg_exprs(
expr: &Expr,
agg: &Aggregate,
windows: Option<&[&Window]>,
) -> Result<Expr> {
expr.clone()
.transform(|sub_expr| {
if let Expr::Column(c) = sub_expr {
// find the column in the agg schema
if let Ok(n) = agg.schema.index_of_column(&c) {
let unprojected_expr = agg
.group_expr
.iter()
.chain(agg.aggr_expr.iter())
.nth(n)
.unwrap();
if let Some(unprojected_expr) = find_agg_expr(agg, &c) {
Ok(Transformed::yes(unprojected_expr.clone()))
} else if let Some(mut unprojected_expr) =
windows.and_then(|w| find_window_expr(w, &c.name).cloned())
{
if let Expr::WindowFunction(func) = &mut unprojected_expr {
// Window function can contain aggregation column, for ex 'avg(sum(ss_sales_price)) over ..' that needs to be unprojected
for arg in &mut func.args {
if let Expr::Column(c) = arg {
if let Some(expr) = find_agg_expr(agg, c) {
*arg = expr.clone();
}
}
}
}
Ok(Transformed::yes(unprojected_expr))
} else {
internal_err!(
"Tried to unproject agg expr not found in provided Aggregate!"
Expand All @@ -112,11 +146,7 @@ pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result
expr.clone()
.transform(|sub_expr| {
if let Expr::Column(c) = sub_expr {
if let Some(unproj) = windows
.iter()
.flat_map(|w| w.window_expr.iter())
.find(|window_expr| window_expr.schema_name().to_string() == c.name)
{
if let Some(unproj) = find_window_expr(windows, &c.name) {
Ok(Transformed::yes(unproj.clone()))
} else {
Ok(Transformed::no(Expr::Column(c)))
Expand All @@ -127,3 +157,21 @@ pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result
})
.map(|e| e.data)
}

fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Option<&'a Expr> {
if let Ok(index) = agg.schema.index_of_column(column) {
agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index)
} else {
None
}
}

fn find_window_expr<'a>(
windows: &'a [&'a Window],
column_name: &'a str,
) -> Option<&'a Expr> {
windows
.iter()
.flat_map(|w| w.window_expr.iter())
.find(|expr| expr.schema_name().to_string() == column_name)
}
6 changes: 6 additions & 0 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ fn roundtrip_statement() -> Result<()> {
sum(id) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) from person"#,
"SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person",
"WITH t1 AS (SELECT j1_id AS id, j1_string name FROM j1), t2 AS (SELECT j2_id AS id, j2_string name FROM j2) SELECT * FROM t1 JOIN t2 USING (id, name)",
r#"SELECT id, first_name,
SUM(id) AS total_sum,
SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum,
MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total
FROM person GROUP BY id, first_name"#,
];

// For each test sql string, we transform as follows:
Expand All @@ -161,6 +166,7 @@ fn roundtrip_statement() -> Result<()> {
let state = MockSessionState::default()
.with_aggregate_function(sum_udaf())
.with_aggregate_function(count_udaf())
.with_aggregate_function(max_udaf())
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()));
let context = MockContextProvider { state };
let sql_to_rel = SqlToRel::new(&context);
Expand Down

0 comments on commit 1ed3963

Please sign in to comment.