Skip to content

Commit

Permalink
Merge pull request #9734 from NVIDIA/branch-23.12
Browse files Browse the repository at this point in the history
[auto-merge] branch-23.12 to branch-24.02 [skip ci] [bot]
  • Loading branch information
nvauto authored Nov 15, 2023
2 parents a46849d + c824a45 commit a3d1e46
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ object BatchWithPartitionDataUtils {
* Splits the input ColumnarBatch into smaller batches, wraps these batches with partition
* data, and returns them as a sequence of [[BatchWithPartitionData]].
*
* This function does not take ownership of `batch`, and callers should make sure to close.
*
* @note Partition values are merged with the columnar batches lazily by the resulting Iterator
* to save GPU memory.
* @param batch Input ColumnarBatch.
Expand Down Expand Up @@ -502,9 +504,10 @@ object BatchWithPartitionDataUtils {
throw new SplitAndRetryOOM("GPU OutOfMemory: cannot split input with one row")
}
// Split the batch into two halves
val cb = batchWithPartData.inputBatch.getColumnarBatch()
splitAndCombineBatchWithPartitionData(cb, splitPartitionData,
batchWithPartData.partitionSchema)
withResource(batchWithPartData.inputBatch.getColumnarBatch()) { cb =>
splitAndCombineBatchWithPartitionData(cb, splitPartitionData,
batchWithPartData.partitionSchema)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,17 @@ class BatchWithPartitionDataSuite extends RmmSparkRetrySuiteBase with SparkQuery
withResource(buildBatch(getSampleValueData)) { valueBatch =>
withResource(buildBatch(partCols)) { partBatch =>
withResource(GpuColumnVector.combineColumns(valueBatch, partBatch)) { expectedBatch =>
// we incRefCounts here because `addPartitionValuesToBatch` takes ownership of
// `valueBatch`, but we are keeping it alive since its columns are part of
// `expectedBatch`
GpuColumnVector.incRefCounts(valueBatch)
val resultBatchIter = BatchWithPartitionDataUtils.addPartitionValuesToBatch(valueBatch,
partRows, partValues, partSchema, maxGpuColumnSizeBytes)
withResource(resultBatchIter) { _ =>
RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId)
// Assert that the final count of rows matches expected batch
val rowCounts = resultBatchIter.map(_.numRows()).sum
// We also need to close each batch coming from `resultBatchIter`.
val rowCounts = resultBatchIter.map(withResource(_){_.numRows()}).sum
assert(rowCounts == expectedBatch.numRows())
}
}
Expand Down

0 comments on commit a3d1e46

Please sign in to comment.