Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow checkpoint and restore on non-deterministic expressions in GpuFilter and GpuProject #9287

Merged
merged 6 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions sql-plugin/src/main/java/com/nvidia/spark/Retryable.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark;

/**
* An interface that can be used by Retry framework of RAPIDS Plugin to handle the GPU OOMs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is too much detail for an end user. Could we adjust it a bit? Perhaps something more like.

An interface that can be used to retry the processing on non-deterministic expressions on the GPU.
GPU memory is a limited resource. When it runs out the RAPIDS Accelerator for Apache Spark will
use several different strategies to try and free more GPU so the query can complete. One of these
strategies is to roll back the processioning for one task, pause that tasks thread, than then retry
the task when more memory is available. This works transparently for any stateless deterministic 
processing. But technically an expression/UDF can be non-deterministic and/or keep state in
between calls. This interface provides a checkpoint method to save any needed state, and a
restore method to reset the state in the case of a retry. Please note that a retry is not isolated to
a single expression, so a restore can be called even after the expression returned
one or more batches of results.

Each time checkpoint it called any previously saved state can be overwritten.

Copy link
Collaborator Author

@firestarman firestarman Sep 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, done

*
* GPU memory is a limited resource, so OOM can happen if too many tasks run in parallel.
* Retry framework is introduced to improve the stability by retrying the work when it
* meets OOMs. The overall process of Retry framework is similar as the below.
* ```
* Retryable retryable
* retryable.checkpoint()
* boolean hasOOM = false
* do {
* try {
* runWorkOnGpu(retryable) // May lead to GPU OOM
* hasOOM = false
* } catch (OOMError oom) {
* hasOOM = true
* tryToReleaseSomeGpuMemoryFromLowPriorityTasks()
* retryable.restore()
* }
* } while(hasOOM)
* ```
* In a retry, "checkpoint" will be called first and only once, which is used to save the
* state for later loops. When OOM happens, "restore" will be called to restore the
* state that saved by "checkpoint". After that, it will try the same work again. And
* the whole process runs on Spark executors.
*
* Retry framework expects the "runWorkOnGpu" always outputs the same result when running
* it multiple times in a retry. So if "runWorkOnGpu" is non-deterministic, it can not be
* used by Retry framework.
* The "Retryable" is designed for this kind of cases. By implementing this interface,
* "runWorkOnGpu" can become deterministic inside a retry process, making it usable for
* Retry framework to improve the stability.
*/
public interface Retryable {
revans2 marked this conversation as resolved.
Show resolved Hide resolved
/**
* Save the state, so it can be restored in case of an OOM Retry.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we drop the OOM here?

Save the state so it can be restored in the case of a retry.

Copy link
Collaborator Author

@firestarman firestarman Sep 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, done

* This is called inside a Spark task context on executors.
*/
void checkpoint();

/**
* Restore the state that was saved by calling to "checkpoint".
* This is called inside a Spark task context on executors.
*/
void restore();
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.OutputStream
import scala.collection.mutable

import ai.rapids.cudf.{HostBufferConsumer, HostMemoryBuffer, NvtxColor, NvtxRange, TableWriter}
import com.nvidia.spark.Retryable
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitSpillableInHalfByRows, withRestoreOnRetry, withRetry, withRetryNoSplit}
Expand Down Expand Up @@ -186,7 +187,7 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext,
/** Apply any necessary casts before writing batch out */
def transformAndClose(cb: ColumnarBatch): ColumnarBatch = cb

private val checkpointRestore = new CheckpointRestore {
private val checkpointRestore = new Retryable {
override def checkpoint(): Unit = ()
override def restore(): Unit = dropBufferedData()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.nvidia.spark.rapids

import ai.rapids.cudf.{ast, BinaryOp, BinaryOperable, ColumnVector, DType, Scalar, UnaryOp}
import com.nvidia.spark.Retryable
import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.{ShimBinaryExpression, ShimExpression, ShimTernaryExpression, ShimUnaryExpression}
Expand Down Expand Up @@ -90,6 +91,19 @@ object GpuExpressionsUtils {
case ga: GpuAlias => extractGpuLit(ga.child)
case _ => None
}

/**
* Collect the Retryables from a Seq of expression.
*/
def collectRetryables(expressions: Seq[Expression]): Seq[Retryable] = {
// There should be no dependence between expression and its children for
// the checkpoint and restore operations.
expressions.flatMap { expr =>
expr.collect {
case r: Retryable => r
}
}
}
}

/**
Expand Down Expand Up @@ -169,10 +183,49 @@ trait GpuExpression extends Expression {
* this is seen.
*/
def disableTieredProjectCombine: Boolean = hasSideEffects

/**
* Whether an expression itself is non-deterministic when its "deterministic" is false,
* no matter whether it has any non-deterministic children.
* An expression is actually a tree, and deterministic being false means there is at
* least one tree node is non-deterministic, but we need to know the exact nodes which
* are non-deterministic to check if it implements the Retryable.
*
* Default to false because Spark checks only children by default in Expression. So it
* is non-deterministic iff it has non-deterministic children.
*
* NOTE When overriding "deterministic", this should be taken care of.
*/
val selfNonDeterministic: Boolean = false

/**
* true means this expression can be used inside a retry block, otherwise false.
* An expression is retryable when
* - it is deterministic, or
* - when being non-deterministic, it is a Retryable and its children are all retryable.
*/
lazy val retryable: Boolean = deterministic || {
val childrenAllRetryable = children.forall(_.asInstanceOf[GpuExpression].retryable)
if (selfNonDeterministic || children.forall(_.deterministic)) {
// self is non-deterministic, so need to check if it is a Retryable.
//
// "selfNonDeterministic" should be reliable enough, but it is still good to
// do this check for one case we are 100% sure self is non-deterministic (its
// "deterministic" is false but its children are all deterministic). This can
// minimize the possibility of missing expressions that happen to forget
// overriding "selfNonDeterministic" correctly.
this.isInstanceOf[Retryable] && childrenAllRetryable
} else {
childrenAllRetryable
}
}
}

