diff --git a/integration_tests/src/main/python/spark_session.py b/integration_tests/src/main/python/spark_session.py index 57602a9bdca..606f9a31dc4 100644 --- a/integration_tests/src/main/python/spark_session.py +++ b/integration_tests/src/main/python/spark_session.py @@ -164,6 +164,9 @@ def is_spark_330_or_later(): def is_spark_340_or_later(): return spark_version() >= "3.4.0" +def is_spark_341(): + return spark_version() == "3.4.1" + def is_spark_350_or_later(): return spark_version() >= "3.5.0" diff --git a/integration_tests/src/main/python/udf_test.py b/integration_tests/src/main/python/udf_test.py index 88281279162..9e3f5d05bcc 100644 --- a/integration_tests/src/main/python/udf_test.py +++ b/integration_tests/src/main/python/udf_test.py @@ -15,7 +15,7 @@ import pytest from conftest import is_at_least_precommit_run, is_not_utc -from spark_session import is_databricks_runtime, is_before_spark_330, is_before_spark_350, is_spark_340_or_later +from spark_session import is_databricks_runtime, is_before_spark_330, is_before_spark_350, is_spark_341 from pyspark.sql.pandas.utils import require_minimum_pyarrow_version, require_minimum_pandas_version @@ -43,12 +43,6 @@ import pyarrow from typing import Iterator, Tuple - -if is_databricks_runtime() and is_spark_340_or_later(): - # Databricks 13.3 does not use separate reader/writer threads for Python UDFs - # which can lead to hangs. Skipping these tests until the Python UDF handling is updated. - pytestmark = pytest.mark.skip(reason="https://github.com/NVIDIA/spark-rapids/issues/9493") - arrow_udf_conf = { 'spark.sql.execution.arrow.pyspark.enabled': 'true', 'spark.rapids.sql.exec.WindowInPandasExec': 'true', @@ -182,7 +176,10 @@ def group_size_udf(to_process: pd.Series) -> int: low_upper_win = Window.partitionBy('a').orderBy('b').rowsBetween(-3, 3) -udf_windows = [no_part_win, unbounded_win, cur_follow_win, pre_cur_win, low_upper_win] +running_win_param = pytest.param(pre_cur_win, marks=pytest.mark.xfail( + condition=is_databricks_runtime() and is_spark_341(), + reason='DB13.3 wrongly uses RunningWindowFunctionExec to evaluate a PythonUDAF and it will fail even on CPU')) +udf_windows = [no_part_win, unbounded_win, cur_follow_win, running_win_param, low_upper_win] window_ids = ['No_Partition', 'Unbounded', 'Unbounded_Following', 'Unbounded_Preceding', 'Lower_Upper'] @@ -338,8 +335,8 @@ def create_df(spark, data_gen, left_length, right_length): @ignore_order @pytest.mark.parametrize('data_gen', [ShortGen(nullable=False)], ids=idfn) def test_cogroup_apply_udf(data_gen): - def asof_join(l, r): - return pd.merge_asof(l, r, on='a', by='b') + def asof_join(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: + return pd.merge_ordered(left, right) def do_it(spark): left, right = create_df(spark, data_gen, 500, 500) diff --git a/jenkins/databricks/build.sh b/jenkins/databricks/build.sh index a68b272257b..7ac947c4686 100755 --- a/jenkins/databricks/build.sh +++ b/jenkins/databricks/build.sh @@ -1,6 +1,6 @@ #!/bin/bash # -# Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -52,7 +52,13 @@ declare -A artifacts initialize() { # install rsync to be used for copying onto the databricks nodes - sudo apt install -y maven rsync + sudo apt install -y rsync + + if [[ ! -d $HOME/apache-maven-3.6.3 ]]; then + wget https://archive.apache.org/dist/maven/maven-3/3.6.3/binaries/apache-maven-3.6.3-bin.tar.gz -P /tmp + tar xf /tmp/apache-maven-3.6.3-bin.tar.gz -C $HOME + sudo ln -s $HOME/apache-maven-3.6.3/bin/mvn /usr/local/bin/mvn + fi # Archive file location of the plugin repository SPARKSRCTGZ=${SPARKSRCTGZ:-''} diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java index 97402f5a58e..63a5084b679 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java @@ -490,7 +490,7 @@ private static StructType structFromTypes(DataType[] format) { return new StructType(fields); } - private static StructType structFromAttributes(List format) { + public static StructType structFromAttributes(List format) { StructField[] fields = new StructField[format.size()]; int i = 0; for (Attribute attribute: format) { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchUtils.scala index bf2d2474dfe..92588885be0 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchUtils.scala @@ -18,11 +18,15 @@ package com.nvidia.spark.rapids import scala.collection.mutable.ArrayBuffer +import ai.rapids.cudf.Table +import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingSeq + import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, MapType, StructType} /** * Utility class with methods for calculating various metrics about GPU memory usage - * prior to allocation. + * prior to allocation, along with some operations with batches. */ object GpuBatchUtils { @@ -175,4 +179,37 @@ object GpuBatchUtils { bytes } } + + /** + * Concatenate the input batches into a single one. + * The caller is responsible for closing the returned batch. + * + * @param spillBatches the batches to be concatenated, will be closed after the call + * returns. + * @return the concatenated SpillableColumnarBatch or None if the input is empty. + */ + def concatSpillBatchesAndClose( + spillBatches: Seq[SpillableColumnarBatch]): Option[SpillableColumnarBatch] = { + val retBatch = if (spillBatches.length >= 2) { + // two or more batches, concatenate them + val (concatTable, types) = RmmRapidsRetryIterator.withRetryNoSplit(spillBatches) { _ => + withResource(spillBatches.safeMap(_.getColumnarBatch())) { batches => + val batchTypes = GpuColumnVector.extractTypes(batches.head) + withResource(batches.safeMap(GpuColumnVector.from)) { tables => + (Table.concatenate(tables: _*), batchTypes) + } + } + } + // Make the concatenated table spillable. + withResource(concatTable) { _ => + SpillableColumnarBatch(GpuColumnVector.from(concatTable, types), + SpillPriorities.ACTIVE_BATCHING_PRIORITY) + } + } else if (spillBatches.length == 1) { + // only one batch + spillBatches.head + } else null + + Option(retBatch) + } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala index 33ca8c906f6..b4582b3e0d5 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.rapids.execution import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import ai.rapids.cudf.Table -import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuHashPartitioningBase, GpuMetric, RmmRapidsRetryIterator, SpillableColumnarBatch, SpillPriorities, TaskAutoCloseableResource} +import com.nvidia.spark.rapids.{GpuBatchUtils, GpuColumnVector, GpuExpression, GpuHashPartitioningBase, GpuMetric, SpillableColumnarBatch, SpillPriorities, TaskAutoCloseableResource} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits._ @@ -41,27 +40,7 @@ object GpuSubPartitionHashJoin { */ def concatSpillBatchesAndClose( spillBatches: Seq[SpillableColumnarBatch]): Option[SpillableColumnarBatch] = { - val retBatch = if (spillBatches.length >= 2) { - // two or more batches, concatenate them - val (concatTable, types) = RmmRapidsRetryIterator.withRetryNoSplit(spillBatches) { _ => - withResource(spillBatches.safeMap(_.getColumnarBatch())) { batches => - val batchTypes = GpuColumnVector.extractTypes(batches.head) - withResource(batches.safeMap(GpuColumnVector.from)) { tables => - (Table.concatenate(tables: _*), batchTypes) - } - } - } - // Make the concatenated table spillable. - withResource(concatTable) { _ => - SpillableColumnarBatch(GpuColumnVector.from(concatTable, types), - SpillPriorities.ACTIVE_BATCHING_PRIORITY) - } - } else if (spillBatches.length == 1) { - // only one batch - spillBatches.head - } else null - - Option(retBatch) + GpuBatchUtils.concatSpillBatchesAndClose(spillBatches) } /** diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/BatchGroupUtils.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/BatchGroupUtils.scala index b97415b31ba..87bcbab785b 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/BatchGroupUtils.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/BatchGroupUtils.scala @@ -22,13 +22,14 @@ import ai.rapids.cudf import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import com.nvidia.spark.rapids.RmmRapidsRetryIterator.withRetryNoSplit import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.rapids.execution.python.shims.GpuPythonArrowOutput +import org.apache.spark.sql.rapids.execution.python.shims.GpuBasePythonRunner import org.apache.spark.sql.rapids.shims.DataTypeUtilsShim import org.apache.spark.sql.vectorized.ColumnarBatch @@ -195,7 +196,7 @@ private[python] object BatchGroupUtils { def executePython[IN]( pyInputIterator: Iterator[IN], output: Seq[Attribute], - pyRunner: GpuPythonRunnerBase[IN], + pyRunner: GpuBasePythonRunner[IN], outputRows: GpuMetric, outputBatches: GpuMetric): Iterator[ColumnarBatch] = { val context = TaskContext.get() @@ -394,38 +395,72 @@ private[python] object BatchGroupedIterator { class CombiningIterator( inputBatchQueue: BatchQueue, pythonOutputIter: Iterator[ColumnarBatch], - pythonArrowReader: GpuPythonArrowOutput, + pythonArrowReader: GpuArrowOutput, numOutputRows: GpuMetric, numOutputBatches: GpuMetric) extends Iterator[ColumnarBatch] { - // For `hasNext` we are waiting on the queue to have something inserted into it - // instead of waiting for a result to be ready from Python. The reason for this - // is to let us know the target number of rows in the batch that we want when reading. - // It is a bit hacked up but it works. In the future when we support spilling we should - // store the number of rows separate from the batch. That way we can get the target batch - // size out without needing to grab the GpuSemaphore which we cannot do if we might block - // on a read operation. - override def hasNext: Boolean = inputBatchQueue.hasNext || pythonOutputIter.hasNext + // This is only for the input. + private var pendingInput: Option[SpillableColumnarBatch] = None + Option(TaskContext.get()).foreach(onTaskCompletion(_)(pendingInput.foreach(_.close()))) + + // The Python output should line up row for row so we only look at the Python output + // iterator and no need to check the `inputPending` who will be consumed when draining + // the Python output. + override def hasNext: Boolean = pythonOutputIter.hasNext override def next(): ColumnarBatch = { - val numRows = inputBatchQueue.peekBatchSize + val numRows = inputBatchQueue.peekBatchNumRows() // Updates the expected batch size for next read - pythonArrowReader.setMinReadTargetBatchSize(numRows) + pythonArrowReader.setMinReadTargetNumRows(numRows) // Reads next batch from Python and combines it with the input batch by the left side. withResource(pythonOutputIter.next()) { cbFromPython => - assert(cbFromPython.numRows() == numRows) - withResource(inputBatchQueue.remove()) { origBatch => + // Here may get a batch has a larger rows number than the current input batch. + assert(cbFromPython.numRows() >= numRows, + s"Expects >=$numRows rows but got ${cbFromPython.numRows()} from the Python worker") + withResource(concatInputBatch(cbFromPython.numRows())) { concated => numOutputBatches += 1 numOutputRows += numRows - combine(origBatch, cbFromPython) + GpuColumnVector.combineColumns(concated, cbFromPython) } } } - private def combine(lBatch: ColumnarBatch, rBatch: ColumnarBatch): ColumnarBatch = { - val lColumns = GpuColumnVector.extractColumns(lBatch).map(_.incRefCount()) - val rColumns = GpuColumnVector.extractColumns(rBatch).map(_.incRefCount()) - new ColumnarBatch(lColumns ++ rColumns, lBatch.numRows()) + private def concatInputBatch(targetNumRows: Int): ColumnarBatch = { + withResource(mutable.ArrayBuffer[SpillableColumnarBatch]()) { buf => + var curNumRows = pendingInput.map(_.numRows()).getOrElse(0) + pendingInput.foreach(buf.append(_)) + pendingInput = None + while (curNumRows < targetNumRows) { + val scb = inputBatchQueue.remove() + if (scb != null) { + buf.append(scb) + curNumRows = curNumRows + scb.numRows() + } + } + assert(buf.nonEmpty, "The input queue is empty") + + if (curNumRows > targetNumRows) { + // Need to split the last batch + val Array(first, second) = withRetryNoSplit(buf.remove(buf.size - 1)) { lastScb => + val splitIdx = lastScb.numRows() - (curNumRows - targetNumRows) + withResource(lastScb.getColumnarBatch()) { lastCb => + val batchTypes = GpuColumnVector.extractTypes(lastCb) + withResource(GpuColumnVector.from(lastCb)) { table => + table.contiguousSplit(splitIdx).safeMap( + SpillableColumnarBatch(_, batchTypes, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)) + } + } + } + buf.append(first) + pendingInput = Some(second) + } + + val ret = GpuBatchUtils.concatSpillBatchesAndClose(buf.toSeq) + // "ret" should be non empty because we checked the buf is not empty ahead. + withResource(ret.get) { concatedScb => + concatedScb.getColumnarBatch() + } + } // end of withResource(mutable.ArrayBuffer) } } @@ -560,3 +595,4 @@ class CoGroupedIterator( keyOrdering.compare(leftKeyRow, rightKeyRow) } } + diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala index 9c02d231706..3f3f2803f5c 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala @@ -22,7 +22,6 @@ import ai.rapids.cudf import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import com.nvidia.spark.rapids.python.PythonWorkerSemaphore import com.nvidia.spark.rapids.shims.ShimUnaryExecNode @@ -141,9 +140,7 @@ case class GpuAggregateInPandasExec( // Start processing child.executeColumnar().mapPartitionsInternal { inputIter => - val queue: BatchQueue = new BatchQueue() val context = TaskContext.get() - onTaskCompletion(queue.close()) if (isPythonOnGpuEnabled) { GpuPythonHelper.injectGpuInfo(pyFuncs, isPythonOnGpuEnabled) @@ -164,51 +161,47 @@ case class GpuAggregateInPandasExec( } // Second splits into separate group batches. - val miniAttrs = gpuGroupingExpressions ++ allInputs - val pyInputIter = BatchGroupedIterator(miniIter, miniAttrs.asInstanceOf[Seq[Attribute]], - groupingRefs.indices) - .map { groupedBatch => - // Resolves the group key and the python input from a grouped batch. Then - // - Caches the key to be combined with the Python output later. And - // - Returns the python input to be sent to Python later. - withResource(groupedBatch) { grouped => - // key batch. - // No `safeMap` because here does not increase the ref count. - // (`Seq.indices.map()` is NOT lazy, so it is safe to be used to slice the columns.) - val keyCudfColumns = groupingRefs.indices.map( - grouped.column(_).asInstanceOf[GpuColumnVector].getBase) - val keyBatch = if (keyCudfColumns.isEmpty) { - // No grouping columns, then the whole batch is a group. Returns the dedicated batch - // as the group key. - // This batch means there is only one empty row, just like the 'new UnsafeRow()' - // used in Spark. The row number setting to 1 is because Python returns only one row - // as the aggregate result for the whole batch, and 'CombiningIterator' requires the - // the same row number for both the key batch and the result batch to be combined. - new ColumnarBatch(Array(), 1) - } else { - // Uses `cudf.Table.gather` to pick the first row in each group as the group key. - // Doing this is because - // - The Python worker produces only one row as the aggregate result, - // - The key rows in a group are equal to each other. - // - // (Now this is done group by group, so the performance would not be good when - // there are too many small groups.) - withResource(new cudf.Table(keyCudfColumns: _*)) { table => - withResource(cudf.ColumnVector.fromInts(0)) { gatherMap => - withResource(table.gather(gatherMap)) { oneRowTable => - GpuColumnVector.from(oneRowTable, groupingRefs.map(_.dataType).toArray) - } - } + val miniAttrs = (gpuGroupingExpressions ++ allInputs).asInstanceOf[Seq[Attribute]] + val keyConverter = (groupedBatch: ColumnarBatch) => { + // No `safeMap` because here does not increase the ref count. + // (`Seq.indices.map()` is NOT lazy, so it is safe to be used to slice the columns.) + val keyCudfColumns = groupingRefs.indices.map( + groupedBatch.column(_).asInstanceOf[GpuColumnVector].getBase) + if (keyCudfColumns.isEmpty) { + // No grouping columns, then the whole batch is a group. Returns the dedicated batch + // as the group key. + // This batch means there is only one empty row, just like the 'new UnsafeRow()' + // used in Spark. The row number setting to 1 is because Python returns only one row + // as the aggregate result for the whole batch, and 'CombiningIterator' requires the + // the same row number for both the key batch and the result batch to be combined. + new ColumnarBatch(Array(), 1) + } else { + // Uses `cudf.Table.gather` to pick the first row in each group as the group key. + // Doing this is because + // - The Python worker produces only one row as the aggregate result, + // - The key rows in a group are equal to each other. + // + // (Now this is done group by group, so the performance would not be good when + // there are too many small groups.) + withResource(new cudf.Table(keyCudfColumns: _*)) { table => + withResource(cudf.ColumnVector.fromInts(0)) { gatherMap => + withResource(table.gather(gatherMap)) { oneRowTable => + GpuColumnVector.from(oneRowTable, groupingRefs.map(_.dataType).toArray) } } - queue.add(keyBatch) + } + } + } - // Python input batch - val pyInputColumns = pyInputRefs.indices.safeMap { idx => - grouped.column(idx + groupingRefs.size).asInstanceOf[GpuColumnVector].incRefCount() - } - new ColumnarBatch(pyInputColumns.toArray, groupedBatch.numRows()) + val batchProducer = new BatchProducer( + BatchGroupedIterator(miniIter, miniAttrs, groupingRefs.indices), Some(keyConverter)) + val pyInputIter = batchProducer.asIterator.map { batch => + withResource(batch) { _ => + val pyInputColumns = pyInputRefs.indices.safeMap { idx => + batch.column(idx + groupingRefs.size).asInstanceOf[GpuColumnVector].incRefCount() } + new ColumnarBatch(pyInputColumns.toArray, batch.numRows()) + } } // Third, sends to Python to execute the aggregate and returns the result. @@ -223,16 +216,15 @@ case class GpuAggregateInPandasExec( pythonRunnerConf, // The whole group data should be written in a single call, so here is unlimited Int.MaxValue, - DataTypeUtilsShim.fromAttributes(pyOutAttributes), - () => queue.finish()) + DataTypeUtilsShim.fromAttributes(pyOutAttributes)) val pyOutputIterator = pyRunner.compute(pyInputIter, context.partitionId(), context) val combinedAttrs = gpuGroupingExpressions.map(_.toAttribute) ++ pyOutAttributes val resultRefs = GpuBindReferences.bindGpuReferences(resultExprs, combinedAttrs) // Gets the combined batch for each group and projects for the output. - new CombiningIterator(queue, pyOutputIterator, pyRunner, mNumOutputRows, - mNumOutputBatches).map { combinedBatch => + new CombiningIterator(batchProducer.getBatchQueue, pyOutputIterator, pyRunner, + mNumOutputRows, mNumOutputBatches).map { combinedBatch => withResource(combinedBatch) { batch => GpuProjectExec.project(batch, resultRefs) } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala index 45fec7c81d2..5e588cae7bd 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf._ import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import com.nvidia.spark.rapids.python.PythonWorkerSemaphore @@ -171,62 +171,151 @@ class RebatchingRoundoffIterator( } /** - * A simple queue that holds the pending batches that need to line up with - * and combined with batches coming back from python + * A trait provides dedicated APIs for GPU reading batches from python. + * This is also for easy type declarations since it is implemented by an inner class + * of BatchProducer. */ -class BatchQueue extends AutoCloseable { - private val queue: mutable.Queue[SpillableColumnarBatch] = - mutable.Queue[SpillableColumnarBatch]() - private var isSet = false - - def add(batch: ColumnarBatch): Unit = synchronized { - queue.enqueue(SpillableColumnarBatch(batch, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)) - if (!isSet) { - // Wake up anyone waiting for the first batch. - isSet = true - notifyAll() - } - } +trait BatchQueue { + /** Return and remove the first batch in the cache. Caller should close it. */ + def remove(): SpillableColumnarBatch - def finish(): Unit = synchronized { - if (!isSet) { - // Wake up anyone waiting for the first batch. - isSet = true - notifyAll() + /** Get the number of rows in the next batch, without actually getting the batch. */ + def peekBatchNumRows(): Int +} + +/** + * It accepts an iterator as input and will cache the batches when pulling them in from + * the input for later combination with batches coming back from python by the reader. + * It also supports an optional converter to convert input batches and put the converted + * result to the cache queue. This is for GpuAggregateInPandas to build and cache key + * batches. + * + * Call "getBatchQueue" to get the internal cache queue and specify it to the output + * combination iterator. + * To access the batches from input, call "asIterator" to get the output iterator. + */ +class BatchProducer( + input: Iterator[ColumnarBatch], + converter: Option[ColumnarBatch => ColumnarBatch] = None +) extends AutoCloseable { producer => + + Option(TaskContext.get()).foreach(onTaskCompletion(_)(close())) + + // A queue that holds the pending batches that need to line up with and combined + // with batches coming back from python. + private[this] val batchQueue = new BatchQueueImpl + + /** Get the internal BatchQueue */ + def getBatchQueue: BatchQueue = batchQueue + + // The cache that holds the pending batches pulled in by the "produce" call for + // the reader peeking the next rows number when the "batchQueue" is empty, and + // consumed by the iterator returned from "asIterator". + // (In fact, there is usually only ONE batch. But using a queue here is because in + // theory "produce" can be called multiple times, then more than one batch can be + // pulled in.) + private[this] val pendingOutput = mutable.Queue[SpillableColumnarBatch]() + + private def produce(): ColumnarBatch = { + if (input.hasNext) { + val cb = input.next() + // Need to duplicate this batch for "next" + pendingOutput.enqueue(SpillableColumnarBatch(GpuColumnVector.incRefCounts(cb), + SpillPriorities.ACTIVE_ON_DECK_PRIORITY)) + cb + } else { + null } } - def remove(): ColumnarBatch = synchronized { - if (queue.isEmpty) { - null - } else { - withResource(queue.dequeue()) { scp => - scp.getColumnarBatch() + /** Return an iterator to access the batches from the input */ + def asIterator: Iterator[ColumnarBatch] = { + new Iterator[ColumnarBatch] { + + override def hasNext: Boolean = producer.synchronized { + pendingOutput.nonEmpty || input.hasNext + } + + override def next(): ColumnarBatch = producer.synchronized { + if (!hasNext) { + throw new NoSuchElementException() + } + if (pendingOutput.nonEmpty) { + withResource(pendingOutput.dequeue()) { scb => + scb.getColumnarBatch() + } + } else { + closeOnExcept(input.next()) { cb => + // Need to duplicate it for later combination with Python output + batchQueue.add(GpuColumnVector.incRefCounts(cb)) + cb + } + } } } } - def hasNext: Boolean = synchronized { - if (!isSet) { - wait() + override def close(): Unit = producer.synchronized { + batchQueue.close() + while (pendingOutput.nonEmpty) { + pendingOutput.dequeue().close() } - queue.nonEmpty } - /** - * Get the number of rows in the next batch, without actually getting the batch. - */ - def peekBatchSize: Int = synchronized { - queue.head.numRows() - } + // Put this batch queue inside the BatchProducer to share the same lock with the + // output iterator returned by "asIterator" and make sure the batch movement from + // input iterator to this queue is an atomic operation. + // In a two-threaded Python runner, using two locks to protect the batch pulling + // from the input and the batch queue separately can not ensure batches in the + // queue has the same order as they are pulled in from the input. Because there is + // a race when the reader and the writer append batches to the queue. + // One possible case is: + // 1) the writer thread gets a batch A, but next it pauses. + // 2) then the reader thread gets the next Batch B, and appends it to the queue. + // 3) the writer thread restores and appends batch A to the queue. + // Therefore, batch A and B have the reversed order in the queue now, leading to data + // corruption when doing the combination. + private class BatchQueueImpl extends BatchQueue with AutoCloseable { + private val queue = mutable.Queue[SpillableColumnarBatch]() + + /** Add a batch to the queue, the input batch will be taken over, do not use it anymore */ + private[python] def add(batch: ColumnarBatch): Unit = { + val cb = converter.map { convert => + withResource(batch)(convert) + }.getOrElse(batch) + queue.enqueue(SpillableColumnarBatch(cb, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)) + } + + /** Return and remove the first batch in the cache. Caller should close it */ + override def remove(): SpillableColumnarBatch = producer.synchronized { + if (queue.isEmpty) { + null + } else { + queue.dequeue() + } + } - override def close(): Unit = synchronized { - if (!isSet) { - isSet = true - notifyAll() + /** Get the number of rows in the next batch, without actually getting the batch. */ + override def peekBatchNumRows(): Int = producer.synchronized { + // Try to pull in the next batch for peek + if (queue.isEmpty) { + val cb = produce() + if (cb != null) { + add(cb) + } + } + + if (queue.nonEmpty) { + queue.head.numRows() + } else { + 0 // Should not go here but just in case. + } } - while(queue.nonEmpty) { - queue.dequeue().close() + + override def close(): Unit = producer.synchronized { + while (queue.nonEmpty) { + queue.dequeue().close() + } } } } @@ -285,10 +374,7 @@ case class GpuArrowEvalPythonExec( val inputRDD = child.executeColumnar() inputRDD.mapPartitions { iter => - val queue: BatchQueue = new BatchQueue() val context = TaskContext.get() - onTaskCompletion(context)(queue.close()) - val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip // Not sure why we are doing this in every task. It is not going to change, but it might @@ -318,14 +404,11 @@ case class GpuArrowEvalPythonExec( }.toArray) val boundReferences = GpuBindReferences.bindReferences(allInputs.toSeq, childOutput) - val batchedIterator = new RebatchingRoundoffIterator(iter, inputSchema, targetBatchSize, - numInputRows, numInputBatches) - val pyInputIterator = batchedIterator.map { batch => - // We have to do the project before we add the batch because the batch might be closed - // when it is added - val ret = GpuProjectExec.project(batch, boundReferences) - queue.add(batch) - ret + val batchProducer = new BatchProducer( + new RebatchingRoundoffIterator(iter, inputSchema, targetBatchSize, numInputRows, + numInputBatches)) + val pyInputIterator = batchProducer.asIterator.map { batch => + withResource(batch)(GpuProjectExec.project(_, boundReferences)) } if (isPythonOnGpuEnabled) { @@ -342,11 +425,10 @@ case class GpuArrowEvalPythonExec( timeZone, runnerConf, targetBatchSize, - pythonOutputSchema, - () => queue.finish()) + pythonOutputSchema) val outputIterator = pyRunner.compute(pyInputIterator, context.partitionId(), context) - new CombiningIterator(queue, outputIterator, pyRunner, numOutputRows, + new CombiningIterator(batchProducer.getBatchQueue, outputIterator, pyRunner, numOutputRows, numOutputBatches) } else { // Empty partition, return it directly diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowPythonRunner.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowPythonRunner.scala deleted file mode 100644 index b323ac62843..00000000000 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowPythonRunner.scala +++ /dev/null @@ -1,216 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.rapids.execution.python - -import java.io.{DataInputStream, DataOutputStream} - -import ai.rapids.cudf._ -import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.Arm.withResource -import org.apache.arrow.vector.VectorSchemaRoot -import org.apache.arrow.vector.ipc.ArrowStreamWriter - -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.python._ -import org.apache.spark.rapids.shims.api.python.ShimBasePythonRunner -import org.apache.spark.sql.execution.python.PythonUDFRunner -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.rapids.execution.python.shims.{GpuArrowPythonRunner, GpuPythonArrowOutput} -import org.apache.spark.sql.rapids.shims.ArrowUtilsShim -import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.Utils - -class BufferToStreamWriter(outputStream: DataOutputStream) extends HostBufferConsumer { - private[this] val tempBuffer = new Array[Byte](128 * 1024) - - override def handleBuffer(hostBuffer: HostMemoryBuffer, length: Long): Unit = { - withResource(hostBuffer) { buffer => - var len = length - var offset: Long = 0 - while(len > 0) { - val toCopy = math.min(tempBuffer.length, len).toInt - buffer.getBytes(tempBuffer, 0, offset, toCopy) - outputStream.write(tempBuffer, 0, toCopy) - len = len - toCopy - offset = offset + toCopy - } - } - } -} - -class StreamToBufferProvider(inputStream: DataInputStream) extends HostBufferProvider { - private[this] val tempBuffer = new Array[Byte](128 * 1024) - - override def readInto(hostBuffer: HostMemoryBuffer, length: Long): Long = { - var amountLeft = length - var totalRead : Long = 0 - while (amountLeft > 0) { - val amountToRead = Math.min(tempBuffer.length, amountLeft).toInt - val amountRead = inputStream.read(tempBuffer, 0, amountToRead) - if (amountRead <= 0) { - // Reached EOF - amountLeft = 0 - } else { - amountLeft -= amountRead - hostBuffer.setBytes(totalRead, tempBuffer, 0, amountRead) - totalRead += amountRead - } - } - totalRead - } -} - -/** - * Base class of GPU Python runners who will be mixed with GpuPythonArrowOutput - * to produce columnar batches. - */ -abstract class GpuPythonRunnerBase[IN]( - funcs: Seq[ChainedPythonFunctions], - evalType: Int, - argOffsets: Array[Array[Int]]) - extends ShimBasePythonRunner[IN, ColumnarBatch](funcs, evalType, argOffsets) - -/** - * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. - */ -abstract class GpuArrowPythonRunnerBase( - funcs: Seq[ChainedPythonFunctions], - evalType: Int, - argOffsets: Array[Array[Int]], - pythonInSchema: StructType, - timeZoneId: String, - conf: Map[String, String], - batchSize: Long, - pythonOutSchema: StructType = null, - onDataWriteFinished: () => Unit = null) - extends GpuPythonRunnerBase[ColumnarBatch](funcs, evalType, argOffsets) - with GpuPythonArrowOutput { - - def toBatch(table: Table): ColumnarBatch = { - GpuColumnVector.from(table, GpuColumnVector.extractTypes(pythonOutSchema)) - } - - override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize - require( - bufferSize >= 4, - "Pandas execution requires more than 4 bytes. Please set higher buffer. " + - s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") - - protected class RapidsWriter( - env: SparkEnv, - inputIterator: Iterator[ColumnarBatch], - partitionIndex: Int, - context: TaskContext) { - - def writeCommand(dataOut: DataOutputStream): Unit = { - - // Write config for the worker as a number of key -> value pairs of strings - dataOut.writeInt(conf.size) - for ((k, v) <- conf) { - PythonRDD.writeUTF(k, dataOut) - PythonRDD.writeUTF(v, dataOut) - } - - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) - } - - def writeInputToStream(dataOut: DataOutputStream): Boolean = { - if (inputIterator.nonEmpty) { - writeNonEmptyIteratorOnGpu(dataOut) - } else { // Partition is empty. - // In this case CPU will still send the schema to Python workers by calling - // the "start" API of the Java Arrow writer, but GPU will send out nothing, - // leading to the IPC error. And it is not easy to do as what Spark does on - // GPU, because the C++ Arrow writer used by GPU will only send out the schema - // iff there is some data. Besides, it does not expose a "start" API to do this. - // So here we leverage the Java Arrow writer to do similar things as Spark. - // It is OK because sending out schema has nothing to do with GPU. - writeEmptyIteratorOnCpu(dataOut) - // Returning false because nothing was written - false - } - } - - private def writeNonEmptyIteratorOnGpu(dataOut: DataOutputStream): Boolean = { - val writer = { - val builder = ArrowIPCWriterOptions.builder() - builder.withMaxChunkSize(batchSize) - builder.withCallback((table: Table) => { - table.close() - GpuSemaphore.releaseIfNecessary(TaskContext.get()) - }) - // Flatten the names of nested struct columns, required by cudf arrow IPC writer. - GpuArrowPythonRunner.flattenNames(pythonInSchema).foreach { case (name, nullable) => - if (nullable) { - builder.withColumnNames(name) - } else { - builder.withNotNullableColumnNames(name) - } - } - Table.writeArrowIPCChunked(builder.build(), new BufferToStreamWriter(dataOut)) - } - - var wrote = false - Utils.tryWithSafeFinally { - while (inputIterator.hasNext) { - wrote = false - val table = withResource(inputIterator.next()) { nextBatch => - GpuColumnVector.from(nextBatch) - } - withResource(new NvtxRange("write python batch", NvtxColor.DARK_GREEN)) { _ => - // The callback will handle closing table and releasing the semaphore - writer.write(table) - wrote = true - } - } - // The iterator can grab the semaphore even on an empty batch - GpuSemaphore.releaseIfNecessary(TaskContext.get()) - } { - writer.close() - dataOut.flush() - if (onDataWriteFinished != null) onDataWriteFinished() - } - wrote - } - - private def writeEmptyIteratorOnCpu(dataOut: DataOutputStream): Unit = { - // most code is copied from Spark - val arrowSchema = ArrowUtilsShim.toArrowSchema(pythonInSchema, timeZoneId) - val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"stdout writer for empty partition", 0, Long.MaxValue) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - - Utils.tryWithSafeFinally { - val writer = new ArrowStreamWriter(root, null, dataOut) - writer.start() - // No data to write - writer.end() - // The iterator can grab the semaphore even on an empty batch - GpuSemaphore.releaseIfNecessary(TaskContext.get()) - } { - root.close() - allocator.close() - if (onDataWriteFinished != null) onDataWriteFinished() - } - } - } -} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowReader.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowReader.scala new file mode 100644 index 00000000000..b31b5de331a --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowReader.scala @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.execution.python + +import java.io.DataInputStream + +import ai.rapids.cudf.{ArrowIPCOptions, HostBufferProvider, HostMemoryBuffer, NvtxColor, NvtxRange, StreamedTableReader, Table} +import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.GpuSemaphore + +import org.apache.spark.TaskContext +import org.apache.spark.sql.vectorized.ColumnarBatch + + +/** A helper class to read arrow data from the input stream to host buffer. */ +private[rapids] class StreamToBufferProvider( + inputStream: DataInputStream) extends HostBufferProvider { + private[this] val tempBuffer = new Array[Byte](128 * 1024) + + override def readInto(hostBuffer: HostMemoryBuffer, length: Long): Long = { + var amountLeft = length + var totalRead : Long = 0 + while (amountLeft > 0) { + val amountToRead = Math.min(tempBuffer.length, amountLeft).toInt + val amountRead = inputStream.read(tempBuffer, 0, amountToRead) + if (amountRead <= 0) { + // Reached EOF + amountLeft = 0 + } else { + amountLeft -= amountRead + hostBuffer.setBytes(totalRead, tempBuffer, 0, amountRead) + totalRead += amountRead + } + } + totalRead + } +} + +trait GpuArrowOutput { + /** + * Update the expected rows number for next reading. + */ + private[rapids] final def setMinReadTargetNumRows(numRows: Int): Unit = { + minReadTargetNumRows = numRows + } + + /** Convert the table received from the Python side to a batch. */ + protected def toBatch(table: Table): ColumnarBatch + + /** + * Default to `Int.MaxValue` to try to read as many as possible. + * Change it by calling `setMinReadTargetNumRows` before a reading. + */ + private var minReadTargetNumRows: Int = Int.MaxValue + + def newGpuArrowReader: GpuArrowReader = new GpuArrowReader + + class GpuArrowReader extends AutoCloseable { + private[this] var tableReader: StreamedTableReader = _ + private[this] var batchLoaded: Boolean = true + + /** Make the reader ready to read data, should be called before reading any batch */ + final def start(stream: DataInputStream): Unit = { + if (tableReader == null) { + val builder = ArrowIPCOptions.builder().withCallback( + () => GpuSemaphore.acquireIfNecessary(TaskContext.get())) + tableReader = Table.readArrowIPCChunked(builder.build(), new StreamToBufferProvider(stream)) + } + } + + final def isStarted: Boolean = tableReader != null + + final def mayHasNext: Boolean = batchLoaded + + final def readNext(): ColumnarBatch = { + val table = + withResource(new NvtxRange("read python batch", NvtxColor.DARK_GREEN)) { _ => + // The GpuSemaphore is acquired in a callback + tableReader.getNextIfAvailable(minReadTargetNumRows) + } + if (table != null) { + batchLoaded = true + withResource(table)(toBatch) + } else { + batchLoaded = false + null + } + } + + def close(): Unit = { + if (tableReader != null) { + tableReader.close() + tableReader = null + } + } + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowWriter.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowWriter.scala new file mode 100644 index 00000000000..14e2a57533d --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowWriter.scala @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.execution.python + +import java.io.DataOutputStream + +import ai.rapids.cudf.{ArrowIPCWriterOptions, HostBufferConsumer, HostMemoryBuffer, NvtxColor, NvtxRange, Table, TableWriter} +import com.nvidia.spark.rapids.{GpuColumnVector, GpuSemaphore} +import com.nvidia.spark.rapids.Arm.withResource +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamWriter +import org.apache.arrow.vector.types.pojo.Schema + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.PythonRDD +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.Utils + +/** A helper class to write arrow data from host buffer to the output stream */ +private[rapids] class BufferToStreamWriter( + outputStream: DataOutputStream) extends HostBufferConsumer { + + private[this] val tempBuffer = new Array[Byte](128 * 1024) + + override def handleBuffer(hostBuffer: HostMemoryBuffer, length: Long): Unit = { + withResource(hostBuffer) { buffer => + var len = length + var offset: Long = 0 + while(len > 0) { + val toCopy = math.min(tempBuffer.length, len).toInt + buffer.getBytes(tempBuffer, 0, offset, toCopy) + outputStream.write(tempBuffer, 0, toCopy) + len = len - toCopy + offset = offset + toCopy + } + } + } +} + +trait GpuArrowWriter extends AutoCloseable { + + protected[this] val inputSchema: StructType + protected[this] val maxBatchSize: Long + + private[this] var tableWriter: TableWriter = _ + private[this] var writerOptions: ArrowIPCWriterOptions = _ + + private def buildWriterOptions: ArrowIPCWriterOptions = { + val builder = ArrowIPCWriterOptions.builder() + builder.withMaxChunkSize(maxBatchSize) + builder.withCallback((table: Table) => { + table.close() + GpuSemaphore.releaseIfNecessary(TaskContext.get()) + }) + // Flatten the names of nested struct columns, required by cudf arrow IPC writer. + GpuArrowWriter.flattenNames(inputSchema).foreach { case (name, nullable) => + if (nullable) { + builder.withColumnNames(name) + } else { + builder.withNotNullableColumnNames(name) + } + } + builder.build() + } + + /** Make the writer ready to write data, should be called before writing any batch */ + final def start(dataOut: DataOutputStream): Unit = { + if (tableWriter == null) { + if (writerOptions == null) { + writerOptions = buildWriterOptions + } + tableWriter = Table.writeArrowIPCChunked(writerOptions, new BufferToStreamWriter(dataOut)) + } + } + + final def writeAndClose(batch: ColumnarBatch): Unit = withResource(batch) { _ => + write(batch) + } + + final def write(batch: ColumnarBatch): Unit = { + withResource(new NvtxRange("write python batch", NvtxColor.DARK_GREEN)) { _ => + // The callback will handle closing table and releasing the semaphore + tableWriter.write(GpuColumnVector.from(batch)) + } + } + + /** This is design to reuse the writer options */ + final def reset(): Unit = { + if (tableWriter != null) { + tableWriter.close() + tableWriter = null + } + } + + def close(): Unit = { + if (tableWriter != null) { + tableWriter.close() + tableWriter = null + writerOptions = null + } + } + +} + +object GpuArrowWriter { + /** + * Create a simple GpuArrowWriter in case you don't want to implement a new one + * by extending from the trait. + */ + def apply(schema: StructType, maxSize: Long): GpuArrowWriter = { + new GpuArrowWriter { + override protected val inputSchema: StructType = schema + override protected val maxBatchSize: Long = maxSize + } + } + + def flattenNames(d: DataType, nullable: Boolean = true): Seq[(String, Boolean)] = + d match { + case s: StructType => + s.flatMap(sf => Seq((sf.name, sf.nullable)) ++ flattenNames(sf.dataType, sf.nullable)) + case m: MapType => + flattenNames(m.keyType, nullable) ++ flattenNames(m.valueType, nullable) + case a: ArrayType => flattenNames(a.elementType, nullable) + case _ => Nil + } +} + +abstract class GpuArrowPythonWriter( + override val inputSchema: StructType, + override val maxBatchSize: Long) extends GpuArrowWriter { + + protected def writeUDFs(dataOut: DataOutputStream): Unit + + def writeCommand(dataOut: DataOutputStream, confs: Map[String, String]): Unit = { + // Write config for the worker as a number of key -> value pairs of strings + dataOut.writeInt(confs.size) + for ((k, v) <- confs) { + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) + } + writeUDFs(dataOut) + } + + /** + * This is for writing the empty partition. + * In this case CPU will still send the schema to Python workers by calling + * the "start" API of the Java Arrow writer, but GPU will send out nothing, + * leading to the IPC error. And it is not easy to do as what Spark does on + * GPU, because the C++ Arrow writer used by GPU will only send out the schema + * iff there is some data. Besides, it does not expose a "start" API to do this. + * So here we leverage the Java Arrow writer to do similar things as Spark. + * It is OK because sending out schema has nothing to do with GPU. + * (Most code is copied from Spark) + */ + final def writeEmptyIteratorOnCpu(dataOut: DataOutputStream, + arrowSchema: Schema): Unit = { + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdout writer for empty partition", 0, Long.MaxValue) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + Utils.tryWithSafeFinally { + val writer = new ArrowStreamWriter(root, null, dataOut) + writer.start() + // No data to write + writer.end() + } { + root.close() + allocator.close() + } + } + +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapGroupsInPandasExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapGroupsInPandasExec.scala index 0baf2d2661f..6c2f716583f 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapGroupsInPandasExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapGroupsInPandasExec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasExec import org.apache.spark.sql.rapids.execution.python.BatchGroupUtils._ -import org.apache.spark.sql.rapids.execution.python.shims._ +import org.apache.spark.sql.rapids.execution.python.shims.GpuGroupedPythonRunnerFactory import org.apache.spark.sql.rapids.shims.DataTypeUtilsShim import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -122,11 +122,8 @@ case class GpuFlatMapGroupsInPandasExec( val GroupArgs(dedupAttrs, argOffsets, groupingOffsets) = resolveArgOffsets(child, groupingAttributes) - val runnerShims = GpuArrowPythonRunnerShims(conf, - chainedFunc, - Array(argOffsets), - DataTypeUtilsShim.fromAttributes(dedupAttrs), - pythonOutputSchema) + val runnerFactory = GpuGroupedPythonRunnerFactory(conf, chainedFunc, Array(argOffsets), + DataTypeUtilsShim.fromAttributes(dedupAttrs), pythonOutputSchema) // Start processing. Map grouped batches to ArrowPythonRunner results. child.executeColumnar().mapPartitionsInternal { inputIter => @@ -142,7 +139,7 @@ case class GpuFlatMapGroupsInPandasExec( if (pyInputIter.hasNext) { // Launch Python workers only when the data is not empty. - val pyRunner = runnerShims.getRunner() + val pyRunner = runnerFactory.getRunner() executePython(pyInputIter, localOutput, pyRunner, mNumOutputRows, mNumOutputBatches) } else { // Empty partition, return it directly diff --git a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuMapInBatchExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuMapInBatchExec.scala similarity index 92% rename from sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuMapInBatchExec.scala rename to sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuMapInBatchExec.scala index 2ee51096fcb..d5c3ef50400 100644 --- a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuMapInBatchExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuMapInBatchExec.scala @@ -13,12 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -/*** spark-rapids-shim-json-lines -{"spark": "341db"} -spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.execution.python +import scala.collection.JavaConverters.seqAsJavaListConverter + import ai.rapids.cudf import ai.rapids.cudf.Table import com.nvidia.spark.rapids._ @@ -62,6 +60,8 @@ trait GpuMapInBatchExec extends ShimUnaryExecNode with GpuPythonExecBase { val pythonRunnerConf = ArrowUtilsShim.getPythonRunnerConfMap(conf) val isPythonOnGpuEnabled = GpuPythonHelper.isPythonOnGpuEnabled(conf) val localOutput = output + val localBatchSize = batchSize + val localEvalType = pythonEvalType // Start process child.executeColumnar().mapPartitionsInternal { inputIter => @@ -77,8 +77,7 @@ trait GpuMapInBatchExec extends ShimUnaryExecNode with GpuPythonExecBase { } val pyInputIterator = new RebatchingRoundoffIterator(inputIter, pyInputTypes, - batchSize, numInputRows, numInputBatches) - .map { batch => + localBatchSize, numInputRows, numInputBatches).map { batch => // Here we wrap it via another column so that Python sides understand it // as a DataFrame. withResource(batch) { b => @@ -88,15 +87,16 @@ trait GpuMapInBatchExec extends ShimUnaryExecNode with GpuPythonExecBase { new ColumnarBatch(Array(gpuColumn), b.numRows()) } } - } + } val pyRunner = new GpuArrowPythonRunner( chainedFunc, - pythonEvalType, + localEvalType, argOffsets, pyInputSchema, sessionLocalTimeZone, pythonRunnerConf, - batchSize) { + localBatchSize, + GpuColumnVector.structFromAttributes(localOutput.asJava)) { override def toBatch(table: Table): ColumnarBatch = { BatchGroupedIterator.extractChildren(table, localOutput) } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuPythonRunnerCommon.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuPythonRunnerCommon.scala new file mode 100644 index 00000000000..3fdd94b61d3 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuPythonRunnerCommon.scala @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.execution.python + +import ai.rapids.cudf.Table +import com.nvidia.spark.rapids.GpuColumnVector + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.rapids.execution.python.shims.GpuBasePythonRunner +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * A trait to put some common things from Spark for the basic GPU Arrow Python runners + */ +trait GpuPythonRunnerCommon { _: GpuBasePythonRunner[_] => + + protected val pythonOutSchema: StructType + + protected def toBatch(table: Table): ColumnarBatch = { + GpuColumnVector.from(table, GpuColumnVector.extractTypes(pythonOutSchema)) + } + + override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize + require( + bufferSize >= 4, + "Pandas execution requires more than 4 bytes. Please set higher buffer. " + + s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala index 66c18011a4e..ab56a0b24b5 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala @@ -499,24 +499,19 @@ trait GpuWindowInPandasExecBase extends ShimUnaryExecNode with GpuPythonExecBase // 8) Start processing. child.executeColumnar().mapPartitions { inputIter => val context = TaskContext.get() - val queue: BatchQueue = new BatchQueue() - onTaskCompletion(context)(queue.close()) val boundDataRefs = GpuBindReferences.bindGpuReferences(dataInputs.toSeq, childOutput) // Re-batching the input data by GroupingIterator val boundPartitionRefs = GpuBindReferences.bindGpuReferences(gpuPartitionSpec, childOutput) - val groupedIterator = new GroupingIterator(inputIter, boundPartitionRefs, - numInputRows, numInputBatches) - val pyInputIterator = groupedIterator.map { batch => - // We have to do the project before we add the batch because the batch might be closed - // when it is added - val projectedBatch = GpuProjectExec.project(batch, boundDataRefs) - // Compute the window bounds and insert to the head of each row for one batch - val inputBatch = withResource(projectedBatch) { projectedCb => - insertWindowBounds(projectedCb) + val batchProducer = new BatchProducer( + new GroupingIterator(inputIter, boundPartitionRefs, numInputRows, numInputBatches)) + val pyInputIterator = batchProducer.asIterator.map { batch => + withResource(batch) { _ => + withResource(GpuProjectExec.project(batch, boundDataRefs)) { projectedCb => + // Compute the window bounds and insert to the head of each row for one batch + insertWindowBounds(projectedCb) + } } - queue.add(batch) - inputBatch } if (isPythonOnGpuEnabled) { @@ -534,12 +529,11 @@ trait GpuWindowInPandasExecBase extends ShimUnaryExecNode with GpuPythonExecBase pythonRunnerConf, /* The whole group data should be written in a single call, so here is unlimited */ Int.MaxValue, - pythonOutputSchema, - () => queue.finish()) + pythonOutputSchema) val outputIterator = pyRunner.compute(pyInputIterator, context.partitionId(), context) - new CombiningIterator(queue, outputIterator, pyRunner, numOutputRows, - numOutputBatches).map(projectResult(_)) + new CombiningIterator(batchProducer.getBatchQueue, outputIterator, pyRunner, + numOutputRows, numOutputBatches).map(projectResult) } else { // Empty partition, return the input iterator directly inputIter diff --git a/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/PythonUDFShim.scala b/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/PythonUDFShim.scala index 1151ad55d8f..0dd0274cec8 100644 --- a/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/PythonUDFShim.scala +++ b/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/PythonUDFShim.scala @@ -35,7 +35,6 @@ {"spark": "333"} {"spark": "340"} {"spark": "341"} -{"spark": "341db"} spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala deleted file mode 100644 index d4aeef00369..00000000000 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*** spark-rapids-shim-json-lines -{"spark": "311"} -{"spark": "312"} -{"spark": "313"} -spark-rapids-shim-json-lines ***/ -package org.apache.spark.rapids.shims.api.python - -import java.io.DataInputStream -import java.net.Socket -import java.util.concurrent.atomic.AtomicBoolean - -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.python.BasePythonRunner - -// pid is not a constructor argument in 30x and 31x -abstract class ShimBasePythonRunner[IN, OUT]( - funcs : scala.Seq[org.apache.spark.api.python.ChainedPythonFunctions], - evalType : scala.Int, argOffsets : scala.Array[scala.Array[scala.Int]] -) extends BasePythonRunner[IN, OUT](funcs, evalType, argOffsets) { - protected abstract class ShimReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext - ) extends ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) -} diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonOutput.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonOutput.scala new file mode 100644 index 00000000000..a6f58095c95 --- /dev/null +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonOutput.scala @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "311"} +{"spark": "312"} +{"spark": "313"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.execution.python.shims + +import java.io.DataInputStream +import java.net.Socket +import java.util.concurrent.atomic.AtomicBoolean + +import com.nvidia.spark.rapids.GpuSemaphore + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python._ +import org.apache.spark.sql.rapids.execution.python.GpuArrowOutput +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * A trait that can be mixed-in with `GpuBasePythonRunner`. It implements the logic from + * Python (Arrow) to GPU/JVM (ColumnarBatch). + */ +trait GpuArrowPythonOutput extends GpuArrowOutput { _: GpuBasePythonRunner[_] => + protected def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[ColumnarBatch] = { + + new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) { + val gpuArrowReader = newGpuArrowReader + + protected override def read(): ColumnarBatch = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + // Because of batching and other things we have to be sure that we release the semaphore + // before any operation that could block. This is because we are using multiple threads + // for a single task and the GpuSemaphore might not wake up both threads associated with + // the task, so a reader can be blocked waiting for data, while a writer is waiting on + // the semaphore + GpuSemaphore.releaseIfNecessary(TaskContext.get()) + if (gpuArrowReader.isStarted && gpuArrowReader.mayHasNext) { + val batch = gpuArrowReader.readNext() + if (batch != null) { + batch + } else { + gpuArrowReader.close() // reach the end, close the reader + read() // read the end signal + } + } else { + stream.readInt() match { + case SpecialLengths.START_ARROW_STREAM => + gpuArrowReader.start(stream) + read() + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } + } catch handleException + } + } + } +} diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonRunner.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonRunner.scala new file mode 100644 index 00000000000..38be0680dc6 --- /dev/null +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonRunner.scala @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "311"} +{"spark": "312"} +{"spark": "313"} +{"spark": "320"} +{"spark": "321"} +{"spark": "321cdh"} +{"spark": "321db"} +{"spark": "322"} +{"spark": "323"} +{"spark": "324"} +{"spark": "330"} +{"spark": "330cdh"} +{"spark": "331"} +{"spark": "332"} +{"spark": "332cdh"} +{"spark": "332db"} +{"spark": "333"} +{"spark": "340"} +{"spark": "341"} +{"spark": "350"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.execution.python.shims + +import java.io.DataOutputStream +import java.net.Socket + +import com.nvidia.spark.rapids.GpuSemaphore + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.sql.execution.python.PythonUDFRunner +import org.apache.spark.sql.rapids.execution.python.{GpuArrowPythonWriter, GpuPythonRunnerCommon} +import org.apache.spark.sql.rapids.shims.ArrowUtilsShim +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.Utils + +/** + * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. + */ +class GpuArrowPythonRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + pythonInSchema: StructType, + timeZoneId: String, + conf: Map[String, String], + maxBatchSize: Long, + override val pythonOutSchema: StructType, + jobArtifactUUID: Option[String] = None) + extends GpuBasePythonRunner[ColumnarBatch](funcs, evalType, argOffsets, jobArtifactUUID) + with GpuArrowPythonOutput with GpuPythonRunnerCommon { + + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[ColumnarBatch], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + val arrowWriter = new GpuArrowPythonWriter(pythonInSchema, maxBatchSize) { + override protected def writeUDFs(dataOut: DataOutputStream): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } + } + val isInputNonEmpty = inputIterator.nonEmpty + lazy val arrowSchema = ArrowUtilsShim.toArrowSchema(pythonInSchema, timeZoneId) + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + arrowWriter.writeCommand(dataOut, conf) + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + if (isInputNonEmpty) { + arrowWriter.start(dataOut) + Utils.tryWithSafeFinally { + while (inputIterator.hasNext) { + arrowWriter.writeAndClose(inputIterator.next()) + } + } { + arrowWriter.close() + GpuSemaphore.releaseIfNecessary(TaskContext.get()) + dataOut.flush() + } + } else { + // The iterator can grab the semaphore even on an empty batch + GpuSemaphore.releaseIfNecessary(TaskContext.get()) + arrowWriter.writeEmptyIteratorOnCpu(dataOut, arrowSchema) + } + } + } + } +} diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuBasePythonRunner.scala similarity index 50% rename from sql-plugin/src/main/spark320/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala rename to sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuBasePythonRunner.scala index fa2f3f3fc72..0f5613289e6 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuBasePythonRunner.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,9 @@ */ /*** spark-rapids-shim-json-lines +{"spark": "311"} +{"spark": "312"} +{"spark": "313"} {"spark": "320"} {"spark": "321"} {"spark": "321cdh"} @@ -33,28 +36,14 @@ {"spark": "340"} {"spark": "341"} spark-rapids-shim-json-lines ***/ -package org.apache.spark.rapids.shims.api.python +package org.apache.spark.sql.rapids.execution.python.shims -import java.io.DataInputStream -import java.net.Socket -import java.util.concurrent.atomic.AtomicBoolean +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions} +import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.python.BasePythonRunner - -abstract class ShimBasePythonRunner[IN, OUT]( - funcs : scala.Seq[org.apache.spark.api.python.ChainedPythonFunctions], - evalType : scala.Int, argOffsets : scala.Array[scala.Array[scala.Int]] -) extends BasePythonRunner[IN, OUT](funcs, evalType, argOffsets) { - protected abstract class ShimReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext - ) extends ReaderIterator(stream, writerThread, startTime, env, worker, pid, - releasedOrClosed, context) -} +abstract class GpuBasePythonRunner[IN]( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + jobArtifactUUID: Option[String] // Introduced after 341 +) extends BasePythonRunner[IN, ColumnarBatch](funcs, evalType, argOffsets) diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuCoGroupedArrowPythonRunner.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuCoGroupedArrowPythonRunner.scala index 7757a0c3582..e6610e49dab 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuCoGroupedArrowPythonRunner.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuCoGroupedArrowPythonRunner.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,6 @@ {"spark": "324"} {"spark": "330"} {"spark": "330cdh"} -{"spark": "330db"} {"spark": "331"} {"spark": "332"} {"spark": "332cdh"} @@ -42,14 +41,13 @@ package org.apache.spark.sql.rapids.execution.python.shims import java.io.DataOutputStream import java.net.Socket -import ai.rapids.cudf.{ArrowIPCWriterOptions, NvtxColor, NvtxRange, Table} -import com.nvidia.spark.rapids.{GpuColumnVector, GpuSemaphore} import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.GpuSemaphore import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRDD} import org.apache.spark.sql.execution.python.PythonUDFRunner -import org.apache.spark.sql.rapids.execution.python._ +import org.apache.spark.sql.rapids.execution.python.{GpuArrowWriter, GpuPythonRunnerCommon} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.Utils @@ -69,9 +67,10 @@ class GpuCoGroupedArrowPythonRunner( timeZoneId: String, conf: Map[String, String], batchSize: Int, - pythonOutSchema: StructType) - extends GpuPythonRunnerBase[(ColumnarBatch, ColumnarBatch)](funcs, evalType, argOffsets) - with GpuPythonArrowOutput { + override val pythonOutSchema: StructType, + jobArtifactUUID: Option[String] = None) + extends GpuBasePythonRunner[(ColumnarBatch, ColumnarBatch)](funcs, evalType, + argOffsets, jobArtifactUUID) with GpuArrowPythonOutput with GpuPythonRunnerCommon { protected override def newWriterThread( env: SparkEnv, @@ -82,14 +81,11 @@ class GpuCoGroupedArrowPythonRunner( new WriterThread(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { - - // Write config for the worker as a number of key -> value pairs of strings dataOut.writeInt(conf.size) for ((k, v) <- conf) { PythonRDD.writeUTF(k, dataOut) PythonRDD.writeUTF(v, dataOut) } - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) } @@ -107,42 +103,20 @@ class GpuCoGroupedArrowPythonRunner( // The iterator can grab the semaphore even on an empty batch GpuSemaphore.releaseIfNecessary(TaskContext.get()) dataOut.writeInt(0) + dataOut.flush() } private def writeGroupBatch(groupBatch: ColumnarBatch, batchSchema: StructType, dataOut: DataOutputStream): Unit = { - val writer = { - val builder = ArrowIPCWriterOptions.builder() - builder.withMaxChunkSize(batchSize) - builder.withCallback((table: Table) => { - table.close() - GpuSemaphore.releaseIfNecessary(TaskContext.get()) - }) - // Flatten the names of nested struct columns, required by cudf arrow IPC writer. - GpuArrowPythonRunner.flattenNames(batchSchema).foreach { case (name, nullable) => - if (nullable) { - builder.withColumnNames(name) - } else { - builder.withNotNullableColumnNames(name) - } - } - Table.writeArrowIPCChunked(builder.build(), new BufferToStreamWriter(dataOut)) - } - + val gpuArrowWriter = GpuArrowWriter(batchSchema, batchSize) Utils.tryWithSafeFinally { - withResource(new NvtxRange("write python batch", NvtxColor.DARK_GREEN)) { _ => - // The callback will handle closing table and releasing the semaphore - writer.write(GpuColumnVector.from(groupBatch)) - } + gpuArrowWriter.start(dataOut) + gpuArrowWriter.write(groupBatch) } { - writer.close() + gpuArrowWriter.reset() dataOut.flush() } } // end of writeGroup } } // end of newWriterThread - - def toBatch(table: Table): ColumnarBatch = { - GpuColumnVector.from(table, GpuColumnVector.extractTypes(pythonOutSchema)) - } } diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonRunnerShims.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala similarity index 79% rename from sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonRunnerShims.scala rename to sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala index d7491ed5a9a..6e6f5edd1d9 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonRunnerShims.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala @@ -36,22 +36,21 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.execution.python.shims -import org.apache.spark.api.python._ -import org.apache.spark.sql.rapids.execution.python._ +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.rapids.shims.ArrowUtilsShim -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch -case class GpuArrowPythonRunnerShims( - conf: org.apache.spark.sql.internal.SQLConf, - chainedFunc: Seq[ChainedPythonFunctions], - argOffsets: Array[Array[Int]], - dedupAttrs: StructType, - pythonOutputSchema: StructType) { +case class GpuGroupedPythonRunnerFactory( + conf: org.apache.spark.sql.internal.SQLConf, + chainedFunc: Seq[ChainedPythonFunctions], + argOffsets: Array[Array[Int]], + dedupAttrs: StructType, + pythonOutputSchema: StructType) { val sessionLocalTimeZone = conf.sessionLocalTimeZone val pythonRunnerConf = ArrowUtilsShim.getPythonRunnerConfMap(conf) - def getRunner(): GpuPythonRunnerBase[ColumnarBatch] = { + def getRunner(): GpuBasePythonRunner[ColumnarBatch] = { new GpuArrowPythonRunner( chainedFunc, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuMapInBatchExec.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuMapInBatchExec.scala deleted file mode 100644 index 91dc6d3789f..00000000000 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuMapInBatchExec.scala +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*** spark-rapids-shim-json-lines -{"spark": "311"} -{"spark": "312"} -{"spark": "313"} -{"spark": "320"} -{"spark": "321"} -{"spark": "321cdh"} -{"spark": "321db"} -{"spark": "322"} -{"spark": "323"} -{"spark": "324"} -{"spark": "330"} -{"spark": "330cdh"} -{"spark": "330db"} -{"spark": "331"} -{"spark": "332"} -{"spark": "332cdh"} -{"spark": "332db"} -{"spark": "333"} -{"spark": "340"} -{"spark": "341"} -{"spark": "350"} -spark-rapids-shim-json-lines ***/ -package org.apache.spark.sql.rapids.execution.python - -import ai.rapids.cudf -import ai.rapids.cudf.Table -import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.Arm.withResource -import com.nvidia.spark.rapids.python.PythonWorkerSemaphore -import com.nvidia.spark.rapids.shims.ShimUnaryExecNode - -import org.apache.spark.{ContextAwareIterator, TaskContext} -import org.apache.spark.api.python.ChainedPythonFunctions -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression} -import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.rapids.execution.python.shims.GpuArrowPythonRunner -import org.apache.spark.sql.rapids.shims.ArrowUtilsShim -import org.apache.spark.sql.types.{StructField, StructType} -import org.apache.spark.sql.vectorized.ColumnarBatch - -/* - * A relation produced by applying a function that takes an iterator of batches - * such as pandas DataFrame or PyArrow's record batches, and outputs an iterator of them. - */ -trait GpuMapInBatchExec extends ShimUnaryExecNode with GpuPythonExecBase { - - protected val func: Expression - protected val pythonEvalType: Int - - private val pandasFunction = func.asInstanceOf[GpuPythonUDF].func - - override def producedAttributes: AttributeSet = AttributeSet(output) - - private val batchSize = conf.arrowMaxRecordsPerBatch - - override def outputPartitioning: Partitioning = child.outputPartitioning - - override def internalDoExecuteColumnar(): RDD[ColumnarBatch] = { - val (numInputRows, numInputBatches, numOutputRows, numOutputBatches) = commonGpuMetrics() - - val pyInputTypes = child.schema - val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) - val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pythonRunnerConf = ArrowUtilsShim.getPythonRunnerConfMap(conf) - val isPythonOnGpuEnabled = GpuPythonHelper.isPythonOnGpuEnabled(conf) - val localOutput = output - - // Start process - child.executeColumnar().mapPartitionsInternal { inputIter => - val context = TaskContext.get() - - // Single function with one struct. - val argOffsets = Array(Array(0)) - val pyInputSchema = StructType(StructField("in_struct", pyInputTypes) :: Nil) - - if (isPythonOnGpuEnabled) { - GpuPythonHelper.injectGpuInfo(chainedFunc, isPythonOnGpuEnabled) - PythonWorkerSemaphore.acquireIfNecessary(context) - } - - val contextAwareIter = new ContextAwareIterator(context, inputIter) - - val pyInputIterator = new RebatchingRoundoffIterator(contextAwareIter, pyInputTypes, - batchSize, numInputRows, numInputBatches) - .map { batch => - // Here we wrap it via another column so that Python sides understand it - // as a DataFrame. - withResource(batch) { b => - val structColumn = cudf.ColumnVector.makeStruct(GpuColumnVector.extractBases(b): _*) - withResource(structColumn) { stColumn => - val gpuColumn = GpuColumnVector.from(stColumn.incRefCount(), pyInputTypes) - new ColumnarBatch(Array(gpuColumn), b.numRows()) - } - } - } - - val pyRunner = new GpuArrowPythonRunner( - chainedFunc, - pythonEvalType, - argOffsets, - pyInputSchema, - sessionLocalTimeZone, - pythonRunnerConf, - batchSize) { - override def toBatch(table: Table): ColumnarBatch = { - BatchGroupedIterator.extractChildren(table, localOutput) - } - } - - pyRunner.compute(pyInputIterator, context.partitionId(), context) - .map { cb => - numOutputBatches += 1 - numOutputRows += cb.numRows - cb - } - } // end of mapPartitionsInternal - } // end of internalDoExecuteColumnar - -} diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuPythonArrowShims.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuPythonArrowShims.scala deleted file mode 100644 index 681cdd3b11c..00000000000 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuPythonArrowShims.scala +++ /dev/null @@ -1,216 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*** spark-rapids-shim-json-lines -{"spark": "311"} -{"spark": "312"} -{"spark": "313"} -{"spark": "320"} -{"spark": "321"} -{"spark": "321cdh"} -{"spark": "321db"} -{"spark": "322"} -{"spark": "323"} -{"spark": "324"} -{"spark": "330"} -{"spark": "330cdh"} -{"spark": "330db"} -{"spark": "331"} -{"spark": "332"} -{"spark": "332cdh"} -{"spark": "332db"} -{"spark": "333"} -{"spark": "340"} -{"spark": "341"} -{"spark": "350"} -spark-rapids-shim-json-lines ***/ -package org.apache.spark.sql.rapids.execution.python.shims - -import java.io.{DataInputStream, DataOutputStream} -import java.net.Socket -import java.util.concurrent.atomic.AtomicBoolean - -import ai.rapids.cudf._ -import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.Arm.withResource -import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion - -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.python._ -import org.apache.spark.sql.rapids.execution.python._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.ColumnarBatch - -/** - * A trait that can be mixed-in with `GpuPythonRunnerBase`. It implements the logic from - * Python (Arrow) to GPU/JVM (ColumnarBatch). - */ -trait GpuPythonArrowOutput { _: GpuPythonRunnerBase[_] => - - /** - * Default to `Int.MaxValue` to try to read as many as possible. - * Change it by calling `setMinReadTargetBatchSize` before a reading. - */ - private var minReadTargetBatchSize: Int = Int.MaxValue - - /** - * Update the expected batch size for next reading. - */ - private[python] final def setMinReadTargetBatchSize(size: Int): Unit = { - minReadTargetBatchSize = size - } - - /** Convert the table received from the Python side to a batch. */ - protected def toBatch(table: Table): ColumnarBatch - - protected def newReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - releasedOrClosed: AtomicBoolean, - context: TaskContext - ): Iterator[ColumnarBatch] = { - newReaderIterator(stream, writerThread, startTime, env, worker, None, releasedOrClosed, - context) - } - - protected def newReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[ColumnarBatch] = { - - new ShimReaderIterator(stream, writerThread, startTime, env, worker, pid, releasedOrClosed, - context) { - - private[this] var arrowReader: StreamedTableReader = _ - - onTaskCompletion(context) { - if (arrowReader != null) { - arrowReader.close() - arrowReader = null - } - } - - private var batchLoaded = true - - protected override def read(): ColumnarBatch = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - try { - // Because of batching and other things we have to be sure that we release the semaphore - // before any operation that could block. This is because we are using multiple threads - // for a single task and the GpuSemaphore might not wake up both threads associated with - // the task, so a reader can be blocked waiting for data, while a writer is waiting on - // the semaphore - GpuSemaphore.releaseIfNecessary(TaskContext.get()) - if (arrowReader != null && batchLoaded) { - // The GpuSemaphore is acquired in a callback - val table = - withResource(new NvtxRange("read python batch", NvtxColor.DARK_GREEN)) { _ => - arrowReader.getNextIfAvailable(minReadTargetBatchSize) - } - if (table == null) { - batchLoaded = false - arrowReader.close() - arrowReader = null - read() - } else { - withResource(table) { _ => - batchLoaded = true - toBatch(table) - } - } - } else { - stream.readInt() match { - case SpecialLengths.START_ARROW_STREAM => - val builder = ArrowIPCOptions.builder() - builder.withCallback(() => - GpuSemaphore.acquireIfNecessary(TaskContext.get())) - arrowReader = Table.readArrowIPCChunked(builder.build(), - new StreamToBufferProvider(stream)) - read() - case SpecialLengths.TIMING_DATA => - handleTimingData() - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - throw handlePythonException() - case SpecialLengths.END_OF_DATA_SECTION => - handleEndOfDataSection() - null - } - } - } catch handleException - } - } - } -} - -/** - * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. - */ -class GpuArrowPythonRunner( - funcs: Seq[ChainedPythonFunctions], - evalType: Int, - argOffsets: Array[Array[Int]], - pythonInSchema: StructType, - timeZoneId: String, - conf: Map[String, String], - batchSize: Long, - pythonOutSchema: StructType = null, - onDataWriteFinished: () => Unit = null) - extends GpuArrowPythonRunnerBase(funcs, evalType, argOffsets, pythonInSchema, timeZoneId, - conf, batchSize, pythonOutSchema, onDataWriteFinished) { - - protected override def newWriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[ColumnarBatch], - partitionIndex: Int, - context: TaskContext): WriterThread = { - new WriterThread(env, worker, inputIterator, partitionIndex, context) { - - val workerImpl = new RapidsWriter(env, inputIterator, partitionIndex, context) - - protected override def writeCommand(dataOut: DataOutputStream): Unit = { - workerImpl.writeCommand(dataOut) - } - - protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - workerImpl.writeInputToStream(dataOut) - } - } - } -} - -object GpuArrowPythonRunner { - def flattenNames(d: DataType, nullable: Boolean = true): Seq[(String, Boolean)] = - d match { - case s: StructType => - s.flatMap(sf => Seq((sf.name, sf.nullable)) ++ flattenNames(sf.dataType, sf.nullable)) - case m: MapType => - flattenNames(m.keyType, nullable) ++ flattenNames(m.valueType, nullable) - case a: ArrayType => flattenNames(a.elementType, nullable) - case _ => Nil - } -} diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonOutput.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonOutput.scala new file mode 100644 index 00000000000..d043d36e23b --- /dev/null +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonOutput.scala @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "320"} +{"spark": "321"} +{"spark": "321cdh"} +{"spark": "321db"} +{"spark": "322"} +{"spark": "323"} +{"spark": "324"} +{"spark": "330"} +{"spark": "330cdh"} +{"spark": "331"} +{"spark": "332"} +{"spark": "332cdh"} +{"spark": "332db"} +{"spark": "333"} +{"spark": "340"} +{"spark": "341"} +{"spark": "350"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.execution.python.shims + +import java.io.DataInputStream +import java.net.Socket +import java.util.concurrent.atomic.AtomicBoolean + +import com.nvidia.spark.rapids.GpuSemaphore + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python._ +import org.apache.spark.sql.rapids.execution.python.GpuArrowOutput +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * A trait that can be mixed-in with `GpuBasePythonRunner`. It implements the logic from + * Python (Arrow) to GPU/JVM (ColumnarBatch). + */ +trait GpuArrowPythonOutput extends GpuArrowOutput { _: GpuBasePythonRunner[_] => + protected def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + pid: Option[Int], // new paramter from Spark 320 + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[ColumnarBatch] = { + + new ReaderIterator(stream, writerThread, startTime, env, worker, pid, releasedOrClosed, + context) { + val gpuArrowReader = newGpuArrowReader + + protected override def read(): ColumnarBatch = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + // Because of batching and other things we have to be sure that we release the semaphore + // before any operation that could block. This is because we are using multiple threads + // for a single task and the GpuSemaphore might not wake up both threads associated with + // the task, so a reader can be blocked waiting for data, while a writer is waiting on + // the semaphore + GpuSemaphore.releaseIfNecessary(TaskContext.get()) + if (gpuArrowReader.isStarted && gpuArrowReader.mayHasNext) { + val batch = gpuArrowReader.readNext() + if (batch != null) { + batch + } else { + gpuArrowReader.close() // reach the end, close the reader + read() // read the end signal + } + } else { + stream.readInt() match { + case SpecialLengths.START_ARROW_STREAM => + gpuArrowReader.start(stream) + read() + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } + } catch handleException + } + } + } +} diff --git a/sql-plugin/src/main/spark321db/scala/com/nvidia/spark/rapids/shims/GpuWindowInPandasExec.scala b/sql-plugin/src/main/spark321db/scala/com/nvidia/spark/rapids/shims/GpuWindowInPandasExec.scala index 86d34414991..988cbe2521c 100644 --- a/sql-plugin/src/main/spark321db/scala/com/nvidia/spark/rapids/shims/GpuWindowInPandasExec.scala +++ b/sql-plugin/src/main/spark321db/scala/com/nvidia/spark/rapids/shims/GpuWindowInPandasExec.scala @@ -26,7 +26,6 @@ import scala.collection.mutable.ArrayBuffer import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.withResource -import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import com.nvidia.spark.rapids.python.PythonWorkerSemaphore import org.apache.spark.TaskContext @@ -34,7 +33,7 @@ import org.apache.spark.api.python.PythonEvalType import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.rapids.execution.python.{BatchQueue, CombiningIterator, GpuPythonHelper, GpuPythonUDF, GpuWindowInPandasExecBase, GroupingIterator} +import org.apache.spark.sql.rapids.execution.python.{BatchProducer, CombiningIterator, GpuPythonHelper, GpuWindowInPandasExecBase, GroupingIterator} import org.apache.spark.sql.rapids.execution.python.shims.GpuArrowPythonRunner import org.apache.spark.sql.rapids.shims.{ArrowUtilsShim, DataTypeUtilsShim} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -111,9 +110,7 @@ case class GpuWindowInPandasExec( // 2) Extract window functions, here should be Python (Pandas) UDFs val allWindowExpressions = expressionsWithFrameIndex.map(_._1) - val udfExpressions = allWindowExpressions.map { - case e: GpuWindowExpression => e.windowFunction.asInstanceOf[GpuPythonUDF] - } + val udfExpressions = PythonUDFShim.getUDFExpressions(allWindowExpressions) // We shouldn't be chaining anything here. // All chained python functions should only contain one function. val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip @@ -196,24 +193,19 @@ case class GpuWindowInPandasExec( // 8) Start processing. child.executeColumnar().mapPartitions { inputIter => val context = TaskContext.get() - val queue: BatchQueue = new BatchQueue() - onTaskCompletion(context)(queue.close()) val boundDataRefs = GpuBindReferences.bindGpuReferences(dataInputs, childOutput) // Re-batching the input data by GroupingIterator val boundPartitionRefs = GpuBindReferences.bindGpuReferences(gpuPartitionSpec, childOutput) - val groupedIterator = new GroupingIterator(inputIter, boundPartitionRefs, - numInputRows, numInputBatches) - val pyInputIterator = groupedIterator.map { batch => - // We have to do the project before we add the batch because the batch might be closed - // when it is added - val projectedBatch = GpuProjectExec.project(batch, boundDataRefs) - // Compute the window bounds and insert to the head of each row for one batch - val inputBatch = withResource(projectedBatch) { projectedCb => - insertWindowBounds(projectedCb) + val batchProducer = new BatchProducer( + new GroupingIterator(inputIter, boundPartitionRefs, numInputRows, numInputBatches)) + val pyInputIterator = batchProducer.asIterator.map { batch => + withResource(batch) { _ => + withResource(GpuProjectExec.project(batch, boundDataRefs)) { projectedCb => + // Compute the window bounds and insert to the head of each row for one batch + insertWindowBounds(projectedCb) + } } - queue.add(batch) - inputBatch } if (isPythonOnGpuEnabled) { @@ -231,12 +223,11 @@ case class GpuWindowInPandasExec( pythonRunnerConf, /* The whole group data should be written in a single call, so here is unlimited */ Int.MaxValue, - pythonOutputSchema, - () => queue.finish()) + pythonOutputSchema) val outputIterator = pyRunner.compute(pyInputIterator, context.partitionId(), context) - new CombiningIterator(queue, outputIterator, pyRunner, numOutputRows, - numOutputBatches).map(projectResult(_)) + new CombiningIterator(batchProducer.getBatchQueue, outputIterator, pyRunner, + numOutputRows, numOutputBatches).map(projectResult) } else { // Empty partition, return the input iterator directly inputIter diff --git a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupUDFArrowPythonRunner.scala b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupUDFArrowPythonRunner.scala index e905e0687cd..6623fc8765f 100644 --- a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupUDFArrowPythonRunner.scala +++ b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupUDFArrowPythonRunner.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -19,7 +19,6 @@ /*** spark-rapids-shim-json-lines {"spark": "321db"} -{"spark": "330db"} {"spark": "332db"} spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.execution.python.shims @@ -27,15 +26,13 @@ package org.apache.spark.sql.rapids.execution.python.shims import java.io.DataOutputStream import java.net.Socket -import ai.rapids.cudf._ -import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.GpuSemaphore import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python._ import org.apache.spark.sql.execution.python.PythonUDFRunner -import org.apache.spark.sql.rapids.execution.python.{BufferToStreamWriter, GpuPythonRunnerBase} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.rapids.execution.python.{GpuArrowPythonWriter, GpuPythonRunnerCommon} +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.Utils @@ -59,10 +56,11 @@ class GpuGroupUDFArrowPythonRunner( pythonInSchema: StructType, timeZoneId: String, conf: Map[String, String], - batchSize: Long, - pythonOutSchema: StructType) - extends GpuPythonRunnerBase[ColumnarBatch](funcs, evalType, argOffsets) - with GpuPythonArrowOutput { + maxBatchSize: Long, + override val pythonOutSchema: StructType, + jobArtifactUUID: Option[String] = None) + extends GpuBasePythonRunner[ColumnarBatch](funcs, evalType, argOffsets, jobArtifactUUID) + with GpuArrowPythonOutput with GpuPythonRunnerCommon { protected override def newWriterThread( env: SparkEnv, @@ -72,64 +70,34 @@ class GpuGroupUDFArrowPythonRunner( context: TaskContext): WriterThread = { new WriterThread(env, worker, inputIterator, partitionIndex, context) { - protected override def writeCommand(dataOut: DataOutputStream): Unit = { - - // Write config for the worker as a number of key -> value pairs of strings - dataOut.writeInt(conf.size) - for ((k, v) <- conf) { - PythonRDD.writeUTF(k, dataOut) - PythonRDD.writeUTF(v, dataOut) + val arrowWriter = new GpuArrowPythonWriter(pythonInSchema, maxBatchSize) { + override protected def writeUDFs(dataOut: DataOutputStream): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) } + } - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + arrowWriter.writeCommand(dataOut, conf) } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - // write out number of columns Utils.tryWithSafeFinally { - val builder = ArrowIPCWriterOptions.builder() - builder.withMaxChunkSize(batchSize) - builder.withCallback((table: Table) => { - table.close() - GpuSemaphore.releaseIfNecessary(TaskContext.get()) - }) - // Flatten the names of nested struct columns, required by cudf Arrow IPC writer. - GpuArrowPythonRunner.flattenNames(pythonInSchema).foreach { case (name, nullable) => - if (nullable) { - builder.withColumnNames(name) - } else { - builder.withNotNullableColumnNames(name) - } - } while(inputIterator.hasNext) { - val writer = { - // write 1 out to indicate there is more to read - dataOut.writeInt(1) - Table.writeArrowIPCChunked(builder.build(), new BufferToStreamWriter(dataOut)) - } - val table = withResource(inputIterator.next()) { nextBatch => - GpuColumnVector.from(nextBatch) - } - withResource(new NvtxRange("write python batch", NvtxColor.DARK_GREEN)) { _ => - // The callback will handle closing table and releasing the semaphore - writer.write(table) - } - writer.close() + // write 1 out to indicate there is more to read + dataOut.writeInt(1) + arrowWriter.start(dataOut) + arrowWriter.writeAndClose(inputIterator.next()) + arrowWriter.reset() dataOut.flush() } - // indicate not to read more - // The iterator can grab the semaphore even on an empty batch - GpuSemaphore.releaseIfNecessary(TaskContext.get()) } { + arrowWriter.close() // tell serializer we are done dataOut.writeInt(0) dataOut.flush() + GpuSemaphore.releaseIfNecessary(TaskContext.get()) } } } } - - def toBatch(table: Table): ColumnarBatch = { - GpuColumnVector.from(table, GpuColumnVector.extractTypes(pythonOutSchema)) - } } diff --git a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonRunnerShims.scala b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala similarity index 90% rename from sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonRunnerShims.scala rename to sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala index 3da7f7dc99e..ed0d5816f40 100644 --- a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonRunnerShims.scala +++ b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala @@ -21,12 +21,12 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.execution.python.shims -import org.apache.spark.api.python._ -import org.apache.spark.sql.rapids.execution.python._ +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.rapids.shims.ArrowUtilsShim -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch -case class GpuArrowPythonRunnerShims( + +case class GpuGroupedPythonRunnerFactory( conf: org.apache.spark.sql.internal.SQLConf, chainedFunc: Seq[ChainedPythonFunctions], argOffsets: Array[Array[Int]], @@ -38,7 +38,7 @@ case class GpuArrowPythonRunnerShims( val sessionLocalTimeZone = conf.sessionLocalTimeZone val pythonRunnerConf = ArrowUtilsShim.getPythonRunnerConfMap(conf) - def getRunner(): GpuPythonRunnerBase[ColumnarBatch] = { + def getRunner(): GpuBasePythonRunner[ColumnarBatch] = { if (zeroConfEnabled && maxBytes > 0L) { new GpuGroupUDFArrowPythonRunner( chainedFunc, diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonOutput.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonOutput.scala new file mode 100644 index 00000000000..f229e50528a --- /dev/null +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonOutput.scala @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330db"} +{"spark": "341db"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.execution.python.shims + +import java.io.DataInputStream +import java.util.concurrent.atomic.AtomicBoolean + +import com.nvidia.spark.rapids.GpuSemaphore + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python._ +import org.apache.spark.sql.rapids.execution.python.GpuArrowOutput +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * A trait that can be mixed-in with `GpuBasePythonRunner`. It implements the logic from + * Python (Arrow) to GPU/JVM (ColumnarBatch). + */ +trait GpuArrowPythonOutput extends GpuArrowOutput { _: GpuBasePythonRunner[_] => + + protected def newReaderIterator( + stream: DataInputStream, + writer: Writer, + startTime: Long, + env: SparkEnv, + worker: PythonWorker, + pid: Option[Int], + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[ColumnarBatch] = { + + new ReaderIterator(stream, writer, startTime, env, worker, pid, releasedOrClosed, context) { + val gpuArrowReader = newGpuArrowReader + + protected override def read(): ColumnarBatch = { + if (writer.exception.isDefined) { + throw writer.exception.get + } + try { + // Because of batching and other things we have to be sure that we release the semaphore + // before any operation that could block. This is because we are using multiple threads + // for a single task and the GpuSemaphore might not wake up both threads associated with + // the task, so a reader can be blocked waiting for data, while a writer is waiting on + // the semaphore + GpuSemaphore.releaseIfNecessary(TaskContext.get()) + if (gpuArrowReader.isStarted && gpuArrowReader.mayHasNext) { + val batch = gpuArrowReader.readNext() + if (batch != null) { + batch + } else { + gpuArrowReader.close() // reach the end, close the reader + read() // read the end signal + } + } else { + stream.readInt() match { + case SpecialLengths.START_ARROW_STREAM => + gpuArrowReader.start(stream) + read() + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } + } catch handleException + } + } + } +} diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonRunner.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonRunner.scala new file mode 100644 index 00000000000..fc79236095a --- /dev/null +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonRunner.scala @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/*** spark-rapids-shim-json-lines +{"spark": "330db"} +{"spark": "341db"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.execution.python.shims + +import java.io.DataOutputStream + +import com.nvidia.spark.rapids.GpuSemaphore + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python._ +import org.apache.spark.sql.execution.python.PythonUDFRunner +import org.apache.spark.sql.rapids.execution.python.{GpuArrowPythonWriter, GpuPythonRunnerCommon} +import org.apache.spark.sql.rapids.shims.ArrowUtilsShim +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. + */ +class GpuArrowPythonRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + pythonInSchema: StructType, + timeZoneId: String, + conf: Map[String, String], + maxBatchSize: Long, + override val pythonOutSchema: StructType, + jobArtifactUUID: Option[String] = None) + extends GpuBasePythonRunner[ColumnarBatch](funcs, evalType, argOffsets, jobArtifactUUID) + with GpuArrowPythonOutput with GpuPythonRunnerCommon { + + protected override def newWriter( + env: SparkEnv, + worker: PythonWorker, + inputIterator: Iterator[ColumnarBatch], + partitionIndex: Int, + context: TaskContext): Writer = { + new Writer(env, worker, inputIterator, partitionIndex, context) { + + val arrowWriter = new GpuArrowPythonWriter(pythonInSchema, maxBatchSize) { + override protected def writeUDFs(dataOut: DataOutputStream): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } + } + val isInputNonEmpty = inputIterator.nonEmpty + lazy val arrowSchema = ArrowUtilsShim.toArrowSchema(pythonInSchema, timeZoneId) + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + arrowWriter.writeCommand(dataOut, conf) + } + + override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { + if (isInputNonEmpty) { + arrowWriter.start(dataOut) + try { + if (inputIterator.hasNext) { + arrowWriter.writeAndClose(inputIterator.next()) + dataOut.flush() + true + } else { + arrowWriter.close() // all batches are written, close the writer + false + } + } catch { + case t: Throwable => + arrowWriter.close() + GpuSemaphore.releaseIfNecessary(TaskContext.get()) + throw t + } + } else { + // The iterator can grab the semaphore even on an empty batch + GpuSemaphore.releaseIfNecessary(TaskContext.get()) + arrowWriter.writeEmptyIteratorOnCpu(dataOut, arrowSchema) + false + } + } + } + } +} diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuCoGroupedArrowPythonRunner.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuCoGroupedArrowPythonRunner.scala new file mode 100644 index 00000000000..6f657950a91 --- /dev/null +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuCoGroupedArrowPythonRunner.scala @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330db"} +{"spark": "341db"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.execution.python.shims + +import java.io.DataOutputStream + +import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.GpuSemaphore + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRDD, PythonWorker} +import org.apache.spark.sql.execution.python.PythonUDFRunner +import org.apache.spark.sql.rapids.execution.python.{GpuArrowWriter, GpuPythonRunnerCommon} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Python UDF Runner for cogrouped UDFs, designed for `GpuFlatMapCoGroupsInPandasExec` only. + * + * It sends Arrow batches from two different DataFrames, groups them in Python, + * and receive it back in JVM as batches of single DataFrame. + */ +class GpuCoGroupedArrowPythonRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + leftSchema: StructType, + rightSchema: StructType, + timeZoneId: String, + conf: Map[String, String], + batchSize: Int, + override val pythonOutSchema: StructType, + jobArtifactUUID: Option[String] = None) + extends GpuBasePythonRunner[(ColumnarBatch, ColumnarBatch)](funcs, evalType, + argOffsets, jobArtifactUUID) with GpuArrowPythonOutput with GpuPythonRunnerCommon { + + protected override def newWriter( + env: SparkEnv, + worker: PythonWorker, // Changed from "Socket" to this "PythonWorker" from db341 + inputIterator: Iterator[(ColumnarBatch, ColumnarBatch)], + partitionIndex: Int, + context: TaskContext): Writer = { + new Writer(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + // Write config for the worker as a number of key -> value pairs of strings + dataOut.writeInt(conf.size) + for ((k, v) <- conf) { + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) + } + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } + + override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { + if (inputIterator.hasNext) { + // For each we first send the number of dataframes in each group then send + // first df, then send second df. + dataOut.writeInt(2) + val (leftGroupBatch, rightGroupBatch) = inputIterator.next() + withResource(Seq(leftGroupBatch, rightGroupBatch)) { _ => + writeGroupBatch(leftGroupBatch, leftSchema, dataOut) + writeGroupBatch(rightGroupBatch, rightSchema, dataOut) + } + true + } else { + // The iterator can grab the semaphore even on an empty batch + GpuSemaphore.releaseIfNecessary(TaskContext.get()) + // End of data is marked by sending 0. + dataOut.writeInt(0) + false + } + } + + private def writeGroupBatch(groupBatch: ColumnarBatch, batchSchema: StructType, + dataOut: DataOutputStream): Unit = { + val gpuArrowWriter = GpuArrowWriter(batchSchema, batchSize) + try { + gpuArrowWriter.start(dataOut) + gpuArrowWriter.write(groupBatch) + } catch { + case t: Throwable => + // release the semaphore in case of exception in the middle of writing a batch + GpuSemaphore.releaseIfNecessary(TaskContext.get()) + throw t + } finally { + gpuArrowWriter.reset() + dataOut.flush() + } + } // end of writeGroupBatch + } + } +} diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupUDFArrowPythonRunner.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupUDFArrowPythonRunner.scala new file mode 100644 index 00000000000..ec3bde02434 --- /dev/null +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupUDFArrowPythonRunner.scala @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330db"} +{"spark": "341db"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.execution.python.shims + +import java.io.DataOutputStream + +import com.nvidia.spark.rapids.GpuSemaphore + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python._ +import org.apache.spark.sql.execution.python.PythonUDFRunner +import org.apache.spark.sql.rapids.execution.python.{GpuArrowPythonWriter, GpuPythonRunnerCommon} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Group Map UDF specific serializer for Databricks because they have a special GroupUDFSerializer. + * The main difference here from the GpuArrowPythonRunner is that it creates a new Arrow + * Stream for each grouped data. + * The overall flow is: + * - send a 1 to indicate more data is coming + * - create a new Arrow Stream for each grouped data + * - send the schema + * - send that group of data + * - close that Arrow stream + * - Repeat starting at sending 1 if more data, otherwise send a 0 to indicate no + * more data being sent. + */ +class GpuGroupUDFArrowPythonRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + pythonInSchema: StructType, + timeZoneId: String, + conf: Map[String, String], + batchSize: Long, + override val pythonOutSchema: StructType, + jobArtifactUUID: Option[String] = None) + extends GpuBasePythonRunner[ColumnarBatch](funcs, evalType, argOffsets, jobArtifactUUID) + with GpuArrowPythonOutput with GpuPythonRunnerCommon { + + protected override def newWriter( + env: SparkEnv, + worker: PythonWorker, // From DB341, changed from Socket to PythonWorker + inputIterator: Iterator[ColumnarBatch], + partitionIndex: Int, + context: TaskContext): Writer = { + new Writer(env, worker, inputIterator, partitionIndex, context) { + + val arrowWriter = new GpuArrowPythonWriter(pythonInSchema, batchSize) { + override protected def writeUDFs(dataOut: DataOutputStream): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } + } + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + arrowWriter.writeCommand(dataOut, conf) + } + + override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { + try { + if (inputIterator.hasNext) { + dataOut.writeInt(1) + arrowWriter.start(dataOut) + arrowWriter.writeAndClose(inputIterator.next()) + arrowWriter.reset() + dataOut.flush() + true + } else { + // The iterator can grab the semaphore even on an empty batch + GpuSemaphore.releaseIfNecessary(TaskContext.get()) + // tell serializer we are done + dataOut.writeInt(0) + dataOut.flush() + false + } + } catch { + case t: Throwable => + arrowWriter.close() + // release the semaphore in case of exception in the middle of writing a batch + GpuSemaphore.releaseIfNecessary(TaskContext.get()) + throw t + } + } + } + } +} diff --git a/sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/Spark341PlusDBShims.scala b/sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/Spark341PlusDBShims.scala index 6018f5e51b1..36ffc1db926 100644 --- a/sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/Spark341PlusDBShims.scala +++ b/sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/Spark341PlusDBShims.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.execution.{CollectLimitExec, GlobalLimitExec, SparkPlan, TakeOrderedAndProjectExec} import org.apache.spark.sql.execution.exchange.ENSURE_REQUIREMENTS import org.apache.spark.sql.rapids.GpuV1WriteUtils.GpuEmpty2Null +import org.apache.spark.sql.rapids.execution.python.GpuPythonUDAF import org.apache.spark.sql.types.StringType trait Spark341PlusDBShims extends Spark332PlusDBShims { @@ -56,7 +57,35 @@ trait Spark341PlusDBShims extends Spark332PlusDBShims { (a, conf, p, r) => new UnaryExprMeta[Empty2Null](a, conf, p, r) { override def convertToGpu(child: Expression): GpuExpression = GpuEmpty2Null(child) } - ) + ), + GpuOverrides.expr[PythonUDAF]( + "UDF run in an external python process. Does not actually run on the GPU, but " + + "the transfer of data to/from it can be accelerated", + ExprChecks.fullAggAndProject( + // Different types of Pandas UDF support different sets of output type. Please refer to + // https://github.com/apache/spark/blob/master/python/pyspark/sql/udf.py#L98 + // for more details. + // It is impossible to specify the exact type signature for each Pandas UDF type in a + // single expression 'PythonUDF'. + // So use the 'unionOfPandasUdfOut' to cover all types for Spark. The type signature of + // plugin is also an union of all the types of Pandas UDF. + (TypeSig.commonCudfTypes + TypeSig.ARRAY).nested() + TypeSig.STRUCT, + TypeSig.unionOfPandasUdfOut, + repeatingParamCheck = Some(RepeatingParamCheck( + "param", + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all))), + (a, conf, p, r) => new ExprMeta[PythonUDAF](a, conf, p, r) { + override def replaceMessage: String = "not block GPU acceleration" + + override def noReplacementPossibleMessage(reasons: String): String = + s"blocks running on GPU because $reasons" + + override def convertToGpu(): GpuExpression = + GpuPythonUDAF(a.name, a.func, a.dataType, + childExprs.map(_.convertToGpu()), + a.evalType, a.udfDeterministic, a.resultId) + }) ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap super.getExprs ++ shimExprs ++ DayTimeIntervalShims.exprs ++ RoundingShims.exprs } diff --git a/sql-plugin/src/main/spark341db/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala b/sql-plugin/src/main/spark341db/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala deleted file mode 100644 index 1cf8abeab2d..00000000000 --- a/sql-plugin/src/main/spark341db/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*** spark-rapids-shim-json-lines -{"spark": "341db"} -spark-rapids-shim-json-lines ***/ -package org.apache.spark.rapids.shims.api.python - -import java.io.DataInputStream -import java.util.concurrent.atomic.AtomicBoolean - -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.python.{BasePythonRunner, PythonWorker} - -abstract class ShimBasePythonRunner[IN, OUT]( - funcs : scala.Seq[org.apache.spark.api.python.ChainedPythonFunctions], - evalType : scala.Int, argOffsets : scala.Array[scala.Array[scala.Int]] -) extends BasePythonRunner[IN, OUT](funcs, evalType, argOffsets, None) { - protected abstract class ShimReaderIterator( - stream: DataInputStream, - writer: Writer, - startTime: Long, - env: SparkEnv, - worker: PythonWorker, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext - ) extends ReaderIterator(stream, writer, startTime, env, worker, pid, - releasedOrClosed, context) -} diff --git a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuBasePythonRunner.scala b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuBasePythonRunner.scala new file mode 100644 index 00000000000..81c78db1cfc --- /dev/null +++ b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuBasePythonRunner.scala @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "341db"} +{"spark": "350"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.execution.python.shims + +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions} +import org.apache.spark.sql.vectorized.ColumnarBatch + +abstract class GpuBasePythonRunner[IN]( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + jobArtifactUUID: Option[String] +) extends BasePythonRunner[IN, ColumnarBatch](funcs, evalType, argOffsets, jobArtifactUUID) diff --git a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuCoGroupedArrowPythonRunner.scala b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuCoGroupedArrowPythonRunner.scala deleted file mode 100644 index 9c245cf2636..00000000000 --- a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuCoGroupedArrowPythonRunner.scala +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*** spark-rapids-shim-json-lines -{"spark": "341db"} -spark-rapids-shim-json-lines ***/ -package org.apache.spark.sql.rapids.execution.python.shims - -import java.io.DataOutputStream - -import ai.rapids.cudf.{ArrowIPCWriterOptions, NvtxColor, NvtxRange, Table} -import com.nvidia.spark.rapids.{GpuColumnVector, GpuSemaphore} -import com.nvidia.spark.rapids.Arm.withResource - -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRDD, PythonWorker} -import org.apache.spark.sql.execution.python.PythonUDFRunner -import org.apache.spark.sql.rapids.execution.python._ -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.Utils - -/** - * Python UDF Runner for cogrouped UDFs, designed for `GpuFlatMapCoGroupsInPandasExec` only. - * - * It sends Arrow batches from two different DataFrames, groups them in Python, - * and receive it back in JVM as batches of single DataFrame. - */ -class GpuCoGroupedArrowPythonRunner( - funcs: Seq[ChainedPythonFunctions], - evalType: Int, - argOffsets: Array[Array[Int]], - leftSchema: StructType, - rightSchema: StructType, - timeZoneId: String, - conf: Map[String, String], - batchSize: Int, - pythonOutSchema: StructType) - extends GpuPythonRunnerBase[(ColumnarBatch, ColumnarBatch)](funcs, evalType, argOffsets) - with GpuPythonArrowOutput { - - protected override def newWriter( - env: SparkEnv, - worker: PythonWorker, - inputIterator: Iterator[(ColumnarBatch, ColumnarBatch)], - partitionIndex: Int, - context: TaskContext): Writer = { - new Writer(env, worker, inputIterator, partitionIndex, context) { - - protected override def writeCommand(dataOut: DataOutputStream): Unit = { - - // Write config for the worker as a number of key -> value pairs of strings - dataOut.writeInt(conf.size) - for ((k, v) <- conf) { - PythonRDD.writeUTF(k, dataOut) - PythonRDD.writeUTF(v, dataOut) - } - - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) - } - - override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { - // For each we first send the number of dataframes in each group then send - // first df, then send second df. End of data is marked by sending 0. - var wrote = false - while (inputIterator.hasNext) { - wrote = false - dataOut.writeInt(2) - val (leftGroupBatch, rightGroupBatch) = inputIterator.next() - withResource(Seq(leftGroupBatch, rightGroupBatch)) { _ => - wrote = writeGroupBatch(leftGroupBatch, leftSchema, dataOut) - wrote = writeGroupBatch(rightGroupBatch, rightSchema, dataOut) - } - } - // The iterator can grab the semaphore even on an empty batch - GpuSemaphore.releaseIfNecessary(TaskContext.get()) - dataOut.writeInt(0) - wrote - } - - private def writeGroupBatch(groupBatch: ColumnarBatch, batchSchema: StructType, - dataOut: DataOutputStream): Boolean = { - val writer = { - val builder = ArrowIPCWriterOptions.builder() - builder.withMaxChunkSize(batchSize) - builder.withCallback((table: Table) => { - table.close() - GpuSemaphore.releaseIfNecessary(TaskContext.get()) - }) - // Flatten the names of nested struct columns, required by cudf arrow IPC writer. - GpuArrowPythonRunner.flattenNames(batchSchema).foreach { case (name, nullable) => - if (nullable) { - builder.withColumnNames(name) - } else { - builder.withNotNullableColumnNames(name) - } - } - Table.writeArrowIPCChunked(builder.build(), new BufferToStreamWriter(dataOut)) - } - var wrote = false - Utils.tryWithSafeFinally { - withResource(new NvtxRange("write python batch", NvtxColor.DARK_GREEN)) { _ => - // The callback will handle closing table and releasing the semaphore - writer.write(GpuColumnVector.from(groupBatch)) - wrote = true - } - } { - writer.close() - dataOut.flush() - } - wrote - } // end of writeGroup - } - } // end of newWriterThread - - def toBatch(table: Table): ColumnarBatch = { - GpuColumnVector.from(table, GpuColumnVector.extractTypes(pythonOutSchema)) - } -} diff --git a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupUDFArrowPythonRunner.scala b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupUDFArrowPythonRunner.scala deleted file mode 100644 index c1aea19a194..00000000000 --- a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupUDFArrowPythonRunner.scala +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*** spark-rapids-shim-json-lines -{"spark": "341db"} -spark-rapids-shim-json-lines ***/ -package org.apache.spark.sql.rapids.execution.python.shims - -import java.io.DataOutputStream - -import ai.rapids.cudf._ -import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.Arm.withResource - -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.python._ -import org.apache.spark.sql.execution.python.PythonUDFRunner -import org.apache.spark.sql.rapids.execution.python._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.Utils - -/** - * Group Map UDF specific serializer for Databricks because they have a special GroupUDFSerializer. - * The main difference here from the GpuArrowPythonRunner is that it creates a new Arrow - * Stream for each grouped data. - * The overall flow is: - * - send a 1 to indicate more data is coming - * - create a new Arrow Stream for each grouped data - * - send the schema - * - send that group of data - * - close that Arrow stream - * - Repeat starting at sending 1 if more data, otherwise send a 0 to indicate no - * more data being sent. - */ -class GpuGroupUDFArrowPythonRunner( - funcs: Seq[ChainedPythonFunctions], - evalType: Int, - argOffsets: Array[Array[Int]], - pythonInSchema: StructType, - timeZoneId: String, - conf: Map[String, String], - batchSize: Long, - pythonOutSchema: StructType) - extends GpuPythonRunnerBase[ColumnarBatch](funcs, evalType, argOffsets) - with GpuPythonArrowOutput { - - protected override def newWriter( - env: SparkEnv, - worker: PythonWorker, - inputIterator: Iterator[ColumnarBatch], - partitionIndex: Int, - context: TaskContext): Writer = { - new Writer(env, worker, inputIterator, partitionIndex, context) { - - protected override def writeCommand(dataOut: DataOutputStream): Unit = { - - // Write config for the worker as a number of key -> value pairs of strings - dataOut.writeInt(conf.size) - for ((k, v) <- conf) { - PythonRDD.writeUTF(k, dataOut) - PythonRDD.writeUTF(v, dataOut) - } - - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) - } - - override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { - var wrote = false - // write out number of columns - Utils.tryWithSafeFinally { - val builder = ArrowIPCWriterOptions.builder() - builder.withMaxChunkSize(batchSize) - builder.withCallback((table: Table) => { - table.close() - GpuSemaphore.releaseIfNecessary(TaskContext.get()) - }) - // Flatten the names of nested struct columns, required by cudf Arrow IPC writer. - GpuArrowPythonRunner.flattenNames(pythonInSchema).foreach { case (name, nullable) => - if (nullable) { - builder.withColumnNames(name) - } else { - builder.withNotNullableColumnNames(name) - } - } - while(inputIterator.hasNext) { - wrote = false - val writer = { - // write 1 out to indicate there is more to read - dataOut.writeInt(1) - Table.writeArrowIPCChunked(builder.build(), new BufferToStreamWriter(dataOut)) - } - val table = withResource(inputIterator.next()) { nextBatch => - GpuColumnVector.from(nextBatch) - } - withResource(new NvtxRange("write python batch", NvtxColor.DARK_GREEN)) { _ => - // The callback will handle closing table and releasing the semaphore - writer.write(table) - wrote = true - } - writer.close() - dataOut.flush() - } - // indicate not to read more - // The iterator can grab the semaphore even on an empty batch - GpuSemaphore.releaseIfNecessary(TaskContext.get()) - } { - // tell serializer we are done - dataOut.writeInt(0) - dataOut.flush() - } - wrote - } - } - } - - def toBatch(table: Table): ColumnarBatch = { - GpuColumnVector.from(table, GpuColumnVector.extractTypes(pythonOutSchema)) - } -} diff --git a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonRunnerShims.scala b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala similarity index 93% rename from sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonRunnerShims.scala rename to sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala index bf05656c861..d3e3415290a 100644 --- a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuArrowPythonRunnerShims.scala +++ b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala @@ -20,13 +20,12 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.execution.python.shims import org.apache.spark.api.python._ -import org.apache.spark.sql.rapids.execution.python._ import org.apache.spark.sql.rapids.shims.ArrowUtilsShim import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch //TODO is this needed? we already have a similar version in spark321db -case class GpuArrowPythonRunnerShims( +case class GpuGroupedPythonRunnerFactory( conf: org.apache.spark.sql.internal.SQLConf, chainedFunc: Seq[ChainedPythonFunctions], argOffsets: Array[Array[Int]], @@ -38,7 +37,7 @@ case class GpuArrowPythonRunnerShims( val sessionLocalTimeZone = conf.sessionLocalTimeZone val pythonRunnerConf = ArrowUtilsShim.getPythonRunnerConfMap(conf) - def getRunner(): GpuPythonRunnerBase[ColumnarBatch] = { + def getRunner(): GpuBasePythonRunner[ColumnarBatch] = { if (zeroConfEnabled && maxBytes > 0L) { new GpuGroupUDFArrowPythonRunner( chainedFunc, diff --git a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuPythonArrowShims.scala b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuPythonArrowShims.scala deleted file mode 100644 index 35fe8979d94..00000000000 --- a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuPythonArrowShims.scala +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/*** spark-rapids-shim-json-lines -{"spark": "341db"} -spark-rapids-shim-json-lines ***/ -package org.apache.spark.sql.rapids.execution.python.shims - -import java.io.{DataInputStream, DataOutputStream} -import java.util.concurrent.atomic.AtomicBoolean - -import ai.rapids.cudf._ -import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.Arm.withResource -import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion - -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.python._ -import org.apache.spark.sql.rapids.execution.python._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.ColumnarBatch - -/** - * A trait that can be mixed-in with `GpuPythonRunnerBase`. It implements the logic from - * Python (Arrow) to GPU/JVM (ColumnarBatch). - */ -trait GpuPythonArrowOutput { _: GpuPythonRunnerBase[_] => - - /** - * Default to `Int.MaxValue` to try to read as many as possible. - * Change it by calling `setMinReadTargetBatchSize` before a reading. - */ - private var minReadTargetBatchSize: Int = Int.MaxValue - - /** - * Update the expected batch size for next reading. - */ - private[python] final def setMinReadTargetBatchSize(size: Int): Unit = { - minReadTargetBatchSize = size - } - - /** Convert the table received from the Python side to a batch. */ - protected def toBatch(table: Table): ColumnarBatch - - protected def newReaderIterator( - stream: DataInputStream, - writer: Writer, - startTime: Long, - env: SparkEnv, - worker: PythonWorker, - releasedOrClosed: AtomicBoolean, - context: TaskContext - ): Iterator[ColumnarBatch] = { - newReaderIterator(stream, writer, startTime, env, worker, None, releasedOrClosed, - context) - } - - protected def newReaderIterator( - stream: DataInputStream, - writer: Writer, - startTime: Long, - env: SparkEnv, - worker: PythonWorker, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[ColumnarBatch] = { - - new ShimReaderIterator(stream, writer, startTime, env, worker, pid, releasedOrClosed, - context) { - - private[this] var arrowReader: StreamedTableReader = _ - - onTaskCompletion(context) { - if (arrowReader != null) { - arrowReader.close() - arrowReader = null - } - } - - private var batchLoaded = true - - protected override def read(): ColumnarBatch = { - if (writer.exception.isDefined) { - throw writer.exception.get - } - try { - // Because of batching and other things we have to be sure that we release the semaphore - // before any operation that could block. This is because we are using multiple threads - // for a single task and the GpuSemaphore might not wake up both threads associated with - // the task, so a reader can be blocked waiting for data, while a writer is waiting on - // the semaphore - GpuSemaphore.releaseIfNecessary(TaskContext.get()) - if (arrowReader != null && batchLoaded) { - // The GpuSemaphore is acquired in a callback - val table = - withResource(new NvtxRange("read python batch", NvtxColor.DARK_GREEN)) { _ => - arrowReader.getNextIfAvailable(minReadTargetBatchSize) - } - if (table == null) { - batchLoaded = false - arrowReader.close() - arrowReader = null - read() - } else { - withResource(table) { _ => - batchLoaded = true - toBatch(table) - } - } - } else { - stream.readInt() match { - case SpecialLengths.START_ARROW_STREAM => - val builder = ArrowIPCOptions.builder() - builder.withCallback(() => - GpuSemaphore.acquireIfNecessary(TaskContext.get())) - arrowReader = Table.readArrowIPCChunked(builder.build(), - new StreamToBufferProvider(stream)) - read() - case SpecialLengths.TIMING_DATA => - handleTimingData() - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - throw handlePythonException() - case SpecialLengths.END_OF_DATA_SECTION => - handleEndOfDataSection() - null - } - } - } catch handleException - } - } - } -} - -/** - * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. - */ -class GpuArrowPythonRunner( - funcs: Seq[ChainedPythonFunctions], - evalType: Int, - argOffsets: Array[Array[Int]], - pythonInSchema: StructType, - timeZoneId: String, - conf: Map[String, String], - batchSize: Long, - pythonOutSchema: StructType = null, - onDataWriteFinished: () => Unit = null) - extends GpuArrowPythonRunnerBase(funcs, evalType, argOffsets, pythonInSchema, timeZoneId, - conf, batchSize, pythonOutSchema, onDataWriteFinished) { - - protected override def newWriter( - env: SparkEnv, - worker: PythonWorker, - inputIterator: Iterator[ColumnarBatch], - partitionIndex: Int, - context: TaskContext): Writer = { - new Writer(env, worker, inputIterator, partitionIndex, context) { - - val workerImpl = new RapidsWriter(env, inputIterator, partitionIndex, context) - - protected override def writeCommand(dataOut: DataOutputStream): Unit = { - workerImpl.writeCommand(dataOut) - } - - override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { - workerImpl.writeInputToStream(dataOut) - } - } - } -} - - -object GpuArrowPythonRunner { - def flattenNames(d: DataType, nullable: Boolean = true): Seq[(String, Boolean)] = - d match { - case s: StructType => - s.flatMap(sf => Seq((sf.name, sf.nullable)) ++ flattenNames(sf.dataType, sf.nullable)) - case m: MapType => - flattenNames(m.keyType, nullable) ++ flattenNames(m.valueType, nullable) - case a: ArrayType => flattenNames(a.elementType, nullable) - case _ => Nil - } -} diff --git a/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/api/python/ShimBasePythonRunner.scala b/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/api/python/ShimBasePythonRunner.scala deleted file mode 100644 index d100d931b9c..00000000000 --- a/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/api/python/ShimBasePythonRunner.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*** spark-rapids-shim-json-lines -{"spark": "350"} -spark-rapids-shim-json-lines ***/ -package org.apache.spark.rapids.shims.api.python - -import java.io.DataInputStream -import java.net.Socket -import java.util.concurrent.atomic.AtomicBoolean - -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.python.BasePythonRunner - -abstract class ShimBasePythonRunner[IN, OUT]( - funcs : scala.Seq[org.apache.spark.api.python.ChainedPythonFunctions], - evalType : scala.Int, - argOffsets : scala.Array[scala.Array[scala.Int]], - jobArtifactUUID: Option[String] = None) // TODO shim this - extends BasePythonRunner[IN, OUT](funcs, evalType, argOffsets, jobArtifactUUID) { - protected abstract class ShimReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext - ) extends ReaderIterator(stream, writerThread, startTime, env, worker, pid, - releasedOrClosed, context) -}