Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle minimum GPU architecture supported [databricks] #10540

Merged
merged 14 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 65 additions & 4 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ object RapidsPluginUtils extends Logging {
private val TASK_GPU_AMOUNT_KEY = "spark.task.resource.gpu.amount"
private val EXECUTOR_GPU_AMOUNT_KEY = "spark.executor.resource.gpu.amount"
private val SPARK_MASTER = "spark.master"
private val SPARK_RAPIDS_REPO_URL = "https://github.com/NVIDIA/spark-rapids"

{
val pluginProps = loadProps(PLUGIN_PROPS_FILENAME)
Expand Down Expand Up @@ -346,6 +347,63 @@ object RapidsPluginUtils extends Logging {
loadExtensions(classOf[SparkPlugin], pluginClasses)
}
}

/**
* Extracts supported GPU architectures from the given properties file
*/
private def getSupportedGpuArchitectures(propFileName: String): Set[Int] = {
val props = RapidsPluginUtils.loadProps(propFileName)
Option(props.getProperty("gpu_architectures"))
.getOrElse(throw new RuntimeException(s"GPU architectures not found in $propFileName"))
.split(";")
.map(_.toInt)
.toSet
}

/**
* Checks if the current GPU architecture is supported by the spark-rapids-jni
* and cuDF libraries.
*/
def validateGpuArchitecture(): Unit = {
val gpuArch = Cuda.getComputeCapabilityMajor * 10 + Cuda.getComputeCapabilityMinor
validateGpuArchitectureInternal(gpuArch, getSupportedGpuArchitectures(JNI_PROPS_FILENAME),
getSupportedGpuArchitectures(CUDF_PROPS_FILENAME))
}

/**
* Checks the validity of the provided GPU architecture in the provided architecture set.
*
* See: https://docs.nvidia.com/cuda/ampere-compatibility-guide/index.html
*/
def validateGpuArchitectureInternal(gpuArch: Int, jniSupportedGpuArchs: Set[Int],
cudfSupportedGpuArchs: Set[Int]): Unit = {
val supportedGpuArchs = jniSupportedGpuArchs.intersect(cudfSupportedGpuArchs)
if (supportedGpuArchs.isEmpty) {
val jniSupportedGpuArchsStr = jniSupportedGpuArchs.toSeq.sorted.mkString(", ")
val cudfSupportedGpuArchsStr = cudfSupportedGpuArchs.toSeq.sorted.mkString(", ")
throw new IllegalStateException(s"Compatibility check failed for GPU architecture " +
s"$gpuArch. Supported GPU architectures by JNI: $jniSupportedGpuArchsStr and " +
s"cuDF: $cudfSupportedGpuArchsStr. Please report this issue at $SPARK_RAPIDS_REPO_URL." +
s" This check can be disabled by setting `spark.rapids.skipGpuArchitectureCheck` to" +
s" `true`, but it may lead to functional failures.")
}

val minSupportedGpuArch = supportedGpuArchs.min
// Check if the device architecture is supported
if (gpuArch < minSupportedGpuArch) {
throw new RuntimeException(s"Device architecture $gpuArch is unsupported." +
s" Minimum supported architecture: $minSupportedGpuArch.")
}
val supportedMajorGpuArchs = supportedGpuArchs.map(_/10)
val majorGpuArch = gpuArch/10
parthosa marked this conversation as resolved.
Show resolved Hide resolved
// Warn the user if the device's major architecture is not available
if (!supportedMajorGpuArchs.contains(majorGpuArch)) {
val supportedMajorArchStr = supportedMajorGpuArchs.toSeq.sorted.mkString(", ")
logWarning(s"No precompiled binaries for device major architecture $majorGpuArch. " +
"This may lead to expensive JIT compile on startup. " +
s"Binaries available for architectures $supportedMajorArchStr.")
}
}
}

