forked from NVIDIA/spark-rapids
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Iterator to make it easier to work with a window of blocks in the RAP…
…IDS shuffle (NVIDIA#934) * Adds iterator to make it easier to work with ranges of blocks Signed-off-by: Alessandro Bellina <abellina@nvidia.com>
- Loading branch information
Showing
2 changed files
with
308 additions
and
0 deletions.
There are no files selected for viewing
179 changes: 179 additions & 0 deletions
179
sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/WindowedBlockIterator.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
/* | ||
* Copyright (c) 2020, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.nvidia.spark.rapids.shuffle | ||
|
||
import scala.collection.mutable.ArrayBuffer | ||
|
||
// Helper trait that callers can use to add blocks to the iterator | ||
// as long as they can provide a size | ||
trait BlockWithSize { | ||
/** | ||
* Abstract method to return the size in bytes of this block | ||
* @return Long - size in bytes | ||
*/ | ||
def size: Long | ||
} | ||
|
||
/** | ||
* Specifies a start and end range of bytes for a block. | ||
* @param block - a BlockWithSize instance | ||
* @param rangeStart - byte offset for the start of the range (inclusive) | ||
* @param rangeEnd - byte offset for the end of the range (exclusive) | ||
* @tparam T - the specific type of `BlockWithSize` | ||
*/ | ||
case class BlockRange[T <: BlockWithSize]( | ||
block: T, rangeStart: Long, rangeEnd: Long) { | ||
require(rangeStart < rangeEnd, | ||
s"Instantiated a BlockRange with invalid boundaries: $rangeStart to $rangeEnd") | ||
|
||
/** | ||
* Returns the size of this range in bytes | ||
* @return - Long - size in bytes | ||
*/ | ||
def rangeSize(): Long = rangeEnd - rangeStart | ||
|
||
def isComplete(): Boolean = rangeEnd == block.size | ||
} | ||
|
||
/** | ||
* Given a set of blocks, this iterator returns BlockRanges | ||
* of such blocks that fit `windowSize`. The ranges are just logical | ||
* chunks of the blocks, so this class performs no memory management or copying. | ||
* | ||
* If a block is too large for the window, the block will be | ||
* returned in `next()` until the full block can be covered. | ||
* | ||
* For example, given a block that is 4 window-sizes in length: | ||
* block = [sb1, sb2, sb3, sb4] | ||
* | ||
* The window will return on `next()` four "sub-blocks", governed by `windowSize`: | ||
* window.next() // sb1 | ||
* window.next() // sb2 | ||
* window.next() // sb3 | ||
* window.next() // sb4 | ||
* | ||
* If blocks are smaller than the `windowSize`, they will be packed: | ||
* block1 = [b1] | ||
* block2 = [b2] | ||
* window.next() // [b1, b2] | ||
* | ||
* A mix of both scenarios above is possible: | ||
* block1 = [sb11, sb12, sb13] // where sb13 is smaller than window length | ||
* block2 = [b2] | ||
* | ||
* window.next() // sb11 | ||
* window.next() // sb12 | ||
* window.next() // [sb13, b2] | ||
* | ||
* @param blocks - sequence of blocks to manage | ||
* @param windowSize - the size (in bytes) that block ranges should fit | ||
* @tparam T - the specific type of `BlockWithSize` | ||
* @note this class does not own `transferBlocks` | ||
* @note this class is not thread safe | ||
*/ | ||
class WindowedBlockIterator[T <: BlockWithSize](blocks: Seq[T], windowSize: Long) | ||
extends Iterator[Seq[BlockRange[T]]] { | ||
|
||
require(windowSize > 0, s"Invalid window size specified $windowSize") | ||
|
||
private case class BlockWindow(start: Long, size: Long) { | ||
val end = start + size // exclusive end offset | ||
def move(): BlockWindow = { | ||
BlockWindow(start + size, size) | ||
} | ||
} | ||
|
||
// start the window at byte 0 | ||
private[this] var window = BlockWindow(0, windowSize) | ||
private[this] var done = false | ||
|
||
// helper class that captures the start/end byte offset | ||
// for `block` on creation | ||
private case class BlockWithOffset[T <: BlockWithSize]( | ||
block: T, startOffset: Long, endOffset: Long) | ||
|
||
private[this] val blocksWithOffsets = { | ||
var lastOffset = 0L | ||
blocks.map { block => | ||
require(block.size > 0, "Invalid 0-byte block") | ||
val startOffset = lastOffset | ||
val endOffset = startOffset + block.size | ||
lastOffset = endOffset // for next block | ||
BlockWithOffset(block, startOffset, endOffset) | ||
} | ||
} | ||
|
||
// the last block index that made it into a window, which | ||
// is an index into the `blocksWithOffsets` sequence | ||
private[this] var lastSeenBlock = 0 | ||
|
||
case class BlocksForWindow(lastBlockIndex: Option[Int], | ||
blockRanges: Seq[BlockRange[T]], | ||
hasMoreBlocks: Boolean) | ||
|
||
private def getBlocksForWindow( | ||
window: BlockWindow, | ||
startingBlock: Int = 0): BlocksForWindow = { | ||
val blockRangesInWindow = new ArrayBuffer[BlockRange[T]]() | ||
var continue = true | ||
var thisBlock = startingBlock | ||
var lastBlockIndex: Option[Int] = None | ||
while (continue && thisBlock < blocksWithOffsets.size) { | ||
val b = blocksWithOffsets(thisBlock) | ||
// if at least 1 byte fits within the window, this block should be included | ||
if (window.start < b.endOffset && window.end > b.startOffset) { | ||
var rangeStart = window.start - b.startOffset | ||
if (rangeStart < 0) { | ||
rangeStart = 0 | ||
} | ||
var rangeEnd = window.end - b.startOffset | ||
if (window.end >= b.endOffset) { | ||
rangeEnd = b.endOffset - b.startOffset | ||
} | ||
blockRangesInWindow.append(BlockRange[T](b.block, rangeStart, rangeEnd)) | ||
lastBlockIndex = Some(thisBlock) | ||
} else { | ||
// skip this block, unless it's before our window starts | ||
continue = b.endOffset <= window.start | ||
} | ||
thisBlock = thisBlock + 1 | ||
} | ||
val lastBlock = blockRangesInWindow.last | ||
BlocksForWindow(lastBlockIndex, | ||
blockRangesInWindow, | ||
!continue || !lastBlock.isComplete()) | ||
} | ||
|
||
def next(): Seq[BlockRange[T]] = { | ||
if (!hasNext) { | ||
throw new NoSuchElementException(s"BounceBufferWindow $window has been exhausted.") | ||
} | ||
|
||
val blocksForWindow = getBlocksForWindow(window, lastSeenBlock) | ||
lastSeenBlock = blocksForWindow.lastBlockIndex.getOrElse(0) | ||
|
||
if (blocksForWindow.hasMoreBlocks) { | ||
window = window.move() | ||
} else { | ||
done = true | ||
} | ||
|
||
blocksForWindow.blockRanges | ||
} | ||
|
||
override def hasNext: Boolean = !done && blocksWithOffsets.nonEmpty | ||
} |
129 changes: 129 additions & 0 deletions
129
tests/src/test/scala/com/nvidia/spark/rapids/shuffle/WindowedBlockIteratorSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
/* | ||
* Copyright (c) 2020, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.nvidia.spark.rapids.shuffle | ||
|
||
import java.util.NoSuchElementException | ||
|
||
import org.mockito.Mockito._ | ||
|
||
class WindowedBlockIteratorSuite extends RapidsShuffleTestHelper { | ||
test ("empty iterator throws on next") { | ||
val wbi = new WindowedBlockIterator[BlockWithSize](Seq.empty, 1024) | ||
assertResult(false)(wbi.hasNext) | ||
assertThrows[NoSuchElementException](wbi.next) | ||
} | ||
|
||
test ("1-byte+ ranges are allowed, but 0-byte or negative ranges are not") { | ||
assertResult(1)(BlockRange(null, 123, 124).rangeSize()) | ||
assertResult(2)(BlockRange(null, 123, 125).rangeSize()) | ||
assertThrows[IllegalArgumentException](BlockRange(null, 123, 123)) | ||
assertThrows[IllegalArgumentException](BlockRange(null, 123, 122)) | ||
} | ||
|
||
test ("0-byte blocks are not allowed") { | ||
val block = mock[BlockWithSize] | ||
when(block.size).thenReturn(0) | ||
assertThrows[IllegalArgumentException]( | ||
new WindowedBlockIterator[BlockWithSize](Seq(block), 1024)) | ||
} | ||
|
||
test ("1024 1-byte blocks all fit in 1 1024-byte window") { | ||
val mockBlocks = (0 until 1024).map { i => | ||
val block = mock[BlockWithSize] | ||
when(block.size).thenReturn(1) | ||
block | ||
} | ||
val wbi = new WindowedBlockIterator[BlockWithSize](mockBlocks, 1024) | ||
assertResult(true)(wbi.hasNext) | ||
val blockRange = wbi.next() | ||
assertResult(1024)(blockRange.size) | ||
blockRange.foreach { br => | ||
assertResult(1)(br.rangeSize()) | ||
assertResult(0)(br.rangeStart) | ||
assertResult(1)(br.rangeEnd) | ||
} | ||
assertResult(false)(wbi.hasNext) | ||
assertThrows[NoSuchElementException](wbi.next) | ||
} | ||
|
||
test ("a block larger than the window is split between calls to next") { | ||
val block = mock[BlockWithSize] | ||
when(block.size).thenReturn(2049) | ||
|
||
val wbi = new WindowedBlockIterator[BlockWithSize](Seq(block), 1024) | ||
assertResult(true)(wbi.hasNext) | ||
val blockRanges = wbi.next() | ||
assertResult(1)(blockRanges.size) | ||
|
||
val blockRange = blockRanges.head | ||
assertResult(1024)(blockRange.rangeSize()) | ||
assertResult(0)(blockRange.rangeStart) | ||
assertResult(1024)(blockRange.rangeEnd) | ||
assertResult(true)(wbi.hasNext) | ||
|
||
val blockRangesMiddle = wbi.next() | ||
val blockRangeMiddle = blockRangesMiddle.head | ||
assertResult(1024)(blockRangeMiddle.rangeSize()) | ||
assertResult(1024)(blockRangeMiddle.rangeStart) | ||
assertResult(2048)(blockRangeMiddle.rangeEnd) | ||
assertResult(true)(wbi.hasNext) | ||
|
||
val blockRangesLastByte = wbi.next() | ||
val blockRangeLastByte = blockRangesLastByte.head | ||
assertResult(1)(blockRangeLastByte.rangeSize()) | ||
assertResult(2048)(blockRangeLastByte.rangeStart) | ||
assertResult(2049)(blockRangeLastByte.rangeEnd) | ||
|
||
assertResult(false)(wbi.hasNext) | ||
assertThrows[NoSuchElementException](wbi.next) | ||
} | ||
|
||
test ("a block fits entirely, but a subsequent block doesn't") { | ||
val block = mock[BlockWithSize] | ||
when(block.size).thenReturn(1000) | ||
|
||
val block2 = mock[BlockWithSize] | ||
when(block2.size).thenReturn(1000) | ||
|
||
val wbi = new WindowedBlockIterator[BlockWithSize](Seq(block, block2), 1024) | ||
assertResult(true)(wbi.hasNext) | ||
val blockRanges = wbi.next() | ||
assertResult(2)(blockRanges.size) | ||
|
||
val firstBlock = blockRanges(0) | ||
val secondBlock = blockRanges(1) | ||
|
||
assertResult(1000)(firstBlock.rangeSize()) | ||
assertResult(0)(firstBlock.rangeStart) | ||
assertResult(1000)(firstBlock.rangeEnd) | ||
assertResult(true)(wbi.hasNext) | ||
|
||
assertResult(24)(secondBlock.rangeSize()) | ||
assertResult(0)(secondBlock.rangeStart) | ||
assertResult(24)(secondBlock.rangeEnd) | ||
assertResult(true)(wbi.hasNext) | ||
|
||
val blockRangesLastByte = wbi.next() | ||
val blockRangeLastByte = blockRangesLastByte.head | ||
assertResult(976)(blockRangeLastByte.rangeSize()) | ||
assertResult(24)(blockRangeLastByte.rangeStart) | ||
assertResult(1000)(blockRangeLastByte.rangeEnd) | ||
|
||
assertResult(false)(wbi.hasNext) | ||
assertThrows[NoSuchElementException](wbi.next) | ||
} | ||
} |