From af1ac1edc2a96c9aba949e3100ddae37b6f0e5b2 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Mon, 27 May 2024 22:40:13 -0700 Subject: [PATCH] [SPARK-41049][SQL][FOLLOW-UP] Mark map related expressions as stateful expressions ### What changes were proposed in this pull request? MapConcat contains a state so it is stateful: ``` private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) ``` Similarly `MapFromEntries, CreateMap, MapFromArrays, StringToMap, and TransformKeys` need the same change. ### Why are the changes needed? Stateful expression should be marked as stateful. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? N/A ### Was this patch authored or co-authored using generative AI tooling? No Closes #46721 from amaliujia/statefulexpr. Authored-by: Rui Wang Signed-off-by: Wenchen Fan --- .../catalyst/expressions/collectionOperations.scala | 3 +++ .../sql/catalyst/expressions/complexTypeCreator.scala | 6 ++++++ .../catalyst/expressions/higherOrderFunctions.scala | 2 ++ .../scala/org/apache/spark/sql/DataFrameSuite.scala | 10 +++++++++- 4 files changed, 20 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 632e2f3d3e973..ea117f876550e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -713,6 +713,7 @@ case class MapConcat(children: Seq[Expression]) } } + override def stateful: Boolean = true override def nullable: Boolean = children.exists(_.nullable) private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) @@ -828,6 +829,8 @@ case class MapFromEntries(child: Expression) override def nullable: Boolean = child.nullable || nullEntries + override def stateful: Boolean = true + @transient override lazy val dataType: MapType = dataTypeDetails.get._1 override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 4c0d005340606..167c02c0bafc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -245,6 +245,8 @@ case class CreateMap(children: Seq[Expression], useStringTypeWhenEmpty: Boolean) private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) + override def stateful: Boolean = true + override def eval(input: InternalRow): Any = { var i = 0 while (i < keys.length) { @@ -320,6 +322,8 @@ case class MapFromArrays(left: Expression, right: Expression) valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull) } + override def stateful: Boolean = true + private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) override def nullSafeEval(keyArray: Any, valueArray: Any): Any = { @@ -568,6 +572,8 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E this(child, Literal(","), Literal(":")) } + override def stateful: Boolean = true + override def first: Expression = text override def second: Expression = pairDelim override def third: Expression = keyValueDelim diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 896f3e9774f37..80bcf156133ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -920,6 +920,8 @@ case class TransformKeys( override def dataType: MapType = MapType(function.dataType, valueType, valueContainsNull) + override def stateful: Boolean = true + override def checkInputDataTypes(): TypeCheckResult = { TypeUtils.checkForMapKeyType(function.dataType) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f11ad230ec160..760ee80260808 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Cast, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF, ScalarSubquery, Uuid} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Cast, CreateMap, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF, ScalarSubquery, Uuid} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LocalRelation, LogicalPlan, OneRowRelation} @@ -2504,6 +2504,14 @@ class DataFrameSuite extends QueryTest assert(row.getInt(0).toString == row.getString(2)) assert(row.getInt(0).toString == row.getString(3)) } + + val v3 = Column(CreateMap(Seq(Literal("key"), Literal("value")))) + val v4 = to_csv(struct(v3.as("a"))) // to_csv is CodegenFallback + df.select(v3, v3, v4, v4).collect().foreach { row => + assert(row.getMap(0).toString() == row.getMap(1).toString()) + assert(row.getString(2) == s"{key -> ${row.getMap(0).get("key").get}}") + assert(row.getString(3) == s"{key -> ${row.getMap(0).get("key").get}}") + } } test("SPARK-45216: Non-deterministic functions with seed") {