From eceeb6fff579c532b1b386c2f4b4c15910d97178 Mon Sep 17 00:00:00 2001 From: Zhixiong Chen Date: Wed, 17 Aug 2022 14:41:37 +0800 Subject: [PATCH] [SPARK-39819][SQL] DS V2 aggregate push down can work with Top N or Paging (Sort with expressions) (#525) * [SPARK-39784][SQL] Put Literal values on the right side of the data source filter after translating Catalyst Expression to data source filter ### What changes were proposed in this pull request? Even though the literal value could be on both sides of the filter, e.g. both `a > 1` and `1 < a` are valid, after translating Catalyst Expression to data source filter, we want the literal value on the right side so it's easier for the data source to handle these filters. We do this kind of normalization for V1 Filter. We should have the same behavior for V2 Filter. Before this PR, for the filters that have literal values on the right side, e.g. `1 > a`, we keep it as is. After this PR, we will normalize it to `a < 1` so the data source doesn't need to check each of the filters (and do the flip). ### Why are the changes needed? I think we should follow V1 Filter's behavior, normalize the filters during catalyst Expression to DS Filter translation time to make the literal values on the right side, so later on, data source doesn't need to check every single filter to figure out if it needs to flip the sides. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new test Closes #37197 from huaxingao/flip. Authored-by: huaxingao Signed-off-by: huaxingao * [SPARK-39836][SQL] Simplify V2ExpressionBuilder by extract common method ### What changes were proposed in this pull request? Currently, `V2ExpressionBuilder` have a lot of similar code, we can extract them as one common method. We can simplify the implement with the common method. ### Why are the changes needed? Simplify `V2ExpressionBuilder` by extract common method. ### Does this PR introduce _any_ user-facing change? 'No'. Just update inner implementation. ### How was this patch tested? N/A Closes #37249 from beliefer/SPARK-39836. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan * [SPARK-39858][SQL] Remove unnecessary `AliasHelper` or `PredicateHelper` for some rules ### What changes were proposed in this pull request? When I using `AliasHelper`, I find that some rules inherit it instead of using it. This PR removes unnecessary `AliasHelper` or `PredicateHelper` in the following cases: - The rule inherit `AliasHelper` instead of using it. In this case, we can remove `AliasHelper` directly. - The rule inherit `PredicateHelper` instead of using it. In this case, we can remove `PredicateHelper` directly. - The rule inherit `AliasHelper` and `PredicateHelper`. In fact, `PredicateHelper` already extends `AliasHelper`. In this case, we can remove `AliasHelper`. - The rule inherit `OperationHelper` and `PredicateHelper`. In fact, `OperationHelper` already extends `PredicateHelper`. In this case, we can remove `PredicateHelper`. - The rule inherit `PlanTest` and `PredicateHelper`. In fact, `PlanTest` already extends `PredicateHelper`. In this case, we can remove `PredicateHelper`. - The rule inherit `QueryTest` and `PredicateHelper`. In fact, `QueryTest` already extends `PredicateHelper`. In this case, we can remove `PredicateHelper`. ### Why are the changes needed? Remove unnecessary `AliasHelper` or `PredicateHelper` for some rules ### Does this PR introduce _any_ user-facing change? 'No'. Just improve the inner implementation. ### How was this patch tested? N/A Closes #37272 from beliefer/SPARK-39858. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan * [SPARK-39784][SQL][FOLLOW-UP] Use BinaryComparison instead of Predicate (if) for type check ### What changes were proposed in this pull request? follow up this [comment](https://github.com/apache/spark/pull/37197#discussion_r928570992) ### Why are the changes needed? code simplification ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing test Closes #37278 from huaxingao/followup. Authored-by: huaxingao Signed-off-by: Dongjoon Hyun * [SPARK-39909] Organize the check of push down information for JDBCV2Suite ### What changes were proposed in this pull request? This PR changes the check method from `check(one_large_string)` to `check(small_string1, small_string2, ...)` ### Why are the changes needed? It can help us check the results individually and make the code more clearer. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests Closes #37342 from yabola/fix. Authored-by: chenliang.lu Signed-off-by: huaxingao * [SPARK-39961][SQL] DS V2 push-down translate Cast if the cast is safe ### What changes were proposed in this pull request? Currently, DS V2 push-down translate `Cast` only if the ansi mode is true. In fact, if the cast is safe(e.g. cast number to string, cast int to long), we can translate it too. This PR will call `Cast.canUpCast` so as we can translate `Cast` to V2 `Cast` safely. Note: The rule `SimplifyCasts` optimize some safe cast, e.g. cast int to long, so we may not see the `Cast`. ### Why are the changes needed? Add the range for DS V2 push down `Cast`. ### Does this PR introduce _any_ user-facing change? 'Yes'. `Cast` could be pushed down to data source in more cases. ### How was this patch tested? Test cases updated. Closes #37388 from beliefer/SPARK-39961. Authored-by: Jiaan Geng Signed-off-by: Dongjoon Hyun * [SPARK-38901][SQL] DS V2 supports push down misc functions ### What changes were proposed in this pull request? Currently, Spark have some misc functions. Please refer https://github.com/apache/spark/blob/2f8613f22c0750c00cf1dcfb2f31c431d8dc1be7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala#L688 These functions show below: `AES_ENCRYPT,` `AES_DECRYPT`, `SHA1`, `SHA2`, `MD5`, `CRC32` Function|PostgreSQL|ClickHouse|H2|MySQL|Oracle|Redshift|Snowflake|DB2|Vertica|Exasol|SqlServer|Yellowbrick|Mariadb|Singlestore| -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- `AesEncrypt`|Yes|Yes|Yes|Yes|Yes|NO|Yes|Yes|NO|NO|NO|Yes|Yes|Yes| `AesDecrypt`|Yes|Yes|Yes|Yes|Yes|NO|Yes|Yes|NO|NO|NO|Yes|Yes|Yes| `Sha1`|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes| `Sha2`|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes| `Md5`|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes|Yes| `Crc32`|No|Yes|No|Yes|NO|Yes|NO|Yes|NO|NO|NO|NO|NO|Yes| DS V2 should supports push down these misc functions. ### Why are the changes needed? DS V2 supports push down misc functions. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New tests. Closes #37169 from chenzhx/misc. Authored-by: chenzhx Signed-off-by: Wenchen Fan * [SPARK-39964][SQL] DS V2 pushdown should unify the translate path ### What changes were proposed in this pull request? Currently, DS V2 pushdown have two translate path `DataSourceStrategy.translateAggregate` used to translate aggregate functions and `V2ExpressionBuilder` used to translate other functions and expressions, we can unify them. After this PR, the translate have only one code path, developers will easy to coding and reading. ### Why are the changes needed? Unify the translate path for DS V2 pushdown. ### Does this PR introduce _any_ user-facing change? 'No'. Just update the inner implementation. ### How was this patch tested? N/A Closes #37391 from beliefer/SPARK-39964. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan * [SPARK-39819][SQL] DS V2 aggregate push down can work with Top N or Paging (Sort with expressions) ### What changes were proposed in this pull request? Currently, DS V2 aggregate push-down cannot work with DS V2 Top N push-down (`ORDER BY col LIMIT m`) or DS V2 Paging push-down (`ORDER BY col LIMIT m OFFSET n`). If we can push down aggregate with Top N or Paging, it will be better performance. This PR only let aggregate pushed down with ORDER BY expressions which must be GROUP BY expressions. The idea of this PR are: 1. When we give an expectation outputs of `ScanBuilderHolder`, holding the map from expectation outputs to origin expressions (contains origin columns). 2. When we try to push down Top N or Paging, we need restore the origin expressions for `SortOrder`. ### Why are the changes needed? Let DS V2 aggregate push down can work with Top N or Paging (Sort with group expressions), then users can get the better performance. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New test cases. Closes #37320 from beliefer/SPARK-39819_new. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan * [SPARK-39929][SQL] DS V2 supports push down string functions(non ANSI) **What changes were proposed in this pull request?** support more commonly used string functions BIT_LENGTH CHAR_LENGTH CONCAT The mainstream databases support these functions show below. Function | PostgreSQL | ClickHouse | H2 | MySQL | Oracle | Redshift | Presto | Teradata | Snowflake | DB2 | Vertica | Exasol | SqlServer | Yellowbrick | Impala | Mariadb | Druid | Pig | SQLite | Influxdata | Singlestore | ElasticSearch -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- BIT_LENGTH | Yes | Yes | Yes | Yes | Yes | no | no | no | no | Yes | Yes | Yes | no | Yes | no | Yes | no | no | no | no | no | Yes CHAR_LENGTH | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | no | Yes | Yes | Yes | Yes CONCAT | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | no | no | no | Yes | Yes **Why are the changes needed?** DS V2 supports push down string functions **Does this PR introduce any user-facing change?** 'No'. New feature. How was this patch tested? New tests. Closes #37427 from zheniantoushipashi/SPARK-39929. Authored-by: biaobiao.sun <1319027852@qq.com> Signed-off-by: Wenchen Fan * [SPARK-38899][SQL][FOLLOWUP] Fix bug extract datetime in DS V2 pushdown ### What changes were proposed in this pull request? [SPARK-38899](https://github.com/apache/spark/pull/36663) supports extract function in JDBC data source. But the implement is incorrect. This PR just add a test case and it will be failed ! The test case show below. ``` test("scan with filter push-down with date time functions") { val df9 = sql("SELECT name FROM h2.test.datetime WHERE " + "dayofyear(date1) > 100 order by dayofyear(date1) limit 1") checkFiltersRemoved(df9) val expectedPlanFragment9 = "PushedFilters: [DATE1 IS NOT NULL, EXTRACT(DAY_OF_YEAR FROM DATE1) > 100], " + "PushedTopN: ORDER BY [EXTRACT(DAY_OF_YEAR FROM DATE1) ASC NULLS FIRST] LIMIT 1," checkPushedInfo(df9, expectedPlanFragment9) checkAnswer(df9, Seq(Row("alex"))) } ``` The test case output failure show below. ``` "== Parsed Logical Plan == 'GlobalLimit 1 +- 'LocalLimit 1 +- 'Sort ['dayofyear('date1) ASC NULLS FIRST], true +- 'Project ['name] +- 'Filter ('dayofyear('date1) > 100) +- 'UnresolvedRelation [h2, test, datetime], [], false == Analyzed Logical Plan == name: string GlobalLimit 1 +- LocalLimit 1 +- Project [name#x] +- Sort [dayofyear(date1#x) ASC NULLS FIRST], true +- Project [name#x, date1#x] +- Filter (dayofyear(date1#x) > 100) +- SubqueryAlias h2.test.datetime +- RelationV2[NAME#x, DATE1#x, TIME1#x] h2.test.datetime test.datetime == Optimized Logical Plan == Project [name#x] +- RelationV2[NAME#x] test.datetime == Physical Plan == *(1) Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$145f6181a [NAME#x] PushedFilters: [DATE1 IS NOT NULL, EXTRACT(DAY_OF_YEAR FROM DATE1) > 100], PushedTopN: ORDER BY [org.apache.spark.sql.connector.expressions.Extract3b95fce9 ASC NULLS FIRST] LIMIT 1, ReadSchema: struct " did not contain "PushedFilters: [DATE1 IS NOT NULL, EXTRACT(DAY_OF_YEAR FROM DATE1) > 100], PushedTopN: ORDER BY [EXTRACT(DAY_OF_YEAR FROM DATE1) ASC NULLS FIRST] LIMIT 1," ``` ### Why are the changes needed? Fix an implement bug. The reason of the bug is the Extract the function does not implement the toString method when pushing down the JDBC data source. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New test case. Closes #37469 from chenzhx/spark-master. Authored-by: chenzhx Signed-off-by: Wenchen Fan * code update Signed-off-by: huaxingao Signed-off-by: Wenchen Fan Signed-off-by: Dongjoon Hyun Co-authored-by: huaxingao Co-authored-by: Jiaan Geng Co-authored-by: chenliang.lu Co-authored-by: biaobiao.sun <1319027852@qq.com> --- .../spark/sql/connector/expressions/Cast.java | 5 + .../sql/connector/expressions/Extract.java | 7 + .../expressions/GeneralScalarExpression.java | 54 ++ .../util/V2ExpressionSQLBuilder.java | 9 + .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../optimizer/CostBasedJoinReorder.scala | 2 +- .../sql/catalyst/optimizer/Optimizer.scala | 4 +- .../sql/catalyst/optimizer/expressions.scala | 6 +- .../sql/catalyst/planning/patterns.scala | 6 +- ...xtractPredicatesWithinOutputSetSuite.scala | 5 +- .../BinaryComparisonSimplificationSuite.scala | 2 +- .../BooleanSimplificationSuite.scala | 2 +- .../EliminateSubqueryAliasesSuite.scala | 2 +- .../PushFoldableIntoBranchesSuite.scala | 3 +- .../RemoveRedundantAliasAndProjectSuite.scala | 2 +- .../optimizer/SimplifyConditionalSuite.scala | 2 +- .../catalyst/util/V2ExpressionBuilder.scala | 325 ++++----- .../execution/OptimizeMetadataOnlyQuery.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../adaptive/LogicalQueryStageStrategy.scala | 3 +- .../datasources/DataSourceStrategy.scala | 63 +- .../PruneFileSourcePartitions.scala | 3 +- .../datasources/v2/PushDownUtils.scala | 4 +- .../v2/V2ScanRelationPushDown.scala | 39 +- .../PlanDynamicPruningFilters.scala | 5 +- .../execution/python/ExtractPythonUDFs.scala | 2 +- .../org/apache/spark/sql/jdbc/H2Dialect.scala | 20 +- .../datasources/FileSourceStrategySuite.scala | 4 +- .../v2/DataSourceV2StrategySuite.scala | 67 +- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 667 +++++++++++++----- .../SimpleTextHadoopFsRelationSuite.scala | 3 +- 31 files changed, 848 insertions(+), 478 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java index 26b97b46fe2ef..44111913f124b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java @@ -42,4 +42,9 @@ public Cast(Expression expression, DataType dataType) { @Override public Expression[] children() { return new Expression[]{ expression() }; } + + @Override + public String toString() { + return "CAST(" + expression.describe() + " AS " + dataType.typeName() + ")"; + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Extract.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Extract.java index a925f1ee31a98..ed9f4415f7da1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Extract.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Extract.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.expressions; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.internal.connector.ToStringSQLBuilder; import java.io.Serializable; @@ -59,4 +60,10 @@ public Extract(String field, Expression source) { @Override public Expression[] children() { return new Expression[]{ source() }; } + + @Override + public String toString() { + ToStringSQLBuilder builder = new ToStringSQLBuilder(); + return builder.build(this); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java index 6dfaad0d26eb4..06fb5be583aa8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -340,6 +340,24 @@ *
  • Since version: 3.4.0
  • * * + *
  • Name: BIT_LENGTH + *
      + *
    • SQL semantic: BIT_LENGTH(src)
    • + *
    • Since version: 3.4.0
    • + *
    + *
  • + *
  • Name: CHAR_LENGTH + *
      + *
    • SQL semantic: CHAR_LENGTH(src)
    • + *
    • Since version: 3.4.0
    • + *
    + *
  • + *
  • Name: CONCAT + *
      + *
    • SQL semantic: CONCAT(col1, col2, ..., colN)
    • + *
    • Since version: 3.4.0
    • + *
    + *
  • *
  • Name: OVERLAY *
      *
    • SQL semantic: OVERLAY(string, replace, position[, length])
    • @@ -364,6 +382,42 @@ *
    • Since version: 3.4.0
    • *
    *
  • + *
  • Name: AES_ENCRYPT + *
      + *
    • SQL semantic: AES_ENCRYPT(expr, key[, mode[, padding]])
    • + *
    • Since version: 3.4.0
    • + *
    + *
  • + *
  • Name: AES_DECRYPT + *
      + *
    • SQL semantic: AES_DECRYPT(expr, key[, mode[, padding]])
    • + *
    • Since version: 3.4.0
    • + *
    + *
  • + *
  • Name: SHA1 + *
      + *
    • SQL semantic: SHA1(expr)
    • + *
    • Since version: 3.4.0
    • + *
    + *
  • + *
  • Name: SHA2 + *
      + *
    • SQL semantic: SHA2(expr, bitLength)
    • + *
    • Since version: 3.4.0
    • + *
    + *
  • + *
  • Name: MD5 + *
      + *
    • SQL semantic: MD5(expr)
    • + *
    • Since version: 3.4.0
    • + *
    + *
  • + *
  • Name: CRC32 + *
      + *
    • SQL semantic: CRC32(expr)
    • + *
    • Since version: 3.4.0
    • + *
    + *
  • * * Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off, * including: add, subtract, multiply, divide, remainder, pmod. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 60708ede19c8f..6cb8a8b116433 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -147,6 +147,15 @@ public String build(Expression expr) { case "DATE_ADD": case "DATE_DIFF": case "TRUNC": + case "AES_ENCRYPT": + case "AES_DECRYPT": + case "SHA1": + case "SHA2": + case "MD5": + case "CRC32": + case "BIT_LENGTH": + case "CHAR_LENGTH": + case "CONCAT": return visitSQLFunction(name, Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); case "CASE_WHEN": { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6322abb7e6c72..726a0dd56c343 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2369,7 +2369,7 @@ class Analyzer(override val catalogManager: CatalogManager) * * Note: CTEs are handled in CTESubstitution. */ - object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper { + object ResolveSubquery extends Rule[LogicalPlan] { /** * Resolve the correlated expressions in a subquery, as if the expressions live in the outer * plan. All resolved outer references are wrapped in an [[OuterReference]] @@ -2538,7 +2538,7 @@ class Analyzer(override val catalogManager: CatalogManager) * those in a HAVING clause or ORDER BY clause. These expressions are pushed down to the * underlying aggregate operator and then projected away after the original operator. */ - object ResolveAggregateFunctions extends Rule[LogicalPlan] with AliasHelper { + object ResolveAggregateFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( _.containsPattern(AGGREGATE), ruleId) { // Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 659384a507746..471f0bd554105 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -401,7 +401,7 @@ case class Cost(card: BigInt, size: BigInt) { * * Filters (2) and (3) are not implemented. */ -object JoinReorderDPFilters extends PredicateHelper { +object JoinReorderDPFilters { /** * Builds join graph information to be used by the filtering strategies. * Currently, it builds the sets of star/non-star joins. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0a53e9d73cd58..21ee6d915b03b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -728,7 +728,7 @@ object LimitPushDown extends Rule[LogicalPlan] { * safe to pushdown Filters and Projections through it. Filter pushdown is handled by another * rule PushDownPredicates. Once we add UNION DISTINCT, we will not be able to pushdown Projections. */ -object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper { +object PushProjectionThroughUnion extends Rule[LogicalPlan] { /** * Maps Attributes from the left side to the corresponding Attribute on the right side. @@ -1450,7 +1450,7 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { * This rule improves performance of predicate pushdown for cascading joins such as: * Filter-Join-Join-Join. Most predicates can be pushed down in a single pass. */ -object PushDownPredicates extends Rule[LogicalPlan] with PredicateHelper { +object PushDownPredicates extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsAnyPattern(FILTER, JOIN)) { CombineFilters.applyLocally diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 74f643ede4a9f..c32971f43b0c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -84,7 +84,7 @@ object ConstantFolding extends Rule[LogicalPlan] { * - Using this mapping, replace occurrence of the attributes with the corresponding constant values * in the AND node. */ -object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper { +object ConstantPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( _.containsAllPatterns(LITERAL, FILTER), ruleId) { case f: Filter => @@ -496,7 +496,7 @@ object SimplifyBinaryComparison /** * Simplifies conditional expressions (if / case). */ -object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { +object SimplifyConditionals extends Rule[LogicalPlan] { private def falseOrNullLiteral(e: Expression): Boolean = e match { case FalseLiteral => true case Literal(null, _) => true @@ -575,7 +575,7 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { /** * Push the foldable expression into (if / case) branches. */ -object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { +object PushFoldableIntoBranches extends Rule[LogicalPlan] { // To be conservative here: it's only a guaranteed win if all but at most only one branch // end up being not foldable. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index f33d137ffd607..11bdfca34377b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -trait OperationHelper extends AliasHelper with PredicateHelper { +trait OperationHelper extends PredicateHelper { import org.apache.spark.sql.catalyst.optimizer.CollapseProject.canCollapseExpressions type ReturnType = @@ -116,7 +116,7 @@ trait OperationHelper extends AliasHelper with PredicateHelper { * [[org.apache.spark.sql.catalyst.expressions.Alias Aliases]] are in-lined/substituted if * necessary. */ -object PhysicalOperation extends OperationHelper with PredicateHelper { +object PhysicalOperation extends OperationHelper { override protected def legacyMode: Boolean = true } @@ -125,7 +125,7 @@ object PhysicalOperation extends OperationHelper with PredicateHelper { * operations even if they are non-deterministic, as long as they satisfy the * requirement of CollapseProject and CombineFilters. */ -object ScanOperation extends OperationHelper with PredicateHelper { +object ScanOperation extends OperationHelper { override protected def legacyMode: Boolean = false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExtractPredicatesWithinOutputSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExtractPredicatesWithinOutputSetSuite.scala index ed141ef923e0a..10f9a88c429c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExtractPredicatesWithinOutputSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExtractPredicatesWithinOutputSetSuite.scala @@ -22,10 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.types.BooleanType -class ExtractPredicatesWithinOutputSetSuite - extends SparkFunSuite - with PredicateHelper - with PlanTest { +class ExtractPredicatesWithinOutputSetSuite extends SparkFunSuite with PlanTest { private val a = AttributeReference("A", BooleanType)(exprId = ExprId(1)) private val b = AttributeReference("B", BooleanType)(exprId = ExprId(2)) private val c = AttributeReference("C", BooleanType)(exprId = ExprId(3)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala index c02691848c8f0..b10d693c01689 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StructField, StructType} -class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper { +class BinaryComparisonSimplificationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 07f16f438cc56..b4d0fca42a9cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types.BooleanType -class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper with PredicateHelper { +class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper { object Optimize extends RuleExecutor[LogicalPlan] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala index 4df1a145a271b..35334a590f1f3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -class EliminateSubqueryAliasesSuite extends PlanTest with PredicateHelper { +class EliminateSubqueryAliasesSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("EliminateSubqueryAliases", Once, EliminateSubqueryAliases) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index 2f6cff3675fb5..7c50bcaf090ed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -32,8 +32,7 @@ import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, Timesta import org.apache.spark.unsafe.types.CalendarInterval -class PushFoldableIntoBranchesSuite - extends PlanTest with ExpressionEvalHelper with PredicateHelper { +class PushFoldableIntoBranchesSuite extends PlanTest with ExpressionEvalHelper { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("PushFoldableIntoBranches", FixedPoint(50), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala index 4b02a847880f7..d399a638adc9c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types.MetadataBuilder -class RemoveRedundantAliasAndProjectSuite extends PlanTest with PredicateHelper { +class RemoveRedundantAliasAndProjectSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 2a685bfeefcb2..e812ef0f6a530 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types.{BooleanType, IntegerType} -class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with PredicateHelper { +class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("SimplifyConditionals", FixedPoint(50), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 041ddf9fd07bc..7e15de56e2dec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -18,9 +18,12 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, UserDefinedScalarFunc} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} -import org.apache.spark.sql.types.{BooleanType, IntegerType} +import org.apache.spark.sql.execution.datasources.PushableExpression +import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType} /** * The builder to generate V2 expressions from catalyst expressions. @@ -88,126 +91,50 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } else { None } - case Cast(child, dataType, _, true) => + case Cast(child, dataType, _, ansiEnabled) + if ansiEnabled || Cast.canUpCast(child.dataType, dataType) => generateExpression(child).map(v => new V2Cast(v, dataType)) - case Abs(child, true) => generateExpression(child) - .map(v => new GeneralScalarExpression("ABS", Array[V2Expression](v))) - case Coalesce(children) => - val childrenExpressions = children.flatMap(generateExpression(_)) - if (children.length == childrenExpressions.length) { - Some(new GeneralScalarExpression("COALESCE", childrenExpressions.toArray[V2Expression])) - } else { - None - } - case Greatest(children) => - val childrenExpressions = children.flatMap(generateExpression(_)) - if (children.length == childrenExpressions.length) { - Some(new GeneralScalarExpression("GREATEST", childrenExpressions.toArray[V2Expression])) - } else { - None - } - case Least(children) => - val childrenExpressions = children.flatMap(generateExpression(_)) - if (children.length == childrenExpressions.length) { - Some(new GeneralScalarExpression("LEAST", childrenExpressions.toArray[V2Expression])) - } else { - None - } + case AggregateExpression(aggregateFunction, Complete, isDistinct, None, _) => + generateAggregateFunc(aggregateFunction, isDistinct) + case Abs(child, true) => generateExpressionWithName("ABS", Seq(child)) + case Coalesce(children) => generateExpressionWithName("COALESCE", children) + case Greatest(children) => generateExpressionWithName("GREATEST", children) + case Least(children) => generateExpressionWithName("LEAST", children) case Rand(child, hideSeed) => if (hideSeed) { Some(new GeneralScalarExpression("RAND", Array.empty[V2Expression])) } else { - generateExpression(child) - .map(v => new GeneralScalarExpression("RAND", Array[V2Expression](v))) - } - case log: Logarithm => - val l = generateExpression(log.left) - val r = generateExpression(log.right) - if (l.isDefined && r.isDefined) { - Some(new GeneralScalarExpression("LOG", Array[V2Expression](l.get, r.get))) - } else { - None - } - case Log10(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("LOG10", Array[V2Expression](v))) - case Log2(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("LOG2", Array[V2Expression](v))) - case Log(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("LN", Array[V2Expression](v))) - case Exp(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("EXP", Array[V2Expression](v))) - case Pow(left, right) => - val l = generateExpression(left) - val r = generateExpression(right) - if (l.isDefined && r.isDefined) { - Some(new GeneralScalarExpression("POWER", Array[V2Expression](l.get, r.get))) - } else { - None - } - case Sqrt(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("SQRT", Array[V2Expression](v))) - case Floor(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("FLOOR", Array[V2Expression](v))) - case Ceil(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("CEIL", Array[V2Expression](v))) - case round: Round => - val l = generateExpression(round.left) - val r = generateExpression(round.right) - if (l.isDefined && r.isDefined) { - Some(new GeneralScalarExpression("ROUND", Array[V2Expression](l.get, r.get))) - } else { - None - } - case Sin(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("SIN", Array[V2Expression](v))) - case Sinh(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("SINH", Array[V2Expression](v))) - case Cos(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("COS", Array[V2Expression](v))) - case Cosh(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("COSH", Array[V2Expression](v))) - case Tan(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("TAN", Array[V2Expression](v))) - case Tanh(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("TANH", Array[V2Expression](v))) - case Cot(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("COT", Array[V2Expression](v))) - case Asin(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("ASIN", Array[V2Expression](v))) - case Asinh(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("ASINH", Array[V2Expression](v))) - case Acos(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("ACOS", Array[V2Expression](v))) - case Acosh(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("ACOSH", Array[V2Expression](v))) - case Atan(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("ATAN", Array[V2Expression](v))) - case Atanh(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("ATANH", Array[V2Expression](v))) - case atan2: Atan2 => - val l = generateExpression(atan2.left) - val r = generateExpression(atan2.right) - if (l.isDefined && r.isDefined) { - Some(new GeneralScalarExpression("ATAN2", Array[V2Expression](l.get, r.get))) - } else { - None - } - case Cbrt(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("CBRT", Array[V2Expression](v))) - case ToDegrees(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("DEGREES", Array[V2Expression](v))) - case ToRadians(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("RADIANS", Array[V2Expression](v))) - case Signum(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("SIGN", Array[V2Expression](v))) - case wb: WidthBucket => - val childrenExpressions = wb.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == wb.children.length) { - Some(new GeneralScalarExpression("WIDTH_BUCKET", - childrenExpressions.toArray[V2Expression])) - } else { - None - } + generateExpressionWithName("RAND", Seq(child)) + } + case log: Logarithm => generateExpressionWithName("LOG", log.children) + case Log10(child) => generateExpressionWithName("LOG10", Seq(child)) + case Log2(child) => generateExpressionWithName("LOG2", Seq(child)) + case Log(child) => generateExpressionWithName("LN", Seq(child)) + case Exp(child) => generateExpressionWithName("EXP", Seq(child)) + case pow: Pow => generateExpressionWithName("POWER", pow.children) + case Sqrt(child) => generateExpressionWithName("SQRT", Seq(child)) + case Floor(child) => generateExpressionWithName("FLOOR", Seq(child)) + case Ceil(child) => generateExpressionWithName("CEIL", Seq(child)) + case round: Round => generateExpressionWithName("ROUND", round.children) + case Sin(child) => generateExpressionWithName("SIN", Seq(child)) + case Sinh(child) => generateExpressionWithName("SINH", Seq(child)) + case Cos(child) => generateExpressionWithName("COS", Seq(child)) + case Cosh(child) => generateExpressionWithName("COSH", Seq(child)) + case Tan(child) => generateExpressionWithName("TAN", Seq(child)) + case Tanh(child) => generateExpressionWithName("TANH", Seq(child)) + case Cot(child) => generateExpressionWithName("COT", Seq(child)) + case Asin(child) => generateExpressionWithName("ASIN", Seq(child)) + case Asinh(child) => generateExpressionWithName("ASINH", Seq(child)) + case Acos(child) => generateExpressionWithName("ACOS", Seq(child)) + case Acosh(child) => generateExpressionWithName("ACOSH", Seq(child)) + case Atan(child) => generateExpressionWithName("ATAN", Seq(child)) + case Atanh(child) => generateExpressionWithName("ATANH", Seq(child)) + case atan2: Atan2 => generateExpressionWithName("ATAN2", atan2.children) + case Cbrt(child) => generateExpressionWithName("CBRT", Seq(child)) + case ToDegrees(child) => generateExpressionWithName("DEGREES", Seq(child)) + case ToRadians(child) => generateExpressionWithName("RADIANS", Seq(child)) + case Signum(child) => generateExpressionWithName("SIGN", Seq(child)) + case wb: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", wb.children) case and: And => // AND expects predicate val l = generateExpression(and.left, true) @@ -233,6 +160,10 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { val r = generateExpression(b.right) if (l.isDefined && r.isDefined) { b match { + case _: BinaryComparison if l.get.isInstanceOf[LiteralValue[_]] && + r.get.isInstanceOf[FieldReference] => + Some(new V2Predicate(flipComparisonOperatorName(b.sqlOperator), + Array[V2Expression](r.get, l.get))) case _: Predicate => Some(new V2Predicate(b.sqlOperator, Array[V2Expression](l.get, r.get))) case _ => @@ -254,10 +185,8 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { assert(v.isInstanceOf[V2Predicate]) new V2Not(v.asInstanceOf[V2Predicate]) } - case UnaryMinus(child, true) => generateExpression(child) - .map(v => new GeneralScalarExpression("-", Array[V2Expression](v))) - case BitwiseNot(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("~", Array[V2Expression](v))) + case UnaryMinus(child, true) => generateExpressionWithName("-", Seq(child)) + case BitwiseNot(child) => generateExpressionWithName("~", Seq(child)) case CaseWhen(branches, elseValue) => val conditions = branches.map(_._1).flatMap(generateExpression(_, true)) val values = branches.map(_._2).flatMap(generateExpression(_, true)) @@ -278,93 +207,35 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } else { None } - case iff: If => - val childrenExpressions = iff.children.flatMap(generateExpression(_)) - if (iff.children.length == childrenExpressions.length) { - Some(new GeneralScalarExpression("CASE_WHEN", childrenExpressions.toArray[V2Expression])) - } else { - None - } + case iff: If => generateExpressionWithName("CASE_WHEN", iff.children) case substring: Substring => val children = if (substring.len == Literal(Integer.MAX_VALUE)) { Seq(substring.str, substring.pos) } else { substring.children } - val childrenExpressions = children.flatMap(generateExpression(_)) - if (childrenExpressions.length == children.length) { - Some(new GeneralScalarExpression("SUBSTRING", - childrenExpressions.toArray[V2Expression])) - } else { - None - } - case Upper(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("UPPER", Array[V2Expression](v))) - case Lower(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("LOWER", Array[V2Expression](v))) - case translate: StringTranslate => - val childrenExpressions = translate.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == translate.children.length) { - Some(new GeneralScalarExpression("TRANSLATE", - childrenExpressions.toArray[V2Expression])) - } else { - None - } - case trim: StringTrim => - val childrenExpressions = trim.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == trim.children.length) { - Some(new GeneralScalarExpression("TRIM", childrenExpressions.toArray[V2Expression])) - } else { - None - } - case trim: StringTrimLeft => - val childrenExpressions = trim.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == trim.children.length) { - Some(new GeneralScalarExpression("LTRIM", childrenExpressions.toArray[V2Expression])) - } else { - None - } - case trim: StringTrimRight => - val childrenExpressions = trim.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == trim.children.length) { - Some(new GeneralScalarExpression("RTRIM", childrenExpressions.toArray[V2Expression])) - } else { - None - } + generateExpressionWithName("SUBSTRING", children) + case Upper(child) => generateExpressionWithName("UPPER", Seq(child)) + case Lower(child) => generateExpressionWithName("LOWER", Seq(child)) + case BitLength(child) if child.dataType.isInstanceOf[StringType] => + generateExpressionWithName("BIT_LENGTH", Seq(child)) + case Length(child) if child.dataType.isInstanceOf[StringType] => + generateExpressionWithName("CHAR_LENGTH", Seq(child)) + case concat: Concat => generateExpressionWithName("CONCAT", concat.children) + case translate: StringTranslate => generateExpressionWithName("TRANSLATE", translate.children) + case trim: StringTrim => generateExpressionWithName("TRIM", trim.children) + case trim: StringTrimLeft => generateExpressionWithName("LTRIM", trim.children) + case trim: StringTrimRight => generateExpressionWithName("RTRIM", trim.children) case overlay: Overlay => val children = if (overlay.len == Literal(-1)) { Seq(overlay.input, overlay.replace, overlay.pos) } else { overlay.children } - val childrenExpressions = children.flatMap(generateExpression(_)) - if (childrenExpressions.length == children.length) { - Some(new GeneralScalarExpression("OVERLAY", - childrenExpressions.toArray[V2Expression])) - } else { - None - } - case date: DateAdd => - val childrenExpressions = date.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == date.children.length) { - Some(new GeneralScalarExpression("DATE_ADD", childrenExpressions.toArray[V2Expression])) - } else { - None - } - case date: DateDiff => - val childrenExpressions = date.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == date.children.length) { - Some(new GeneralScalarExpression("DATE_DIFF", childrenExpressions.toArray[V2Expression])) - } else { - None - } - case date: TruncDate => - val childrenExpressions = date.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == date.children.length) { - Some(new GeneralScalarExpression("TRUNC", childrenExpressions.toArray[V2Expression])) - } else { - None - } + generateExpressionWithName("OVERLAY", children) + case date: DateAdd => generateExpressionWithName("DATE_ADD", date.children) + case date: DateDiff => generateExpressionWithName("DATE_DIFF", date.children) + case date: TruncDate => generateExpressionWithName("TRUNC", date.children) case Second(child, _) => generateExpression(child).map(v => new V2Extract("SECOND", v)) case Minute(child, _) => @@ -397,6 +268,10 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { generateExpression(child).map(v => new V2Extract("WEEK", v)) case YearOfWeek(child) => generateExpression(child).map(v => new V2Extract("YEAR_OF_WEEK", v)) + case Crc32(child) => generateExpressionWithName("CRC32", Seq(child)) + case Md5(child) => generateExpressionWithName("MD5", Seq(child)) + case Sha1(child) => generateExpressionWithName("SHA1", Seq(child)) + case sha2: Sha2 => generateExpressionWithName("SHA2", sha2.children) // TODO supports other expressions case ApplyFunctionExpression(function, children) => val childrenExpressions = children.flatMap(generateExpression(_)) @@ -408,6 +283,66 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } case _ => None } + + private def generateAggregateFunc( + aggregateFunction: AggregateFunction, + isDistinct: Boolean): Option[AggregateFunc] = aggregateFunction match { + case aggregate.Min(PushableExpression(expr)) => Some(new Min(expr)) + case aggregate.Max(PushableExpression(expr)) => Some(new Max(expr)) + case count: aggregate.Count if count.children.length == 1 => + count.children.head match { + // COUNT(any literal) is the same as COUNT(*) + case Literal(_, _) => Some(new CountStar()) + case PushableExpression(expr) => Some(new Count(expr, isDistinct)) + case _ => None + } + case aggregate.Sum(PushableExpression(expr), _) => Some(new Sum(expr, isDistinct)) + case aggregate.Average(PushableExpression(expr), _) => Some(new Avg(expr, isDistinct)) + case aggregate.VariancePop(PushableExpression(expr), _) => + Some(new GeneralAggregateFunc("VAR_POP", isDistinct, Array(expr))) + case aggregate.VarianceSamp(PushableExpression(expr), _) => + Some(new GeneralAggregateFunc("VAR_SAMP", isDistinct, Array(expr))) + case aggregate.StddevPop(PushableExpression(expr), _) => + Some(new GeneralAggregateFunc("STDDEV_POP", isDistinct, Array(expr))) + case aggregate.StddevSamp(PushableExpression(expr), _) => + Some(new GeneralAggregateFunc("STDDEV_SAMP", isDistinct, Array(expr))) + case aggregate.CovPopulation(PushableExpression(left), PushableExpression(right), _) => + Some(new GeneralAggregateFunc("COVAR_POP", isDistinct, Array(left, right))) + case aggregate.CovSample(PushableExpression(left), PushableExpression(right), _) => + Some(new GeneralAggregateFunc("COVAR_SAMP", isDistinct, Array(left, right))) + case aggregate.Corr(PushableExpression(left), PushableExpression(right), _) => + Some(new GeneralAggregateFunc("CORR", isDistinct, Array(left, right))) + // TODO supports other aggregate functions + case aggregate.V2Aggregator(aggrFunc, children, _, _) => + val translatedExprs = children.flatMap(PushableExpression.unapply(_)) + if (translatedExprs.length == children.length) { + Some(new UserDefinedAggregateFunc(aggrFunc.name(), + aggrFunc.canonicalName(), isDistinct, translatedExprs.toArray[V2Expression])) + } else { + None + } + case _ => None + } + + private def flipComparisonOperatorName(operatorName: String): String = { + operatorName match { + case ">" => "<" + case "<" => ">" + case ">=" => "<=" + case "<=" => ">=" + case _ => operatorName + } + } + + private def generateExpressionWithName( + v2ExpressionName: String, children: Seq[Expression]): Option[V2Expression] = { + val childrenExpressions = children.flatMap(generateExpression(_)) + if (childrenExpressions.length == children.length) { + Some(new GeneralScalarExpression(v2ExpressionName, childrenExpressions.toArray[V2Expression])) + } else { + None + } + } } object ColumnOrField { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index d95e86bba0528..00b1ec749d762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -160,7 +160,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic * A pattern that finds the partitioned table relation node inside the given plan, and returns a * pair of the partition attributes and the table relation node. */ - object PartitionedRelation extends PredicateHelper { + object PartitionedRelation { def unapply(plan: LogicalPlan): Option[(AttributeSet, LogicalPlan)] = { plan match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 422f1f041a58b..878e2599c4a7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -167,9 +167,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Supports both equi-joins and non-equi-joins. * Supports only inner like joins. */ - object JoinSelection extends Strategy - with PredicateHelper - with JoinSelectionHelper { + object JoinSelection extends Strategy with JoinSelectionHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala index bcf9dc1544ce3..f5484b56bdf1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.Strategy -import org.apache.spark.sql.catalyst.expressions.PredicateHelper import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, ExtractSingleColumnNullAwareAntiJoin} import org.apache.spark.sql.catalyst.plans.LeftAnti @@ -35,7 +34,7 @@ import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNes * stage in case of the larger join child relation finishes before the smaller relation. Note * that this rule needs to applied before regular join strategies. */ -object LogicalQueryStageStrategy extends Strategy with PredicateHelper { +object LogicalQueryStageStrategy extends Strategy { private def isBroadcastStage(plan: LogicalPlan): Boolean = plan match { case LogicalQueryStage(_, _: BroadcastQueryStageExec) => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 5709e2e1484df..67297d99d254c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -41,8 +41,8 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue} -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ @@ -700,59 +700,6 @@ object DataSourceStrategy (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters) } - protected[sql] def translateAggregate(agg: AggregateExpression): Option[AggregateFunc] = { - if (agg.filter.isEmpty) { - agg.aggregateFunction match { - case aggregate.Min(PushableExpression(expr)) => Some(new Min(expr)) - case aggregate.Max(PushableExpression(expr)) => Some(new Max(expr)) - case count: aggregate.Count if count.children.length == 1 => - count.children.head match { - // COUNT(any literal) is the same as COUNT(*) - case Literal(_, _) => Some(new CountStar()) - case PushableExpression(expr) => Some(new Count(expr, agg.isDistinct)) - case _ => None - } - case aggregate.Sum(PushableExpression(expr), _) => Some(new Sum(expr, agg.isDistinct)) - case aggregate.Average(PushableExpression(expr), _) => Some(new Avg(expr, agg.isDistinct)) - case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) => - Some(new GeneralAggregateFunc( - "VAR_POP", agg.isDistinct, Array(FieldReference(name)))) - case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) => - Some(new GeneralAggregateFunc( - "VAR_SAMP", agg.isDistinct, Array(FieldReference(name)))) - case aggregate.StddevPop(PushableColumnWithoutNestedColumn(name), _) => - Some(new GeneralAggregateFunc( - "STDDEV_POP", agg.isDistinct, Array(FieldReference(name)))) - case aggregate.StddevSamp(PushableColumnWithoutNestedColumn(name), _) => - Some(new GeneralAggregateFunc( - "STDDEV_SAMP", agg.isDistinct, Array(FieldReference(name)))) - case aggregate.CovPopulation(PushableColumnWithoutNestedColumn(left), - PushableColumnWithoutNestedColumn(right), _) => - Some(new GeneralAggregateFunc("COVAR_POP", agg.isDistinct, - Array(FieldReference(left), FieldReference(right)))) - case aggregate.CovSample(PushableColumnWithoutNestedColumn(left), - PushableColumnWithoutNestedColumn(right), _) => - Some(new GeneralAggregateFunc("COVAR_SAMP", agg.isDistinct, - Array(FieldReference(left), FieldReference(right)))) - case aggregate.Corr(PushableColumnWithoutNestedColumn(left), - PushableColumnWithoutNestedColumn(right), _) => - Some(new GeneralAggregateFunc("CORR", agg.isDistinct, - Array(FieldReference(left), FieldReference(right)))) - case aggregate.V2Aggregator(aggrFunc, children, _, _) => - val translatedExprs = children.flatMap(PushableExpression.unapply(_)) - if (translatedExprs.length == children.length) { - Some(new UserDefinedAggregateFunc(aggrFunc.name(), - aggrFunc.canonicalName(), agg.isDistinct, translatedExprs.toArray[V2Expression])) - } else { - None - } - case _ => None - } - } else { - None - } - } - /** * Translate aggregate expressions and group by expressions. * @@ -761,13 +708,13 @@ object DataSourceStrategy protected[sql] def translateAggregation( aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Option[Aggregation] = { - def translateGroupBy(e: Expression): Option[V2Expression] = e match { + def translate(e: Expression): Option[V2Expression] = e match { case PushableExpression(expr) => Some(expr) case _ => None } - val translatedAggregates = aggregates.flatMap(translateAggregate) - val translatedGroupBys = groupBy.flatMap(translateGroupBy) + val translatedAggregates = aggregates.flatMap(translate).asInstanceOf[Seq[AggregateFunc]] + val translatedGroupBys = groupBy.flatMap(translate) if (translatedAggregates.length != aggregates.length || translatedGroupBys.length != groupBy.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 2e8e5426d47be..42cd2c99090ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -32,8 +32,7 @@ import org.apache.spark.sql.catalyst.rules.Rule * statistics will be updated. And the partition filters will be kept in the filters of returned * logical plan. */ -private[sql] object PruneFileSourcePartitions - extends Rule[LogicalPlan] with PredicateHelper { +private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { private def rebuildPhysicalOperation( projects: Seq[NamedExpression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 66d26446f40a6..41b0432980130 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, SchemaPruning} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.filter.Predicate @@ -29,7 +29,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType -object PushDownUtils extends PredicateHelper { +object PushDownUtils { /** * Pushes down filters to the data source reader * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index e90f59f310fcb..c3b50550a5e7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.planning.ScanOperation @@ -34,7 +34,7 @@ import org.apache.spark.sql.sources import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType, StructType} import org.apache.spark.sql.util.SchemaUtils._ -object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper with AliasHelper { +object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { import DataSourceV2Implicits._ def apply(plan: LogicalPlan): LogicalPlan = { @@ -189,12 +189,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit // +- ScanBuilderHolder[group_col_0#10, agg_func_0#21, agg_func_1#22] // Later, we build the `Scan` instance and convert ScanBuilderHolder to DataSourceV2ScanRelation. // scalastyle:on - val groupOutput = normalizedGroupingExpr.zipWithIndex.map { case (e, i) => - AttributeReference(s"group_col_$i", e.dataType)() + val groupOutputMap = normalizedGroupingExpr.zipWithIndex.map { case (e, i) => + AttributeReference(s"group_col_$i", e.dataType)() -> e } - val aggOutput = finalAggExprs.zipWithIndex.map { case (e, i) => - AttributeReference(s"agg_func_$i", e.dataType)() + val groupOutput = groupOutputMap.unzip._1 + val aggOutputMap = finalAggExprs.zipWithIndex.map { case (e, i) => + AttributeReference(s"agg_func_$i", e.dataType)() -> e } + val aggOutput = aggOutputMap.unzip._1 val newOutput = groupOutput ++ aggOutput val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] normalizedGroupingExpr.zipWithIndex.foreach { case (expr, ordinal) => @@ -204,6 +206,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit } holder.pushedAggregate = Some(translatedAgg) + holder.pushedAggOutputMap = AttributeMap(groupOutputMap ++ aggOutputMap) holder.output = newOutput logInfo( s""" @@ -406,15 +409,21 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit sHolder.pushedLimit = Some(limit) } (operation, isPushed && !isPartiallyPushed) - case s @ Sort(order, _, operation @ ScanOperation(project, filter, sHolder: ScanBuilderHolder)) - // Without building the Scan, we do not know the resulting column names after aggregate - // push-down, and thus can't push down Top-N which needs to know the ordering column names. - // TODO: we can support simple cases like GROUP BY columns directly and ORDER BY the same - // columns, which we know the resulting column names: the original table columns. - if sHolder.pushedAggregate.isEmpty && filter.isEmpty && - CollapseProject.canCollapseExpressions(order, project, alwaysInline = true) => + case s @ Sort(order, _, operation @ ScanOperation(project, Nil, sHolder: ScanBuilderHolder)) + if CollapseProject.canCollapseExpressions(order, project, alwaysInline = true) => val aliasMap = getAliasMap(project) - val newOrder = order.map(replaceAlias(_, aliasMap)).asInstanceOf[Seq[SortOrder]] + val aliasReplacedOrder = order.map(replaceAlias(_, aliasMap)) + val newOrder = if (sHolder.pushedAggregate.isDefined) { + // `ScanBuilderHolder` has different output columns after aggregate push-down. Here we + // replace the attributes in ordering expressions with the original table output columns. + aliasReplacedOrder.map { + _.transform { + case a: Attribute => sHolder.pushedAggOutputMap.getOrElse(a, a) + }.asInstanceOf[SortOrder] + } + } else { + aliasReplacedOrder.asInstanceOf[Seq[SortOrder]] + } val normalizedOrders = DataSourceStrategy.normalizeExprs( newOrder, sHolder.relation.output).asInstanceOf[Seq[SortOrder]] val orders = DataSourceStrategy.translateSortOrders(normalizedOrders) @@ -544,6 +553,8 @@ case class ScanBuilderHolder( var pushedPredicates: Seq[Predicate] = Seq.empty[Predicate] var pushedAggregate: Option[Aggregation] = None + + var pushedAggOutputMap: AttributeMap[Expression] = AttributeMap.empty[Expression] } // A wrapper for v1 scan to carry the translated filters and the handled ones, along with diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala index 9a05e396d4a70..a3ed0633dd807 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.dynamicpruning import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeSeq, BindReferences, DynamicPruningExpression, DynamicPruningSubquery, Expression, ListQuery, Literal, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeSeq, BindReferences, DynamicPruningExpression, DynamicPruningSubquery, Expression, ListQuery, Literal} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode @@ -34,8 +34,7 @@ import org.apache.spark.sql.execution.joins._ * results of broadcast. For joins that are not planned as broadcast hash joins we keep * the fallback mechanism with subquery duplicate. */ -case class PlanDynamicPruningFilters(sparkSession: SparkSession) - extends Rule[SparkPlan] with PredicateHelper { +case class PlanDynamicPruningFilters(sparkSession: SparkSession) extends Rule[SparkPlan] { /** * Identify the shape in which keys of a given plan are broadcasted. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 62b99f74f96ac..6a765d6b7e1ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -158,7 +158,7 @@ object ExtractGroupingPythonUDFFromAggregate extends Rule[LogicalPlan] { * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper { +object ExtractPythonUDFs extends Rule[LogicalPlan] { private type EvalType = Int private type EvalTypeChecker = EvalType => Boolean diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 5a909b704e24c..b847fcdd19136 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -45,7 +45,8 @@ private[sql] object H2Dialect extends JdbcDialect { Set("ABS", "COALESCE", "GREATEST", "LEAST", "RAND", "LOG", "LOG10", "LN", "EXP", "POWER", "SQRT", "FLOOR", "CEIL", "ROUND", "SIN", "SINH", "COS", "COSH", "TAN", "TANH", "COT", "ASIN", "ACOS", "ATAN", "ATAN2", "DEGREES", "RADIANS", "SIGN", - "PI", "SUBSTRING", "UPPER", "LOWER", "TRANSLATE", "TRIM") + "PI", "SUBSTRING", "UPPER", "LOWER", "TRANSLATE", "TRIM", "MD5", "SHA1", "SHA2", + "BIT_LENGTH", "CHAR_LENGTH", "CONCAT") override def isSupportedFunction(funcName: String): Boolean = supportedFunctions.contains(funcName) @@ -123,5 +124,22 @@ private[sql] object H2Dialect extends JdbcDialect { } s"EXTRACT($newField FROM $source)" } + + override def visitSQLFunction(funcName: String, inputs: Array[String]): String = { + if (isSupportedFunction(funcName)) { + funcName match { + case "MD5" => + "RAWTOHEX(HASH('MD5', " + inputs.mkString(",") + "))" + case "SHA1" => + "RAWTOHEX(HASH('SHA-1', " + inputs.mkString(",") + "))" + case "SHA2" => + "RAWTOHEX(HASH('SHA-" + inputs(1) + "'," + inputs(0) + "))" + case _ => super.visitSQLFunction(funcName, inputs) + } + } else { + throw new UnsupportedOperationException( + s"${this.getClass.getSimpleName} does not support function: $funcName"); + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 50f32126e5dec..a87e5761560a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet} import org.apache.spark.sql.catalyst.util import org.apache.spark.sql.execution.{DataSourceScanExec, FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation @@ -40,7 +40,7 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType} import org.apache.spark.util.Utils -class FileSourceStrategySuite extends QueryTest with SharedSparkSession with PredicateHelper { +class FileSourceStrategySuite extends QueryTest with SharedSparkSession { import testImplicits._ protected override def sparkConf = super.sparkConf.set("spark.default.parallelism", "1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala index 1a5a382afdc6b..b3ec30823713b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala @@ -18,14 +18,77 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, StructField, StructType} class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession { + val attrInts = Seq( + $"cint".int, + $"c.int".int, + GetStructField($"a".struct(StructType( + StructField("cstr", StringType, nullable = true) :: + StructField("cint", IntegerType, nullable = true) :: Nil)), 1, None), + GetStructField($"a".struct(StructType( + StructField("c.int", IntegerType, nullable = true) :: + StructField("cstr", StringType, nullable = true) :: Nil)), 0, None), + GetStructField($"a.b".struct(StructType( + StructField("cstr1", StringType, nullable = true) :: + StructField("cstr2", StringType, nullable = true) :: + StructField("cint", IntegerType, nullable = true) :: Nil)), 2, None), + GetStructField($"a.b".struct(StructType( + StructField("c.int", IntegerType, nullable = true) :: Nil)), 0, None), + GetStructField(GetStructField($"a".struct(StructType( + StructField("cstr1", StringType, nullable = true) :: + StructField("b", StructType(StructField("cint", IntegerType, nullable = true) :: + StructField("cstr2", StringType, nullable = true) :: Nil)) :: Nil)), 1, None), 0, None) + ).zip(Seq( + "cint", + "`c.int`", // single level field that contains `dot` in name + "a.cint", // two level nested field + "a.`c.int`", // two level nested field, and nested level contains `dot` + "`a.b`.cint", // two level nested field, and top level contains `dot` + "`a.b`.`c.int`", // two level nested field, and both levels contain `dot` + "a.b.cint" // three level nested field + )) + + test("SPARK-39784: translate binary expression") { attrInts + .foreach { case (attrInt, intColName) => + testTranslateFilter(EqualTo(attrInt, 1), + Some(new Predicate("=", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + testTranslateFilter(EqualTo(1, attrInt), + Some(new Predicate("=", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + + testTranslateFilter(EqualNullSafe(attrInt, 1), + Some(new Predicate("<=>", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + testTranslateFilter(EqualNullSafe(1, attrInt), + Some(new Predicate("<=>", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + + testTranslateFilter(GreaterThan(attrInt, 1), + Some(new Predicate(">", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + testTranslateFilter(GreaterThan(1, attrInt), + Some(new Predicate("<", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + + testTranslateFilter(LessThan(attrInt, 1), + Some(new Predicate("<", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + testTranslateFilter(LessThan(1, attrInt), + Some(new Predicate(">", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + + testTranslateFilter(GreaterThanOrEqual(attrInt, 1), + Some(new Predicate(">=", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + testTranslateFilter(GreaterThanOrEqual(1, attrInt), + Some(new Predicate("<=", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + + testTranslateFilter(LessThanOrEqual(attrInt, 1), + Some(new Predicate("<=", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + testTranslateFilter(LessThanOrEqual(1, attrInt), + Some(new Predicate(">=", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + } + } + test("SPARK-36644: Push down boolean column filter") { testTranslateFilter(Symbol("col").boolean, Some(new Predicate("=", Array(FieldReference("col"), LiteralValue(true, BooleanType))))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index be2c8ce057559..cbb3dc250b07f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -43,6 +43,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val tempDir = Utils.createTempDir() val url = s"jdbc:h2:${tempDir.getCanonicalPath};user=testUser;password=testPass" + val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) ++ + Array.fill(15)(0.toByte) var conn: java.sql.Connection = null val testH2Dialect = new JdbcDialect { @@ -177,6 +179,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel "('amy', '2022-05-19', '2022-05-19 00:00:00')").executeUpdate() conn.prepareStatement("INSERT INTO \"test\".\"datetime\" VALUES " + "('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate() + + conn.prepareStatement("CREATE TABLE \"test\".\"binary1\" (name TEXT(32),b BINARY(20))") + .executeUpdate() + val stmt = conn.prepareStatement("INSERT INTO \"test\".\"binary1\" VALUES (?, ?)") + stmt.setString(1, "jen") + stmt.setBytes(2, testBytes) + stmt.executeUpdate() } H2Dialect.registerFunction("my_avg", IntegralAverage) H2Dialect.registerFunction("my_strlen", StrLen(CharLength)) @@ -209,7 +218,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("TABLESAMPLE (integer_expression ROWS) is the same as LIMIT") { val df = sql("SELECT NAME FROM h2.test.employee TABLESAMPLE (3 ROWS)") checkSchemaNames(df, Seq("NAME")) - checkPushedInfo(df, "PushedFilters: [], PushedLimit: LIMIT 3, ") + checkPushedInfo(df, + "PushedFilters: []", + "PushedLimit: LIMIT 3") checkAnswer(df, Seq(Row("amy"), Row("alex"), Row("cathy"))) } @@ -237,7 +248,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .where($"dept" === 1).limit(1) checkLimitRemoved(df1) checkPushedInfo(df1, - "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], PushedLimit: LIMIT 1, ") + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1]", + "PushedLimit: LIMIT 1") checkAnswer(df1, Seq(Row(1, "amy", 10000.00, 1000.0, true))) val df2 = spark.read @@ -250,14 +262,16 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .limit(1) checkLimitRemoved(df2, false) checkPushedInfo(df2, - "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], PushedLimit: LIMIT 1, ") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 1]", + "PushedLimit: LIMIT 1") checkAnswer(df2, Seq(Row(2, "alex", 12000.00, 1200.0, false))) val df3 = sql("SELECT name FROM h2.test.employee WHERE dept > 1 LIMIT 1") checkSchemaNames(df3, Seq("NAME")) checkLimitRemoved(df3) checkPushedInfo(df3, - "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], PushedLimit: LIMIT 1, ") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 1]", + "PushedLimit: LIMIT 1") checkAnswer(df3, Seq(Row("alex"))) val df4 = spark.read @@ -282,7 +296,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .limit(1) checkLimitRemoved(df5, false) // LIMIT is pushed down only if all the filters are pushed down - checkPushedInfo(df5, "PushedFilters: [], ") + checkPushedInfo(df5, "PushedFilters: []") checkAnswer(df5, Seq(Row(10000.00, 1000.0, "amy"))) } @@ -304,7 +318,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .offset(1) checkOffsetRemoved(df1) checkPushedInfo(df1, - "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], PushedOffset: OFFSET 1,") + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1]", + "PushedOffset: OFFSET 1") checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) val df2 = spark.read @@ -314,7 +329,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .offset(1) checkOffsetRemoved(df2, false) checkPushedInfo(df2, - "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ReadSchema:") + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1]", + "ReadSchema:") checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) val df3 = spark.read @@ -324,7 +340,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .offset(1) checkOffsetRemoved(df3, false) checkPushedInfo(df3, - "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ReadSchema:") + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1]", + "ReadSchema:") checkAnswer(df3, Seq(Row(1, "amy", 10000.00, 1000.0, true))) val df4 = spark.read @@ -336,7 +353,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"dept" > 1) .offset(1) checkOffsetRemoved(df4, false) - checkPushedInfo(df4, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], ReadSchema:") + checkPushedInfo(df4, + "PushedFilters: [DEPT IS NOT NULL, DEPT > 1]", + "ReadSchema:") checkAnswer(df4, Seq(Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true))) val df5 = spark.read @@ -361,7 +380,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .offset(1) checkOffsetRemoved(df6, false) // OFFSET is pushed down only if all the filters are pushed down - checkPushedInfo(df6, "PushedFilters: [], ") + checkPushedInfo(df6, "PushedFilters: []") checkAnswer(df6, Seq(Row(10000.00, 1300.0, "dav"), Row(9000.00, 1200.0, "cat"))) } @@ -374,7 +393,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkLimitRemoved(df1) checkOffsetRemoved(df1) checkPushedInfo(df1, - "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], PushedLimit: LIMIT 2, PushedOffset: OFFSET 1,") + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1]", + "PushedLimit: LIMIT 2", + "PushedOffset: OFFSET 1") checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) val df2 = spark.read @@ -386,7 +407,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkLimitRemoved(df2, false) checkOffsetRemoved(df2, false) checkPushedInfo(df2, - "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ReadSchema:") + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1]", + "ReadSchema:") checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) val df3 = spark.read @@ -398,7 +420,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkLimitRemoved(df3) checkOffsetRemoved(df3, false) checkPushedInfo(df3, - "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], PushedLimit: LIMIT 2, ReadSchema:") + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1]", + "PushedLimit: LIMIT 2", + "ReadSchema:") checkAnswer(df3, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) val df4 = spark.read @@ -411,7 +435,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkLimitRemoved(df4, false) checkOffsetRemoved(df4, false) checkPushedInfo(df4, - "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ReadSchema:") + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1]", + " ReadSchema:") checkAnswer(df4, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) val df5 = spark.read @@ -422,8 +447,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .offset(1) checkLimitRemoved(df5) checkOffsetRemoved(df5) - checkPushedInfo(df5, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], " + - "PushedOffset: OFFSET 1, PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 2, ReadSchema:") + checkPushedInfo(df5, + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1]", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 2", + "ReadSchema:") checkAnswer(df5, Seq(Row(1, "amy", 10000.00, 1000.0, true))) val df6 = spark.read @@ -435,7 +463,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .offset(1) checkLimitRemoved(df6, false) checkOffsetRemoved(df6, false) - checkPushedInfo(df6, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ReadSchema:") + checkPushedInfo(df6, + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1]", + "ReadSchema:") checkAnswer(df6, Seq(Row(1, "amy", 10000.00, 1000.0, true))) val df7 = spark.read @@ -447,8 +477,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .offset(1) checkLimitRemoved(df7) checkOffsetRemoved(df7, false) - checkPushedInfo(df7, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1]," + - " PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 2, ReadSchema:") + checkPushedInfo(df7, + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1]", + "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 2", + "ReadSchema:") checkAnswer(df7, Seq(Row(1, "amy", 10000.00, 1000.0, true))) val df8 = spark.read @@ -461,7 +493,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .offset(1) checkLimitRemoved(df8, false) checkOffsetRemoved(df8, false) - checkPushedInfo(df8, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ReadSchema:") + checkPushedInfo(df8, + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1]", + "ReadSchema:") checkAnswer(df8, Seq(Row(1, "amy", 10000.00, 1000.0, true))) val df9 = spark.read @@ -476,7 +510,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkLimitRemoved(df9, false) checkOffsetRemoved(df9, false) checkPushedInfo(df9, - "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], PushedLimit: LIMIT 2, ReadSchema:") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 1]", + "PushedLimit: LIMIT 2", " ReadSchema:") checkAnswer(df9, Seq(Row(2, "david", 10000.00, 1300.0, true))) val df10 = spark.read @@ -505,7 +540,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .offset(1) checkLimitRemoved(df11, false) checkOffsetRemoved(df11, false) - checkPushedInfo(df11, "PushedFilters: [], ") + checkPushedInfo(df11, "PushedFilters: []") checkAnswer(df11, Seq(Row(9000.00, 1200.0, "cat"))) } @@ -518,7 +553,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkLimitRemoved(df1) checkOffsetRemoved(df1) checkPushedInfo(df1, - "[DEPT IS NOT NULL, DEPT = 1], PushedLimit: LIMIT 2, PushedOffset: OFFSET 1,") + "[DEPT IS NOT NULL, DEPT = 1]", + "PushedLimit: LIMIT 2", + " PushedOffset: OFFSET 1") checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) val df2 = spark.read @@ -530,7 +567,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkLimitRemoved(df2) checkOffsetRemoved(df2, false) checkPushedInfo(df2, - "[DEPT IS NOT NULL, DEPT = 1], PushedLimit: LIMIT 2, ReadSchema:") + "[DEPT IS NOT NULL, DEPT = 1]", + "PushedLimit: LIMIT 2", + "ReadSchema:") checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) val df3 = spark.read @@ -542,7 +581,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkLimitRemoved(df3, false) checkOffsetRemoved(df3) checkPushedInfo(df3, - "[DEPT IS NOT NULL, DEPT = 1], PushedOffset: OFFSET 1, ReadSchema:") + "[DEPT IS NOT NULL, DEPT = 1]", + "PushedOffset: OFFSET 1", + "ReadSchema:") checkAnswer(df3, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) val df4 = spark.read @@ -555,7 +596,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkLimitRemoved(df4, false) checkOffsetRemoved(df4, false) checkPushedInfo(df4, - "[DEPT IS NOT NULL, DEPT = 1], ReadSchema:") + "[DEPT IS NOT NULL, DEPT = 1]", + "ReadSchema:") checkAnswer(df4, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) val df5 = spark.read @@ -566,8 +608,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .limit(1) checkLimitRemoved(df5) checkOffsetRemoved(df5) - checkPushedInfo(df5, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], " + - "PushedOffset: OFFSET 1, PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 2, ReadSchema:") + checkPushedInfo(df5, + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1]", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 2", + "ReadSchema:") checkAnswer(df5, Seq(Row(1, "amy", 10000.00, 1000.0, true))) val df6 = spark.read @@ -579,8 +624,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .limit(1) checkLimitRemoved(df6) checkOffsetRemoved(df6, false) - checkPushedInfo(df6, "[DEPT IS NOT NULL, DEPT = 1]," + - " PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 2, ReadSchema:") + checkPushedInfo(df6, + "[DEPT IS NOT NULL, DEPT = 1]", + "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 2", + "ReadSchema:") checkAnswer(df6, Seq(Row(1, "amy", 10000.00, 1000.0, true))) val df7 = spark.read @@ -592,7 +639,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .limit(1) checkLimitRemoved(df7, false) checkOffsetRemoved(df7, false) - checkPushedInfo(df7, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ReadSchema:") + checkPushedInfo(df7, + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1]", + "ReadSchema:") checkAnswer(df7, Seq(Row(1, "amy", 10000.00, 1000.0, true))) val df8 = spark.read @@ -605,7 +654,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .limit(1) checkLimitRemoved(df8, false) checkOffsetRemoved(df8, false) - checkPushedInfo(df8, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ReadSchema:") + checkPushedInfo(df8, + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1]", + "ReadSchema:") checkAnswer(df8, Seq(Row(1, "amy", 10000.00, 1000.0, true))) val df9 = spark.read @@ -620,7 +671,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkLimitRemoved(df9, false) checkOffsetRemoved(df9, false) checkPushedInfo(df9, - "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], PushedLimit: LIMIT 2, ReadSchema:") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 1]", + "PushedLimit: LIMIT 2", + "ReadSchema:") checkAnswer(df9, Seq(Row(2, "david", 10000.00, 1300.0, true))) val df10 = sql("SELECT dept, sum(salary) FROM h2.test.employee group by dept LIMIT 1 OFFSET 1") @@ -645,7 +698,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .limit(1) checkLimitRemoved(df11, false) checkOffsetRemoved(df11, false) - checkPushedInfo(df11, "PushedFilters: [], ") + checkPushedInfo(df11, "PushedFilters: []") checkAnswer(df11, Seq(Row(9000.00, 1200.0, "cat"))) } @@ -668,7 +721,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkSortRemoved(df1) checkLimitRemoved(df1) checkPushedInfo(df1, - "PushedFilters: [], PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1, ") + "PushedFilters: []", + "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1") checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) val df2 = spark.read @@ -682,8 +736,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .limit(1) checkSortRemoved(df2) checkLimitRemoved(df2) - checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], " + - "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1, ") + checkPushedInfo(df2, + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1]", + "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1") checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) val df3 = spark.read @@ -697,8 +752,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .limit(1) checkSortRemoved(df3, false) checkLimitRemoved(df3, false) - checkPushedInfo(df3, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " + - "PushedTopN: ORDER BY [SALARY DESC NULLS LAST] LIMIT 1, ") + checkPushedInfo(df3, + "PushedFilters: [DEPT IS NOT NULL, DEPT > 1]", + "PushedTopN: ORDER BY [SALARY DESC NULLS LAST] LIMIT 1") checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0, false))) val df4 = @@ -706,67 +762,58 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkSchemaNames(df4, Seq("NAME")) checkSortRemoved(df4) checkLimitRemoved(df4) - checkPushedInfo(df4, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " + - "PushedTopN: ORDER BY [SALARY ASC NULLS LAST] LIMIT 1, ") + checkPushedInfo(df4, + "PushedFilters: [DEPT IS NOT NULL, DEPT > 1]", + "PushedTopN: ORDER BY [SALARY ASC NULLS LAST] LIMIT 1") checkAnswer(df4, Seq(Row("david"))) val df5 = spark.read.table("h2.test.employee") .where($"dept" === 1).orderBy($"salary") checkSortRemoved(df5, false) - checkPushedInfo(df5, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ") + checkPushedInfo(df5, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1]") checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0, false), Row(1, "amy", 10000.00, 1000.0, true))) - val df6 = spark.read - .table("h2.test.employee") - .groupBy("DEPT").sum("SALARY") - .orderBy("DEPT") - .limit(1) - checkSortRemoved(df6, false) - checkLimitRemoved(df6, false) - checkPushedInfo(df6, "PushedAggregates: [SUM(SALARY)]," + - " PushedFilters: [], PushedGroupByExpressions: [DEPT], ") - checkAnswer(df6, Seq(Row(1, 19000.00))) - val name = udf { (x: String) => x.matches("cat|dav|amy") } val sub = udf { (x: String) => x.substring(0, 3) } - val df7 = spark.read + val df6 = spark.read .table("h2.test.employee") .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) .filter(name($"shortName")) .sort($"SALARY".desc) .limit(1) // LIMIT is pushed down only if all the filters are pushed down - checkSortRemoved(df7, false) - checkLimitRemoved(df7, false) - checkPushedInfo(df7, "PushedFilters: [], ") - checkAnswer(df7, Seq(Row(10000.00, 1000.0, "amy"))) + checkSortRemoved(df6, false) + checkLimitRemoved(df6, false) + checkPushedInfo(df6, "PushedFilters: []") + checkAnswer(df6, Seq(Row(10000.00, 1000.0, "amy"))) - val df8 = spark.read + val df7 = spark.read .table("h2.test.employee") .sort(sub($"NAME")) .limit(1) - checkSortRemoved(df8, false) - checkLimitRemoved(df8, false) - checkPushedInfo(df8, "PushedFilters: [], ") - checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0, false))) + checkSortRemoved(df7, false) + checkLimitRemoved(df7, false) + checkPushedInfo(df7, "PushedFilters: []") + checkAnswer(df7, Seq(Row(2, "alex", 12000.00, 1200.0, false))) - val df9 = spark.read + val df8 = spark.read .table("h2.test.employee") .select($"DEPT", $"name", $"SALARY", when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key")) .sort("key", "dept", "SALARY") .limit(3) - checkSortRemoved(df9) - checkLimitRemoved(df9) - checkPushedInfo(df9, "PushedFilters: [], " + - "PushedTopN: " + - "ORDER BY [CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END " + - "ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3,") - checkAnswer(df9, + checkSortRemoved(df8) + checkLimitRemoved(df8) + checkPushedInfo(df8, + "PushedFilters: []", + "PushedTopN: ORDER BY " + + "[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END" + + " ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3") + checkAnswer(df8, Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 12000, 0))) - val df10 = spark.read + val df9 = spark.read .option("partitionColumn", "dept") .option("lowerBound", "0") .option("upperBound", "2") @@ -776,13 +823,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key")) .orderBy($"key", $"dept", $"SALARY") .limit(3) - checkSortRemoved(df10, false) - checkLimitRemoved(df10, false) - checkPushedInfo(df10, "PushedFilters: [], " + - "PushedTopN: " + - "ORDER BY [CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END " + - "ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3,") - checkAnswer(df10, + checkSortRemoved(df9, false) + checkLimitRemoved(df9, false) + checkPushedInfo(df9, + "PushedFilters: []", + "PushedTopN: ORDER BY " + + "[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END " + + "ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3") + checkAnswer(df9, Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 12000, 0))) } @@ -794,7 +842,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .limit(1) checkSortRemoved(df1) checkPushedInfo(df1, - "PushedFilters: [], PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1, ") + "PushedFilters: []", + "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1") checkAnswer(df1, Seq(Row("cathy", 9000.00))) val df2 = spark.read @@ -805,15 +854,205 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .limit(1) checkSortRemoved(df2) checkPushedInfo(df2, - "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " + - "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1, ") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 1]", + "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1") checkAnswer(df2, Seq(Row(2, "david", 10000.00))) } + test("scan with aggregate push-down, top N push-down and offset push-down") { + val df1 = spark.read + .table("h2.test.employee") + .groupBy("DEPT").sum("SALARY") + .orderBy("DEPT") + + val paging1 = df1.offset(1).limit(1) + checkSortRemoved(paging1) + checkLimitRemoved(paging1) + checkPushedInfo(paging1, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [DEPT ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging1, Seq(Row(2, 22000.00))) + + val topN1 = df1.limit(1) + checkSortRemoved(topN1) + checkLimitRemoved(topN1) + checkPushedInfo(topN1, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedTopN: ORDER BY [DEPT ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN1, Seq(Row(1, 19000.00))) + + val df2 = spark.read + .table("h2.test.employee") + .select($"DEPT".cast("string").as("my_dept"), $"SALARY") + .groupBy("my_dept").sum("SALARY") + .orderBy("my_dept") + + val paging2 = df2.offset(1).limit(1) + checkSortRemoved(paging2) + checkLimitRemoved(paging2) + checkPushedInfo(paging2, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [CAST(DEPT AS string)]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging2, Seq(Row("2", 22000.00))) + + val topN2 = df2.limit(1) + checkSortRemoved(topN2) + checkLimitRemoved(topN2) + checkPushedInfo(topN2, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [CAST(DEPT AS string)]", + "PushedFilters: []", + "PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN2, Seq(Row("1", 19000.00))) + + val df3 = spark.read + .table("h2.test.employee") + .groupBy("dept").sum("SALARY") + .orderBy($"dept".cast("string")) + + val paging3 = df3.offset(1).limit(1) + checkSortRemoved(paging3) + checkLimitRemoved(paging3) + checkPushedInfo(paging3, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging3, Seq(Row(2, 22000.00))) + + val topN3 = df3.limit(1) + checkSortRemoved(topN3) + checkLimitRemoved(topN3) + checkPushedInfo(topN3, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN3, Seq(Row(1, 19000.00))) + + val df4 = spark.read + .table("h2.test.employee") + .groupBy("DEPT", "IS_MANAGER").sum("SALARY") + .orderBy("DEPT", "IS_MANAGER") + + val paging4 = df4.offset(1).limit(1) + checkSortRemoved(paging4) + checkLimitRemoved(paging4) + checkPushedInfo(paging4, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT, IS_MANAGER]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [DEPT ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging4, Seq(Row(1, true, 10000.00))) + + val topN4 = df4.limit(1) + checkSortRemoved(topN4) + checkLimitRemoved(topN4) + checkPushedInfo(topN4, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT, IS_MANAGER]", + "PushedFilters: []", + "PushedTopN: ORDER BY [DEPT ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN4, Seq(Row(1, false, 9000.00))) + + val df5 = spark.read + .table("h2.test.employee") + .select($"SALARY", $"IS_MANAGER", $"DEPT".cast("string").as("my_dept")) + .groupBy("my_dept", "IS_MANAGER").sum("SALARY") + .orderBy("my_dept", "IS_MANAGER") + + val paging5 = df5.offset(1).limit(1) + checkSortRemoved(paging5) + checkLimitRemoved(paging5) + checkPushedInfo(paging5, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [CAST(DEPT AS string), IS_MANAGER]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: " + + "ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging5, Seq(Row("1", true, 10000.00))) + + val topN5 = df5.limit(1) + checkSortRemoved(topN5) + checkLimitRemoved(topN5) + checkPushedInfo(topN5, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [CAST(DEPT AS string), IS_MANAGER]", + "PushedFilters: []", + "PushedTopN: " + + "ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN5, Seq(Row("1", false, 9000.00))) + + val df6 = spark.read + .table("h2.test.employee") + .select($"DEPT", $"SALARY") + .groupBy("dept").agg(sum("SALARY")) + .orderBy(sum("SALARY")) + + val paging6 = df6.offset(1).limit(1) + checkSortRemoved(paging6) + checkLimitRemoved(paging6) + checkPushedInfo(paging6, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging6, Seq(Row(1, 19000.00))) + + val topN6 = df6.limit(1) + checkSortRemoved(topN6) + checkLimitRemoved(topN6) + checkPushedInfo(topN6, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN6, Seq(Row(6, 12000.00))) + + val df7 = spark.read + .table("h2.test.employee") + .select($"DEPT", $"SALARY") + .groupBy("dept").agg(sum("SALARY").as("total")) + .orderBy("total") + + val paging7 = df7.offset(1).limit(1) + checkSortRemoved(paging7) + checkLimitRemoved(paging7) + checkPushedInfo(paging7, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging7, Seq(Row(1, 19000.00))) + + val topN7 = df7.limit(1) + checkSortRemoved(topN7) + checkLimitRemoved(topN7) + checkPushedInfo(topN7, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN7, Seq(Row(6, 12000.00))) + } + test("scan with filter push-down") { val df = spark.table("h2.test.people").filter($"id" > 1) checkFiltersRemoved(df) - checkPushedInfo(df, "PushedFilters: [ID IS NOT NULL, ID > 1], ") + checkPushedInfo(df, "PushedFilters: [ID IS NOT NULL, ID > 1]") checkAnswer(df, Row("mary", 2)) val df2 = spark.table("h2.test.employee").filter($"name".isin("amy", "cathy")) @@ -834,32 +1073,34 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df5 = spark.table("h2.test.employee").filter($"is_manager".and($"salary" > 10000)) checkFiltersRemoved(df5) - checkPushedInfo(df5, "PushedFilters: [IS_MANAGER IS NOT NULL, SALARY IS NOT NULL, " + + checkPushedInfo(df5, + "PushedFilters: [IS_MANAGER IS NOT NULL, SALARY IS NOT NULL", "IS_MANAGER = true, SALARY > 10000.00]") checkAnswer(df5, Seq(Row(6, "jen", 12000, 1200, true))) val df6 = spark.table("h2.test.employee").filter($"is_manager".or($"salary" > 10000)) checkFiltersRemoved(df6) - checkPushedInfo(df6, "PushedFilters: [(IS_MANAGER = true) OR (SALARY > 10000.00)], ") + checkPushedInfo(df6, "PushedFilters: [(IS_MANAGER = true) OR (SALARY > 10000.00)]") checkAnswer(df6, Seq(Row(1, "amy", 10000, 1000, true), Row(2, "alex", 12000, 1200, false), Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true))) val df7 = spark.table("h2.test.employee").filter(not($"is_manager") === true) checkFiltersRemoved(df7) - checkPushedInfo(df7, "PushedFilters: [IS_MANAGER IS NOT NULL, NOT (IS_MANAGER = true) = TRUE], ") + checkPushedInfo(df7, "PushedFilters: [IS_MANAGER IS NOT NULL, NOT (IS_MANAGER = true) = TRUE]") checkAnswer(df7, Seq(Row(1, "cathy", 9000, 1200, false), Row(2, "alex", 12000, 1200, false))) val df8 = spark.table("h2.test.employee").filter($"is_manager" === true) checkFiltersRemoved(df8) - checkPushedInfo(df8, "PushedFilters: [IS_MANAGER IS NOT NULL, IS_MANAGER = TRUE], ") + checkPushedInfo(df8, "PushedFilters: [IS_MANAGER IS NOT NULL, IS_MANAGER = TRUE]") checkAnswer(df8, Seq(Row(1, "amy", 10000, 1000, true), Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true))) val df9 = spark.table("h2.test.employee") .filter(when($"dept" > 1, true).when($"is_manager", false).otherwise($"dept" > 3)) checkFiltersRemoved(df9) - checkPushedInfo(df9, "PushedFilters: [CASE WHEN DEPT > 1 THEN TRUE " + - "WHEN IS_MANAGER = true THEN FALSE ELSE DEPT > 3 END], ") + checkPushedInfo(df9, + "PushedFilters: [CASE WHEN DEPT > 1 THEN TRUE", + "WHEN IS_MANAGER = true THEN FALSE ELSE DEPT > 3 END]") checkAnswer(df9, Seq(Row(2, "alex", 12000, 1200, false), Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true))) @@ -867,7 +1108,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .select($"NAME".as("myName"), $"ID".as("myID")) .filter($"myID" > 1) checkFiltersRemoved(df10) - checkPushedInfo(df10, "PushedFilters: [ID IS NOT NULL, ID > 1], ") + checkPushedInfo(df10, "PushedFilters: [ID IS NOT NULL, ID > 1]") checkAnswer(df10, Row("mary", 2)) val df11 = sql( @@ -1052,7 +1293,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel "CAST(BONUS AS string) LIKE '%30%', CAST(DEPT AS byte) > 1, " + "CAST(DEPT AS short) > 1, CAST(BONUS AS decimal(20,2)) > 1200.00]" } else { - "PushedFilters: [BONUS IS NOT NULL, DEPT IS NOT NULL]," + "PushedFilters: [BONUS IS NOT NULL, DEPT IS NOT NULL, CAST(BONUS AS string) LIKE '%30%']" } checkPushedInfo(df6, expectedPlanFragment6) checkAnswer(df6, Seq(Row(2, "david", 10000, 1300, true))) @@ -1131,6 +1372,47 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel "PushedFilters: [DATE1 IS NOT NULL, ((EXTRACT(DAY_OF_WEEK FROM DATE1) % 7) + 1) = 4]" checkPushedInfo(df8, expectedPlanFragment8) checkAnswer(df8, Seq(Row("alex"))) + + val df9 = sql("SELECT name FROM h2.test.datetime WHERE " + + "dayofyear(date1) > 100 order by dayofyear(date1) limit 1") + checkFiltersRemoved(df9) + val expectedPlanFragment9 = + "PushedFilters: [DATE1 IS NOT NULL, EXTRACT(DAY_OF_YEAR FROM DATE1) > 100], " + + "PushedTopN: ORDER BY [EXTRACT(DAY_OF_YEAR FROM DATE1) ASC NULLS FIRST] LIMIT 1," + checkPushedInfo(df9, expectedPlanFragment9) + checkAnswer(df9, Seq(Row("alex"))) + } + + test("scan with filter push-down with misc functions") { + val df1 = sql("SELECT name FROM h2.test.binary1 WHERE " + + "md5(b) = '4371fe0aa613bcb081543a37d241adcb'") + checkFiltersRemoved(df1) + val expectedPlanFragment1 = "PushedFilters: [B IS NOT NULL, " + + "MD5(B) = '4371fe0aa613bcb081543a37d241adcb']" + checkPushedInfo(df1, expectedPlanFragment1) + checkAnswer(df1, Seq(Row("jen"))) + + val df2 = sql("SELECT name FROM h2.test.binary1 WHERE " + + "sha1(b) = 'cf355e86e8666f9300ef12e996acd5c629e0b0a1'") + checkFiltersRemoved(df2) + val expectedPlanFragment2 = "PushedFilters: [B IS NOT NULL, " + + "SHA1(B) = 'cf355e86e8666f9300ef12e996acd5c629e0b0a1']," + checkPushedInfo(df2, expectedPlanFragment2) + checkAnswer(df2, Seq(Row("jen"))) + + val df3 = sql("SELECT name FROM h2.test.binary1 WHERE " + + "sha2(b, 256) = '911732d10153f859dec04627df38b19290ec707ff9f83910d061421fdc476109'") + checkFiltersRemoved(df3) + val expectedPlanFragment3 = "PushedFilters: [B IS NOT NULL, (SHA2(B, 256)) = " + + "'911732d10153f859dec04627df38b19290ec707ff9f83910d061421fdc476109']" + checkPushedInfo(df3, expectedPlanFragment3) + checkAnswer(df3, Seq(Row("jen"))) + + val df4 = sql("SELECT * FROM h2.test.employee WHERE crc32(name) = '142689369'") + checkFiltersRemoved(df4, false) + val expectedPlanFragment4 = "PushedFilters: [NAME IS NOT NULL], " + checkPushedInfo(df4, expectedPlanFragment4) + checkAnswer(df4, Seq(Row(6, "jen", 12000, 1200, true))) } test("scan with filter push-down with UDF") { @@ -1142,18 +1424,16 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkPushedInfo(df1, "PushedFilters: [CHAR_LENGTH(NAME) > 2],") checkAnswer(df1, Seq(Row("fred", 1), Row("mary", 2))) - withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { - val df2 = sql( - """ - |SELECT * - |FROM h2.test.people - |WHERE h2.my_strlen(CASE WHEN NAME = 'fred' THEN NAME ELSE "abc" END) > 2 + val df2 = sql( + """ + |SELECT * + |FROM h2.test.people + |WHERE h2.my_strlen(CASE WHEN NAME = 'fred' THEN NAME ELSE "abc" END) > 2 """.stripMargin) - checkFiltersRemoved(df2) - checkPushedInfo(df2, - "PushedFilters: [CHAR_LENGTH(CASE WHEN NAME = 'fred' THEN NAME ELSE 'abc' END) > 2],") - checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2))) - } + checkFiltersRemoved(df2) + checkPushedInfo(df2, + "PushedFilters: [CHAR_LENGTH(CASE WHEN NAME = 'fred' THEN NAME ELSE 'abc' END) > 2],") + checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2))) } finally { JdbcDialects.unregisterDialect(testH2Dialect) JdbcDialects.registerDialect(H2Dialect) @@ -1212,7 +1492,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel Seq(Row("test", "people", false), Row("test", "empty_table", false), Row("test", "employee", false), Row("test", "item", false), Row("test", "dept", false), Row("test", "person", false), Row("test", "view1", false), Row("test", "view2", false), - Row("test", "datetime", false))) + Row("test", "datetime", false), Row("test", "binary1", false))) } test("SQL API: create table as select") { @@ -1339,6 +1619,22 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel "PushedFilters: [NAME IS NOT NULL]" checkPushedInfo(df5, expectedPlanFragment5) checkAnswer(df5, Seq(Row(6, "jen", 12000, 1200, true))) + + val df6 = sql("SELECT * FROM h2.test.employee WHERE bit_length(name) = 40") + checkFiltersRemoved(df6) + checkPushedInfo(df6, "[NAME IS NOT NULL, BIT_LENGTH(NAME) = 40]") + checkAnswer(df6, Seq(Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true))) + + val df7 = sql("SELECT * FROM h2.test.employee WHERE char_length(name) = 5") + checkFiltersRemoved(df7) + checkPushedInfo(df7, "[NAME IS NOT NULL, CHAR_LENGTH(NAME) = 5]") + checkAnswer(df6, Seq(Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true))) + + val df8 = sql("SELECT * FROM h2.test.employee WHERE " + + "concat(name, ',' , cast(salary as string)) = 'cathy,9000.00'") + checkFiltersRemoved(df8) + checkPushedInfo(df8, "[(CONCAT(NAME, ',', CAST(SALARY AS string))) = 'cathy,9000.00']") + checkAnswer(df8, Seq(Row(1, "cathy", 9000, 1200, false))) } test("scan with aggregate push-down: MAX AVG with filter and group by") { @@ -1346,9 +1642,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel " GROUP BY DePt") checkFiltersRemoved(df) checkAggregateRemoved(df) - checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), AVG(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + - "PushedGroupByExpressions: [DEPT], ") + checkPushedInfo(df, + "PushedAggregates: [MAX(SALARY), AVG(BONUS)]", + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0]", + "PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(10000, 1100.0), Row(12000, 1250.0), Row(12000, 1200.0))) } @@ -1367,9 +1664,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df = sql("SELECT MAX(ID), AVG(ID) FROM h2.test.people WHERE id > 0") checkFiltersRemoved(df) checkAggregateRemoved(df) - checkPushedInfo(df, "PushedAggregates: [MAX(ID), AVG(ID)], " + - "PushedFilters: [ID IS NOT NULL, ID > 0], " + - "PushedGroupByExpressions: [], ") + checkPushedInfo(df, + "PushedAggregates: [MAX(ID), AVG(ID)]", + "PushedFilters: [ID IS NOT NULL, ID > 0]", + "PushedGroupByExpressions: []") checkAnswer(df, Seq(Row(2, 1.5))) } @@ -1414,7 +1712,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df = sql("SELECT name FROM h2.test.employee GROUP BY name") checkAggregateRemoved(df) checkPushedInfo(df, - "PushedAggregates: [], PushedFilters: [], PushedGroupByExpressions: [NAME],") + "PushedAggregates: []", + "PushedFilters: []", + "PushedGroupByExpressions: [NAME]") checkAnswer(df, Seq(Row("alex"), Row("amy"), Row("cathy"), Row("david"), Row("jen"))) val df2 = spark.read @@ -1427,7 +1727,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .agg(Map.empty[String, String]) checkAggregateRemoved(df2, false) checkPushedInfo(df2, - "PushedAggregates: [], PushedFilters: [], PushedGroupByExpressions: [NAME],") + "PushedAggregates: []", + "PushedFilters: []", + "PushedGroupByExpressions: [NAME]") checkAnswer(df2, Seq(Row("alex"), Row("amy"), Row("cathy"), Row("david"), Row("jen"))) val df3 = sql("SELECT CASE WHEN SALARY > 8000 AND SALARY < 10000 THEN SALARY ELSE 0 END as" + @@ -1504,8 +1806,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: SUM with group by") { val df1 = sql("SELECT SUM(SALARY) FROM h2.test.employee GROUP BY DEPT") checkAggregateRemoved(df1) - checkPushedInfo(df1, "PushedAggregates: [SUM(SALARY)], " + - "PushedFilters: [], PushedGroupByExpressions: [DEPT], ") + checkPushedInfo(df1, + "PushedAggregates: [SUM(SALARY)]", + "PushedFilters: []", + "PushedGroupByExpressions: [DEPT]") checkAnswer(df1, Seq(Row(19000), Row(22000), Row(12000))) val df2 = sql( @@ -1620,8 +1924,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: DISTINCT SUM with group by") { val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee GROUP BY DEPT") checkAggregateRemoved(df) - checkPushedInfo(df, "PushedAggregates: [SUM(DISTINCT SALARY)], " + - "PushedFilters: [], PushedGroupByExpressions: [DEPT]") + checkPushedInfo(df, + "PushedAggregates: [SUM(DISTINCT SALARY)]", + "PushedFilters: []", + "PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000))) } @@ -1630,8 +1936,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel " GROUP BY DEPT, NAME") checkFiltersRemoved(df) checkAggregateRemoved(df) - checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT, NAME]") + checkPushedInfo(df, + "PushedAggregates: [MAX(SALARY), MIN(BONUS)]", + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0]", + "PushedGroupByExpressions: [DEPT, NAME]") checkAnswer(df, Seq(Row(9000, 1200), Row(12000, 1200), Row(10000, 1300), Row(10000, 1000), Row(12000, 1200))) } @@ -1644,8 +1952,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } assert(filters1.isEmpty) checkAggregateRemoved(df1) - checkPushedInfo(df1, "PushedAggregates: [MAX(SALARY)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT, NAME]") + checkPushedInfo(df1, + "PushedAggregates: [MAX(SALARY)]", + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0]", + "PushedGroupByExpressions: [DEPT, NAME]") checkAnswer(df1, Seq(Row("1#amy", 10000), Row("1#cathy", 9000), Row("2#alex", 12000), Row("2#david", 10000), Row("6#jen", 12000))) @@ -1656,8 +1966,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } assert(filters2.isEmpty) checkAggregateRemoved(df2) - checkPushedInfo(df2, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT, NAME]") + checkPushedInfo(df2, + "PushedAggregates: [MAX(SALARY), MIN(BONUS)]", + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0]", + "PushedGroupByExpressions: [DEPT, NAME]") checkAnswer(df2, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200), Row("2#david", 11300), Row("6#jen", 13200))) @@ -1665,7 +1977,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel " FROM h2.test.employee WHERE dept > 0 GROUP BY concat_ws('#', DEPT, NAME)") checkFiltersRemoved(df3) checkAggregateRemoved(df3, false) - checkPushedInfo(df3, "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], ") + checkPushedInfo(df3, "PushedFilters: [DEPT IS NOT NULL, DEPT > 0]") checkAnswer(df3, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200), Row("2#david", 11300), Row("6#jen", 13200))) } @@ -1676,8 +1988,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel // filter over aggregate not push down checkFiltersRemoved(df, false) checkAggregateRemoved(df) - checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") + checkPushedInfo(df, + "PushedAggregates: [MAX(SALARY), MIN(BONUS)]", + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0]", + "PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(12000, 1200), Row(12000, 1200))) } @@ -1686,8 +2000,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .groupBy($"DEPT") .min("SALARY").as("total") checkAggregateRemoved(df) - checkPushedInfo(df, "PushedAggregates: [MIN(SALARY)], " + - "PushedFilters: [], PushedGroupByExpressions: [DEPT]") + checkPushedInfo(df, + "PushedAggregates: [MIN(SALARY)]", + "PushedFilters: []", + "PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(1, 9000), Row(2, 10000), Row(6, 12000))) } @@ -1701,8 +2017,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .orderBy($"total") checkFiltersRemoved(query, false)// filter over aggregate not pushed down checkAggregateRemoved(query) - checkPushedInfo(query, "PushedAggregates: [SUM(SALARY)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") + checkPushedInfo(query, + "PushedAggregates: [SUM(SALARY)]", + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0]", + "PushedGroupByExpressions: [DEPT]") checkAnswer(query, Seq(Row(6, 12000), Row(1, 19000), Row(2, 22000))) } @@ -1711,7 +2029,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val decrease = udf { (x: Double, y: Double) => x - y } val query = df.select(decrease(sum($"SALARY"), sum($"BONUS")).as("value")) checkAggregateRemoved(query) - checkPushedInfo(query, "PushedAggregates: [SUM(SALARY), SUM(BONUS)], ") + checkPushedInfo(query, "PushedAggregates: [SUM(SALARY), SUM(BONUS)]") checkAnswer(query, Seq(Row(47100.0))) } @@ -1740,12 +2058,12 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkFiltersRemoved(df) checkAggregateRemoved(df) checkPushedInfo(df, - """ - |PushedAggregates: [VAR_POP(BONUS), VAR_POP(DISTINCT BONUS), - |VAR_SAMP(BONUS), VAR_SAMP(DISTINCT BONUS)], - |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], - |PushedGroupByExpressions: [DEPT], - |""".stripMargin.replaceAll("\n", " ")) + """ + |PushedAggregates: [VAR_POP(BONUS), VAR_POP(DISTINCT BONUS), + |VAR_SAMP(BONUS), VAR_SAMP(DISTINCT BONUS)], + |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], + |PushedGroupByExpressions: [DEPT], + |""".stripMargin.replaceAll("\n", " ")) checkAnswer(df, Seq(Row(10000d, 10000d, 20000d, 20000d), Row(2500d, 2500d, 5000d, 5000d), Row(0d, 0d, null, null))) } @@ -1777,8 +2095,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel " FROM h2.test.employee WHERE dept > 0 GROUP BY DePt") checkFiltersRemoved(df1) checkAggregateRemoved(df1) - checkPushedInfo(df1, "PushedAggregates: [COVAR_POP(BONUS, BONUS), COVAR_SAMP(BONUS, BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") + checkPushedInfo(df1, + "PushedAggregates: [COVAR_POP(BONUS, BONUS), COVAR_SAMP(BONUS, BONUS)]", + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0]", + "PushedGroupByExpressions: [DEPT]") checkAnswer(df1, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) val df2 = sql("SELECT COVAR_POP(DISTINCT bonus, bonus), COVAR_SAMP(DISTINCT bonus, bonus)" + @@ -1794,8 +2114,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel " GROUP BY DePt") checkFiltersRemoved(df1) checkAggregateRemoved(df1) - checkPushedInfo(df1, "PushedAggregates: [CORR(BONUS, BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") + checkPushedInfo(df1, + "PushedAggregates: [CORR(BONUS, BONUS)]", + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0]", + "PushedGroupByExpressions: [DEPT]") checkAnswer(df1, Seq(Row(1d), Row(1d), Row(null))) val df2 = sql("SELECT CORR(DISTINCT bonus, bonus) FROM h2.test.employee WHERE dept > 0" + @@ -1875,8 +2197,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel "MAX(CASE WHEN (SALARY <= 8000.00) OR (SALARY >= 10000.00) THEN 0.00 ELSE SALARY END), " + "MIN(CASE WHEN (SALARY <= 8000.00) AND (SALARY IS NOT NULL) THEN SALARY ELSE 0.00 END), " + "SUM(CASE WHEN SALARY > 10000.00 THEN 2 WHEN SALARY > 8000.00 THEN 1 END), " + - "AVG(CASE WHEN (SALARY <= 8000.00) AND (SALARY IS NULL) THEN SALARY ELSE 0.00 END)], " + - "PushedFilters: [], " + + "AVG(CASE WHEN (SALARY <= 8000.00) AND (SALARY IS NULL) THEN SALARY ELSE 0.00 END)]", + "PushedFilters: []", "PushedGroupByExpressions: [DEPT],") checkAnswer(df, Seq(Row(1, 1, 1, 1, 1, 0d, 12000d, 0d, 12000d, 0d, 0d, 2, 0d), Row(2, 2, 2, 2, 2, 10000d, 12000d, 10000d, 12000d, 0d, 0d, 3, 0d), @@ -1955,12 +2277,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df2 = sql("SELECT `dept_id1`, COUNT(`dept_id`) FROM h2.test.dept GROUP BY `dept_id1`") checkPushedInfo(df2, - "PushedGroupByExpressions: [dept_id1]", "PushedAggregates: [COUNT(dept_id)]") + "PushedGroupByExpressions: [dept_id1]", + "PushedAggregates: [COUNT(dept_id)]") checkAnswer(df2, Seq(Row(1, 2))) val df3 = sql("SELECT `dept_id`, COUNT(`dept_id1`) FROM h2.test.dept GROUP BY `dept_id`") checkPushedInfo(df3, - "PushedGroupByExpressions: [dept_id]", "PushedAggregates: [COUNT(dept_id1)]") + "PushedGroupByExpressions: [dept_id]", + "PushedAggregates: [COUNT(dept_id1)]") checkAnswer(df3, Seq(Row(1, 1), Row(2, 1))) } @@ -2050,7 +2374,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"total" > 1000) checkAggregateRemoved(df) checkPushedInfo(df, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT]") + "PushedAggregates: [SUM(SALARY)]", + "PushedFilters: []", + "PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) val df2 = spark.table("h2.test.employee") @@ -2060,7 +2386,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"total" > 1000) checkAggregateRemoved(df2) checkPushedInfo(df2, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT]") + "PushedAggregates: [SUM(SALARY)]", + "PushedFilters: []", + "PushedGroupByExpressions: [DEPT]") checkAnswer(df2, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) } @@ -2077,7 +2405,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"total" > 1000) checkAggregateRemoved(df, false) checkPushedInfo(df, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [NAME]") + "PushedAggregates: [SUM(SALARY)]", + "PushedFilters: []", + "PushedGroupByExpressions: [NAME]") checkAnswer(df, Seq(Row("alex", 12000.00), Row("amy", 10000.00), Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) @@ -2093,30 +2423,25 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"total" > 1000) checkAggregateRemoved(df2, false) checkPushedInfo(df2, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [NAME]") + "PushedAggregates: [SUM(SALARY)]", + "PushedFilters: []", + "PushedGroupByExpressions: [NAME]") checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00), Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) } test("scan with aggregate push-down: partial push-down AVG with overflow") { - def createDataFrame: DataFrame = spark.read - .option("partitionColumn", "id") - .option("lowerBound", "0") - .option("upperBound", "2") - .option("numPartitions", "2") - .table("h2.test.item") - .agg(avg($"PRICE").as("avg")) - Seq(true, false).foreach { ansiEnabled => withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { - val df = createDataFrame + val df = spark.read + .option("partitionColumn", "id") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.item") + .agg(avg($"PRICE").as("avg")) checkAggregateRemoved(df, false) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [SUM(PRICE), COUNT(PRICE)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [SUM(PRICE), COUNT(PRICE)]") if (ansiEnabled) { val e = intercept[SparkException] { df.collect() @@ -2138,13 +2463,17 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df1 = sql("SELECT h2.my_avg(id) FROM h2.test.people") checkAggregateRemoved(df1) checkPushedInfo(df1, - "PushedAggregates: [iavg(ID)], PushedFilters: [], PushedGroupByExpressions: []") + "PushedAggregates: [iavg(ID)]", + "PushedFilters: []", + "PushedGroupByExpressions: []") checkAnswer(df1, Seq(Row(1))) val df2 = sql("SELECT name, h2.my_avg(id) FROM h2.test.people group by name") checkAggregateRemoved(df2) checkPushedInfo(df2, - "PushedAggregates: [iavg(ID)], PushedFilters: [], PushedGroupByExpressions: [NAME]") + "PushedAggregates: [iavg(ID)]", + "PushedFilters: []", + "PushedGroupByExpressions: [NAME]") checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2))) withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { val df3 = sql( @@ -2155,8 +2484,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel """.stripMargin) checkAggregateRemoved(df3) checkPushedInfo(df3, - "PushedAggregates: [iavg(CASE WHEN NAME = 'fred' THEN ID + 1 ELSE ID END)]," + - " PushedFilters: [], PushedGroupByExpressions: []") + "PushedAggregates: [iavg(CASE WHEN NAME = 'fred' THEN ID + 1 ELSE ID END)]", + "PushedFilters: []", + "PushedGroupByExpressions: []") checkAnswer(df3, Seq(Row(2))) val df4 = sql( @@ -2169,8 +2499,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel """.stripMargin) checkAggregateRemoved(df4) checkPushedInfo(df4, - "PushedAggregates: [iavg(CASE WHEN NAME = 'fred' THEN ID + 1 ELSE ID END)]," + - " PushedFilters: [], PushedGroupByExpressions: [NAME]") + "PushedAggregates: [iavg(CASE WHEN NAME = 'fred' THEN ID + 1 ELSE ID END)]", + "PushedFilters: []", + "PushedGroupByExpressions: [NAME]") checkAnswer(df4, Seq(Row("fred", 2), Row("mary", 2))) } } finally { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 2ec593b95c9b6..006ef14ed73da 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -20,10 +20,9 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path import org.apache.spark.sql.catalyst.catalog.CatalogUtils -import org.apache.spark.sql.catalyst.expressions.PredicateHelper import org.apache.spark.sql.types._ -class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with PredicateHelper { +class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName // We have a very limited number of supported types at here since it is just for a