diff --git a/integration_tests/src/main/python/join_test.py b/integration_tests/src/main/python/join_test.py index 9ccb36ea923..1252dcecc7f 100644 --- a/integration_tests/src/main/python/join_test.py +++ b/integration_tests/src/main/python/join_test.py @@ -23,6 +23,12 @@ pytest.param(FloatGen(), marks=[incompat]), pytest.param(DoubleGen(), marks=[incompat])] +all_gen_no_nulls = [StringGen(nullable=False), ByteGen(nullable=False), + ShortGen(nullable=False), IntegerGen(nullable=False), LongGen(nullable=False), + BooleanGen(nullable=False), DateGen(nullable=False), TimestampGen(nullable=False), + pytest.param(FloatGen(nullable=False), marks=[incompat]), + pytest.param(DoubleGen(nullable=False), marks=[incompat])] + double_gen = [pytest.param(DoubleGen(), marks=[incompat])] _sortmerge_join_conf = {'spark.sql.autoBroadcastJoinThreshold': '-1', @@ -36,10 +42,21 @@ def create_df(spark, data_gen, left_length, right_length): .withColumnRenamed("b", "r_b") return left, right +# Once https://github.com/NVIDIA/spark-rapids/issues/280 is fixed this test should be deleted +@ignore_order(local=True) +@pytest.mark.parametrize('data_gen', all_gen_no_nulls, ids=idfn) +@pytest.mark.parametrize('join_type', ['FullOuter'], ids=idfn) +def test_sortmerge_join_no_nulls(data_gen, join_type): + def do_join(spark): + left, right = create_df(spark, data_gen, 500, 500) + return left.join(right, left.a == right.r_a, join_type) + assert_gpu_and_cpu_are_equal_collect(do_join, conf=_sortmerge_join_conf) + # local sort becasue of https://github.com/NVIDIA/spark-rapids/issues/84 @ignore_order(local=True) @pytest.mark.parametrize('data_gen', all_gen, ids=idfn) -@pytest.mark.parametrize('join_type', ['Inner', 'LeftSemi', 'LeftAnti'], ids=idfn) +@pytest.mark.parametrize('join_type', ['Left', 'Inner', 'LeftSemi', 'LeftAnti', + pytest.param('FullOuter', marks=[pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')])], ids=idfn) def test_sortmerge_join(data_gen, join_type): def do_join(spark): left, right = create_df(spark, data_gen, 500, 500) @@ -47,19 +64,29 @@ def do_join(spark): assert_gpu_and_cpu_are_equal_collect(do_join, conf=_sortmerge_join_conf) +# Once https://github.com/NVIDIA/spark-rapids/issues/280 is fixed this test should be deleted +@ignore_order(local=True) +@pytest.mark.parametrize('data_gen', all_gen_no_nulls, ids=idfn) +@pytest.mark.parametrize('join_type', ['FullOuter'], ids=idfn) +def test_broadcast_join_no_nulls(data_gen, join_type): + def do_join(spark): + left, right = create_df(spark, data_gen, 500, 250) + return left.join(broadcast(right), left.a == right.r_a, join_type) + assert_gpu_and_cpu_are_equal_collect(do_join) + # For tests which include broadcast joins, right table is broadcasted and hence it is # made smaller than left table. # local sort becasue of https://github.com/NVIDIA/spark-rapids/issues/84 @ignore_order(local=True) @pytest.mark.parametrize('data_gen', all_gen, ids=idfn) -@pytest.mark.parametrize('join_type', ['Inner', 'LeftSemi', 'LeftAnti'], ids=idfn) +@pytest.mark.parametrize('join_type', ['Left', 'Inner', 'LeftSemi', 'LeftAnti', + pytest.param('FullOuter', marks=[pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')])], ids=idfn) def test_broadcast_join(data_gen, join_type): def do_join(spark): left, right = create_df(spark, data_gen, 500, 250) return left.join(broadcast(right), left.a == right.r_a, join_type) assert_gpu_and_cpu_are_equal_collect(do_join) - # local sort becasue of https://github.com/NVIDIA/spark-rapids/issues/84 @ignore_order(local=True) @pytest.mark.parametrize('data_gen', all_gen, ids=idfn) @@ -78,7 +105,7 @@ def do_join(spark): ('b', StringGen()), ('c', BooleanGen())] @ignore_order -@pytest.mark.parametrize('join_type', ['Inner', 'LeftSemi', 'LeftAnti'], ids=idfn) +@pytest.mark.parametrize('join_type', ['Left', 'Inner', 'LeftSemi', 'LeftAnti', 'FullOuter'], ids=idfn) def test_broadcast_join_mixed(join_type): def do_join(spark): left = gen_df(spark, _mixed_df1_with_nulls, length=500) diff --git a/integration_tests/src/main/python/qa_nightly_sql.py b/integration_tests/src/main/python/qa_nightly_sql.py index 0b9bee57810..92cefe27209 100644 --- a/integration_tests/src/main/python/qa_nightly_sql.py +++ b/integration_tests/src/main/python/qa_nightly_sql.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + SELECT_SQL = [ # (" FUNCTIONAL CHECKING", "FUNCTIONAL CHECKING"), # (" AGG functions", "AGG functions"), @@ -738,16 +740,16 @@ #("SELECT test_table.strF as strF, test_table1.strF as strF1 from test_table RIGHT JOIN test_table1 ON test_table.strF=test_table1.strF", "test_table.strF, test_table1.strF RIGHT JOIN test_table1 ON test_table.strF=test_table1.strF"), #("SELECT test_table.dateF as dateF, test_table1.dateF as dateF1 from test_table RIGHT JOIN test_table1 ON test_table.dateF=test_table1.dateF", "test_table.dateF, test_table1.dateF RIGHT JOIN test_table1 ON test_table.dateF=test_table1.dateF"), #("SELECT test_table.timestampF as timestampF, test_table1.timestampF as timestampF1 from test_table RIGHT JOIN test_table1 ON test_table.timestampF=test_table1.timestampF", "test_table.timestampF, test_table1.timestampF RIGHT JOIN test_table1 ON test_table.timestampF=test_table1.timestampF"), -#("SELECT test_table.byteF as byteF, test_table1.byteF as byteF1 from test_table FULL JOIN test_table1 ON test_table.byteF=test_table1.byteF", "test_table.byteF, test_table1.byteF FULL JOIN test_table1 ON test_table.byteF=test_table1.byteF"), -#("SELECT test_table.shortF as shortF, test_table1.shortF as shortF1 from test_table FULL JOIN test_table1 ON test_table.shortF=test_table1.shortF", "test_table.shortF, test_table1.shortF FULL JOIN test_table1 ON test_table.shortF=test_table1.shortF"), -#("SELECT test_table.intF as intF, test_table1.intF as intF1 from test_table FULL JOIN test_table1 ON test_table.intF=test_table1.intF", "test_table.intF, test_table1.intF FULL JOIN test_table1 ON test_table.intF=test_table1.intF"), -#("SELECT test_table.longF as longF, test_table1.longF as longF1 from test_table FULL JOIN test_table1 ON test_table.longF=test_table1.longF", "test_table.longF, test_table1.longF FULL JOIN test_table1 ON test_table.longF=test_table1.longF"), -#("SELECT test_table.floatF as floatF, test_table1.floatF as floatF1 from test_table FULL JOIN test_table1 ON test_table.floatF=test_table1.floatF", "test_table.floatF, test_table1.floatF FULL JOIN test_table1 ON test_table.floatF=test_table1.floatF"), -#("SELECT test_table.doubleF as doubleF, test_table1.doubleF as doubleF1 from test_table FULL JOIN test_table1 ON test_table.doubleF=test_table1.doubleF", "test_table.doubleF, test_table1.doubleF FULL JOIN test_table1 ON test_table.doubleF=test_table1.doubleF"), -#("SELECT test_table.booleanF as booleanF, test_table1.booleanF as booleanF1 from test_table FULL JOIN test_table1 ON test_table.booleanF=test_table1.booleanF", "test_table.booleanF, test_table1.booleanF FULL JOIN test_table1 ON test_table.booleanF=test_table1.booleanF"), -#("SELECT test_table.strF as strF, test_table1.strF as strF1 from test_table FULL JOIN test_table1 ON test_table.strF=test_table1.strF", "test_table.strF, test_table1.strF FULL JOIN test_table1 ON test_table.strF=test_table1.strF"), -#("SELECT test_table.dateF as dateF, test_table1.dateF as dateF1 from test_table FULL JOIN test_table1 ON test_table.dateF=test_table1.dateF", "test_table.dateF, test_table1.dateF FULL JOIN test_table1 ON test_table.dateF=test_table1.dateF"), -#("SELECT test_table.timestampF as timestampF, test_table1.timestampF as timestampF1 from test_table FULL JOIN test_table1 ON test_table.timestampF=test_table1.timestampF", "test_table.timestampF, test_table1.timestampF FULL JOIN test_table1 ON test_table.timestampF=test_table1.timestampF"), +pytest.param(("SELECT test_table.byteF as byteF, test_table1.byteF as byteF1 from test_table FULL JOIN test_table1 ON test_table.byteF=test_table1.byteF", "test_table.byteF, test_table1.byteF FULL JOIN test_table1 ON test_table.byteF=test_table1.byteF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')), +pytest.param(("SELECT test_table.shortF as shortF, test_table1.shortF as shortF1 from test_table FULL JOIN test_table1 ON test_table.shortF=test_table1.shortF", "test_table.shortF, test_table1.shortF FULL JOIN test_table1 ON test_table.shortF=test_table1.shortF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')), +pytest.param(("SELECT test_table.intF as intF, test_table1.intF as intF1 from test_table FULL JOIN test_table1 ON test_table.intF=test_table1.intF", "test_table.intF, test_table1.intF FULL JOIN test_table1 ON test_table.intF=test_table1.intF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')), +pytest.param(("SELECT test_table.longF as longF, test_table1.longF as longF1 from test_table FULL JOIN test_table1 ON test_table.longF=test_table1.longF", "test_table.longF, test_table1.longF FULL JOIN test_table1 ON test_table.longF=test_table1.longF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')), +pytest.param(("SELECT test_table.floatF as floatF, test_table1.floatF as floatF1 from test_table FULL JOIN test_table1 ON test_table.floatF=test_table1.floatF", "test_table.floatF, test_table1.floatF FULL JOIN test_table1 ON test_table.floatF=test_table1.floatF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')), +pytest.param(("SELECT test_table.doubleF as doubleF, test_table1.doubleF as doubleF1 from test_table FULL JOIN test_table1 ON test_table.doubleF=test_table1.doubleF", "test_table.doubleF, test_table1.doubleF FULL JOIN test_table1 ON test_table.doubleF=test_table1.doubleF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')), +pytest.param(("SELECT test_table.booleanF as booleanF, test_table1.booleanF as booleanF1 from test_table FULL JOIN test_table1 ON test_table.booleanF=test_table1.booleanF", "test_table.booleanF, test_table1.booleanF FULL JOIN test_table1 ON test_table.booleanF=test_table1.booleanF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')), +pytest.param(("SELECT test_table.strF as strF, test_table1.strF as strF1 from test_table FULL JOIN test_table1 ON test_table.strF=test_table1.strF", "test_table.strF, test_table1.strF FULL JOIN test_table1 ON test_table.strF=test_table1.strF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')), +pytest.param(("SELECT test_table.dateF as dateF, test_table1.dateF as dateF1 from test_table FULL JOIN test_table1 ON test_table.dateF=test_table1.dateF", "test_table.dateF, test_table1.dateF FULL JOIN test_table1 ON test_table.dateF=test_table1.dateF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')), +pytest.param(("SELECT test_table.timestampF as timestampF, test_table1.timestampF as timestampF1 from test_table FULL JOIN test_table1 ON test_table.timestampF=test_table1.timestampF", "test_table.timestampF, test_table1.timestampF FULL JOIN test_table1 ON test_table.timestampF=test_table1.timestampF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')), ] diff --git a/integration_tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala b/integration_tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala index 258aa9d9c28..614bc8a3440 100644 --- a/integration_tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala +++ b/integration_tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala @@ -53,6 +53,11 @@ class JoinsSuite extends SparkQueryCompareTestSuite { (A, B) => A.join(B, A("longs") === B("longs"), "LeftAnti") } + IGNORE_ORDER_testSparkResultsAreEqual2("Test hash full join", longsDf, biggerLongsDf, + conf = shuffledJoinConf) { + (A, B) => A.join(B, A("longs") === B("longs"), "FullOuter") + } + // test replacement of sort merge join with hash join // make sure broadcast size small enough it doesn't get used testSparkResultsAreEqual2("Test replace sort merge join with hash join", diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashJoin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashJoin.scala index e6290d1e280..cdc8834ab23 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashJoin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashJoin.scala @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{NvtxColor, Table} import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftAnti, LeftOuter, LeftSemi} +import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, Inner, InnerLike, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, HashJoin} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -27,8 +27,18 @@ object GpuHashJoin { def tagJoin( meta: RapidsMeta[_, _, _], joinType: JoinType, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], condition: Option[Expression]): Unit = joinType match { case Inner => + case FullOuter => + if (leftKeys.exists(_.nullable) || rightKeys.exists(_.nullable)) { + // https://github.com/rapidsai/cudf/issues/5563 + meta.willNotWorkOnGpu("Full outer join does not work on nullable keys") + } + if (condition.isDefined) { + meta.willNotWorkOnGpu(s"$joinType joins currently do not support conditions") + } case LeftOuter | LeftSemi | LeftAnti => if (condition.isDefined) { meta.willNotWorkOnGpu(s"$joinType joins currently do not support conditions") @@ -39,6 +49,25 @@ object GpuHashJoin { trait GpuHashJoin extends GpuExec with HashJoin { + override def output: Seq[Attribute] = { + joinType match { + case _: InnerLike => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case j: ExistenceJoin => + left.output :+ j.exists + case LeftExistence(_) => + left.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new IllegalArgumentException(s"GpuHashJoin should not take $x as the JoinType") + } + } + protected lazy val (gpuBuildKeys, gpuStreamedKeys) = { require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType), "Join keys from two sides should have same types") @@ -122,6 +151,9 @@ trait GpuHashJoin extends GpuExec with HashJoin { case LeftAnti => leftTable.onColumns(joinKeyIndices: _*) .leftAntiJoin(rightTable.onColumns(joinKeyIndices: _*)) + case FullOuter => + leftTable.onColumns(joinKeyIndices: _*) + .fullJoin(rightTable.onColumns(joinKeyIndices: _*)) case _ => throw new NotImplementedError(s"Joint Type ${joinType.getClass} is not currently" + s" supported") } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala index 5be8d1aae8d..4bd98418a5a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala @@ -45,7 +45,7 @@ class GpuShuffledHashJoinMeta( override val childExprs: Seq[BaseExprMeta[_]] = leftKeys ++ rightKeys ++ condition override def tagPlanForGpu(): Unit = { - GpuHashJoin.tagJoin(this, join.joinType, join.condition) + GpuHashJoin.tagJoin(this, join.joinType, join.leftKeys, join.rightKeys, join.condition) } override def convertToGpu(): GpuExec = diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortMergeJoinExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortMergeJoinExec.scala index 3b2b4473d42..29ba63d625e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortMergeJoinExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortMergeJoinExec.scala @@ -38,7 +38,7 @@ class GpuSortMergeJoinMeta( override def tagPlanForGpu(): Unit = { // Use conditions from Hash Join - GpuHashJoin.tagJoin(this, join.joinType, join.condition) + GpuHashJoin.tagJoin(this, join.joinType, join.leftKeys, join.rightKeys, join.condition) if (!conf.enableReplaceSortMergeJoin) { willNotWorkOnGpu(s"Not replacing sort merge join with hash join, " + diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala index 7be40e84809..7545f42f05e 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala @@ -47,7 +47,7 @@ class GpuBroadcastHashJoinMeta( override val childExprs: Seq[BaseExprMeta[_]] = leftKeys ++ rightKeys ++ condition override def tagPlanForGpu(): Unit = { - GpuHashJoin.tagJoin(this, join.joinType, join.condition) + GpuHashJoin.tagJoin(this, join.joinType, join.leftKeys, join.rightKeys, join.condition) val buildSide = join.buildSide match { case BuildLeft => childPlans(0)