Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support GpuSubqueryBroadcast for DPP [databricks] #4150

Merged
merged 17 commits into from
Dec 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ Name | Description | Default Value | Notes
<a name="sql.exec.RangeExec"></a>spark.rapids.sql.exec.RangeExec|The backend for range operator|true|None|
<a name="sql.exec.SampleExec"></a>spark.rapids.sql.exec.SampleExec|The backend for the sample operator|true|None|
<a name="sql.exec.SortExec"></a>spark.rapids.sql.exec.SortExec|The backend for the sort operator|true|None|
<a name="sql.exec.SubqueryBroadcastExec"></a>spark.rapids.sql.exec.SubqueryBroadcastExec|Plan to collect and transform the broadcast key values|true|None|
<a name="sql.exec.TakeOrderedAndProjectExec"></a>spark.rapids.sql.exec.TakeOrderedAndProjectExec|Take the first limit elements as defined by the sortOrder, and do projection if needed|true|None|
<a name="sql.exec.UnionExec"></a>spark.rapids.sql.exec.UnionExec|The backend for the union operator|true|None|
<a name="sql.exec.CustomShuffleReaderExec"></a>spark.rapids.sql.exec.CustomShuffleReaderExec|A wrapper of shuffle query stage|true|None|
Expand Down
100 changes: 62 additions & 38 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,8 @@ Accelerator supports are described below.
<td><b>NS</b></td>
</tr>
<tr>
<td rowspan="1">TakeOrderedAndProjectExec</td>
<td rowspan="1">Take the first limit elements as defined by the sortOrder, and do projection if needed</td>
<td rowspan="1">SubqueryBroadcastExec</td>
<td rowspan="1">Plan to collect and transform the broadcast key values</td>
<td rowspan="1">None</td>
<td>Input/Output</td>
<td>S</td>
Expand All @@ -436,16 +436,16 @@ Accelerator supports are described below.
<td>S</td>
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><b>NS</b></td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP</em></td>
<td>S</td>
</tr>
<tr>
<td rowspan="1">UnionExec</td>
<td rowspan="1">The backend for the union operator</td>
<td rowspan="1">TakeOrderedAndProjectExec</td>
<td rowspan="1">Take the first limit elements as defined by the sortOrder, and do projection if needed</td>
<td rowspan="1">None</td>
<td>Input/Output</td>
<td>S</td>
Expand All @@ -464,12 +464,12 @@ Accelerator supports are described below.
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>unionByName will not optionally impute nulls for missing struct fields when the column is a struct and there are non-overlapping fields;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
<td rowspan="1">CustomShuffleReaderExec</td>
<td rowspan="1">A wrapper of shuffle query stage</td>
<td rowspan="1">UnionExec</td>
<td rowspan="1">The backend for the union operator</td>
<td rowspan="1">None</td>
<td>Input/Output</td>
<td>S</td>
Expand All @@ -488,7 +488,7 @@ Accelerator supports are described below.
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>unionByName will not optionally impute nulls for missing struct fields when the column is a struct and there are non-overlapping fields;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down Expand Up @@ -516,6 +516,30 @@ Accelerator supports are described below.
<th>UDT</th>
</tr>
<tr>
<td rowspan="1">CustomShuffleReaderExec</td>
<td rowspan="1">A wrapper of shuffle query stage</td>
<td rowspan="1">None</td>
<td>Input/Output</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
<td rowspan="1">HashAggregateExec</td>
<td rowspan="1">The backend for hash based aggregations</td>
<td rowspan="1">None</td>
Expand Down Expand Up @@ -840,6 +864,30 @@ Accelerator supports are described below.
<td><b>NS</b></td>
</tr>
<tr>
<th>Executor</th>
<th>Description</th>
<th>Notes</th>
<th>Param(s)</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowspan="4">ShuffledHashJoinExec</td>
<td rowspan="4">Implementation of join using hashed shuffled data</td>
<td rowspan="4">None</td>
Expand Down Expand Up @@ -927,30 +975,6 @@ Accelerator supports are described below.
<td><b>NS</b></td>
</tr>
<tr>
<th>Executor</th>
<th>Description</th>
<th>Notes</th>
<th>Param(s)</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowspan="4">SortMergeJoinExec</td>
<td rowspan="4">Sort merge join, replacing with shuffled hash join</td>
<td rowspan="4">None</td>
Expand Down
121 changes: 97 additions & 24 deletions integration_tests/src/main/python/dpp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest

