Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
winningsix committed Oct 19, 2023
1 parent 2dffe58 commit bb4cd76
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,6 @@ public final GpuColumnVector incRefCount() {
cudfCv.incRefCount();
return this;
}

@Override
public final void close() {
// Just pass through the reference counting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,7 @@ case class GpuRangePartitioner(

def computeBoundsAndCloseWithRetry(batch: ColumnarBatch): (Array[Int], Array[GpuColumnVector]) = {
val types = GpuColumnVector.extractTypes(batch)
val spillableBatch = SpillableColumnarBatch(GpuColumnVector.incRefCounts(batch),
SpillPriorities.ACTIVE_ON_DECK_PRIORITY)
withRetryNoSplit(spillableBatch) { sb =>
withRetryNoSplit(SpillableColumnarBatch(batch, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)) { sb =>
val partedTable = withResource(sb.getColumnarBatch()) { cb =>
val parts = withResource(new NvtxRange("Calculate part", NvtxColor.CYAN)) { _ =>
computePartitionIndexes(cb)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ case class GpuRoundRobinPartitioning(numPartitions: Int)

def partitionInternal(batch: ColumnarBatch): (Array[Int], Array[GpuColumnVector]) = {
val sparkTypes = GpuColumnVector.extractTypes(batch)
if (numPartitions == 1) {
if (1 == numPartitions) {
// Skip retry since partition number = 1
withResource(GpuColumnVector.from(batch)) { table =>
val columns = (0 until table.getNumberOfColumns).zip(sparkTypes).map {
Expand All @@ -54,9 +54,9 @@ case class GpuRoundRobinPartitioning(numPartitions: Int)
(Array(0), columns)
}
} else {
// Increase ref count since the caller will close the batch also.
val spillableBatch = SpillableColumnarBatch(GpuColumnVector.incRefCounts(batch),
SpillPriorities.ACTIVE_ON_DECK_PRIORITY)
// Increase ref count since the caller will close the batch.
val spillableBatch = SpillableColumnarBatch(
GpuColumnVector.incRefCounts(batch), SpillPriorities.ACTIVE_ON_DECK_PRIORITY)
withRetryNoSplit(spillableBatch) { sb =>
withResource(sb.getColumnarBatch()) { b =>
withResource(GpuColumnVector.from(b)) { table =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,16 @@ class ShufflePartitionerRetrySuite extends RmmSparkRetrySuiteBase {
private def testRoundRobinPartitioner(partNum: Int) = {
TestUtils.withGpuSparkSession(new SparkConf()) { _ =>
val rrp = GpuRoundRobinPartitioning(partNum)
val prebuiltBatch = buildBatch
// batch will be closed within columnarEvalAny
val batch = buildBatch
RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1)
withResource(prebuiltBatch) { batch =>
// Increase ref count since batch will be closed by rrp
var ret: Array[(ColumnarBatch, Int)] = null
try {
ret = rrp.columnarEvalAny(GpuColumnVector.incRefCounts(batch))
.asInstanceOf[Array[(ColumnarBatch, Int)]]
assert(partNum === ret.size)
} finally {
if (ret != null) {
ret.map(_._1).safeClose()
}
var ret: Array[(ColumnarBatch, Int)] = null
try {
ret = rrp.columnarEvalAny(batch).asInstanceOf[Array[(ColumnarBatch, Int)]]
assert(partNum === ret.size)
} finally {
if (ret != null) {
ret.map(_._1).safeClose()
}
}
}
Expand All @@ -69,17 +66,16 @@ class ShufflePartitionerRetrySuite extends RmmSparkRetrySuiteBase {
val gpuSorter = new GpuSorter(Seq(sortOrder), Array(attrs))

val rp = GpuRangePartitioner(Array.apply(bounds), gpuSorter)
val prebuiltBatch = buildBatch
// batch will be closed within columnarEvalAny
val batch = buildBatch
RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1)
withResource(prebuiltBatch) { batch =>
var ret: Array[(ColumnarBatch, Int)] = null
try {
ret = rp.columnarEvalAny(batch).asInstanceOf[Array[(ColumnarBatch, Int)]]
assert(2 === ret.size)
} finally {
if (ret != null) {
ret.map(_._1).safeClose()
}
var ret: Array[(ColumnarBatch, Int)] = null
try {
ret = rp.columnarEvalAny(batch).asInstanceOf[Array[(ColumnarBatch, Int)]]
assert(ret.length === 2)
} finally {
if (ret != null) {
ret.map(_._1).safeClose()
}
}
}
Expand Down

0 comments on commit bb4cd76

Please sign in to comment.