diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala index 750ccce8a1c..926b8a021e4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala @@ -43,8 +43,10 @@ trait RapidsShuffleFetchHandler { /** * Called when a buffer is received and has been handed off to the catalog. * @param bufferId - a tracked shuffle buffer id + * @return a boolean that lets the caller know the batch was accepted (true), or + * rejected (false), in which case the caller should dispose of the batch. */ - def batchReceived(bufferId: ShuffleReceivedBufferId): Unit + def batchReceived(bufferId: ShuffleReceivedBufferId): Boolean /** * Called when the transport layer is not able to handle a fetch error for metadata @@ -450,13 +452,25 @@ class RapidsShuffleClient( val stats = tx.getStats + // the number of batches successfully received that the requesting iterator + // rejected (limit case) + var numBatchesRejected = 0 + // hand buffer off to the catalog buffMetas.foreach { consumed: ConsumedBatchFromBounceBuffer => val bId = track(consumed.contigBuffer, consumed.meta) - consumed.handler.batchReceived(bId.asInstanceOf[ShuffleReceivedBufferId]) + if (!consumed.handler.batchReceived(bId)) { + catalog.removeBuffer(bId) + numBatchesRejected += 1 + } transport.doneBytesInFlight(consumed.contigBuffer.getLength) } + if (numBatchesRejected > 0) { + logDebug(s"Removed ${numBatchesRejected} batches that were received after " + + s"tasks completed.") + } + logDebug(s"Received buffer size ${stats.receiveSize} in" + s" ${stats.txTimeMs} ms @ bw: [recv: ${stats.recvThroughput}] GB/sec") @@ -483,7 +497,8 @@ class RapidsShuffleClient( * @param meta [[TableMeta]] describing [[buffer]] * @return the [[RapidsBufferId]] to be used to look up the buffer from catalog */ - private[shuffle] def track(buffer: DeviceMemoryBuffer, meta: TableMeta): RapidsBufferId = { + private[shuffle] def track( + buffer: DeviceMemoryBuffer, meta: TableMeta): ShuffleReceivedBufferId = { val id: ShuffleReceivedBufferId = catalog.nextShuffleReceivedBufferId() logDebug(s"Adding buffer id ${id} to catalog") if (buffer != null) { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala index f2e672cef14..083c0696a47 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala @@ -21,7 +21,7 @@ import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} import scala.collection.mutable import ai.rapids.cudf.{NvtxColor, NvtxRange} -import com.nvidia.spark.rapids.{GpuSemaphore, RapidsBuffer, RapidsConf, ShuffleReceivedBufferCatalog, ShuffleReceivedBufferId} +import com.nvidia.spark.rapids.{Arm, GpuSemaphore, RapidsBuffer, RapidsConf, ShuffleReceivedBufferCatalog, ShuffleReceivedBufferId} import org.apache.spark.TaskContext import org.apache.spark.internal.Logging @@ -56,7 +56,7 @@ class RapidsShuffleIterator( catalog: ShuffleReceivedBufferCatalog = GpuShuffleEnv.getReceivedCatalog, timeoutSeconds: Long = GpuShuffleEnv.shuffleFetchTimeoutSeconds) extends Iterator[ColumnarBatch] - with Logging { + with Logging with Arm { /** * General trait encapsulating either a buffer or an error. Used to hand off batches @@ -107,6 +107,10 @@ class RapidsShuffleIterator( // [[pendingFetchesByAddress]] is empty, and there are no [[batchesInFlight]]. private[this] var totalBatchesResolved: Long = 0L + // If the task finishes, and this iterator is releasing resources, we set this to true + // which allows us to reject incoming batches that would otherwise get leaked + private[this] var taskComplete: Boolean = false + blocksByAddress.foreach(bba => { // expected blocks per address if (pendingFetchesByAddress.put(bba._1, bba._2.size).nonEmpty){ @@ -198,46 +202,49 @@ class RapidsShuffleIterator( val handler = new RapidsShuffleFetchHandler { private[this] var clientExpectedBatches = 0L private[this] var clientResolvedBatches = 0L + def start(expectedBatches: Int): Unit = resolvedBatches.synchronized { if (expectedBatches == 0) { throw new IllegalStateException( s"Received an invalid response from shuffle server: " + - s"0 expected batches for $shuffleRequestsMapIndex") + s"0 expected batches for $shuffleRequestsMapIndex") } pendingFetchesByAddress.remove(blockManagerId) batchesInFlight = batchesInFlight + expectedBatches totalBatchesExpected = totalBatchesExpected + expectedBatches clientExpectedBatches = expectedBatches logDebug(s"Task: $taskAttemptId Client $blockManagerId " + - s"Expecting $expectedBatches batches, $batchesInFlight batches currently in " + - s"flight, total expected by this client: $clientExpectedBatches, total resolved by " + - s"this client: $clientResolvedBatches") + s"Expecting $expectedBatches batches, $batchesInFlight batches currently in " + + s"flight, total expected by this client: $clientExpectedBatches, total " + + s"resolved by this client: $clientResolvedBatches") } - def batchReceived(bufferId: ShuffleReceivedBufferId): Unit = + def batchReceived(bufferId: ShuffleReceivedBufferId): Boolean = resolvedBatches.synchronized { - batchesInFlight = batchesInFlight - 1 - val nvtxRange = new NvtxRange(s"BATCH RECEIVED", NvtxColor.DARK_GREEN) - try { - if (markedAsDone) { - throw new IllegalStateException( - "This iterator was marked done, but a batched showed up after!!") - } - totalBatchesResolved = totalBatchesResolved + 1 - clientResolvedBatches = clientResolvedBatches + 1 - resolvedBatches.offer(BufferReceived(bufferId)) - - if (clientExpectedBatches == clientResolvedBatches) { - logDebug(s"Task: $taskAttemptId Client $blockManagerId is " + - s"done fetching batches. Total batches expected $clientExpectedBatches, " + - s"total batches resolved $clientResolvedBatches.") - } else { - logDebug(s"Task: $taskAttemptId Client $blockManagerId is " + - s"NOT done fetching batches. Total batches expected $clientExpectedBatches, " + - s"total batches resolved $clientResolvedBatches.") + if (taskComplete) { + false + } else { + batchesInFlight = batchesInFlight - 1 + withResource(new NvtxRange(s"BATCH RECEIVED", NvtxColor.DARK_GREEN)) { _ => + if (markedAsDone) { + throw new IllegalStateException( + "This iterator was marked done, but a batched showed up after!!") + } + totalBatchesResolved = totalBatchesResolved + 1 + clientResolvedBatches = clientResolvedBatches + 1 + resolvedBatches.offer(BufferReceived(bufferId)) + + if (clientExpectedBatches == clientResolvedBatches) { + logDebug(s"Task: $taskAttemptId Client $blockManagerId is " + + s"done fetching batches. Total batches expected $clientExpectedBatches, " + + s"total batches resolved $clientResolvedBatches.") + } else { + logDebug(s"Task: $taskAttemptId Client $blockManagerId is " + + s"NOT done fetching batches. Total batches expected " + + s"$clientExpectedBatches, total batches resolved $clientResolvedBatches.") + } } - } finally { - nvtxRange.close() + true } } @@ -263,7 +270,8 @@ class RapidsShuffleIterator( s"${clients.size} clients.") } - private[this] def receiveBufferCleaner(): Unit = { + private[this] def receiveBufferCleaner(): Unit = resolvedBatches.synchronized { + taskComplete = true if (hasNext) { logWarning(s"Iterator for task ${taskAttemptId} closing, " + s"but it is not done. Closing ${resolvedBatches.size()} resolved batches!!") diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala index ad7bf4e7eea..9195d09ed5d 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala @@ -256,6 +256,58 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { } } + test("successful buffer fetch - but handler rejected it") { + when(mockTransaction.getStatus).thenReturn(TransactionStatus.Success) + when(mockHandler.batchReceived(any())).thenReturn(false) // reject incoming batches + + val numRows = 100 + val tableMeta = + RapidsShuffleTestHelper.prepareMetaTransferResponse(mockTransport, numRows) + val sizePerBuffer = 10000 + val expectedReceives = 1 + closeOnExcept(getBounceBuffer(sizePerBuffer)) { bounceBuffer => + val brs = prepareBufferReceiveState(tableMeta, bounceBuffer) + + assert(brs.hasNext) + + // Kick off receives + client.doIssueBufferReceives(brs) + + // If transactions are successful, we should have completed the receive + assert(!brs.hasNext) + + // we would issue as many requests as required in order to get the full contiguous + // buffer + verify(mockConnection, times(expectedReceives)) + .receive(any[Seq[AddressLengthTag]](), any[TransactionCallback]()) + + // the mock connection keeps track of every receive length + val totalReceived = mockConnection.receiveLengths.sum + val numBuffersUsed = mockConnection.receiveLengths.size + + assertResult(tableMeta.bufferMeta().size())(totalReceived) + assertResult(1)(numBuffersUsed) + + // we would perform 1 request to issue a `TransferRequest`, so the server can start. + verify(mockConnection, times(1)).request(any(), any(), any[TransactionCallback]()) + + // we will hand off a `DeviceMemoryBuffer` to the catalog + val dmbCaptor = ArgumentCaptor.forClass(classOf[DeviceMemoryBuffer]) + val tmCaptor = ArgumentCaptor.forClass(classOf[TableMeta]) + verify(client, times(1)).track(any[DeviceMemoryBuffer](), tmCaptor.capture()) + verifyTableMeta(tableMeta, tmCaptor.getValue.asInstanceOf[TableMeta]) + verify(mockStorage, times(1)) + .addBuffer(any(), dmbCaptor.capture(), any(), any()) + verify(mockCatalog, times(1)).removeBuffer(any()) + + val receivedBuff = dmbCaptor.getValue.asInstanceOf[DeviceMemoryBuffer] + assertResult(tableMeta.bufferMeta().size())(receivedBuff.getLength) + + // after closing, we should have freed our bounce buffers. + assertResult(true)(bounceBuffer.isClosed) + } + } + test("successful buffer fetch multi-buffer") { when(mockTransaction.getStatus).thenReturn(TransactionStatus.Success)