Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transition to v2 shims [Databricks] #4857

Merged
merged 35 commits into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6137ef7
301 compiling
razajafri Feb 2, 2022
b4ceda0
Merged the shims
razajafri Feb 3, 2022
68c9a3b
db build errors
razajafri Feb 18, 2022
06c98dd
renamed folder
razajafri Feb 18, 2022
bb5c3b8
modify binary-dedupe.sh to reflect the new package and remove call to…
razajafri Feb 19, 2022
c003b06
some more changes to return the static classname for shuffle managers
razajafri Feb 23, 2022
cc671b5
init shim when getting shufflemanager
razajafri Feb 23, 2022
ca5c0be
getVersion changes
razajafri Feb 23, 2022
e4dd885
add hypot back
razajafri Feb 23, 2022
2282417
removed buildShim
razajafri Feb 24, 2022
69dc48d
clean up
razajafri Feb 24, 2022
2c23552
removed package v2
razajafri Feb 24, 2022
de8a56b
reference the correct package
razajafri Feb 24, 2022
980eba3
removed duplicate versions of RapidsShuffleManager
razajafri Feb 25, 2022
a2a6b7e
addressed review comments
razajafri Feb 25, 2022
cf5bd29
fix db build
razajafri Feb 26, 2022
6dc37c3
Revert "fix db build"
razajafri Feb 28, 2022
d1c0fd8
Revert "addressed review comments"
razajafri Feb 28, 2022
e485926
Revert "removed duplicate versions of RapidsShuffleManager"
razajafri Feb 28, 2022
3014c21
removed the non-existent folder
razajafri Feb 28, 2022
38aca27
removed unused import
razajafri Feb 28, 2022
1c6b7f8
reverted shuffle manager and internal manager change
razajafri Feb 28, 2022
f1bbbed
revert spark2diffs changes
razajafri Mar 1, 2022
26f1368
Fix 301db build
razajafri Mar 1, 2022
3b3ed5a
removed reference of ShimLoader.getSparkShims from doc
razajafri Mar 1, 2022
d941b83
Revert 312db build fix
razajafri Mar 1, 2022
1846c58
merge
razajafri Mar 7, 2022
568e4f6
merge conflicts
razajafri Mar 7, 2022
b873c53
fix db build
razajafri Mar 7, 2022
02bd9e5
fix 301db
razajafri Mar 7, 2022
3f8fb9b
fixed 304
razajafri Mar 7, 2022
8f6dce4
fixed 330 build errors
razajafri Mar 8, 2022
57b5dbd
Merge remote-tracking branch 'origin/branch-22.04' into shim-work-2
razajafri Mar 8, 2022
6469765
Merge remote-tracking branch 'origin/branch-22.04' into shim-work-2
razajafri Mar 8, 2022
e36248d
fixed imports
razajafri Mar 8, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,7 @@ import scala.reflect.api
import scala.reflect.runtime.universe._

import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.v2.SparkShimImpl

import org.apache.spark.internal.Logging

