Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shims provider override config not being seen by executors #798

Merged
merged 2 commits into from
Sep 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ class RapidsDriverPlugin extends DriverPlugin with Logging {
override def init(sc: SparkContext, pluginContext: PluginContext): util.Map[String, String] = {
val sparkConf = pluginContext.conf
RapidsPluginUtils.fixupConfigs(sparkConf)
new RapidsConf(sparkConf).rapidsConfMap
val conf = new RapidsConf(sparkConf)
if (conf.shimsProviderOverride.isDefined) {
ShimLoader.setSparkShimProviderClass(conf.shimsProviderOverride.get)
}
conf.rapidsConfMap
}
}

Expand All @@ -120,6 +124,9 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging {
extraConf: util.Map[String, String]): Unit = {
try {
val conf = new RapidsConf(extraConf.asScala.toMap)
if (conf.shimsProviderOverride.isDefined) {
ShimLoader.setSparkShimProviderClass(conf.shimsProviderOverride.get)
}

// we rely on the Rapids Plugin being run with 1 GPU per executor so we can initialize
// on executor startup.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ import java.util.ServiceLoader

import scala.collection.JavaConverters._

import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION, SparkConf}
import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION}
import org.apache.spark.internal.Logging

object ShimLoader extends Logging {
private var shimProviderClass: String = null
private var sparkShims: SparkShims = null

private def detectShimProvider(): SparkShimServiceProvider = {
Expand All @@ -47,14 +48,12 @@ object ShimLoader extends Logging {
}

private def findShimProvider(): SparkShimServiceProvider = {
val conf = new RapidsConf(new SparkConf())
if (conf.shimsProviderOverride.isEmpty) {
if (shimProviderClass == null) {
detectShimProvider()
} else {
val classname = conf.shimsProviderOverride.get
logWarning(s"Overriding Spark shims provider to $classname. " +
logWarning(s"Overriding Spark shims provider to $shimProviderClass. " +
"This may be an untested configuration!")
val providerClass = Class.forName(classname)
val providerClass = Class.forName(shimProviderClass)
val constructor = providerClass.getConstructor()
constructor.newInstance().asInstanceOf[SparkShimServiceProvider]
}
Expand All @@ -76,4 +75,8 @@ object ShimLoader extends Logging {
SPARK_VERSION
}
}

def setSparkShimProviderClass(classname: String): Unit = {
shimProviderClass = classname
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, isDriver: Boole

private val rapidsConf = new RapidsConf(conf)

// set the shim override if specified since the shuffle manager loads early
if (rapidsConf.shimsProviderOverride.isDefined) {
ShimLoader.setSparkShimProviderClass(rapidsConf.shimsProviderOverride.get)
}

protected val wrapped = new SortShuffleManager(conf)
GpuShuffleEnv.setRapidsShuffleManagerInitialized(true, this.getClass.getCanonicalName)
logWarning("Rapids Shuffle Plugin Enabled")
Expand Down