from asserts import assert_cpu_and_gpu_are_equal_collect_with_capture
from conftest import spark_tmp_table_factory
from conftest import spark_tmp_table_factory, is_databricks_runtime
from data_gen import *
from marks import ignore_order
from spark_session import is_before_spark_320, with_cpu_session
Expand Down Expand Up @@ -81,7 +81,7 @@ def fn(spark):
''',
'''
SELECT f.key, sum(f.value)
FROM (SELECT *, struct(key) AS keys FROM {0} fact) f
FROM (SELECT *, struct(key) AS keys FROM {0} fact) f
JOIN (SELECT *, struct(key) AS keys FROM {1} dim) d
ON f.keys = d.keys
WHERE d.filter = {2}
Expand All @@ -91,32 +91,64 @@ def fn(spark):


# When BroadcastExchangeExec is available on filtering side, and it can be reused:
# DynamicPruningExpression(InSubqueryExec(value, SubqueryBroadcastExec)))
# DynamicPruningExpression(InSubqueryExec(value, GpuSubqueryBroadcastExec)))
@ignore_order
@pytest.mark.parametrize('aqe_on', ['true', 'false'], ids=idfn)
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn)
@pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn)
@pytest.mark.skipif(is_before_spark_320(), reason="Only in Spark 3.2.0+ AQE and DPP can be both enabled")
def test_dpp_reuse_broadcast_exchange(aqe_on, store_format, s_index, spark_tmp_table_factory):
@pytest.mark.skipif(is_databricks_runtime(), reason="DPP can not cooperate with rapids plugin on Databricks runtime")
def test_dpp_reuse_broadcast_exchange_aqe_off(store_format, s_index, spark_tmp_table_factory):
fact_table, dim_table = spark_tmp_table_factory.get(), spark_tmp_table_factory.get()
create_fact_table(fact_table, store_format, length=10000)
filter_val = create_dim_table(dim_table, store_format, length=2000)
statement = _statements[s_index].format(fact_table, dim_table, filter_val)
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark: spark.sql(statement),
# SubqueryBroadcastExec appears if we reuse broadcast exchange for DPP
exist_classes='DynamicPruningExpression,SubqueryBroadcastExec',
conf=dict(_exchange_reuse_conf + [('spark.sql.adaptive.enabled', aqe_on)]))
# The existence of GpuSubqueryBroadcastExec indicates the reuse works on the GPU
exist_classes='DynamicPruningExpression,GpuSubqueryBroadcastExec,ReusedExchangeExec',
conf=dict(_exchange_reuse_conf + [('spark.sql.adaptive.enabled', 'false')]))


# When BroadcastExchange is not available and non-broadcast DPPs are forbidden, Spark will bypass it:
# DynamicPruningExpression(Literal.TrueLiteral)
# The SubqueryBroadcast can work on GPU even if the scan who holds it fallbacks into CPU.
@ignore_order
@pytest.mark.allow_non_gpu('FileSourceScanExec')
@pytest.mark.skipif(is_databricks_runtime(), reason="DPP can not cooperate with rapids plugin on Databricks runtime")
def test_dpp_reuse_broadcast_exchange_cpu_scan(spark_tmp_table_factory):
fact_table, dim_table = spark_tmp_table_factory.get(), spark_tmp_table_factory.get()
create_fact_table(fact_table, 'parquet', length=10000)
filter_val = create_dim_table(dim_table, 'parquet', length=2000)
statement = _statements[0].format(fact_table, dim_table, filter_val)
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark: spark.sql(statement),
# The existence of GpuSubqueryBroadcastExec indicates the reuse works on the GPU
exist_classes='FileSourceScanExec,GpuSubqueryBroadcastExec,ReusedExchangeExec',
conf=dict(_exchange_reuse_conf + [
('spark.sql.adaptive.enabled', 'false'),
('spark.rapids.sql.format.parquet.read.enabled', 'false')]))


