Skip to content

Commit

Permalink
Remove batches if they are received after the iterator detects that t… (
Browse files Browse the repository at this point in the history
NVIDIA#1180)

* Remove batches if they are received after the iterator detects that the task has completed

Signed-off-by: Alessandro Bellina <abellina@nvidia.com>
  • Loading branch information
abellina authored Dec 1, 2020
1 parent 515dd1a commit 6c23218
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ trait RapidsShuffleFetchHandler {
/**
* Called when a buffer is received and has been handed off to the catalog.
* @param bufferId - a tracked shuffle buffer id
* @return a boolean that lets the caller know the batch was accepted (true), or
* rejected (false), in which case the caller should dispose of the batch.
*/
def batchReceived(bufferId: ShuffleReceivedBufferId): Unit
def batchReceived(bufferId: ShuffleReceivedBufferId): Boolean

/**
* Called when the transport layer is not able to handle a fetch error for metadata
Expand Down Expand Up @@ -450,13 +452,25 @@ class RapidsShuffleClient(

val stats = tx.getStats

// the number of batches successfully received that the requesting iterator
// rejected (limit case)
var numBatchesRejected = 0

// hand buffer off to the catalog
buffMetas.foreach { consumed: ConsumedBatchFromBounceBuffer =>
val bId = track(consumed.contigBuffer, consumed.meta)
consumed.handler.batchReceived(bId.asInstanceOf[ShuffleReceivedBufferId])
if (!consumed.handler.batchReceived(bId)) {
catalog.removeBuffer(bId)
numBatchesRejected += 1
}
transport.doneBytesInFlight(consumed.contigBuffer.getLength)
}

if (numBatchesRejected > 0) {
logDebug(s"Removed ${numBatchesRejected} batches that were received after " +
s"tasks completed.")
}

logDebug(s"Received buffer size ${stats.receiveSize} in" +
s" ${stats.txTimeMs} ms @ bw: [recv: ${stats.recvThroughput}] GB/sec")

Expand All @@ -483,7 +497,8 @@ class RapidsShuffleClient(
* @param meta [[TableMeta]] describing [[buffer]]
* @return the [[RapidsBufferId]] to be used to look up the buffer from catalog
*/
private[shuffle] def track(buffer: DeviceMemoryBuffer, meta: TableMeta): RapidsBufferId = {
private[shuffle] def track(
buffer: DeviceMemoryBuffer, meta: TableMeta): ShuffleReceivedBufferId = {
val id: ShuffleReceivedBufferId = catalog.nextShuffleReceivedBufferId()
logDebug(s"Adding buffer id ${id} to catalog")
if (buffer != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
import scala.collection.mutable

import ai.rapids.cudf.{NvtxColor, NvtxRange}
import com.nvidia.spark.rapids.{GpuSemaphore, RapidsBuffer, RapidsConf, ShuffleReceivedBufferCatalog, ShuffleReceivedBufferId}
import com.nvidia.spark.rapids.{Arm, GpuSemaphore, RapidsBuffer, RapidsConf, ShuffleReceivedBufferCatalog, ShuffleReceivedBufferId}

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -56,7 +56,7 @@ class RapidsShuffleIterator(
catalog: ShuffleReceivedBufferCatalog = GpuShuffleEnv.getReceivedCatalog,
timeoutSeconds: Long = GpuShuffleEnv.shuffleFetchTimeoutSeconds)
extends Iterator[ColumnarBatch]
with Logging {
with Logging with Arm {

/**
* General trait encapsulating either a buffer or an error. Used to hand off batches
Expand Down Expand Up @@ -107,6 +107,10 @@ class RapidsShuffleIterator(
// [[pendingFetchesByAddress]] is empty, and there are no [[batchesInFlight]].
private[this] var totalBatchesResolved: Long = 0L

// If the task finishes, and this iterator is releasing resources, we set this to true
// which allows us to reject incoming batches that would otherwise get leaked
private[this] var taskComplete: Boolean = false

blocksByAddress.foreach(bba => {
// expected blocks per address
if (pendingFetchesByAddress.put(bba._1, bba._2.size).nonEmpty){
Expand Down Expand Up @@ -198,46 +202,49 @@ class RapidsShuffleIterator(
val handler = new RapidsShuffleFetchHandler {
private[this] var clientExpectedBatches = 0L
private[this] var clientResolvedBatches = 0L

def start(expectedBatches: Int): Unit = resolvedBatches.synchronized {
if (expectedBatches == 0) {
throw new IllegalStateException(
s"Received an invalid response from shuffle server: " +
s"0 expected batches for $shuffleRequestsMapIndex")
s"0 expected batches for $shuffleRequestsMapIndex")
}
pendingFetchesByAddress.remove(blockManagerId)
batchesInFlight = batchesInFlight + expectedBatches
totalBatchesExpected = totalBatchesExpected + expectedBatches
clientExpectedBatches = expectedBatches
logDebug(s"Task: $taskAttemptId Client $blockManagerId " +
s"Expecting $expectedBatches batches, $batchesInFlight batches currently in " +
s"flight, total expected by this client: $clientExpectedBatches, total resolved by " +
s"this client: $clientResolvedBatches")
s"Expecting $expectedBatches batches, $batchesInFlight batches currently in " +
s"flight, total expected by this client: $clientExpectedBatches, total " +
s"resolved by this client: $clientResolvedBatches")
}

def batchReceived(bufferId: ShuffleReceivedBufferId): Unit =
def batchReceived(bufferId: ShuffleReceivedBufferId): Boolean =
resolvedBatches.synchronized {
batchesInFlight = batchesInFlight - 1
val nvtxRange = new NvtxRange(s"BATCH RECEIVED", NvtxColor.DARK_GREEN)
try {
if (markedAsDone) {
throw new IllegalStateException(
"This iterator was marked done, but a batched showed up after!!")
}
totalBatchesResolved = totalBatchesResolved + 1
clientResolvedBatches = clientResolvedBatches + 1
resolvedBatches.offer(BufferReceived(bufferId))

if (clientExpectedBatches == clientResolvedBatches) {
logDebug(s"Task: $taskAttemptId Client $blockManagerId is " +
s"done fetching batches. Total batches expected $clientExpectedBatches, " +
s"total batches resolved $clientResolvedBatches.")
} else {
logDebug(s"Task: $taskAttemptId Client $blockManagerId is " +
s"NOT done fetching batches. Total batches expected $clientExpectedBatches, " +
s"total batches resolved $clientResolvedBatches.")
if (taskComplete) {
false
} else {
batchesInFlight = batchesInFlight - 1
withResource(new NvtxRange(s"BATCH RECEIVED", NvtxColor.DARK_GREEN)) { _ =>
if (markedAsDone) {
throw new IllegalStateException(
"This iterator was marked done, but a batched showed up after!!")
}
totalBatchesResolved = totalBatchesResolved + 1
clientResolvedBatches = clientResolvedBatches + 1
resolvedBatches.offer(BufferReceived(bufferId))

if (clientExpectedBatches == clientResolvedBatches) {
logDebug(s"Task: $taskAttemptId Client $blockManagerId is " +
s"done fetching batches. Total batches expected $clientExpectedBatches, " +
s"total batches resolved $clientResolvedBatches.")
} else {
logDebug(s"Task: $taskAttemptId Client $blockManagerId is " +
s"NOT done fetching batches. Total batches expected " +
s"$clientExpectedBatches, total batches resolved $clientResolvedBatches.")
}
}
} finally {
nvtxRange.close()
true
}
}

Expand All @@ -263,7 +270,8 @@ class RapidsShuffleIterator(
s"${clients.size} clients.")
}

private[this] def receiveBufferCleaner(): Unit = {
private[this] def receiveBufferCleaner(): Unit = resolvedBatches.synchronized {
taskComplete = true
if (hasNext) {
logWarning(s"Iterator for task ${taskAttemptId} closing, " +
s"but it is not done. Closing ${resolvedBatches.size()} resolved batches!!")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,58 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper {
}
}

test("successful buffer fetch - but handler rejected it") {
when(mockTransaction.getStatus).thenReturn(TransactionStatus.Success)
when(mockHandler.batchReceived(any())).thenReturn(false) // reject incoming batches

val numRows = 100
val tableMeta =
RapidsShuffleTestHelper.prepareMetaTransferResponse(mockTransport, numRows)
val sizePerBuffer = 10000
val expectedReceives = 1
closeOnExcept(getBounceBuffer(sizePerBuffer)) { bounceBuffer =>
val brs = prepareBufferReceiveState(tableMeta, bounceBuffer)

assert(brs.hasNext)

// Kick off receives
client.doIssueBufferReceives(brs)

// If transactions are successful, we should have completed the receive
assert(!brs.hasNext)

// we would issue as many requests as required in order to get the full contiguous
// buffer
verify(mockConnection, times(expectedReceives))
.receive(any[Seq[AddressLengthTag]](), any[TransactionCallback]())

// the mock connection keeps track of every receive length
val totalReceived = mockConnection.receiveLengths.sum
val numBuffersUsed = mockConnection.receiveLengths.size

assertResult(tableMeta.bufferMeta().size())(totalReceived)
assertResult(1)(numBuffersUsed)

// we would perform 1 request to issue a `TransferRequest`, so the server can start.
verify(mockConnection, times(1)).request(any(), any(), any[TransactionCallback]())

// we will hand off a `DeviceMemoryBuffer` to the catalog
val dmbCaptor = ArgumentCaptor.forClass(classOf[DeviceMemoryBuffer])
val tmCaptor = ArgumentCaptor.forClass(classOf[TableMeta])
verify(client, times(1)).track(any[DeviceMemoryBuffer](), tmCaptor.capture())
verifyTableMeta(tableMeta, tmCaptor.getValue.asInstanceOf[TableMeta])
verify(mockStorage, times(1))
.addBuffer(any(), dmbCaptor.capture(), any(), any())
verify(mockCatalog, times(1)).removeBuffer(any())

val receivedBuff = dmbCaptor.getValue.asInstanceOf[DeviceMemoryBuffer]
assertResult(tableMeta.bufferMeta().size())(receivedBuff.getLength)

// after closing, we should have freed our bounce buffers.
assertResult(true)(bounceBuffer.isClosed)
}
}

test("successful buffer fetch multi-buffer") {
when(mockTransaction.getStatus).thenReturn(TransactionStatus.Success)

Expand Down

0 comments on commit 6c23218

Please sign in to comment.