From 7af65a0360b9af1fb32d4b00806dc1bd0f274062 Mon Sep 17 00:00:00 2001 From: Niranjan Artal <50492963+nartal1@users.noreply.github.com> Date: Wed, 4 Nov 2020 06:52:59 -0800 Subject: [PATCH] Sanity checks for cudf jar mismatch (#1047) * Sanity checks for cudf jar mismatch Signed-off-by: Niranjan Artal * addressed review comments Signed-off-by: Niranjan Artal * addressed review comments Signed-off-by: Niranjan Artal * log warnings if the config is set but versions mismatch Signed-off-by: Niranjan Artal * addressed review comments Signed-off-by: Niranjan Artal * addressed review comments Signed-off-by: Niranjan Artal * refactored code and addressed review comments Signed-off-by: Niranjan Artal * remove unwanted comment Signed-off-by: Niranjan Artal * addressed review comments Signed-off-by: Niranjan Artal --- build/build-info | 3 +- pom.xml | 1 + .../com/nvidia/spark/rapids/Plugin.scala | 56 +++++++++++++++++++ .../com/nvidia/spark/rapids/RapidsConf.scala | 10 ++++ 4 files changed, 69 insertions(+), 1 deletion(-) diff --git a/build/build-info b/build/build-info index 057f137892c..a905661920b 100755 --- a/build/build-info +++ b/build/build-info @@ -22,6 +22,7 @@ echo_build_properties() { echo version=$1 + echo cudf_version=$2 echo user=$USER echo revision=$(git rev-parse HEAD) echo branch=$(git rev-parse --abbrev-ref HEAD) @@ -29,4 +30,4 @@ echo_build_properties() { echo url=$(git config --get remote.origin.url) } -echo_build_properties $1 +echo_build_properties $1 $2 diff --git a/pom.xml b/pom.xml index 763b7512494..3429082a90b 100644 --- a/pom.xml +++ b/pom.xml @@ -538,6 +538,7 @@ + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala index 7f60f3d860b..337d9dc2bb5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala @@ -17,6 +17,7 @@ package com.nvidia.spark.rapids import java.util +import java.util.Properties import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import scala.collection.JavaConverters._ @@ -35,6 +36,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.internal.StaticSQLConf import org.apache.spark.sql.util.QueryExecutionListener + case class ColumnarOverrideRules() extends ColumnarRule with Logging { val overrides: Rule[SparkPlan] = GpuOverrides() val overrideTransitions: Rule[SparkPlan] = new GpuTransitionOverrides() @@ -128,6 +130,11 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging { ShimLoader.setSparkShimProviderClass(conf.shimsProviderOverride.get) } + // Compare if the cudf version mentioned in the classpath is equal to the version which + // plugin expects. If there is a version mismatch, throw error. This check can be disabled + // by setting this config spark.rapids.cudfVersionOverride=true + checkCudfVersion(conf) + // we rely on the Rapids Plugin being run with 1 GPU per executor so we can initialize // on executor startup. if (!GpuDeviceManager.rmmTaskInitEnabled) { @@ -146,6 +153,55 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging { } } + private def checkCudfVersion(conf: RapidsConf): Unit = { + try { + val cudfPropertiesFileName = "cudf-java-version-info.properties" + val pluginPropertiesFileName = "rapids4spark-version-info.properties" + + val props = new Properties + val classLoader = classOf[RapidsExecutorPlugin].getClassLoader + val cudfProperties = classLoader.getResourceAsStream(cudfPropertiesFileName) + if (cudfProperties == null) { + throw CudfVersionMismatchException(s"Could not find properties file " + + s"$cudfPropertiesFileName in the cudf jar. Cannot verify cudf version compatibility " + + s"with RAPIDS Accelerator version.") + } + props.load(cudfProperties) + + val classpathCudfVersion = props.get("version") + if (classpathCudfVersion == null) { + throw CudfVersionMismatchException(s"Property name `version` not found in " + + s"$cudfPropertiesFileName file.") + } + val cudfVersion = classpathCudfVersion.toString + + val pluginResource = classLoader.getResourceAsStream(pluginPropertiesFileName) + if (pluginResource == null) { + throw CudfVersionMismatchException(s"Could not find properties file " + + s"$pluginPropertiesFileName in the RAPIDS Accelerator jar. Cannot verify cudf " + + s"version compatibility with RAPIDS Accelerator version.") + } + props.load(pluginResource) + + val pluginCudfVersion = props.get("cudf_version") + if (pluginCudfVersion == null) { + throw CudfVersionMismatchException(s"Property name `cudf_version` not found in" + + s" $pluginPropertiesFileName file.") + } + val expectedCudfVersion = pluginCudfVersion.toString + // compare cudf version in the classpath with the cudf version expected by plugin + if (!cudfVersion.equals(expectedCudfVersion)) { + throw CudfVersionMismatchException(s"Cudf version in the classpath is different. " + + s"Found $cudfVersion, RAPIDS Accelerator expects $expectedCudfVersion") + } + } catch { + case x: CudfVersionMismatchException if conf.cudfVersionOverride => + logWarning(s"${x.errorMsg}") + } + } + + case class CudfVersionMismatchException(errorMsg: String) extends RuntimeException(errorMsg) + override def shutdown(): Unit = { GpuSemaphore.shutdown() PythonWorkerSemaphore.shutdown() diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index ecac75c402e..c2da5002904 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -759,6 +759,14 @@ object RapidsConf { .stringConf .createOptional + val CUDF_VERSION_OVERRIDE = conf("spark.rapids.cudfVersionOverride") + .internal() + .doc("Overrides the cudf version compatibility check between cudf jar and RAPIDS Accelerator " + + "jar. If you are sure that the cudf jar which is mentioned in the classpath is compatible " + + "with the RAPIDS Accelerator version, then set this to true.") + .booleanConf + .createWithDefault(false) + private def printSectionHeader(category: String): Unit = println(s"\n### $category") @@ -1040,6 +1048,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val shimsProviderOverride: Option[String] = get(SHIMS_PROVIDER_OVERRIDE) + lazy val cudfVersionOverride: Boolean = get(CUDF_VERSION_OVERRIDE) + lazy val getCloudSchemes: Option[Seq[String]] = get(CLOUD_SCHEMES) def isOperatorEnabled(key: String, incompat: Boolean, isDisabledByDefault: Boolean): Boolean = {