Skip to content

Commit

Permalink
Add in support for Full Outer Join on non-null keys (NVIDIA#281)
Browse files Browse the repository at this point in the history
  • Loading branch information
revans2 authored Jun 24, 2020
1 parent 1f947f4 commit 62756bf
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 18 deletions.
35 changes: 31 additions & 4 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
pytest.param(FloatGen(), marks=[incompat]),
pytest.param(DoubleGen(), marks=[incompat])]

all_gen_no_nulls = [StringGen(nullable=False), ByteGen(nullable=False),
ShortGen(nullable=False), IntegerGen(nullable=False), LongGen(nullable=False),
BooleanGen(nullable=False), DateGen(nullable=False), TimestampGen(nullable=False),
pytest.param(FloatGen(nullable=False), marks=[incompat]),
pytest.param(DoubleGen(nullable=False), marks=[incompat])]

double_gen = [pytest.param(DoubleGen(), marks=[incompat])]

_sortmerge_join_conf = {'spark.sql.autoBroadcastJoinThreshold': '-1',
Expand All @@ -36,30 +42,51 @@ def create_df(spark, data_gen, left_length, right_length):
.withColumnRenamed("b", "r_b")
return left, right

# Once https://github.com/NVIDIA/spark-rapids/issues/280 is fixed this test should be deleted
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen_no_nulls, ids=idfn)
@pytest.mark.parametrize('join_type', ['FullOuter'], ids=idfn)
def test_sortmerge_join_no_nulls(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 500, 500)
return left.join(right, left.a == right.r_a, join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=_sortmerge_join_conf)

