-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support GpuSubqueryBroadcast for DPP [databricks] #4150
Changes from all commits
96ac231
49a6d7c
05a3cc5
ac123f0
57f19d7
489cee8
904af4f
4a73dc2
262664b
466d053
18ec34d
85c2cc7
a7d8036
f7529f9
544bc47
bbcaf43
809bf17
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ | |
import pytest | ||
|
||
from asserts import assert_cpu_and_gpu_are_equal_collect_with_capture | ||
from conftest import spark_tmp_table_factory | ||
from conftest import spark_tmp_table_factory, is_databricks_runtime | ||
from data_gen import * | ||
from marks import ignore_order | ||
from spark_session import is_before_spark_320, with_cpu_session | ||
|
@@ -81,7 +81,7 @@ def fn(spark): | |
''', | ||
''' | ||
SELECT f.key, sum(f.value) | ||
FROM (SELECT *, struct(key) AS keys FROM {0} fact) f | ||
FROM (SELECT *, struct(key) AS keys FROM {0} fact) f | ||
JOIN (SELECT *, struct(key) AS keys FROM {1} dim) d | ||
ON f.keys = d.keys | ||
WHERE d.filter = {2} | ||
|
@@ -91,32 +91,64 @@ def fn(spark): | |
|
||
|
||
# When BroadcastExchangeExec is available on filtering side, and it can be reused: | ||
# DynamicPruningExpression(InSubqueryExec(value, SubqueryBroadcastExec))) | ||
# DynamicPruningExpression(InSubqueryExec(value, GpuSubqueryBroadcastExec))) | ||
@ignore_order | ||
@pytest.mark.parametrize('aqe_on', ['true', 'false'], ids=idfn) | ||
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn) | ||
@pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn) | ||
@pytest.mark.skipif(is_before_spark_320(), reason="Only in Spark 3.2.0+ AQE and DPP can be both enabled") | ||
def test_dpp_reuse_broadcast_exchange(aqe_on, store_format, s_index, spark_tmp_table_factory): | ||
@pytest.mark.skipif(is_databricks_runtime(), reason="DPP can not cooperate with rapids plugin on Databricks runtime") | ||
def test_dpp_reuse_broadcast_exchange_aqe_off(store_format, s_index, spark_tmp_table_factory): | ||
fact_table, dim_table = spark_tmp_table_factory.get(), spark_tmp_table_factory.get() | ||
create_fact_table(fact_table, store_format, length=10000) | ||
filter_val = create_dim_table(dim_table, store_format, length=2000) | ||
statement = _statements[s_index].format(fact_table, dim_table, filter_val) | ||
assert_cpu_and_gpu_are_equal_collect_with_capture( | ||
lambda spark: spark.sql(statement), | ||
# SubqueryBroadcastExec appears if we reuse broadcast exchange for DPP | ||
exist_classes='DynamicPruningExpression,SubqueryBroadcastExec', | ||
conf=dict(_exchange_reuse_conf + [('spark.sql.adaptive.enabled', aqe_on)])) | ||
# The existence of GpuSubqueryBroadcastExec indicates the reuse works on the GPU | ||
exist_classes='DynamicPruningExpression,GpuSubqueryBroadcastExec,ReusedExchangeExec', | ||
conf=dict(_exchange_reuse_conf + [('spark.sql.adaptive.enabled', 'false')])) | ||
|
||
|
||
# When BroadcastExchange is not available and non-broadcast DPPs are forbidden, Spark will bypass it: | ||
# DynamicPruningExpression(Literal.TrueLiteral) | ||
# The SubqueryBroadcast can work on GPU even if the scan who holds it fallbacks into CPU. | ||
@ignore_order | ||
@pytest.mark.allow_non_gpu('FileSourceScanExec') | ||
@pytest.mark.skipif(is_databricks_runtime(), reason="DPP can not cooperate with rapids plugin on Databricks runtime") | ||
def test_dpp_reuse_broadcast_exchange_cpu_scan(spark_tmp_table_factory): | ||
fact_table, dim_table = spark_tmp_table_factory.get(), spark_tmp_table_factory.get() | ||
create_fact_table(fact_table, 'parquet', length=10000) | ||
filter_val = create_dim_table(dim_table, 'parquet', length=2000) | ||
statement = _statements[0].format(fact_table, dim_table, filter_val) | ||
assert_cpu_and_gpu_are_equal_collect_with_capture( | ||
lambda spark: spark.sql(statement), | ||
# The existence of GpuSubqueryBroadcastExec indicates the reuse works on the GPU | ||
exist_classes='FileSourceScanExec,GpuSubqueryBroadcastExec,ReusedExchangeExec', | ||
conf=dict(_exchange_reuse_conf + [ | ||
('spark.sql.adaptive.enabled', 'false'), | ||
('spark.rapids.sql.format.parquet.read.enabled', 'false')])) | ||
|
||
|
||
# When AQE enabled, the broadcast exchange can not be reused in current, because spark-rapids | ||
# will plan GpuBroadcastToCpu for exchange reuse. Meanwhile, the original broadcast exchange is | ||
# simply replaced by GpuBroadcastExchange. Therefore, the reuse can not work since | ||
# GpuBroadcastToCpu is not semantically equal to GpuBroadcastExchange. | ||
@ignore_order | ||
@pytest.mark.parametrize('aqe_on', ['true', 'false'], ids=idfn) | ||
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn) | ||
@pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn) | ||
@pytest.mark.skipif(is_databricks_runtime(), reason="DPP can not cooperate with rapids plugin on Databricks runtime") | ||
@pytest.mark.skipif(is_before_spark_320(), reason="Only in Spark 3.2.0+ AQE and DPP can be both enabled") | ||
def test_dpp_bypass(aqe_on, store_format, s_index, spark_tmp_table_factory): | ||
def test_dpp_reuse_broadcast_exchange_aqe_on(store_format, s_index, spark_tmp_table_factory): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did we miss the corresponding test with AQE off, or is that covered in some other existing test and was not really needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I named the test of AQE off as |
||
fact_table, dim_table = spark_tmp_table_factory.get(), spark_tmp_table_factory.get() | ||
create_fact_table(fact_table, store_format, length=10000) | ||
filter_val = create_dim_table(dim_table, store_format, length=2000) | ||
statement = _statements[s_index].format(fact_table, dim_table, filter_val) | ||
assert_cpu_and_gpu_are_equal_collect_with_capture( | ||
lambda spark: spark.sql(statement), | ||
exist_classes='DynamicPruningExpression,SubqueryBroadcastExec,GpuBroadcastToCpuExec', | ||
conf=dict(_exchange_reuse_conf + [('spark.sql.adaptive.enabled', 'true')])) | ||
|
||
|
||
# When BroadcastExchange is not available and non-broadcast DPPs are forbidden, Spark will bypass it: | ||
# DynamicPruningExpression(Literal.TrueLiteral) | ||
def __dpp_bypass(store_format, s_index, spark_tmp_table_factory, aqe_enabled): | ||
fact_table, dim_table = spark_tmp_table_factory.get(), spark_tmp_table_factory.get() | ||
create_fact_table(fact_table, store_format) | ||
filter_val = create_dim_table(dim_table, store_format) | ||
|
@@ -126,18 +158,30 @@ def test_dpp_bypass(aqe_on, store_format, s_index, spark_tmp_table_factory): | |
# Bypass with a true literal, if we can not reuse broadcast exchange. | ||
exist_classes='DynamicPruningExpression', | ||
non_exist_classes='SubqueryExec,SubqueryBroadcastExec', | ||
conf=dict(_bypass_conf + [('spark.sql.adaptive.enabled', aqe_on)])) | ||
conf=dict(_bypass_conf + [('spark.sql.adaptive.enabled', aqe_enabled)])) | ||
|
||
|
||
@ignore_order | ||
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn) | ||
@pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn) | ||
@pytest.mark.skipif(is_databricks_runtime(), reason="DPP can not cooperate with rapids plugin on Databricks runtime") | ||
def test_dpp_bypass_aqe_off(store_format, s_index, spark_tmp_table_factory): | ||
__dpp_bypass(store_format, s_index, spark_tmp_table_factory, 'false') | ||
|
||
|
||
# When BroadcastExchange is not available, but it is still worthwhile to run DPP, | ||
# then Spark will plan an extra Aggregate to collect filtering values: | ||
# DynamicPruningExpression(InSubqueryExec(value, SubqueryExec(Aggregate(...)))) | ||
@ignore_order | ||
@pytest.mark.parametrize('aqe_on', ['true', 'false'], ids=idfn) | ||
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn) | ||
@pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn) | ||
@pytest.mark.skipif(is_databricks_runtime(), reason="DPP can not cooperate with rapids plugin on Databricks runtime") | ||
@pytest.mark.skipif(is_before_spark_320(), reason="Only in Spark 3.2.0+ AQE and DPP can be both enabled") | ||
def test_dpp_via_aggregate_subquery(aqe_on, store_format, s_index, spark_tmp_table_factory): | ||
def test_dpp_bypass_aqe_on(store_format, s_index, spark_tmp_table_factory): | ||
__dpp_bypass(store_format, s_index, spark_tmp_table_factory, 'true') | ||
|
||
|
||
# When BroadcastExchange is not available, but it is still worthwhile to run DPP, | ||
# then Spark will plan an extra Aggregate to collect filtering values: | ||
# DynamicPruningExpression(InSubqueryExec(value, SubqueryExec(Aggregate(...)))) | ||
def __dpp_via_aggregate_subquery(store_format, s_index, spark_tmp_table_factory, aqe_enabled): | ||
fact_table, dim_table = spark_tmp_table_factory.get(), spark_tmp_table_factory.get() | ||
create_fact_table(fact_table, store_format) | ||
filter_val = create_dim_table(dim_table, store_format) | ||
|
@@ -146,16 +190,28 @@ def test_dpp_via_aggregate_subquery(aqe_on, store_format, s_index, spark_tmp_tab | |
lambda spark: spark.sql(statement), | ||
# SubqueryExec appears if we plan extra subquery for DPP | ||
exist_classes='DynamicPruningExpression,SubqueryExec', | ||
conf=dict(_no_exchange_reuse_conf + [('spark.sql.adaptive.enabled', aqe_on)])) | ||
conf=dict(_no_exchange_reuse_conf + [('spark.sql.adaptive.enabled', aqe_enabled)])) | ||
|
||
|
||
# When BroadcastExchange is not available, Spark will skip DPP if there is no potential benefit | ||
@ignore_order | ||
@pytest.mark.parametrize('aqe_on', ['true', 'false'], ids=idfn) | ||
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn) | ||
@pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn) | ||
@pytest.mark.skipif(is_databricks_runtime(), reason="DPP can not cooperate with rapids plugin on Databricks runtime") | ||
def test_dpp_via_aggregate_subquery_aqe_off(store_format, s_index, spark_tmp_table_factory): | ||
__dpp_via_aggregate_subquery(store_format, s_index, spark_tmp_table_factory, 'false') | ||
|
||
|
||
@ignore_order | ||
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn) | ||
@pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn) | ||
@pytest.mark.skipif(is_databricks_runtime(), reason="DPP can not cooperate with rapids plugin on Databricks runtime") | ||
@pytest.mark.skipif(is_before_spark_320(), reason="Only in Spark 3.2.0+ AQE and DPP can be both enabled") | ||
def test_dpp_skip(aqe_on, store_format, s_index, spark_tmp_table_factory): | ||
def test_dpp_via_aggregate_subquery_aqe_on(store_format, s_index, spark_tmp_table_factory): | ||
__dpp_via_aggregate_subquery(store_format, s_index, spark_tmp_table_factory, 'true') | ||
|
||
|
||
# When BroadcastExchange is not available, Spark will skip DPP if there is no potential benefit | ||
def __dpp_skip(store_format, s_index, spark_tmp_table_factory, aqe_enabled): | ||
fact_table, dim_table = spark_tmp_table_factory.get(), spark_tmp_table_factory.get() | ||
create_fact_table(fact_table, store_format) | ||
filter_val = create_dim_table(dim_table, store_format) | ||
|
@@ -164,4 +220,21 @@ def test_dpp_skip(aqe_on, store_format, s_index, spark_tmp_table_factory): | |
lambda spark: spark.sql(statement), | ||
# SubqueryExec appears if we plan extra subquery for DPP | ||
non_exist_classes='DynamicPruningExpression', | ||
conf=dict(_dpp_fallback_conf + [('spark.sql.adaptive.enabled', aqe_on)])) | ||
conf=dict(_dpp_fallback_conf + [('spark.sql.adaptive.enabled', aqe_enabled)])) | ||
|
||
|
||
@ignore_order | ||
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn) | ||
@pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn) | ||
@pytest.mark.skipif(is_databricks_runtime(), reason="DPP can not cooperate with rapids plugin on Databricks runtime") | ||
def test_dpp_skip_aqe_off(store_format, s_index, spark_tmp_table_factory): | ||
__dpp_skip(store_format, s_index, spark_tmp_table_factory, 'false') | ||
|
||
|
||
@ignore_order | ||
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn) | ||
@pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn) | ||
@pytest.mark.skipif(is_databricks_runtime(), reason="DPP can not cooperate with rapids plugin on Databricks runtime") | ||
@pytest.mark.skipif(is_before_spark_320(), reason="Only in Spark 3.2.0+ AQE and DPP can be both enabled") | ||
def test_dpp_skip_aqe_on(store_format, s_index, spark_tmp_table_factory): | ||
__dpp_skip(store_format, s_index, spark_tmp_table_factory, 'true') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this something that we should fix? Should be combine the two classes together so that they are the same thing and it does not matter if you are reading the data on the CPU or the GPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. IMO, with the help of the new method
SerializeConcatHostBuffersDeserializeBatch.hostBatches
, we can change the role ofGpuBroadcastToCpu
, making it as a wrapper ofGpuBroadcastExchangeExec
. Therefore, we can reuse the GpuBroadcast in terms of serialized host buffers. I tried in my local environment, it works. I would like to create a separate PR for this change.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds like a good plan.