# When AQE enabled, the broadcast exchange can not be reused in current, because spark-rapids
# will plan GpuBroadcastToCpu for exchange reuse. Meanwhile, the original broadcast exchange is
# simply replaced by GpuBroadcastExchange. Therefore, the reuse can not work since
# GpuBroadcastToCpu is not semantically equal to GpuBroadcastExchange.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something that we should fix? Should be combine the two classes together so that they are the same thing and it does not matter if you are reading the data on the CPU or the GPU?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. IMO, with the help of the new method SerializeConcatHostBuffersDeserializeBatch.hostBatches, we can change the role of GpuBroadcastToCpu, making it as a wrapper of GpuBroadcastExchangeExec. Therefore, we can reuse the GpuBroadcast in terms of serialized host buffers. I tried in my local environment, it works. I would like to create a separate PR for this change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a good plan.

@ignore_order
@pytest.mark.parametrize('aqe_on', ['true', 'false'], ids=idfn)
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn)
@pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn)
@pytest.mark.skipif(is_databricks_runtime(), reason="DPP can not cooperate with rapids plugin on Databricks runtime")
@pytest.mark.skipif(is_before_spark_320(), reason="Only in Spark 3.2.0+ AQE and DPP can be both enabled")
def test_dpp_bypass(aqe_on, store_format, s_index, spark_tmp_table_factory):
def test_dpp_reuse_broadcast_exchange_aqe_on(store_format, s_index, spark_tmp_table_factory):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we miss the corresponding test with AQE off, or is that covered in some other existing test and was not really needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I named the test of AQE off as test_dpp_reuse_broadcast_exchange. I appended the suffix _aqe_off to clarify the intention of the tests.

fact_table, dim_table = spark_tmp_table_factory.get(), spark_tmp_table_factory.get()
create_fact_table(fact_table, store_format, length=10000)
filter_val = create_dim_table(dim_table, store_format, length=2000)
statement = _statements[s_index].format(fact_table, dim_table, filter_val)
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark: spark.sql(statement),
exist_classes='DynamicPruningExpression,SubqueryBroadcastExec,GpuBroadcastToCpuExec',
conf=dict(_exchange_reuse_conf + [('spark.sql.adaptive.enabled', 'true')]))


# When BroadcastExchange is not available and non-broadcast DPPs are forbidden, Spark will bypass it:
# DynamicPruningExpression(Literal.TrueLiteral)
def __dpp_bypass(store_format, s_index, spark_tmp_table_factory, aqe_enabled):
fact_table, dim_table = spark_tmp_table_factory.get(), spark_tmp_table_factory.get()
create_fact_table(fact_table, store_format)
filter_val = create_dim_table(dim_table, store_format)
Expand All @@ -126,18 +158,30 @@ def test_dpp_bypass(aqe_on, store_format, s_index, spark_tmp_table_factory):
# Bypass with a true literal, if we can not reuse broadcast exchange.
exist_classes='DynamicPruningExpression',
non_exist_classes='SubqueryExec,SubqueryBroadcastExec',
conf=dict(_bypass_conf + [('spark.sql.adaptive.enabled', aqe_on)]))
conf=dict(_bypass_conf + [('spark.sql.adaptive.enabled', aqe_enabled)]))


@ignore_order
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn)
@pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn)
@pytest.mark.skipif(is_databricks_runtime(), reason="DPP can not cooperate with rapids plugin on Databricks runtime")
def test_dpp_bypass_aqe_off(store_format, s_index, spark_tmp_table_factory):
__dpp_bypass(store_format, s_index, spark_tmp_table_factory, 'false')


# When BroadcastExchange is not available, but it is still worthwhile to run DPP,
# then Spark will plan an extra Aggregate to collect filtering values:
# DynamicPruningExpression(InSubqueryExec(value, SubqueryExec(Aggregate(...))))
@ignore_order
@pytest.mark.parametrize('aqe_on', ['true', 'false'], ids=idfn)
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn)
@pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn)
@pytest.mark.skipif(is_databricks_runtime(), reason="DPP can not cooperate with rapids plugin on Databricks runtime")
@pytest.mark.skipif(is_before_spark_320(), reason="Only in Spark 3.2.0+ AQE and DPP can be both enabled")
def test_dpp_via_aggregate_subquery(aqe_on, store_format, s_index, spark_tmp_table_factory):
def test_dpp_bypass_aqe_on(store_format, s_index, spark_tmp_table_factory):
__dpp_bypass(store_format, s_index, spark_tmp_table_factory, 'true')


