Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove batches if they are received after the iterator detects that t… #1180

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ trait RapidsShuffleFetchHandler {
/**
* Called when a buffer is received and has been handed off to the catalog.
* @param bufferId - a tracked shuffle buffer id
* @return a boolean that lets the caller know the batch was accepted (true), or
* rejected (false), in which case the caller should dispose of the batch.
*/
def batchReceived(bufferId: ShuffleReceivedBufferId): Unit
def batchReceived(bufferId: ShuffleReceivedBufferId): Boolean

/**
* Called when the transport layer is not able to handle a fetch error for metadata
Expand Down Expand Up @@ -450,13 +452,22 @@ class RapidsShuffleClient(

val stats = tx.getStats

val toDelete = new ArrayBuffer[ShuffleReceivedBufferId]()

// hand buffer off to the catalog
buffMetas.foreach { consumed: ConsumedBatchFromBounceBuffer =>
val bId = track(consumed.contigBuffer, consumed.meta)
consumed.handler.batchReceived(bId.asInstanceOf[ShuffleReceivedBufferId])
if (!consumed.handler.batchReceived(bId)) {
toDelete.append(bId)
jlowe marked this conversation as resolved.
Show resolved Hide resolved
}
transport.doneBytesInFlight(consumed.contigBuffer.getLength)
}

if (toDelete.nonEmpty) {
logDebug(s"Received ${toDelete.size} batches after task completed. Freeing.")
toDelete.foreach(id => catalog.removeBuffer(id))
}

logDebug(s"Received buffer size ${stats.receiveSize} in" +
s" ${stats.txTimeMs} ms @ bw: [recv: ${stats.recvThroughput}] GB/sec")

Expand All @@ -483,7 +494,8 @@ class RapidsShuffleClient(
* @param meta [[TableMeta]] describing [[buffer]]
* @return the [[RapidsBufferId]] to be used to look up the buffer from catalog
*/
private[shuffle] def track(buffer: DeviceMemoryBuffer, meta: TableMeta): RapidsBufferId = {
private[shuffle] def track(
buffer: DeviceMemoryBuffer, meta: TableMeta): ShuffleReceivedBufferId = {
val id: ShuffleReceivedBufferId = catalog.nextShuffleReceivedBufferId()
logDebug(s"Adding buffer id ${id} to catalog")
if (buffer != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ class RapidsShuffleIterator(
// [[pendingFetchesByAddress]] is empty, and there are no [[batchesInFlight]].
private[this] var totalBatchesResolved: Long = 0L

// If the task finishes, and this iterator is releasing resources, we set this to true
// which allows us to reject incoming batches that would otherwise get leaked
private[this] var taskComplete: Boolean = false

blocksByAddress.foreach(bba => {
// expected blocks per address
if (pendingFetchesByAddress.put(bba._1, bba._2.size).nonEmpty){
Expand Down Expand Up @@ -198,46 +202,52 @@ class RapidsShuffleIterator(
val handler = new RapidsShuffleFetchHandler {
private[this] var clientExpectedBatches = 0L
private[this] var clientResolvedBatches = 0L

def start(expectedBatches: Int): Unit = resolvedBatches.synchronized {
if (expectedBatches == 0) {
throw new IllegalStateException(
s"Received an invalid response from shuffle server: " +
s"0 expected batches for $shuffleRequestsMapIndex")
s"0 expected batches for $shuffleRequestsMapIndex")
}
pendingFetchesByAddress.remove(blockManagerId)
batchesInFlight = batchesInFlight + expectedBatches
totalBatchesExpected = totalBatchesExpected + expectedBatches
clientExpectedBatches = expectedBatches
logDebug(s"Task: $taskAttemptId Client $blockManagerId " +
s"Expecting $expectedBatches batches, $batchesInFlight batches currently in " +
s"flight, total expected by this client: $clientExpectedBatches, total resolved by " +
s"this client: $clientResolvedBatches")
s"Expecting $expectedBatches batches, $batchesInFlight batches currently in " +
s"flight, total expected by this client: $clientExpectedBatches, total " +
s"resolved by this client: $clientResolvedBatches")
}

def batchReceived(bufferId: ShuffleReceivedBufferId): Unit =
def batchReceived(bufferId: ShuffleReceivedBufferId): Boolean =
resolvedBatches.synchronized {
batchesInFlight = batchesInFlight - 1
val nvtxRange = new NvtxRange(s"BATCH RECEIVED", NvtxColor.DARK_GREEN)
try {
if (markedAsDone) {
throw new IllegalStateException(
"This iterator was marked done, but a batched showed up after!!")
}
totalBatchesResolved = totalBatchesResolved + 1
clientResolvedBatches = clientResolvedBatches + 1
resolvedBatches.offer(BufferReceived(bufferId))

if (clientExpectedBatches == clientResolvedBatches) {
logDebug(s"Task: $taskAttemptId Client $blockManagerId is " +
s"done fetching batches. Total batches expected $clientExpectedBatches, " +
s"total batches resolved $clientResolvedBatches.")
} else {
logDebug(s"Task: $taskAttemptId Client $blockManagerId is " +
s"NOT done fetching batches. Total batches expected $clientExpectedBatches, " +
s"total batches resolved $clientResolvedBatches.")
if (taskComplete) {
false
} else {
batchesInFlight = batchesInFlight - 1
val nvtxRange = new NvtxRange(s"BATCH RECEIVED", NvtxColor.DARK_GREEN)
jlowe marked this conversation as resolved.
Show resolved Hide resolved
try {
if (markedAsDone) {
jlowe marked this conversation as resolved.
Show resolved Hide resolved
throw new IllegalStateException(
"This iterator was marked done, but a batched showed up after!!")
}
totalBatchesResolved = totalBatchesResolved + 1
clientResolvedBatches = clientResolvedBatches + 1
resolvedBatches.offer(BufferReceived(bufferId))

if (clientExpectedBatches == clientResolvedBatches) {
logDebug(s"Task: $taskAttemptId Client $blockManagerId is " +
s"done fetching batches. Total batches expected $clientExpectedBatches, " +
s"total batches resolved $clientResolvedBatches.")
} else {
logDebug(s"Task: $taskAttemptId Client $blockManagerId is " +
s"NOT done fetching batches. Total batches expected " +
s"$clientExpectedBatches, total batches resolved $clientResolvedBatches.")
}
} finally {
nvtxRange.close()
}
} finally {
nvtxRange.close()
true
}
}

Expand All @@ -263,7 +273,8 @@ class RapidsShuffleIterator(
s"${clients.size} clients.")
}

private[this] def receiveBufferCleaner(): Unit = {
private[this] def receiveBufferCleaner(): Unit = resolvedBatches.synchronized {
taskComplete = true
if (hasNext) {
logWarning(s"Iterator for task ${taskAttemptId} closing, " +
s"but it is not done. Closing ${resolvedBatches.size()} resolved batches!!")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,58 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper {
}
}

test("successful buffer fetch - but handler rejected it") {
when(mockTransaction.getStatus).thenReturn(TransactionStatus.Success)
when(mockHandler.batchReceived(any())).thenReturn(false) // reject incoming batches

val numRows = 100
val tableMeta =
RapidsShuffleTestHelper.prepareMetaTransferResponse(mockTransport, numRows)
val sizePerBuffer = 10000
val expectedReceives = 1
closeOnExcept(getBounceBuffer(sizePerBuffer)) { bounceBuffer =>
val brs = prepareBufferReceiveState(tableMeta, bounceBuffer)

assert(brs.hasNext)

// Kick off receives
client.doIssueBufferReceives(brs)

// If transactions are successful, we should have completed the receive
assert(!brs.hasNext)

// we would issue as many requests as required in order to get the full contiguous
// buffer
verify(mockConnection, times(expectedReceives))
.receive(any[Seq[AddressLengthTag]](), any[TransactionCallback]())

// the mock connection keeps track of every receive length
val totalReceived = mockConnection.receiveLengths.sum
val numBuffersUsed = mockConnection.receiveLengths.size

assertResult(tableMeta.bufferMeta().size())(totalReceived)
assertResult(1)(numBuffersUsed)

// we would perform 1 request to issue a `TransferRequest`, so the server can start.
verify(mockConnection, times(1)).request(any(), any(), any[TransactionCallback]())

// we will hand off a `DeviceMemoryBuffer` to the catalog
val dmbCaptor = ArgumentCaptor.forClass(classOf[DeviceMemoryBuffer])
val tmCaptor = ArgumentCaptor.forClass(classOf[TableMeta])
verify(client, times(1)).track(any[DeviceMemoryBuffer](), tmCaptor.capture())
verifyTableMeta(tableMeta, tmCaptor.getValue.asInstanceOf[TableMeta])
verify(mockStorage, times(1))
.addBuffer(any(), dmbCaptor.capture(), any(), any())
verify(mockCatalog, times(1)).removeBuffer(any())

val receivedBuff = dmbCaptor.getValue.asInstanceOf[DeviceMemoryBuffer]
assertResult(tableMeta.bufferMeta().size())(receivedBuff.getLength)

// after closing, we should have freed our bounce buffers.
assertResult(true)(bounceBuffer.isClosed)
}
}

test("successful buffer fetch multi-buffer") {
when(mockTransaction.getStatus).thenReturn(TransactionStatus.Success)

Expand Down