diff --git a/docs/configs.md b/docs/configs.md
index 5a89b08a9b9..d80e3574bd1 100644
--- a/docs/configs.md
+++ b/docs/configs.md
@@ -339,7 +339,7 @@ Name | Description | Default Value | Notes
spark.rapids.sql.exec.BroadcastExchangeExec|The backend for broadcast exchange of data|true|None|
spark.rapids.sql.exec.ShuffleExchangeExec|The backend for most data being exchanged between processes|true|None|
spark.rapids.sql.exec.BroadcastHashJoinExec|Implementation of join using broadcast data|true|None|
-spark.rapids.sql.exec.BroadcastNestedLoopJoinExec|Implementation of join using brute force|true|None|
+spark.rapids.sql.exec.BroadcastNestedLoopJoinExec|Implementation of join using brute force. Full outer joins and joins where the broadcast side matches the join side (e.g.: LeftOuter with left broadcast) are not supported. A non-inner join only is supported if the join condition expression can be converted to a GPU AST expression|true|None|
spark.rapids.sql.exec.CartesianProductExec|Implementation of join using brute force|true|None|
spark.rapids.sql.exec.ShuffledHashJoinExec|Implementation of join using hashed shuffled data|true|None|
spark.rapids.sql.exec.SortMergeJoinExec|Sort merge join, replacing with shuffled hash join|true|None|
diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index b97efdad160..de40a8ca321 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -729,7 +729,7 @@ Accelerator supports are described below.
BroadcastNestedLoopJoinExec |
-Implementation of join using brute force |
+Implementation of join using brute force. Full outer joins and joins where the broadcast side matches the join side (e.g.: LeftOuter with left broadcast) are not supported. A non-inner join only is supported if the join condition expression can be converted to a GPU AST expression |
Input |
None |
S |
diff --git a/integration_tests/src/main/python/join_test.py b/integration_tests/src/main/python/join_test.py
index ddc9a14381f..98d94e68263 100644
--- a/integration_tests/src/main/python/join_test.py
+++ b/integration_tests/src/main/python/join_test.py
@@ -20,10 +20,11 @@
from marks import ignore_order, allow_non_gpu, incompat, validate_execs_in_gpu_plan
from spark_session import with_cpu_session, with_spark_session
-
# Mark all tests in current file as slow test since it would require ~30mins in total
pytestmark = pytest.mark.slow_test
+all_join_types = ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter']
+
all_gen = [StringGen(), ByteGen(), ShortGen(), IntegerGen(), LongGen(),
BooleanGen(), DateGen(), TimestampGen(), null_gen,
pytest.param(FloatGen(), marks=[incompat]),
@@ -98,7 +99,7 @@ def create_nested_df(spark, key_data_gen, data_gen, left_length, right_length):
# After 3.1.0 is the min spark version we can drop this
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
-@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
+@pytest.mark.parametrize('join_type', all_join_types, ids=idfn)
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
def test_sortmerge_join(data_gen, join_type, batch_size):
def do_join(spark):
@@ -110,7 +111,7 @@ def do_join(spark):
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', single_level_array_gens_no_decimal, ids=idfn)
-@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
+@pytest.mark.parametrize('join_type', all_join_types, ids=idfn)
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
def test_sortmerge_join_array(data_gen, join_type, batch_size):
def do_join(spark):
@@ -122,7 +123,7 @@ def do_join(spark):
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', map_string_string_gen, ids=idfn)
-@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
+@pytest.mark.parametrize('join_type', all_join_types, ids=idfn)
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
def test_sortmerge_join_map(data_gen, join_type, batch_size):
def do_join(spark):
@@ -138,7 +139,7 @@ def do_join(spark):
'NamedLambdaVariable', 'NormalizeNaNAndZero', 'ShuffleExchangeExec', 'HashPartitioning')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', single_level_array_gens, ids=idfn)
-@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
+@pytest.mark.parametrize('join_type', all_join_types, ids=idfn)
def test_sortmerge_join_array_as_key(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 500, 500)
@@ -147,7 +148,7 @@ def do_join(spark):
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [all_basic_struct_gen], ids=idfn)
-@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
+@pytest.mark.parametrize('join_type', all_join_types, ids=idfn)
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test out of core joins too
def test_sortmerge_join_struct(data_gen, join_type, batch_size):
def do_join(spark):
@@ -168,7 +169,7 @@ def do_join(spark):
@validate_execs_in_gpu_plan('GpuShuffledHashJoinExec')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', single_level_array_gens_no_decimal, ids=idfn)
-@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
+@pytest.mark.parametrize('join_type', all_join_types, ids=idfn)
def test_hash_join_array(data_gen, join_type):
def do_join(spark):
left, right = create_nested_df(spark, short_gen, data_gen, 50, 500)
@@ -178,7 +179,7 @@ def do_join(spark):
@validate_execs_in_gpu_plan('GpuShuffledHashJoinExec')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', map_string_string_gen, ids=idfn)
-@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
+@pytest.mark.parametrize('join_type', all_join_types, ids=idfn)
def test_hash_join_map(data_gen, join_type):
def do_join(spark):
left, right = create_nested_df(spark, short_gen, data_gen, 50, 500)
@@ -191,7 +192,7 @@ def do_join(spark):
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
# Not all join types can be translated to a broadcast join, but this tests them to be sure we
# can handle what spark is doing
-@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
+@pytest.mark.parametrize('join_type', all_join_types, ids=idfn)
def test_broadcast_join_right_table(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 500, 250)
@@ -202,7 +203,7 @@ def do_join(spark):
@pytest.mark.parametrize('data_gen', single_level_array_gens_no_decimal, ids=idfn)
# Not all join types can be translated to a broadcast join, but this tests them to be sure we
# can handle what spark is doing
-@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
+@pytest.mark.parametrize('join_type', all_join_types, ids=idfn)
def test_broadcast_join_right_table_array(data_gen, join_type):
def do_join(spark):
left, right = create_nested_df(spark, short_gen, data_gen, 500, 500)
@@ -213,7 +214,7 @@ def do_join(spark):
@pytest.mark.parametrize('data_gen', map_string_string_gen, ids=idfn)
# Not all join types can be translated to a broadcast join, but this tests them to be sure we
# can handle what spark is doing
-@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
+@pytest.mark.parametrize('join_type', all_join_types, ids=idfn)
def test_broadcast_join_right_table_map(data_gen, join_type):
def do_join(spark):
left, right = create_nested_df(spark, short_gen, data_gen, 500, 500)
@@ -224,7 +225,7 @@ def do_join(spark):
@pytest.mark.parametrize('data_gen', [all_basic_struct_gen], ids=idfn)
# Not all join types can be translated to a broadcast join, but this tests them to be sure we
# can handle what spark is doing
-@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
+@pytest.mark.parametrize('join_type', all_join_types, ids=idfn)
def test_broadcast_join_right_table_struct(data_gen, join_type):
def do_join(spark):
left, right = create_nested_df(spark, short_gen, data_gen, 500, 500)
@@ -237,7 +238,7 @@ def do_join(spark):
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
# Not all join types can be translated to a broadcast join, but this tests them to be sure we
# can handle what spark is doing
-@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
+@pytest.mark.parametrize('join_type', all_join_types, ids=idfn)
def test_broadcast_join_right_table_with_job_group(data_gen, join_type):
with_cpu_session(lambda spark : spark.sparkContext.setJobGroup("testjob1", "test", False))
def do_join(spark):
@@ -293,7 +294,7 @@ def do_join(spark):
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
@pytest.mark.parametrize('batch_size', ['100', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
-def test_cartesian_join_with_conditionals(data_gen, batch_size):
+def test_cartesian_join_with_condition(data_gen, batch_size):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# This test is impacted by https://github.com/NVIDIA/spark-rapids/issues/294
@@ -349,18 +350,17 @@ def do_join(spark):
# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@ignore_order(local=True)
-@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
-@pytest.mark.parametrize('join_type', ['Inner', 'Cross'], ids=idfn)
+@pytest.mark.parametrize('data_gen', ast_gen, ids=idfn)
+@pytest.mark.parametrize('join_type', ['Left', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross'], ids=idfn)
@pytest.mark.parametrize('batch_size', ['100', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
-def test_broadcast_nested_loop_innerlike_join_with_conditionals(data_gen, join_type, batch_size):
+def test_right_broadcast_nested_loop_join_with_ast_condition(data_gen, join_type, batch_size):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# This test is impacted by https://github.com/NVIDIA/spark-rapids/issues/294
# if the sizes are large enough to have both 0.0 and -0.0 show up 500 and 250
# but these take a long time to verify so we run with smaller numbers by default
# that do not expose the error
- return left.join(broadcast(right),
- (left.b >= right.r_b), join_type)
+ return left.join(broadcast(right), (left.b >= right.r_b), join_type)
conf = {'spark.rapids.sql.batchSizeBytes': batch_size}
conf.update(allow_negative_scale_of_decimal_conf)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=conf)
@@ -369,24 +369,54 @@ def do_join(spark):
# After 3.1.0 is the min spark version we can drop this
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', ast_gen, ids=idfn)
-@pytest.mark.parametrize('join_type', ['LeftSemi', 'LeftAnti'], ids=idfn)
@pytest.mark.parametrize('batch_size', ['100', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
-def test_broadcast_nested_loop_join_with_ast_condition(data_gen, join_type, batch_size):
+def test_left_broadcast_nested_loop_join_with_ast_condition(data_gen, batch_size):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# This test is impacted by https://github.com/NVIDIA/spark-rapids/issues/294
# if the sizes are large enough to have both 0.0 and -0.0 show up 500 and 250
# but these take a long time to verify so we run with smaller numbers by default
# that do not expose the error
- return left.join(broadcast(right),
- (left.b >= right.r_b), join_type)
+ return broadcast(left).join(right, (left.b >= right.r_b), 'Right')
+ conf = {'spark.rapids.sql.batchSizeBytes': batch_size}
+ conf.update(allow_negative_scale_of_decimal_conf)
+ assert_gpu_and_cpu_are_equal_collect(do_join, conf=conf)
+
+# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
+# After 3.1.0 is the min spark version we can drop this
+@ignore_order(local=True)
+@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
+@pytest.mark.parametrize('join_type', ['Inner', 'Cross'], ids=idfn)
+@pytest.mark.parametrize('batch_size', ['100', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
+def test_broadcast_nested_loop_join_with_condition_post_filter(data_gen, join_type, batch_size):
+ def do_join(spark):
+ left, right = create_df(spark, data_gen, 50, 25)
+ # This test is impacted by https://github.com/NVIDIA/spark-rapids/issues/294
+ # if the sizes are large enough to have both 0.0 and -0.0 show up 500 and 250
+ # but these take a long time to verify so we run with smaller numbers by default
+ # that do not expose the error
+ # AST does not support logical AND, so this must be implemented as a post-filter
+ return left.join(broadcast(right), (left.a >= right.r_a) & (left.b >= right.r_b), join_type)
conf = {'spark.rapids.sql.batchSizeBytes': batch_size}
conf.update(allow_negative_scale_of_decimal_conf)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=conf)
+@allow_non_gpu('And', 'BroadcastExchangeExec', 'BroadcastNestedLoopJoinExec', 'CheckOverflow', 'Divide', 'GreaterThanOrEqual')
+@ignore_order(local=True)
+@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
+@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn)
+def test_broadcast_nested_loop_join_with_condition_fallback(data_gen, join_type):
+ def do_join(spark):
+ left, right = create_df(spark, data_gen, 50, 25)
+ # AST does not support logical AND yet
+ return broadcast(left).join(right, (left.a >= right.r_a) & (left.b >= right.r_b), join_type)
+ conf = allow_negative_scale_of_decimal_conf
+ assert_gpu_fallback_collect(do_join, 'BroadcastNestedLoopJoinExec', conf=conf)
+
+@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen + single_level_array_gens, ids=idfn)
-@pytest.mark.parametrize('join_type', ['Inner', 'LeftSemi', 'LeftAnti'], ids=idfn)
-def test_broadcast_nested_loop_join_condition_missing(data_gen, join_type):
+@pytest.mark.parametrize('join_type', ['Left', 'LeftSemi', 'LeftAnti'], ids=idfn)
+def test_right_broadcast_nested_loop_join_condition_missing(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# This test is impacted by https://github.com/NVIDIA/spark-rapids/issues/294
@@ -397,21 +427,44 @@ def do_join(spark):
conf = allow_negative_scale_of_decimal_conf
assert_gpu_and_cpu_are_equal_collect(do_join, conf=conf)
+@ignore_order(local=True)
+@pytest.mark.parametrize('data_gen', all_gen + single_level_array_gens, ids=idfn)
+@pytest.mark.parametrize('join_type', ['Right'], ids=idfn)
+def test_left_broadcast_nested_loop_join_condition_missing(data_gen, join_type):
+ def do_join(spark):
+ left, right = create_df(spark, data_gen, 50, 25)
+ # This test is impacted by https://github.com/NVIDIA/spark-rapids/issues/294
+ # if the sizes are large enough to have both 0.0 and -0.0 show up 500 and 250
+ # but these take a long time to verify so we run with smaller numbers by default
+ # that do not expose the error
+ return broadcast(left).join(right, how=join_type)
+ conf = allow_negative_scale_of_decimal_conf
+ assert_gpu_and_cpu_are_equal_collect(do_join, conf=conf)
+
@pytest.mark.xfail(condition=is_databricks_runtime(),
reason='https://github.com/NVIDIA/spark-rapids/issues/3244')
@pytest.mark.parametrize('data_gen', all_gen + single_level_array_gens, ids=idfn)
-@pytest.mark.parametrize('join_type', ['Inner', 'LeftSemi', 'LeftAnti'], ids=idfn)
-def test_broadcast_nested_loop_join_condition_missing_count(data_gen, join_type):
+@pytest.mark.parametrize('join_type', ['Left', 'LeftSemi', 'LeftAnti'], ids=idfn)
+def test_right_broadcast_nested_loop_join_condition_missing_count(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
return left.join(broadcast(right), how=join_type).selectExpr('COUNT(*)')
conf = allow_negative_scale_of_decimal_conf
assert_gpu_and_cpu_are_equal_collect(do_join, conf=conf)
+@pytest.mark.parametrize('data_gen', all_gen + single_level_array_gens, ids=idfn)
+@pytest.mark.parametrize('join_type', ['Right'], ids=idfn)
+def test_left_broadcast_nested_loop_join_condition_missing_count(data_gen, join_type):
+ def do_join(spark):
+ left, right = create_df(spark, data_gen, 50, 25)
+ return broadcast(left).join(right, how=join_type).selectExpr('COUNT(*)')
+ conf = allow_negative_scale_of_decimal_conf
+ assert_gpu_and_cpu_are_equal_collect(do_join, conf=conf)
+
@allow_non_gpu('BroadcastExchangeExec', 'BroadcastNestedLoopJoinExec', 'GreaterThanOrEqual')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
-@pytest.mark.parametrize('join_type', ['LeftSemi', 'LeftAnti'], ids=idfn)
+@pytest.mark.parametrize('join_type', ['LeftOuter', 'LeftSemi', 'LeftAnti', 'FullOuter'], ids=idfn)
def test_broadcast_nested_loop_join_with_conditionals_build_left_fallback(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
@@ -419,13 +472,24 @@ def do_join(spark):
conf = allow_negative_scale_of_decimal_conf
assert_gpu_fallback_collect(do_join, 'BroadcastNestedLoopJoinExec', conf=conf)
+@allow_non_gpu('BroadcastExchangeExec', 'BroadcastNestedLoopJoinExec', 'GreaterThanOrEqual')
+@ignore_order(local=True)
+@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
+@pytest.mark.parametrize('join_type', ['RightOuter', 'FullOuter'], ids=idfn)
+def test_broadcast_nested_loop_with_conditionals_build_right_fallback(data_gen, join_type):
+ def do_join(spark):
+ left, right = create_df(spark, data_gen, 50, 25)
+ return left.join(broadcast(right), (left.b >= right.r_b), join_type)
+ conf = allow_negative_scale_of_decimal_conf
+ assert_gpu_fallback_collect(do_join, 'BroadcastNestedLoopJoinExec', conf=conf)
+
# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
# Not all join types can be translated to a broadcast join, but this tests them to be sure we
# can handle what spark is doing
-@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
+@pytest.mark.parametrize('join_type', all_join_types, ids=idfn)
def test_broadcast_join_left_table(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 250, 500)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AbstractGpuJoinIterator.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AbstractGpuJoinIterator.scala
new file mode 100644
index 00000000000..0f412a5781c
--- /dev/null
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AbstractGpuJoinIterator.scala
@@ -0,0 +1,310 @@
+/*
+ * Copyright (c) 2021, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids
+
+import scala.collection.mutable
+
+import ai.rapids.cudf.{GatherMap, NvtxColor}
+import com.nvidia.spark.rapids.RapidsBuffer.SpillCallback
+
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+/**
+ * Base class for iterators producing the results of a join.
+ * @param gatherNvtxName name to use for the NVTX range when producing the join gather maps
+ * @param targetSize configured target batch size in bytes
+ * @param joinTime metric to record GPU time spent in join
+ * @param totalTime metric to record total time in the iterator
+ */
+abstract class AbstractGpuJoinIterator(
+ gatherNvtxName: String,
+ targetSize: Long,
+ joinTime: GpuMetric,
+ totalTime: GpuMetric) extends Iterator[ColumnarBatch] with Arm with AutoCloseable {
+ private[this] var nextCb: Option[ColumnarBatch] = None
+ private[this] var gathererStore: Option[JoinGatherer] = None
+
+ protected[this] var closed = false
+
+ TaskContext.get().addTaskCompletionListener[Unit](_ => close())
+
+ /** Returns whether there are any more batches on the stream side of the join */
+ protected def hasNextStreamBatch: Boolean
+
+ /**
+ * Called to setup the next join gatherer instance when the previous instance is done or
+ * there is no previous instance.
+ * @param startNanoTime system nanoseconds timestamp at the top of the iterator loop, useful for
+ * calculating the time spent producing the next stream batch
+ * @return some gatherer to use next or None if there is no next gatherer or the loop should try
+ * to build the gatherer again (e.g.: to skip a degenerate join result batch)
+ */
+ protected def setupNextGatherer(startNanoTime: Long): Option[JoinGatherer]
+
+ override def hasNext: Boolean = {
+ if (closed) {
+ return false
+ }
+ var mayContinue = true
+ while (nextCb.isEmpty && mayContinue) {
+ val startNanoTime = System.nanoTime()
+ if (gathererStore.exists(!_.isDone)) {
+ nextCb = nextCbFromGatherer()
+ } else if (hasNextStreamBatch) {
+ // Need to refill the gatherer
+ gathererStore.foreach(_.close())
+ gathererStore = None
+ gathererStore = setupNextGatherer(startNanoTime)
+ nextCb = nextCbFromGatherer()
+ } else {
+ mayContinue = false
+ }
+ totalTime += (System.nanoTime() - startNanoTime)
+ }
+ if (nextCb.isEmpty) {
+ // Nothing is left to return so close ASAP.
+ close()
+ }
+ nextCb.isDefined
+ }
+
+ override def next(): ColumnarBatch = {
+ if (!hasNext) {
+ throw new NoSuchElementException()
+ }
+ val ret = nextCb.get
+ nextCb = None
+ ret
+ }
+
+ override def close(): Unit = {
+ if (!closed) {
+ nextCb.foreach(_.close())
+ nextCb = None
+ gathererStore.foreach(_.close())
+ gathererStore = None
+ closed = true
+ }
+ }
+
+ private def nextCbFromGatherer(): Option[ColumnarBatch] = {
+ withResource(new NvtxWithMetrics(gatherNvtxName, NvtxColor.DARK_GREEN, joinTime)) { _ =>
+ val ret = gathererStore.map { gather =>
+ val nextRows = JoinGatherer.getRowsInNextBatch(gather, targetSize)
+ gather.gatherNext(nextRows)
+ }
+ if (gathererStore.exists(_.isDone)) {
+ gathererStore.foreach(_.close())
+ gathererStore = None
+ }
+
+ if (ret.isDefined) {
+ // We are about to return something. We got everything we need from it so now let it spill
+ // if there is more to be gathered later on.
+ gathererStore.foreach(_.allowSpilling())
+ }
+ ret
+ }
+ }
+}
+
+/**
+ * Base class for join iterators that split and spill batches to avoid GPU OOM errors.
+ * @param gatherNvtxName name to use for the NVTX range when producing the join gather maps
+ * @param stream iterator to produce the batches for the streaming side input of the join
+ * @param streamAttributes attributes corresponding to the streaming side input
+ * @param builtBatch batch for the built side input of the join
+ * @param targetSize configured target batch size in bytes
+ * @param spillCallback callback to use when spilling
+ * @param joinTime metric to record GPU time spent in join
+ * @param streamTime metric to record time spent producing streaming side batches
+ * @param totalTime metric to record total time in the iterator
+ */
+abstract class SplittableJoinIterator(
+ gatherNvtxName: String,
+ stream: Iterator[LazySpillableColumnarBatch],
+ streamAttributes: Seq[Attribute],
+ builtBatch: LazySpillableColumnarBatch,
+ targetSize: Long,
+ spillCallback: SpillCallback,
+ joinTime: GpuMetric,
+ streamTime: GpuMetric,
+ totalTime: GpuMetric)
+ extends AbstractGpuJoinIterator(
+ gatherNvtxName,
+ targetSize,
+ joinTime = joinTime,
+ totalTime = totalTime) with Logging {
+ // For some join types even if there is no stream data we might output something
+ private var isInitialJoin = true
+ // If the join explodes this holds batches from the stream side split into smaller pieces.
+ private val pendingSplits = mutable.Queue[SpillableColumnarBatch]()
+
+ protected def computeNumJoinRows(cb: ColumnarBatch): Long
+
+ /**
+ * Create a join gatherer.
+ * @param cb next column batch from the streaming side of the join
+ * @param numJoinRows if present, the number of join output rows computed for this batch
+ * @return some gatherer to use next or None if there is no next gatherer or the loop should try
+ * to build the gatherer again (e.g.: to skip a degenerate join result batch)
+ */
+ protected def createGatherer(cb: ColumnarBatch, numJoinRows: Option[Long]): Option[JoinGatherer]
+
+ override def hasNextStreamBatch: Boolean = {
+ isInitialJoin || pendingSplits.nonEmpty || stream.hasNext
+ }
+
+ override def setupNextGatherer(startNanoTime: Long): Option[JoinGatherer] = {
+ val wasInitialJoin = isInitialJoin
+ isInitialJoin = false
+ if (pendingSplits.nonEmpty || stream.hasNext) {
+ val cb = if (pendingSplits.nonEmpty) {
+ withResource(pendingSplits.dequeue()) {
+ _.getColumnarBatch()
+ }
+ } else {
+ val batch = withResource(stream.next()) { lazyBatch =>
+ lazyBatch.releaseBatch()
+ }
+ streamTime += (System.nanoTime() - startNanoTime)
+ batch
+ }
+ withResource(cb) { cb =>
+ val numJoinRows = computeNumJoinRows(cb)
+
+ // We want the gather maps size to be around the target size. There are two gather maps
+ // that are made up of ints, so compute how many rows on the stream side will produce the
+ // desired gather maps size.
+ val maxJoinRows = Math.max(1, targetSize / (2 * Integer.BYTES))
+ if (numJoinRows > maxJoinRows && cb.numRows() > 1) {
+ // Need to split the batch to reduce the gather maps size. This takes a simplistic
+ // approach of assuming the data is uniformly distributed in the stream table.
+ val numSplits = Math.min(cb.numRows(),
+ Math.ceil(numJoinRows.toDouble / maxJoinRows).toInt)
+ splitAndSave(cb, numSplits)
+
+ // Return no gatherer so the outer loop will try again
+ return None
+ }
+
+ createGatherer(cb, Some(numJoinRows))
+ }
+ } else {
+ assert(wasInitialJoin)
+ import scala.collection.JavaConverters._
+ withResource(GpuColumnVector.emptyBatch(streamAttributes.asJava)) { cb =>
+ createGatherer(cb, None)
+ }
+ }
+ }
+
+ override def close(): Unit = {
+ if (!closed) {
+ super.close()
+ builtBatch.close()
+ pendingSplits.foreach(_.close())
+ pendingSplits.clear()
+ }
+ }
+
+ /**
+ * Split a stream-side input batch, making all splits spillable, and replacing this batch with
+ * the splits in the stream-side input
+ * @param cb stream-side input batch to split
+ * @param numBatches number of splits to produce with approximately the same number of rows each
+ * @param oom a prior OOM exception that this will try to recover from by splitting
+ */
+ protected def splitAndSave(
+ cb: ColumnarBatch,
+ numBatches: Int,
+ oom: Option[OutOfMemoryError] = None): Unit = {
+ val batchSize = cb.numRows() / numBatches
+ if (oom.isDefined && batchSize < 100) {
+ // We just need some kind of cutoff to not get stuck in a loop if the batches get to be too
+ // small but we want to at least give it a chance to work (mostly for tests where the
+ // targetSize can be set really small)
+ throw oom.get
+ }
+ val msg = s"Split stream batch into $numBatches batches of about $batchSize rows"
+ if (oom.isDefined) {
+ logWarning(s"OOM Encountered: $msg")
+ } else {
+ logInfo(msg)
+ }
+ val splits = withResource(GpuColumnVector.from(cb)) { tab =>
+ val splitIndexes = (1 until numBatches).map(num => num * batchSize)
+ tab.contiguousSplit(splitIndexes: _*)
+ }
+ withResource(splits) { splits =>
+ val schema = GpuColumnVector.extractTypes(cb)
+ pendingSplits ++= splits.map { ct =>
+ SpillableColumnarBatch(ct, schema,
+ SpillPriorities.ACTIVE_ON_DECK_PRIORITY, spillCallback)
+ }
+ }
+ }
+
+ /**
+ * Create a join gatherer from gather maps.
+ * @param maps gather maps produced from a cudf join
+ * @param leftData batch corresponding to the left table in the join
+ * @param rightData batch corresponding to the right table in the join
+ * @return some gatherer or None if the are no rows to gather in this join batch
+ */
+ protected def makeGatherer(
+ maps: Array[GatherMap],
+ leftData: LazySpillableColumnarBatch,
+ rightData: LazySpillableColumnarBatch): Option[JoinGatherer] = {
+ assert(maps.length > 0 && maps.length <= 2)
+ try {
+ val leftMap = maps.head
+ val rightMap = if (maps.length > 1) {
+ if (rightData.numCols == 0) {
+ // No data so don't bother with it
+ None
+ } else {
+ Some(maps(1))
+ }
+ } else {
+ None
+ }
+
+ val lazyLeftMap = LazySpillableGatherMap(leftMap, spillCallback, "left_map")
+ val gatherer = rightMap match {
+ case None =>
+ rightData.close()
+ JoinGatherer(lazyLeftMap, leftData)
+ case Some(right) =>
+ val lazyRightMap = LazySpillableGatherMap(right, spillCallback, "right_map")
+ JoinGatherer(lazyLeftMap, leftData, lazyRightMap, rightData)
+ }
+ if (gatherer.isDone) {
+ // Nothing matched...
+ gatherer.close()
+ None
+ } else {
+ Some(gatherer)
+ }
+ } finally {
+ maps.foreach(_.close())
+ }
+ }
+}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index c2dbebd6cf9..fb823a0f17d 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -3082,7 +3082,10 @@ object GpuOverrides {
TypeSig.DECIMAL_64 + TypeSig.STRUCT), TypeSig.all),
(exchange, conf, p, r) => new GpuBroadcastMeta(exchange, conf, p, r)),
exec[BroadcastNestedLoopJoinExec](
- "Implementation of join using brute force",
+ "Implementation of join using brute force. Full outer joins and joins where the " +
+ "broadcast side matches the join side (e.g.: LeftOuter with left broadcast) are not " +
+ "supported. A non-inner join only is supported if the join condition expression can " +
+ "be converted to a GPU AST expression",
ExecChecks(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_64 +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_64),
TypeSig.all),
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala
index 4117d2b16f3..a2cfdb9a421 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala
@@ -201,6 +201,12 @@ trait LazySpillableColumnarBatch extends LazySpillable {
* Get the batch that this wraps and unspill it if needed.
*/
def getBatch: ColumnarBatch
+
+ /**
+ * Release the underlying batch to the caller who is responsible for closing it. The resulting
+ * batch will NOT be closed when this instance is closed.
+ */
+ def releaseBatch(): ColumnarBatch
}
object LazySpillableColumnarBatch {
@@ -218,13 +224,20 @@ object LazySpillableColumnarBatch {
/**
* A version of `LazySpillableColumnarBatch` where instead of closing the underlying
* batch it is only spilled. This is used for cases, like with a streaming hash join
- * where the data itself needs to out live the JoinGatherer it is haded off to.
+ * where the data itself needs to out live the JoinGatherer it is handed off to.
*/
case class AllowSpillOnlyLazySpillableColumnarBatchImpl(wrapped: LazySpillableColumnarBatch)
- extends LazySpillableColumnarBatch {
+ extends LazySpillableColumnarBatch with Arm {
override def getBatch: ColumnarBatch =
wrapped.getBatch
+ override def releaseBatch(): ColumnarBatch = {
+ closeOnExcept(GpuColumnVector.incRefCounts(wrapped.getBatch)) { batch =>
+ wrapped.allowSpilling()
+ batch
+ }
+ }
+
override def numRows: Int = wrapped.numRows
override def numCols: Int = wrapped.numCols
override def deviceMemorySize: Long = wrapped.deviceMemorySize
@@ -262,7 +275,15 @@ class LazySpillableColumnarBatchImpl(
cached = spill.map(_.getColumnarBatch())
}
}
- cached.get
+ cached.getOrElse(throw new IllegalStateException("batch is closed"))
+ }
+
+ override def releaseBatch(): ColumnarBatch = {
+ closeOnExcept(getBatch) { batch =>
+ cached = None
+ close()
+ batch
+ }
}
override def allowSpilling(): Unit = {
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCartesianProductExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCartesianProductExec.scala
index b55bed09378..bcd72ee6921 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCartesianProductExec.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCartesianProductExec.scala
@@ -21,7 +21,6 @@ import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
import scala.collection.mutable
import ai.rapids.cudf.{JCudfSerialization, NvtxColor, NvtxRange}
-import ai.rapids.cudf.ast
import com.nvidia.spark.rapids.{Arm, GpuBindReferences, GpuBuildLeft, GpuColumnVector, GpuExec, GpuExpression, GpuMetric, GpuSemaphore, LazySpillableColumnarBatch, MetricsLevel}
import com.nvidia.spark.rapids.RapidsBuffer.SpillCallback
import com.nvidia.spark.rapids.RapidsPluginImplicits._
@@ -115,6 +114,7 @@ class GpuCartesianRDD(
sc: SparkContext,
boundCondition: Option[GpuExpression],
numFirstTableColumns: Int,
+ streamAttributes: Seq[Attribute],
spillCallback: SpillCallback,
targetSize: Long,
joinTime: GpuMetric,
@@ -186,9 +186,13 @@ class GpuCartesianRDD(
}
GpuBroadcastNestedLoopJoinExecBase.nestedLoopJoin(
- Cross, numFirstTableColumns, batch, streamIterator, targetSize, GpuBuildLeft,
- boundCondition, spillCallback, numOutputRows, joinOutputRows, numOutputBatches,
- joinTime, totalTime)
+ Cross, GpuBuildLeft, numFirstTableColumns, batch, streamIterator, streamAttributes,
+ targetSize, boundCondition, spillCallback,
+ numOutputRows = numOutputRows,
+ joinOutputRows = joinOutputRows,
+ numOutputBatches = numOutputBatches,
+ joinTime = joinTime,
+ totalTime = totalTime)
}
}
@@ -273,6 +277,7 @@ case class GpuCartesianProductExec(
new GpuCartesianRDD(sparkContext,
boundCondition,
numFirstTableColumns,
+ right.output,
spillCallback,
targetSizeBytes,
joinTime,
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala
index e3c7f01997f..a119e7766ea 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala
@@ -16,8 +16,7 @@
package org.apache.spark.sql.rapids.execution
-import ai.rapids.cudf.NvtxColor
-import ai.rapids.cudf.ast
+import ai.rapids.cudf.{ast, GatherMap, NvtxColor, Table}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.RapidsBuffer.SpillCallback
@@ -26,7 +25,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
-import org.apache.spark.sql.catalyst.plans.{Cross, ExistenceJoin, FullOuter, InnerLike, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi, RightOuter}
+import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, IdentityBroadcastMode, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
@@ -47,23 +46,26 @@ class GpuBroadcastNestedLoopJoinMeta(
override val childExprs: Seq[BaseExprMeta[_]] = conditionMeta.toSeq
override def tagPlanForGpu(): Unit = {
+ val gpuBuildSide = ShimLoader.getSparkShims.getBuildSide(join)
JoinTypeChecks.tagForGpu(join.joinType, this)
conditionMeta.foreach(_.tagForAst())
join.joinType match {
case _: InnerLike =>
- case LeftSemi | LeftAnti =>
+ case LeftOuter | RightOuter | LeftSemi | LeftAnti =>
if (conditionMeta.exists(!_.canThisBeAst)) {
val astInfo = conditionMeta.get.explainAst(false)
willNotWorkOnGpu(s"AST cannot support join condition:\n$astInfo")
}
- val gpuBuildSide = ShimLoader.getSparkShims.getBuildSide(join)
- if (gpuBuildSide == GpuBuildLeft) {
- willNotWorkOnGpu(s"build left not supported for ${join.joinType}")
- }
case _ => willNotWorkOnGpu(s"${join.joinType} currently is not supported")
}
+ join.joinType match {
+ case LeftOuter | LeftSemi | LeftAnti if gpuBuildSide == GpuBuildLeft =>
+ willNotWorkOnGpu(s"build left not supported for ${join.joinType}")
+ case RightOuter if gpuBuildSide == GpuBuildRight =>
+ willNotWorkOnGpu(s"build right not supported for ${join.joinType}")
+ case _ =>
+ }
- val gpuBuildSide = ShimLoader.getSparkShims.getBuildSide(join)
val Seq(leftPlan, rightPlan) = childPlans
val buildSide = gpuBuildSide match {
case GpuBuildLeft => leftPlan
@@ -91,15 +93,18 @@ class GpuBroadcastNestedLoopJoinMeta(
verifyBuildSideWasReplaced(buildSide)
val condition = conditionMeta.map(_.convertToGpu())
- // Do not yet support AST conditions on anything but semi/anti joins
- val isAstCondition = join.joinType match {
- case _: InnerLike => false
- case LeftSemi | LeftAnti =>
- conditionMeta.foreach(_.tagForAst())
- val isAst = conditionMeta.forall(_.canThisBeAst)
- assert(isAst, s"Non-AST condition in ${join.joinType}")
- isAst
- case _ => throw new IllegalStateException(s"${join.joinType} nested loop join not supported")
+ conditionMeta.foreach(_.tagForAst())
+ val isAstCondition = conditionMeta.forall(_.canThisBeAst)
+ join.joinType match {
+ case _: InnerLike =>
+ case LeftOuter | LeftSemi | LeftAnti if gpuBuildSide == GpuBuildLeft =>
+ throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}")
+ case RightOuter if gpuBuildSide == GpuBuildRight =>
+ throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}")
+ case LeftOuter | RightOuter | LeftSemi | LeftAnti =>
+ // Cannot post-filter these types of joins
+ assert(isAstCondition, s"Non-AST condition in ${join.joinType}")
+ case _ => throw new IllegalStateException(s"Unsupported join type ${join.joinType}")
}
val joinExec = ShimLoader.getSparkShims.getGpuBroadcastNestedLoopJoinShim(
@@ -116,93 +121,6 @@ class GpuBroadcastNestedLoopJoinMeta(
}
}
-/** Base class for the join iterators based on a nested loop join */
-abstract class NestedLoopJoinIterator(
- joinType: JoinType,
- builtBatch: LazySpillableColumnarBatch,
- private val stream: Iterator[LazySpillableColumnarBatch],
- targetSize: Long,
- private val joinTime: GpuMetric,
- private val totalTime: GpuMetric) extends Iterator[ColumnarBatch] with Arm {
-
- private var nextCb: Option[ColumnarBatch] = None
- private var gathererStore: Option[JoinGatherer] = None
- private var closed = false
- private val nvtxName = s"$joinType gather"
-
- def close(): Unit = {
- if (!closed) {
- nextCb.foreach(_.close())
- nextCb = None
- gathererStore.foreach(_.close())
- gathererStore = None
- // Close the build batch we are done with it.
- builtBatch.close()
- closed = true
- }
- }
-
- TaskContext.get().addTaskCompletionListener[Unit](_ => close())
-
- private def nextCbFromGatherer(): Option[ColumnarBatch] = {
- withResource(new NvtxWithMetrics(nvtxName, NvtxColor.DARK_GREEN, joinTime)) { _ =>
- val ret = gathererStore.map { gather =>
- val nextRows = JoinGatherer.getRowsInNextBatch(gather, targetSize)
- gather.gatherNext(nextRows)
- }
- if (gathererStore.exists(_.isDone)) {
- gathererStore.foreach(_.close())
- gathererStore = None
- }
-
- if (ret.isDefined) {
- // We are about to return something. We got everything we need from it so now let it spill
- // if there is more to be gathered later on.
- gathererStore.foreach(_.allowSpilling())
- }
- ret
- }
- }
-
- protected def makeGatherer(streamBatch: LazySpillableColumnarBatch): Option[JoinGatherer]
-
- override def hasNext: Boolean = {
- if (closed) {
- return false
- }
- var mayContinue = true
- while (nextCb.isEmpty && mayContinue) {
- val startTime = System.nanoTime()
- if (gathererStore.exists(!_.isDone)) {
- nextCb = nextCbFromGatherer()
- } else if (stream.hasNext) {
- // Need to refill the gatherer
- gathererStore.foreach(_.close())
- gathererStore = None
- gathererStore = makeGatherer(stream.next())
- nextCb = nextCbFromGatherer()
- } else {
- mayContinue = false
- }
- totalTime += (System.nanoTime() - startTime)
- }
- if (nextCb.isEmpty) {
- // Nothing is left to return so close ASAP.
- close()
- }
- nextCb.isDefined
- }
-
- override def next(): ColumnarBatch = {
- if (!hasNext) {
- throw new NoSuchElementException()
- }
- val ret = nextCb.get
- nextCb = None
- ret
- }
-}
-
/**
* An iterator that does a cross join against a stream of batches.
*/
@@ -211,17 +129,25 @@ class CrossJoinIterator(
stream: Iterator[LazySpillableColumnarBatch],
targetSize: Long,
buildSide: GpuBuildSide,
- private val joinTime: GpuMetric,
- private val totalTime: GpuMetric)
- extends NestedLoopJoinIterator(
- Cross,
- builtBatch,
- stream,
+ joinTime: GpuMetric,
+ totalTime: GpuMetric)
+ extends AbstractGpuJoinIterator(
+ "Cross join gather",
targetSize,
joinTime,
totalTime) {
+ override def close(): Unit = {
+ if (!closed) {
+ super.close()
+ builtBatch.close()
+ }
+ }
+
+ override def hasNextStreamBatch: Boolean = stream.hasNext
+
+ override def setupNextGatherer(startTime: Long): Option[JoinGatherer] = {
+ val streamBatch = stream.next()
- override def makeGatherer(streamBatch: LazySpillableColumnarBatch): Option[JoinGatherer] = {
// Don't close the built side because it will be used for each stream and closed
// when the iterator is done.
val (leftBatch, rightBatch) = buildSide match {
@@ -252,67 +178,127 @@ class CrossJoinIterator(
}
}
-class ConditionalSemiOrAntiJoinIterator(
+class ConditionalNestedLoopJoinIterator(
joinType: JoinType,
+ buildSide: GpuBuildSide,
builtBatch: LazySpillableColumnarBatch,
stream: Iterator[LazySpillableColumnarBatch],
+ streamAttributes: Seq[Attribute],
targetSize: Long,
condition: ast.CompiledExpression,
spillCallback: SpillCallback,
joinTime: GpuMetric,
totalTime: GpuMetric)
- extends NestedLoopJoinIterator(
- joinType,
- builtBatch,
+ extends SplittableJoinIterator(
+ s"$joinType join gather",
stream,
+ streamAttributes,
+ builtBatch,
targetSize,
- joinTime,
- totalTime) {
- private[this] var compiledExpr: Option[ast.CompiledExpression] = Some(condition)
- TaskContext.get().addTaskCompletionListener[Unit](_ => close())
-
- override def makeGatherer(streamBatch: LazySpillableColumnarBatch): Option[JoinGatherer] = {
- withResource(GpuColumnVector.from(streamBatch.getBatch)) { leftTable =>
- withResource(GpuColumnVector.from(builtBatch.getBatch)) { rightTable =>
- val map = joinType match {
- case LeftSemi =>
- compiledExpr.map(leftTable.conditionalLeftSemiJoinGatherMap(rightTable, _, false))
- .getOrElse(leftTable.leftSemiJoinGatherMap(rightTable, false))
- case LeftAnti =>
- compiledExpr.map(leftTable.conditionalLeftAntiJoinGatherMap(rightTable, _, false))
- .getOrElse(leftTable.leftAntiJoinGatherMap(rightTable, false))
- case _ =>
- throw new IllegalStateException(s"Unexpected join type $joinType")
+ spillCallback,
+ joinTime = joinTime,
+ streamTime = NoopMetric,
+ totalTime = totalTime) {
+ override def close(): Unit = {
+ if (!closed) {
+ super.close()
+ condition.close()
+ }
+ }
+
+ override def computeNumJoinRows(cb: ColumnarBatch): Long = {
+ withResource(GpuColumnVector.from(builtBatch.getBatch)) { builtTable =>
+ withResource(GpuColumnVector.from(cb)) { streamTable =>
+ val (left, right) = buildSide match {
+ case GpuBuildLeft => (builtTable, streamTable)
+ case GpuBuildRight => (streamTable, builtTable)
+ }
+ joinType match {
+ case _: InnerLike =>left.conditionalInnerJoinRowCount(right, condition, false)
+ case LeftOuter => left.conditionalLeftJoinRowCount(right, condition, false)
+ case RightOuter => right.conditionalLeftJoinRowCount(left, condition, false)
+ case LeftSemi => left.conditionalLeftSemiJoinRowCount(right, condition, false)
+ case LeftAnti => left.conditionalLeftAntiJoinRowCount(right, condition, false)
+ case _ => throw new IllegalStateException(s"Unsupported join type $joinType")
}
- withResource(map) { map =>
- val lazyMap = LazySpillableGatherMap(map, spillCallback, "left_map")
- val gatherer = JoinGatherer(lazyMap, streamBatch)
- if (gatherer.isDone) {
- gatherer.close()
- None
- } else {
- Some(gatherer)
+ }
+ }
+ }
+
+ override def createGatherer(
+ cb: ColumnarBatch,
+ numJoinRows: Option[Long]): Option[JoinGatherer] = {
+ if (numJoinRows.contains(0)) {
+ // nothing matched
+ return None
+ }
+ withResource(GpuColumnVector.from(builtBatch.getBatch)) { builtTable =>
+ withResource(GpuColumnVector.from(cb)) { streamTable =>
+ closeOnExcept(LazySpillableColumnarBatch(cb, spillCallback, "stream_data")) { streamBatch =>
+ val builtSpillOnly = LazySpillableColumnarBatch.spillOnly(builtBatch)
+ val (leftTable, leftBatch, rightTable, rightBatch) = buildSide match {
+ case GpuBuildLeft => (builtTable, builtSpillOnly, streamTable, streamBatch)
+ case GpuBuildRight => (streamTable, streamBatch, builtTable, builtSpillOnly)
}
+ val maps = computeGatherMaps(leftTable, rightTable, numJoinRows)
+ makeGatherer(maps, leftBatch, rightBatch)
}
}
}
}
- override def close(): Unit = {
- super.close()
- compiledExpr.foreach(_.close())
- compiledExpr = None
+ private def computeGatherMaps(
+ left: Table,
+ right: Table,
+ numJoinRows: Option[Long]): Array[GatherMap] = {
+ joinType match {
+ case _: InnerLike =>
+ numJoinRows.map { rowCount =>
+ left.conditionalInnerJoinGatherMaps(right, condition, false, rowCount)
+ }.getOrElse {
+ left.conditionalInnerJoinGatherMaps(right, condition, false)
+ }
+ case LeftOuter =>
+ numJoinRows.map { rowCount =>
+ left.conditionalLeftJoinGatherMaps(right, condition, false, rowCount)
+ }.getOrElse {
+ left.conditionalLeftJoinGatherMaps(right, condition, false)
+ }
+ case RightOuter =>
+ val maps = numJoinRows.map { rowCount =>
+ right.conditionalLeftJoinGatherMaps(left, condition, false, rowCount)
+ }.getOrElse {
+ right.conditionalLeftJoinGatherMaps(left, condition, false)
+ }
+ // Reverse the output of the join, because we expect the right gather map to
+ // always be on the right
+ maps.reverse
+ case LeftSemi =>
+ numJoinRows.map { rowCount =>
+ Array(left.conditionalLeftSemiJoinGatherMap(right, condition, false, rowCount))
+ }.getOrElse {
+ Array(left.conditionalLeftSemiJoinGatherMap(right, condition, false))
+ }
+ case LeftAnti =>
+ numJoinRows.map { rowCount =>
+ Array(left.conditionalLeftAntiJoinGatherMap(right, condition, false, rowCount))
+ }.getOrElse {
+ Array(left.conditionalLeftAntiJoinGatherMap(right, condition, false))
+ }
+ case _ => throw new IllegalStateException(s"Unsupported join type $joinType")
+ }
}
}
object GpuBroadcastNestedLoopJoinExecBase extends Arm {
def nestedLoopJoin(
joinType: JoinType,
+ buildSide: GpuBuildSide,
numFirstTableColumns: Int,
builtBatch: LazySpillableColumnarBatch,
stream: Iterator[LazySpillableColumnarBatch],
+ streamAttributes: Seq[Attribute],
targetSize: Long,
- buildSide: GpuBuildSide,
boundCondition: Option[GpuExpression],
spillCallback: SpillCallback,
numOutputRows: GpuMetric,
@@ -327,14 +313,9 @@ object GpuBroadcastNestedLoopJoinExecBase extends Arm {
new CrossJoinIterator(builtBatch, stream, targetSize, buildSide, joinTime, totalTime)
} else {
val compiledAst = boundCondition.get.convertToAst(numFirstTableColumns).compile()
- joinType match {
- case LeftAnti | LeftSemi =>
- assert(buildSide == GpuBuildRight)
- new ConditionalSemiOrAntiJoinIterator(joinType, builtBatch, stream, targetSize,
- compiledAst, spillCallback, joinTime, totalTime)
- case _ =>
- throw new UnsupportedOperationException("not supported yet")
- }
+ new ConditionalNestedLoopJoinIterator(joinType, buildSide, builtBatch,
+ stream, streamAttributes, targetSize, compiledAst, spillCallback,
+ joinTime = joinTime, totalTime = totalTime)
}
joinIterator.map { cb =>
joinOutputRows += cb.numRows()
@@ -478,28 +459,25 @@ abstract class GpuBroadcastNestedLoopJoinExecBase(
broadcastExchange.executeColumnarBroadcast[SerializeConcatHostBuffersDeserializeBatch]()
if (boundCondition.isEmpty) {
- doUnconditionalJoin(broadcastRelation, numFirstTableColumns)
+ doUnconditionalJoin(broadcastRelation)
} else {
doConditionalJoin(broadcastRelation, boundCondition, numFirstTableColumns)
}
}
private def doUnconditionalJoin(
- broadcastRelation: Broadcast[SerializeConcatHostBuffersDeserializeBatch],
- numFirstTableColumns: Int): RDD[ColumnarBatch] = {
+ broadcastRelation: Broadcast[SerializeConcatHostBuffersDeserializeBatch]
+ ): RDD[ColumnarBatch] = {
if (output.isEmpty) {
doUnconditionalJoinRowCount(broadcastRelation)
} else {
- val nestedLoopJoinType = joinType
- val buildTime = gpuLongMetric(BUILD_TIME)
- val buildDataSize = gpuLongMetric(BUILD_DATA_SIZE)
val joinOutputRows = gpuLongMetric(JOIN_OUTPUT_ROWS)
- val joinTime = gpuLongMetric(JOIN_TIME)
val numOutputRows = gpuLongMetric(NUM_OUTPUT_ROWS)
val numOutputBatches = gpuLongMetric(NUM_OUTPUT_BATCHES)
- val totalTime = gpuLongMetric(TOTAL_TIME)
+ val buildTime = gpuLongMetric(BUILD_TIME)
+ val buildDataSize = gpuLongMetric(BUILD_DATA_SIZE)
lazy val builtBatch = makeBuiltBatch(broadcastRelation, buildTime, buildDataSize)
- nestedLoopJoinType match {
+ joinType match {
case LeftSemi =>
// just return the left table
left.executeColumnar().mapPartitions { leftIter =>
@@ -516,19 +494,24 @@ abstract class GpuBroadcastNestedLoopJoinExecBase(
Iterator.single(new ColumnarBatch(Array(), 0))
}
case _ =>
+ // Everything else is treated like an unconditional cross join
+ val buildSide = getGpuBuildSide
val spillCallback = GpuMetric.makeSpillCallback(allMetrics)
+ val joinTime = gpuLongMetric(JOIN_TIME)
+ val totalTime = gpuLongMetric(TOTAL_TIME)
streamed.executeColumnar().mapPartitions { streamedIter =>
val lazyStream = streamedIter.map { cb =>
withResource(cb) { cb =>
LazySpillableColumnarBatch(cb, spillCallback, "stream_batch")
}
}
- GpuBroadcastNestedLoopJoinExecBase.nestedLoopJoin(
- nestedLoopJoinType, numFirstTableColumns,
+ new CrossJoinIterator(
LazySpillableColumnarBatch(builtBatch, spillCallback, "built_batch"),
- lazyStream, targetSizeBytes, getGpuBuildSide, None,
- spillCallback, numOutputRows, joinOutputRows, numOutputBatches,
- joinTime, totalTime)
+ lazyStream,
+ targetSizeBytes,
+ buildSide,
+ joinTime = joinTime,
+ totalTime = totalTime)
}
}
}
@@ -579,12 +562,14 @@ abstract class GpuBroadcastNestedLoopJoinExecBase(
val buildDataSize = gpuLongMetric(BUILD_DATA_SIZE)
lazy val builtBatch = makeBuiltBatch(broadcastRelation, buildTime, buildDataSize)
+ val streamAttributes = streamed.output
val numOutputRows = gpuLongMetric(NUM_OUTPUT_ROWS)
val numOutputBatches = gpuLongMetric(NUM_OUTPUT_BATCHES)
val totalTime = gpuLongMetric(TOTAL_TIME)
val joinTime = gpuLongMetric(JOIN_TIME)
val joinOutputRows = gpuLongMetric(JOIN_OUTPUT_ROWS)
val nestedLoopJoinType = joinType
+ val buildSide = getGpuBuildSide
val spillCallback = GpuMetric.makeSpillCallback(allMetrics)
streamed.executeColumnar().mapPartitions { streamedIter =>
val lazyStream = streamedIter.map { cb =>
@@ -593,11 +578,14 @@ abstract class GpuBroadcastNestedLoopJoinExecBase(
}
}
GpuBroadcastNestedLoopJoinExecBase.nestedLoopJoin(
- nestedLoopJoinType, numFirstTableColumns,
+ nestedLoopJoinType, buildSide, numFirstTableColumns,
LazySpillableColumnarBatch(builtBatch, spillCallback, "built_batch"),
- lazyStream, targetSizeBytes, getGpuBuildSide, boundCondition,
- spillCallback, numOutputRows, joinOutputRows, numOutputBatches,
- joinTime, totalTime)
+ lazyStream, streamAttributes, targetSizeBytes, boundCondition, spillCallback,
+ numOutputRows = numOutputRows,
+ joinOutputRows = joinOutputRows,
+ numOutputBatches = numOutputBatches,
+ joinTime = joinTime,
+ totalTime = totalTime)
}
}
}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala
index 2165c8c524e..57d095a0c0a 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala
@@ -15,14 +15,10 @@
*/
package org.apache.spark.sql.rapids.execution
-import scala.collection.mutable
-
-import ai.rapids.cudf.{DType, GatherMap, GroupByAggregation, NullPolicy, NvtxColor, ReductionAggregation, Table}
+import ai.rapids.cudf.{DType, GroupByAggregation, NullPolicy, NvtxColor, ReductionAggregation, Table}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.RapidsBuffer.SpillCallback
-import org.apache.spark.TaskContext
-import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.{Cross, ExistenceJoin, FullOuter, Inner, InnerLike, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.execution.SparkPlan
@@ -176,9 +172,9 @@ object GpuHashJoin extends Arm {
* An iterator that does a hash join against a stream of batches.
*/
class HashJoinIterator(
- builtInput: ColumnarBatch,
+ built: LazySpillableColumnarBatch,
val boundBuiltKeys: Seq[Expression],
- private val stream: Iterator[ColumnarBatch],
+ private val stream: Iterator[LazySpillableColumnarBatch],
val boundStreamKeys: Seq[Expression],
val streamAttributes: Seq[Attribute],
val targetSize: Long,
@@ -188,104 +184,59 @@ class HashJoinIterator(
private val spillCallback: SpillCallback,
private val streamTime: GpuMetric,
private val joinTime: GpuMetric,
- private val totalTime: GpuMetric) extends Iterator[ColumnarBatch] with Arm with Logging {
- import scala.collection.JavaConverters._
-
- // For some join types even if there is no stream data we might output something
- private var initialJoin = true
- // If the join explodes this holds batches from the stream side split into smaller
- // pieces.
- private val pendingSplits = mutable.Queue[SpillableColumnarBatch]()
- private var nextCb: Option[ColumnarBatch] = None
- private var gathererStore: Option[JoinGatherer] = None
- // Close the input data, the lazy spillable batch now owns it.
- private val built = withResource(builtInput) { builtInput =>
- LazySpillableColumnarBatch(builtInput, spillCallback, "built")
- }
-
+ private val totalTime: GpuMetric)
+ extends SplittableJoinIterator(
+ s"hash $joinType gather",
+ stream,
+ streamAttributes,
+ built,
+ targetSize,
+ spillCallback,
+ joinTime = joinTime,
+ streamTime = streamTime,
+ totalTime = totalTime) {
// We can cache this because the build side is not changing
- private lazy val estimatedRowsPerStreamBatch = joinType match {
+ private lazy val streamMagnificationFactor = joinType match {
case _: InnerLike | LeftOuter | RightOuter =>
withResource(GpuProjectExec.project(built.getBatch, boundBuiltKeys)) { builtKeys =>
- guessJoinRowsForTargetSize(builtKeys)
+ guessStreamMagnificationFactor(builtKeys)
}
case _ =>
// existence joins don't change size, and FullOuter cannot be split
1.0
}
- private var closed = false
-
- def close(): Unit = {
- if (!closed) {
- built.close()
- nextCb.foreach(_.close())
- nextCb = None
- gathererStore.foreach(_.close())
- gathererStore = None
- pendingSplits.foreach(_.close())
- pendingSplits.clear()
- closed = true
- }
- }
-
- TaskContext.get().addTaskCompletionListener[Unit](_ => close())
-
- private def nextCbFromGatherer(): Option[ColumnarBatch] = {
- withResource(new NvtxWithMetrics("hash join gather", NvtxColor.DARK_GREEN, joinTime)) { _ =>
- val ret = gathererStore.map { gather =>
- val nextRows = JoinGatherer.getRowsInNextBatch(gather, targetSize)
- gather.gatherNext(nextRows)
- }
- if (gathererStore.exists(_.isDone)) {
- gathererStore.foreach(_.close())
- gathererStore = None
- }
- if (ret.isDefined) {
- // We are about to return something. We got everything we need from it so now let it spill
- // if there is more to be gathered later on.
- gathererStore.foreach(_.allowSpilling())
- }
- ret
+ override def computeNumJoinRows(cb: ColumnarBatch): Long = {
+ // TODO: Replace this estimate with exact join row counts using the corresponding cudf APIs
+ // being added in https://github.com/rapidsai/cudf/issues/9053.
+ joinType match {
+ case _: InnerLike | LeftOuter | RightOuter =>
+ Math.ceil(cb.numRows() * streamMagnificationFactor).toLong
+ case _ => cb.numRows()
}
}
- private def makeGatherer(
- maps: Array[GatherMap],
- leftData: LazySpillableColumnarBatch,
- rightData: LazySpillableColumnarBatch): Option[JoinGatherer] = {
- assert(maps.length > 0 && maps.length <= 2)
+ override def createGatherer(
+ cb: ColumnarBatch,
+ numJoinRows: Option[Long]): Option[JoinGatherer] = {
try {
- val leftMap = maps.head
- val rightMap = if (maps.length > 1) {
- if (rightData.numCols == 0) {
- // No data so don't bother with it
- None
- } else {
- Some(maps(1))
- }
- } else {
- None
- }
-
- val lazyLeftMap = LazySpillableGatherMap(leftMap, spillCallback, "left_map")
- val gatherer = rightMap match {
- case None =>
- rightData.close()
- JoinGatherer(lazyLeftMap, leftData)
- case Some(right) =>
- val lazyRightMap = LazySpillableGatherMap(right, spillCallback, "right_map")
- JoinGatherer(lazyLeftMap, leftData, lazyRightMap, rightData)
+ withResource(GpuProjectExec.project(built.getBatch, boundBuiltKeys)) { builtKeys =>
+ joinGatherer(builtKeys, built, cb)
}
- if (gatherer.isDone) {
- // Nothing matched...
- gatherer.close()
+ } catch {
+ // This should work for all join types except for FullOuter. There should be no need
+ // to do this for any of the existence joins because the output rows will never be
+ // larger than the input rows on the stream side.
+ case oom: OutOfMemoryError if joinType.isInstanceOf[InnerLike]
+ || joinType == LeftOuter
+ || joinType == RightOuter =>
+ // Because this is just an estimate, it is possible for us to get this wrong, so
+ // make sure we at least split the batch in half.
+ val numBatches = Math.max(2, estimatedNumBatches(cb))
+
+ // Split batch and return no gatherer so the outer loop will try again
+ splitAndSave(cb, numBatches, Some(oom))
None
- } else {
- Some(gatherer)
- }
- } finally {
- maps.foreach(_.close())
}
}
@@ -357,145 +308,32 @@ class HashJoinIterator(
}
/**
- * Guess how large a stream side batch should be to avoid really huge gather maps.
+ * Guess the magnification factor for a stream side batch.
* This is temporary until cudf gives us APIs to get the actual gather map size.
*/
- private def guessJoinRowsForTargetSize(builtKeys: ColumnarBatch): Double = {
+ private def guessStreamMagnificationFactor(builtKeys: ColumnarBatch): Double = {
// Based off of the keys on the build side guess at how many output rows there
// will be for each input row on the stream side. This does not take into account
// the join type, data skew or even if the keys actually match.
- val averageStreamSizeExpansion = withResource(countGroups(builtKeys)) { builtCount =>
+ withResource(countGroups(builtKeys)) { builtCount =>
val counts = builtCount.getColumn(builtCount.getNumberOfColumns - 1)
withResource(counts.reduce(ReductionAggregation.mean(), DType.FLOAT64)) { scalarAverage =>
scalarAverage.getDouble
}
}
-
- // We want the gather map size to be around the target size. There are two gather maps
- // that are made up of ints, so estimate how many rows per batch on the stream side
- // will produce the desired gather map size.
- val approximateStreamRowCount = ((targetSize.toDouble / 2) /
- DType.INT32.getSizeInBytes) / averageStreamSizeExpansion
- Math.min(Int.MaxValue, approximateStreamRowCount)
}
private def estimatedNumBatches(cb: ColumnarBatch): Int = joinType match {
case _: InnerLike | LeftOuter | RightOuter =>
+ // We want the gather map size to be around the target size. There are two gather maps
+ // that are made up of ints, so estimate how many rows per batch on the stream side
+ // will produce the desired gather map size.
+ val approximateStreamRowCount = ((targetSize.toDouble / 2) /
+ DType.INT32.getSizeInBytes) / streamMagnificationFactor
+ val estimatedRowsPerStreamBatch = Math.min(Int.MaxValue, approximateStreamRowCount)
Math.ceil(cb.numRows() / estimatedRowsPerStreamBatch).toInt
case _ => 1
}
-
- private def splitAndSave(cb: ColumnarBatch,
- numBatches: Int,
- oom: Option[OutOfMemoryError] = None): Unit = {
- val batchSize = cb.numRows() / numBatches
- if (oom.isDefined && batchSize < 100) {
- // We just need some kind of cutoff to not get stuck in a loop if the batches get to be too
- // small but we want to at least give it a chance to work (mostly for tests where the
- // targetSize can be set really small)
- throw oom.get
- }
- val msg = s"Split stream batch into $numBatches batches of about $batchSize rows"
- if (oom.isDefined) {
- logWarning(s"OOM Encountered: $msg")
- } else {
- logInfo(msg)
- }
- val splits = withResource(GpuColumnVector.from(cb)) { tab =>
- val splitIndexes = (1 until numBatches).map(num => num * batchSize)
- tab.contiguousSplit(splitIndexes: _*)
- }
- withResource(splits) { splits =>
- val schema = GpuColumnVector.extractTypes(cb)
- pendingSplits ++= splits.map { ct =>
- SpillableColumnarBatch(ct, schema,
- SpillPriorities.ACTIVE_ON_DECK_PRIORITY, spillCallback)
- }
- }
- }
-
- override def hasNext: Boolean = {
- if (closed) {
- return false
- }
- var mayContinue = true
- while (nextCb.isEmpty && mayContinue) {
- val startTime = System.nanoTime()
- if (gathererStore.exists(!_.isDone)) {
- nextCb = nextCbFromGatherer()
- } else if (pendingSplits.nonEmpty || stream.hasNext) {
- // Need to refill the gatherer
- gathererStore.foreach(_.close())
- gathererStore = None
- val cb = if (pendingSplits.isEmpty) {
- val cb = stream.next()
- val estimatedBatches = estimatedNumBatches(cb)
- // The cutoff is arbitrary, just to avoid doing duplicate work
- if (estimatedBatches > 2) {
- withResource(cb) { cb =>
- splitAndSave(cb, estimatedBatches)
- }
- withResource(pendingSplits.dequeue()) { scb =>
- scb.getColumnarBatch()
- }
- } else {
- cb
- }
- } else {
- withResource(pendingSplits.dequeue()) { scb =>
- scb.getColumnarBatch()
- }
- }
- withResource(cb) { cb =>
- streamTime += (System.nanoTime() - startTime)
- try {
- withResource(GpuProjectExec.project(built.getBatch, boundBuiltKeys)) { builtKeys =>
- gathererStore = joinGatherer(builtKeys, built, cb)
- }
- } catch {
- // This should work for all join types except for FullOuter. There should be no need
- // to do this for any of the existence joins because the output rows will never be
- // larger than the input rows on the stream side.
- case oom: OutOfMemoryError if joinType.isInstanceOf[InnerLike]
- || joinType == LeftOuter
- || joinType == RightOuter =>
- // Because this is just an estimate, it is possible for us to get this wrong, so
- // make sure we at least split the batch in half.
- val numBatches = Math.max(2, estimatedNumBatches(cb))
- // Now try again nextCbFromGatherer will return None, but there will be more to do
- // so the loop will not finish
- splitAndSave(cb, numBatches, Some(oom))
- }
- }
- nextCb = nextCbFromGatherer()
- } else if (initialJoin) {
- withResource(GpuColumnVector.emptyBatch(streamAttributes.asJava)) { cb =>
- withResource(GpuProjectExec.project(built.getBatch, boundBuiltKeys)) { builtKeys =>
- gathererStore = joinGatherer(builtKeys, built, cb)
- }
- }
- nextCb = nextCbFromGatherer()
- } else {
- mayContinue = false
- }
- totalTime += (System.nanoTime() - startTime)
- initialJoin = false
- }
- if (nextCb.isEmpty) {
- // Nothing is left to return so close ASAP.
- close()
- }
- nextCb.isDefined
- }
-
- override def next(): ColumnarBatch = {
- if (!hasNext) {
- throw new NoSuchElementException()
- }
- val ret = nextCb.get
- nextCb = None
- ret
- }
}
trait GpuHashJoin extends GpuExec {
@@ -628,10 +466,20 @@ trait GpuHashJoin extends GpuExec {
GpuColumnVector.incRefCounts(builtBatch)
}
+ val spillableBuiltBatch = withResource(nullFiltered) {
+ LazySpillableColumnarBatch(_, spillCallback, "built")
+ }
+
+ val lazyStream = stream.map { cb =>
+ withResource(cb) { cb =>
+ LazySpillableColumnarBatch(cb, spillCallback, "stream_batch")
+ }
+ }
+
// The HashJoinIterator takes ownership of the built keys and built data. It will close
// them when it is done
val joinIterator =
- new HashJoinIterator(nullFiltered, boundBuildKeys, stream, boundStreamKeys,
+ new HashJoinIterator(spillableBuiltBatch, boundBuildKeys, lazyStream, boundStreamKeys,
streamedPlan.output, realTarget, joinType, buildSide, compareNullsEqual, spillCallback,
streamTime, joinTime, totalTime)
if (boundCondition.isDefined) {