diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 598110e9e207e..2c717f4ae9496 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -28,6 +28,10 @@ import org.apache.spark.sql.types._ * An interface for expressions that contain a [[QueryPlan]]. */ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression { + + override lazy val deterministic: Boolean = children.forall(_.deterministic) && + plan.deterministic + /** The id of the subquery expression. */ def exprId: ExprId diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 92b7690b5e0fe..3f40f8da8c17b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -66,6 +66,13 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] AttributeSet.fromAttributeSets(expressions.map(_.references)) -- producedAttributes } + /** + * Returns true when the all the expressions in the current node as well as all of its children + * are deterministic + */ + lazy val deterministic: Boolean = expressions.forall(_.deterministic) && + children.forall(_.deterministic) + /** * Attributes that are referenced by expressions but not provided by this node's children. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala index 91ce187f4d270..f0ff1569e1583 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.dsl.plans -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ListQuery, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Add, Alias, AttributeReference, Expression, ListQuery, Literal, NamedExpression, Rand} import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project, Union} import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.types.IntegerType @@ -83,4 +83,32 @@ class QueryPlanSuite extends SparkFunSuite { assert(countRelationsInPlan == 2) assert(countRelationsInPlanAndSubqueries == 5) } + + test("SPARK-37199: add a deterministic field to QueryPlan") { + val a: NamedExpression = AttributeReference("a", IntegerType)() + val aRand: NamedExpression = Alias(Add(a, Rand(1)), "aRand")() + val deterministicPlan = Project( + Seq(a), + Filter( + ListQuery(Project( + Seq(a), + UnresolvedRelation(TableIdentifier("t", None)) + )), + UnresolvedRelation(TableIdentifier("t", None)) + ) + ) + assert(deterministicPlan.deterministic) + + val nonDeterministicPlan = Project( + Seq(aRand), + Filter( + ListQuery(Project( + Seq(a), + UnresolvedRelation(TableIdentifier("t", None)) + )), + UnresolvedRelation(TableIdentifier("t", None)) + ) + ) + assert(!nonDeterministicPlan.deterministic) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index cd2b812b6a7cf..537800a03dbe6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -1753,4 +1753,14 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } + + test("SPARK-37199: deterministic in QueryPlan considers subquery") { + val deterministicQueryPlan = sql("select (select 1 as b) as b") + .queryExecution.executedPlan + assert(deterministicQueryPlan.deterministic) + + val nonDeterministicQueryPlan = sql("select (select rand(1) as b) as b") + .queryExecution.executedPlan + assert(!nonDeterministicQueryPlan.deterministic) + } }