/**
Expand Down Expand Up @@ -427,17 +485,20 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging {
pluginContext: PluginContext,
extraConf: java.util.Map[String, String]): Unit = {
try {
if (Cuda.getComputeCapabilityMajor < 6) {
throw new RuntimeException(s"GPU compute capability ${Cuda.getComputeCapabilityMajor}" +
" is unsupported, requires 6.0+")
}
// if configured, re-register checking leaks hook.
reRegisterCheckLeakHook()

val sparkConf = pluginContext.conf()
val numCores = RapidsPluginUtils.estimateCoresOnExec(sparkConf)
val conf = new RapidsConf(extraConf.asScala.toMap)

// Checks if the current GPU architecture is supported by the
// spark-rapids-jni and cuDF libraries.
// Note: We allow this check to be skipped for off-chance cases.
if (!conf.skipGpuArchCheck) {
RapidsPluginUtils.validateGpuArchitecture()
}

// Fail if there are multiple plugin jars in the classpath.
RapidsPluginUtils.detectMultipleJars(conf)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2132,6 +2132,13 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression.
.booleanConf
.createOptional

val SKIP_GPU_ARCH_CHECK = conf("spark.rapids.skipGpuArchitectureCheck")
.doc("When true, skips GPU architecture compatibility check. Note that this check " +
"might still be present in cuDF.")
.internal()
.booleanConf
.createWithDefault(false)

private def printSectionHeader(category: String): Unit =
println(s"\n### $category")

Expand Down Expand Up @@ -2895,6 +2902,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val splitUntilSizeOverride: Option[Long] = get(SPLIT_UNTIL_SIZE_OVERRIDE)

lazy val skipGpuArchCheck: Boolean = get(SKIP_GPU_ARCH_CHECK)

private val optimizerDefaults = Map(
// this is not accurate because CPU projections do have a cost due to appending values
// to each row that is produced, but this needs to be a really small number because
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import com.nvidia.spark.rapids.RapidsPluginUtils.validateGpuArchitectureInternal
import org.scalatest.funsuite.AnyFunSuite

class GpuArchitectureTestSuite extends AnyFunSuite {
test("test supported architecture") {
val jniSupportedGpuArchs = Set(50, 60, 70)
val cudfSupportedGpuArchs = Set(50, 60, 65, 70)
val gpuArch = 60
validateGpuArchitectureInternal(gpuArch, jniSupportedGpuArchs, cudfSupportedGpuArchs)
}

test("test unsupported architecture") {
assertThrows[RuntimeException] {
parthosa marked this conversation as resolved.
Show resolved Hide resolved
val jniSupportedGpuArchs = Set(50, 60, 70)
val cudfSupportedGpuArchs = Set(50, 60, 65, 70)
val gpuArch = 40
validateGpuArchitectureInternal(gpuArch, jniSupportedGpuArchs, cudfSupportedGpuArchs)
}
}

test("test supported major architecture with higher minor version") {
val jniSupportedGpuArchs = Set(50, 60, 65, 70)
val cudfSupportedGpuArchs = Set(50, 60, 65, 70)
val gpuArch = 67
validateGpuArchitectureInternal(gpuArch, jniSupportedGpuArchs, cudfSupportedGpuArchs)
}

test("test supported major architecture with lower minor version") {
val jniSupportedGpuArchs = Set(50, 60, 65, 70)
val cudfSupportedGpuArchs = Set(50, 60, 65, 70)
val gpuArch = 63
validateGpuArchitectureInternal(gpuArch, jniSupportedGpuArchs, cudfSupportedGpuArchs)
}

test("test empty supported architecture set") {
assertThrows[IllegalStateException] {
val jniSupportedGpuArchs = Set(50, 60)
val cudfSupportedGpuArchs = Set(70, 80)
val gpuArch = 60
validateGpuArchitectureInternal(gpuArch, jniSupportedGpuArchs, cudfSupportedGpuArchs)
}
}
}
Loading