diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala b/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala index 3fed04b3503..3ebe83bc113 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala @@ -26,11 +26,19 @@ import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.GpuShuffleEnv import org.apache.spark.sql.vectorized.ColumnarBatch /** A collection of utility methods useful in tests. */ object TestUtils extends Assertions with Arm { + // Need to set a legacy config to allow clearing the active session + private val clearSessionConf = { + val conf = new SQLConf + conf.setConfString("spark.sql.legacy.allowModifyActiveSession", "true") + conf + } + def getTempDir(basename: String): File = new File( System.getProperty("test.build.data", System.getProperty("java.io.tmpdir", "/tmp")), basename) @@ -134,8 +142,10 @@ object TestUtils extends Assertions with Arm { f(spark) } finally { spark.stop() - SparkSession.clearActiveSession() - SparkSession.clearDefaultSession() + SQLConf.withExistingConf(clearSessionConf) { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } GpuShuffleEnv.setRapidsShuffleManagerInitialized(false, GpuShuffleEnv.RAPIDS_SHUFFLE_CLASS) } }