# local sort becasue of https://github.com/NVIDIA/spark-rapids/issues/84
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
@pytest.mark.parametrize('join_type', ['Inner', 'LeftSemi', 'LeftAnti'], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Inner', 'LeftSemi', 'LeftAnti',
pytest.param('FullOuter', marks=[pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')])], ids=idfn)
def test_sortmerge_join(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 500, 500)
return left.join(right, left.a == right.r_a, join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=_sortmerge_join_conf)


# Once https://github.com/NVIDIA/spark-rapids/issues/280 is fixed this test should be deleted
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen_no_nulls, ids=idfn)
@pytest.mark.parametrize('join_type', ['FullOuter'], ids=idfn)
def test_broadcast_join_no_nulls(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 500, 250)
return left.join(broadcast(right), left.a == right.r_a, join_type)
assert_gpu_and_cpu_are_equal_collect(do_join)

# For tests which include broadcast joins, right table is broadcasted and hence it is
# made smaller than left table.
# local sort becasue of https://github.com/NVIDIA/spark-rapids/issues/84
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
@pytest.mark.parametrize('join_type', ['Inner', 'LeftSemi', 'LeftAnti'], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Inner', 'LeftSemi', 'LeftAnti',
pytest.param('FullOuter', marks=[pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')])], ids=idfn)
def test_broadcast_join(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 500, 250)
return left.join(broadcast(right), left.a == right.r_a, join_type)
assert_gpu_and_cpu_are_equal_collect(do_join)


# local sort becasue of https://github.com/NVIDIA/spark-rapids/issues/84
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
Expand All @@ -78,7 +105,7 @@ def do_join(spark):
('b', StringGen()), ('c', BooleanGen())]

@ignore_order
@pytest.mark.parametrize('join_type', ['Inner', 'LeftSemi', 'LeftAnti'], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Inner', 'LeftSemi', 'LeftAnti', 'FullOuter'], ids=idfn)
def test_broadcast_join_mixed(join_type):
def do_join(spark):
left = gen_df(spark, _mixed_df1_with_nulls, length=500)
Expand Down
22 changes: 12 additions & 10 deletions integration_tests/src/main/python/qa_nightly_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

SELECT_SQL = [
# (" FUNCTIONAL CHECKING", "FUNCTIONAL CHECKING"),
# (" AGG functions", "AGG functions"),
Expand Down Expand Up @@ -738,16 +740,16 @@
#("SELECT test_table.strF as strF, test_table1.strF as strF1 from test_table RIGHT JOIN test_table1 ON test_table.strF=test_table1.strF", "test_table.strF, test_table1.strF RIGHT JOIN test_table1 ON test_table.strF=test_table1.strF"),
#("SELECT test_table.dateF as dateF, test_table1.dateF as dateF1 from test_table RIGHT JOIN test_table1 ON test_table.dateF=test_table1.dateF", "test_table.dateF, test_table1.dateF RIGHT JOIN test_table1 ON test_table.dateF=test_table1.dateF"),
#("SELECT test_table.timestampF as timestampF, test_table1.timestampF as timestampF1 from test_table RIGHT JOIN test_table1 ON test_table.timestampF=test_table1.timestampF", "test_table.timestampF, test_table1.timestampF RIGHT JOIN test_table1 ON test_table.timestampF=test_table1.timestampF"),
#("SELECT test_table.byteF as byteF, test_table1.byteF as byteF1 from test_table FULL JOIN test_table1 ON test_table.byteF=test_table1.byteF", "test_table.byteF, test_table1.byteF FULL JOIN test_table1 ON test_table.byteF=test_table1.byteF"),
#("SELECT test_table.shortF as shortF, test_table1.shortF as shortF1 from test_table FULL JOIN test_table1 ON test_table.shortF=test_table1.shortF", "test_table.shortF, test_table1.shortF FULL JOIN test_table1 ON test_table.shortF=test_table1.shortF"),
#("SELECT test_table.intF as intF, test_table1.intF as intF1 from test_table FULL JOIN test_table1 ON test_table.intF=test_table1.intF", "test_table.intF, test_table1.intF FULL JOIN test_table1 ON test_table.intF=test_table1.intF"),
#("SELECT test_table.longF as longF, test_table1.longF as longF1 from test_table FULL JOIN test_table1 ON test_table.longF=test_table1.longF", "test_table.longF, test_table1.longF FULL JOIN test_table1 ON test_table.longF=test_table1.longF"),
#("SELECT test_table.floatF as floatF, test_table1.floatF as floatF1 from test_table FULL JOIN test_table1 ON test_table.floatF=test_table1.floatF", "test_table.floatF, test_table1.floatF FULL JOIN test_table1 ON test_table.floatF=test_table1.floatF"),
#("SELECT test_table.doubleF as doubleF, test_table1.doubleF as doubleF1 from test_table FULL JOIN test_table1 ON test_table.doubleF=test_table1.doubleF", "test_table.doubleF, test_table1.doubleF FULL JOIN test_table1 ON test_table.doubleF=test_table1.doubleF"),
#("SELECT test_table.booleanF as booleanF, test_table1.booleanF as booleanF1 from test_table FULL JOIN test_table1 ON test_table.booleanF=test_table1.booleanF", "test_table.booleanF, test_table1.booleanF FULL JOIN test_table1 ON test_table.booleanF=test_table1.booleanF"),
#("SELECT test_table.strF as strF, test_table1.strF as strF1 from test_table FULL JOIN test_table1 ON test_table.strF=test_table1.strF", "test_table.strF, test_table1.strF FULL JOIN test_table1 ON test_table.strF=test_table1.strF"),
#("SELECT test_table.dateF as dateF, test_table1.dateF as dateF1 from test_table FULL JOIN test_table1 ON test_table.dateF=test_table1.dateF", "test_table.dateF, test_table1.dateF FULL JOIN test_table1 ON test_table.dateF=test_table1.dateF"),
#("SELECT test_table.timestampF as timestampF, test_table1.timestampF as timestampF1 from test_table FULL JOIN test_table1 ON test_table.timestampF=test_table1.timestampF", "test_table.timestampF, test_table1.timestampF FULL JOIN test_table1 ON test_table.timestampF=test_table1.timestampF"),
pytest.param(("SELECT test_table.byteF as byteF, test_table1.byteF as byteF1 from test_table FULL JOIN test_table1 ON test_table.byteF=test_table1.byteF", "test_table.byteF, test_table1.byteF FULL JOIN test_table1 ON test_table.byteF=test_table1.byteF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
pytest.param(("SELECT test_table.shortF as shortF, test_table1.shortF as shortF1 from test_table FULL JOIN test_table1 ON test_table.shortF=test_table1.shortF", "test_table.shortF, test_table1.shortF FULL JOIN test_table1 ON test_table.shortF=test_table1.shortF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
pytest.param(("SELECT test_table.intF as intF, test_table1.intF as intF1 from test_table FULL JOIN test_table1 ON test_table.intF=test_table1.intF", "test_table.intF, test_table1.intF FULL JOIN test_table1 ON test_table.intF=test_table1.intF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
pytest.param(("SELECT test_table.longF as longF, test_table1.longF as longF1 from test_table FULL JOIN test_table1 ON test_table.longF=test_table1.longF", "test_table.longF, test_table1.longF FULL JOIN test_table1 ON test_table.longF=test_table1.longF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
pytest.param(("SELECT test_table.floatF as floatF, test_table1.floatF as floatF1 from test_table FULL JOIN test_table1 ON test_table.floatF=test_table1.floatF", "test_table.floatF, test_table1.floatF FULL JOIN test_table1 ON test_table.floatF=test_table1.floatF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
pytest.param(("SELECT test_table.doubleF as doubleF, test_table1.doubleF as doubleF1 from test_table FULL JOIN test_table1 ON test_table.doubleF=test_table1.doubleF", "test_table.doubleF, test_table1.doubleF FULL JOIN test_table1 ON test_table.doubleF=test_table1.doubleF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
pytest.param(("SELECT test_table.booleanF as booleanF, test_table1.booleanF as booleanF1 from test_table FULL JOIN test_table1 ON test_table.booleanF=test_table1.booleanF", "test_table.booleanF, test_table1.booleanF FULL JOIN test_table1 ON test_table.booleanF=test_table1.booleanF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
pytest.param(("SELECT test_table.strF as strF, test_table1.strF as strF1 from test_table FULL JOIN test_table1 ON test_table.strF=test_table1.strF", "test_table.strF, test_table1.strF FULL JOIN test_table1 ON test_table.strF=test_table1.strF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
pytest.param(("SELECT test_table.dateF as dateF, test_table1.dateF as dateF1 from test_table FULL JOIN test_table1 ON test_table.dateF=test_table1.dateF", "test_table.dateF, test_table1.dateF FULL JOIN test_table1 ON test_table.dateF=test_table1.dateF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
pytest.param(("SELECT test_table.timestampF as timestampF, test_table1.timestampF as timestampF1 from test_table FULL JOIN test_table1 ON test_table.timestampF=test_table1.timestampF", "test_table.timestampF, test_table1.timestampF FULL JOIN test_table1 ON test_table.timestampF=test_table1.timestampF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ class JoinsSuite extends SparkQueryCompareTestSuite {
(A, B) => A.join(B, A("longs") === B("longs"), "LeftAnti")
}

IGNORE_ORDER_testSparkResultsAreEqual2("Test hash full join", longsDf, biggerLongsDf,
conf = shuffledJoinConf) {
(A, B) => A.join(B, A("longs") === B("longs"), "FullOuter")
}

// test replacement of sort merge join with hash join
// make sure broadcast size small enough it doesn't get used
testSparkResultsAreEqual2("Test replace sort merge join with hash join",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids
import ai.rapids.cudf.{NvtxColor, Table}

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftAnti, LeftOuter, LeftSemi}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, Inner, InnerLike, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, HashJoin}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
Expand All @@ -27,8 +27,18 @@ object GpuHashJoin {
def tagJoin(
meta: RapidsMeta[_, _, _],
joinType: JoinType,
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
condition: Option[Expression]): Unit = joinType match {
case Inner =>
case FullOuter =>
if (leftKeys.exists(_.nullable) || rightKeys.exists(_.nullable)) {
// https://github.com/rapidsai/cudf/issues/5563
meta.willNotWorkOnGpu("Full outer join does not work on nullable keys")
}
if (condition.isDefined) {
meta.willNotWorkOnGpu(s"$joinType joins currently do not support conditions")
}
case LeftOuter | LeftSemi | LeftAnti =>
if (condition.isDefined) {
meta.willNotWorkOnGpu(s"$joinType joins currently do not support conditions")
Expand All @@ -39,6 +49,25 @@ object GpuHashJoin {

trait GpuHashJoin extends GpuExec with HashJoin {

override def output: Seq[Attribute] = {
joinType match {
case _: InnerLike =>
left.output ++ right.output
case LeftOuter =>
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
case j: ExistenceJoin =>
left.output :+ j.exists
case LeftExistence(_) =>
left.output
case FullOuter =>
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
case x =>
throw new IllegalArgumentException(s"GpuHashJoin should not take $x as the JoinType")
}
}

protected lazy val (gpuBuildKeys, gpuStreamedKeys) = {
require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
"Join keys from two sides should have same types")
Expand Down Expand Up @@ -122,6 +151,9 @@ trait GpuHashJoin extends GpuExec with HashJoin {
case LeftAnti =>
leftTable.onColumns(joinKeyIndices: _*)
.leftAntiJoin(rightTable.onColumns(joinKeyIndices: _*))
case FullOuter =>
leftTable.onColumns(joinKeyIndices: _*)
.fullJoin(rightTable.onColumns(joinKeyIndices: _*))
case _ => throw new NotImplementedError(s"Joint Type ${joinType.getClass} is not currently" +
s" supported")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class GpuShuffledHashJoinMeta(
override val childExprs: Seq[BaseExprMeta[_]] = leftKeys ++ rightKeys ++ condition

override def tagPlanForGpu(): Unit = {
GpuHashJoin.tagJoin(this, join.joinType, join.condition)
GpuHashJoin.tagJoin(this, join.joinType, join.leftKeys, join.rightKeys, join.condition)
}

override def convertToGpu(): GpuExec =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class GpuSortMergeJoinMeta(

override def tagPlanForGpu(): Unit = {
// Use conditions from Hash Join
GpuHashJoin.tagJoin(this, join.joinType, join.condition)
GpuHashJoin.tagJoin(this, join.joinType, join.leftKeys, join.rightKeys, join.condition)

if (!conf.enableReplaceSortMergeJoin) {
willNotWorkOnGpu(s"Not replacing sort merge join with hash join, " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class GpuBroadcastHashJoinMeta(
override val childExprs: Seq[BaseExprMeta[_]] = leftKeys ++ rightKeys ++ condition

override def tagPlanForGpu(): Unit = {
GpuHashJoin.tagJoin(this, join.joinType, join.condition)
GpuHashJoin.tagJoin(this, join.joinType, join.leftKeys, join.rightKeys, join.condition)

val buildSide = join.buildSide match {
case BuildLeft => childPlans(0)
Expand Down

0 comments on commit 62756bf

Please sign in to comment.