Expand Down Expand Up @@ -70,7 +71,7 @@ object ApiValidation extends Logging {
var printNewline = false

val sparkToShimMap = Map("3.0.1" -> "spark301", "3.1.1" -> "spark311")
val sparkVersion = ShimLoader.getSparkShims.getSparkShimVersion.toString
val sparkVersion = SparkShimImpl.getSparkShimVersion.toString
val shimVersion = sparkToShimMap(sparkVersion)

gpuKeys.foreach { e =>
Expand Down
4 changes: 2 additions & 2 deletions dist/scripts/binary-dedupe.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
tgravescs marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -173,7 +173,7 @@ function verify_same_sha_for_unshimmed() {
# TODO currently RapidsShuffleManager is "removed" from /spark3* by construction in
tgravescs marked this conversation as resolved.
Show resolved Hide resolved
# dist pom.xml via ant. We could delegate this logic to this script
# and make both simmpler
if [[ ! "$class_file_quoted" =~ (com/nvidia/spark/rapids/spark3.*/.*ShuffleManager.class|org/apache/spark/sql/rapids/shims/spark3.*/ProxyRapidsShuffleInternalManager.class) ]]; then
if [[ ! "$class_file_quoted" =~ (com/nvidia/spark/rapids/shims/v2/.*ShuffleManager.class|org/apache/spark/sql/rapids/shims/v2/ProxyRapidsShuffleInternalManager.class) ]]; then

if ! grep -q "/spark.\+/$class_file_quoted" "$SPARK3XX_COMMON_TXT"; then
echo >&2 "$class_file is not bitwise-identical across shims"
Expand Down
2 changes: 1 addition & 1 deletion jenkins/spark-premerge-build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ rapids_shuffle_smoke_test() {
PYSP_TEST_spark_cores_max=2 \
PYSP_TEST_spark_executor_cores=1 \
SPARK_SUBMIT_FLAGS="--conf spark.executorEnv.UCX_ERROR_SIGNALS=" \
PYSP_TEST_spark_shuffle_manager=com.nvidia.spark.rapids.$SHUFFLE_SPARK_SHIM.RapidsShuffleManager \
PYSP_TEST_spark_shuffle_manager=com.nvidia.spark.rapids.shims.v2.RapidsShuffleManager \
tgravescs marked this conversation as resolved.
Show resolved Hide resolved
PYSP_TEST_spark_rapids_memory_gpu_minAllocFraction=0 \
PYSP_TEST_spark_rapids_memory_gpu_maxAllocFraction=0.1 \
PYSP_TEST_spark_rapids_memory_gpu_allocFraction=0.1 \
Expand Down
11 changes: 9 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
<sources>
<source>${project.basedir}/src/main/301+-nondb/scala</source>
<source>${project.basedir}/src/main/301/scala</source>
<source>${project.basedir}/src/main/301until304/scala</source>
tgravescs marked this conversation as resolved.
Show resolved Hide resolved
<source>${project.basedir}/src/main/301until310-all/scala</source>
<source>${project.basedir}/src/main/301until310-nondb/scala</source>
<source>${project.basedir}/src/main/301until320-all/scala</source>
Expand Down Expand Up @@ -164,6 +165,7 @@
<sources>
<source>${project.basedir}/src/main/301+-nondb/scala</source>
<source>${project.basedir}/src/main/302/scala</source>
<source>${project.basedir}/src/main/301until304/scala</source>
<source>${project.basedir}/src/main/301until310-all/scala</source>
<source>${project.basedir}/src/main/301until310-nondb/scala</source>
<source>${project.basedir}/src/main/301until320-all/scala</source>
Expand Down Expand Up @@ -222,6 +224,7 @@
<sources>
<source>${project.basedir}/src/main/301+-nondb/scala</source>
<source>${project.basedir}/src/main/303/scala</source>
<source>${project.basedir}/src/main/301until304/scala</source>
<source>${project.basedir}/src/main/301until310-all/scala</source>
<source>${project.basedir}/src/main/301until310-nondb/scala</source>
<source>${project.basedir}/src/main/301until320-all/scala</source>
Expand Down Expand Up @@ -327,7 +330,7 @@
<configuration>
<sources>
<source>${project.basedir}/src/main/301+-nondb/scala</source>
<source>${project.basedir}/src/main/311/scala</source>
<source>${project.basedir}/src/main/311-nondb/scala</source>
<source>${project.basedir}/src/main/301until320-all/scala</source>
<source>${project.basedir}/src/main/301until320-noncdh/scala</source>
<source>${project.basedir}/src/main/301until320-nondb/scala</source>
Expand Down Expand Up @@ -509,7 +512,7 @@
<configuration>
<sources>
<source>${project.basedir}/src/main/301+-nondb/scala</source>
<source>${project.basedir}/src/main/312/scala</source>
<source>${project.basedir}/src/main/312-nondb/scala</source>
<source>${project.basedir}/src/main/301until320-all/scala</source>
<source>${project.basedir}/src/main/301until320-noncdh/scala</source>
<source>${project.basedir}/src/main/301until320-nondb/scala</source>
Expand Down Expand Up @@ -758,6 +761,7 @@
<source>${project.basedir}/src/main/311until330-all/scala</source>
<source>${project.basedir}/src/main/320+/scala</source>
<source>${project.basedir}/src/main/321+/scala</source>
<source>${project.basedir}/src/main/322+/scala</source>
<source>${project.basedir}/src/main/post320-treenode/scala</source>
</sources>
</configuration>
Expand Down Expand Up @@ -819,6 +823,7 @@
<source>${project.basedir}/src/main/311+-nondb/scala</source>
<source>${project.basedir}/src/main/320+/scala</source>
<source>${project.basedir}/src/main/321+/scala</source>
<source>${project.basedir}/src/main/322+/scala</source>
<source>${project.basedir}/src/main/330+/scala</source>
<source>${project.basedir}/src/main/post320-treenode/scala</source>
</sources>
Expand Down Expand Up @@ -876,6 +881,8 @@
<configuration>
<sources>
<source>${project.basedir}/src/main/301+-nondb/scala</source>
<source>${project.basedir}/src/main/311-nondb/scala</source>
<source>${project.basedir}/src/main/311cdh/scala</source>
<source>${project.basedir}/src/main/301until320-all/scala</source>
<source>${project.basedir}/src/main/301until320-nondb/scala</source>
<source>${project.basedir}/src/main/301until330-all/scala</source>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ final class CastExprMeta[INPUT <: Cast](
// NOOP for anything prior to 3.2.0
case (_: StringType, dt:DecimalType) =>
// Spark 2.x: removed check for
// !ShimLoader.getSparkShims.isCastingStringToNegDecimalScaleSupported
// !SparkShimImpl.isCastingStringToNegDecimalScaleSupported
// this dealt with handling a bug fix that is only in newer versions of Spark
// (https://issues.apache.org/jira/browse/SPARK-37451)
// Since we don't know what version of Spark 3 they will be using
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1397,7 +1397,7 @@ object GpuOverrides extends Logging {
TypeSig.STRING)),
(a, conf, p, r) => new UnixTimeExprMeta[ToUnixTimestamp](a, conf, p, r) {
override def shouldFallbackOnAnsiTimestamp: Boolean = false
// ShimLoader.getSparkShims.shouldFallbackOnAnsiTimestamp
// SparkShimImpl.shouldFallbackOnAnsiTimestamp
}),
expr[UnixTimestamp](
"Returns the UNIX timestamp of current or specified time",
Expand All @@ -1410,7 +1410,7 @@ object GpuOverrides extends Logging {
TypeSig.STRING)),
(a, conf, p, r) => new UnixTimeExprMeta[UnixTimestamp](a, conf, p, r) {
override def shouldFallbackOnAnsiTimestamp: Boolean = false
// ShimLoader.getSparkShims.shouldFallbackOnAnsiTimestamp
// SparkShimImpl.shouldFallbackOnAnsiTimestamp

}),
expr[Hour](
Expand Down Expand Up @@ -2865,8 +2865,8 @@ object GpuOverrides extends Logging {
TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), TypeSig.all),
(sample, conf, p, r) => new GpuSampleExecMeta(sample, conf, p, r) {}
),
// ShimLoader.getSparkShims.aqeShuffleReaderExec,
// ShimLoader.getSparkShims.neverReplaceShowCurrentNamespaceCommand,
// SparkShimImpl.aqeShuffleReaderExec,
// SparkShimImpl.neverReplaceShowCurrentNamespaceCommand,
neverReplaceExec[ExecutedCommandExec]("Table metadata operation")
).collect { case r if r != null => (r.getClassFor.asSubclass(classOf[SparkPlan]), r) }.toMap

