diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesSuite.scala index eb11ecd7a4a..f9a1fc20254 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesSuite.scala @@ -531,7 +531,9 @@ class GpuCoalesceBatchesSuite extends SparkQueryCompareTestSuite { val codec = TableCompressionCodec.getCodec(CodecType.NVCOMP_LZ4) withResource(codec.createBatchCompressor(0, Cuda.DEFAULT_STREAM)) { compressor => compressor.addTableToCompress(buildContiguousTable(start, numRows)) - GpuCompressedColumnVector.from(compressor.finish().head) + withResource(compressor.finish()) { compressed => + GpuCompressedColumnVector.from(compressed.head) + } } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala index 6e3939c43ba..6ba766436e9 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala @@ -49,8 +49,7 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { withResource(new RapidsDeviceMemoryStore(catalog)) { store => val spillPriority = 3 val bufferId = MockRapidsBufferId(7) - closeOnExcept(buildContiguousTable()) { ct => - // store takes ownership of the table + withResource(buildContiguousTable()) { ct => store.addContiguousTable(bufferId, ct, spillPriority) } val captor: ArgumentCaptor[RapidsBuffer] = ArgumentCaptor.forClass(classOf[RapidsBuffer]) @@ -142,7 +141,7 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { assertResult(0)(store.currentSize) val bufferSizes = new Array[Long](2) bufferSizes.indices.foreach { i => - closeOnExcept(buildContiguousTable()) { ct => + withResource(buildContiguousTable()) { ct => bufferSizes(i) = ct.getBuffer.getLength // store takes ownership of the table store.addContiguousTable(MockRapidsBufferId(i), ct, initialSpillPriority = 0) @@ -164,7 +163,7 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { withResource(new RapidsDeviceMemoryStore(catalog)) { store => store.setSpillStore(spillStore) spillPriorities.indices.foreach { i => - closeOnExcept(buildContiguousTable()) { ct => + withResource(buildContiguousTable()) { ct => bufferSizes(i) = ct.getBuffer.getLength // store takes ownership of the table store.addContiguousTable(MockRapidsBufferId(i), ct, spillPriorities(i)) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala index 8245ea8e99c..44286fdc3e0 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala @@ -196,7 +196,7 @@ class RapidsDiskStoreSuite extends FunSuite with BeforeAndAfterEach with Arm wit devStore: RapidsDeviceMemoryStore, bufferId: RapidsBufferId, spillPriority: Long): Long = { - closeOnExcept(buildContiguousTable()) { ct => + withResource(buildContiguousTable()) { ct => val bufferSize = ct.getBuffer.getLength // store takes ownership of the table devStore.addContiguousTable(bufferId, ct, spillPriority) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala index 7b84429272c..c66b0bb6411 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala @@ -63,7 +63,7 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { assertResult(hostStoreMaxSize)(hostStore.numBytesFree) devStore.setSpillStore(hostStore) - val bufferSize = closeOnExcept(buildContiguousTable()) { ct => + val bufferSize = withResource(buildContiguousTable()) { ct => val len = ct.getBuffer.getLength // store takes ownership of the table devStore.addContiguousTable(bufferId, ct, spillPriority) @@ -93,14 +93,10 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => withResource(new RapidsHostMemoryStore(hostStoreMaxSize, catalog)) { hostStore => devStore.setSpillStore(hostStore) - var ct = buildContiguousTable() - try { + withResource(buildContiguousTable()) { ct => withResource(HostMemoryBuffer.allocate(ct.getBuffer.getLength)) { expectedBuffer => expectedBuffer.copyFromDeviceBuffer(ct.getBuffer) - // store takes ownership of the table devStore.addContiguousTable(bufferId, ct, spillPriority) - ct = null - devStore.synchronousSpill(0) withResource(catalog.acquireBuffer(bufferId)) { buffer => withResource(buffer.getMemoryBuffer) { actualBuffer => @@ -111,10 +107,6 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { } } } - } finally { - if (ct != null) { - ct.close() - } } } } @@ -130,14 +122,10 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => withResource(new RapidsHostMemoryStore(hostStoreMaxSize, catalog)) { hostStore => devStore.setSpillStore(hostStore) - var ct = buildContiguousTable() - try { + withResource(buildContiguousTable()) { ct => withResource(GpuColumnVector.from(ct.getTable, sparkTypes)) { expectedBatch => - // store takes ownership of the table devStore.addContiguousTable(bufferId, ct, spillPriority) - ct = null - devStore.synchronousSpill(0) withResource(catalog.acquireBuffer(bufferId)) { buffer => assertResult(StorageTier.HOST)(buffer.storageTier) @@ -146,10 +134,6 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { } } } - } finally { - if (ct != null) { - ct.close() - } } } } @@ -167,35 +151,28 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { withResource(new RapidsHostMemoryStore(hostStoreMaxSize, catalog)) { hostStore => devStore.setSpillStore(hostStore) hostStore.setSpillStore(mockStore) - var bigTable: ContiguousTable = null - var smallTable: ContiguousTable = null - try { - bigTable = buildContiguousTable(1024 * 1024) - smallTable = buildContiguousTable(1) - withResource(GpuColumnVector.from(bigTable.getTable, sparkTypes)) { expectedBatch => - // store takes ownership of the table - devStore.addContiguousTable(bigBufferId, bigTable, spillPriority) - bigTable = null - - devStore.synchronousSpill(0) - verify(mockStore, never()).copyBuffer(ArgumentMatchers.any[RapidsBuffer], - ArgumentMatchers.any[Cuda.Stream]) - withResource(catalog.acquireBuffer(bigBufferId)) { buffer => - assertResult(StorageTier.HOST)(buffer.storageTier) - withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => - TestUtils.compareBatches(expectedBatch, actualBatch) + withResource(buildContiguousTable(1024 * 1024)) { bigTable => + withResource(buildContiguousTable(1)) { smallTable => + withResource(GpuColumnVector.from(bigTable.getTable, sparkTypes)) { expectedBatch => + // store takes ownership of the table + devStore.addContiguousTable(bigBufferId, bigTable, spillPriority) + devStore.synchronousSpill(0) + verify(mockStore, never()).copyBuffer(ArgumentMatchers.any[RapidsBuffer], + ArgumentMatchers.any[Cuda.Stream]) + withResource(catalog.acquireBuffer(bigBufferId)) { buffer => + assertResult(StorageTier.HOST)(buffer.storageTier) + withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => + TestUtils.compareBatches(expectedBatch, actualBatch) + } } - } - devStore.addContiguousTable(smallBufferId, smallTable, spillPriority) - smallTable = null - devStore.synchronousSpill(0) - val ac: ArgumentCaptor[RapidsBuffer] = ArgumentCaptor.forClass(classOf[RapidsBuffer]) - verify(mockStore).copyBuffer(ac.capture(), ArgumentMatchers.any[Cuda.Stream]) - assertResult(bigBufferId)(ac.getValue.id) + devStore.addContiguousTable(smallBufferId, smallTable, spillPriority) + devStore.synchronousSpill(0) + val ac: ArgumentCaptor[RapidsBuffer] = ArgumentCaptor.forClass(classOf[RapidsBuffer]) + verify(mockStore).copyBuffer(ac.capture(), ArgumentMatchers.any[Cuda.Stream]) + assertResult(bigBufferId)(ac.getValue.id) + } } - } finally { - Seq(bigTable, smallTable).safeClose() } } }