Skip to content

Commit

Permalink
Throw SparkArrayIndexOutOfBoundsException for Spark 3.3.0+ (NVIDIA#4464)
Browse files Browse the repository at this point in the history
* Throw the out of bounds exception to match the CPU

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* addressed review concerns

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* removed the logic for dynamic loading of shims, instead use a compiled class directly

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* addressed more review comments

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* added test for element_at and throw the correct exception

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* added the missed copyrights

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

Co-authored-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
razajafri and razajafri authored Jan 7, 2022
1 parent 54c0f94 commit 01d386c
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 17 deletions.
19 changes: 14 additions & 5 deletions integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,9 +16,7 @@

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error, assert_gpu_fallback_collect
from data_gen import *
from functools import reduce
from spark_session import is_before_spark_311
from marks import allow_non_gpu
from spark_session import is_before_spark_311, is_before_spark_330
from pyspark.sql.types import *
from pyspark.sql.types import IntegralType
from pyspark.sql.functions import array_contains, col, first, isnan, lit, element_at
Expand Down Expand Up @@ -127,11 +125,22 @@ def main_df(spark):
@pytest.mark.skipif(is_before_spark_311(), reason="Only in Spark 3.1.1 + ANSI mode, array index throws on out of range indexes")
@pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn)
def test_get_array_item_ansi_fail(data_gen):
message = "org.apache.spark.SparkArrayIndexOutOfBoundsException" if not is_before_spark_330() else "java.lang.ArrayIndexOutOfBoundsException"
assert_gpu_and_cpu_error(lambda spark: unary_op_df(
spark, data_gen).select(col('a')[100]).collect(),
conf={'spark.sql.ansi.enabled':True,
'spark.sql.legacy.allowNegativeScaleOfDecimal': True},
error_message='java.lang.ArrayIndexOutOfBoundsException')
error_message=message)

@pytest.mark.skipif(is_before_spark_311(), reason="Only in Spark 3.1.1 + ANSI mode, array index throws on out of range indexes")
@pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn)
def test_element_at_index_ansi_fail(data_gen):
message = "org.apache.spark.SparkArrayIndexOutOfBoundsException" if not is_before_spark_330() else "java.lang.ArrayIndexOutOfBoundsException"
assert_gpu_and_cpu_error(lambda spark: unary_op_df(
spark, data_gen).select(element_at(col('a'), 100)).collect(),
conf={'spark.sql.ansi.enabled':True,
'spark.sql.legacy.allowNegativeScaleOfDecimal': True},
error_message=message)

@pytest.mark.skipif(not is_before_spark_311(), reason="For Spark before 3.1.1 + ANSI mode, null will be returned instead of an exception if index is out of range")
@pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.shims.v2

import ai.rapids.cudf.ColumnVector

object RapidsErrorUtils {
def throwArrayIndexOutOfBoundsException(index: Int, numElements: Int): ColumnVector = {
throw new ArrayIndexOutOfBoundsException(s"index $index is beyond the max index allowed " +
s"${numElements - 1}")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.shims.v2

import ai.rapids.cudf.ColumnVector

import org.apache.spark.sql.errors.QueryExecutionErrors

object RapidsErrorUtils {
def throwArrayIndexOutOfBoundsException(index: Int, numElements: Int): ColumnVector = {
throw QueryExecutionErrors.invalidArrayIndexError(index, numElements)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import ai.rapids.cudf.{BinaryOperable, ColumnVector, ColumnView, GroupByAggregat
import com.nvidia.spark.rapids.{DataFromReplacementRule, ExprMeta, GpuBinaryExpression, GpuColumnVector, GpuComplexTypeMergingExpression, GpuExpression, GpuLiteral, GpuMapUtils, GpuScalar, GpuTernaryExpression, GpuUnaryExpression, RapidsConf, RapidsMeta}
import com.nvidia.spark.rapids.GpuExpressionsUtils.columnarEvalToColumn
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.v2.RapidsErrorUtils

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ImplicitCastInputTypes, RowOrdering, Sequence, TimeZoneAwareExpression}
Expand Down Expand Up @@ -141,11 +142,9 @@ case class GpuElementAt(left: Expression, right: Expression, failOnError: Boolea
// Note: when the column is containing all null arrays, CPU will not throw, so make
// GPU to behave the same.
if (failOnError &&
minNumElements < math.abs(ordinalValue) &&
lhs.getBase.getNullCount != lhs.getBase.getRowCount) {
throw new ArrayIndexOutOfBoundsException(
s"Invalid index: $ordinalValue, minimum numElements in this ColumnVector: " +
s"$minNumElements")
minNumElements < math.abs(ordinalValue) &&
lhs.getBase.getNullCount != lhs.getBase.getRowCount) {
RapidsErrorUtils.throwArrayIndexOutOfBoundsException(ordinalValue, minNumElements)
} else {
if (ordinalValue > 0) {
// Positive index
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids
import ai.rapids.cudf.ColumnVector
import com.nvidia.spark.rapids.{BinaryExprMeta, DataFromReplacementRule, DataTypeUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, RapidsConf, RapidsMeta}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.v2.ShimUnaryExpression
import com.nvidia.spark.rapids.shims.v2.{RapidsErrorUtils, ShimUnaryExpression}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
Expand Down Expand Up @@ -124,11 +124,9 @@ case class GpuGetArrayItem(child: Expression, ordinal: Expression, failOnError:
withResource(numElementsCV.min) { minScalar =>
val minNumElements = minScalar.getInt
if (failOnError &&
(ordinal < 0 || minNumElements < ordinal + 1) &&
numElementsCV.getRowCount != numElementsCV.getNullCount) {
throw new ArrayIndexOutOfBoundsException(
s"Invalid index: ${ordinal}, minimum numElements in this ColumnVector: " +
s"$minNumElements")
(ordinal < 0 || minNumElements < ordinal + 1) &&
numElementsCV.getRowCount != numElementsCV.getNullCount) {
RapidsErrorUtils.throwArrayIndexOutOfBoundsException(ordinal, minNumElements)
} else if (!failOnError && ordinal < 0) {
GpuColumnVector.columnVectorFromNull(lhs.getRowCount.toInt, dataType)
} else {
Expand Down

0 comments on commit 01d386c

Please sign in to comment.