From 3410f68791695d8f2493595eafc30aa86fded902 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Wed, 7 Jul 2021 13:52:26 -0500 Subject: [PATCH] Refactor window operations to do them in the exec (#2882) Signed-off-by: Robert (Bobby) Evans --- .../scala/com/nvidia/spark/rapids/Arm.scala | 2 + .../nvidia/spark/rapids/GpuWindowExec.scala | 685 ++++++++++++++---- .../spark/rapids/GpuWindowExpression.scala | 327 +-------- 3 files changed, 558 insertions(+), 456 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala index 75411908ad1..6aa449134ce 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala @@ -74,6 +74,8 @@ trait Arm { } finally { r match { case c: AutoCloseable => c.close() + case scala.util.Left(c: AutoCloseable) => c.close() + case scala.util.Right(c: AutoCloseable) => c.close() case _ => //NOOP } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala index 0aa5eff6b1b..a875696e7d1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala @@ -16,22 +16,25 @@ package com.nvidia.spark.rapids +import java.util.concurrent.TimeUnit + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import ai.rapids.cudf.Scalar +import ai.rapids.cudf.{AggregationOverWindow, DType, NvtxColor, Scalar, Table, WindowOptions} import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, AttributeSet, CurrentRow, Expression, NamedExpression, RowFrame, SortOrder, UnboundedPreceding} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, AttributeSeq, AttributeSet, CurrentRow, Expression, FrameType, NamedExpression, RangeFrame, RowFrame, SortOrder, UnboundedPreceding} import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.rapids.GpuAggregateExpression -import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.{BooleanType, ByteType, CalendarIntervalType, DataType, IntegerType, LongType, ShortType} import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} +import org.apache.spark.unsafe.types.CalendarInterval /** * Base class for GPU Execs that implement window functions. This abstracts the method @@ -353,61 +356,413 @@ object GpuWindowExec extends Arm { GpuSpecialFrameBoundary(UnboundedPreceding), GpuSpecialFrameBoundary(CurrentRow))) => true case _ => false } +} - def fixerIndexMap(windowExpressionAliases: Seq[Expression]): Map[Int, BatchedRunningWindowFixer] = - windowExpressionAliases.zipWithIndex.flatMap { - case (GpuAlias(GpuWindowExpression(func, _), _), index) => - func match { - case f: GpuBatchedRunningWindowFunction[_] => - Some((index, f.newFixer())) - case GpuAggregateExpression(f: GpuBatchedRunningWindowFunction[_], _, _, _, _) => - Some((index, f.newFixer())) - case _ => None +trait GpuWindowBaseExec extends UnaryExecNode with GpuExec { + val windowOps: Seq[NamedExpression] + val partitionSpec: Seq[Expression] + val orderSpec: Seq[SortOrder] + + import GpuMetric._ + + override lazy val additionalMetrics: Map[String, GpuMetric] = Map( + OP_TIME -> createNanoTimingMetric(MODERATE_LEVEL, OP_TIME) + ) + + override def output: Seq[Attribute] = windowOps.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = { + if (partitionSpec.isEmpty) { + // Only show warning when the number of bytes is larger than 100 MiB? + logWarning("No Partition Defined for Window operation! Moving all data to a single " + + "partition, this can cause serious performance degradation.") + AllTuples :: Nil + } else ClusteredDistribution(partitionSpec) :: Nil + } + + lazy val partitionOrdering: Seq[SortOrder] = { + val shims = ShimLoader.getSparkShims + partitionSpec.map(shims.sortOrder(_, Ascending)) + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(partitionOrdering ++ orderSpec) + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override protected def doExecute(): RDD[InternalRow] = + throw new IllegalStateException(s"Row-based execution should not happen, in $this.") +} + +/** + * The class represents a window function and the locations of its deduped inputs after an initial + * projection. + */ +case class BoundGpuWindowFunction(windowFunc: GpuWindowFunction, boundInputLocations: Array[Int]) { + def aggOverWindow(cb: ColumnarBatch, + windowOpts: WindowOptions): AggregationOverWindow[Nothing] = { + val aggFunc = windowFunc.asInstanceOf[GpuAggregateWindowFunction[_]] + val inputs = boundInputLocations.map { pos => + (cb.column(pos).asInstanceOf[GpuColumnVector].getBase, pos) + } + aggFunc.windowAggregation(inputs).overWindow(windowOpts) + } + + val dataType: DataType = windowFunc.dataType +} + +case class ParsedBoundary(isUnbounded: Boolean, valueAsLong: Long) + +object GroupedAggregations extends Arm { + /** + * Get the window options for an aggregation + * @param orderSpec the order by spec + * @param orderPositions the positions of the order by columns + * @param frame the frame to translate + * @return the options to use when doing the aggregation. + */ + private def getWindowOptions( + orderSpec: Seq[SortOrder], + orderPositions: Seq[Int], + frame: GpuSpecifiedWindowFrame): WindowOptions = { + frame.frameType match { + case RowFrame => + withResource(getRowBasedLower(frame)) { lower => + withResource(getRowBasedUpper(frame)) { upper => + WindowOptions.builder() + .minPeriods(1) + .window(lower, upper).build() + } } - case _ => None - }.toMap + case RangeFrame => + // This gets to be a little more complicated + + // We only support a single column to order by right now, so just verify that. + require(orderSpec.length == 1) + require(orderPositions.length == orderSpec.length) + val orderExpr = orderSpec.head + + // We only support basic types for now too + val orderType = GpuColumnVector.getNonNestedRapidsType(orderExpr.dataType) + + val orderByIndex = orderPositions.head + val lower = getRangeBoundaryValue(frame.lower) + val upper = getRangeBoundaryValue(frame.upper) + + withResource(asScalarRangeBoundary(orderType, lower)) { preceding => + withResource(asScalarRangeBoundary(orderType, upper)) { following => + val windowOptionBuilder = WindowOptions.builder() + .minPeriods(1) + .orderByColumnIndex(orderByIndex) + + if (preceding.isEmpty) { + windowOptionBuilder.unboundedPreceding() + } else { + windowOptionBuilder.preceding(preceding.get) + } + + if (following.isEmpty) { + windowOptionBuilder.unboundedFollowing() + } else { + windowOptionBuilder.following(following.get) + } - def computeRunningNoPartitioning( - iter: Iterator[ColumnarBatch], - boundWindowOps: Seq[GpuExpression], - numOutputBatches: GpuMetric, - numOutputRows: GpuMetric, - opTime: GpuMetric): Iterator[ColumnarBatch] = { - val fixers = fixerIndexMap(boundWindowOps) - TaskContext.get().addTaskCompletionListener[Unit](_ => fixers.values.foreach(_.close())) - - iter.flatMap { cb => - val numRows = cb.numRows - numOutputBatches += 1 - numOutputRows += numRows - withResource(new MetricRange(opTime)) { _ => - if (numRows > 0) { - withResource(GpuProjectExec.projectAndClose(cb, boundWindowOps, NoopMetric)) { full => - closeOnExcept(ArrayBuffer[ColumnVector]()) { newColumns => - boundWindowOps.indices.foreach { idx => - val column = full.column(idx).asInstanceOf[GpuColumnVector] - fixers.get(idx) match { - case Some(fixer) => - closeOnExcept(fixer.fixUp(scala.util.Right(true), column)) { finalOutput => - fixer.updateState(finalOutput) - newColumns += finalOutput - } - case None => - newColumns += column.incRefCount() - } - } - Some(new ColumnarBatch(newColumns.toArray, full.numRows())) + if (orderExpr.isAscending) { + windowOptionBuilder.orderByAscending() + } else { + windowOptionBuilder.orderByDescending() } + + windowOptionBuilder.build() + } + } + } + } + + private def getRowBasedLower(windowFrameSpec : GpuSpecifiedWindowFrame): Scalar = { + val lower = getRowBoundaryValue(windowFrameSpec.lower) + + // Translate the lower bound value to CUDF semantics: + // In spark 0 is the current row and lower bound is negative relative to that + // In CUDF the preceding window starts at the current row with 1 and up from there the + // further from the current row. + val ret = if (lower >= Int.MaxValue) { + Int.MinValue + } else if (lower <= Int.MinValue) { + Int.MaxValue + } else { + -(lower-1) + } + Scalar.fromInt(ret) + } + + private def getRowBasedUpper(windowFrameSpec : GpuSpecifiedWindowFrame): Scalar = + Scalar.fromInt(getRowBoundaryValue(windowFrameSpec.upper)) + + private def getRowBoundaryValue(boundary : Expression) : Int = boundary match { + case literal: GpuLiteral if literal.dataType.equals(IntegerType) => + literal.value.asInstanceOf[Int] + case special: GpuSpecialFrameBoundary => + special.value + case anythingElse => + throw new UnsupportedOperationException(s"Unsupported window frame expression $anythingElse") + } + + /** + * Create a Scalar from boundary value according to order by column type. + * + * Timestamp types will be converted into interval types. + * + * @param orderByType the type of order by column + * @param bound boundary value + * @return a Scalar holding boundary value or None if the boundary is unbounded. + */ + private def asScalarRangeBoundary(orderByType: DType, bound: ParsedBoundary): Option[Scalar] = { + if (bound.isUnbounded) { + None + } else { + val value = bound.valueAsLong + val s = orderByType match { + case DType.INT8 => Scalar.fromByte(value.toByte) + case DType.INT16 => Scalar.fromShort(value.toShort) + case DType.INT32 => Scalar.fromInt(value.toInt) + case DType.INT64 => Scalar.fromLong(value) + // Interval is not working for DateType + case DType.TIMESTAMP_DAYS => Scalar.durationFromLong(DType.DURATION_DAYS, value) + case DType.TIMESTAMP_MICROSECONDS => + Scalar.durationFromLong(DType.DURATION_MICROSECONDS, value) + case _ => throw new RuntimeException(s"Not supported order by type, Found $orderByType") + } + Some(s) + } + } + + private def getRangeBoundaryValue(boundary: Expression): ParsedBoundary = boundary match { + case special: GpuSpecialFrameBoundary => + val isUnBounded = special.isUnbounded + ParsedBoundary(isUnBounded, special.value) + case GpuLiteral(ci: CalendarInterval, CalendarIntervalType) => + // Get the total microseconds for TIMESTAMP_MICROSECONDS + var x = TimeUnit.DAYS.toMicros(ci.days) + ci.microseconds + if (x == Long.MinValue) x = Long.MaxValue + ParsedBoundary(isUnbounded = false, Math.abs(x)) + case GpuLiteral(value, ByteType) => + var x = value.asInstanceOf[Byte] + if (x == Byte.MinValue) x = Byte.MaxValue + ParsedBoundary(isUnbounded = false, Math.abs(x)) + case GpuLiteral(value, ShortType) => + var x = value.asInstanceOf[Short] + if (x == Short.MinValue) x = Short.MaxValue + ParsedBoundary(isUnbounded = false, Math.abs(x)) + case GpuLiteral(value, IntegerType) => + var x = value.asInstanceOf[Int] + if (x == Int.MinValue) x = Int.MaxValue + ParsedBoundary(isUnbounded = false, Math.abs(x)) + case GpuLiteral(value, LongType) => + var x = value.asInstanceOf[Long] + if (x == Long.MinValue) x = Long.MaxValue + ParsedBoundary(isUnbounded = false, Math.abs(x)) + case anything => throw new UnsupportedOperationException("Unsupported window frame" + + s" expression $anything") + } +} + +class GroupedAggregations extends Arm { + import GroupedAggregations._ + + // The window frame to a map of the window function to the output locations for the result + private val data = mutable.HashMap[GpuSpecifiedWindowFrame, + mutable.HashMap[BoundGpuWindowFunction, ArrayBuffer[Int]]]() + + def addAggregation(win: GpuWindowExpression, inputLocs: Array[Int], outputIndex: Int): Unit = { + val forSpec = + data.getOrElseUpdate(win.normalizedFrameSpec, mutable.HashMap.empty) + forSpec.getOrElseUpdate(BoundGpuWindowFunction(win.wrappedWindowFunc, inputLocs), + ArrayBuffer.empty) += outputIndex + } + + private def copyResultToFinalOutput(result: Table, + functions: mutable.HashMap[BoundGpuWindowFunction, ArrayBuffer[Int]], + outputColumns: Array[ColumnVector]): Unit = { + functions.zipWithIndex.foreach { + case ((winFunc, outputIndexes), resultIndex) => + val aggColumn = result.getColumn(resultIndex) + // For nested type, do not cast + val finalCol = aggColumn.getType match { + case dType if dType.isNestedType => + GpuColumnVector.from(aggColumn.incRefCount(), winFunc.dataType) + case _ => + val expectedType = GpuColumnVector.getNonNestedRapidsType(winFunc.dataType) + // The API 'castTo' will take care of the 'from' type and 'to' type, and + // just increase the reference count by one when they are the same. + // so it is OK to always call it here. + GpuColumnVector.from(aggColumn.castTo(expectedType), winFunc.dataType) + } + + withResource(finalCol) { finalCol => + outputIndexes.foreach { outIndex => + outputColumns(outIndex) = finalCol.incRefCount() + } + } + } + } + + private def doAggInternal( + frameType: FrameType, + boundOrderSpec: Seq[SortOrder], + orderByPositions: Array[Int], + partByPositions: Array[Int], + inputCb: ColumnarBatch, + outputColumns: Array[ColumnVector], + aggIt: (Table.GroupByOperation, Seq[AggregationOverWindow[Nothing]]) => Table): Unit = { + data.foreach { + case (frameSpec, functions) => + if (frameSpec.frameType == frameType) { + // For now I am going to assume that we don't need to combine calls across frame specs + // because it would just not help that much + val result = withResource( + getWindowOptions(boundOrderSpec, orderByPositions, frameSpec)) { windowOpts => + val allAggs = functions.map { + case (winFunc, _) => winFunc.aggOverWindow(inputCb, windowOpts) + }.toSeq + withResource(GpuColumnVector.from(inputCb)) { initProjTab => + aggIt(initProjTab.groupBy(partByPositions: _*), allAggs) + } + } + withResource(result) { result => + copyResultToFinalOutput(result, functions, outputColumns) } - } else { - // Now rows so just filter it out - cb.close() - None } + } + } + + private def doRowAggs(boundOrderSpec: Seq[SortOrder], + orderByPositions: Array[Int], + partByPositions: Array[Int], + inputCb: ColumnarBatch, + outputColumns: Array[ColumnVector]): Unit = { + doAggInternal( + RowFrame, boundOrderSpec, orderByPositions, partByPositions, inputCb, outputColumns, + (groupBy, aggs) => groupBy.aggregateWindows(aggs: _*)) + } + + private def doRangeAggs(boundOrderSpec: Seq[SortOrder], + orderByPositions: Array[Int], + partByPositions: Array[Int], + inputCb: ColumnarBatch, + outputColumns: Array[ColumnVector]): Unit = { + doAggInternal( + RangeFrame, boundOrderSpec, orderByPositions, partByPositions, inputCb, outputColumns, + (groupBy, aggs) => groupBy.aggregateWindowsOverRanges(aggs: _*)) + } + + def doAggs(boundOrderSpec: Seq[SortOrder], + orderByPositions: Array[Int], + partByPositions: Array[Int], + inputCb: ColumnarBatch, + outputColumns: Array[ColumnVector]): Unit = { + doRowAggs(boundOrderSpec, orderByPositions, partByPositions, inputCb, outputColumns) + doRangeAggs(boundOrderSpec, orderByPositions, partByPositions, inputCb, outputColumns) + } +} + +/** + * Calculates the results of window operations + */ +trait BasicWindowCalc extends Arm { + val boundWindowOps: Seq[GpuExpression] + val boundPartitionSpec: Seq[GpuExpression] + val boundOrderSpec: Seq[SortOrder] + + private val (initialProjections, + passThrough, + aggregations, + orderByPositions, + partByPositions) = { + val initialProjections = ArrayBuffer[Expression]() + val dedupedInitialProjections = mutable.HashMap[Expression, Int]() + + def getOrAddInitialProjectionIndex(expr: Expression): Int = + dedupedInitialProjections.getOrElseUpdate(expr, { + val at = initialProjections.length + initialProjections += expr + at + }) + + val passThrough = ArrayBuffer[(Int, Int)]() + val aggregations = new GroupedAggregations() + + boundWindowOps.zipWithIndex.foreach { + case (GpuAlias(GpuBoundReference(inputIndex, _, _), _), outputIndex) => + passThrough.append((inputIndex, outputIndex)) + case (GpuBoundReference(inputIndex, _, _), outputIndex) => + passThrough.append((inputIndex, outputIndex)) + case (GpuAlias(win: GpuWindowExpression, _), outputIndex) => + val inputLocations = win.wrappedWindowFunc. + windowInputProjection.map(getOrAddInitialProjectionIndex).toArray + aggregations.addAggregation(win, inputLocations, outputIndex) + case _ => + throw new IllegalArgumentException("Unexpected operation found in window expression") + } + + val partByPositions = boundPartitionSpec.map(getOrAddInitialProjectionIndex).toArray + val orderByPositions = boundOrderSpec.map { so => + getOrAddInitialProjectionIndex(so.child) + }.toArray + + (initialProjections, passThrough, aggregations, orderByPositions, partByPositions) + } + + def computeBasicWindow(cb: ColumnarBatch): ColumnarBatch = { + closeOnExcept(new Array[ColumnVector](boundWindowOps.length)) { outputColumns => + // First the pass through unchanged columns + passThrough.foreach { + case (inputIndex, outputIndex) => + outputColumns(outputIndex) = + cb.column(inputIndex).asInstanceOf[GpuColumnVector].incRefCount() + } + + withResource(GpuProjectExec.project(cb, initialProjections)) { initProjCb => + aggregations.doAggs(boundOrderSpec, orderByPositions, + partByPositions, initProjCb, outputColumns) + } + + new ColumnarBatch(outputColumns, cb.numRows()) + } + } +} + +/** + * An Iterator that performs window operations on the input data. It is required that the input + * data is batched so all of the data for a given key is in the same batch. The input data must + * also be sorted by both partition by keys and order by keys. + */ +class GpuWindowIterator( + input: Iterator[ColumnarBatch], + override val boundWindowOps: Seq[GpuExpression], + override val boundPartitionSpec: Seq[GpuExpression], + override val boundOrderSpec: Seq[SortOrder], + numOutputBatches: GpuMetric, + numOutputRows: GpuMetric, + opTime: GpuMetric) extends Iterator[ColumnarBatch] with BasicWindowCalc { + + override def hasNext: Boolean = input.hasNext + + override def next(): ColumnarBatch = { + withResource(input.next()) { cb => + withResource(new NvtxWithMetrics("window", NvtxColor.CYAN, opTime)) { _ => + val ret = computeBasicWindow(cb) + numOutputBatches += 1 + numOutputRows += ret.numRows() + ret } } } +} +object GpuRunningWindowIterator extends Arm { private def cudfAnd(lhs: ai.rapids.cudf.ColumnVector, rhs: ai.rapids.cudf.ColumnVector): ai.rapids.cudf.ColumnVector = { withResource(lhs) { lhs => @@ -420,8 +775,10 @@ object GpuWindowExec extends Arm { private def arePartsEqual( scalars: Seq[Scalar], columns: Seq[ai.rapids.cudf.ColumnVector]): Either[GpuColumnVector, Boolean] = { - if (scalars.isEmpty) { + if (scalars.length != columns.length) { scala.util.Right(false) + } else if (scalars.isEmpty && columns.isEmpty) { + scala.util.Right(true) } else { val ret = scalars.zip(columns).map { case (scalar, column) => scalar.equalToNullAware(column) @@ -432,113 +789,131 @@ object GpuWindowExec extends Arm { private def getScalarRow(index: Int, columns: Seq[ai.rapids.cudf.ColumnVector]): Array[Scalar] = columns.map(_.getScalarElement(index)).toArray +} - def computeRunning( - iter: Iterator[ColumnarBatch], - boundWindowOps: Seq[GpuExpression], - boundPartitionSpec: Seq[Expression], - numOutputBatches: GpuMetric, - numOutputRows: GpuMetric, - opTime: GpuMetric): Iterator[ColumnarBatch] = { - var lastParts: Array[Scalar] = Array.empty - val fixers = fixerIndexMap(boundWindowOps) - - def saveLastParts(newLastParts: Array[Scalar]): Unit = { - lastParts.foreach(_.close()) - lastParts = newLastParts - } +/** + * An iterator that can do row based aggregations on running window queries (Unbounded preceding to + * current row) if and only if the aggregations are instances of GpuBatchedRunningWindowFunction + * which can fix up the window output when an aggregation is only partly done in one batch of data. + * Because of this there is no requirement about how the input data is batched, but it must + * be sorted by both partitioning and ordering. + */ +class GpuRunningWindowIterator( + input: Iterator[ColumnarBatch], + override val boundWindowOps: Seq[GpuExpression], + override val boundPartitionSpec: Seq[GpuExpression], + override val boundOrderSpec: Seq[SortOrder], + numOutputBatches: GpuMetric, + numOutputRows: GpuMetric, + opTime: GpuMetric) extends Iterator[ColumnarBatch] with BasicWindowCalc { + import GpuRunningWindowIterator._ + TaskContext.get().addTaskCompletionListener[Unit](_ => close()) + + // This should only ever be cached in between calls to `hasNext` and `next`. This is just + // to let us filter out empty batches. + private var cachedBatch: Option[ColumnarBatch] = None + private var lastParts: Array[Scalar] = Array.empty + private var isClosed: Boolean = false + + private def saveLastParts(newLastParts: Array[Scalar]): Unit = { + lastParts.foreach(_.close()) + lastParts = newLastParts + } - def closeState(): Unit = { + def close(): Unit = { + if (!isClosed) { + isClosed = true + fixerIndexMap.values.foreach(_.close()) saveLastParts(Array.empty) - fixers.values.foreach(_.close()) } + } - TaskContext.get().addTaskCompletionListener[Unit](_ => closeState()) - - iter.map { cb => - val numRows = cb.numRows - numOutputBatches += 1 - numOutputRows += numRows - withResource(new MetricRange(opTime)) { _ => - val fullWindowProjectList = boundWindowOps ++ boundPartitionSpec - withResource( - GpuProjectExec.projectAndClose(cb, fullWindowProjectList, NoopMetric)) { full => - // part columns are owned by full and do not need to be closed, but should not be used - // if full is closed - val partColumns = boundPartitionSpec.indices.map { idx => - full.column(idx + boundWindowOps.length).asInstanceOf[GpuColumnVector].getBase - } + private lazy val fixerIndexMap: Map[Int, BatchedRunningWindowFixer] = + boundWindowOps.zipWithIndex.flatMap { + case (GpuAlias(GpuWindowExpression(func, _), _), index) => + func match { + case f: GpuBatchedRunningWindowFunction[_] => + Some((index, f.newFixer())) + case GpuAggregateExpression(f: GpuBatchedRunningWindowFunction[_], _, _, _, _) => + Some((index, f.newFixer())) + case _ => None + } + case _ => None + }.toMap - // We need to fix up the rows that are part of the same batch as the end of the - // last batch - val partsEqual = arePartsEqual(lastParts, partColumns) - try { - closeOnExcept(ArrayBuffer[ColumnVector]()) { newColumns => - boundWindowOps.indices.foreach { idx => - val column = full.column(idx).asInstanceOf[GpuColumnVector] - val fixer = fixers.get(idx) - if (fixer.isDefined) { - val f = fixer.get - closeOnExcept(f.fixUp(partsEqual, column)) { finalOutput => - f.updateState(finalOutput) - newColumns += finalOutput - } - } else { - newColumns += column.incRefCount() - } - } - saveLastParts(getScalarRow(numRows - 1, partColumns)) - - new ColumnarBatch(newColumns.toArray, numRows) - } - } finally { - partsEqual match { - case scala.util.Left(cv) => cv.close() - case _ => // Nothing + private def fixUpAll(computedWindows: ColumnarBatch, + fixers: Map[Int, BatchedRunningWindowFixer], + samePartitionMask: Either[GpuColumnVector, Boolean]): ColumnarBatch = { + closeOnExcept(ArrayBuffer[ColumnVector]()) { newColumns => + boundWindowOps.indices.foreach { idx => + val column = computedWindows.column(idx).asInstanceOf[GpuColumnVector] + fixers.get(idx) match { + case Some(fixer) => + closeOnExcept(fixer.fixUp(samePartitionMask, column)) { finalOutput => + fixer.updateState(finalOutput) + newColumns += finalOutput } - } + case None => + newColumns += column.incRefCount() } } + new ColumnarBatch(newColumns.toArray, computedWindows.numRows()) } } -} -trait GpuWindowBaseExec extends UnaryExecNode with GpuExec { - val windowOps: Seq[NamedExpression] - val partitionSpec: Seq[Expression] - val orderSpec: Seq[SortOrder] - - import GpuMetric._ - - override lazy val additionalMetrics: Map[String, GpuMetric] = Map( - OP_TIME -> createNanoTimingMetric(MODERATE_LEVEL, OP_TIME) - ) - - override def output: Seq[Attribute] = windowOps.map(_.toAttribute) - - override def requiredChildDistribution: Seq[Distribution] = { - if (partitionSpec.isEmpty) { - // Only show warning when the number of bytes is larger than 100 MiB? - logWarning("No Partition Defined for Window operation! Moving all data to a single " - + "partition, this can cause serious performance degradation.") - AllTuples :: Nil - } else ClusteredDistribution(partitionSpec) :: Nil + def computeRunning(cb: ColumnarBatch): ColumnarBatch = { + val fixers = fixerIndexMap + val numRows = cb.numRows() + withResource(computeBasicWindow(cb)) { basic => + withResource(GpuProjectExec.project(cb, boundPartitionSpec)) { parts => + val partColumns = GpuColumnVector.extractBases(parts) + // We need to fix up the rows that are part of the same batch as the end of the + // last batch + withResourceIfAllowed(arePartsEqual(lastParts, partColumns)) { partsEqual => + val ret = fixUpAll(basic, fixers, partsEqual) + saveLastParts(getScalarRow(numRows - 1, partColumns)) + ret + } + } + } } - lazy val partitionOrdering: Seq[SortOrder] = { - val shims = ShimLoader.getSparkShims - partitionSpec.map(shims.sortOrder(_, Ascending)) + private def cacheBatchIfNeeded(): Unit = { + while (cachedBatch.isEmpty && input.hasNext) { + closeOnExcept(input.next()) { cb => + if (cb.numRows() > 0) { + cachedBatch = Some(cb) + } else { + cb.close() + } + } + } } - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(partitionOrdering ++ orderSpec) - - override def outputOrdering: Seq[SortOrder] = child.outputOrdering + def readNextInputBatch(): ColumnarBatch = { + cacheBatchIfNeeded() + val ret = cachedBatch.getOrElse { + throw new NoSuchElementException() + } + cachedBatch = None + ret + } - override def outputPartitioning: Partitioning = child.outputPartitioning + override def hasNext: Boolean = { + cacheBatchIfNeeded() + cachedBatch.isDefined + } - override protected def doExecute(): RDD[InternalRow] = - throw new IllegalStateException(s"Row-based execution should not happen, in $this.") + override def next(): ColumnarBatch = { + withResource(readNextInputBatch()) { cb => + withResource(new NvtxWithMetrics("RunningWindow", NvtxColor.CYAN, opTime)) { _ => + val ret = computeRunning(cb) + numOutputBatches += 1 + numOutputRows += ret.numRows() + ret + } + } + } } case class GpuRunningWindowExec( @@ -553,23 +928,13 @@ case class GpuRunningWindowExec( val numOutputRows = gpuLongMetric(GpuMetric.NUM_OUTPUT_ROWS) val opTime = gpuLongMetric(GpuMetric.OP_TIME) - val boundWindowOps = - GpuBindReferences.bindGpuReferences(windowOps, child.output) - - val boundPartitionSpec = - GpuBindReferences.bindGpuReferences(partitionSpec, child.output) + val boundWindowOps = GpuBindReferences.bindGpuReferences(windowOps, child.output) + val boundPartitionSpec = GpuBindReferences.bindGpuReferences(partitionSpec, child.output) + val boundOrderSpec = GpuBindReferences.bindReferences(orderSpec, child.output) - if (partitionSpec.isEmpty) { - child.executeColumnar().mapPartitions { - iter => GpuWindowExec.computeRunningNoPartitioning(iter, - boundWindowOps, numOutputBatches, numOutputRows, opTime) - } - } else { - child.executeColumnar().mapPartitions { - iter => GpuWindowExec.computeRunning(iter, - boundWindowOps, boundPartitionSpec, numOutputBatches, - numOutputRows, opTime) - } + child.executeColumnar().mapPartitions { iter => + new GpuRunningWindowIterator(iter, boundWindowOps, boundPartitionSpec, boundOrderSpec, + numOutputBatches, numOutputRows, opTime) } } } @@ -594,13 +959,13 @@ case class GpuWindowExec( val numOutputRows = gpuLongMetric(GpuMetric.NUM_OUTPUT_ROWS) val opTime = gpuLongMetric(GpuMetric.OP_TIME) - val boundWindowOps = - GpuBindReferences.bindGpuReferences(windowOps, child.output) + val boundWindowOps = GpuBindReferences.bindGpuReferences(windowOps, child.output) + val boundPartitionSpec = GpuBindReferences.bindGpuReferences(partitionSpec, child.output) + val boundOrderSpec = GpuBindReferences.bindReferences(orderSpec, child.output) - child.executeColumnar().map { cb => - numOutputBatches += 1 - numOutputRows += cb.numRows - GpuProjectExec.projectAndClose(cb, boundWindowOps, opTime) + child.executeColumnar().mapPartitions { iter => + new GpuWindowIterator(iter, boundWindowOps, boundPartitionSpec, boundOrderSpec, + numOutputBatches, numOutputRows, opTime) } } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala index 439804c4e88..110d8c5c5eb 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala @@ -20,7 +20,7 @@ import java.util.concurrent.TimeUnit import scala.language.{existentials, implicitConversions} -import ai.rapids.cudf.{Aggregation, AggregationOnColumn, BinaryOp, ColumnVector, DType, RollingAggregation, Scalar, WindowOptions} +import ai.rapids.cudf.{Aggregation, AggregationOnColumn, BinaryOp, ColumnVector, RollingAggregation, Scalar} import ai.rapids.cudf.Aggregation.{LagAggregation, LeadAggregation, RowNumberAggregation} import com.nvidia.spark.rapids.GpuOverrides.wrapExpr @@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.rapids.GpuAggregateExpression import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.unsafe.types.CalendarInterval class GpuWindowExpressionMeta( @@ -40,17 +39,11 @@ class GpuWindowExpressionMeta( rule: DataFromReplacementRule) extends ExprMeta[WindowExpression](windowExpression, conf, parent, rule) { - private def getBoundaryValue(boundary : Expression) : Int = boundary match { + private def getAndCheckRowBoundaryValue(boundary: Expression) : Int = boundary match { case literal: Literal => literal.dataType match { case IntegerType => literal.value.asInstanceOf[Int] - case CalendarIntervalType => - val ci = literal.value.asInstanceOf[CalendarInterval] - if (ci.months != 0 || ci.microseconds != 0) { - willNotWorkOnGpu("only days are supported for window range intervals") - } - ci.days case t => willNotWorkOnGpu(s"unsupported window boundary type $t") -1 @@ -75,8 +68,8 @@ class GpuWindowExpressionMeta( spec.frameType match { case RowFrame => // Will also verify that the types are what we expect. - val lower = getBoundaryValue(spec.lower) - val upper = getBoundaryValue(spec.upper) + val lower = getAndCheckRowBoundaryValue(spec.lower) + val upper = getAndCheckRowBoundaryValue(spec.upper) windowFunction match { case Lead(_, _, _) | Lag(_, _, _) => // ignored we are good case _ => @@ -92,31 +85,28 @@ class GpuWindowExpressionMeta( } case RangeFrame => // Spark by default does a RangeFrame if no RowFrame is given - // even for columns that are not time type columns. We can switch this back to row - // based iff the ranges we are looking at both unbounded. We do this for all range - // queries because https://github.com/NVIDIA/spark-rapids/issues/1039 makes it so - // we cannot support nulls in range queries - // Will also verify that the types are what we expect. + // even for columns that are not time type columns. We can switch this to row + // based iff the ranges we are looking at both unbounded. if (spec.isUnbounded) { // this is okay because we will translate it to be a row query } else { // check whether order by column is supported or not val orderSpec = wrapped.windowSpec.orderSpec if (orderSpec.length > 1) { - // We only support a single time column + // We only support a single order by column willNotWorkOnGpu("only a single date/time or integral (Boolean exclusive)" + "based column in window range functions is supported") } val orderByTypeSupported = orderSpec.forall { so => so.dataType match { - case ByteType | ShortType | IntegerType | LongType => true - case DateType | TimestampType => true + case ByteType | ShortType | IntegerType | LongType | + DateType | TimestampType => true case _ => false } } if (!orderByTypeSupported) { willNotWorkOnGpu(s"the type of orderBy column is not supported in a window" + - s" range function, found ${orderSpec(0).dataType}") + s" range function, found ${orderSpec.head.dataType}") } def checkRangeBoundaryConfig(dt: DataType): Unit = { @@ -173,7 +163,7 @@ class GpuWindowExpressionMeta( } case class GpuWindowExpression(windowFunction: Expression, windowSpec: GpuWindowSpecDefinition) - extends GpuExpression { + extends GpuUnevaluable { override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil @@ -187,9 +177,17 @@ case class GpuWindowExpression(windowFunction: Expression, windowSpec: GpuWindow override def sql: String = windowFunction.sql + " OVER " + windowSpec.sql + lazy val normalizedFrameSpec: GpuSpecifiedWindowFrame = { + val fs = windowFrameSpec.canonicalized.asInstanceOf[GpuSpecifiedWindowFrame] + fs.frameType match { + case RangeFrame if fs.isUnbounded => + GpuSpecifiedWindowFrame(RowFrame, fs.lower, fs.upper) + case _ => fs + } + } + private val windowFrameSpec = windowSpec.frameSpecification.asInstanceOf[GpuSpecifiedWindowFrame] - private val frameType : FrameType = windowFrameSpec.frameType - private val windowFunc = windowFunction match { + lazy val wrappedWindowFunc = windowFunction match { case func: GpuAggregateWindowFunction[_] => func case agg: GpuAggregateExpression => agg.aggregateFunction match { case func: GpuAggregateWindowFunction[_] => func @@ -199,276 +197,6 @@ case class GpuWindowExpression(windowFunction: Expression, windowSpec: GpuWindow case other => throw new IllegalStateException(s"${other.getClass} is not a supported window function") } - private lazy val boundRowProjectList = windowSpec.partitionSpec ++ - windowFunc.windowInputProjection - private lazy val boundRangeProjectList = windowSpec.partitionSpec ++ - windowSpec.orderSpec.map(_.child.asInstanceOf[GpuExpression]) ++ - windowFunc.windowInputProjection - - override def columnarEval(cb: ColumnarBatch) : Any = { - frameType match { - case RowFrame => evaluateRowBasedWindowExpression(cb) - case RangeFrame => - if (windowFrameSpec.isUnbounded) { - // We already verified that this will be okay... - evaluateRowBasedWindowExpression(cb) - } else { - evaluateRangeBasedWindowExpression(cb) - } - - case allElse => - throw new UnsupportedOperationException( - s"Unsupported window expression frame type: $allElse") - } - } - - private def evaluateRowBasedWindowExpression(cb : ColumnarBatch) : GpuColumnVector = { - val numGroupingColumns = windowSpec.partitionSpec.length - val totalExtraColumns = numGroupingColumns - - val aggColumn = withResource(GpuProjectExec.project(cb, boundRowProjectList)) { projected => - - // in case boundRowProjectList is empty - val finalCb = if (boundRowProjectList.nonEmpty) projected else cb - - withResource(GpuColumnVector.from(finalCb)) { table => - val bases = GpuColumnVector.extractBases(finalCb).zipWithIndex - .slice(totalExtraColumns, boundRowProjectList.length) - - withResource(GpuWindowExpression.getRowBasedWindowOptions(windowFrameSpec)) { - windowOptions => - - val agg = windowFunc.windowAggregation(bases) - .overWindow(windowOptions) - - withResource(table - .groupBy(0 until numGroupingColumns: _*) - .aggregateWindows(agg)) { aggResultTable => - aggResultTable.getColumn(0).incRefCount() - } - } - } - } - // For nested type, do not cast - aggColumn.getType match { - case dType if dType.isNestedType => - GpuColumnVector.from(aggColumn, windowFunc.dataType) - case _ => - val expectedType = GpuColumnVector.getNonNestedRapidsType(windowFunc.dataType) - // The API 'castTo' will take care of the 'from' type and 'to' type, and - // just increase the reference count by one when they are the same. - // so it is OK to always call it here. - withResource(aggColumn) { aggColumn => - GpuColumnVector.from(aggColumn.castTo(expectedType), windowFunc.dataType) - } - } - } - - private def evaluateRangeBasedWindowExpression(cb : ColumnarBatch) : GpuColumnVector = { - val numGroupingColumns = windowSpec.partitionSpec.length - val numSortColumns = windowSpec.orderSpec.length - assert(numSortColumns == 1) - val totalExtraColumns = numGroupingColumns + numSortColumns - - val aggColumn = withResource(GpuProjectExec.project(cb, boundRangeProjectList)) { projected => - withResource(GpuColumnVector.from(projected)) { table => - val bases = GpuColumnVector.extractBases(projected).zipWithIndex - .slice(totalExtraColumns, boundRangeProjectList.length) - - // Since boundRangeProjectList = windowSpec.partitionSpec ++ - // windowSpec.orderSpec.map(_.child.asInstanceOf[GpuExpression]) ++ - // windowFunc.windowInputProjection - // Here table.getColumn(numGroupingColumns) is the orderBy column - val orderByType = table.getColumn(numGroupingColumns).getType - // get the preceding/following scalar to construct WindowOptions - val (isUnboundedPreceding, preceding) = GpuWindowExpression.getRangeBasedLower( - windowFrameSpec, Some(orderByType)) - val (isUnBoundedFollowing, following) = GpuWindowExpression.getRangeBasedUpper( - windowFrameSpec, Some(orderByType)) - - withResource(preceding) { preceding => - withResource(following) { following => - withResource(GpuWindowExpression.getRangeBasedWindowOptions(windowSpec.orderSpec, - numGroupingColumns, - isUnboundedPreceding, - preceding.orNull, - isUnBoundedFollowing, - following.orNull)) { windowOptions => - val agg = windowFunc.windowAggregation(bases).overWindow(windowOptions) - withResource(table - .groupBy(0 until numGroupingColumns: _*) - .aggregateWindowsOverRanges(agg)) { aggResultTable => - aggResultTable.getColumn(0).incRefCount() - } - } - } - } - } - } - // For nested type, do not cast - aggColumn.getType match { - case dType if dType.isNestedType => - GpuColumnVector.from(aggColumn, windowFunc.dataType) - case _ => - val expectedType = GpuColumnVector.getNonNestedRapidsType(windowFunc.dataType) - // The API 'castTo' will take care of the 'from' type and 'to' type, and - // just increase the reference count by one when they are the same. - // so it is OK to always call it here. - withResource(aggColumn) { aggColumn => - GpuColumnVector.from(aggColumn.castTo(expectedType), windowFunc.dataType) - } - } - } -} - -object GpuWindowExpression extends Arm { - - def getRowBasedLower(windowFrameSpec : GpuSpecifiedWindowFrame): Scalar = { - val lower = getBoundaryValue(windowFrameSpec.lower) - - // Translate the lower bound value to CUDF semantics: - // In spark 0 is the current row and lower bound is negative relative to that - // In CUDF the preceding window starts at the current row with 1 and up from there the - // further from the current row. - val ret = if (lower >= Int.MaxValue) { - Int.MinValue - } else if (lower <= Int.MinValue) { - Int.MaxValue - } else { - -(lower-1) - } - Scalar.fromInt(ret) - } - - def getRowBasedUpper(windowFrameSpec : GpuSpecifiedWindowFrame): Scalar = - Scalar.fromInt(getBoundaryValue(windowFrameSpec.upper)) - - def getRowBasedWindowOptions(windowFrameSpec : GpuSpecifiedWindowFrame): WindowOptions = { - withResource(getRowBasedLower(windowFrameSpec)) { lower => - withResource(getRowBasedUpper(windowFrameSpec)) { upper => - WindowOptions.builder().minPeriods(1) - .window(lower, upper).build() - } - } - } - - def getRangeBasedLower(windowFrameSpec: GpuSpecifiedWindowFrame, orderByType: Option[DType]): - (Boolean, Option[Scalar]) = { - getRangeBoundaryValue(windowFrameSpec.lower, orderByType) - } - - def getRangeBasedUpper(windowFrameSpec: GpuSpecifiedWindowFrame, orderByType: Option[DType]): - (Boolean, Option[Scalar]) = { - getRangeBoundaryValue(windowFrameSpec.upper, orderByType) - } - - def getRangeBasedWindowOptions( - orderSpec: Seq[SortOrder], - orderByColumnIndex : Int, - isUnboundedPreceding: Boolean, - preceding: Scalar, - isUnBoundedFollowing: Boolean, - following: Scalar): WindowOptions = { - val windowOptionBuilder = WindowOptions.builder() - .minPeriods(1) - .orderByColumnIndex(orderByColumnIndex) - - if (isUnboundedPreceding) { - windowOptionBuilder.unboundedPreceding() - } else { - windowOptionBuilder.preceding(preceding) - } - - if (isUnBoundedFollowing) { - windowOptionBuilder.unboundedFollowing() - } else { - windowOptionBuilder.following(following) - } - - // We only support a single time based column to order by right now, so just verify - // that it is correct. - assert(orderSpec.length == 1) - if (orderSpec.head.isAscending) { - windowOptionBuilder.orderByAscending() - } else { - windowOptionBuilder.orderByDescending() - } - - windowOptionBuilder.build() - } - - def getBoundaryValue(boundary : Expression) : Int = boundary match { - case literal: GpuLiteral if literal.dataType.equals(IntegerType) => - literal.value.asInstanceOf[Int] - case literal: GpuLiteral if literal.dataType.equals(CalendarIntervalType) => - literal.value.asInstanceOf[CalendarInterval].days - case special: GpuSpecialFrameBoundary => - special.value - case anythingElse => - throw new UnsupportedOperationException(s"Unsupported window frame expression $anythingElse") - } - - /** - * Create a Scalar from boundary value according to order by column type. - * - * For the timestamp types, only days are supported for window range intervals - * - * @param orderByType the type of order by column - * @param value boundary value - * @return Scalar holding boundary value - */ - def createRangeWindowBoundary(orderByType: DType, value: Long): Scalar = { - orderByType match { - case DType.INT8 => Scalar.fromByte(value.toByte) - case DType.INT16 => Scalar.fromShort(value.toShort) - case DType.INT32 => Scalar.fromInt(value.toInt) - case DType.INT64 => Scalar.fromLong(value) - // Interval is not working for DateType - case DType.TIMESTAMP_DAYS => Scalar.durationFromLong(DType.DURATION_DAYS, value) - case DType.TIMESTAMP_MICROSECONDS => - Scalar.durationFromLong(DType.DURATION_MICROSECONDS, value) - case _ => throw new RuntimeException(s"Not supported order by type, Found $orderByType") - } - } - - /** - * Get the range boundary tuple - * @param boundary boundary expression - * @param orderByType the type of order by column - * @return ret: (Boolean, Option[Scalar]). the first element of tuple specifies if the boundary is - * unBounded, the second element of tuple specifies the Scalar created from boundary. - * When orderByType is None, the Scalar will be None. - */ - def getRangeBoundaryValue(boundary: Expression, orderByType: Option[DType]): - (Boolean, Option[Scalar]) = boundary match { - case special: GpuSpecialFrameBoundary => - val isUnBounded = special.isUnBounded - (isUnBounded, if (isUnBounded) None else orderByType.map( - createRangeWindowBoundary(_, special.value))) - case GpuLiteral(ci: CalendarInterval, CalendarIntervalType) => - // Get the total microseconds for TIMESTAMP_MICROSECONDS - var x = ci.days * TimeUnit.DAYS.toMicros(1) + ci.microseconds - if (x == Long.MinValue) x = Long.MaxValue - (false, orderByType.map(createRangeWindowBoundary(_, Math.abs(x)))) - case GpuLiteral(value, ByteType) => - var x = value.asInstanceOf[Byte] - if (x == Byte.MinValue) x = Byte.MaxValue - (false, orderByType.map(createRangeWindowBoundary(_, Math.abs(x)))) - case GpuLiteral(value, ShortType) => - var x = value.asInstanceOf[Short] - if (x == Short.MinValue) x = Short.MaxValue - (false, orderByType.map(createRangeWindowBoundary(_, Math.abs(x)))) - case GpuLiteral(value, IntegerType) => - var x = value.asInstanceOf[Int] - if (x == Int.MinValue) x = Int.MaxValue - (false, orderByType.map(createRangeWindowBoundary(_, Math.abs(x)))) - case GpuLiteral(value, LongType) => - var x = value.asInstanceOf[Long] - if (x == Long.MinValue) x = Long.MaxValue - (false, orderByType.map(createRangeWindowBoundary(_, Math.abs(x)))) - case anything => throw new UnsupportedOperationException("Unsupported window frame" + - s" expression $anything") - } } class GpuWindowSpecDefinitionMeta( @@ -815,6 +543,14 @@ case class GpuSpecialFrameBoundary(boundary : SpecialFrameBoundary) override def foldable: Boolean = false override def nullable: Boolean = false + /** + * Maps boundary to an Int value that in some cases can be used to build up the window options + * for a window aggregation. UnboundedPreceding and UnboundedFollowing produce Int.MinValue and + * Int.MaxValue respectively. In row based operations this should be fine because we cannot have + * a batch with that many rows in it anyways. For range based queries isUnbounded should be + * called too, to properly interpret the data. CurrentRow produces 0 which works for both row and + * range based queries. + */ def value : Int = { boundary match { case UnboundedPreceding => Int.MinValue @@ -825,10 +561,9 @@ case class GpuSpecialFrameBoundary(boundary : SpecialFrameBoundary) } } - def isUnBounded: Boolean = { + def isUnbounded: Boolean = { boundary match { - case UnboundedPreceding => true - case UnboundedFollowing => true + case UnboundedPreceding | UnboundedFollowing => true case _ => false } }