Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Iterator to make it easier to work with a window of blocks in the RAPIDS shuffle #934

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.shuffle

import scala.collection.mutable.ArrayBuffer

// Helper trait that callers can use to add blocks to the iterator
// as long as they can provide a size
trait BlockWithSize {
/**
* Abstract method to return the size in bytes of this block
* @return Long - size in bytes
*/
def size: Long
}

/**
* Specifies a start and end range of bytes for a block.
* @param block - a BlockWithSize instance
* @param rangeStart - byte offset for the start of the range (inclusive)
* @param rangeEnd - byte offset for the end of the range (exclusive)
* @tparam T - the specific type of `BlockWithSize`
*/
case class BlockRange[T <: BlockWithSize](
block: T, rangeStart: Long, rangeEnd: Long) {
require(rangeStart < rangeEnd,
s"Instantiated a BlockRange with invalid boundaries: $rangeStart to $rangeEnd")

revans2 marked this conversation as resolved.
Show resolved Hide resolved
/**
* Returns the size of this range in bytes
* @return - Long - size in bytes
*/
def rangeSize(): Long = rangeEnd - rangeStart

def isComplete(): Boolean = rangeEnd == block.size
}

/**
* Given a set of blocks, this iterator returns BlockRanges
* of such blocks that fit `windowSize`. The ranges are just logical
* chunks of the blocks, so this class performs no memory management or copying.
*
* If a block is too large for the window, the block will be
revans2 marked this conversation as resolved.
Show resolved Hide resolved
* returned in `next()` until the full block can be covered.
*
* For example, given a block that is 4 window-sizes in length:
* block = [sb1, sb2, sb3, sb4]
*
* The window will return on `next()` four "sub-blocks", governed by `windowSize`:
* window.next() // sb1
* window.next() // sb2
* window.next() // sb3
* window.next() // sb4
*
* If blocks are smaller than the `windowSize`, they will be packed:
* block1 = [b1]
* block2 = [b2]
* window.next() // [b1, b2]
*
* A mix of both scenarios above is possible:
* block1 = [sb11, sb12, sb13] // where sb13 is smaller than window length
* block2 = [b2]
*
* window.next() // sb11
* window.next() // sb12
* window.next() // [sb13, b2]
*
* @param blocks - sequence of blocks to manage
* @param windowSize - the size (in bytes) that block ranges should fit
* @tparam T - the specific type of `BlockWithSize`
* @note this class does not own `transferBlocks`
* @note this class is not thread safe
*/
class WindowedBlockIterator[T <: BlockWithSize](blocks: Seq[T], windowSize: Long)
extends Iterator[Seq[BlockRange[T]]] {

require(windowSize > 0, s"Invalid window size specified $windowSize")

private case class BlockWindow(start: Long, size: Long) {
val end = start + size // exclusive end offset
def move(): BlockWindow = {
BlockWindow(start + size, size)
}
}

// start the window at byte 0
private[this] var window = BlockWindow(0, windowSize)
private[this] var done = false

// helper class that captures the start/end byte offset
// for `block` on creation
private case class BlockWithOffset[T <: BlockWithSize](
block: T, startOffset: Long, endOffset: Long)

private[this] val blocksWithOffsets = {
var lastOffset = 0L
blocks.map { block =>
require(block.size > 0, "Invalid 0-byte block")
val startOffset = lastOffset
val endOffset = startOffset + block.size
lastOffset = endOffset // for next block
BlockWithOffset(block, startOffset, endOffset)
}
}

// the last block index that made it into a window, which
// is an index into the `blocksWithOffsets` sequence
private[this] var lastSeenBlock = 0

case class BlocksForWindow(lastBlockIndex: Option[Int],
blockRanges: Seq[BlockRange[T]],
hasMoreBlocks: Boolean)

private def getBlocksForWindow(
window: BlockWindow,
startingBlock: Int = 0): BlocksForWindow = {
val blockRangesInWindow = new ArrayBuffer[BlockRange[T]]()
var continue = true
var thisBlock = startingBlock
var lastBlockIndex: Option[Int] = None
while (continue && thisBlock < blocksWithOffsets.size) {
val b = blocksWithOffsets(thisBlock)
// if at least 1 byte fits within the window, this block should be included
if (window.start < b.endOffset && window.end > b.startOffset) {
var rangeStart = window.start - b.startOffset
if (rangeStart < 0) {
rangeStart = 0
}
var rangeEnd = window.end - b.startOffset
if (window.end >= b.endOffset) {
rangeEnd = b.endOffset - b.startOffset
}
blockRangesInWindow.append(BlockRange[T](b.block, rangeStart, rangeEnd))
lastBlockIndex = Some(thisBlock)
} else {
// skip this block, unless it's before our window starts
continue = b.endOffset <= window.start
}
thisBlock = thisBlock + 1
}
val lastBlock = blockRangesInWindow.last
BlocksForWindow(lastBlockIndex,
blockRangesInWindow,
!continue || !lastBlock.isComplete())
}

def next(): Seq[BlockRange[T]] = {
if (!hasNext) {
throw new NoSuchElementException(s"BounceBufferWindow $window has been exhausted.")
}

val blocksForWindow = getBlocksForWindow(window, lastSeenBlock)
lastSeenBlock = blocksForWindow.lastBlockIndex.getOrElse(0)

if (blocksForWindow.hasMoreBlocks) {
window = window.move()
} else {
done = true
}

blocksForWindow.blockRanges
}

override def hasNext: Boolean = !done && blocksWithOffsets.nonEmpty
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.shuffle

import java.util.NoSuchElementException

import org.mockito.Mockito._

class WindowedBlockIteratorSuite extends RapidsShuffleTestHelper {
test ("empty iterator throws on next") {
val wbi = new WindowedBlockIterator[BlockWithSize](Seq.empty, 1024)
assertResult(false)(wbi.hasNext)
assertThrows[NoSuchElementException](wbi.next)
}

test ("1-byte+ ranges are allowed, but 0-byte or negative ranges are not") {
assertResult(1)(BlockRange(null, 123, 124).rangeSize())
assertResult(2)(BlockRange(null, 123, 125).rangeSize())
assertThrows[IllegalArgumentException](BlockRange(null, 123, 123))
assertThrows[IllegalArgumentException](BlockRange(null, 123, 122))
}

test ("0-byte blocks are not allowed") {
val block = mock[BlockWithSize]
when(block.size).thenReturn(0)
assertThrows[IllegalArgumentException](
new WindowedBlockIterator[BlockWithSize](Seq(block), 1024))
}

test ("1024 1-byte blocks all fit in 1 1024-byte window") {
val mockBlocks = (0 until 1024).map { i =>
val block = mock[BlockWithSize]
when(block.size).thenReturn(1)
block
}
val wbi = new WindowedBlockIterator[BlockWithSize](mockBlocks, 1024)
assertResult(true)(wbi.hasNext)
val blockRange = wbi.next()
assertResult(1024)(blockRange.size)
blockRange.foreach { br =>
assertResult(1)(br.rangeSize())
assertResult(0)(br.rangeStart)
assertResult(1)(br.rangeEnd)
}
assertResult(false)(wbi.hasNext)
assertThrows[NoSuchElementException](wbi.next)
}

test ("a block larger than the window is split between calls to next") {
val block = mock[BlockWithSize]
when(block.size).thenReturn(2049)

val wbi = new WindowedBlockIterator[BlockWithSize](Seq(block), 1024)
assertResult(true)(wbi.hasNext)
val blockRanges = wbi.next()
assertResult(1)(blockRanges.size)

val blockRange = blockRanges.head
assertResult(1024)(blockRange.rangeSize())
assertResult(0)(blockRange.rangeStart)
assertResult(1024)(blockRange.rangeEnd)
assertResult(true)(wbi.hasNext)

val blockRangesMiddle = wbi.next()
val blockRangeMiddle = blockRangesMiddle.head
assertResult(1024)(blockRangeMiddle.rangeSize())
assertResult(1024)(blockRangeMiddle.rangeStart)
assertResult(2048)(blockRangeMiddle.rangeEnd)
assertResult(true)(wbi.hasNext)

val blockRangesLastByte = wbi.next()
val blockRangeLastByte = blockRangesLastByte.head
assertResult(1)(blockRangeLastByte.rangeSize())
assertResult(2048)(blockRangeLastByte.rangeStart)
assertResult(2049)(blockRangeLastByte.rangeEnd)

assertResult(false)(wbi.hasNext)
assertThrows[NoSuchElementException](wbi.next)
}

test ("a block fits entirely, but a subsequent block doesn't") {
val block = mock[BlockWithSize]
when(block.size).thenReturn(1000)

val block2 = mock[BlockWithSize]
when(block2.size).thenReturn(1000)

val wbi = new WindowedBlockIterator[BlockWithSize](Seq(block, block2), 1024)
assertResult(true)(wbi.hasNext)
val blockRanges = wbi.next()
assertResult(2)(blockRanges.size)

val firstBlock = blockRanges(0)
val secondBlock = blockRanges(1)

assertResult(1000)(firstBlock.rangeSize())
assertResult(0)(firstBlock.rangeStart)
assertResult(1000)(firstBlock.rangeEnd)
assertResult(true)(wbi.hasNext)

assertResult(24)(secondBlock.rangeSize())
assertResult(0)(secondBlock.rangeStart)
assertResult(24)(secondBlock.rangeEnd)
assertResult(true)(wbi.hasNext)

val blockRangesLastByte = wbi.next()
val blockRangeLastByte = blockRangesLastByte.head
assertResult(976)(blockRangeLastByte.rangeSize())
assertResult(24)(blockRangeLastByte.rangeStart)
assertResult(1000)(blockRangeLastByte.rangeEnd)

assertResult(false)(wbi.hasNext)
assertThrows[NoSuchElementException](wbi.next)
}
}