Skip to content

Commit

Permalink
[SPARK-46812][SQL][PYTHON] Make mapInPandas / mapInArrow support Reso…
Browse files Browse the repository at this point in the history
…urceProfile

### What changes were proposed in this pull request?

Support stage-level scheduling for some PySpark DataFrame APIs (mapInPandas and mapInArrow).

### Why are the changes needed?

The introduction of barrier mode in Spark, as seen in #40520, allows for the implementation of Spark ML cases (pure Python algorithms) using DataFrame APIs such as mapInPandas and mapInArrow, so it's necessary to enable stage-level scheduling for DataFrame APIs.

### Does this PR introduce _any_ user-facing change?
Yes, This PR adds a new argument "profile" for mapInPandas and mapInArrow.

``` python
def mapInPandas(
    self, func: "PandasMapIterFunction",
        schema: Union[StructType, str],
        barrier: bool = False,
        profile: Optional[ResourceProfile] = None,
) -> "DataFrame":

def mapInArrow(
    self, func: "ArrowMapIterFunction",
        schema: Union[StructType, str],
        barrier: bool = False,
        profile: Optional[ResourceProfile] = None,
) -> "DataFrame":
```

How to use it? take mapInPandas as an example,

``` python
from pyspark import TaskContext
def func(iterator):
    tc = TaskContext.get()
    assert tc.cpus() == 3
    for batch in iterator:
        yield batch
df = spark.range(10)

from pyspark.resource import TaskResourceRequests, ResourceProfileBuilder
treqs = TaskResourceRequests().cpus(3)
rp = ResourceProfileBuilder().require(treqs).build

df.mapInPandas(func, "id long", False, rp).collect()
```

### How was this patch tested?

The newly added tests can pass, and some manual tests are needed for dynamic allocation on or off.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #44852 from wbo4958/df-rp.

Lead-authored-by: Bobby Wang <bobwang@nvidia.com>
Co-authored-by: Bobby Wang <wbo4958@gmail.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
2 people authored and WeichenXu123 committed Feb 19, 2024
1 parent 0818096 commit c4e4497
Show file tree
Hide file tree
Showing 12 changed files with 206 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -549,13 +549,15 @@ class SparkConnectPlanner(
pythonUdf,
DataTypeUtils.toAttributes(pythonUdf.dataType.asInstanceOf[StructType]),
baseRel,
isBarrier)
isBarrier,
None)
case PythonEvalType.SQL_MAP_ARROW_ITER_UDF =>
logical.MapInArrow(
pythonUdf,
DataTypeUtils.toAttributes(pythonUdf.dataType.asInstanceOf[StructType]),
baseRel,
isBarrier)
isBarrier,
None)
case _ =>
throw InvalidPlanInput(
s"Function with EvalType: ${pythonUdf.evalType} is not supported")
Expand Down
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ def __hash__(self):
"pyspark.sql.tests.test_udf_profiler",
"pyspark.sql.tests.test_udtf",
"pyspark.sql.tests.test_utils",
"pyspark.sql.tests.test_resources",
],
)

Expand Down
61 changes: 56 additions & 5 deletions python/pyspark/sql/pandas/map_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
# limitations under the License.
#
import sys
from typing import Union, TYPE_CHECKING
from typing import Union, TYPE_CHECKING, Optional

from py4j.java_gateway import JavaObject

from pyspark.resource.requests import ExecutorResourceRequests, TaskResourceRequests
from pyspark.rdd import PythonEvalType
from pyspark.resource import ResourceProfile
from pyspark.sql.types import StructType

