Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support launching Map Pandas UDF on empty partitions #9557

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions integration_tests/src/main/python/udf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,3 +400,13 @@ def filter_func(iterator):
.mapInArrow(filter_func, schema=f"a {data_type}, b {data_type}"),
"PythonMapInArrowExec",
conf=conf)


def test_map_pandas_udf_with_empty_partitions():
def test_func(spark):
df = spark.range(10).withColumn("const", f.lit(1))
# The repartition will produce 4 empty partitions.
return df.repartition(5, "const").mapInPandas(
lambda data: [pd.DataFrame([len(list(data))])], schema="ret:integer")

assert_gpu_and_cpu_are_equal_collect(test_func, conf=arrow_udf_conf)
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ import ai.rapids.cudf._
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python._
import org.apache.spark.rapids.shims.api.python.ShimBasePythonRunner
import org.apache.spark.sql.execution.python.PythonUDFRunner
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.shims.ArrowUtilsShim
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -241,6 +245,21 @@ abstract class GpuArrowPythonRunnerBase(
}

protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
if (inputIterator.nonEmpty) {
writeNonEmptyIteratorOnGpu(dataOut)
} else { // Partition is empty.
// In this case CPU will still send the schema to Python workers by calling
// the "start" API of the Java Arrow writer, but GPU will send out nothing,
// leading to the IPC error. And it is not easy to do as what Spark does on
// GPU, because the C++ Arrow writer used by GPU will only send out the schema
// iff there is some data. Besides, it does not expose a "start" API to do this.
// So here we leverage the Java Arrow writer to do similar things as Spark.
// It is OK because sending out schema has nothing to do with GPU.
firestarman marked this conversation as resolved.
Show resolved Hide resolved
writeEmptyIteratorOnCpu(dataOut)
}
}

private def writeNonEmptyIteratorOnGpu(dataOut: DataOutputStream): Unit = {
val writer = {
val builder = ArrowIPCWriterOptions.builder()
builder.withMaxChunkSize(batchSize)
Expand All @@ -250,11 +269,11 @@ abstract class GpuArrowPythonRunnerBase(
})
// Flatten the names of nested struct columns, required by cudf arrow IPC writer.
GpuArrowPythonRunner.flattenNames(pythonInSchema).foreach { case (name, nullable) =>
if (nullable) {
builder.withColumnNames(name)
} else {
builder.withNotNullableColumnNames(name)
}
if (nullable) {
builder.withColumnNames(name)
} else {
builder.withNotNullableColumnNames(name)
}
}
Table.writeArrowIPCChunked(builder.build(), new BufferToStreamWriter(dataOut))
}
Expand All @@ -277,6 +296,28 @@ abstract class GpuArrowPythonRunnerBase(
if (onDataWriteFinished != null) onDataWriteFinished()
}
}

private def writeEmptyIteratorOnCpu(dataOut: DataOutputStream): Unit = {
// most code is copied from Spark
val arrowSchema = ArrowUtilsShim.toArrowSchema(pythonInSchema, timeZoneId)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"stdout writer for empty partition", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)

Utils.tryWithSafeFinally {
val writer = new ArrowStreamWriter(root, null, dataOut)
writer.start()
// No data to write
writer.end()
// The iterator can grab the semaphore even on an empty batch
GpuSemaphore.releaseIfNecessary(TaskContext.get())
} {
root.close()
allocator.close()
if (onDataWriteFinished != null) onDataWriteFinished()
}
}

}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,31 +87,25 @@ trait GpuMapInBatchExec extends ShimUnaryExecNode with GpuPythonExecBase {
}
}
}

if (pyInputIterator.hasNext) {
val pyRunner = new GpuArrowPythonRunnerBase(
chainedFunc,
pythonEvalType,
argOffsets,
pyInputSchema,
sessionLocalTimeZone,
pythonRunnerConf,
batchSize) {
override def toBatch(table: Table): ColumnarBatch = {
BatchGroupedIterator.extractChildren(table, localOutput)
}
val pyRunner = new GpuArrowPythonRunnerBase(
chainedFunc,
pythonEvalType,
argOffsets,
pyInputSchema,
sessionLocalTimeZone,
pythonRunnerConf,
batchSize) {
override def toBatch(table: Table): ColumnarBatch = {
BatchGroupedIterator.extractChildren(table, localOutput)
}

pyRunner.compute(pyInputIterator, context.partitionId(), context)
.map { cb =>
numOutputBatches += 1
numOutputRows += cb.numRows
cb
}
} else {
// Empty partition, return it directly
inputIter
}

pyRunner.compute(pyInputIterator, context.partitionId(), context)
.map { cb =>
numOutputBatches += 1
numOutputRows += cb.numRows
cb
}
} // end of mapPartitionsInternal
} // end of internalDoExecuteColumnar

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,18 @@
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.arrow.vector.types.pojo.Schema

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils

object ArrowUtilsShim {
def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] =
ArrowUtils.getPythonRunnerConfMap(conf)

def toArrowSchema(schema: StructType, timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean = true, largeVarTypes: Boolean = false): Schema = {
ArrowUtils.toArrowSchema(schema, timeZoneId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,19 @@
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.arrow.vector.types.pojo.Schema

import org.apache.spark.sql.execution.python.ArrowPythonRunner
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils

object ArrowUtilsShim {
def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] =
ArrowPythonRunner.getPythonRunnerConfMap(conf)

def toArrowSchema(schema: StructType, timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean = true, largeVarTypes: Boolean = false): Schema = {
ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
}
}