From dc69f7b0ce87c63efb9c6e48947ed162033c0902 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Tue, 1 Oct 2024 16:39:11 -0700 Subject: [PATCH] Fix to unparse the plan with multiple UNION statements into an SQL string (#12605) (#37) * fix unparse multiple UNION statement * enhance the error message * cargo fmt --------- Co-authored-by: Jax Liu Co-authored-by: Phillip LeBlanc --- datafusion/sql/src/unparser/plan.rs | 12 +++++++----- datafusion/sql/tests/cases/plan_to_sql.rs | 3 +++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 603d2d77cce7..4e36b50afec7 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -522,7 +522,7 @@ impl Unparser<'_> { let input_exprs: Vec = union .inputs .iter() - .map(|input| self.select_to_sql_expr(input, &mut None)) + .map(|input| self.select_to_sql_expr(input, query)) .collect::>>()?; let union_expr = SetExpr::SetOperation { @@ -532,10 +532,12 @@ impl Unparser<'_> { right: Box::new(input_exprs[1].clone()), }; - query - .as_mut() - .expect("to have a query builder") - .body(Box::new(union_expr)); + let Some(query) = query.as_mut() else { + return internal_err!( + "UNION ALL operator only valid in a statement context" + ); + }; + query.body(Box::new(union_expr)); Ok(()) } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 6ded9eed0a3e..330ac3218a26 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -151,6 +151,9 @@ fn roundtrip_statement() -> Result<()> { SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total FROM person GROUP BY id, first_name"#, + "WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3", + "WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col), w4 as (SELECT 'd' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3 UNION ALL SELECT * FROM w4", + "WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col) SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col", ]; // For each test sql string, we transform as follows: