Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Audit] Add bucketed scan info in query plan of data source v1 [databricks] #4461

Merged
82 changes: 80 additions & 2 deletions integration_tests/src/main/python/explain_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,7 @@
from marks import *
from pyspark.sql.functions import *
from pyspark.sql.types import *
from spark_session import with_cpu_session
from spark_session import with_cpu_session, with_gpu_session, is_before_spark_311

def create_df(spark, data_gen, left_length, right_length):
left = binary_op_df(spark, data_gen, length=left_length)
Expand Down Expand Up @@ -91,3 +91,81 @@ def do_explain(spark):

with_cpu_session(do_explain)


@allow_non_gpu(any = True)
def test_explain_bucketd_scan(spark_tmp_table_factory):
"""
Test the physical plan includes the info of enabling bucketed scan.
The code is copied from:
https://github.com/apache/spark/commit/79515e4b6c#diff-03f119698c3637b87c9ce2634c34c14bb0f7efc043ea37a0891c1ab9fbc3ebadR688
"""
def do_explain(spark):
tbl_1 = spark_tmp_table_factory.get()
tbl_2 = spark_tmp_table_factory.get()
spark.createDataFrame([(1, 2), (2, 3)], ("i", "j")).write.bucketBy(8, "i").saveAsTable(tbl_1)
spark.createDataFrame([(2,), (3,)], ("i",)).write.bucketBy(8, "i").saveAsTable(tbl_2)
df1 = spark.table(tbl_1)
df2 = spark.table(tbl_2)
joined_df = df1.join(df2, df1.i == df2.i , "inner")

assert "Bucketed: true" in joined_df._sc._jvm.PythonSQLUtils.explainString(joined_df._jdf.queryExecution(), "simple")

with_gpu_session(do_explain, {"spark.sql.autoBroadcastJoinThreshold": "0"})


@allow_non_gpu(any = True)
def test_explain_bucket_column_not_read(spark_tmp_table_factory):
"""
Test the physical plan includes the info of disabling bucketed scan and the reason.
The code is copied from:
https://github.com/apache/spark/commit/79515e4b6c#diff-03f119698c3637b87c9ce2634c34c14bb0f7efc043ea37a0891c1ab9fbc3ebadR702
"""
def do_explain(spark):
tbl = spark_tmp_table_factory.get()
spark.createDataFrame([(1, 2), (2, 3)], ("i", "j")).write.bucketBy(8, "i").saveAsTable(tbl)
df = spark.table(tbl).select(f.col("j"))

assert "Bucketed: false (bucket column(s) not read)" in df._sc._jvm.PythonSQLUtils.explainString(df._jdf.queryExecution(), "simple")

with_gpu_session(do_explain)


@allow_non_gpu(any = True)
def test_explain_bucket_disabled_by_conf(spark_tmp_table_factory):
"""
Test the physical plan includes the info of disabling bucketed scan and the reason.
The code is copied from:
https://github.com/apache/spark/commit/79515e4b6c#diff-03f119698c3637b87c9ce2634c34c14bb0f7efc043ea37a0891c1ab9fbc3ebadR694
"""
def do_explain(spark):
tbl_1 = spark_tmp_table_factory.get()
tbl_2 = spark_tmp_table_factory.get()
spark.createDataFrame([(1, 2), (2, 3)], ("i", "j")).write.bucketBy(8, "i").saveAsTable(tbl_1)
spark.createDataFrame([(2,), (3,)], ("i",)).write.bucketBy(8, "i").saveAsTable(tbl_2)
df1 = spark.table(tbl_1)
df2 = spark.table(tbl_2)
joined_df = df1.join(df2, df1.i == df2.i , "inner")

assert "Bucketed: false (disabled by configuration)" in joined_df._sc._jvm.PythonSQLUtils.explainString(joined_df._jdf.queryExecution(), "simple")

with_gpu_session(do_explain, {"spark.sql.sources.bucketing.enabled": "false"})


@allow_non_gpu(any=True)
@pytest.mark.skipif(is_before_spark_311(), reason="Only in Spark 3.1.0+ does the `GpuFileSourceScanExec` have the attribute `disableBucketedScan`")
def test_explain_bucket_disabled_by_query_planner(spark_tmp_table_factory):
"""
Test the physical plan includes the info of disabling bucketed scan and the reason.
The code is copied from:
https://github.com/apache/spark/commit/79515e4b6c#diff-03f119698c3637b87c9ce2634c34c14bb0f7efc043ea37a0891c1ab9fbc3ebadR700

This test will be skipped if spark version is before 3.1.0. Because the attribute `disableBucketedScan` is not included in `GpuFileSourceScanExec` before 3.1.0.
"""
def do_explain(spark):
tbl = spark_tmp_table_factory.get()
spark.createDataFrame([(1, 2), (2, 3)], ("i", "j")).write.bucketBy(8, "i").saveAsTable(tbl)
df = spark.table(tbl)

assert "Bucketed: false (disabled by query planner)" in df._sc._jvm.PythonSQLUtils.explainString(df._jdf.queryExecution(), "simple")

with_gpu_session(do_explain)
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,8 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
wrapped.optionalBucketSet,
wrapped.optionalNumCoalescedBuckets,
wrapped.dataFilters,
wrapped.tableIdentifier)(conf)
wrapped.tableIdentifier,
wrapped.disableBucketedScan)(conf)
}
}),
GpuOverrides.exec[InMemoryTableScanExec](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,8 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
// TODO: Does Databricks have coalesced bucketing implemented?
None,
wrapped.dataFilters,
wrapped.tableIdentifier)(conf)
wrapped.tableIdentifier,
wrapped.disableBucketedScan)(conf)
}
}),
GpuOverrides.exec[InMemoryTableScanExec](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,8 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging {
wrapped.optionalBucketSet,
wrapped.optionalNumCoalescedBuckets,
wrapped.dataFilters,
wrapped.tableIdentifier)(conf)
wrapped.tableIdentifier,
wrapped.disableBucketedScan)(conf)
}
}),
GpuOverrides.exec[InMemoryTableScanExec](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ trait Spark320until322Shims extends Spark320PlusShims with RebaseShims with Logg
new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith,
pushDownInFilterThreshold, caseSensitive, datetimeRebaseMode)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ trait Spark322PlusShims extends Spark320PlusShims with RebaseShims with Logging
new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith,
pushDownInFilterThreshold, caseSensitive, datetimeRebaseMode)
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -56,6 +56,7 @@ import org.apache.spark.util.collection.BitSet
* @param queryUsesInputFile This is a parameter to easily allow turning it
* off in GpuTransitionOverrides if InputFileName,
* InputFileBlockStart, or InputFileBlockLength are used
* @param disableBucketedScan Disable bucketed scan based on physical query plan.
*/
case class GpuFileSourceScanExec(
@transient relation: HadoopFsRelation,
Expand All @@ -66,6 +67,7 @@ case class GpuFileSourceScanExec(
optionalNumCoalescedBuckets: Option[Int],
dataFilters: Seq[Expression],
tableIdentifier: Option[TableIdentifier],
disableBucketedScan: Boolean = false,
queryUsesInputFile: Boolean = false)(@transient val rapidsConf: RapidsConf)
extends GpuDataSourceScanExec with GpuExec {
import GpuMetric._
Expand Down Expand Up @@ -153,7 +155,8 @@ case class GpuFileSourceScanExec(

// exposed for testing
lazy val bucketedScan: Boolean = {
if (relation.sparkSession.sessionState.conf.bucketingEnabled && relation.bucketSpec.isDefined) {
if (relation.sparkSession.sessionState.conf.bucketingEnabled && relation.bucketSpec.isDefined
&& !disableBucketedScan) {
val spec = relation.bucketSpec.get
val bucketColumns = spec.bucketColumnNames.flatMap(n => toAttribute(n))
bucketColumns.size == spec.bucketColumnNames.size
Expand Down Expand Up @@ -244,20 +247,31 @@ case class GpuFileSourceScanExec(
"DataFilters" -> seqToString(dataFilters),
"Location" -> locationDesc)

val withSelectedBucketsCount = relation.bucketSpec.map { spec =>
val numSelectedBuckets = optionalBucketSet.map { b =>
b.cardinality()
} getOrElse {
spec.numBuckets


relation.bucketSpec.map { spec =>
val bucketedKey = "Bucketed"
if (bucketedScan){
val numSelectedBuckets = optionalBucketSet.map { b =>
b.cardinality()
} getOrElse {
spec.numBuckets
}

metadata ++ Map(
bucketedKey -> "true",
"SelectedBucketsCount" -> (s"$numSelectedBuckets out of ${spec.numBuckets}" +
optionalNumCoalescedBuckets.map { b => s" (Coalesced to $b)"}.getOrElse("")))
} else if (!relation.sparkSession.sessionState.conf.bucketingEnabled) {
metadata + (bucketedKey -> "false (disabled by configuration)")
} else if (disableBucketedScan) {
metadata + (bucketedKey -> "false (disabled by query planner)")
} else {
metadata + (bucketedKey -> "false (bucket column(s) not read)")
}
metadata + ("SelectedBucketsCount" ->
(s"$numSelectedBuckets out of ${spec.numBuckets}" +
optionalNumCoalescedBuckets.map { b => s" (Coalesced to $b)"}.getOrElse("")))
} getOrElse {
metadata
}

withSelectedBucketsCount
}

override def verboseStringWithOperatorId(): String = {
Expand Down