Skip to content

Commit

Permalink
[SPARK-41049][SQL][FOLLOW-UP] Mark map related expressions as statefu…
Browse files Browse the repository at this point in the history
…l expressions

### What changes were proposed in this pull request?

MapConcat contains a state so it is stateful:
```
private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType)
```

Similarly `MapFromEntries, CreateMap, MapFromArrays, StringToMap, and TransformKeys` need the same change.

### Why are the changes needed?

Stateful expression should be marked as stateful.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

N/A

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #46721 from amaliujia/statefulexpr.

Authored-by: Rui Wang <rui.wang@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
amaliujia authored and cloud-fan committed May 28, 2024
1 parent a88cc1a commit af1ac1e
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ case class MapConcat(children: Seq[Expression])
}
}

override def stateful: Boolean = true
override def nullable: Boolean = children.exists(_.nullable)

private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType)
Expand Down Expand Up @@ -828,6 +829,8 @@ case class MapFromEntries(child: Expression)

override def nullable: Boolean = child.nullable || nullEntries

override def stateful: Boolean = true

@transient override lazy val dataType: MapType = dataTypeDetails.get._1

override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ case class CreateMap(children: Seq[Expression], useStringTypeWhenEmpty: Boolean)

private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType)

override def stateful: Boolean = true

override def eval(input: InternalRow): Any = {
var i = 0
while (i < keys.length) {
Expand Down Expand Up @@ -320,6 +322,8 @@ case class MapFromArrays(left: Expression, right: Expression)
valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull)
}

override def stateful: Boolean = true

private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType)

override def nullSafeEval(keyArray: Any, valueArray: Any): Any = {
Expand Down Expand Up @@ -568,6 +572,8 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E
this(child, Literal(","), Literal(":"))
}

override def stateful: Boolean = true

override def first: Expression = text
override def second: Expression = pairDelim
override def third: Expression = keyValueDelim
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,8 @@ case class TransformKeys(

override def dataType: MapType = MapType(function.dataType, valueType, valueContainsNull)

override def stateful: Boolean = true

override def checkInputDataTypes(): TypeCheckResult = {
TypeUtils.checkForMapKeyType(function.dataType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Cast, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF, ScalarSubquery, Uuid}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Cast, CreateMap, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF, ScalarSubquery, Uuid}
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LocalRelation, LogicalPlan, OneRowRelation}
Expand Down Expand Up @@ -2504,6 +2504,14 @@ class DataFrameSuite extends QueryTest
assert(row.getInt(0).toString == row.getString(2))
assert(row.getInt(0).toString == row.getString(3))
}

val v3 = Column(CreateMap(Seq(Literal("key"), Literal("value"))))
val v4 = to_csv(struct(v3.as("a"))) // to_csv is CodegenFallback
df.select(v3, v3, v4, v4).collect().foreach { row =>
assert(row.getMap(0).toString() == row.getMap(1).toString())
assert(row.getString(2) == s"{key -> ${row.getMap(0).get("key").get}}")
assert(row.getString(3) == s"{key -> ${row.getMap(0).get("key").get}}")
}
}

test("SPARK-45216: Non-deterministic functions with seed") {
Expand Down

0 comments on commit af1ac1e

Please sign in to comment.