Expand Down Expand Up @@ -2955,7 +2955,7 @@ object GpuOverrides extends Logging {
// case c2r: ColumnarToRowExec => prepareExplainOnly(c2r.child)
case re: ReusedExchangeExec => prepareExplainOnly(re.child)
// case aqe: AdaptiveSparkPlanExec =>
// prepareExplainOnly(ShimLoader.getSparkShims.getAdaptiveInputPlan(aqe))
// prepareExplainOnly(SparkShimImpl.getAdaptiveInputPlan(aqe))
case sub: SubqueryExec => prepareExplainOnly(sub.child)
}
planAfter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ object GpuParquetFileFormat {
// they set when they get to 3.x. The default in 3.x is EXCEPTION which would be good
// for us.
/*
ShimLoader.getSparkShims.int96ParquetRebaseWrite(sqlConf) match {
SparkShimImpl.int96ParquetRebaseWrite(sqlConf) match {
case "EXCEPTION" =>
case "CORRECTED" =>
case "LEGACY" =>
Expand All @@ -90,7 +90,7 @@ object GpuParquetFileFormat {
meta.willNotWorkOnGpu(s"$other is not a supported rebase mode for int96")
}

ShimLoader.getSparkShims.parquetRebaseWrite(sqlConf) match {
SparkShimImpl.parquetRebaseWrite(sqlConf) match {
case "EXCEPTION" => //Good
case "CORRECTED" => //Good
case "LEGACY" =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,31 +97,31 @@ object GpuParquetScanBase {

// Spark 2.x doesn't support the rebase mode
/*
sqlConf.get(ShimLoader.getSparkShims.int96ParquetRebaseReadKey) match {
sqlConf.get(SparkShimImpl.int96ParquetRebaseReadKey) match {
case "EXCEPTION" => if (schemaMightNeedNestedRebase) {
meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " +
s"${ShimLoader.getSparkShims.int96ParquetRebaseReadKey} is EXCEPTION")
s"${SparkShimImpl.int96ParquetRebaseReadKey} is EXCEPTION")
}
case "CORRECTED" => // Good
case "LEGACY" => // really is EXCEPTION for us...
if (schemaMightNeedNestedRebase) {
meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " +
s"${ShimLoader.getSparkShims.int96ParquetRebaseReadKey} is LEGACY")
s"${SparkShimImpl.int96ParquetRebaseReadKey} is LEGACY")
}
case other =>
meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode")
}

sqlConf.get(ShimLoader.getSparkShims.parquetRebaseReadKey) match {
sqlConf.get(SparkShimImpl.parquetRebaseReadKey) match {
case "EXCEPTION" => if (schemaMightNeedNestedRebase) {
meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " +
s"${ShimLoader.getSparkShims.parquetRebaseReadKey} is EXCEPTION")
s"${SparkShimImpl.parquetRebaseReadKey} is EXCEPTION")
}
case "CORRECTED" => // Good
case "LEGACY" => // really is EXCEPTION for us...
if (schemaMightNeedNestedRebase) {
meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " +
s"${ShimLoader.getSparkShims.parquetRebaseReadKey} is LEGACY")
s"${SparkShimImpl.parquetRebaseReadKey} is LEGACY")
}
case other =>
meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.execution.python.shims.v2

import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.python.PythonWorkerSemaphore
import com.nvidia.spark.rapids.shims.v2.ShimUnaryExecNode
import com.nvidia.spark.rapids.shims.v2.{ShimUnaryExecNode, SparkShimImpl}

import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
Expand Down Expand Up @@ -96,7 +96,7 @@ case class GpuFlatMapGroupsInPandasExec(
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] =
Seq(groupingAttributes.map(ShimLoader.getSparkShims.sortOrder(_, Ascending)))
Seq(groupingAttributes.map(SparkShimImpl.sortOrder(_, Ascending)))

private val pandasFunction = func.asInstanceOf[GpuPythonUDF].func

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,19 +16,17 @@

package com.nvidia.spark.rapids.shims.spark301

import com.nvidia.spark.rapids.{SparkShims, SparkShimVersion}
import com.nvidia.spark.rapids.SparkShimVersion

object SparkShimServiceProvider {
val VERSION = SparkShimVersion(3, 0, 1)
val VERSIONNAMES = Seq(s"$VERSION")
}
class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider {

override def getShimVersion: SparkShimVersion = SparkShimServiceProvider.VERSION

def matchesVersion(version: String): Boolean = {
SparkShimServiceProvider.VERSIONNAMES.contains(version)
}

def buildShim: SparkShims = {
new Spark301Shims()
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids.shims.spark301db
tgravescs marked this conversation as resolved.
Show resolved Hide resolved

import com.nvidia.spark.rapids.{DatabricksShimVersion, SparkShims}
import com.nvidia.spark.rapids.{DatabricksShimVersion, SparkShimVersion}

object SparkShimServiceProvider {
val VERSION = DatabricksShimVersion(3, 0, 1)
Expand All @@ -25,11 +25,9 @@ object SparkShimServiceProvider {

class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider {

override def getShimVersion: SparkShimVersion = SparkShimServiceProvider.VERSION

def matchesVersion(version: String): Boolean = {
SparkShimServiceProvider.VERSIONNAMES.contains(version)
}

def buildShim: SparkShims = {
new Spark301dbShims()
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
tgravescs marked this conversation as resolved.
Show resolved Hide resolved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -14,10 +14,10 @@
* limitations under the License.
*/

package com.nvidia.spark.rapids.spark303
package com.nvidia.spark.rapids.v2

import org.apache.spark.SparkConf
import org.apache.spark.sql.rapids.shims.spark303.ProxyRapidsShuffleInternalManager
import org.apache.spark.sql.rapids.shims.v2.ProxyRapidsShuffleInternalManager

/** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */
sealed class RapidsShuffleManager(
Expand Down
Loading