From 7c36b5885ae1aaaf8928fc519444c2ad541b04db Mon Sep 17 00:00:00 2001 From: Firestarman Date: Thu, 12 Nov 2020 17:08:50 +0800 Subject: [PATCH 1/3] Accelerate data transfer for map Pandas UDF node. Signed-off-by: Firestarman --- docs/configs.md | 2 +- docs/supported_ops.md | 39 +++- .../src/main/python/udf_cudf_test.py | 2 - integration_tests/src/main/python/udf_test.py | 1 - .../nvidia/spark/rapids/GpuOverrides.scala | 13 +- .../execution/python/GpuMapInPandasExec.scala | 183 ++++++++++++------ 6 files changed, 164 insertions(+), 76 deletions(-) diff --git a/docs/configs.md b/docs/configs.md index 8abcf8f00a2..1ea28bcfc96 100644 --- a/docs/configs.md +++ b/docs/configs.md @@ -304,7 +304,7 @@ Name | Description | Default Value | Notes spark.rapids.sql.exec.ArrowEvalPythonExec|The backend of the Scalar Pandas UDFs. Accelerates the data transfer between the Java process and the Python process. It also supports scheduling GPU resources for the Python process when enabled|true|None| spark.rapids.sql.exec.FlatMapCoGroupsInPandasExec|The backend for CoGrouped Aggregation Pandas UDF, it runs on CPU itself now but supports scheduling GPU resources for the Python process when enabled|false|This is disabled by default because Performance is not ideal now| spark.rapids.sql.exec.FlatMapGroupsInPandasExec|The backend for Grouped Map Pandas UDF, it runs on CPU itself now but supports scheduling GPU resources for the Python process when enabled|false|This is disabled by default because Performance is not ideal now| -spark.rapids.sql.exec.MapInPandasExec|The backend for Map Pandas Iterator UDF, it runs on CPU itself now but supports scheduling GPU resources for the Python process when enabled|false|This is disabled by default because Performance is not ideal now| +spark.rapids.sql.exec.MapInPandasExec|The backend for Map Pandas Iterator UDF. Accelerates the data transfer between the Java process and the Python process. It also supports scheduling GPU resources for the Python process when enabled.|true|None| spark.rapids.sql.exec.WindowInPandasExec|The backend for Window Aggregation Pandas UDF, Accelerates the data transfer between the Java process and the Python process. It also supports scheduling GPU resources for the Python process when enabled. For now it only supports row based window frame.|false|This is disabled by default because it only supports row based frame for now| spark.rapids.sql.exec.WindowExec|Window-operator backend|true|None| diff --git a/docs/supported_ops.md b/docs/supported_ops.md index fe803eb5d21..78dae6745a9 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -705,6 +705,29 @@ Accelerator supports are described below. NS +MapInPandasExec +The backend for Map Pandas Iterator UDF. Accelerates the data transfer between the Java process and the Python process. It also supports scheduling GPU resources for the Python process when enabled. +None +S +S +S +S +S +S +S +S +S* +S +NS +NS +NS +NS +PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) +NS +PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) +NS + + WindowInPandasExec The backend for Window Aggregation Pandas UDF, Accelerates the data transfer between the Java process and the Python process. It also supports scheduling GPU resources for the Python process when enabled. For now it only supports row based window frame. This is disabled by default because it only supports row based frame for now @@ -10594,9 +10617,9 @@ Accelerator support is described below. NS NS NS -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) -NS +PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) NS +PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) NS @@ -10637,9 +10660,9 @@ Accelerator support is described below. NS NS NS -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) -NS +PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) NS +PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) NS @@ -10680,9 +10703,9 @@ Accelerator support is described below. NS NS NS -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) -NS +PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) NS +PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) NS @@ -10723,9 +10746,9 @@ Accelerator support is described below. NS NS NS -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) -NS +PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) NS +PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) NS diff --git a/integration_tests/src/main/python/udf_cudf_test.py b/integration_tests/src/main/python/udf_cudf_test.py index 1dd057caad9..36135eed3a1 100644 --- a/integration_tests/src/main/python/udf_cudf_test.py +++ b/integration_tests/src/main/python/udf_cudf_test.py @@ -41,7 +41,6 @@ _conf = { - 'spark.rapids.sql.exec.MapInPandasExec':'true', 'spark.rapids.sql.exec.FlatMapGroupsInPandasExec': 'true', 'spark.rapids.sql.exec.AggregateInPandasExec': 'true', 'spark.rapids.sql.exec.FlatMapCoGroupsInPandasExec': 'true', @@ -156,7 +155,6 @@ def gpu_run(spark): # ======= Test Flat Map In Pandas ======= -@allow_non_gpu('GpuMapInPandasExec','PythonUDF') @cudf_udf def test_map_in_pandas(enable_cudf_udf): def cpu_run(spark): diff --git a/integration_tests/src/main/python/udf_test.py b/integration_tests/src/main/python/udf_test.py index 36ba96ac499..e8221a605b4 100644 --- a/integration_tests/src/main/python/udf_test.py +++ b/integration_tests/src/main/python/udf_test.py @@ -201,7 +201,6 @@ def pandas_add(data): conf=arrow_udf_conf) -@allow_non_gpu('MapInPandasExec', 'PythonUDF', 'Alias') @pytest.mark.parametrize('data_gen', [LongGen()], ids=idfn) def test_map_apply_udf(data_gen): def pandas_filter(iterator): diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 874b4df020a..ad5df6d2229 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2008,7 +2008,7 @@ object GpuOverrides { "UDF run in an external python process. Does not actually run on the GPU, but " + "the transfer of data to/from it can be accelerated.", ExprChecks.fullAggAndProject( - TypeSig.commonCudfTypes + TypeSig.ARRAY.nested(TypeSig.commonCudfTypes), + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(), TypeSig.all, repeatingParamCheck = Some(RepeatingParamCheck( "param", @@ -2824,11 +2824,12 @@ object GpuOverrides { } }), exec[MapInPandasExec]( - "The backend for Map Pandas Iterator UDF, it runs on CPU itself now but supports " + - " scheduling GPU resources for the Python process when enabled", - ExecChecks.hiddenHack(), - (mapPy, conf, p, r) => new GpuMapInPandasExecMeta(mapPy, conf, p, r)) - .disabledByDefault("Performance is not ideal now"), + "The backend for Map Pandas Iterator UDF. Accelerates the data transfer between the" + + " Java process and the Python process. It also supports scheduling GPU resources" + + " for the Python process when enabled.", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all), + (mapPy, conf, p, r) => new GpuMapInPandasExecMeta(mapPy, conf, p, r)), exec[FlatMapGroupsInPandasExec]( "The backend for Grouped Map Pandas UDF, it runs on CPU itself now but supports " + " scheduling GPU resources for the Python process when enabled", diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuMapInPandasExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuMapInPandasExec.scala index fef286c57a1..0291f58fc40 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuMapInPandasExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuMapInPandasExec.scala @@ -16,22 +16,23 @@ package org.apache.spark.sql.rapids.execution.python +import ai.rapids.cudf import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.GpuMetric._ +import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.python.PythonWorkerSemaphore -import scala.collection.JavaConverters._ import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, - PythonUDF, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} +import org.apache.spark.sql.vectorized.ColumnarBatch class GpuMapInPandasExecMeta( mapPandas: MapInPandasExec, @@ -44,13 +45,17 @@ class GpuMapInPandasExecMeta( override def noReplacementPossibleMessage(reasons: String): String = s"cannot run even partially on the GPU because $reasons" - // Ignore the udf since columnar way is not supported yet - override val childExprs: Seq[BaseExprMeta[_]] = Seq.empty + private val udf: BaseExprMeta[PythonUDF] = GpuOverrides.wrapExpr( + mapPandas.func.asInstanceOf[PythonUDF], conf, Some(this)) + private val resultAttrs: Seq[BaseExprMeta[Attribute]] = + mapPandas.output.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + + override val childExprs: Seq[BaseExprMeta[_]] = resultAttrs :+ udf override def convertToGpu(): GpuExec = GpuMapInPandasExec( - mapPandas.func, - mapPandas.output, + udf.convertToGpu(), + resultAttrs.map(_.convertToGpu()).asInstanceOf[Seq[Attribute]], childPlans.head.convertIfNeeded() ) } @@ -59,10 +64,8 @@ class GpuMapInPandasExecMeta( * A relation produced by applying a function that takes an iterator of pandas DataFrames * and outputs an iterator of pandas DataFrames. * - * This GpuMapInPandasExec aims at supporting running Pandas functional code - * on GPU at Python side. - * - * (Currently it will not run on GPU itself, since the columnar way is not implemented yet.) + * This GpuMapInPandasExec aims at accelerating the data transfer between + * JVM and Python, and scheduling GPU resources for its Python processes. * */ case class GpuMapInPandasExec( @@ -71,13 +74,7 @@ case class GpuMapInPandasExec( child: SparkPlan) extends UnaryExecNode with GpuExec { - override def supportsColumnar = false - override def doExecuteColumnar(): RDD[ColumnarBatch] = { - throw new IllegalStateException(s"Columnar execution is not supported by $this yet") - } - - // Most code is copied from MapInPandasExec, except two GPU related calls - private val pandasFunction = func.asInstanceOf[PythonUDF].func + private val pandasFunction = func.asInstanceOf[GpuPythonUDF].func override def producedAttributes: AttributeSet = AttributeSet(output) @@ -85,52 +82,122 @@ case class GpuMapInPandasExec( override def outputPartitioning: Partitioning = child.outputPartitioning - override protected def doExecute(): RDD[InternalRow] = { - lazy val isPythonOnGpuEnabled = GpuPythonHelper.isPythonOnGpuEnabled(conf) - child.execute().mapPartitionsInternal { inputIter => - // Single function with one struct. - val argOffsets = Array(Array(0)) - val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) - val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) - val outputTypes = child.schema + override lazy val allMetrics: Map[String, GpuMetric] = Map( + NUM_OUTPUT_ROWS -> createMetric(outputRowsLevel, DESCRIPTION_NUM_OUTPUT_ROWS), + NUM_OUTPUT_BATCHES -> createMetric(outputBatchesLevel, DESCRIPTION_NUM_OUTPUT_BATCHES), + NUM_INPUT_ROWS -> createMetric(DEBUG_LEVEL, DESCRIPTION_NUM_INPUT_ROWS), + NUM_INPUT_BATCHES -> createMetric(DEBUG_LEVEL, DESCRIPTION_NUM_INPUT_BATCHES) + ) ++ spillMetrics - // Here we wrap it via another row so that Python sides understand it - // as a DataFrame. - val wrappedIter = inputIter.map(InternalRow(_)) + override protected def doExecute(): RDD[InternalRow] = + throw new IllegalStateException(s"Row-based execution should not occur for $this") - // DO NOT use iter.grouped(). See BatchIterator. - val batchIter = - if (batchSize > 0) new BatchIterator(wrappedIter, batchSize) else Iterator(wrappedIter) + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = gpuLongMetric(NUM_OUTPUT_ROWS) + val numOutputBatches = gpuLongMetric(NUM_OUTPUT_BATCHES) + val numInputRows = gpuLongMetric(NUM_INPUT_ROWS) + val numInputBatches = gpuLongMetric(NUM_INPUT_BATCHES) + val spillCallback = GpuMetric.makeSpillCallback(allMetrics) + lazy val isPythonOnGpuEnabled = GpuPythonHelper.isPythonOnGpuEnabled(conf) + val pyInputTypes = child.schema + val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) + val sessionLocalTimeZone = conf.sessionLocalTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + + // Start process + child.executeColumnar().mapPartitionsInternal { inputIter => + val queue: BatchQueue = new BatchQueue() val context = TaskContext.get() + context.addTaskCompletionListener[Unit](_ => queue.close()) + + // Single function with one struct. + val argOffsets = Array(Array(0)) + val pyInputSchema = StructType(StructField("in_struct", pyInputTypes) :: Nil) + val pythonOutputSchema = StructType(StructField("out_struct", + StructType.fromAttributes(output)) :: Nil) - // Start of GPU things if (isPythonOnGpuEnabled) { GpuPythonHelper.injectGpuInfo(chainedFunc, isPythonOnGpuEnabled) PythonWorkerSemaphore.acquireIfNecessary(context) } - // End of GPU things - - val columnarBatchIter = new ArrowPythonRunner( - chainedFunc, - PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, - argOffsets, - StructType(StructField("struct", outputTypes) :: Nil), - sessionLocalTimeZone, - pythonRunnerConf).compute(batchIter, context.partitionId(), context) - - val unsafeProj = UnsafeProjection.create(output, output) - - columnarBatchIter.flatMap { batch => - // Scalar Iterator UDF returns a StructType column in ColumnarBatch, select - // the children here - val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] - val outputVectors = output.indices.map(structVector.getChild) - val flattenedBatch = new ColumnarBatch(outputVectors.toArray) - flattenedBatch.setNumRows(batch.numRows()) - flattenedBatch.rowIterator.asScala - }.map(unsafeProj) - } - } + + val contextAwareIter = new Iterator[ColumnarBatch] { + // This is to implement the same logic of `ContextAwareIterator` involved from 3.0.2 . + // Doing this can avoid shim layers for Spark versions before 3.0.2. + override def hasNext: Boolean = + !context.isCompleted() && !context.isInterrupted() && inputIter.hasNext + + override def next(): ColumnarBatch = inputIter.next() + } + + val pyInputIterator = new RebatchingRoundoffIterator(contextAwareIter, pyInputTypes, + batchSize, numInputRows, numInputBatches, spillCallback) + .map { batch => + // Here we wrap it via another column so that Python sides understand it + // as a DataFrame. + val structColumn = cudf.ColumnVector.makeStruct(GpuColumnVector.extractBases(batch): _*) + val pyInputBatch = withResource(structColumn) { stColumn => + val gpuColumn = GpuColumnVector.from(stColumn.incRefCount(), pyInputTypes) + new ColumnarBatch(Array(gpuColumn), batch.numRows()) + } + // cache the original batches for release later. + queue.add(batch, spillCallback) + pyInputBatch + } + + if (pyInputIterator.hasNext) { + val pyRunner = new GpuArrowPythonRunner( + chainedFunc, + PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, + argOffsets, + pyInputSchema, + sessionLocalTimeZone, + pythonRunnerConf, + batchSize, + () => queue.close(), + pythonOutputSchema, + Int.MaxValue) + val pythonOutputIter = pyRunner.compute(pyInputIterator, context.partitionId(), context) + + new Iterator[ColumnarBatch] { + + override def hasNext: Boolean = pythonOutputIter.hasNext + + override def next(): ColumnarBatch = { + // We can not assert the result batch from Python has the same row number with the + // input batch. Because Map Pandas UDF allows the output of arbitrary length + // and columns. + // Then try to read as many as possible by specifying `minReadTargetBatchSize` as + // `Int.MaxValue` when creating the `GpuArrowPythonRunner` above. + withResource(pythonOutputIter.next()) { cbFromPython => + numOutputBatches += 1 + numOutputRows += cbFromPython.numRows + extractChildren(cbFromPython) + } + } + + private[this] def extractChildren(batch: ColumnarBatch): ColumnarBatch = { + assert(batch.numCols() == 1, "Expect only one struct column") + assert(batch.column(0).dataType().isInstanceOf[StructType], + "Expect a struct column") + // Scalar Iterator UDF returns a StructType column in ColumnarBatch, select + // the children here + val structColumn = batch.column(0).asInstanceOf[GpuColumnVector].getBase + val outputColumns = output.zipWithIndex.safeMap { + case (attr, i) => + withResource(structColumn.getChildColumnView(i)) { childView => + GpuColumnVector.from(childView.copyToColumnVector(), attr.dataType) + } + } + new ColumnarBatch(outputColumns.toArray, batch.numRows()) + } + } + } else { + // Empty partition, return it directly + inputIter + } + } // end of mapPartitionsInternal + } // end of doExecuteColumnar + } From bb25aa35d2a571f8fa132fd13446e7239db3ea59 Mon Sep 17 00:00:00 2001 From: Firestarman Date: Mon, 29 Mar 2021 15:54:52 +0800 Subject: [PATCH 2/3] Add integration tests for map pandas. Signed-off-by: Firestarman --- integration_tests/src/main/python/udf_test.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/integration_tests/src/main/python/udf_test.py b/integration_tests/src/main/python/udf_test.py index e8221a605b4..3bb4f7617a7 100644 --- a/integration_tests/src/main/python/udf_test.py +++ b/integration_tests/src/main/python/udf_test.py @@ -213,6 +213,22 @@ def pandas_filter(iterator): conf=arrow_udf_conf) +@pytest.mark.parametrize('data_gen', data_gens_nested_for_udf, ids=idfn) +def test_pandas_map_udf_nested_type(data_gen): + # Spark supports limited types as the return type of pandas UDF. + # For details please go to + # https://github.com/apache/spark/blob/master/python/pyspark/sql/pandas/types.py#L28 + # So we return the data frames with one column of integral type for all types of the input. + def size_udf(pdf_itr): + for pdf in pdf_itr: + yield pd.DataFrame({"ret": [i for i in range(len(pdf))]}) + + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, data_gen)\ + .mapInPandas(size_udf, schema="ret long"), + conf=arrow_udf_conf) + + def create_df(spark, data_gen, left_length, right_length): left = binary_op_df(spark, data_gen, length=left_length) right = binary_op_df(spark, data_gen, length=right_length) From 9971a4c21fe9ef887383ddac8553504daf660ffa Mon Sep 17 00:00:00 2001 From: Firestarman Date: Fri, 2 Apr 2021 13:45:57 +0800 Subject: [PATCH 3/3] Address some comments Signed-off-by: Firestarman --- docs/supported_ops.md | 32 ++++++++--------- integration_tests/src/main/python/udf_test.py | 35 +++++++++++++++---- .../nvidia/spark/rapids/GpuOverrides.scala | 11 ++++-- .../com/nvidia/spark/rapids/TypeChecks.scala | 13 +++++++ 4 files changed, 66 insertions(+), 25 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index b372490bfda..5ff8ec1c4ed 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -10616,11 +10616,11 @@ Accelerator support is described below. NS NS NS + +PS* (missing nested DECIMAL, NULL, BINARY, MAP) NS -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) -NS -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) -NS +PS* (missing nested DECIMAL, NULL, BINARY, MAP) + reduction @@ -10659,11 +10659,11 @@ Accelerator support is described below. NS NS NS + +PS* (missing nested DECIMAL, NULL, BINARY, MAP) NS -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) -NS -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) -NS +PS* (missing nested DECIMAL, NULL, BINARY, MAP) + window @@ -10702,11 +10702,11 @@ Accelerator support is described below. NS NS NS + +PS* (missing nested DECIMAL, NULL, BINARY, MAP) NS -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) -NS -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) -NS +PS* (missing nested DECIMAL, NULL, BINARY, MAP) + project @@ -10745,11 +10745,11 @@ Accelerator support is described below. NS NS NS + +PS* (missing nested DECIMAL, NULL, BINARY, MAP) NS -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) -NS -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) -NS +PS* (missing nested DECIMAL, NULL, BINARY, MAP) + Quarter diff --git a/integration_tests/src/main/python/udf_test.py b/integration_tests/src/main/python/udf_test.py index 3bb4f7617a7..0f42b8b2253 100644 --- a/integration_tests/src/main/python/udf_test.py +++ b/integration_tests/src/main/python/udf_test.py @@ -215,17 +215,38 @@ def pandas_filter(iterator): @pytest.mark.parametrize('data_gen', data_gens_nested_for_udf, ids=idfn) def test_pandas_map_udf_nested_type(data_gen): - # Spark supports limited types as the return type of pandas UDF. - # For details please go to - # https://github.com/apache/spark/blob/master/python/pyspark/sql/pandas/types.py#L28 - # So we return the data frames with one column of integral type for all types of the input. - def size_udf(pdf_itr): + # Supported UDF output types by plugin: (commonCudfTypes + ARRAY).nested() + STRUCT + # STRUCT represents the whole dataframe in Map Pandas UDF, so no struct column in UDF output. + # More details is here + # https://github.com/apache/spark/blob/master/python/pyspark/sql/udf.py#L119 + udf_out_schema = 'c_integral long,' \ + 'c_string string,' \ + 'c_fp double,' \ + 'c_bool boolean,' \ + 'c_date date,' \ + 'c_time timestamp,' \ + 'c_array_array array>,' \ + 'c_array_string array' + + def col_types_udf(pdf_itr): for pdf in pdf_itr: - yield pd.DataFrame({"ret": [i for i in range(len(pdf))]}) + # Return a data frame with columns of supported type, and there is only one row. + # The values can not be generated randomly because it should return the same data + # for both CPU and GPU runs. + yield pd.DataFrame({ + "c_integral": [len(pdf)], + "c_string": ["size" + str(len(pdf))], + "c_fp": [float(len(pdf))], + "c_bool": [False], + "c_date": [date(2021, 4, 2)], + "c_time": [datetime(2021, 4, 2, tzinfo=timezone.utc)], + "c_array_array": [[[len(pdf)]]], + "c_array_string": [["size" + str(len(pdf))]] + }) assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, data_gen)\ - .mapInPandas(size_udf, schema="ret long"), + .mapInPandas(col_types_udf, schema=udf_out_schema), conf=arrow_udf_conf) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index ad5df6d2229..89b55a59112 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2008,8 +2008,15 @@ object GpuOverrides { "UDF run in an external python process. Does not actually run on the GPU, but " + "the transfer of data to/from it can be accelerated.", ExprChecks.fullAggAndProject( - (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(), - TypeSig.all, + // Different types of Pandas UDF support different sets of output type. Please refer to + // https://github.com/apache/spark/blob/master/python/pyspark/sql/udf.py#L98 + // for more details. + // It is impossible to specify the exact type signature for each Pandas UDF type in a single + // expression 'PythonUDF'. + // So use the 'unionOfPandasUdfOut' to cover all types for Spark. The type signature of + // plugin is also an union of all the types of Pandas UDF. + (TypeSig.commonCudfTypes + TypeSig.ARRAY).nested() + TypeSig.STRUCT, + TypeSig.unionOfPandasUdfOut, repeatingParamCheck = Some(RepeatingParamCheck( "param", (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(), diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index e9469429d32..5ba7e8b3417 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -454,6 +454,19 @@ object TypeSig { val orderable: TypeSig = (BOOLEAN + BYTE + SHORT + INT + LONG + FLOAT + DOUBLE + DATE + TIMESTAMP + STRING + DECIMAL + NULL + BINARY + CALENDAR + ARRAY + STRUCT + UDT).nested() + /** + * Different types of Pandas UDF support different sets of output type. Please refer to + * https://github.com/apache/spark/blob/master/python/pyspark/sql/udf.py#L98 + * for more details. + * + * It is impossible to specify the exact type signature for each Pandas UDF type in a single + * expression 'PythonUDF'. + * + * So here comes the union of all the sets of supported type, to cover all the cases. + */ + val unionOfPandasUdfOut = (commonCudfTypes + BINARY + DECIMAL + NULL + ARRAY + MAP).nested() + + STRUCT + def getDataType(expr: Expression): Option[DataType] = { try { Some(expr.dataType)