Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GpuCheckOverflowInTableInsert to Databricks 11.3+ #9800

Merged
merged 1 commit into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions integration_tests/src/main/python/parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,3 +818,24 @@ def test_parquet_write_column_name_with_dots(spark_tmp_path):
lambda spark, path: gen_df(spark, gens).coalesce(1).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path)

@ignore_order
def test_parquet_append_with_downcast(spark_tmp_table_factory, spark_tmp_path):
data_path = spark_tmp_path + "/PARQUET_DATA"
cpu_table = spark_tmp_table_factory.get()
gpu_table = spark_tmp_table_factory.get()
def setup_tables(spark):
df = unary_op_df(spark, int_gen, length=10)
df.write.format("parquet").option("path", data_path + "/CPU").saveAsTable(cpu_table)
df.write.format("parquet").option("path", data_path + "/GPU").saveAsTable(gpu_table)
with_cpu_session(setup_tables)
def do_append(spark, path):
table = cpu_table
if path.endswith("/GPU"):
table = gpu_table
unary_op_df(spark, LongGen(min_val=0, max_val=128, special_cases=[]), length=10)\
.write.mode("append").saveAsTable(table)
assert_gpu_and_cpu_writes_are_equal_collect(
do_append,
lambda spark, path: spark.read.parquet(path),
data_path)
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "330db"}
{"spark": "332db"}
{"spark": "341db"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan}
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.exchange.{EXECUTOR_BROADCAST, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.rapids.{GpuCheckOverflowInTableInsert, GpuElementAtMeta}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastHashJoinExec, GpuBroadcastNestedLoopJoinExec}

trait Spark330PlusDBShims extends Spark321PlusDBShims {
override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
val shimExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq(
GpuOverrides.expr[CheckOverflowInTableInsert](
"Casting a numeric value as another numeric type in store assignment",
ExprChecks.unaryProjectInputMatchesOutput(
TypeSig.all,
TypeSig.all),
(t, conf, p, r) => new UnaryExprMeta[CheckOverflowInTableInsert](t, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression = {
child match {
case c: GpuCast => GpuCheckOverflowInTableInsert(c, t.columnName)
case _ =>
throw new IllegalStateException("Expression child is not of Type GpuCast")
}
}
}),
GpuElementAtMeta.elementAtRule(true)
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
super.getExprs ++ shimExprs ++ DayTimeIntervalShims.exprs ++ RoundingShims.exprs
}

override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] =
super.getExecs ++ PythonMapInArrowExecShims.execs

override def reproduceEmptyStringBug: Boolean = false

override def isExecutorBroadcastShuffle(shuffle: ShuffleExchangeLike): Boolean = {
shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST)
}

override def shuffleParentReadsShuffleData(shuffle: ShuffleExchangeLike,
parent: SparkPlan): Boolean = {
parent match {
case _: GpuBroadcastHashJoinExec =>
shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST)
case _: GpuBroadcastNestedLoopJoinExec =>
shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST)
case _ => false
}
}