if TYPE_CHECKING:
Expand All @@ -32,7 +36,11 @@ class PandasMapOpsMixin:
"""

def mapInPandas(
self, func: "PandasMapIterFunction", schema: Union[StructType, str], barrier: bool = False
self,
func: "PandasMapIterFunction",
schema: Union[StructType, str],
barrier: bool = False,
profile: Optional[ResourceProfile] = None,
) -> "DataFrame":
"""
Maps an iterator of batches in the current :class:`DataFrame` using a Python native
Expand Down Expand Up @@ -65,6 +73,12 @@ def mapInPandas(
.. versionadded: 3.5.0
profile : :class:`pyspark.resource.ResourceProfile`. The optional ResourceProfile
to be used for mapInPandas.
.. versionadded: 4.0.0
Examples
--------
>>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
Expand Down Expand Up @@ -141,11 +155,17 @@ def mapInPandas(
func, returnType=schema, functionType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
) # type: ignore[call-overload]
udf_column = udf(*[self[col] for col in self.columns])
jdf = self._jdf.mapInPandas(udf_column._jc.expr(), barrier)

jrp = self._build_java_profile(profile)
jdf = self._jdf.mapInPandas(udf_column._jc.expr(), barrier, jrp)
return DataFrame(jdf, self.sparkSession)

def mapInArrow(
self, func: "ArrowMapIterFunction", schema: Union[StructType, str], barrier: bool = False
self,
func: "ArrowMapIterFunction",
schema: Union[StructType, str],
barrier: bool = False,
profile: Optional[ResourceProfile] = None,
) -> "DataFrame":
"""
Maps an iterator of batches in the current :class:`DataFrame` using a Python native
Expand Down Expand Up @@ -175,6 +195,11 @@ def mapInArrow(
.. versionadded: 3.5.0
profile : :class:`pyspark.resource.ResourceProfile`. The optional ResourceProfile
to be used for mapInArrow.
.. versionadded: 4.0.0
Examples
--------
>>> import pyarrow # doctest: +SKIP
Expand Down Expand Up @@ -220,9 +245,35 @@ def mapInArrow(
func, returnType=schema, functionType=PythonEvalType.SQL_MAP_ARROW_ITER_UDF
) # type: ignore[call-overload]
udf_column = udf(*[self[col] for col in self.columns])
jdf = self._jdf.mapInArrow(udf_column._jc.expr(), barrier)

jrp = self._build_java_profile(profile)
jdf = self._jdf.mapInArrow(udf_column._jc.expr(), barrier, jrp)
return DataFrame(jdf, self.sparkSession)

def _build_java_profile(
self, profile: Optional[ResourceProfile] = None
) -> Optional[JavaObject]:
"""Build the java ResourceProfile based on PySpark ResourceProfile"""
from pyspark.sql import DataFrame

assert isinstance(self, DataFrame)

jrp = None
if profile is not None:
if profile._java_resource_profile is not None:
jrp = profile._java_resource_profile
else:
jvm = self.sparkSession.sparkContext._jvm
assert jvm is not None

builder = jvm.org.apache.spark.resource.ResourceProfileBuilder()
ereqs = ExecutorResourceRequests(jvm, profile._executor_resource_requests)
treqs = TaskResourceRequests(jvm, profile._task_resource_requests)
builder.require(ereqs._java_executor_resource_requests)
builder.require(treqs._java_task_resource_requests)
jrp = builder.build()
return jrp


def _test() -> None:
import doctest
Expand Down
104 changes: 104 additions & 0 deletions python/pyspark/sql/tests/test_resources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest

from pyspark import SparkContext, TaskContext
from pyspark.resource import TaskResourceRequests, ResourceProfileBuilder
from pyspark.sql import SparkSession
from pyspark.testing.sqlutils import (
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
)
from pyspark.testing.utils import ReusedPySparkTestCase


@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message,
)
class ResourceProfileTestsMixin(object):
def test_map_in_arrow_without_profile(self):
def func(iterator):
tc = TaskContext.get()
assert tc.cpus() == 1
for batch in iterator:
yield batch

df = self.spark.range(10)
df.mapInArrow(func, "id long").collect()

def test_map_in_arrow_with_profile(self):
def func(iterator):
tc = TaskContext.get()
assert tc.cpus() == 3
for batch in iterator:
yield batch

df = self.spark.range(10)

treqs = TaskResourceRequests().cpus(3)
rp = ResourceProfileBuilder().require(treqs).build
df.mapInArrow(func, "id long", False, rp).collect()

def test_map_in_pandas_without_profile(self):
def func(iterator):
tc = TaskContext.get()
assert tc.cpus() == 1
for batch in iterator:
yield batch

df = self.spark.range(10)
df.mapInPandas(func, "id long").collect()

def test_map_in_pandas_with_profile(self):
def func(iterator):
tc = TaskContext.get()
assert tc.cpus() == 3
for batch in iterator:
yield batch

df = self.spark.range(10)

treqs = TaskResourceRequests().cpus(3)
rp = ResourceProfileBuilder().require(treqs).build
df.mapInPandas(func, "id long", False, rp).collect()


class ResourceProfileTests(ResourceProfileTestsMixin, ReusedPySparkTestCase):
@classmethod
def setUpClass(cls):
cls.sc = SparkContext("local-cluster[1, 4, 1024]", cls.__name__, conf=cls.conf())
cls.spark = SparkSession(cls.sc)

@classmethod
def tearDownClass(cls):
super(ResourceProfileTests, cls).tearDownClass()
cls.spark.stop()


if __name__ == "__main__":
from pyspark.sql.tests.test_resources import * # noqa: F401

try:
import xmlrunner

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,13 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion @ MapInPandas(_, output, _, _)
case oldVersion @ MapInPandas(_, output, _, _, _)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion @ MapInArrow(_, output, _, _)
case oldVersion @ MapInArrow(_, output, _, _, _)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
newVersion.copyTagsFrom(oldVersion)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.resource.ResourceProfile
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF, PythonUDTF}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.truncatedString
Expand Down Expand Up @@ -77,7 +78,8 @@ case class MapInPandas(
functionExpr: Expression,
output: Seq[Attribute],
child: LogicalPlan,
isBarrier: Boolean) extends UnaryNode {
isBarrier: Boolean,
profile: Option[ResourceProfile]) extends UnaryNode {

override val producedAttributes = AttributeSet(output)

Expand All @@ -93,7 +95,8 @@ case class MapInArrow(
functionExpr: Expression,
output: Seq[Attribute],
child: LogicalPlan,
isBarrier: Boolean) extends UnaryNode {
isBarrier: Boolean,
profile: Option[ResourceProfile]) extends UnaryNode {

override val producedAttributes = AttributeSet(output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
pythonUdf,
output,
project,
false)
false,
None)
val left = SubqueryAlias("temp0", mapInPandas)
val right = SubqueryAlias("temp1", mapInPandas)
val join = Join(left, right, Inner, None, JoinHint.NONE)
Expand All @@ -729,7 +730,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
pythonUdf,
output,
project,
false)
false,
None)
assertAnalysisSuccess(mapInPandas)
}

Expand All @@ -745,7 +747,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
pythonUdf,
output,
project,
false)
false,
None)
assertAnalysisSuccess(mapInArrow)
}

Expand Down
17 changes: 13 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
import org.apache.spark.api.r.RRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.resource.ResourceProfile
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QueryPlanningTracker, ScalaReflection, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
Expand Down Expand Up @@ -3515,29 +3516,37 @@ class Dataset[T] private[sql](
* This function uses Apache Arrow as serialization format between Java executors and Python
* workers.
*/
private[sql] def mapInPandas(func: PythonUDF, isBarrier: Boolean = false): DataFrame = {
private[sql] def mapInPandas(
func: PythonUDF,
isBarrier: Boolean = false,
profile: ResourceProfile = null): DataFrame = {
Dataset.ofRows(
sparkSession,
MapInPandas(
func,
toAttributes(func.dataType.asInstanceOf[StructType]),
logicalPlan,
isBarrier))
isBarrier,
Option(profile)))
}

/**
* Applies a function to each partition in Arrow format. The user-defined function
* defines a transformation: `iter(pyarrow.RecordBatch)` -> `iter(pyarrow.RecordBatch)`.
* Each partition is each iterator consisting of `pyarrow.RecordBatch`s as batches.
*/
private[sql] def mapInArrow(func: PythonUDF, isBarrier: Boolean = false): DataFrame = {
private[sql] def mapInArrow(
func: PythonUDF,
isBarrier: Boolean = false,
profile: ResourceProfile = null): DataFrame = {
Dataset.ofRows(
sparkSession,
MapInArrow(
func,
toAttributes(func.dataType.asInstanceOf[StructType]),
logicalPlan,
isBarrier))
isBarrier,
Option(profile)))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -867,10 +867,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.python.FlatMapCoGroupsInArrowExec(
f.leftAttributes, f.rightAttributes,
func, output, planLater(left), planLater(right)) :: Nil
case logical.MapInPandas(func, output, child, isBarrier) =>
execution.python.MapInPandasExec(func, output, planLater(child), isBarrier) :: Nil
case logical.MapInArrow(func, output, child, isBarrier) =>
execution.python.MapInArrowExec(func, output, planLater(child), isBarrier) :: Nil
case logical.MapInPandas(func, output, child, isBarrier, profile) =>
execution.python.MapInPandasExec(func, output, planLater(child), isBarrier, profile) :: Nil
case logical.MapInArrow(func, output, child, isBarrier, profile) =>
execution.python.MapInArrowExec(func, output, planLater(child), isBarrier, profile) :: Nil
case logical.AttachDistributedSequence(attr, child) =>
execution.python.AttachDistributedSequenceExec(attr, planLater(child)) :: Nil
case logical.MapElements(f, _, _, objAttr, child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.python

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.resource.ResourceProfile
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan

Expand All @@ -29,7 +30,8 @@ case class MapInArrowExec(
func: Expression,
output: Seq[Attribute],
child: SparkPlan,
override val isBarrier: Boolean)
override val isBarrier: Boolean,
override val profile: Option[ResourceProfile])
extends MapInBatchExec {

override protected val pythonEvalType: Int = PythonEvalType.SQL_MAP_ARROW_ITER_UDF
Expand Down
Loading

0 comments on commit c4e4497

Please sign in to comment.