diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala index e7acd575285..be5c750b2e0 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala @@ -23,6 +23,7 @@ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{withRetry, withRetryNoSplit} import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion +import com.nvidia.spark.rapids.jni.SplitAndRetryOOM import com.nvidia.spark.rapids.shims.{ShimExpression, ShimUnaryExecNode} import org.apache.spark.TaskContext @@ -670,7 +671,7 @@ abstract class AbstractGpuCoalesceIterator( val it = batchesToCoalesce.batches val numBatches = it.length if (numBatches <= 1) { - throw new OutOfMemoryError(s"Cannot split a sequence of $numBatches batches") + throw new SplitAndRetryOOM(s"Cannot split a sequence of $numBatches batches") } val res = it.splitAt(numBatches / 2) Seq(BatchesToCoalesce(res._1), BatchesToCoalesce(res._2)) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesRetrySuite.scala index 8cb21e9e6c0..f597e7d8d9c 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesRetrySuite.scala @@ -175,7 +175,7 @@ class GpuCoalesceBatchesRetrySuite test("coalesce gpu batches fails with OOM if it cannot split enough") { val iters = getIters(mockInjectSplitAndRetry = true) iters.foreach { iter => - assertThrows[OutOfMemoryError] { + assertThrows[SplitAndRetryOOM] { iter.next() // throws } val batches = iter.asInstanceOf[CoalesceIteratorMocks].getBatches()