override def addRowShuffleToQueryStageTransitionIfNeeded(c2r: ColumnarToRowTransition,
sqse: ShuffleQueryStageExec): SparkPlan = {
val plan = GpuTransitionOverrides.getNonQueryStagePlan(sqse)
plan match {
case shuffle: ShuffleExchangeLike if shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST) =>
ShuffleExchangeExec(SinglePartition, c2r, EXECUTOR_BROADCAST)
case _ =>
c2r
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,13 @@ package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan}
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, RunnableCommand}
import org.apache.spark.sql.execution.exchange.{EXECUTOR_BROADCAST, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.rapids.GpuElementAtMeta
import org.apache.spark.sql.rapids.execution.{GpuBroadcastHashJoinExec, GpuBroadcastNestedLoopJoinExec}

object SparkShimImpl extends Spark321PlusDBShims {
object SparkShimImpl extends Spark330PlusDBShims {
// AnsiCast is removed from Spark3.4.0
override def ansiCastRule: ExprRule[_ <: Expression] = null

override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
val elementAtExpr: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq(
GpuElementAtMeta.elementAtRule(true)
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
super.getExprs ++ DayTimeIntervalShims.exprs ++ RoundingShims.exprs ++ elementAtExpr
}

override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] =
super.getExecs ++ PythonMapInArrowExecShims.execs

override def getDataWriteCmds: Map[Class[_ <: DataWritingCommand],
DataWritingCommandRule[_ <: DataWritingCommand]] = {
Seq(GpuOverrides.dataWriteCmd[CreateDataSourceTableAsSelectCommand](
Expand All @@ -56,32 +40,4 @@ object SparkShimImpl extends Spark321PlusDBShims {
RunnableCommandRule[_ <: RunnableCommand]] = {
Map.empty
}

override def reproduceEmptyStringBug: Boolean = false

override def isExecutorBroadcastShuffle(shuffle: ShuffleExchangeLike): Boolean = {
shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST)
}

override def shuffleParentReadsShuffleData(shuffle: ShuffleExchangeLike,
parent: SparkPlan): Boolean = {
parent match {
case _: GpuBroadcastHashJoinExec =>
shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST)
case _: GpuBroadcastNestedLoopJoinExec =>
shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST)
case _ => false
}
}

override def addRowShuffleToQueryStageTransitionIfNeeded(c2r: ColumnarToRowTransition,
sqse: ShuffleQueryStageExec): SparkPlan = {
val plan = GpuTransitionOverrides.getNonQueryStagePlan(sqse)
plan match {
case shuffle: ShuffleExchangeLike if shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST) =>
ShuffleExchangeExec(SinglePartition, c2r, EXECUTOR_BROADCAST)
case _ =>
c2r
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
*/

/*** spark-rapids-shim-json-lines
{"spark": "330db"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332cdh"}
{"spark": "332db"}
{"spark": "333"}
{"spark": "340"}
{"spark": "341"}
{"spark": "341db"}
{"spark": "350"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import com.nvidia.spark.rapids.{ExprChecks, ExprRule, GpuCast, GpuExpression, Gp
import org.apache.spark.sql.catalyst.expressions.{CheckOverflowInTableInsert, Expression}
import org.apache.spark.sql.rapids.GpuCheckOverflowInTableInsert

trait Spark331PlusShims extends Spark330PlusNonDBShims {
trait Spark331PlusNonDBShims extends Spark330PlusNonDBShims {
override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
val map: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq(
// Add expression CheckOverflowInTableInsert starting Spark-3.3.1+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import com.nvidia.spark.rapids._

import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, RunnableCommand}

object SparkShimImpl extends Spark331PlusShims with AnsiCastRuleShims {
object SparkShimImpl extends Spark331PlusNonDBShims with AnsiCastRuleShims {
override def getDataWriteCmds: Map[Class[_ <: DataWritingCommand],
DataWritingCommandRule[_ <: DataWritingCommand]] = {
Seq(GpuOverrides.dataWriteCmd[CreateDataSourceTableAsSelectCommand](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

object SparkShimImpl extends Spark33cdhShims with Spark331PlusShims {}
object SparkShimImpl extends Spark33cdhShims with Spark331PlusNonDBShims {}
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,11 @@ package com.nvidia.spark.rapids.shims
import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan}
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, RunnableCommand}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.exchange.{EXECUTOR_BROADCAST, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.rapids.GpuElementAtMeta
import org.apache.spark.sql.rapids.execution.{GpuBroadcastHashJoinExec, GpuBroadcastNestedLoopJoinExec}

trait Spark332PlusDBShims extends Spark321PlusDBShims {
trait Spark332PlusDBShims extends Spark330PlusDBShims {
// AnsiCast is removed from Spark3.4.0
override def ansiCastRule: ExprRule[_ <: Expression] = null

Expand All @@ -45,10 +40,9 @@ trait Spark332PlusDBShims extends Spark321PlusDBShims {
(a, conf, p, r) => new UnaryExprMeta[KnownNullable](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression = GpuKnownNullable(child)
}
),
GpuElementAtMeta.elementAtRule(true)
)
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
super.getExprs ++ shimExprs ++ DayTimeIntervalShims.exprs ++ RoundingShims.exprs
super.getExprs ++ shimExprs
}

private val shimExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = Seq(
Expand All @@ -63,7 +57,7 @@ trait Spark332PlusDBShims extends Spark321PlusDBShims {
).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap

override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] =
super.getExecs ++ shimExecs ++ PythonMapInArrowExecShims.execs
super.getExecs ++ shimExecs

override def getDataWriteCmds: Map[Class[_ <: DataWritingCommand],
DataWritingCommandRule[_ <: DataWritingCommand]] = {
Expand All @@ -78,32 +72,4 @@ trait Spark332PlusDBShims extends Spark321PlusDBShims {
(a, conf, p, r) => new CreateDataSourceTableAsSelectCommandMeta(a, conf, p, r))
).map(r => (r.getClassFor.asSubclass(classOf[RunnableCommand]), r)).toMap
}

override def reproduceEmptyStringBug: Boolean = false

override def isExecutorBroadcastShuffle(shuffle: ShuffleExchangeLike): Boolean = {
shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST)
}

override def shuffleParentReadsShuffleData(shuffle: ShuffleExchangeLike,
parent: SparkPlan): Boolean = {
parent match {
case _: GpuBroadcastHashJoinExec =>
shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST)
case _: GpuBroadcastNestedLoopJoinExec =>
shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST)
case _ => false
}
}

override def addRowShuffleToQueryStageTransitionIfNeeded(c2r: ColumnarToRowTransition,
sqse: ShuffleQueryStageExec): SparkPlan = {
val plan = GpuTransitionOverrides.getNonQueryStagePlan(sqse)
plan match {
case shuffle: ShuffleExchangeLike if shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST) =>
ShuffleExchangeExec(SinglePartition, c2r, EXECUTOR_BROADCAST)
case _ =>
c2r
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.exchange.ENSURE_REQUIREMENTS
import org.apache.spark.sql.rapids.GpuElementAtMeta
import org.apache.spark.sql.rapids.GpuV1WriteUtils.GpuEmpty2Null

trait Spark340PlusShims extends Spark331PlusShims {
trait Spark340PlusNonDBShims extends Spark331PlusNonDBShims {

private val shimExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = Seq(
GpuOverrides.exec[GlobalLimitExec](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

object SparkShimImpl extends Spark340PlusShims
object SparkShimImpl extends Spark340PlusNonDBShims
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDAF, ToPret
import org.apache.spark.sql.rapids.execution.python.GpuPythonUDAF
import org.apache.spark.sql.types.StringType

object SparkShimImpl extends Spark340PlusShims {
object SparkShimImpl extends Spark340PlusNonDBShims {

override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
val shimExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq(
Expand Down