# When BroadcastExchange is not available, but it is still worthwhile to run DPP,
# then Spark will plan an extra Aggregate to collect filtering values:
# DynamicPruningExpression(InSubqueryExec(value, SubqueryExec(Aggregate(...))))
def __dpp_via_aggregate_subquery(store_format, s_index, spark_tmp_table_factory, aqe_enabled):
fact_table, dim_table = spark_tmp_table_factory.get(), spark_tmp_table_factory.get()
create_fact_table(fact_table, store_format)
filter_val = create_dim_table(dim_table, store_format)
Expand All @@ -146,16 +190,28 @@ def test_dpp_via_aggregate_subquery(aqe_on, store_format, s_index, spark_tmp_tab
lambda spark: spark.sql(statement),
# SubqueryExec appears if we plan extra subquery for DPP
exist_classes='DynamicPruningExpression,SubqueryExec',
conf=dict(_no_exchange_reuse_conf + [('spark.sql.adaptive.enabled', aqe_on)]))
conf=dict(_no_exchange_reuse_conf + [('spark.sql.adaptive.enabled', aqe_enabled)]))


# When BroadcastExchange is not available, Spark will skip DPP if there is no potential benefit
@ignore_order
@pytest.mark.parametrize('aqe_on', ['true', 'false'], ids=idfn)
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn)
@pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn)
@pytest.mark.skipif(is_databricks_runtime(), reason="DPP can not cooperate with rapids plugin on Databricks runtime")
def test_dpp_via_aggregate_subquery_aqe_off(store_format, s_index, spark_tmp_table_factory):
__dpp_via_aggregate_subquery(store_format, s_index, spark_tmp_table_factory, 'false')


@ignore_order
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn)
@pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn)
@pytest.mark.skipif(is_databricks_runtime(), reason="DPP can not cooperate with rapids plugin on Databricks runtime")
@pytest.mark.skipif(is_before_spark_320(), reason="Only in Spark 3.2.0+ AQE and DPP can be both enabled")
def test_dpp_skip(aqe_on, store_format, s_index, spark_tmp_table_factory):
def test_dpp_via_aggregate_subquery_aqe_on(store_format, s_index, spark_tmp_table_factory):
__dpp_via_aggregate_subquery(store_format, s_index, spark_tmp_table_factory, 'true')


# When BroadcastExchange is not available, Spark will skip DPP if there is no potential benefit
def __dpp_skip(store_format, s_index, spark_tmp_table_factory, aqe_enabled):
fact_table, dim_table = spark_tmp_table_factory.get(), spark_tmp_table_factory.get()
create_fact_table(fact_table, store_format)
filter_val = create_dim_table(dim_table, store_format)
Expand All @@ -164,4 +220,21 @@ def test_dpp_skip(aqe_on, store_format, s_index, spark_tmp_table_factory):
lambda spark: spark.sql(statement),
# SubqueryExec appears if we plan extra subquery for DPP
non_exist_classes='DynamicPruningExpression',
conf=dict(_dpp_fallback_conf + [('spark.sql.adaptive.enabled', aqe_on)]))
conf=dict(_dpp_fallback_conf + [('spark.sql.adaptive.enabled', aqe_enabled)]))


@ignore_order
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn)
@pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn)
@pytest.mark.skipif(is_databricks_runtime(), reason="DPP can not cooperate with rapids plugin on Databricks runtime")
def test_dpp_skip_aqe_off(store_format, s_index, spark_tmp_table_factory):
__dpp_skip(store_format, s_index, spark_tmp_table_factory, 'false')


@ignore_order
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn)
@pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn)
@pytest.mark.skipif(is_databricks_runtime(), reason="DPP can not cooperate with rapids plugin on Databricks runtime")
@pytest.mark.skipif(is_before_spark_320(), reason="Only in Spark 3.2.0+ AQE and DPP can be both enabled")
def test_dpp_skip_aqe_on(store_format, s_index, spark_tmp_table_factory):
__dpp_skip(store_format, s_index, spark_tmp_table_factory, 'true')
Loading