diff --git a/integration_tests/src/main/python/cache_test.py b/integration_tests/src/main/python/cache_test.py index 38518385f00..94ed718434b 100644 --- a/integration_tests/src/main/python/cache_test.py +++ b/integration_tests/src/main/python/cache_test.py @@ -202,6 +202,9 @@ def write_read_parquet_cached(spark): # rapids-spark doesn't support LEGACY read for parquet conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED', 'spark.sql.legacy.parquet.datetimeRebaseModeInRead' : 'CORRECTED', + # set the int96 rebase mode values because its LEGACY in databricks which will preclude this op from running on GPU + 'spark.sql.legacy.parquet.int96RebaseModeInWrite' : 'CORRECTED', + 'spark.sql.legacy.parquet.int96RebaseModeInRead' : 'CORRECTED', 'spark.sql.inMemoryColumnarStorage.enableVectorizedReader' : enable_vectorized, 'spark.sql.parquet.outputTimestampType': ts_write} diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py index d353c9c4a17..bcb1417dced 100644 --- a/integration_tests/src/main/python/hash_aggregate_test.py +++ b/integration_tests/src/main/python/hash_aggregate_test.py @@ -1160,6 +1160,7 @@ def do_it(spark): @pytest.mark.parametrize('data_gen', _no_overflow_ansi_gens, ids=idfn) +@ignore_order(local=True) def test_no_fallback_when_ansi_enabled(data_gen): def do_it(spark): df = gen_df(spark, [('a', data_gen), ('b', data_gen)], length=100) diff --git a/integration_tests/src/main/python/parquet_test.py b/integration_tests/src/main/python/parquet_test.py index 37434815913..b13c786a702 100644 --- a/integration_tests/src/main/python/parquet_test.py +++ b/integration_tests/src/main/python/parquet_test.py @@ -82,6 +82,8 @@ def test_read_round_trip(spark_tmp_path, parquet_gens, read_func, reader_confs, conf=rebase_write_corrected_conf) all_confs = copy_and_update(reader_confs, { 'spark.sql.sources.useV1SourceList': v1_enabled_list, + # set the int96 rebase mode values because its LEGACY in databricks which will preclude this op from running on GPU + 'spark.sql.legacy.parquet.int96RebaseModeInRead' : 'CORRECTED', 'spark.sql.legacy.parquet.datetimeRebaseModeInRead': 'CORRECTED'}) # once https://github.com/NVIDIA/spark-rapids/issues/1126 is in we can remove spark.sql.legacy.parquet.datetimeRebaseModeInRead config which is a workaround # for nested timestamp/date support diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py index 11ff2b19d4a..a75ebfd4718 100644 --- a/integration_tests/src/main/python/parquet_write_test.py +++ b/integration_tests/src/main/python/parquet_write_test.py @@ -23,6 +23,7 @@ import pyspark.sql.functions as f import pyspark.sql.utils import random +from spark_session import is_before_spark_311 # test with original parquet file reader, the multi-file parallel reader for cloud, and coalesce file reader for # non-cloud @@ -41,6 +42,11 @@ def limited_timestamp(nullable=True): return TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(2262, 4, 11, tzinfo=timezone.utc), nullable=nullable) +# TODO - https://github.com/NVIDIA/spark-rapids/issues/1130 to handle TIMESTAMP_MILLIS +# TODO - we are limiting the INT96 values, see https://github.com/rapidsai/cudf/issues/8070 +def limited_int96(): + return TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(2262, 4, 11, tzinfo=timezone.utc)) + parquet_basic_gen =[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, date_gen, # we are limiting TimestampGen to avoid overflowing the INT96 value @@ -214,24 +220,44 @@ def test_write_sql_save_table(spark_tmp_path, parquet_gens, ts_type, spark_tmp_t data_path, conf=all_confs) -def writeParquetUpgradeCatchException(spark, df, data_path, spark_tmp_table_factory, ts_rebase, ts_write): +def writeParquetUpgradeCatchException(spark, df, data_path, spark_tmp_table_factory, int96_rebase, datetime_rebase, ts_write): spark.conf.set('spark.sql.parquet.outputTimestampType', ts_write) - spark.conf.set('spark.sql.legacy.parquet.datetimeRebaseModeInWrite', ts_rebase) - spark.conf.set('spark.sql.legacy.parquet.int96RebaseModeInWrite', ts_rebase) # for spark 310 + spark.conf.set('spark.sql.legacy.parquet.datetimeRebaseModeInWrite', datetime_rebase) + spark.conf.set('spark.sql.legacy.parquet.int96RebaseModeInWrite', int96_rebase) # for spark 310 with pytest.raises(Exception) as e_info: df.coalesce(1).write.format("parquet").mode('overwrite').option("path", data_path).saveAsTable(spark_tmp_table_factory.get()) assert e_info.match(r".*SparkUpgradeException.*") # TODO - https://github.com/NVIDIA/spark-rapids/issues/1130 to handle TIMESTAMP_MILLIS # TODO - we are limiting the INT96 values, see https://github.com/rapidsai/cudf/issues/8070 -@pytest.mark.parametrize('ts_write_data_gen', [('INT96', TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(2262, 4, 11, tzinfo=timezone.utc))), - ('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc)))]) -@pytest.mark.parametrize('ts_rebase', ['EXCEPTION']) -def test_ts_write_fails_datetime_exception(spark_tmp_path, ts_write_data_gen, ts_rebase, spark_tmp_table_factory): +@pytest.mark.parametrize('ts_write_data_gen', [('INT96', limited_int96()), ('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc)))]) +@pytest.mark.parametrize('rebase', ["CORRECTED","EXCEPTION"]) +def test_ts_write_fails_datetime_exception(spark_tmp_path, ts_write_data_gen, spark_tmp_table_factory, rebase): ts_write, gen = ts_write_data_gen data_path = spark_tmp_path + '/PARQUET_DATA' - with_gpu_session( - lambda spark : writeParquetUpgradeCatchException(spark, unary_op_df(spark, gen), data_path, spark_tmp_table_factory, ts_rebase, ts_write)) + int96_rebase = "EXCEPTION" if (ts_write == "INT96") else rebase + date_time_rebase = "EXCEPTION" if (ts_write == "TIMESTAMP_MICROS") else rebase + if is_before_spark_311() and ts_write == 'INT96': + all_confs = {'spark.sql.parquet.outputTimestampType': ts_write} + all_confs.update({'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': date_time_rebase, + 'spark.sql.legacy.parquet.int96RebaseModeInWrite': int96_rebase}) + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.parquet(path), + lambda spark, path: spark.read.parquet(path), + data_path, + conf=all_confs) + else: + with_gpu_session( + lambda spark : writeParquetUpgradeCatchException(spark, + unary_op_df(spark, gen), + data_path, + spark_tmp_table_factory, + int96_rebase, date_time_rebase, ts_write)) + with_cpu_session( + lambda spark: writeParquetUpgradeCatchException(spark, + unary_op_df(spark, gen), data_path, + spark_tmp_table_factory, + int96_rebase, date_time_rebase, ts_write)) def writeParquetNoOverwriteCatchException(spark, df, data_path, table_name): with pytest.raises(Exception) as e_info: @@ -319,6 +345,27 @@ def generate_map_with_empty_validity(spark, path): lambda spark, path: spark.read.parquet(path), data_path) +@pytest.mark.parametrize('ts_write_data_gen', [('INT96', limited_int96()), ('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc)))]) +@pytest.mark.parametrize('date_time_rebase_write', ["CORRECTED"]) +@pytest.mark.parametrize('date_time_rebase_read', ["EXCEPTION", "CORRECTED"]) +@pytest.mark.parametrize('int96_rebase_write', ["CORRECTED"]) +@pytest.mark.parametrize('int96_rebase_read', ["EXCEPTION", "CORRECTED"]) +def test_roundtrip_with_rebase_values(spark_tmp_path, ts_write_data_gen, date_time_rebase_read, + date_time_rebase_write, int96_rebase_read, int96_rebase_write): + ts_write, gen = ts_write_data_gen + data_path = spark_tmp_path + '/PARQUET_DATA' + all_confs = {'spark.sql.parquet.outputTimestampType': ts_write} + all_confs.update({'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': date_time_rebase_write, + 'spark.sql.legacy.parquet.int96RebaseModeInWrite': int96_rebase_write}) + all_confs.update({'spark.sql.legacy.parquet.datetimeRebaseModeInRead': date_time_rebase_read, + 'spark.sql.legacy.parquet.int96RebaseModeInRead': int96_rebase_read}) + + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.parquet(path), + lambda spark, path: spark.read.parquet(path), + data_path, + conf=all_confs) + @pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/3476') @pytest.mark.allow_non_gpu("DataWritingCommandExec", "HiveTableScanExec") @pytest.mark.parametrize('allow_non_empty', [True, False]) diff --git a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala index 1efe3f54e3e..6b32121cb32 100644 --- a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala +++ b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala @@ -32,6 +32,22 @@ class Spark311Shims extends SparkBaseShims { classOf[RapidsShuffleManager].getCanonicalName } + override def int96ParquetRebaseRead(conf: SQLConf): String = { + conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ) + } + + override def int96ParquetRebaseWrite(conf: SQLConf): String = { + conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE) + } + + override def int96ParquetRebaseReadKey: String = { + SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ.key + } + + override def int96ParquetRebaseWriteKey: String = { + SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE.key + } + override def hasCastFloatTimestampUpcast: Boolean = false override def getParquetFilters( diff --git a/shims/spark311cdh/src/main/scala/com/nvidia/spark/rapids/shims/spark311cdh/Spark311CDHShims.scala b/shims/spark311cdh/src/main/scala/com/nvidia/spark/rapids/shims/spark311cdh/Spark311CDHShims.scala index 5027446335c..02b0c479212 100644 --- a/shims/spark311cdh/src/main/scala/com/nvidia/spark/rapids/shims/spark311cdh/Spark311CDHShims.scala +++ b/shims/spark311cdh/src/main/scala/com/nvidia/spark/rapids/shims/spark311cdh/Spark311CDHShims.scala @@ -97,6 +97,22 @@ class Spark311CDHShims extends SparkBaseShims { sessionCatalog.createTable(newTable, ignoreIfExists = false, validateLocation = false) } + override def int96ParquetRebaseRead(conf: SQLConf): String = { + conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ) + } + + override def int96ParquetRebaseWrite(conf: SQLConf): String = { + conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE) + } + + override def int96ParquetRebaseReadKey: String = { + SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ.key + } + + override def int96ParquetRebaseWriteKey: String = { + SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE.key + } + override def hasCastFloatTimestampUpcast: Boolean = false override def getParquetFilters( diff --git a/shims/spark311cdh/src/main/scala/com/nvidia/spark/rapids/shims/spark311cdh/SparkBaseShims.scala b/shims/spark311cdh/src/main/scala/com/nvidia/spark/rapids/shims/spark311cdh/SparkBaseShims.scala new file mode 100644 index 00000000000..a2bd1f7c254 --- /dev/null +++ b/shims/spark311cdh/src/main/scala/com/nvidia/spark/rapids/shims/spark311cdh/SparkBaseShims.scala @@ -0,0 +1,877 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.spark311cdh + +import java.net.URI +import java.nio.ByteBuffer + +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} +import com.nvidia.spark.ParquetCachedBatchSerializer +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.shims.v2._ +import org.apache.arrow.memory.ReferenceManager +import org.apache.arrow.vector.ValueVector +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.parquet.schema.MessageType + +import org.apache.spark.SparkEnv +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SessionCatalog} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.errors.attachTree +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.Average +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} +import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, RunnableCommand} +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters +import org.apache.spark.sql.execution.datasources.rapids.GpuPartitioningUtils +import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.execution.python._ +import org.apache.spark.sql.execution.window.WindowExecBase +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.rapids._ +import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase, JoinTypeChecks, SerializeBatchDeserializeHostBuffer, SerializeConcatHostBuffersDeserializeBatch} +import org.apache.spark.sql.rapids.execution.python.GpuPythonUDF +import org.apache.spark.sql.rapids.execution.python.shims.spark311cdh._ +import org.apache.spark.sql.rapids.shims.spark311cdh._ +import org.apache.spark.sql.rapids.shims.v2.{GpuInMemoryTableScanExec, HadoopFSUtilsShim} +import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.types._ +import org.apache.spark.storage.{BlockId, BlockManagerId} + +/** + * Base Shim for Spark 3.1.1 that can be used by other 3.1.x versions and to easily diff + */ +abstract class SparkBaseShims extends Spark31XShims { + + override def parquetRebaseReadKey: String = + SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ.key + override def parquetRebaseWriteKey: String = + SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key + override def avroRebaseReadKey: String = + SQLConf.LEGACY_AVRO_REBASE_MODE_IN_READ.key + override def avroRebaseWriteKey: String = + SQLConf.LEGACY_AVRO_REBASE_MODE_IN_WRITE.key + override def parquetRebaseRead(conf: SQLConf): String = + conf.getConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ) + override def parquetRebaseWrite(conf: SQLConf): String = + conf.getConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE) + + override def getParquetFilters( + schema: MessageType, + pushDownDate: Boolean, + pushDownTimestamp: Boolean, + pushDownDecimal: Boolean, + pushDownStartWith: Boolean, + pushDownInFilterThreshold: Int, + caseSensitive: Boolean, + datetimeRebaseMode: SQLConf.LegacyBehaviorPolicy.Value): ParquetFilters = { + new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith, + pushDownInFilterThreshold, caseSensitive) + } + + override def v1RepairTableCommand(tableName: TableIdentifier): RunnableCommand = + AlterTableRecoverPartitionsCommand(tableName) + + override def getScalaUDFAsExpression( + function: AnyRef, + dataType: DataType, + children: Seq[Expression], + inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil, + outputEncoder: Option[ExpressionEncoder[_]] = None, + udfName: Option[String] = None, + nullable: Boolean = true, + udfDeterministic: Boolean = true): Expression = { + ScalaUDF(function, dataType, children, inputEncoders, outputEncoder, udfName, nullable, + udfDeterministic) + } + + override def getMapSizesByExecutorId( + shuffleId: Int, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, + startMapIndex, endMapIndex, startPartition, endPartition) + } + + override def getGpuBroadcastNestedLoopJoinShim( + left: SparkPlan, + right: SparkPlan, + join: BroadcastNestedLoopJoinExec, + joinType: JoinType, + condition: Option[Expression], + targetSizeBytes: Long): GpuBroadcastNestedLoopJoinExecBase = { + GpuBroadcastNestedLoopJoinExec(left, right, join, joinType, condition, targetSizeBytes) + } + + override def getGpuBroadcastExchangeExec( + mode: BroadcastMode, + child: SparkPlan): GpuBroadcastExchangeExecBase = { + GpuBroadcastExchangeExec(mode, child) + } + + override def isGpuBroadcastHashJoin(plan: SparkPlan): Boolean = { + plan match { + case _: GpuBroadcastHashJoinExec => true + case _ => false + } + } + + override def isWindowFunctionExec(plan: SparkPlan): Boolean = plan.isInstanceOf[WindowExecBase] + + override def isGpuShuffledHashJoin(plan: SparkPlan): Boolean = { + plan match { + case _: GpuShuffledHashJoinExec => true + case _ => false + } + } + + override def getFileSourceMaxMetadataValueLength(sqlConf: SQLConf): Int = + sqlConf.maxMetadataStringLength + + override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq( + GpuOverrides.expr[Cast]( + "Convert a column of one type of data into another type", + new CastChecks(), + (cast, conf, p, r) => new CastExprMeta[Cast](cast, SparkSession.active.sessionState.conf + .ansiEnabled, conf, p, r) { + override def tagExprForGpu(): Unit = { + if (!conf.isCastFloatToIntegralTypesEnabled && + (fromType == DataTypes.FloatType || fromType == DataTypes.DoubleType) && + (toType == DataTypes.ByteType || toType == DataTypes.ShortType || + toType == DataTypes.IntegerType || toType == DataTypes.LongType)) { + willNotWorkOnGpu(buildTagMessage(RapidsConf.ENABLE_CAST_FLOAT_TO_INTEGRAL_TYPES)) + } + super.tagExprForGpu() + } + }), + GpuOverrides.expr[AnsiCast]( + "Convert a column of one type of data into another type", + new CastChecks { + import TypeSig._ + // nullChecks are the same + + override val booleanChecks: TypeSig = integral + fp + BOOLEAN + STRING + override val sparkBooleanSig: TypeSig = numeric + BOOLEAN + STRING + + override val integralChecks: TypeSig = gpuNumeric + BOOLEAN + STRING + override val sparkIntegralSig: TypeSig = numeric + BOOLEAN + STRING + + override val fpChecks: TypeSig = (gpuNumeric + BOOLEAN + STRING) + .withPsNote(TypeEnum.STRING, fpToStringPsNote) + override val sparkFpSig: TypeSig = numeric + BOOLEAN + STRING + + override val dateChecks: TypeSig = TIMESTAMP + DATE + STRING + override val sparkDateSig: TypeSig = TIMESTAMP + DATE + STRING + + override val timestampChecks: TypeSig = TIMESTAMP + DATE + STRING + override val sparkTimestampSig: TypeSig = TIMESTAMP + DATE + STRING + + // stringChecks are the same + // binaryChecks are the same + override val decimalChecks: TypeSig = DECIMAL_64 + STRING + override val sparkDecimalSig: TypeSig = numeric + BOOLEAN + STRING + + // calendarChecks are the same + + override val arrayChecks: TypeSig = + ARRAY.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY + STRUCT) + + psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to " + + "the desired child type") + override val sparkArraySig: TypeSig = ARRAY.nested(all) + + override val mapChecks: TypeSig = + MAP.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY + STRUCT + MAP) + + psNote(TypeEnum.MAP, "the map's key and value must also support being cast to the " + + "desired child types") + override val sparkMapSig: TypeSig = MAP.nested(all) + + override val structChecks: TypeSig = + STRUCT.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY + STRUCT) + + psNote(TypeEnum.STRUCT, "the struct's children must also support being cast to the " + + "desired child type(s)") + override val sparkStructSig: TypeSig = STRUCT.nested(all) + + override val udtChecks: TypeSig = none + override val sparkUdtSig: TypeSig = UDT + }, + (cast, conf, p, r) => new CastExprMeta[AnsiCast](cast, true, conf, p, r) { + override def tagExprForGpu(): Unit = { + if (!conf.isCastFloatToIntegralTypesEnabled && + (fromType == DataTypes.FloatType || fromType == DataTypes.DoubleType) && + (toType == DataTypes.ByteType || toType == DataTypes.ShortType || + toType == DataTypes.IntegerType || toType == DataTypes.LongType)) { + willNotWorkOnGpu(buildTagMessage(RapidsConf.ENABLE_CAST_FLOAT_TO_INTEGRAL_TYPES)) + } + super.tagExprForGpu() + } + }), + GpuOverrides.expr[Average]( + "Average aggregate operator", + ExprChecks.fullAgg( + TypeSig.DOUBLE, TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL, + Seq(ParamCheck("input", TypeSig.integral + TypeSig.fp, TypeSig.numeric))), + (a, conf, p, r) => new AggExprMeta[Average](a, conf, p, r) { + override def tagExprForGpu(): Unit = { + val dataType = a.child.dataType + GpuOverrides.checkAndTagFloatAgg(dataType, conf, this) + } + + override def convertToGpu(child: Expression): GpuExpression = GpuAverage(child) + }), + GpuOverrides.expr[Abs]( + "Absolute value", + ExprChecks.unaryProjectAndAstInputMatchesOutput( + TypeSig.implicitCastsAstTypes, TypeSig.gpuNumeric, TypeSig.numeric), + (a, conf, p, r) => new UnaryAstExprMeta[Abs](a, conf, p, r) { + // ANSI support for ABS was added in 3.2.0 SPARK-33275 + override def convertToGpu(child: Expression): GpuExpression = GpuAbs(child, false) + }), + GpuOverrides.expr[RegExpReplace]( + "RegExpReplace support for string literal input patterns", + ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, + Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), + ParamCheck("regex", TypeSig.lit(TypeEnum.STRING) + .withPsNote(TypeEnum.STRING, "very limited regex support"), TypeSig.STRING), + ParamCheck("rep", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), + ParamCheck("pos", TypeSig.lit(TypeEnum.INT) + .withPsNote(TypeEnum.INT, "only a value of 1 is supported"), + TypeSig.lit(TypeEnum.INT)))), + (a, conf, p, r) => new ExprMeta[RegExpReplace](a, conf, p, r) { + override def tagExprForGpu(): Unit = { + if (GpuOverrides.isNullOrEmptyOrRegex(a.regexp)) { + willNotWorkOnGpu( + "Only non-null, non-empty String literals that are not regex patterns " + + "are supported by RegExpReplace on the GPU") + } + GpuOverrides.extractLit(a.pos).foreach { lit => + if (lit.value.asInstanceOf[Int] != 1) { + willNotWorkOnGpu("Only a search starting position of 1 is supported") + } + } + } + override def convertToGpu(): GpuExpression = { + // ignore the pos expression which must be a literal 1 after tagging check + require(childExprs.length == 4, + s"Unexpected child count for RegExpReplace: ${childExprs.length}") + val Seq(subject, regexp, rep) = childExprs.take(3).map(_.convertToGpu()) + GpuStringReplace(subject, regexp, rep) + } + }), + // Spark 3.1.1-specific LEAD expression, using custom OffsetWindowFunctionMeta. + GpuOverrides.expr[Lead]( + "Window function that returns N entries ahead of this one", + ExprChecks.windowOnly( + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL + + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all, + Seq( + ParamCheck("input", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + + TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all), + ParamCheck("offset", TypeSig.INT, TypeSig.INT), + ParamCheck("default", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL + + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all) + ) + ), + (lead, conf, p, r) => new OffsetWindowFunctionMeta[Lead](lead, conf, p, r) { + override def convertToGpu(): GpuExpression = + GpuLead(input.convertToGpu(), offset.convertToGpu(), default.convertToGpu()) + }), + // Spark 3.1.1-specific LAG expression, using custom OffsetWindowFunctionMeta. + GpuOverrides.expr[Lag]( + "Window function that returns N entries behind this one", + ExprChecks.windowOnly( + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL + + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all, + Seq( + ParamCheck("input", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + + TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all), + ParamCheck("offset", TypeSig.INT, TypeSig.INT), + ParamCheck("default", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL + + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all) + ) + ), + (lag, conf, p, r) => new OffsetWindowFunctionMeta[Lag](lag, conf, p, r) { + override def convertToGpu(): GpuExpression = { + GpuLag(input.convertToGpu(), offset.convertToGpu(), default.convertToGpu()) + } + }), + GpuOverrides.expr[GetArrayItem]( + "Gets the field at `ordinal` in the Array", + ExprChecks.binaryProject( + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + + TypeSig.DECIMAL_64 + TypeSig.MAP).nested(), + TypeSig.all, + ("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + + TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_64 + TypeSig.MAP), + TypeSig.ARRAY.nested(TypeSig.all)), + ("ordinal", TypeSig.lit(TypeEnum.INT), TypeSig.INT)), + (in, conf, p, r) => new GpuGetArrayItemMeta(in, conf, p, r){ + override def convertToGpu(arr: Expression, ordinal: Expression): GpuExpression = + GpuGetArrayItem(arr, ordinal, SQLConf.get.ansiEnabled) + }), + GpuOverrides.expr[GetMapValue]( + "Gets Value from a Map based on a key", + ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all, + ("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)), + ("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)), + (in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r){ + override def convertToGpu(map: Expression, key: Expression): GpuExpression = + GpuGetMapValue(map, key, SQLConf.get.ansiEnabled) + }), + GpuOverrides.expr[ElementAt]( + "Returns element of array at given(1-based) index in value if column is array. " + + "Returns value for the given key in value if column is map.", + ExprChecks.binaryProject( + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + + TypeSig.DECIMAL_64 + TypeSig.MAP).nested(), TypeSig.all, + ("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + + TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_64 + TypeSig.MAP) + + TypeSig.MAP.nested(TypeSig.STRING) + .withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."), + TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)), + ("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING)) + .withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " + + "not as maps keys") + .withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " + + "not array indexes"), + TypeSig.all)), + (in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) { + override def tagExprForGpu(): Unit = { + // To distinguish the supported nested type between Array and Map + val checks = in.left.dataType match { + case _: MapType => + // Match exactly with the checks for GetMapValue + ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all, + ("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)), + ("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)) + case _: ArrayType => + // Match exactly with the checks for GetArrayItem + ExprChecks.binaryProject( + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + + TypeSig.DECIMAL_64 + TypeSig.MAP).nested(), + TypeSig.all, + ("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + + TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_64 + TypeSig.MAP), + TypeSig.ARRAY.nested(TypeSig.all)), + ("ordinal", TypeSig.lit(TypeEnum.INT), TypeSig.INT)) + case _ => throw new IllegalStateException("Only Array or Map is supported as input.") + } + checks.tag(this) + } + override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { + GpuElementAt(lhs, rhs, SQLConf.get.ansiEnabled) + } + }) + ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap + + override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = { + Seq( + GpuOverrides.exec[WindowInPandasExec]( + "The backend for Window Aggregation Pandas UDF, Accelerates the data transfer between" + + " the Java process and the Python process. It also supports scheduling GPU resources" + + " for the Python process when enabled. For now it only supports row based window frame.", + ExecChecks( + (TypeSig.commonCudfTypes + TypeSig.ARRAY).nested(TypeSig.commonCudfTypes), + TypeSig.all), + (winPy, conf, p, r) => new GpuWindowInPandasExecMetaBase(winPy, conf, p, r) { + override val windowExpressions: Seq[BaseExprMeta[NamedExpression]] = + winPy.windowExpression.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + + override def convertToGpu(): GpuExec = { + GpuWindowInPandasExec( + windowExpressions.map(_.convertToGpu()), + partitionSpec.map(_.convertToGpu()), + orderSpec.map(_.convertToGpu().asInstanceOf[SortOrder]), + childPlans.head.convertIfNeeded() + ) + } + }).disabledByDefault("it only supports row based frame for now"), + GpuOverrides.exec[FileSourceScanExec]( + "Reading data from files, often from Hive tables", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + + TypeSig.ARRAY + TypeSig.DECIMAL_64).nested(), TypeSig.all), + (fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) { + // partition filters and data filters are not run on the GPU + override val childExprs: Seq[ExprMeta[_]] = Seq.empty + + override def tagPlanForGpu(): Unit = GpuFileSourceScanExec.tagSupport(this) + + override def convertToGpu(): GpuExec = { + val sparkSession = wrapped.relation.sparkSession + val options = wrapped.relation.options + + val location = replaceWithAlluxioPathIfNeeded( + conf, + wrapped.relation, + wrapped.partitionFilters, + wrapped.dataFilters) + + val newRelation = HadoopFsRelation( + location, + wrapped.relation.partitionSchema, + wrapped.relation.dataSchema, + wrapped.relation.bucketSpec, + GpuFileSourceScanExec.convertFileFormat(wrapped.relation.fileFormat), + options)(sparkSession) + + GpuFileSourceScanExec( + newRelation, + wrapped.output, + wrapped.requiredSchema, + wrapped.partitionFilters, + wrapped.optionalBucketSet, + wrapped.optionalNumCoalescedBuckets, + wrapped.dataFilters, + wrapped.tableIdentifier)(conf) + } + }), + GpuOverrides.exec[InMemoryTableScanExec]( + "Implementation of InMemoryTableScanExec to use GPU accelerated Caching", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT).nested(), + TypeSig.all), + (scan, conf, p, r) => new SparkPlanMeta[InMemoryTableScanExec](scan, conf, p, r) { + override def tagPlanForGpu(): Unit = { + if (!scan.relation.cacheBuilder.serializer.isInstanceOf[ParquetCachedBatchSerializer]) { + willNotWorkOnGpu("ParquetCachedBatchSerializer is not being used") + } + } + + /** + * Convert InMemoryTableScanExec to a GPU enabled version. + */ + override def convertToGpu(): GpuExec = { + GpuInMemoryTableScanExec(scan.attributes, scan.predicates, scan.relation) + } + }), + GpuOverrides.exec[SortMergeJoinExec]( + "Sort merge join, replacing with shuffled hash join", + JoinTypeChecks.equiJoinExecChecks, + (join, conf, p, r) => new GpuSortMergeJoinMeta(join, conf, p, r)), + GpuOverrides.exec[BroadcastHashJoinExec]( + "Implementation of join using broadcast data", + JoinTypeChecks.equiJoinExecChecks, + (join, conf, p, r) => new GpuBroadcastHashJoinMeta(join, conf, p, r)), + GpuOverrides.exec[ShuffledHashJoinExec]( + "Implementation of join using hashed shuffled data", + JoinTypeChecks.equiJoinExecChecks, + (join, conf, p, r) => new GpuShuffledHashJoinMeta(join, conf, p, r)), + GpuOverrides.exec[ArrowEvalPythonExec]( + "The backend of the Scalar Pandas UDFs. Accelerates the data transfer between the" + + " Java process and the Python process. It also supports scheduling GPU resources" + + " for the Python process when enabled", + ExecChecks( + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all), + (e, conf, p, r) => + new SparkPlanMeta[ArrowEvalPythonExec](e, conf, p, r) { + val udfs: Seq[BaseExprMeta[PythonUDF]] = + e.udfs.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val resultAttrs: Seq[BaseExprMeta[Attribute]] = + e.resultAttrs.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + override val childExprs: Seq[BaseExprMeta[_]] = udfs ++ resultAttrs + + override def replaceMessage: String = "partially run on GPU" + override def noReplacementPossibleMessage(reasons: String): String = + s"cannot run even partially on the GPU because $reasons" + + override def convertToGpu(): GpuExec = + GpuArrowEvalPythonExec(udfs.map(_.convertToGpu()).asInstanceOf[Seq[GpuPythonUDF]], + resultAttrs.map(_.convertToGpu()).asInstanceOf[Seq[Attribute]], + childPlans.head.convertIfNeeded(), + e.evalType) + }), + GpuOverrides.exec[MapInPandasExec]( + "The backend for Map Pandas Iterator UDF. Accelerates the data transfer between the" + + " Java process and the Python process. It also supports scheduling GPU resources" + + " for the Python process when enabled.", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all), + (mapPy, conf, p, r) => new GpuMapInPandasExecMeta(mapPy, conf, p, r)), + GpuOverrides.exec[FlatMapGroupsInPandasExec]( + "The backend for Flat Map Groups Pandas UDF, Accelerates the data transfer between the" + + " Java process and the Python process. It also supports scheduling GPU resources" + + " for the Python process when enabled.", + ExecChecks(TypeSig.commonCudfTypes, TypeSig.all), + (flatPy, conf, p, r) => new GpuFlatMapGroupsInPandasExecMeta(flatPy, conf, p, r)), + GpuOverrides.exec[AggregateInPandasExec]( + "The backend for an Aggregation Pandas UDF, this accelerates the data transfer between" + + " the Java process and the Python process. It also supports scheduling GPU resources" + + " for the Python process when enabled.", + ExecChecks(TypeSig.commonCudfTypes, TypeSig.all), + (aggPy, conf, p, r) => new GpuAggregateInPandasExecMeta(aggPy, conf, p, r)) + ).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap + } + + override def getScans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] = Seq( + GpuOverrides.scan[ParquetScan]( + "Parquet parsing", + (a, conf, p, r) => new ScanMeta[ParquetScan](a, conf, p, r) { + override def tagSelfForGpu(): Unit = GpuParquetScanBase.tagSupport(this) + + override def convertToGpu(): Scan = { + GpuParquetScan(a.sparkSession, + a.hadoopConf, + a.fileIndex, + a.dataSchema, + a.readDataSchema, + a.readPartitionSchema, + a.pushedFilters, + a.options, + a.partitionFilters, + a.dataFilters, + conf) + } + }), + GpuOverrides.scan[OrcScan]( + "ORC parsing", + (a, conf, p, r) => new ScanMeta[OrcScan](a, conf, p, r) { + override def tagSelfForGpu(): Unit = + GpuOrcScanBase.tagSupport(this) + + override def convertToGpu(): Scan = + GpuOrcScan(a.sparkSession, + a.hadoopConf, + a.fileIndex, + a.dataSchema, + a.readDataSchema, + a.readPartitionSchema, + a.options, + a.pushedFilters, + a.partitionFilters, + a.dataFilters, + conf) + }) + ).map(r => (r.getClassFor.asSubclass(classOf[Scan]), r)).toMap + + override def getBuildSide(join: HashJoin): GpuBuildSide = { + GpuJoinUtils.getGpuBuildSide(join.buildSide) + } + + override def getBuildSide(join: BroadcastNestedLoopJoinExec): GpuBuildSide = { + GpuJoinUtils.getGpuBuildSide(join.buildSide) + } + + override def getPartitionFileNames( + partitions: Seq[PartitionDirectory]): Seq[String] = { + val files = partitions.flatMap(partition => partition.files) + files.map(_.getPath.getName) + } + + override def getPartitionFileStatusSize(partitions: Seq[PartitionDirectory]): Long = { + partitions.map(_.files.map(_.getLen).sum).sum + } + + override def getPartitionedFiles( + partitions: Array[PartitionDirectory]): Array[PartitionedFile] = { + partitions.flatMap { p => + p.files.map { f => + PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values) + } + } + } + + override def getPartitionSplitFiles( + partitions: Array[PartitionDirectory], + maxSplitBytes: Long, + relation: HadoopFsRelation): Array[PartitionedFile] = { + partitions.flatMap { partition => + partition.files.flatMap { file => + // getPath() is very expensive so we only want to call it once in this block: + val filePath = file.getPath + val isSplitable = relation.fileFormat.isSplitable( + relation.sparkSession, relation.options, filePath) + PartitionedFileUtil.splitFiles( + sparkSession = relation.sparkSession, + file = file, + filePath = filePath, + isSplitable = isSplitable, + maxSplitBytes = maxSplitBytes, + partitionValues = partition.values + ) + } + } + } + + override def getFileScanRDD( + sparkSession: SparkSession, + readFunction: PartitionedFile => Iterator[InternalRow], + filePartitions: Seq[FilePartition]): RDD[InternalRow] = { + new FileScanRDD(sparkSession, readFunction, filePartitions) + } + + override def createFilePartition(index: Int, files: Array[PartitionedFile]): FilePartition = { + FilePartition(index, files) + } + + override def copyBatchScanExec( + batchScanExec: GpuBatchScanExec, + queryUsesInputFile: Boolean): GpuBatchScanExec = { + val scanCopy = batchScanExec.scan match { + case parquetScan: GpuParquetScan => + parquetScan.copy(queryUsesInputFile=queryUsesInputFile) + case orcScan: GpuOrcScan => + orcScan.copy(queryUsesInputFile=queryUsesInputFile) + case _ => throw new RuntimeException("Wrong format") // never reach here + } + batchScanExec.copy(scan=scanCopy) + } + + override def copyFileSourceScanExec( + scanExec: GpuFileSourceScanExec, + queryUsesInputFile: Boolean): GpuFileSourceScanExec = { + scanExec.copy(queryUsesInputFile=queryUsesInputFile)(scanExec.rapidsConf) + } + + override def getGpuColumnarToRowTransition(plan: SparkPlan, + exportColumnRdd: Boolean): GpuColumnarToRowExecParent = { + val serName = plan.conf.getConf(StaticSQLConf.SPARK_CACHE_SERIALIZER) + val serClass = Class.forName(serName) + if (serClass == classOf[ParquetCachedBatchSerializer]) { + GpuColumnarToRowTransitionExec(plan) + } else { + GpuColumnarToRowExec(plan) + } + } + + override def checkColumnNameDuplication( + schema: StructType, + colType: String, + resolver: Resolver): Unit = { + GpuSchemaUtils.checkColumnNameDuplication(schema, colType, resolver) + } + + override def getGpuShuffleExchangeExec( + outputPartitioning: Partitioning, + child: SparkPlan, + cpuShuffle: Option[ShuffleExchangeExec]): GpuShuffleExchangeExecBase = { + val shuffleOrigin = cpuShuffle.map(_.shuffleOrigin).getOrElse(ENSURE_REQUIREMENTS) + GpuShuffleExchangeExec(outputPartitioning, child, shuffleOrigin) + } + + override def getGpuShuffleExchangeExec( + queryStage: ShuffleQueryStageExec): GpuShuffleExchangeExecBase = { + queryStage.shuffle.asInstanceOf[GpuShuffleExchangeExecBase] + } + + override def sortOrderChildren(s: SortOrder): Seq[Expression] = s.children + + override def sortOrder( + child: Expression, + direction: SortDirection, + nullOrdering: NullOrdering): SortOrder = SortOrder(child, direction, nullOrdering, Seq.empty) + + override def copySortOrderWithNewChild(s: SortOrder, child: Expression) = { + s.copy(child = child) + } + + override def alias(child: Expression, name: String)( + exprId: ExprId, + qualifier: Seq[String], + explicitMetadata: Option[Metadata]): Alias = { + Alias(child, name)(exprId, qualifier, explicitMetadata) + } + + override def shouldIgnorePath(path: String): Boolean = { + HadoopFSUtilsShim.shouldIgnorePath(path) + } + + override def getLegacyComplexTypeToString(): Boolean = { + SQLConf.get.getConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING) + } + + // Arrow version changed between Spark versions + override def getArrowDataBuf(vec: ValueVector): (ByteBuffer, ReferenceManager) = { + val arrowBuf = vec.getDataBuffer() + (arrowBuf.nioBuffer(), arrowBuf.getReferenceManager) + } + + override def getArrowValidityBuf(vec: ValueVector): (ByteBuffer, ReferenceManager) = { + val arrowBuf = vec.getValidityBuffer + (arrowBuf.nioBuffer(), arrowBuf.getReferenceManager) + } + + override def createTable(table: CatalogTable, + sessionCatalog: SessionCatalog, + tableLocation: Option[URI], + result: BaseRelation) = { + val newTable = table.copy( + storage = table.storage.copy(locationUri = tableLocation), + // We will use the schema of resolved.relation as the schema of the table (instead of + // the schema of df). It is important since the nullability may be changed by the relation + // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). + schema = result.schema) + // Table location is already validated. No need to check it again during table creation. + sessionCatalog.createTable(newTable, ignoreIfExists = false, validateLocation = false) + } + + override def getArrowOffsetsBuf(vec: ValueVector): (ByteBuffer, ReferenceManager) = { + val arrowBuf = vec.getOffsetBuffer + (arrowBuf.nioBuffer(), arrowBuf.getReferenceManager) + } + + /** matches SPARK-33008 fix in 3.1.1 */ + override def shouldFailDivByZero(): Boolean = SQLConf.get.ansiEnabled + + override def replaceWithAlluxioPathIfNeeded( + conf: RapidsConf, + relation: HadoopFsRelation, + partitionFilters: Seq[Expression], + dataFilters: Seq[Expression]): FileIndex = { + + val alluxioPathsReplace: Option[Seq[String]] = conf.getAlluxioPathsToReplace + + if (alluxioPathsReplace.isDefined) { + // alluxioPathsReplace: Seq("key->value", "key1->value1") + // turn the rules to the Map with eg + // { s3:/foo -> alluxio://0.1.2.3:19998/foo, + // gs:/bar -> alluxio://0.1.2.3:19998/bar, + // /baz -> alluxio://0.1.2.3:19998/baz } + val replaceMapOption = alluxioPathsReplace.map(rules => { + rules.map(rule => { + val split = rule.split("->") + if (split.size == 2) { + split(0).trim -> split(1).trim + } else { + throw new IllegalArgumentException(s"Invalid setting for " + + s"${RapidsConf.ALLUXIO_PATHS_REPLACE.key}") + } + }).toMap + }) + + replaceMapOption.map(replaceMap => { + + def isDynamicPruningFilter(e: Expression): Boolean = + e.find(_.isInstanceOf[PlanExpression[_]]).isDefined + + val partitionDirs = relation.location.listFiles( + partitionFilters.filterNot(isDynamicPruningFilter), dataFilters) + + // replacement func to check if the file path is prefixed with the string user configured + // if yes, replace it + val replaceFunc = (f: Path) => { + val pathStr = f.toString + val matchedSet = replaceMap.keySet.filter(reg => pathStr.startsWith(reg)) + if (matchedSet.size > 1) { + // never reach here since replaceMap is a Map + throw new IllegalArgumentException(s"Found ${matchedSet.size} same replacing rules " + + s"from ${RapidsConf.ALLUXIO_PATHS_REPLACE.key} which requires only 1 rule for each " + + s"file path") + } else if (matchedSet.size == 1) { + new Path(pathStr.replaceFirst(matchedSet.head, replaceMap(matchedSet.head))) + } else { + f + } + } + + // replace all of input files + val inputFiles: Seq[Path] = partitionDirs.flatMap(partitionDir => { + replacePartitionDirectoryFiles(partitionDir, replaceFunc) + }) + + // replace all of rootPaths which are already unique + val rootPaths = relation.location.rootPaths.map(replaceFunc) + + val parameters: Map[String, String] = relation.options + + // infer PartitionSpec + val partitionSpec = GpuPartitioningUtils.inferPartitioning( + relation.sparkSession, + rootPaths, + inputFiles, + parameters, + Option(relation.dataSchema), + replaceFunc) + + // generate a new InMemoryFileIndex holding paths with alluxio schema + new InMemoryFileIndex( + relation.sparkSession, + inputFiles, + parameters, + Option(relation.dataSchema), + userSpecifiedPartitionSpec = Some(partitionSpec)) + }).getOrElse(relation.location) + + } else { + relation.location + } + } + + override def replacePartitionDirectoryFiles(partitionDir: PartitionDirectory, + replaceFunc: Path => Path): Seq[Path] = { + partitionDir.files.map(f => replaceFunc(f.getPath)) + } + + override def reusedExchangeExecPfn: PartialFunction[SparkPlan, ReusedExchangeExec] = { + case ShuffleQueryStageExec(_, e: ReusedExchangeExec) => e + case BroadcastQueryStageExec(_, e: ReusedExchangeExec) => e + } + + /** dropped by SPARK-34234 */ + override def attachTreeIfSupported[TreeType <: TreeNode[_], A]( + tree: TreeType, + msg: String)( + f: => A + ): A = { + attachTree(tree, msg)(f) + } + + override def hasAliasQuoteFix: Boolean = false + + override def hasCastFloatTimestampUpcast: Boolean = false + + override def filesFromFileIndex(fileIndex: PartitioningAwareFileIndex): Seq[FileStatus] = { + fileIndex.allFiles() + } + + override def broadcastModeTransform(mode: BroadcastMode, rows: Array[InternalRow]): Any = + mode.transform(rows) + + override def registerKryoClasses(kryo: Kryo): Unit = { + kryo.register(classOf[SerializeConcatHostBuffersDeserializeBatch], + new KryoJavaSerializer()) + kryo.register(classOf[SerializeBatchDeserializeHostBuffer], + new KryoJavaSerializer()) + } + + override def shouldFallbackOnAnsiTimestamp(): Boolean = SQLConf.get.ansiEnabled +} diff --git a/shims/spark311db/src/main/scala/com/nvidia/spark/rapids/shims/spark311db/Spark311dbShims.scala b/shims/spark311db/src/main/scala/com/nvidia/spark/rapids/shims/spark311db/Spark311dbShims.scala index a8a569fe9b6..cabe045b109 100644 --- a/shims/spark311db/src/main/scala/com/nvidia/spark/rapids/shims/spark311db/Spark311dbShims.scala +++ b/shims/spark311db/src/main/scala/com/nvidia/spark/rapids/shims/spark311db/Spark311dbShims.scala @@ -43,4 +43,20 @@ class Spark311dbShims extends SparkBaseShims { datetimeRebaseMode: SQLConf.LegacyBehaviorPolicy.Value): ParquetFilters = new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith, pushDownInFilterThreshold, caseSensitive, datetimeRebaseMode) + + override def int96ParquetRebaseRead(conf: SQLConf): String = { + conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ) + } + + override def int96ParquetRebaseWrite(conf: SQLConf): String = { + conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE) + } + + override def int96ParquetRebaseReadKey: String = { + SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ.key + } + + override def int96ParquetRebaseWriteKey: String = { + SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE.key + } } diff --git a/shims/spark312/src/main/scala/com/nvidia/spark/rapids/shims/spark312/Spark312Shims.scala b/shims/spark312/src/main/scala/com/nvidia/spark/rapids/shims/spark312/Spark312Shims.scala index 0ad515324c8..9d330fd930c 100644 --- a/shims/spark312/src/main/scala/com/nvidia/spark/rapids/shims/spark312/Spark312Shims.scala +++ b/shims/spark312/src/main/scala/com/nvidia/spark/rapids/shims/spark312/Spark312Shims.scala @@ -34,6 +34,22 @@ class Spark312Shims extends SparkBaseShims { override def hasCastFloatTimestampUpcast: Boolean = true + override def int96ParquetRebaseRead(conf: SQLConf): String = { + conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ) + } + + override def int96ParquetRebaseWrite(conf: SQLConf): String = { + conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE) + } + + override def int96ParquetRebaseReadKey: String = { + SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ.key + } + + override def int96ParquetRebaseWriteKey: String = { + SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE.key + } + override def getParquetFilters( schema: MessageType, pushDownDate: Boolean, diff --git a/shims/spark313/src/main/scala/com/nvidia/spark/rapids/shims/spark313/Spark313Shims.scala b/shims/spark313/src/main/scala/com/nvidia/spark/rapids/shims/spark313/Spark313Shims.scala index 4f53386c330..4c7f95911e4 100644 --- a/shims/spark313/src/main/scala/com/nvidia/spark/rapids/shims/spark313/Spark313Shims.scala +++ b/shims/spark313/src/main/scala/com/nvidia/spark/rapids/shims/spark313/Spark313Shims.scala @@ -45,4 +45,20 @@ class Spark313Shims extends SparkBaseShims { pushDownInFilterThreshold, caseSensitive, datetimeRebaseMode) override def hasCastFloatTimestampUpcast: Boolean = true + + override def int96ParquetRebaseRead(conf: SQLConf): String = { + conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ) + } + + override def int96ParquetRebaseWrite(conf: SQLConf): String = { + conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE) + } + + override def int96ParquetRebaseReadKey: String = { + SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ.key + } + + override def int96ParquetRebaseWriteKey: String = { + SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE.key + } } diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala index a7a2ab84f26..ba27b7ae80b 100644 --- a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala @@ -113,6 +113,8 @@ trait Spark30XShims extends SparkShims { override def shouldFailDivOverflow(): Boolean = false + override def hasSeparateINT96RebaseConf: Boolean = false + override def leafNodeDefaultParallelism(ss: SparkSession): Int = { ss.sparkContext.defaultParallelism } diff --git a/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala b/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala index dfce42b3895..70364892fd0 100644 --- a/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala +++ b/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala @@ -126,5 +126,7 @@ trait Spark30XShims extends SparkShims { ss.sparkContext.defaultParallelism } + override def hasSeparateINT96RebaseConf: Boolean = false + override def shouldFallbackOnAnsiTimestamp(): Boolean = false } diff --git a/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XShims.scala b/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XShims.scala new file mode 100644 index 00000000000..79e5be6b18d --- /dev/null +++ b/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XShims.scala @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.v2 + +trait Spark31XShims extends Spark30XShims { + override def hasSeparateINT96RebaseConf: Boolean = true +} diff --git a/sql-plugin/src/main/311db/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala b/sql-plugin/src/main/311db/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala index a7a2ab84f26..51fd6ec5374 100644 --- a/sql-plugin/src/main/311db/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala +++ b/sql-plugin/src/main/311db/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala @@ -111,6 +111,8 @@ trait Spark30XShims extends SparkShims { override def skipAssertIsOnTheGpu(plan: SparkPlan): Boolean = false + override def hasSeparateINT96RebaseConf: Boolean = true + override def shouldFailDivOverflow(): Boolean = false override def leafNodeDefaultParallelism(ss: SparkSession): Int = { diff --git a/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala b/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala index 731d5d05b17..05c8ea8614d 100644 --- a/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala +++ b/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala @@ -62,7 +62,7 @@ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.storage.{BlockId, BlockManagerId} -abstract class SparkBaseShims extends Spark30XShims { +abstract class SparkBaseShims extends Spark31XShims { override def v1RepairTableCommand(tableName: TableIdentifier): RunnableCommand = AlterTableRecoverPartitionsCommand(tableName) diff --git a/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala index f6c2dfef7ad..1953bfb959f 100644 --- a/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala +++ b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala @@ -54,6 +54,8 @@ trait Spark32XShims extends SparkShims { override final def parquetRebaseWrite(conf: SQLConf): String = conf.getConf(SQLConf.PARQUET_REBASE_MODE_IN_WRITE) + override def hasSeparateINT96RebaseConf: Boolean = true + override final def aqeShuffleReaderExec: ExecRule[_ <: SparkPlan] = exec[AQEShuffleReadExec]( "A wrapper of shuffle query stage", ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_64 + TypeSig.ARRAY + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala b/sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala index 374dffb456f..f2d739c8eba 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala @@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime import org.apache.spark.sql.rapids.execution.TrampolineUtil object RebaseHelper extends Arm { - private[this] def isDateTimeRebaseNeeded(column: ColumnVector, - startDay: Int, - startTs: Long): Boolean = { + private[this] def isDateRebaseNeeded(column: ColumnVector, + startDay: Int): Boolean = { // TODO update this for nested column checks // https://github.com/NVIDIA/spark-rapids/issues/1126 val dtype = column.getType @@ -37,7 +36,15 @@ object RebaseHelper extends Arm { } } } - } else if (dtype.isTimestampType) { + } else { + false + } + } + + private[this] def isTimeRebaseNeeded(column: ColumnVector, + startTs: Long): Boolean = { + val dtype = column.getType + if (dtype.hasTimeResolution) { // TODO - https://github.com/NVIDIA/spark-rapids/issues/1130 to properly handle // TIMESTAMP_MILLIS, for use require so we fail if that happens require(dtype == DType.TIMESTAMP_MICROSECONDS) @@ -54,15 +61,17 @@ object RebaseHelper extends Arm { } } - def isDateTimeRebaseNeededWrite(column: ColumnVector): Boolean = - isDateTimeRebaseNeeded(column, - RebaseDateTime.lastSwitchGregorianDay, - RebaseDateTime.lastSwitchGregorianTs) + def isDateRebaseNeededInRead(column: ColumnVector): Boolean = + isDateRebaseNeeded(column, RebaseDateTime.lastSwitchJulianDay) + + def isTimeRebaseNeededInRead(column: ColumnVector): Boolean = + isTimeRebaseNeeded(column, RebaseDateTime.lastSwitchJulianTs) + + def isDateRebaseNeededInWrite(column: ColumnVector): Boolean = + isDateRebaseNeeded(column, RebaseDateTime.lastSwitchGregorianDay) - def isDateTimeRebaseNeededRead(column: ColumnVector): Boolean = - isDateTimeRebaseNeeded(column, - RebaseDateTime.lastSwitchJulianDay, - RebaseDateTime.lastSwitchJulianTs) + def isTimeRebaseNeededInWrite(column: ColumnVector): Boolean = + isTimeRebaseNeeded(column, RebaseDateTime.lastSwitchGregorianTs) def newRebaseExceptionInRead(format: String): Exception = { val config = if (format == "Parquet") { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala index 202a0917af6..89a305fa1d4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala @@ -87,6 +87,18 @@ object GpuParquetFileFormat { TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[DateType]) } + + ShimLoader.getSparkShims.int96ParquetRebaseWrite(sqlConf) match { + case "EXCEPTION" => + case "CORRECTED" => + case "LEGACY" => + if (schemaHasTimestamps) { + meta.willNotWorkOnGpu("LEGACY rebase mode for int96 timestamps is not supported") + } + case other => + meta.willNotWorkOnGpu(s"$other is not a supported rebase mode for int96") + } + ShimLoader.getSparkShims.parquetRebaseWrite(sqlConf) match { case "EXCEPTION" => //Good case "CORRECTED" => //Good @@ -202,8 +214,19 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging { val conf = ContextUtil.getConfiguration(job) + val outputTimestampType = sparkSession.sessionState.conf.parquetOutputTimestampType val dateTimeRebaseException = "EXCEPTION".equals( - sparkSession.sqlContext.getConf(ShimLoader.getSparkShims.parquetRebaseWriteKey)) + sparkSession.sqlContext.getConf(ShimLoader.getSparkShims.parquetRebaseWriteKey)) + // prior to spark 311 int96 don't check for rebase exception + // https://github.com/apache/spark/blob/068465d016447ef0dbf7974b1a3f992040f4d64d/sql/core/src/ + // main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala#L195 + val hasSeparateInt96RebaseConf = ShimLoader.getSparkShims.hasSeparateINT96RebaseConf + val timestampRebaseException = + outputTimestampType.equals(ParquetOutputTimestampType.INT96) && + "EXCEPTION".equals(sparkSession.sqlContext + .getConf(ShimLoader.getSparkShims.int96ParquetRebaseWriteKey)) && + hasSeparateInt96RebaseConf || + !outputTimestampType.equals(ParquetOutputTimestampType.INT96) && dateTimeRebaseException val committerClass = conf.getClass( @@ -243,7 +266,6 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging { SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, sparkSession.sessionState.conf.writeLegacyParquetFormat.toString) - val outputTimestampType = sparkSession.sessionState.conf.parquetOutputTimestampType if(!GpuParquetFileFormat.isOutputTimestampTypeSupported(outputTimestampType)) { val hasTimestamps = dataSchema.exists { field => TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType]) @@ -283,7 +305,8 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging { path: String, dataSchema: StructType, context: TaskAttemptContext): ColumnarOutputWriter = { - new GpuParquetWriter(path, dataSchema, compressionType, dateTimeRebaseException, context) + new GpuParquetWriter(path, dataSchema, compressionType, dateTimeRebaseException, + timestampRebaseException, context) } override def getFileExtension(context: TaskAttemptContext): String = { @@ -297,18 +320,23 @@ class GpuParquetWriter( path: String, dataSchema: StructType, compressionType: CompressionType, - dateTimeRebaseException: Boolean, + dateRebaseException: Boolean, + timestampRebaseException: Boolean, context: TaskAttemptContext) extends ColumnarOutputWriter(path, context, dataSchema, "Parquet") { val outputTimestampType = conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key) override def scanTableBeforeWrite(table: Table): Unit = { - if (dateTimeRebaseException) { - (0 until table.getNumberOfColumns).foreach { i => - if (RebaseHelper.isDateTimeRebaseNeededWrite(table.getColumn(i))) { - throw DataSourceUtils.newRebaseExceptionInWrite("Parquet") - } + (0 until table.getNumberOfColumns).foreach { i => + val col = table.getColumn(i) + // if col is a day + if (dateRebaseException && RebaseHelper.isDateRebaseNeededInWrite(col)) { + throw DataSourceUtils.newRebaseExceptionInWrite("Parquet") + } + // if col is a time + else if (timestampRebaseException && RebaseHelper.isTimeRebaseNeededInWrite(col)) { + throw DataSourceUtils.newRebaseExceptionInWrite("Parquet") } } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScanBase.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScanBase.scala index 54cbd2a581c..8263b9ccc69 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScanBase.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScanBase.scala @@ -42,7 +42,8 @@ import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.format.converter.ParquetMetadataConverter import org.apache.parquet.hadoop.{ParquetFileReader, ParquetInputFormat} import org.apache.parquet.hadoop.metadata._ -import org.apache.parquet.schema.{GroupType, MessageType, OriginalType, Type, Types} +import org.apache.parquet.schema.{GroupType, MessageType, OriginalType, PrimitiveType, Type, Types} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast @@ -52,15 +53,16 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile} -import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, ParquetReadSupport} +import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.{SerializableConfiguration, Utils} /** @@ -114,6 +116,27 @@ object GpuParquetScanBase { tagSupport(scan.sparkSession, schema, scanMeta) } + def throwIfNeeded( + table: Table, + isCorrectedInt96Rebase: Boolean, + isCorrectedDateTimeRebase: Boolean, + hasInt96Timestamps: Boolean): Unit = { + (0 until table.getNumberOfColumns).foreach { i => + val col = table.getColumn(i) + // if col is a day + if (!isCorrectedDateTimeRebase && RebaseHelper.isDateRebaseNeededInRead(col)) { + throw DataSourceUtils.newRebaseExceptionInRead("Parquet") + } + // if col is a time + else if (hasInt96Timestamps && !isCorrectedInt96Rebase || + !hasInt96Timestamps && !isCorrectedDateTimeRebase) { + if (RebaseHelper.isTimeRebaseNeededInRead(col)) { + throw DataSourceUtils.newRebaseExceptionInRead("Parquet") + } + } + } + } + def tagSupport( sparkSession: SparkSession, readSchema: StructType, @@ -170,6 +193,21 @@ object GpuParquetScanBase { meta.willNotWorkOnGpu("GpuParquetScan does not support int96 timestamp conversion") } + sqlConf.get(ShimLoader.getSparkShims.int96ParquetRebaseReadKey) match { + case "EXCEPTION" => if (schemaMightNeedNestedRebase) { + meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + + s"${ShimLoader.getSparkShims.int96ParquetRebaseReadKey} is EXCEPTION") + } + case "CORRECTED" => // Good + case "LEGACY" => // really is EXCEPTION for us... + if (schemaMightNeedNestedRebase) { + meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + + s"${ShimLoader.getSparkShims.int96ParquetRebaseReadKey} is LEGACY") + } + case other => + meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode") + } + sqlConf.get(ShimLoader.getSparkShims.parquetRebaseReadKey) match { case "EXCEPTION" => if (schemaMightNeedNestedRebase) { meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + @@ -223,6 +261,27 @@ object GpuParquetPartitionReaderFactoryBase { private val SPARK_VERSION_METADATA_KEY = "org.apache.spark.version" // Copied from Spark private val SPARK_LEGACY_DATETIME = "org.apache.spark.legacyDateTime" + // Copied from Spark + private val SPARK_LEGACY_INT96 = "org.apache.spark.legacyINT96" + + def isCorrectedInt96RebaseMode( + lookupFileMeta: String => String, + isCorrectedInt96ModeConfig: Boolean): Boolean = { + // If there is no version, we return the mode specified by the config. + Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)).map { version => + // Files written by Spark 3.0 and earlier follow the legacy hybrid calendar and we need to + // rebase the INT96 timestamp values. + // Files written by Spark 3.1 and latter may also need the rebase if they were written with + // the "LEGACY" rebase mode. + if (version >= "3.1.0") { + lookupFileMeta(SPARK_LEGACY_INT96) == null + } else if (version >= "3.0.0") { + lookupFileMeta(SPARK_LEGACY_DATETIME) == null + } else { + false + } + }.getOrElse(isCorrectedInt96ModeConfig) + } def isCorrectedRebaseMode( lookupFileMeta: String => String, @@ -240,7 +299,8 @@ object GpuParquetPartitionReaderFactoryBase { // contains meta about all the blocks in a file private case class ParquetFileInfoWithBlockMeta(filePath: Path, blocks: Seq[BlockMetaData], - partValues: InternalRow, schema: MessageType, isCorrectedRebaseMode: Boolean) + partValues: InternalRow, schema: MessageType, isCorrectedInt96RebaseMode: Boolean, + isCorrectedRebaseMode: Boolean, hasInt96Timestamps: Boolean) private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) extends Arm { private val isCaseSensitive = sqlConf.caseSensitiveAnalysis @@ -252,6 +312,19 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold private val rebaseMode = ShimLoader.getSparkShims.parquetRebaseRead(sqlConf) private val isCorrectedRebase = "CORRECTED" == rebaseMode + val int96RebaseMode = ShimLoader.getSparkShims.int96ParquetRebaseRead(sqlConf) + private val isInt96CorrectedRebase = "CORRECTED" == int96RebaseMode + + + def isParquetTimeInInt96(parquetType: Type): Boolean = { + parquetType match { + case p:PrimitiveType => + p.getPrimitiveTypeName == PrimitiveTypeName.INT96 + case g:GroupType => //GroupType + g.getFields.asScala.exists(t => isParquetTimeInInt96(t)) + case _ => false + } + } def filterBlocks( file: PartitionedFile, @@ -275,9 +348,15 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte None } + val hasInt96Timestamps = isParquetTimeInInt96(fileSchema) + val isCorrectedRebaseForThisFile = - GpuParquetPartitionReaderFactoryBase.isCorrectedRebaseMode( - footer.getFileMetaData.getKeyValueMetaData.get, isCorrectedRebase) + GpuParquetPartitionReaderFactoryBase.isCorrectedRebaseMode( + footer.getFileMetaData.getKeyValueMetaData.get, isCorrectedRebase) + + val isCorrectedInt96RebaseForThisFile = + GpuParquetPartitionReaderFactoryBase.isCorrectedInt96RebaseMode( + footer.getFileMetaData.getKeyValueMetaData.get, isInt96CorrectedRebase) val blocks = if (pushedFilters.isDefined) { // Use the ParquetFileReader to perform dictionary-level filtering @@ -301,7 +380,8 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte val columnPaths = clippedSchema.getPaths.asScala.map(x => ColumnPath.get(x: _*)) val clipped = ParquetPartitionReader.clipBlocks(columnPaths, blocks.asScala) ParquetFileInfoWithBlockMeta(filePath, clipped, file.partitionValues, - clippedSchema, isCorrectedRebaseForThisFile) + clippedSchema, isCorrectedInt96RebaseForThisFile, isCorrectedRebaseForThisFile, + hasInt96Timestamps) } } @@ -372,7 +452,8 @@ case class GpuParquetMultiFilePartitionReaderFactory( ParquetDataBlock(block), file.partitionValues, ParquetSchemaWrapper(singleFileInfo.schema), - ParquetExtraInfo(singleFileInfo.isCorrectedRebaseMode))) + ParquetExtraInfo(singleFileInfo.isCorrectedRebaseMode, + singleFileInfo.isCorrectedInt96RebaseMode, singleFileInfo.hasInt96Timestamps))) } new MultiFileParquetPartitionReader(conf, files, clippedBlocks, isCaseSensitive, readDataSchema, debugDumpPrefix, @@ -425,7 +506,8 @@ case class GpuParquetPartitionReaderFactory( new ParquetPartitionReader(conf, file, singleFileInfo.filePath, singleFileInfo.blocks, singleFileInfo.schema, isCaseSensitive, readDataSchema, debugDumpPrefix, maxReadBatchSizeRows, - maxReadBatchSizeBytes, metrics, singleFileInfo.isCorrectedRebaseMode) + maxReadBatchSizeBytes, metrics, singleFileInfo.isCorrectedInt96RebaseMode, + singleFileInfo.isCorrectedRebaseMode, singleFileInfo.hasInt96Timestamps) } } @@ -751,7 +833,8 @@ private case class ParquetDataBlock(dataBlock: BlockMetaData) extends DataBlockB } /** Parquet extra information containing isCorrectedRebaseMode */ -case class ParquetExtraInfo(isCorrectedRebaseMode: Boolean) extends ExtraInfo +case class ParquetExtraInfo(isCorrectedRebaseMode: Boolean, + isCorrectedInt96RebaseMode: Boolean, hasInt96Timestamps: Boolean) extends ExtraInfo // contains meta about a single block in a file private case class ParquetSingleDataBlockMeta( @@ -842,9 +925,11 @@ class MultiFileParquetPartitionReader( // We need to ensure all files we are going to combine have the same datetime // rebase mode. if (nextBlockInfo.extraInfo.isCorrectedRebaseMode != - currentBlockInfo.extraInfo.isCorrectedRebaseMode) { - logInfo(s"datetime rebase mode for the next file ${nextBlockInfo.filePath} is " + - s"different then current file ${currentBlockInfo.filePath}, splitting into another batch.") + currentBlockInfo.extraInfo.isCorrectedRebaseMode && + nextBlockInfo.extraInfo.isCorrectedInt96RebaseMode != + currentBlockInfo.extraInfo.isCorrectedInt96RebaseMode) { + logInfo(s"datetime rebase mode for the next file ${nextBlockInfo.filePath} is different " + + s"then current file ${currentBlockInfo.filePath}, splitting into another batch.") return true } @@ -902,13 +987,11 @@ class MultiFileParquetPartitionReader( } closeOnExcept(table) { _ => - if (!extraInfo.isCorrectedRebaseMode) { - (0 until table.getNumberOfColumns).foreach { i => - if (RebaseHelper.isDateTimeRebaseNeededRead(table.getColumn(i))) { - throw RebaseHelper.newRebaseExceptionInRead("Parquet") - } - } - } + GpuParquetScanBase.throwIfNeeded( + table, + extraInfo.isCorrectedInt96RebaseMode, + extraInfo.isCorrectedRebaseMode, + extraInfo.hasInt96Timestamps) } evolveSchemaIfNeededAndClose(table, splits.mkString(","), clippedSchema) } @@ -988,11 +1071,13 @@ class MultiFileCloudParquetPartitionReader( execMetrics) with ParquetPartitionReaderBase { case class HostMemoryBuffersWithMetaData( - override val partitionedFile: PartitionedFile, - override val memBuffersAndSizes: Array[(HostMemoryBuffer, Long)], - override val bytesRead: Long, - isCorrectRebaseMode: Boolean, - clippedSchema: MessageType) extends HostMemoryBuffersWithMetaDataBase + override val partitionedFile: PartitionedFile, + override val memBuffersAndSizes: Array[(HostMemoryBuffer, Long)], + override val bytesRead: Long, + isCorrectRebaseMode: Boolean, + isCorrectInt96RebaseMode: Boolean, + hasInt96Timestamps: Boolean, + clippedSchema: MessageType) extends HostMemoryBuffersWithMetaDataBase private class ReadBatchRunner(filterHandler: GpuParquetFileFilterHandler, file: PartitionedFile, @@ -1018,21 +1103,24 @@ class MultiFileCloudParquetPartitionReader( val bytesRead = fileSystemBytesRead() - startingBytesRead // no blocks so return null buffer and size 0 return HostMemoryBuffersWithMetaData(file, Array((null, 0)), bytesRead, - fileBlockMeta.isCorrectedRebaseMode, fileBlockMeta.schema) + fileBlockMeta.isCorrectedRebaseMode, fileBlockMeta.isCorrectedInt96RebaseMode, + fileBlockMeta.hasInt96Timestamps, fileBlockMeta.schema) } blockChunkIter = fileBlockMeta.blocks.iterator.buffered if (isDone) { val bytesRead = fileSystemBytesRead() - startingBytesRead // got close before finishing HostMemoryBuffersWithMetaData(file, Array((null, 0)), bytesRead, - fileBlockMeta.isCorrectedRebaseMode, fileBlockMeta.schema) + fileBlockMeta.isCorrectedRebaseMode, fileBlockMeta.isCorrectedInt96RebaseMode, + fileBlockMeta.hasInt96Timestamps, fileBlockMeta.schema) } else { if (readDataSchema.isEmpty) { val bytesRead = fileSystemBytesRead() - startingBytesRead val numRows = fileBlockMeta.blocks.map(_.getRowCount).sum.toInt // overload size to be number of rows with null buffer HostMemoryBuffersWithMetaData(file, Array((null, numRows)), bytesRead, - fileBlockMeta.isCorrectedRebaseMode, fileBlockMeta.schema) + fileBlockMeta.isCorrectedRebaseMode, fileBlockMeta.isCorrectedInt96RebaseMode, + fileBlockMeta.hasInt96Timestamps, fileBlockMeta.schema) } else { val filePath = new Path(new URI(file.filePath)) while (blockChunkIter.hasNext) { @@ -1045,10 +1133,12 @@ class MultiFileCloudParquetPartitionReader( // got close before finishing hostBuffers.foreach(_._1.safeClose()) HostMemoryBuffersWithMetaData(file, Array((null, 0)), bytesRead, - fileBlockMeta.isCorrectedRebaseMode, fileBlockMeta.schema) + fileBlockMeta.isCorrectedRebaseMode, fileBlockMeta.isCorrectedInt96RebaseMode, + fileBlockMeta.hasInt96Timestamps, fileBlockMeta.schema) } else { HostMemoryBuffersWithMetaData(file, hostBuffers.toArray, bytesRead, - fileBlockMeta.isCorrectedRebaseMode, fileBlockMeta.schema) + fileBlockMeta.isCorrectedRebaseMode, fileBlockMeta.isCorrectedInt96RebaseMode, + fileBlockMeta.hasInt96Timestamps, fileBlockMeta.schema) } } } @@ -1104,7 +1194,8 @@ class MultiFileCloudParquetPartitionReader( val memBuffersAndSize = buffer.memBuffersAndSizes val (hostBuffer, size) = memBuffersAndSize.head val nextBatch = readBufferToTable(buffer.isCorrectRebaseMode, - buffer.clippedSchema, buffer.partitionedFile.partitionValues, + buffer.isCorrectInt96RebaseMode, buffer.hasInt96Timestamps, buffer.clippedSchema, + buffer.partitionedFile.partitionValues, hostBuffer, size, buffer.partitionedFile.filePath) if (memBuffersAndSize.length > 1) { val updatedBuffers = memBuffersAndSize.drop(1) @@ -1117,8 +1208,11 @@ class MultiFileCloudParquetPartitionReader( } } + private def readBufferToTable( isCorrectRebaseMode: Boolean, + isCorrectInt96RebaseMode: Boolean, + hasInt96Timestamps: Boolean, clippedSchema: MessageType, partValues: InternalRow, hostBuffer: HostMemoryBuffer, @@ -1149,17 +1243,12 @@ class MultiFileCloudParquetPartitionReader( Table.readParquet(parseOpts, hostBuffer, 0, dataSize) } closeOnExcept(table) { _ => - if (!isCorrectRebaseMode) { - (0 until table.getNumberOfColumns).foreach { i => - if (RebaseHelper.isDateTimeRebaseNeededRead(table.getColumn(i))) { - throw RebaseHelper.newRebaseExceptionInRead("Parquet") - } - } - } + GpuParquetScanBase.throwIfNeeded(table, isCorrectInt96RebaseMode, isCorrectRebaseMode, + hasInt96Timestamps) maxDeviceMemory = max(GpuColumnVector.getTotalDeviceMemoryUsed(table), maxDeviceMemory) if (readDataSchema.length < table.getNumberOfColumns) { throw new QueryExecutionException(s"Expected ${readDataSchema.length} columns " + - s"but read ${table.getNumberOfColumns} from $fileName") + s"but read ${table.getNumberOfColumns} from $fileName") } } metrics(NUM_OUTPUT_BATCHES) += 1 @@ -1209,7 +1298,9 @@ class ParquetPartitionReader( maxReadBatchSizeRows: Integer, maxReadBatchSizeBytes: Long, execMetrics: Map[String, GpuMetric], - isCorrectedRebaseMode: Boolean) extends FilePartitionReaderBase(conf, execMetrics) + isCorrectedInt96RebaseMode: Boolean, + isCorrectedRebaseMode: Boolean, + hasInt96Timestamps: Boolean) extends FilePartitionReaderBase(conf, execMetrics) with ParquetPartitionReaderBase { private val blockIterator: BufferedIterator[BlockMetaData] = clippedBlocks.iterator.buffered @@ -1291,13 +1382,8 @@ class ParquetPartitionReader( Table.readParquet(parseOpts, dataBuffer, 0, dataSize) } closeOnExcept(table) { _ => - if (!isCorrectedRebaseMode) { - (0 until table.getNumberOfColumns).foreach { i => - if (RebaseHelper.isDateTimeRebaseNeededRead(table.getColumn(i))) { - throw RebaseHelper.newRebaseExceptionInRead("Parquet") - } - } - } + GpuParquetScanBase.throwIfNeeded(table, isCorrectedInt96RebaseMode, isCorrectedRebaseMode, + hasInt96Timestamps) maxDeviceMemory = max(GpuColumnVector.getTotalDeviceMemoryUsed(table), maxDeviceMemory) if (readDataSchema.length < table.getNumberOfColumns) { throw new QueryExecutionException(s"Expected ${readDataSchema.length} columns " + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala index 5cebe599e91..36548e281f6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -90,6 +90,23 @@ trait SparkShims { def parquetRebaseRead(conf: SQLConf): String def parquetRebaseWrite(conf: SQLConf): String def v1RepairTableCommand(tableName: TableIdentifier): RunnableCommand + def hasSeparateINT96RebaseConf: Boolean + + def int96ParquetRebaseRead(conf: SQLConf): String = { + parquetRebaseRead(conf) + } + + def int96ParquetRebaseWrite(conf: SQLConf): String = { + parquetRebaseWrite(conf) + } + + def int96ParquetRebaseReadKey: String = { + parquetRebaseReadKey + } + + def int96ParquetRebaseWriteKey: String = { + parquetRebaseWriteKey + } def getParquetFilters( schema: MessageType, diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RebaseHelperSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RebaseHelperSuite.scala index e543a51e4bb..4b039fa1484 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RebaseHelperSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RebaseHelperSuite.scala @@ -23,15 +23,15 @@ import org.scalatest.FunSuite class RebaseHelperSuite extends FunSuite with Arm { test("all null timestamp days column rebase check") { withResource(ColumnVector.timestampDaysFromBoxedInts(null, null, null)) { c => - assertResult(false)(RebaseHelper.isDateTimeRebaseNeededWrite(c)) - assertResult(false)(RebaseHelper.isDateTimeRebaseNeededRead(c)) + assertResult(false)(RebaseHelper.isDateRebaseNeededInWrite(c)) + assertResult(false)(RebaseHelper.isDateRebaseNeededInRead(c)) } } test("all null timestamp microseconds column rebase check") { withResource(ColumnVector.timestampMicroSecondsFromBoxedLongs(null, null, null)) { c => - assertResult(false)(RebaseHelper.isDateTimeRebaseNeededWrite(c)) - assertResult(false)(RebaseHelper.isDateTimeRebaseNeededRead(c)) + assertResult(false)(RebaseHelper.isTimeRebaseNeededInWrite(c)) + assertResult(false)(RebaseHelper.isTimeRebaseNeededInRead(c)) } } }