abstract class GpuLeafExpression extends GpuExpression with ShimExpression {
override final def children: Seq[Expression] = Nil

/* no children, so only self can be non-deterministic */
override final val selfNonDeterministic: Boolean = !deterministic
}

trait GpuUnevaluable extends GpuExpression {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ trait GpuUserDefinedFunction extends GpuExpression
override def hasSideEffects: Boolean = true

override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
override val selfNonDeterministic: Boolean = !udfDeterministic

private[this] val nvtxRangeName = s"UDF: $name"
private[this] lazy val funcCls = TrampolineUtil.getSimpleName(function.getClass)
Expand Down Expand Up @@ -107,6 +108,7 @@ trait GpuRowBasedUserDefinedFunction extends GpuExpression
private[this] lazy val outputType = dataType.catalogString

override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
override val selfNonDeterministic: Boolean = !udfDeterministic
override def hasSideEffects: Boolean = true

override def columnarEval(batch: ColumnarBatch): GpuColumnVector = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.util.concurrent.TimeUnit

import ai.rapids.cudf
import ai.rapids.cudf.{BinaryOp, ColumnVector, DType, GroupByScanAggregation, RollingAggregation, RollingAggregationOnColumn, Scalar, ScanAggregation}
import com.nvidia.spark.Retryable
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.GpuOverrides.wrapExpr
import com.nvidia.spark.rapids.shims.{GpuWindowUtil, ShimExpression}
Expand Down Expand Up @@ -814,7 +815,7 @@ trait GpuRunningWindowFunction extends GpuWindowFunction {
* </code>
* which can be output.
*/
trait BatchedRunningWindowFixer extends AutoCloseable with CheckpointRestore {
trait BatchedRunningWindowFixer extends AutoCloseable with Retryable {
/**
* Fix up `windowedColumnOutput` with any stored state from previous batches.
* Like all window operations the input data will have been sorted by the partition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.nvidia.spark.rapids

import ai.rapids.cudf.{ColumnVector, ColumnView, DeviceMemoryBuffer, DType, GatherMap, NvtxColor, NvtxRange, OrderByArg, OutOfBoundsPolicy, Scalar, Table}
import com.nvidia.spark.Retryable
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}

import org.apache.spark.TaskContext
Expand All @@ -34,7 +35,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
* If the data is needed after `allowSpilling` is called the implementations should get the data
* back and cache it again until allowSpilling is called once more.
*/
trait LazySpillable extends AutoCloseable with CheckpointRestore {
trait LazySpillable extends AutoCloseable with Retryable {

/**
* Indicate that we are done using the data for now and it can be spilled.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import scala.annotation.tailrec
import scala.collection.mutable

import ai.rapids.cudf.CudfColumnSizeOverflowException
import com.nvidia.spark.Retryable
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion
Expand Down Expand Up @@ -223,19 +224,19 @@ object RmmRapidsRetryIterator extends Logging {
}

/**
* withRestoreOnRetry for CheckpointRestore. This helper function calls `fn` with no input and
* withRestoreOnRetry for Retryable. This helper function calls `fn` with no input and
* returns the result. In the event of an OOM Retry exception, it calls the restore() method
* of the input and then throws the oom exception. This is intended to be used within the `fn`
* of one of the withRetry* functions. It provides an opportunity to reset state in the case
* of a retry.
*
* @param r a single item T
* @param fn the work to perform. Takes no input and produces K
* @tparam T element type that must be a `CheckpointRestore` subclass
* @tparam T element type that must be a `Retryable` subclass
* @tparam K `fn` result type
* @return a single item of type K
*/
def withRestoreOnRetry[T <: CheckpointRestore, K](r: T)(fn: => K): K = {
def withRestoreOnRetry[T <: Retryable, K](r: T)(fn: => K): K = {
try {
fn
} catch {
Expand All @@ -250,19 +251,19 @@ object RmmRapidsRetryIterator extends Logging {
}

/**
* withRestoreOnRetry for CheckpointRestore. This helper function calls `fn` with no input and
* withRestoreOnRetry for Retryable. This helper function calls `fn` with no input and
* returns the result. In the event of an OOM Retry exception, it calls the restore() method
* of the input and then throws the oom exception. This is intended to be used within the `fn`
* of one of the withRetry* functions. It provides an opportunity to reset state in the case
* of a retry.
*
* @param r a Seq of item T
* @param fn the work to perform. Takes no input and produces K
* @tparam T element type that must be a `CheckpointRestore` subclass
* @tparam T element type that must be a `Retryable` subclass
* @tparam K `fn` result type
* @return a single item of type K
*/
def withRestoreOnRetry[T <: CheckpointRestore, K](r: Seq[T])(fn: => K): K = {
def withRestoreOnRetry[T <: Retryable, K](r: Seq[T])(fn: => K): K = {
try {
fn
} catch {
Expand Down Expand Up @@ -673,18 +674,6 @@ object RmmRapidsRetryIterator extends Logging {
}
}

trait CheckpointRestore {
/**
* Save state so it can be restored in case of an OOM Retry.
*/
def checkpoint(): Unit

/**
* Restore state that was checkpointed.
*/
def restore(): Unit
}

/**
* This is a wrapper that turns a target size into an autocloseable to allow it to be used
* in withRetry blocks. It is intended to be used to help with cases where the split calculation
Expand Down
Loading