Skip to content

Commit

Permalink
[SPARK-42896][SQL][PYTHON] Make mapInPandas / mapInArrow support …
Browse files Browse the repository at this point in the history
…barrier mode execution

### What changes were proposed in this pull request?

Make mapInPandas / mapInArrow support barrier mode execution

### Why are the changes needed?

This is the preparation PR for supporting mapInPandas / mapInArrow barrier execution in spark connect mode. The feature is required by machine learning use cases.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Closes #40520 from WeichenXu123/barrier-udf.

Authored-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 committed Mar 27, 2023
1 parent 80f8664 commit 06bf544
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -489,12 +489,14 @@ class SparkConnectPlanner(val session: SparkSession) {
logical.MapInPandas(
pythonUdf,
pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
transformRelation(rel.getInput))
transformRelation(rel.getInput),
false)
case PythonEvalType.SQL_MAP_ARROW_ITER_UDF =>
logical.PythonMapInArrow(
pythonUdf,
pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
transformRelation(rel.getInput))
transformRelation(rel.getInput),
false)
case _ =>
throw InvalidPlanInput(s"Function with EvalType: ${pythonUdf.evalType} is not supported")
}
Expand Down
26 changes: 22 additions & 4 deletions python/pyspark/sql/pandas/map_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class PandasMapOpsMixin:
"""

def mapInPandas(
self, func: "PandasMapIterFunction", schema: Union[StructType, str]
self, func: "PandasMapIterFunction", schema: Union[StructType, str], isBarrier: bool = False
) -> "DataFrame":
"""
Maps an iterator of batches in the current :class:`DataFrame` using a Python native
Expand Down Expand Up @@ -60,6 +60,7 @@ def mapInPandas(
schema : :class:`pyspark.sql.types.DataType` or str
the return type of the `func` in PySpark. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
isBarrier : Use barrier mode execution if True.
Examples
--------
Expand All @@ -74,6 +75,14 @@ def mapInPandas(
+---+---+
| 1| 21|
+---+---+
>>> # Set isBarrier=True to force the "mapInPandas" stage running in barrier mode,
>>> # it ensures all python UDF workers in the stage will be launched concurrently.
>>> df.mapInPandas(filter_func, df.schema, isBarrier=True).show() # doctest: +SKIP
+---+---+
| id|age|
+---+---+
| 1| 21|
+---+---+
Notes
-----
Expand All @@ -93,11 +102,11 @@ def mapInPandas(
func, returnType=schema, functionType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
) # type: ignore[call-overload]
udf_column = udf(*[self[col] for col in self.columns])
jdf = self._jdf.mapInPandas(udf_column._jc.expr())
jdf = self._jdf.mapInPandas(udf_column._jc.expr(), isBarrier)
return DataFrame(jdf, self.sparkSession)

def mapInArrow(
self, func: "ArrowMapIterFunction", schema: Union[StructType, str]
self, func: "ArrowMapIterFunction", schema: Union[StructType, str], isBarrier: bool = False
) -> "DataFrame":
"""
Maps an iterator of batches in the current :class:`DataFrame` using a Python native
Expand All @@ -122,6 +131,7 @@ def mapInArrow(
schema : :class:`pyspark.sql.types.DataType` or str
the return type of the `func` in PySpark. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
isBarrier : Use barrier mode execution if True.
Examples
--------
Expand All @@ -137,6 +147,14 @@ def mapInArrow(
+---+---+
| 1| 21|
+---+---+
>>> # Set isBarrier=True to force the "mapInArrow" stage running in barrier mode,
>>> # it ensures all python UDF workers in the stage will be launched concurrently.
>>> df.mapInArrow(filter_func, df.schema, isBarrier=True).show() # doctest: +SKIP
+---+---+
| id|age|
+---+---+
| 1| 21|
+---+---+
Notes
-----
Expand All @@ -157,7 +175,7 @@ def mapInArrow(
func, returnType=schema, functionType=PythonEvalType.SQL_MAP_ARROW_ITER_UDF
) # type: ignore[call-overload]
udf_column = udf(*[self[col] for col in self.columns])
jdf = self._jdf.pythonMapInArrow(udf_column._jc.expr())
jdf = self._jdf.pythonMapInArrow(udf_column._jc.expr(), isBarrier)
return DataFrame(jdf, self.sparkSession)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,13 +233,13 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion @ MapInPandas(_, output, _)
case oldVersion @ MapInPandas(_, output, _, _)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion @ PythonMapInArrow(_, output, _)
case oldVersion @ PythonMapInArrow(_, output, _, _)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
newVersion.copyTagsFrom(oldVersion)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ case class FlatMapGroupsInPandas(
case class MapInPandas(
functionExpr: Expression,
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
child: LogicalPlan,
isBarrier: Boolean) extends UnaryNode {

override val producedAttributes = AttributeSet(output)

Expand All @@ -68,7 +69,8 @@ case class MapInPandas(
case class PythonMapInArrow(
functionExpr: Expression,
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
child: LogicalPlan,
isBarrier: Boolean) extends UnaryNode {

override val producedAttributes = AttributeSet(output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
val mapInPandas = MapInPandas(
pythonUdf,
output,
project)
project,
false)
val left = SubqueryAlias("temp0", mapInPandas)
val right = SubqueryAlias("temp1", mapInPandas)
val join = Join(left, right, Inner, None, JoinHint.NONE)
Expand Down
10 changes: 6 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3283,27 +3283,29 @@ class Dataset[T] private[sql](
* This function uses Apache Arrow as serialization format between Java executors and Python
* workers.
*/
private[sql] def mapInPandas(func: PythonUDF): DataFrame = {
private[sql] def mapInPandas(func: PythonUDF, isBarrier: Boolean = false): DataFrame = {
Dataset.ofRows(
sparkSession,
MapInPandas(
func,
func.dataType.asInstanceOf[StructType].toAttributes,
logicalPlan))
logicalPlan,
isBarrier))
}

/**
* Applies a function to each partition in Arrow format. The user-defined function
* defines a transformation: `iter(pyarrow.RecordBatch)` -> `iter(pyarrow.RecordBatch)`.
* Each partition is each iterator consisting of `pyarrow.RecordBatch`s as batches.
*/
private[sql] def pythonMapInArrow(func: PythonUDF): DataFrame = {
private[sql] def pythonMapInArrow(func: PythonUDF, isBarrier: Boolean = false): DataFrame = {
Dataset.ofRows(
sparkSession,
PythonMapInArrow(
func,
func.dataType.asInstanceOf[StructType].toAttributes,
logicalPlan))
logicalPlan,
isBarrier))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -806,10 +806,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.python.FlatMapCoGroupsInPandasExec(
f.leftAttributes, f.rightAttributes,
func, output, planLater(left), planLater(right)) :: Nil
case logical.MapInPandas(func, output, child) =>
execution.python.MapInPandasExec(func, output, planLater(child)) :: Nil
case logical.PythonMapInArrow(func, output, child) =>
execution.python.PythonMapInArrowExec(func, output, planLater(child)) :: Nil
case logical.MapInPandas(func, output, child, isBarrier) =>
execution.python.MapInPandasExec(func, output, planLater(child), isBarrier) :: Nil
case logical.PythonMapInArrow(func, output, child, isBarrier) =>
execution.python.PythonMapInArrowExec(func, output, planLater(child), isBarrier) :: Nil
case logical.AttachDistributedSequence(attr, child) =>
execution.python.AttachDistributedSequenceExec(attr, planLater(child)) :: Nil
case logical.MapElements(f, _, _, objAttr, child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics {
protected val func: Expression
protected val pythonEvalType: Int

protected val isBarrier: Boolean

private val pythonFunction = func.asInstanceOf[PythonUDF].func

override def producedAttributes: AttributeSet = AttributeSet(output)
Expand All @@ -50,7 +52,7 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics {
override def outputPartitioning: Partitioning = child.outputPartitioning

override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { inputIter =>
def mapper(inputIter: Iterator[InternalRow]): Iterator[InternalRow] = {
// Single function with one struct.
val argOffsets = Array(Array(0))
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction)))
Expand Down Expand Up @@ -90,5 +92,11 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics {
flattenedBatch.rowIterator.asScala
}.map(unsafeProj)
}

if (isBarrier) {
child.execute().barrier().mapPartitions(mapper)
} else {
child.execute().mapPartitionsInternal(mapper)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import org.apache.spark.sql.execution.SparkPlan
case class MapInPandasExec(
func: Expression,
output: Seq[Attribute],
child: SparkPlan)
child: SparkPlan,
override val isBarrier: Boolean)
extends MapInBatchExec {

override protected val pythonEvalType: Int = PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import org.apache.spark.sql.execution.SparkPlan
case class PythonMapInArrowExec(
func: Expression,
output: Seq[Attribute],
child: SparkPlan)
child: SparkPlan,
override val isBarrier: Boolean)
extends MapInBatchExec {

override protected val pythonEvalType: Int = PythonEvalType.SQL_MAP_ARROW_ITER_UDF
Expand Down

0 comments on commit 06bf544

Please sign in to comment.