From 2f76d2b2007b85b994c168fa1f95b0dbb549d886 Mon Sep 17 00:00:00 2001 From: Firestarman Date: Tue, 26 Sep 2023 02:39:34 +0000 Subject: [PATCH] Address comments and add more tests Signed-off-by: Firestarman --- dist/unshimmed-common-from-spark311.txt | 1 + .../main/java/com/nvidia/spark/Retryable.java | 49 +++----- .../rapids/NonDeterministicRetrySuite.scala | 106 ++++++++++++++---- 3 files changed, 102 insertions(+), 54 deletions(-) diff --git a/dist/unshimmed-common-from-spark311.txt b/dist/unshimmed-common-from-spark311.txt index b044c2cb4c2..cf67a19590a 100644 --- a/dist/unshimmed-common-from-spark311.txt +++ b/dist/unshimmed-common-from-spark311.txt @@ -5,6 +5,7 @@ com/nvidia/spark/ExclusiveModeGpuDiscoveryPlugin* com/nvidia/spark/GpuCachedBatchSerializer* com/nvidia/spark/ParquetCachedBatchSerializer* com/nvidia/spark/RapidsUDF* +com/nvidia/spark/Retryable* com/nvidia/spark/SQLPlugin* com/nvidia/spark/rapids/ColumnarRdd* com/nvidia/spark/rapids/GpuColumnVectorUtils* diff --git a/sql-plugin/src/main/java/com/nvidia/spark/Retryable.java b/sql-plugin/src/main/java/com/nvidia/spark/Retryable.java index afc742e6e8b..db1bb0e1fe6 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/Retryable.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/Retryable.java @@ -17,48 +17,33 @@ package com.nvidia.spark; /** - * An interface that can be used by Retry framework of RAPIDS Plugin to handle the GPU OOMs. + * An interface that can be used to retry the processing on non-deterministic + * expressions on the GPU. * - * GPU memory is a limited resource, so OOM can happen if too many tasks run in parallel. - * Retry framework is introduced to improve the stability by retrying the work when it - * meets OOMs. The overall process of Retry framework is similar as the below. - * ``` - * Retryable retryable - * retryable.checkpoint() - * boolean hasOOM = false - * do { - * try { - * runWorkOnGpu(retryable) // May lead to GPU OOM - * hasOOM = false - * } catch (OOMError oom) { - * hasOOM = true - * tryToReleaseSomeGpuMemoryFromLowPriorityTasks() - * retryable.restore() - * } - * } while(hasOOM) - * ``` - * In a retry, "checkpoint" will be called first and only once, which is used to save the - * state for later loops. When OOM happens, "restore" will be called to restore the - * state that saved by "checkpoint". After that, it will try the same work again. And - * the whole process runs on Spark executors. + * GPU memory is a limited resource. When it runs out the RAPIDS Accelerator + * for Apache Spark will use several different strategies to try and free more + * GPU memory to let the query complete. + * One of these strategies is to roll back the processioning for one task, pause + * the task thread, then retry the task when more memory is available. This + * works transparently for any stateless deterministic processing. But technically + * an expression/UDF can be non-deterministic and/or keep state in between calls. + * This interface provides a checkpoint method to save any needed state, and a + * restore method to reset the state in the case of a retry. * - * Retry framework expects the "runWorkOnGpu" always outputs the same result when running - * it multiple times in a retry. So if "runWorkOnGpu" is non-deterministic, it can not be - * used by Retry framework. - * The "Retryable" is designed for this kind of cases. By implementing this interface, - * "runWorkOnGpu" can become deterministic inside a retry process, making it usable for - * Retry framework to improve the stability. + * Please note that a retry is not isolated to a single expression, so a restore can + * be called even after the expression returned one or more batches of results. And + * each time checkpoint it called any previously saved state can be overwritten. */ public interface Retryable { /** - * Save the state, so it can be restored in case of an OOM Retry. - * This is called inside a Spark task context on executors. + * Save the state, so it can be restored in the case of a retry. + * (This is called inside a Spark task context on executors.) */ void checkpoint(); /** * Restore the state that was saved by calling to "checkpoint". - * This is called inside a Spark task context on executors. + * (This is called inside a Spark task context on executors.) */ void restore(); } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/NonDeterministicRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/NonDeterministicRetrySuite.scala index 6d3f1d71cc8..b0cecd80c87 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/NonDeterministicRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/NonDeterministicRetrySuite.scala @@ -16,18 +16,21 @@ package com.nvidia.spark.rapids -import ai.rapids.cudf.ColumnVector +import ai.rapids.cudf.{ColumnVector, Table} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingSeq import com.nvidia.spark.rapids.jni.RmmSpark -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExprId} +import org.apache.spark.sql.rapids.GpuGreaterThan import org.apache.spark.sql.rapids.catalyst.expressions.GpuRand -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{DoubleType, IntegerType} import org.apache.spark.sql.vectorized.ColumnarBatch class NonDeterministicRetrySuite extends RmmSparkRetrySuiteBase { private val NUM_ROWS = 500 private val RAND_SEED = 10 + private val batchAttrs = Seq(AttributeReference("int", IntegerType)(ExprId(10))) private def buildBatch(ints: Seq[Int] = 0 until NUM_ROWS): ColumnarBatch = { new ColumnarBatch( @@ -43,14 +46,14 @@ class NonDeterministicRetrySuite extends RmmSparkRetrySuiteBase { randCol1.copyToHost() } withResource(randHCol1) { _ => - // store the state, and generate data again + assert(randHCol1.getRowCount.toInt == NUM_ROWS) + // Restore the state, and generate data again gpuRand.restore() val randHCol2 = withResource(gpuRand.columnarEval(inputCB)) { randCol2 => randCol2.copyToHost() } withResource(randHCol2) { _ => // check the two random columns are equal. - assert(randHCol1.getRowCount.toInt == NUM_ROWS) assert(randHCol1.getRowCount == randHCol2.getRowCount) (0 until randHCol1.getRowCount.toInt).foreach { pos => assert(randHCol1.getDouble(pos) == randHCol2.getDouble(pos)) @@ -61,27 +64,86 @@ class NonDeterministicRetrySuite extends RmmSparkRetrySuiteBase { } test("GPU project retry with GPU rand") { - val childOutput = Seq(AttributeReference("int", IntegerType)(NamedExpression.newExprId)) - val projectRandOnly = Seq( - GpuAlias(GpuRand(GpuLiteral(RAND_SEED)), "rand")(NamedExpression.newExprId)) - val projectList = projectRandOnly ++ childOutput + def projectRand(): Seq[GpuExpression] = Seq( + GpuAlias(GpuRand(GpuLiteral(RAND_SEED)), "rand")()) + Seq(true, false).foreach { useTieredProject => // expression should be retryable - val randOnlyProjectList = GpuBindReferences.bindGpuReferencesTiered(projectRandOnly, - childOutput, useTieredProject) - assert(randOnlyProjectList.areAllRetryable) - val boundProjectList = GpuBindReferences.bindGpuReferencesTiered(projectList, - childOutput, useTieredProject) - assert(boundProjectList.areAllRetryable) + val boundProjectRand = GpuBindReferences.bindGpuReferencesTiered(projectRand(), + batchAttrs, useTieredProject) + assert(boundProjectRand.areAllRetryable) + // project with and without retry + val batches = Seq(true, false).safeMap { forceRetry => + val boundProjectList = GpuBindReferences.bindGpuReferencesTiered( + projectRand() ++ batchAttrs, batchAttrs, useTieredProject) + assert(boundProjectList.areAllRetryable) + + val sb = closeOnExcept(buildBatch()) { cb => + SpillableColumnarBatch(cb, SpillPriorities.ACTIVE_ON_DECK_PRIORITY) + } + closeOnExcept(sb) { _ => + if (forceRetry) { + RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId) + } + } + boundProjectList.projectAndCloseWithRetrySingleBatch(sb) + } + // check the random columns + val randCols = withResource(batches) { case Seq(retriedBatch, batch) => + assert(retriedBatch.numRows() == batch.numRows()) + assert(retriedBatch.numCols() == batch.numCols()) + batches.safeMap(_.column(0).asInstanceOf[GpuColumnVector].copyToHost()) + } + withResource(randCols) { case Seq(retriedRand, rand) => + (0 until rand.getRowCount.toInt).foreach { pos => + assert(retriedRand.getDouble(pos) == rand.getDouble(pos)) + } + } + } + } + + test("GPU filter retry with GPU rand") { + def filterRand(): Seq[GpuExpression] = Seq( + GpuGreaterThan( + GpuRand(GpuLiteral.create(RAND_SEED, IntegerType)), + GpuLiteral.create(0.1d, DoubleType))) + + Seq(true, false).foreach { useTieredProject => + // filter with and without retry + val tables = Seq(true, false).safeMap { forceRetry => + val boundCondition = GpuBindReferences.bindGpuReferencesTiered(filterRand(), + batchAttrs, useTieredProject) + assert(boundCondition.areAllRetryable) + + val cb = buildBatch() + if (forceRetry) { + RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId) + } + val batchSeq = GpuFilter.filterAndClose(cb, boundCondition, + NoopMetric, NoopMetric, NoopMetric).toSeq + withResource(batchSeq) { _ => + val tables = batchSeq.safeMap(GpuColumnVector.from) + if (tables.size == 1) { + tables.head + } else { + withResource(tables) { _ => + assert(tables.size > 1) + Table.concatenate(tables: _*) + } + } + } + } - // project with retry - val sb = closeOnExcept(buildBatch()) { cb => - SpillableColumnarBatch(cb, SpillPriorities.ACTIVE_ON_DECK_PRIORITY) + // check the outputs + val cols = withResource(tables) { case Seq(retriedTable, table) => + assert(retriedTable.getRowCount == table.getRowCount) + assert(retriedTable.getNumberOfColumns == table.getNumberOfColumns) + tables.safeMap(_.getColumn(0).copyToHost()) } - RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId) - withResource(boundProjectList.projectAndCloseWithRetrySingleBatch(sb)) { outCB => - // We can not verify the data, so only rows number here - assertResult(NUM_ROWS)(outCB.numRows()) + withResource(cols) { case Seq(retriedInts, ints) => + (0 until ints.getRowCount.toInt).foreach { pos => + assert(retriedInts.getInt(pos) == ints.getInt(pos)) + } } } }