Skip to content

Commit

Permalink
Shuffle/better error handling (NVIDIA#1121)
Browse files Browse the repository at this point in the history
* Handle all errors by passing to the RapidsShuffleIterator and allowing
it to create FetchFailed exceptions

Signed-off-by: Alessandro Bellina <abellina@nvidia.com>

* Remove log-then-throw

* On server side, for now, we can at least return all buffers back to the catalog when tehre is an error

Signed-off-by: Alessandro Bellina <abellina@nvidia.com>

* Make the fetch failed exception include a cause, rather than add a suppressed exception

* Use closeOnExcept

* Update tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala

Fix exception text

Co-authored-by: Jason Lowe <jlowe@nvidia.com>

Co-authored-by: Jason Lowe <jlowe@nvidia.com>
  • Loading branch information
abellina and jlowe authored Nov 17, 2020
1 parent 5f38d81 commit b71eb2d
Show file tree
Hide file tree
Showing 8 changed files with 308 additions and 205 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ class BufferReceiveState(
* Calls `transferError` on each `RapidsShuffleFetchHandler`
* @param errMsg - the message to pass onto the handlers
*/
def errorOcurred(errMsg: String): Unit = {
currentBlocks.foreach(_.block.request.handler.transferError(errMsg))
def errorOcurred(errMsg: String, throwable: Throwable = null): Unit = {
currentBlocks.foreach(_.block.request.handler.transferError(errMsg, throwable))
}

override def hasNext: Boolean = synchronized { hasMoreBuffers }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class BufferSendState(
isClosed = true
freeBounceBuffers()
request.close()
releaseAcquiredToCatalog()
}

case class RangeBuffer(
Expand Down Expand Up @@ -228,9 +229,6 @@ class BufferSendState(
* allowing us to safely return buffers to the catalog to be potentially freed if spilling.
*/
def releaseAcquiredToCatalog(): Unit = synchronized {
if (acquiredBuffs.isEmpty) {
logWarning("Told to close rapids buffers, but nothing was acquired")
}
acquiredBuffs.foreach(_.close())
acquiredBuffs = Seq.empty
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ trait RapidsShuffleFetchHandler {
*
* @param errorMessage - a string containing an error message
*/
def transferError(errorMessage: String): Unit
def transferError(errorMessage: String, throwable: Throwable = null): Unit
}

/**
Expand Down Expand Up @@ -101,7 +101,8 @@ class RapidsShuffleClient(
clientCopyExecutor: Executor,
maximumMetadataSize: Long,
devStorage: RapidsDeviceMemoryStore = RapidsBufferCatalog.getDeviceStorage,
catalog: ShuffleReceivedBufferCatalog = GpuShuffleEnv.getReceivedCatalog) extends Logging {
catalog: ShuffleReceivedBufferCatalog = GpuShuffleEnv.getReceivedCatalog)
extends Logging with Arm {

object ShuffleClientOps {
/**
Expand Down Expand Up @@ -152,21 +153,16 @@ class RapidsShuffleClient(
import ShuffleClientOps._

private[this] def handleOp(op: Any): Unit = {
try {
op match {
case HandleMetadataResponse(tx, resp, shuffleRequests, rapidsShuffleFetchHandler) =>
doHandleMetadataResponse(tx, resp, shuffleRequests, rapidsShuffleFetchHandler)
case FetchRetry(shuffleRequests, rapidsShuffleFetchHandler, fullResponseSize) =>
doFetch(shuffleRequests, rapidsShuffleFetchHandler, fullResponseSize)
case IssueBufferReceives(bufferReceiveState) =>
doIssueBufferReceives(bufferReceiveState)
case HandleBounceBufferReceive(tx, bufferReceiveState) =>
doHandleBounceBufferReceive(tx, bufferReceiveState)
}
} catch {
case t: Throwable => {
logError("Exception occurred while handling shuffle client task.", t)
}
// functions we dispatch to must not throw
op match {
case HandleMetadataResponse(tx, resp, shuffleRequests, rapidsShuffleFetchHandler) =>
doHandleMetadataResponse(tx, resp, shuffleRequests, rapidsShuffleFetchHandler)
case FetchRetry(shuffleRequests, rapidsShuffleFetchHandler, fullResponseSize) =>
doFetch(shuffleRequests, rapidsShuffleFetchHandler, fullResponseSize)
case IssueBufferReceives(bufferReceiveState) =>
doIssueBufferReceives(bufferReceiveState)
case HandleBounceBufferReceive(tx, bufferReceiveState) =>
doHandleBounceBufferReceive(tx, bufferReceiveState)
}
}

Expand Down Expand Up @@ -196,44 +192,47 @@ class RapidsShuffleClient(
def doFetch(shuffleRequests: Seq[ShuffleBlockBatchId],
handler: RapidsShuffleFetchHandler,
metadataSize: Long = maximumMetadataSize): Unit = {
val fetchRange = new NvtxRange("Client.fetch", NvtxColor.PURPLE)
try {
if (shuffleRequests.isEmpty) {
throw new IllegalStateException("Sending empty blockIds in the MetadataRequest?")
}
withResource(new NvtxRange("Client.fetch", NvtxColor.PURPLE)) { _ =>
if (shuffleRequests.isEmpty) {
throw new IllegalStateException("Sending empty blockIds in the MetadataRequest?")
}

// get a metadata response tag so we can send it with the request
val responseTag = connection.assignResponseTag

// serialize a request, note that this includes the responseTag in the message
val metaReq = new RefCountedDirectByteBuffer(ShuffleMetadata.buildShuffleMetadataRequest(
localExecutorId, // needed s.t. the server knows what endpoint to pick
responseTag,
shuffleRequests,
metadataSize))

logDebug(s"Requesting block_ids=[$shuffleRequests] from connection $connection, req: \n " +
s"${ShuffleMetadata.printRequest(ShuffleMetadata.getMetadataRequest(metaReq.getBuffer()))}")

val resp = transport.getMetaBuffer(metadataSize)

// make request
connection.request(
AddressLengthTag.from(
metaReq.acquire(),
connection.composeRequestTag(RequestType.MetadataRequest)),
AddressLengthTag.from(
resp.acquire(),
responseTag),
tx => {
try {
asyncOrBlock(HandleMetadataResponse(tx, resp, shuffleRequests, handler))
} finally {
metaReq.close()
}
})
} finally {
fetchRange.close()
// get a metadata response tag so we can send it with the request
val responseTag = connection.assignResponseTag

// serialize a request, note that this includes the responseTag in the message
val metaReq = new RefCountedDirectByteBuffer(ShuffleMetadata.buildShuffleMetadataRequest(
localExecutorId, // needed s.t. the server knows what endpoint to pick
responseTag,
shuffleRequests,
metadataSize))

logDebug(s"Requesting block_ids=[$shuffleRequests] from connection $connection, req: \n " +
s"${ShuffleMetadata.printRequest(
ShuffleMetadata.getMetadataRequest(metaReq.getBuffer()))}")

val resp = transport.getMetaBuffer(metadataSize)

// make request
connection.request(
AddressLengthTag.from(
metaReq.acquire(),
connection.composeRequestTag(RequestType.MetadataRequest)),
AddressLengthTag.from(
resp.acquire(),
responseTag),
tx => {
try {
asyncOrBlock(HandleMetadataResponse(tx, resp, shuffleRequests, handler))
} finally {
metaReq.close()
}
})
}
} catch {
case t: Throwable =>
handler.transferError("Error occurred while requesting metadata", t)
}
}

Expand All @@ -250,38 +249,44 @@ class RapidsShuffleClient(
resp: RefCountedDirectByteBuffer,
shuffleRequests: Seq[ShuffleBlockBatchId],
handler: RapidsShuffleFetchHandler): Unit = {
val start = System.currentTimeMillis()
val handleMetaRange = new NvtxRange("Client.handleMeta", NvtxColor.CYAN)
try {
tx.getStatus match {
case TransactionStatus.Success =>
// start the receives
val respBuffer = resp.getBuffer()
val metadataResponse: MetadataResponse = ShuffleMetadata.getMetadataResponse(respBuffer)

logDebug(s"Received from ${tx} response: \n:" +
s"${ShuffleMetadata.printResponse("received response", metadataResponse)}")

if (metadataResponse.fullResponseSize() <= respBuffer.capacity()) {
// signal to the handler how many batches are expected
handler.start(metadataResponse.tableMetasLength())

// queue up the receives
queueTransferRequests(metadataResponse, handler)
} else {
// NOTE: this path hasn't been tested yet.
logWarning("Large metadata message received, widening the receive size.")
asyncOrBlock(FetchRetry(shuffleRequests, handler, metadataResponse.fullResponseSize()))
}
case _ =>
handler.transferError(
tx.getErrorMessage.getOrElse(s"Unsuccessful metadata request ${tx}"))
val start = System.currentTimeMillis()
val handleMetaRange = new NvtxRange("Client.handleMeta", NvtxColor.CYAN)
try {
tx.getStatus match {
case TransactionStatus.Success =>
// start the receives
val respBuffer = resp.getBuffer()
val metadataResponse: MetadataResponse = ShuffleMetadata.getMetadataResponse(respBuffer)

logDebug(s"Received from ${tx} response: \n:" +
s"${ShuffleMetadata.printResponse("received response", metadataResponse)}")

if (metadataResponse.fullResponseSize() <= respBuffer.capacity()) {
// signal to the handler how many batches are expected
handler.start(metadataResponse.tableMetasLength())

// queue up the receives
queueTransferRequests(metadataResponse, handler)
} else {
// NOTE: this path hasn't been tested yet.
logWarning("Large metadata message received, widening the receive size.")
asyncOrBlock(
FetchRetry(shuffleRequests, handler, metadataResponse.fullResponseSize()))
}
case _ =>
handler.transferError(
tx.getErrorMessage.getOrElse(s"Unsuccessful metadata request ${tx}"))
}
} finally {
logDebug(s"Metadata response handled in ${TransportUtils.timeDiffMs(start)} ms")
handleMetaRange.close()
resp.close()
tx.close()
}
} finally {
logDebug(s"Metadata response handled in ${TransportUtils.timeDiffMs(start)} ms")
handleMetaRange.close()
resp.close()
tx.close()
} catch {
case t: Throwable =>
handler.transferError("Error occurred while handling metadata", t)
}
}

Expand All @@ -302,10 +307,16 @@ class RapidsShuffleClient(
* the transport's throttle logic.
*/
private[shuffle] def doIssueBufferReceives(bufferReceiveState: BufferReceiveState): Unit = {
if (!bufferReceiveState.hasIterated) {
sendTransferRequest(bufferReceiveState)
try {
if (!bufferReceiveState.hasIterated) {
sendTransferRequest(bufferReceiveState)
}
receiveBuffers(bufferReceiveState)
} catch {
case t: Throwable =>
bufferReceiveState.errorOcurred("Error issuing buffer receives", t)
bufferReceiveState.close()
}
receiveBuffers(bufferReceiveState)
}

private def receiveBuffers(bufferReceiveState: BufferReceiveState): Transaction = {
Expand Down Expand Up @@ -430,34 +441,38 @@ class RapidsShuffleClient(
*/
def doHandleBounceBufferReceive(tx: Transaction,
bufferReceiveState: BufferReceiveState): Unit = {
val nvtxRange = new NvtxRange("Buffer Callback", NvtxColor.RED)
try {
// consume buffers, which will non empty for batches that are ready
// to be handed off to the catalog
val buffMetas = bufferReceiveState.consumeWindow()
withResource(new NvtxRange("Buffer Callback", NvtxColor.RED)) { _ =>
withResource(tx) { tx =>
// consume buffers, which will non empty for batches that are ready
// to be handed off to the catalog
val buffMetas = bufferReceiveState.consumeWindow()

val stats = tx.getStats

// hand buffer off to the catalog
buffMetas.foreach { consumed: ConsumedBatchFromBounceBuffer =>
val bId = track(consumed.contigBuffer, consumed.meta)
consumed.handler.batchReceived(bId.asInstanceOf[ShuffleReceivedBufferId])
transport.doneBytesInFlight(consumed.contigBuffer.getLength)
}

val stats = tx.getStats
logDebug(s"Received buffer size ${stats.receiveSize} in" +
s" ${stats.txTimeMs} ms @ bw: [recv: ${stats.recvThroughput}] GB/sec")

// hand buffer off to the catalog
buffMetas.foreach { consumed: ConsumedBatchFromBounceBuffer =>
val bId = track(consumed.contigBuffer, consumed.meta)
consumed.handler.batchReceived(bId.asInstanceOf[ShuffleReceivedBufferId])
transport.doneBytesInFlight(consumed.contigBuffer.getLength)
if (bufferReceiveState.hasNext) {
logDebug(s"${bufferReceiveState} is not done.")
asyncOnCopyThread(IssueBufferReceives(bufferReceiveState))
} else {
logDebug(s"${bufferReceiveState} is DONE, closing.")
bufferReceiveState.close()
}
}
}

logDebug(s"Received buffer size ${stats.receiveSize} in" +
s" ${stats.txTimeMs} ms @ bw: [recv: ${stats.recvThroughput}] GB/sec")

if (bufferReceiveState.hasNext) {
logDebug(s"${bufferReceiveState} is not done.")
asyncOnCopyThread(IssueBufferReceives(bufferReceiveState))
} else {
logDebug(s"${bufferReceiveState} is DONE, closing.")
} catch {
case t: Throwable =>
bufferReceiveState.errorOcurred("Error while handling buffer receives", t)
bufferReceiveState.close()
}
} finally {
nvtxRange.close()
tx.close()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ class RapidsShuffleIterator(
blockManagerId: BlockManagerId,
blockId: ShuffleBlockBatchId,
mapIndex: Int,
errorMessage: String) extends ShuffleClientResult
errorMessage: String,
throwable: Throwable) extends ShuffleClientResult

// when batches (or errors) arrive from the transport, the are pushed
// to the `resolvedBatches` queue.
Expand Down Expand Up @@ -182,16 +183,15 @@ class RapidsShuffleIterator(
transport.makeClient(localExecutorId, blockManagerId)
} catch {
case t: Throwable => {
val errorMsg = s"Error getting client to fetch ${blockIds} from ${blockManagerId}: ${t}"
logError(errorMsg, t)
val BlockIdMapIndex(firstId, firstMapIndex) = shuffleRequestsMapIndex.head
throw new RapidsShuffleFetchFailedException(
blockManagerId,
firstId.shuffleId,
firstId.mapId,
firstMapIndex,
firstId.startReduceId,
errorMsg)
s"Error getting client to fetch ${blockIds} from ${blockManagerId}",
t)
}
}

Expand Down Expand Up @@ -241,15 +241,16 @@ class RapidsShuffleIterator(
}
}

override def transferError(errorMessage: String): Unit = resolvedBatches.synchronized {
// If Spark detects a single fetch failure, the whole task has failed
// as per `FetchFailedException`. In the future `mapIndex` will come from the
// error callback.
shuffleRequestsMapIndex.map { case BlockIdMapIndex(id, mapIndex) =>
resolvedBatches.offer(TransferError(
blockManagerId, id, mapIndex, errorMessage))
override def transferError(errorMessage: String, throwable: Throwable): Unit =
resolvedBatches.synchronized {
// If Spark detects a single fetch failure, the whole task has failed
// as per `FetchFailedException`. In the future `mapIndex` will come from the
// error callback.
shuffleRequestsMapIndex.map { case BlockIdMapIndex(id, mapIndex) =>
resolvedBatches.offer(TransferError(
blockManagerId, id, mapIndex, errorMessage, throwable))
}
}
}
}

logInfo(s"Client $blockManagerId triggered, for ${shuffleRequestsMapIndex.size} blocks")
Expand Down Expand Up @@ -337,18 +338,19 @@ class RapidsShuffleIterator(
}
catalog.removeBuffer(bufferId)
}
case Some(TransferError(blockManagerId, shuffleBlockBatchId, mapIndex, errorMessage)) =>
case Some(
TransferError(blockManagerId, shuffleBlockBatchId, mapIndex, errorMessage, throwable)) =>
taskContext.foreach(GpuSemaphore.releaseIfNecessary)
metricsUpdater.update(blockedTime, 0, 0, 0)
val errorMsg = s"Transfer error detected by shuffle iterator, failing task. ${errorMessage}"
logError(errorMsg)
throw new RapidsShuffleFetchFailedException(
val exp = new RapidsShuffleFetchFailedException(
blockManagerId,
shuffleBlockBatchId.shuffleId,
shuffleBlockBatchId.mapId,
mapIndex,
shuffleBlockBatchId.startReduceId,
errorMsg)
s"Transfer error detected by shuffle iterator, failing task. ${errorMessage}",
throwable)
throw exp
case None =>
// NOTE: this isn't perfect, since what we really want is the transport to
// bubble this error, but for now we'll make this a fatal exception.
Expand Down
Loading

0 comments on commit b71eb2d

Please sign in to comment.