Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add implicit safeFree for RapidsBuffer #5471

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,7 @@ trait Arm {
block(r)
} catch {
case t: Throwable =>
try {
if (r != null) {
r.free()
}
} catch {
case e: Throwable =>
t.addSuppressed(e)
}
r.safeFree(t)
throw t
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.util.concurrent.ConcurrentHashMap
import java.util.function.BiFunction

import ai.rapids.cudf.{ContiguousTable, DeviceMemoryBuffer, Rmm, Table}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.StorageTier.StorageTier
import com.nvidia.spark.rapids.format.TableMeta

Expand Down Expand Up @@ -144,9 +145,7 @@ class RapidsBufferCatalog extends Logging {
/** Remove a buffer ID from the catalog and release the resources of the registered buffers. */
def removeBuffer(id: RapidsBufferId): Unit = {
val buffers = bufferMap.remove(id)
if (buffers != null) {
buffers.foreach(_.free())
}
buffers.safeFree()
}

/** Return the number of buffers currently in the catalog. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong

import ai.rapids.cudf.{BaseDeviceMemoryBuffer, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, NvtxColor, NvtxRange}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.StorageTier.{DEVICE, StorageTier}
import com.nvidia.spark.rapids.format.TableMeta

Expand Down Expand Up @@ -80,7 +81,7 @@ abstract class RapidsBufferStore(
// We need to release the `RapidsBufferStore` lock to prevent a lock order inversion
// deadlock: (1) `RapidsBufferBase.free` calls (2) `RapidsBufferStore.remove` and
// (1) `RapidsBufferStore.freeAll` calls (2) `RapidsBufferBase.free`.
values.foreach(_.free())
values.safeFree()
}

def nextSpillableBuffer(): RapidsBufferBase = synchronized {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.function.BiFunction
import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.StorageTier.StorageTier
import com.nvidia.spark.rapids.format.TableMeta

Expand Down Expand Up @@ -136,7 +137,7 @@ class RapidsGdsStore(
private[this] var currentOffset = 0L

override def close(): Unit = {
pendingBuffers.foreach(_.free())
pendingBuffers.safeFree()
pendingBuffers.clear()
batchWriteBuffer.close()
}
Expand Down
54 changes: 54 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,26 @@ object RapidsPluginImplicits {
}
}

implicit class RapidsBufferColumn(rapidsBuffer: RapidsBuffer) {

/**
* safeFree: Is an implicit on RapidsBuffer class that tries to free the resource, if an
* Exception was thrown prior to this free, it adds the new exception to the suppressed
* exceptions, otherwise just throws
*
* @param e Exception which we don't want to suppress
*/
def safeFree(e: Throwable = null): Unit = {
if (rapidsBuffer != null) {
try {
rapidsBuffer.free()
} catch {
case suppressed: Throwable if e != null => e.addSuppressed(suppressed)
}
}
}
}

implicit class AutoCloseableSeq[A <: AutoCloseable](val in: SeqLike[A, _]) {
/**
* safeClose: Is an implicit on a sequence of AutoCloseable classes that tries to close each
Expand Down Expand Up @@ -87,12 +107,46 @@ object RapidsPluginImplicits {
}
}

implicit class RapidsBufferSeq[A <: RapidsBuffer](val in: SeqLike[A, _]) {
/**
* safeFree: Is an implicit on a sequence of RapidsBuffer classes that tries to free each
* element of the sequence, even if prior free calls fail. In case of failure in any of the
* free calls, an Exception is thrown containing the suppressed exceptions (getSuppressed),
* if any.
*/
def safeFree(error: Throwable = null): Unit = if (in != null) {
var freeException: Throwable = null
in.foreach { element =>
if (element != null) {
try {
element.free()
} catch {
case e: Throwable if error != null => error.addSuppressed(e)
case e: Throwable if freeException == null => freeException = e
case e: Throwable => freeException.addSuppressed(e)
}
}
}
if (freeException != null) {
// an exception happened while we were trying to safely free
// resources, throw the exception to alert the caller
throw freeException
}
}
}

implicit class AutoCloseableArray[A <: AutoCloseable](val in: Array[A]) {
def safeClose(e: Throwable = null): Unit = if (in != null) {
in.toSeq.safeClose(e)
}
}
HaoYang670 marked this conversation as resolved.
Show resolved Hide resolved

implicit class RapidsBufferArray[A <: RapidsBuffer](val in: Array[A]) {
def safeFree(e: Throwable = null): Unit = if (in != null) {
in.toSeq.safeFree(e)
}
}

class MapsSafely[A, Repr] {
/**
* safeMap: safeMap implementation that is leveraged by other type-specific implicits.
Expand Down