From 7e899f486c88ab2a9313995747c8d5c3d7cc18de Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Wed, 4 Oct 2023 08:39:49 -0500 Subject: [PATCH] GpuCoalesceBatches should throw SplitAndRetyOOM on GPU OOM error. (#9374) Signed-off-by: Jason Lowe --- .../scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala | 3 ++- .../com/nvidia/spark/rapids/GpuCoalesceBatchesRetrySuite.scala | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala index e7acd575285..be5c750b2e0 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala @@ -23,6 +23,7 @@ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{withRetry, withRetryNoSplit} import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion +import com.nvidia.spark.rapids.jni.SplitAndRetryOOM import com.nvidia.spark.rapids.shims.{ShimExpression, ShimUnaryExecNode} import org.apache.spark.TaskContext @@ -670,7 +671,7 @@ abstract class AbstractGpuCoalesceIterator( val it = batchesToCoalesce.batches val numBatches = it.length if (numBatches <= 1) { - throw new OutOfMemoryError(s"Cannot split a sequence of $numBatches batches") + throw new SplitAndRetryOOM(s"Cannot split a sequence of $numBatches batches") } val res = it.splitAt(numBatches / 2) Seq(BatchesToCoalesce(res._1), BatchesToCoalesce(res._2)) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesRetrySuite.scala index 8cb21e9e6c0..f597e7d8d9c 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesRetrySuite.scala @@ -175,7 +175,7 @@ class GpuCoalesceBatchesRetrySuite test("coalesce gpu batches fails with OOM if it cannot split enough") { val iters = getIters(mockInjectSplitAndRetry = true) iters.foreach { iter => - assertThrows[OutOfMemoryError] { + assertThrows[SplitAndRetryOOM] { iter.next() // throws } val batches = iter.asInstanceOf[CoalesceIteratorMocks].getBatches()