From 7e0771819480d76ed4552fd09eb8c09f3b2f696d Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Thu, 17 Sep 2020 18:22:51 -0500 Subject: [PATCH 1/2] Fix shims provider override config not being seen by executors Signed-off-by: Jason Lowe --- .../scala/com/nvidia/spark/rapids/Plugin.scala | 9 ++++++++- .../com/nvidia/spark/rapids/ShimLoader.scala | 15 +++++++++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala index 0d6efd08f12..7f60f3d860b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala @@ -107,7 +107,11 @@ class RapidsDriverPlugin extends DriverPlugin with Logging { override def init(sc: SparkContext, pluginContext: PluginContext): util.Map[String, String] = { val sparkConf = pluginContext.conf RapidsPluginUtils.fixupConfigs(sparkConf) - new RapidsConf(sparkConf).rapidsConfMap + val conf = new RapidsConf(sparkConf) + if (conf.shimsProviderOverride.isDefined) { + ShimLoader.setSparkShimProviderClass(conf.shimsProviderOverride.get) + } + conf.rapidsConfMap } } @@ -120,6 +124,9 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging { extraConf: util.Map[String, String]): Unit = { try { val conf = new RapidsConf(extraConf.asScala.toMap) + if (conf.shimsProviderOverride.isDefined) { + ShimLoader.setSparkShimProviderClass(conf.shimsProviderOverride.get) + } // we rely on the Rapids Plugin being run with 1 GPU per executor so we can initialize // on executor startup. diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala index 2a9e1525d26..b806910dd2c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala @@ -20,10 +20,11 @@ import java.util.ServiceLoader import scala.collection.JavaConverters._ -import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION, SparkConf} +import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION} import org.apache.spark.internal.Logging object ShimLoader extends Logging { + private var shimProviderClass: String = null private var sparkShims: SparkShims = null private def detectShimProvider(): SparkShimServiceProvider = { @@ -47,14 +48,12 @@ object ShimLoader extends Logging { } private def findShimProvider(): SparkShimServiceProvider = { - val conf = new RapidsConf(new SparkConf()) - if (conf.shimsProviderOverride.isEmpty) { + if (shimProviderClass == null) { detectShimProvider() } else { - val classname = conf.shimsProviderOverride.get - logWarning(s"Overriding Spark shims provider to $classname. " + + logWarning(s"Overriding Spark shims provider to $shimProviderClass. " + "This may be an untested configuration!") - val providerClass = Class.forName(classname) + val providerClass = Class.forName(shimProviderClass) val constructor = providerClass.getConstructor() constructor.newInstance().asInstanceOf[SparkShimServiceProvider] } @@ -76,4 +75,8 @@ object ShimLoader extends Logging { SPARK_VERSION } } + + def setSparkShimProviderClass(classname: String): Unit = { + shimProviderClass = classname + } } From caf10e04290774fb4e8aa1a09932f5f64e134724 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Thu, 17 Sep 2020 19:04:54 -0500 Subject: [PATCH 2/2] Also update shim provider class override from shuffle manager since it loads early Signed-off-by: Jason Lowe --- .../spark/sql/rapids/RapidsShuffleInternalManager.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala index 55b87de4e23..b640b9d1dc9 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala @@ -201,6 +201,11 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, isDriver: Boole private val rapidsConf = new RapidsConf(conf) + // set the shim override if specified since the shuffle manager loads early + if (rapidsConf.shimsProviderOverride.isDefined) { + ShimLoader.setSparkShimProviderClass(rapidsConf.shimsProviderOverride.get) + } + protected val wrapped = new SortShuffleManager(conf) GpuShuffleEnv.setRapidsShuffleManagerInitialized(true, this.getClass.getCanonicalName) logWarning("Rapids Shuffle Plugin Enabled")