diff --git a/integration_tests/src/main/python/array_test.py b/integration_tests/src/main/python/array_test.py index 2b7485b42eb..942f2460d4e 100644 --- a/integration_tests/src/main/python/array_test.py +++ b/integration_tests/src/main/python/array_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,9 +16,7 @@ from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error, assert_gpu_fallback_collect from data_gen import * -from functools import reduce -from spark_session import is_before_spark_311 -from marks import allow_non_gpu +from spark_session import is_before_spark_311, is_before_spark_330 from pyspark.sql.types import * from pyspark.sql.types import IntegralType from pyspark.sql.functions import array_contains, col, first, isnan, lit, element_at @@ -127,11 +125,22 @@ def main_df(spark): @pytest.mark.skipif(is_before_spark_311(), reason="Only in Spark 3.1.1 + ANSI mode, array index throws on out of range indexes") @pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn) def test_get_array_item_ansi_fail(data_gen): + message = "org.apache.spark.SparkArrayIndexOutOfBoundsException" if not is_before_spark_330() else "java.lang.ArrayIndexOutOfBoundsException" assert_gpu_and_cpu_error(lambda spark: unary_op_df( spark, data_gen).select(col('a')[100]).collect(), conf={'spark.sql.ansi.enabled':True, 'spark.sql.legacy.allowNegativeScaleOfDecimal': True}, - error_message='java.lang.ArrayIndexOutOfBoundsException') + error_message=message) + +@pytest.mark.skipif(is_before_spark_311(), reason="Only in Spark 3.1.1 + ANSI mode, array index throws on out of range indexes") +@pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn) +def test_element_at_index_ansi_fail(data_gen): + message = "org.apache.spark.SparkArrayIndexOutOfBoundsException" if not is_before_spark_330() else "java.lang.ArrayIndexOutOfBoundsException" + assert_gpu_and_cpu_error(lambda spark: unary_op_df( + spark, data_gen).select(element_at(col('a'), 100)).collect(), + conf={'spark.sql.ansi.enabled':True, + 'spark.sql.legacy.allowNegativeScaleOfDecimal': True}, + error_message=message) @pytest.mark.skipif(not is_before_spark_311(), reason="For Spark before 3.1.1 + ANSI mode, null will be returned instead of an exception if index is out of range") @pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn) diff --git a/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/RapidsErrorUtils.scala b/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/RapidsErrorUtils.scala new file mode 100644 index 00000000000..8105b2349df --- /dev/null +++ b/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/RapidsErrorUtils.scala @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.v2 + +import ai.rapids.cudf.ColumnVector + +object RapidsErrorUtils { + def throwArrayIndexOutOfBoundsException(index: Int, numElements: Int): ColumnVector = { + throw new ArrayIndexOutOfBoundsException(s"index $index is beyond the max index allowed " + + s"${numElements - 1}") + } +} diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/RapidsErrorUtils.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/RapidsErrorUtils.scala new file mode 100644 index 00000000000..de67b5e5cf7 --- /dev/null +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/RapidsErrorUtils.scala @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.v2 + +import ai.rapids.cudf.ColumnVector + +import org.apache.spark.sql.errors.QueryExecutionErrors + +object RapidsErrorUtils { + def throwArrayIndexOutOfBoundsException(index: Int, numElements: Int): ColumnVector = { + throw QueryExecutionErrors.invalidArrayIndexError(index, numElements) + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala index 270b2cef152..7067e6fb3b8 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala @@ -23,6 +23,7 @@ import ai.rapids.cudf.{BinaryOperable, ColumnVector, ColumnView, GroupByAggregat import com.nvidia.spark.rapids.{DataFromReplacementRule, ExprMeta, GpuBinaryExpression, GpuColumnVector, GpuComplexTypeMergingExpression, GpuExpression, GpuLiteral, GpuMapUtils, GpuScalar, GpuTernaryExpression, GpuUnaryExpression, RapidsConf, RapidsMeta} import com.nvidia.spark.rapids.GpuExpressionsUtils.columnarEvalToColumn import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import com.nvidia.spark.rapids.shims.v2.RapidsErrorUtils import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ImplicitCastInputTypes, RowOrdering, Sequence, TimeZoneAwareExpression} @@ -141,11 +142,9 @@ case class GpuElementAt(left: Expression, right: Expression, failOnError: Boolea // Note: when the column is containing all null arrays, CPU will not throw, so make // GPU to behave the same. if (failOnError && - minNumElements < math.abs(ordinalValue) && - lhs.getBase.getNullCount != lhs.getBase.getRowCount) { - throw new ArrayIndexOutOfBoundsException( - s"Invalid index: $ordinalValue, minimum numElements in this ColumnVector: " + - s"$minNumElements") + minNumElements < math.abs(ordinalValue) && + lhs.getBase.getNullCount != lhs.getBase.getRowCount) { + RapidsErrorUtils.throwArrayIndexOutOfBoundsException(ordinalValue, minNumElements) } else { if (ordinalValue > 0) { // Positive index diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index 47d90b4b72e..8dd0635c988 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf.ColumnVector import com.nvidia.spark.rapids.{BinaryExprMeta, DataFromReplacementRule, DataTypeUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, RapidsConf, RapidsMeta} import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExpression +import com.nvidia.spark.rapids.shims.v2.{RapidsErrorUtils, ShimUnaryExpression} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult @@ -124,11 +124,9 @@ case class GpuGetArrayItem(child: Expression, ordinal: Expression, failOnError: withResource(numElementsCV.min) { minScalar => val minNumElements = minScalar.getInt if (failOnError && - (ordinal < 0 || minNumElements < ordinal + 1) && - numElementsCV.getRowCount != numElementsCV.getNullCount) { - throw new ArrayIndexOutOfBoundsException( - s"Invalid index: ${ordinal}, minimum numElements in this ColumnVector: " + - s"$minNumElements") + (ordinal < 0 || minNumElements < ordinal + 1) && + numElementsCV.getRowCount != numElementsCV.getNullCount) { + RapidsErrorUtils.throwArrayIndexOutOfBoundsException(ordinal, minNumElements) } else if (!failOnError && ordinal < 0) { GpuColumnVector.columnVectorFromNull(lhs.getRowCount.toInt, dataType) } else {