Skip to content

Commit

Permalink
Iterator to make it easier to work with a window of blocks in the RAP…
Browse files Browse the repository at this point in the history
…IDS shuffle (NVIDIA#934)

* Adds iterator to make it easier to work with ranges of blocks

Signed-off-by: Alessandro Bellina <abellina@nvidia.com>
  • Loading branch information
abellina authored Oct 21, 2020
1 parent 7886cb4 commit d99086e
Show file tree
Hide file tree
Showing 2 changed files with 308 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.shuffle

import scala.collection.mutable.ArrayBuffer

// Helper trait that callers can use to add blocks to the iterator
// as long as they can provide a size
trait BlockWithSize {
/**
* Abstract method to return the size in bytes of this block
* @return Long - size in bytes
*/
def size: Long
}

/**
* Specifies a start and end range of bytes for a block.
* @param block - a BlockWithSize instance
* @param rangeStart - byte offset for the start of the range (inclusive)
* @param rangeEnd - byte offset for the end of the range (exclusive)
* @tparam T - the specific type of `BlockWithSize`
*/
case class BlockRange[T <: BlockWithSize](
block: T, rangeStart: Long, rangeEnd: Long) {
require(rangeStart < rangeEnd,
s"Instantiated a BlockRange with invalid boundaries: $rangeStart to $rangeEnd")

/**
* Returns the size of this range in bytes
* @return - Long - size in bytes
*/
def rangeSize(): Long = rangeEnd - rangeStart

def isComplete(): Boolean = rangeEnd == block.size
}

/**
* Given a set of blocks, this iterator returns BlockRanges
* of such blocks that fit `windowSize`. The ranges are just logical
* chunks of the blocks, so this class performs no memory management or copying.
*
* If a block is too large for the window, the block will be
* returned in `next()` until the full block can be covered.
*
* For example, given a block that is 4 window-sizes in length:
* block = [sb1, sb2, sb3, sb4]
*
* The window will return on `next()` four "sub-blocks", governed by `windowSize`:
* window.next() // sb1
* window.next() // sb2
* window.next() // sb3
* window.next() // sb4
*
* If blocks are smaller than the `windowSize`, they will be packed:
* block1 = [b1]
* block2 = [b2]
* window.next() // [b1, b2]
*
* A mix of both scenarios above is possible:
* block1 = [sb11, sb12, sb13] // where sb13 is smaller than window length
* block2 = [b2]
*
* window.next() // sb11
* window.next() // sb12
* window.next() // [sb13, b2]
*
* @param blocks - sequence of blocks to manage
* @param windowSize - the size (in bytes) that block ranges should fit
* @tparam T - the specific type of `BlockWithSize`
* @note this class does not own `transferBlocks`
* @note this class is not thread safe
*/
class WindowedBlockIterator[T <: BlockWithSize](blocks: Seq[T], windowSize: Long)
extends Iterator[Seq[BlockRange[T]]] {

require(windowSize > 0, s"Invalid window size specified $windowSize")

private case class BlockWindow(start: Long, size: Long) {
val end = start + size // exclusive end offset
def move(): BlockWindow = {
BlockWindow(start + size, size)
}
}

// start the window at byte 0
private[this] var window = BlockWindow(0, windowSize)
private[this] var done = false

// helper class that captures the start/end byte offset
// for `block` on creation
private case class BlockWithOffset[T <: BlockWithSize](
block: T, startOffset: Long, endOffset: Long)

private[this] val blocksWithOffsets = {
var lastOffset = 0L
blocks.map { block =>
require(block.size > 0, "Invalid 0-byte block")
val startOffset = lastOffset
val endOffset = startOffset + block.size
lastOffset = endOffset // for next block
BlockWithOffset(block, startOffset, endOffset)
}
}

// the last block index that made it into a window, which
// is an index into the `blocksWithOffsets` sequence
private[this] var lastSeenBlock = 0

case class BlocksForWindow(lastBlockIndex: Option[Int],
blockRanges: Seq[BlockRange[T]],
hasMoreBlocks: Boolean)

private def getBlocksForWindow(
window: BlockWindow,
startingBlock: Int = 0): BlocksForWindow = {
val blockRangesInWindow = new ArrayBuffer[BlockRange[T]]()
var continue = true
var thisBlock = startingBlock
var lastBlockIndex: Option[Int] = None
while (continue && thisBlock < blocksWithOffsets.size) {
val b = blocksWithOffsets(thisBlock)
// if at least 1 byte fits within the window, this block should be included
if (window.start < b.endOffset && window.end > b.startOffset) {
var rangeStart = window.start - b.startOffset
if (rangeStart < 0) {
rangeStart = 0
}
var rangeEnd = window.end - b.startOffset
if (window.end >= b.endOffset) {
rangeEnd = b.endOffset - b.startOffset
}
blockRangesInWindow.append(BlockRange[T](b.block, rangeStart, rangeEnd))
lastBlockIndex = Some(thisBlock)
} else {
// skip this block, unless it's before our window starts
continue = b.endOffset <= window.start
}
thisBlock = thisBlock + 1
}
val lastBlock = blockRangesInWindow.last
BlocksForWindow(lastBlockIndex,
blockRangesInWindow,
!continue || !lastBlock.isComplete())
}

def next(): Seq[BlockRange[T]] = {
if (!hasNext) {
throw new NoSuchElementException(s"BounceBufferWindow $window has been exhausted.")
}

val blocksForWindow = getBlocksForWindow(window, lastSeenBlock)
lastSeenBlock = blocksForWindow.lastBlockIndex.getOrElse(0)

if (blocksForWindow.hasMoreBlocks) {
window = window.move()
} else {
done = true
}

blocksForWindow.blockRanges
}

override def hasNext: Boolean = !done && blocksWithOffsets.nonEmpty
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.shuffle

import java.util.NoSuchElementException

import org.mockito.Mockito._

class WindowedBlockIteratorSuite extends RapidsShuffleTestHelper {
test ("empty iterator throws on next") {
val wbi = new WindowedBlockIterator[BlockWithSize](Seq.empty, 1024)
assertResult(false)(wbi.hasNext)
assertThrows[NoSuchElementException](wbi.next)
}

test ("1-byte+ ranges are allowed, but 0-byte or negative ranges are not") {
assertResult(1)(BlockRange(null, 123, 124).rangeSize())
assertResult(2)(BlockRange(null, 123, 125).rangeSize())
assertThrows[IllegalArgumentException](BlockRange(null, 123, 123))
assertThrows[IllegalArgumentException](BlockRange(null, 123, 122))
}

test ("0-byte blocks are not allowed") {
val block = mock[BlockWithSize]
when(block.size).thenReturn(0)
assertThrows[IllegalArgumentException](
new WindowedBlockIterator[BlockWithSize](Seq(block), 1024))
}

test ("1024 1-byte blocks all fit in 1 1024-byte window") {
val mockBlocks = (0 until 1024).map { i =>
val block = mock[BlockWithSize]
when(block.size).thenReturn(1)
block
}
val wbi = new WindowedBlockIterator[BlockWithSize](mockBlocks, 1024)
assertResult(true)(wbi.hasNext)
val blockRange = wbi.next()
assertResult(1024)(blockRange.size)
blockRange.foreach { br =>
assertResult(1)(br.rangeSize())
assertResult(0)(br.rangeStart)
assertResult(1)(br.rangeEnd)
}
assertResult(false)(wbi.hasNext)
assertThrows[NoSuchElementException](wbi.next)
}

test ("a block larger than the window is split between calls to next") {
val block = mock[BlockWithSize]
when(block.size).thenReturn(2049)

val wbi = new WindowedBlockIterator[BlockWithSize](Seq(block), 1024)
assertResult(true)(wbi.hasNext)
val blockRanges = wbi.next()
assertResult(1)(blockRanges.size)

val blockRange = blockRanges.head
assertResult(1024)(blockRange.rangeSize())
assertResult(0)(blockRange.rangeStart)
assertResult(1024)(blockRange.rangeEnd)
assertResult(true)(wbi.hasNext)

val blockRangesMiddle = wbi.next()
val blockRangeMiddle = blockRangesMiddle.head
assertResult(1024)(blockRangeMiddle.rangeSize())
assertResult(1024)(blockRangeMiddle.rangeStart)
assertResult(2048)(blockRangeMiddle.rangeEnd)
assertResult(true)(wbi.hasNext)

val blockRangesLastByte = wbi.next()
val blockRangeLastByte = blockRangesLastByte.head
assertResult(1)(blockRangeLastByte.rangeSize())
assertResult(2048)(blockRangeLastByte.rangeStart)
assertResult(2049)(blockRangeLastByte.rangeEnd)

assertResult(false)(wbi.hasNext)
assertThrows[NoSuchElementException](wbi.next)
}

test ("a block fits entirely, but a subsequent block doesn't") {
val block = mock[BlockWithSize]
when(block.size).thenReturn(1000)

val block2 = mock[BlockWithSize]
when(block2.size).thenReturn(1000)

val wbi = new WindowedBlockIterator[BlockWithSize](Seq(block, block2), 1024)
assertResult(true)(wbi.hasNext)
val blockRanges = wbi.next()
assertResult(2)(blockRanges.size)

val firstBlock = blockRanges(0)
val secondBlock = blockRanges(1)

assertResult(1000)(firstBlock.rangeSize())
assertResult(0)(firstBlock.rangeStart)
assertResult(1000)(firstBlock.rangeEnd)
assertResult(true)(wbi.hasNext)

assertResult(24)(secondBlock.rangeSize())
assertResult(0)(secondBlock.rangeStart)
assertResult(24)(secondBlock.rangeEnd)
assertResult(true)(wbi.hasNext)

val blockRangesLastByte = wbi.next()
val blockRangeLastByte = blockRangesLastByte.head
assertResult(976)(blockRangeLastByte.rangeSize())
assertResult(24)(blockRangeLastByte.rangeStart)
assertResult(1000)(blockRangeLastByte.rangeEnd)

assertResult(false)(wbi.hasNext)
assertThrows[NoSuchElementException](wbi.next)
}
}

0 comments on commit d99086e

Please sign in to comment.