Skip to content

Commit

Permalink
Degenerate table metas were not getting copied to the heap, which cou…
Browse files Browse the repository at this point in the history
…ld cause corruption (NVIDIA#529)

Signed-off-by: Alessandro Bellina <abellina@nvidia.com>
  • Loading branch information
abellina authored Aug 8, 2020
1 parent 97f2e7b commit fbea5c5
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -705,11 +705,11 @@ class RapidsShuffleClient(

val ptrs = new ArrayBuffer[PendingTransferRequest](allTables)
(0 until allTables).foreach { i =>
val tableMeta = metaResponse.tableMetas(i)
val tableMeta = ShuffleMetadata.copyTableMetaToHeap(metaResponse.tableMetas(i))
if (tableMeta.bufferMeta() != null) {
ptrs += PendingTransferRequest(
this,
ShuffleMetadata.copyTableMetaToHeap(tableMeta),
tableMeta,
connection.assignBufferTag(tableMeta.bufferMeta().id),
handler)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,31 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper {
}
}

test("successful degenerate metadata fetch") {
when(mockTransaction.getStatus).thenReturn(TransactionStatus.Success)
val shuffleRequests = RapidsShuffleTestHelper.getShuffleBlocks
val numRows = 100000
val numBatches = 3

RapidsShuffleTestHelper.mockDegenerateMetaResponse(mockTransport, numRows, numBatches)

// initialize metadata fetch
client.doFetch(shuffleRequests.map(_._1), mockHandler)

// the connection saw one request (for metadata)
assertResult(1)(mockConnection.requests.size)

// upon a successful response, the `start()` method in the fetch handler
// will be called with 3 expected batches
verify(mockHandler, times(1)).start(ArgumentMatchers.eq(numBatches))

// nothing gets queued to be received since it's just metadata
verify(mockTransport, times(0)).queuePending(any())

// ensure our handler (iterator) received 3 batches
verify(mockHandler, times(numBatches)).batchReceived(any())
}

test("errored/cancelled metadata fetch") {
Seq(TransactionStatus.Error, TransactionStatus.Cancelled).foreach { status =>
when(mockTransaction.getStatus).thenReturn(status)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ object RapidsShuffleTestHelper extends MockitoSugar with Arm {
MetaUtils.buildTableMeta(tableId, cols, tbl.getRowCount, contigTable.getBuffer)
}

def buildDegenerateMockTableMeta(tableId: Int): TableMeta = {
MetaUtils.buildDegenerateTableMeta(tableId, new ColumnarBatch(Array.empty, 123))
}

def withMockContiguousTable[T](numRows: Long)(body: ContiguousTable => T): T = {
val rows: Seq[Integer] = (0 until numRows.toInt).map(new Integer(_))
withResource(ColumnVector.fromBoxedInts(rows:_*)) { cvBase =>
Expand Down Expand Up @@ -104,6 +108,18 @@ object RapidsShuffleTestHelper extends MockitoSugar with Arm {
tableMetas
}

def mockDegenerateMetaResponse(
mockTransport: RapidsShuffleTransport,
numRows: Long,
numBatches: Int,
maximumResponseSize: Long = 10000): Seq[TableMeta] = {
val tableMetas = (0 until numBatches).map(b => buildDegenerateMockTableMeta(b))
val res = ShuffleMetadata.buildMetaResponse(tableMetas, maximumResponseSize)
val refCountedRes = new RefCountedDirectByteBuffer(res)
when(mockTransport.getMetaBuffer(any())).thenReturn(refCountedRes)
tableMetas
}

def prepareMetaTransferResponse(
mockTransport: RapidsShuffleTransport,
numRows: Long): TableMeta =
Expand Down

0 comments on commit fbea5c5

Please sign in to comment.