diff --git a/integration_tests/src/main/python/collection_ops_test.py b/integration_tests/src/main/python/collection_ops_test.py index 4d8602de360..0cefd608a6c 100644 --- a/integration_tests/src/main/python/collection_ops_test.py +++ b/integration_tests/src/main/python/collection_ops_test.py @@ -14,7 +14,7 @@ import pytest -from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error from data_gen import * from pyspark.sql.types import * from spark_session import with_cpu_session @@ -116,90 +116,166 @@ def test_sort_array_lit(data_gen, is_ascending): lambda spark: unary_op_df(spark, data_gen, length=10).select( f.sort_array(f.lit(array_lit), is_ascending))) -# We must restrict the length of sequence, since we may suffer the exception -# "Too long sequence: 2147483745. Should be <= 2147483632" or OOM. -sequence_integral_gens = [ - ByteGen(nullable=False, min_val=-20, max_val=20, special_cases=[]), - ShortGen(nullable=False, min_val=-20, max_val=20, special_cases=[]), - IntegerGen(nullable=False, min_val=-20, max_val=20, special_cases=[]), - LongGen(nullable=False, min_val=-20, max_val=20, special_cases=[]) +# For functionality test, the sequence length in each row should be limited, +# to avoid the exception as below, +# "Too long sequence: 2147483745. Should be <= 2147483632" +# And the input data should follow the rules below, +# (step > 0 && start <= stop) +# or (step < 0 && start >= stop) +# or (step == 0 && start == stop) +sequence_normal_integral_gens = [ + # (step > 0 && start <= stop) + (ByteGen(min_val=-10, max_val=20, special_cases=[]), + ByteGen(min_val=20, max_val=50, special_cases=[]), + ByteGen(min_val=1, max_val=5, special_cases=[])), + (ShortGen(min_val=-10, max_val=20, special_cases=[]), + ShortGen(min_val=20, max_val=50, special_cases=[]), + ShortGen(min_val=1, max_val=5, special_cases=[])), + (IntegerGen(min_val=-10, max_val=20, special_cases=[]), + IntegerGen(min_val=20, max_val=50, special_cases=[]), + IntegerGen(min_val=1, max_val=5, special_cases=[])), + (LongGen(min_val=-10, max_val=20, special_cases=[None]), + LongGen(min_val=20, max_val=50, special_cases=[None]), + LongGen(min_val=1, max_val=5, special_cases=[None])), + # (step < 0 && start >= stop) + (ByteGen(min_val=20, max_val=50, special_cases=[]), + ByteGen(min_val=-10, max_val=20, special_cases=[]), + ByteGen(min_val=-5, max_val=-1, special_cases=[])), + (ShortGen(min_val=20, max_val=50, special_cases=[]), + ShortGen(min_val=-10, max_val=20, special_cases=[]), + ShortGen(min_val=-5, max_val=-1, special_cases=[])), + (IntegerGen(min_val=20, max_val=50, special_cases=[]), + IntegerGen(min_val=-10, max_val=20, special_cases=[]), + IntegerGen(min_val=-5, max_val=-1, special_cases=[])), + (LongGen(min_val=20, max_val=50, special_cases=[None]), + LongGen(min_val=-10, max_val=20, special_cases=[None]), + LongGen(min_val=-5, max_val=-1, special_cases=[None])), + # (step == 0 && start == stop) + (ByteGen(min_val=20, max_val=20, special_cases=[]), + ByteGen(min_val=20, max_val=20, special_cases=[]), + ByteGen(min_val=0, max_val=0, special_cases=[])), + (ShortGen(min_val=20, max_val=20, special_cases=[]), + ShortGen(min_val=20, max_val=20, special_cases=[]), + ShortGen(min_val=0, max_val=0, special_cases=[])), + (IntegerGen(min_val=20, max_val=20, special_cases=[]), + IntegerGen(min_val=20, max_val=20, special_cases=[]), + IntegerGen(min_val=0, max_val=0, special_cases=[])), + (LongGen(min_val=20, max_val=20, special_cases=[None]), + LongGen(min_val=20, max_val=20, special_cases=[None]), + LongGen(min_val=0, max_val=0, special_cases=[None])), ] -@pytest.mark.parametrize('data_gen', sequence_integral_gens, ids=idfn) -def test_sequence_without_step(data_gen): +sequence_normal_no_step_integral_gens = [(gens[0], gens[1]) for + gens in sequence_normal_integral_gens] + +@pytest.mark.parametrize('start_gen,stop_gen', sequence_normal_no_step_integral_gens, ids=idfn) +def test_sequence_without_step(start_gen, stop_gen): assert_gpu_and_cpu_are_equal_collect( - lambda spark : - three_col_df(spark, data_gen, data_gen, data_gen) - .selectExpr("sequence(a, b)", - "sequence(a, 0)", - "sequence(0, b)")) - -# This function is to generate the correct sequence data according to below limitations. -# (step > num.zero && start <= stop) -# || (step < num.zero && start >= stop) -# || (step == num.zero && start == stop) -def get_sequence_data(data_gen, length=2048): - rand = random.Random(0) - data_gen.start(rand) - list = [] - for index in range(length): - start = data_gen.gen() - stop = data_gen.gen() - step = data_gen.gen() - # decide the direction of step - if start < stop: - step = abs(step) + 1 - elif start == stop: - step = 0 - else: - step = -(abs(step) + 1) - list.append(tuple([start, stop, step])) - # add special case - list.append(tuple([2, 2, 0])) - return list - -def get_sequence_df(spark, data, data_type): - return spark.createDataFrame( - SparkContext.getOrCreate().parallelize(data), - StructType([StructField('a', data_type), StructField('b', data_type), StructField('c', data_type)])) - -# test below case -# (2, -1, -1) -# (2, 5, 2) -# (2, 2, 0) -@pytest.mark.parametrize('data_gen', sequence_integral_gens, ids=idfn) -def test_sequence_with_step_case1(data_gen): - data = get_sequence_data(data_gen) + lambda spark: two_col_df(spark, start_gen, stop_gen).selectExpr( + "sequence(a, b)", + "sequence(a, 20)", + "sequence(20, b)")) + +@pytest.mark.parametrize('start_gen,stop_gen,step_gen', sequence_normal_integral_gens, ids=idfn) +def test_sequence_with_step(start_gen, stop_gen, step_gen): + # Get a step scalar from the 'step_gen' which follows the rules. + step_gen.start(random.Random(0)) + step_lit = step_gen.gen() assert_gpu_and_cpu_are_equal_collect( - lambda spark : - get_sequence_df(spark, data, data_gen.data_type) - .selectExpr("sequence(a, b, c)")) + lambda spark: three_col_df(spark, start_gen, stop_gen, step_gen).selectExpr( + "sequence(a, b, c)", + "sequence(a, b, {})".format(step_lit), + "sequence(a, 20, c)", + "sequence(a, 20, {})".format(step_lit), + "sequence(20, b, c)", + "sequence(20, 20, c)", + "sequence(20, b, {})".format(step_lit))) -sequence_three_cols_integral_gens = [ - (ByteGen(nullable=False, min_val=-10, max_val=10, special_cases=[]), - ByteGen(nullable=False, min_val=30, max_val=50, special_cases=[]), - ByteGen(nullable=False, min_val=1, max_val=10, special_cases=[])), - (ShortGen(nullable=False, min_val=-10, max_val=10, special_cases=[]), - ShortGen(nullable=False, min_val=30, max_val=50, special_cases=[]), - ShortGen(nullable=False, min_val=1, max_val=10, special_cases=[])), - (IntegerGen(nullable=False, min_val=-10, max_val=10, special_cases=[]), - IntegerGen(nullable=False, min_val=30, max_val=50, special_cases=[]), - IntegerGen(nullable=False, min_val=1, max_val=10, special_cases=[])), - (LongGen(nullable=False, min_val=-10, max_val=10, special_cases=[-10, 10]), - LongGen(nullable=False, min_val=30, max_val=50, special_cases=[30, 50]), - LongGen(nullable=False, min_val=1, max_val=10, special_cases=[1, 10])), +# Illegal sequence boundaries: +# step > 0, but start > stop +# step < 0, but start < stop +# step == 0, but start != stop +# +# All integral types share the same check implementation, so each case +# will not run over all the types in the tests. +sequence_illegal_boundaries_integral_gens = [ + # step > 0, but start > stop + (ShortGen(min_val=20, max_val=50, special_cases=[]), + ShortGen(min_val=-10, max_val=19, special_cases=[]), + ShortGen(min_val=1, max_val=5, special_cases=[])), + (LongGen(min_val=20, max_val=50, special_cases=[None]), + LongGen(min_val=-10, max_val=19, special_cases=[None]), + LongGen(min_val=1, max_val=5, special_cases=[None])), + # step < 0, but start < stop + (ByteGen(min_val=-10, max_val=19, special_cases=[]), + ByteGen(min_val=20, max_val=50, special_cases=[]), + ByteGen(min_val=-5, max_val=-1, special_cases=[])), + (IntegerGen(min_val=-10, max_val=19, special_cases=[]), + IntegerGen(min_val=20, max_val=50, special_cases=[]), + IntegerGen(min_val=-5, max_val=-1, special_cases=[])), + # step == 0, but start != stop + (IntegerGen(min_val=-10, max_val=19, special_cases=[]), + IntegerGen(min_val=20, max_val=50, special_cases=[]), + IntegerGen(min_val=0, max_val=0, special_cases=[])) ] -# Test the scalar case for the data start < stop and step > 0 -@pytest.mark.parametrize('start_gen,stop_gen,step_gen', sequence_three_cols_integral_gens, ids=idfn) -def test_sequence_with_step_case2(start_gen, stop_gen, step_gen): +@pytest.mark.parametrize('start_gen,stop_gen,step_gen', sequence_illegal_boundaries_integral_gens, ids=idfn) +def test_sequence_illegal_boundaries(start_gen, stop_gen, step_gen): + assert_gpu_and_cpu_error( + lambda spark:three_col_df(spark, start_gen, stop_gen, step_gen).selectExpr( + "sequence(a, b, c)").collect(), + conf = {}, error_message = "Illegal sequence boundaries") + +# Exceed the max length of a sequence +# "Too long sequence: xxxxxxxxxx. Should be <= 2147483632" +sequence_too_long_length_gens = [ + IntegerGen(min_val=2147483633, max_val=2147483633, special_cases=[]), + LongGen(min_val=2147483635, max_val=2147483635, special_cases=[None]) +] + +@pytest.mark.parametrize('stop_gen', sequence_too_long_length_gens, ids=idfn) +def test_sequence_too_long_sequence(stop_gen): + assert_gpu_and_cpu_error( + # To avoid OOM, reduce the row number to 1, it is enough to verify this case. + lambda spark:unary_op_df(spark, stop_gen, 1).selectExpr( + "sequence(0, a)").collect(), + conf = {}, error_message = "Too long sequence") + +def get_sequence_cases_mixed_df(spark, length=2048): + # Generate the sequence data following the 3 rules mixed in a single dataset. + # (step > num.zero && start <= stop) || + # (step < num.zero && start >= stop) || + # (step == num.zero && start == stop) + data_gen = IntegerGen(nullable=False, min_val=-10, max_val=10, special_cases=[]) + def get_sequence_data(gen, len): + gen.start(random.Random(0)) + list = [] + for index in range(len): + start = gen.gen() + stop = gen.gen() + step = gen.gen() + # decide the direction of step + if start < stop: + step = abs(step) + 1 + elif start == stop: + step = 0 + else: + step = -(abs(step) + 1) + list.append(tuple([start, stop, step])) + # add special case + list.append(tuple([2, 2, 0])) + return list + + mixed_schema = StructType([ + StructField('a', data_gen.data_type), + StructField('b', data_gen.data_type), + StructField('c', data_gen.data_type)]) + return spark.createDataFrame( + SparkContext.getOrCreate().parallelize(get_sequence_data(data_gen, length)), + mixed_schema) + +# test for 3 cases mixed in a single dataset +def test_sequence_with_step_mixed_cases(): assert_gpu_and_cpu_are_equal_collect( - lambda spark : - three_col_df(spark, start_gen, stop_gen, step_gen) - .selectExpr("sequence(a, b, c)", - "sequence(a, b, 2)", - "sequence(a, 20, c)", - "sequence(a, 20, 2)", - "sequence(0, b, c)", - "sequence(0, 4, c)", - "sequence(0, b, 3)"),) + lambda spark: get_sequence_cases_mixed_df(spark) + .selectExpr("sequence(a, b, c)")) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/BoolUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/BoolUtils.scala new file mode 100644 index 00000000000..30337ff8766 --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/BoolUtils.scala @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import ai.rapids.cudf.{ColumnVector, DType} + +object BoolUtils extends Arm { + + /** + * Whether all the valid rows in 'col' are true. An empty column will get true. + * null rows are skipped. + */ + def isAllValidTrue(col: ColumnVector): Boolean = { + assert(DType.BOOL8 == col.getType, "input column type is not bool") + if (col.getRowCount == 0) { + return true + } + + if (col.getRowCount == col.getNullCount) { + // all is null, equal to empty, since nulls should be skipped. + return true + } + withResource(col.all()) { allTrue => + // Guaranteed there is at least one row and not all of the rows are null, + // so result scalar must be valid + allTrue.getBoolean + } + } + +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala index 7067e6fb3b8..3bf981d2c63 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala @@ -16,19 +16,23 @@ package org.apache.spark.sql.rapids +import java.util.Optional + import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf -import ai.rapids.cudf.{BinaryOperable, ColumnVector, ColumnView, GroupByAggregation, GroupByOptions, Scalar} -import com.nvidia.spark.rapids.{DataFromReplacementRule, ExprMeta, GpuBinaryExpression, GpuColumnVector, GpuComplexTypeMergingExpression, GpuExpression, GpuLiteral, GpuMapUtils, GpuScalar, GpuTernaryExpression, GpuUnaryExpression, RapidsConf, RapidsMeta} +import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, GroupByAggregation, GroupByOptions, Scalar, Table} +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.BoolUtils.isAllValidTrue import com.nvidia.spark.rapids.GpuExpressionsUtils.columnarEvalToColumn import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.RapidsErrorUtils +import com.nvidia.spark.rapids.shims.v2.{RapidsErrorUtils, ShimExpression} import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ImplicitCastInputTypes, RowOrdering, Sequence, TimeZoneAwareExpression} import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH import org.apache.spark.unsafe.types.UTF8String case class GpuConcat(children: Seq[Expression]) extends GpuComplexTypeMergingExpression { @@ -407,260 +411,243 @@ class GpuSequenceMeta( } override def convertToGpu(): GpuExpression = { - if (expr.stepOpt.isDefined) { - val Seq(start, stop, step) = childExprs.map(_.convertToGpu()) - GpuSequenceWithStep(start, stop, step, expr.timeZoneId) - } else { - val Seq(start, stop) = childExprs.map(_.convertToGpu()) - GpuSequence(start, stop, expr.timeZoneId) - } - } -} - -object GpuSequenceUtil { - - def numberScalar(dt: DataType, value: Int): Scalar = dt match { - case ByteType => Scalar.fromByte(value.toByte) - case ShortType => Scalar.fromShort(value.toShort) - case IntegerType => Scalar.fromInt(value) - case LongType => Scalar.fromLong(value.toLong) - case _ => - throw new IllegalArgumentException("wrong data type: " + dt) + val (startExpr, stopExpr, stepOpt) = if (expr.stepOpt.isDefined) { + val Seq(start, stop, step) = childExprs.map(_.convertToGpu()) + (start, stop, Some(step)) + } else { + val Seq(start, stop) = childExprs.map(_.convertToGpu()) + (start, stop, None) + } + GpuSequence(startExpr, stopExpr, stepOpt, expr.timeZoneId) } } -/** GpuSequence without step */ -case class GpuSequence(start: Expression, stop: Expression, timeZoneId: Option[String] = None) - extends GpuBinaryExpression with TimeZoneAwareExpression { - - override def left: Expression = start - - override def right: Expression = stop +object GpuSequenceUtil extends Arm { + + private def checkSequenceInputs( + start: ColumnVector, + stop: ColumnVector, + step: ColumnVector): Unit = { + // Keep the same requirement with Spark: + // (step > 0 && start <= stop) || (step < 0 && start >= stop) || (step == 0 && start == stop) + withResource(Scalar.fromByte(0.toByte)) { zero => + // The check should ignore each row (Row(start, stop, step)) that contains at least + // one null element according to Spark's code. Thanks to the cudf binary ops, who ignore + // nulls already, skipping nulls can be done without any additional process. + // + // Because the filtered table (e.g. upTbl) in each rule check excludes the rows that the + // step is null. Next a null row will be produced when comparing start or stop when any + // of them is null, and the nulls are skipped in the final assertion 'isAllValidTrue'. + withResource(new Table(start, stop)) { startStopTable => + // (step > 0 && start <= stop) + val upTbl = withResource(step.greaterThan(zero)) { positiveStep => + startStopTable.filter(positiveStep) + } + val allUp = withResource(upTbl) { _ => + upTbl.getColumn(0).lessOrEqualTo(upTbl.getColumn(1)) + } + withResource(allUp) { _ => + require(isAllValidTrue(allUp), "Illegal sequence boundaries: step > 0 but start > stop") + } - override def dataType: DataType = ArrayType(start.dataType, containsNull = false) + // (step < 0 && start >= stop) + val downTbl = withResource(step.lessThan(zero)) { negativeStep => + startStopTable.filter(negativeStep) + } + val allDown = withResource(downTbl) { _ => + downTbl.getColumn(0).greaterOrEqualTo(downTbl.getColumn(1)) + } + withResource(allDown) { _ => + require(isAllValidTrue(allDown), + "Illegal sequence boundaries: step < 0 but start < stop") + } - override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = - copy(timeZoneId = Some(timeZoneId)) + // (step == 0 && start == stop) + val equalTbl = withResource(step.equalTo(zero)) { zeroStep => + startStopTable.filter(zeroStep) + } + val allEq = withResource(equalTbl) { _ => + equalTbl.getColumn(0).equalTo(equalTbl.getColumn(1)) + } + withResource(allEq) { _ => + require(isAllValidTrue(allEq), + "Illegal sequence boundaries: step == 0 but start != stop") + } + } + } // end of zero + } /** - * Calculate the size and step (1 or -1) between start and stop both inclusive - * size = |stop - start| + 1 - * step = 1 if stop >= start else -1 - * @param start first values in the result sequences - * @param stop end values in the result sequences - * @return (size, step) + * Compute the size of each sequence according to 'start', 'stop' and 'step'. + * A row (Row[start, stop, step]) contains at least one null element will produce + * a null in the output. + * + * The returned column should be closed. */ - private def calculateSizeAndStep(start: BinaryOperable, stop: BinaryOperable, dt: DataType): - Seq[ColumnVector] = { - withResource(stop.sub(start)) { difference => - withResource(GpuSequenceUtil.numberScalar(dt, 1)) { one => - val step = withResource(GpuSequenceUtil.numberScalar(dt, -1)) { negativeOne => - withResource(GpuSequenceUtil.numberScalar(dt, 0)) { zero => - withResource(difference.greaterOrEqualTo(zero)) { pred => - pred.ifElse(one, negativeOne) - } - } + def computeSequenceSizes( + start: ColumnVector, + stop: ColumnVector, + step: ColumnVector): ColumnVector = { + checkSequenceInputs(start, stop, step) + + // Spark's algorithm to get the length (aka size) + // ``` Scala + // size = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / estimatedStep.toLong + // require(size <= MAX_ROUNDED_ARRAY_LENGTH, + // s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH") + // size.toInt + // ``` + val sizeAsLong = withResource(Scalar.fromLong(1L)) { one => + val diff = withResource(stop.castTo(DType.INT64)) { stopAsLong => + withResource(start.castTo(DType.INT64)) { startAsLong => + stopAsLong.sub(startAsLong) + } + } + val quotient = withResource(diff) { _ => + withResource(step.castTo(DType.INT64)) { stepAsLong => + diff.div(stepAsLong) } - val size = closeOnExcept(step) { _ => - withResource(difference.abs()) { absDifference => - absDifference.add(one) + } + // actualSize = 1L + (stop.toLong - start.toLong) / estimatedStep.toLong + val actualSize = withResource(quotient) { quotient => + quotient.add(one, DType.INT64) + } + withResource(actualSize) { _ => + val mergedEquals = withResource(start.equalTo(stop)) { equals => + if (step.hasNulls) { + // Also set the row to null where step is null. + equals.mergeAndSetValidity(BinaryOp.BITWISE_AND, equals, step) + } else { + equals.incRefCount() } } - Seq(size, step) + withResource(mergedEquals) { _ => + mergedEquals.ifElse(one, actualSize) + } } } - } - - override def doColumnar(start: GpuColumnVector, stop: GpuColumnVector): ColumnVector = { - withResource(calculateSizeAndStep(start.getBase, stop.getBase, start.dataType())) { ret => - ColumnVector.sequence(start.getBase, ret(0), ret(1)) - } - } - override def doColumnar(start: GpuScalar, stop: GpuColumnVector): ColumnVector = { - withResource(calculateSizeAndStep(start.getBase, stop.getBase, stop.dataType())) { ret => - withResource(ColumnVector.fromScalar(start.getBase, stop.getRowCount.toInt)) { startV => - ColumnVector.sequence(startV, ret(0), ret(1)) + withResource(sizeAsLong) { _ => + // check max size + withResource(Scalar.fromInt(MAX_ROUNDED_ARRAY_LENGTH)) { maxLen => + withResource(sizeAsLong.lessOrEqualTo(maxLen)) { allValid => + require(isAllValidTrue(allValid), + s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH") + } } + // cast to int and return + sizeAsLong.castTo(DType.INT32) } } - override def doColumnar(start: GpuColumnVector, stop: GpuScalar): ColumnVector = { - withResource(calculateSizeAndStep(start.getBase, stop.getBase, start.dataType())) { ret => - ColumnVector.sequence(start.getBase, ret(0), ret(1)) - } - } - - override def doColumnar(numRows: Int, start: GpuScalar, stop: GpuScalar): ColumnVector = { - val startV = GpuColumnVector.from(ColumnVector.fromScalar(start.getBase, numRows), - start.dataType) - doColumnar(startV, stop) - } } -/** GpuSequence with step */ -case class GpuSequenceWithStep(start: Expression, stop: Expression, step: Expression, - timeZoneId: Option[String] = None) extends GpuTernaryExpression with TimeZoneAwareExpression { +case class GpuSequence(start: Expression, stop: Expression, stepOpt: Option[Expression], + timeZoneId: Option[String] = None) extends TimeZoneAwareExpression with GpuExpression + with ShimExpression { + + import GpuSequenceUtil._ + + override def dataType: ArrayType = ArrayType(start.dataType, containsNull = false) override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Some(timeZoneId)) - override def first: Expression = start + override def children: Seq[Expression] = Seq(start, stop) ++ stepOpt - override def second: Expression = stop - - override def third: Expression = step + override def nullable: Boolean = children.exists(_.nullable) - override def dataType: DataType = ArrayType(start.dataType, containsNull = false) + override def foldable: Boolean = children.forall(_.foldable) - private def calculateSize( - start: BinaryOperable, - stop: BinaryOperable, - step: BinaryOperable, - rows: Int, - dt: DataType): ColumnVector = { - // First, calculate sizeWithNegative=floor((stop-start)/step)+1. - // if step = 0, the div operation in cudf will get MIN_VALUE, which is ok for this case, - // since when size < 0, cudf will not generate sequence - // Second, calculate size = if(sizeWithNegative < 0) 0 else sizeWithNegative - // Third, if (start == stop && step == 0), let size = 1. - withResource(GpuSequenceUtil.numberScalar(dt, 1)) { one => - withResource(GpuSequenceUtil.numberScalar(dt, 0)) { zero => + override def columnarEval(batch: ColumnarBatch): Any = { + withResource(columnarEvalToColumn(start, batch)) { startGpuCol => + withResource(stepOpt.map(columnarEvalToColumn(_, batch))) { stepGpuColOpt => + val startCol = startGpuCol.getBase - val (sizeWithNegative, diffHasZero) = withResource(stop.sub(start)) { difference => - // sizeWithNegative=floor((stop-start)/step)+1 - val sizeWithNegative = withResource(difference.floorDiv(step)) { quotient => - quotient.add(one) + // 1 Compute the sequence size for each row. + val (sizeCol, stepCol) = withResource(columnarEvalToColumn(stop, batch)) { stopGpuCol => + val stopCol = stopGpuCol.getBase + val steps = stepGpuColOpt.map(_.getBase.incRefCount()) + .getOrElse(defaultStepsFunc(startCol, stopCol)) + closeOnExcept(steps) { _ => + (computeSequenceSizes(startCol, stopCol, steps), steps) } - val tmpDiffHasZero = closeOnExcept(sizeWithNegative) { _ => - difference.equalTo(zero) - } - (sizeWithNegative, tmpDiffHasZero) } - val tmpSize = closeOnExcept(diffHasZero) { _ => - // tmpSize = if(sizeWithNegative < 0) 0 else sizeWithNegative - withResource(sizeWithNegative) { _ => - withResource(sizeWithNegative.greaterOrEqualTo(zero)) { pred => - pred.ifElse(sizeWithNegative, zero) - } + // 2 Generate the sequence + // + // cudf 'sequence' requires 'step' has the same type with 'start'. + // And the step type may differ due to the default steps. + val castedStepCol = withResource(stepCol) { _ => + closeOnExcept(sizeCol) { _ => + stepCol.castTo(startCol.getType) } } - - // when start==stop && step==0, size will be 0. - // but we should change size to 1 - withResource(tmpSize) { tmpSize => - withResource(diffHasZero) { diffHasZero => - step match { - case stepScalar: Scalar => - withResource(ColumnVector.fromScalar(stepScalar, rows)) { stepV => - withResource(stepV.equalTo(zero)) { stepHasZero => - withResource(diffHasZero.and(stepHasZero)) { predWithZero => - predWithZero.ifElse(one, tmpSize) - } - } - } - case _ => - withResource(step.equalTo(zero)) { stepHasZero => - withResource(diffHasZero.and(stepHasZero)) { predWithZero => - predWithZero.ifElse(one, tmpSize) - } - } - } - } + withResource(Seq(sizeCol, castedStepCol)) { _ => + GpuColumnVector.from(genSequence(startCol, sizeCol, castedStepCol), dataType) } } } } - override def doColumnar( - start: GpuColumnVector, - stop: GpuColumnVector, - step: GpuColumnVector): ColumnVector = { - withResource(calculateSize(start.getBase, stop.getBase, step.getBase, start.getRowCount.toInt, - start.dataType())) { size => - ColumnVector.sequence(start.getBase, size, step.getBase) - } - } - - override def doColumnar( - start: GpuScalar, - stop: GpuColumnVector, - step: GpuColumnVector): ColumnVector = { - withResource(calculateSize(start.getBase, stop.getBase, step.getBase, stop.getRowCount.toInt, - start.dataType)) { size => - withResource(ColumnVector.fromScalar(start.getBase, stop.getRowCount.toInt)) { startV => - ColumnVector.sequence(startV, size, step.getBase) + @transient + private lazy val defaultStepsFunc: (ColumnView, ColumnView) => ColumnVector = + dataType.elementType match { + case _: IntegralType => + // Default step: + // start > stop, step == -1 + // start <= stop, step == 1 + (starts, stops) => { + // It is ok to always use byte, since it will be casted to the same type before + // going into cudf sequence. Besides byte saves memory, and does not cause any + // type promotion during computation. + withResource(Scalar.fromByte((-1).toByte)) { minusOne => + withResource(Scalar.fromByte(1.toByte)) { one => + withResource(starts.greaterThan(stops)) { decrease => + decrease.ifElse(minusOne, one) + } + } + } } - } - } - - override def doColumnar( - start: GpuScalar, - stop: GpuScalar, - step: GpuColumnVector): ColumnVector = { - withResource(ColumnVector.fromScalar(start.getBase, step.getRowCount.toInt)) { startV => - withResource(calculateSize(startV, stop.getBase, step.getBase, step.getRowCount.toInt, - start.dataType)) { size => - ColumnVector.sequence(startV, size, step.getBase) + // Timestamp and Date will come soon + // case TimestampType => + // case DateType => + } + + private def genSequence( + start: ColumnView, + size: ColumnView, + step: ColumnView): ColumnVector = { + // size is calculated from start, stop and step, so its validity mask is equal to + // the merged validity of the three columns, and can be used as the final output + // validity mask directly. + // Then checking nulls only in size column is enough. + if(size.getNullCount > 0) { + // Nulls are not acceptable in cudf 'list::sequences'. (Pls refer to + // https://github.com/rapidsai/cudf/issues/10012), + // + // So replace the nulls with 0 for size, and create temp views for start and + // stop with forcing null count to be 0. + val sizeNoNull = withResource(Scalar.fromInt(0)) { zero => + size.replaceNulls(zero) } - } - } - - override def doColumnar( - start: GpuScalar, - stop: GpuColumnVector, - step: GpuScalar): ColumnVector = { - withResource(calculateSize(start.getBase, stop.getBase, step.getBase, stop.getRowCount.toInt, - start.dataType)) { size => - withResource(ColumnVector.fromScalar(start.getBase, stop.getRowCount.toInt)) { startV => - withResource(ColumnVector.fromScalar(step.getBase, stop.getRowCount.toInt)) { stepV => - ColumnVector.sequence(startV, size, stepV) + val ret = withResource(sizeNoNull) { _ => + val startNoNull = new ColumnView(start.getType, start.getRowCount, Optional.of(0L), + start.getData, null) + withResource(startNoNull) { _ => + val stepNoNull = new ColumnView(step.getType, step.getRowCount, Optional.of(0L), + step.getData, null) + withResource(stepNoNull) { _ => + ColumnVector.sequence(startNoNull, sizeNoNull, stepNoNull) + } } } - } - } - - override def doColumnar( - start: GpuColumnVector, - stop: GpuScalar, - step: GpuColumnVector): ColumnVector = { - withResource(calculateSize(start.getBase, stop.getBase, step.getBase, start.getRowCount.toInt, - start.dataType())) { size => - ColumnVector.sequence(start.getBase, size, step.getBase) - } - } - - override def doColumnar( - start: GpuColumnVector, - stop: GpuScalar, - step: GpuScalar): ColumnVector = { - withResource(calculateSize(start.getBase, stop.getBase, step.getBase, start.getRowCount.toInt, - start.dataType())) { size => - withResource(ColumnVector.fromScalar(step.getBase, start.getRowCount.toInt)) { stepV => - ColumnVector.sequence(start.getBase, size, stepV) - } - } - } - - override def doColumnar( - start: GpuColumnVector, - stop: GpuColumnVector, - step: GpuScalar): ColumnVector = - withResource(calculateSize(start.getBase, stop.getBase, step.getBase, start.getRowCount.toInt, - start.dataType())) { size => - withResource(ColumnVector.fromScalar(step.getBase, start.getRowCount.toInt)) { stepV => - ColumnVector.sequence(start.getBase, size, stepV) + withResource(ret) { _ => + // Restore the null rows by setting the validity mask. + ret.mergeAndSetValidity(BinaryOp.BITWISE_AND, size) } + } else { + ColumnVector.sequence(start, size, step) } - - override def doColumnar( - numRows: Int, - start: GpuScalar, - stop: GpuScalar, - step: GpuScalar): ColumnVector = { - val startV = GpuColumnVector.from(ColumnVector.fromScalar(start.getBase, numRows), - start.dataType) - doColumnar(startV, stop, step) } }