Skip to content

Commit

Permalink
updated join code to treat null equality properly (NVIDIA#349)
Browse files Browse the repository at this point in the history
  • Loading branch information
revans2 authored Jul 14, 2020
1 parent ab80adb commit 46c8b41
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 60 deletions.
29 changes: 3 additions & 26 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,45 +43,23 @@ def create_df(spark, data_gen, left_length, right_length):
.withColumnRenamed("b", "r_b")
return left, right

# Once https://github.com/NVIDIA/spark-rapids/issues/280 is fixed this test should be deleted
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen_no_nulls, ids=idfn)
@pytest.mark.parametrize('join_type', ['FullOuter'], ids=idfn)
def test_sortmerge_join_no_nulls(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 500, 500)
return left.join(right, left.a == right.r_a, join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=_sortmerge_join_conf)

# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross',
pytest.param('FullOuter', marks=[pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')])], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
def test_sortmerge_join(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 500, 500)
return left.join(right, left.a == right.r_a, join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=_sortmerge_join_conf)


# Once https://github.com/NVIDIA/spark-rapids/issues/280 is fixed this test should be deleted
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen_no_nulls, ids=idfn)
@pytest.mark.parametrize('join_type', ['FullOuter'], ids=idfn)
def test_broadcast_join_no_nulls(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 500, 250)
return left.join(broadcast(right), left.a == right.r_a, join_type)
assert_gpu_and_cpu_are_equal_collect(do_join)

# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
# Not all join types can be translated to a broadcast join, but this tests them to be sure we
# can handle what spark is doing
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross',
pytest.param('FullOuter', marks=[pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')])], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
def test_broadcast_join_right_table(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 500, 250)
Expand Down Expand Up @@ -147,8 +125,7 @@ def do_join(spark):
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
# Not all join types can be translated to a broadcast join, but this tests them to be sure we
# can handle what spark is doing
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross',
pytest.param('FullOuter', marks=[pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')])], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
def test_broadcast_join_left_table(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 250, 500)
Expand Down
20 changes: 10 additions & 10 deletions integration_tests/src/main/python/qa_nightly_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,16 +744,16 @@
("SELECT test_table.strF as strF, test_table1.strF as strF1 from test_table RIGHT JOIN test_table1 ON test_table.strF=test_table1.strF", "test_table.strF, test_table1.strF RIGHT JOIN test_table1 ON test_table.strF=test_table1.strF"),
("SELECT test_table.dateF as dateF, test_table1.dateF as dateF1 from test_table RIGHT JOIN test_table1 ON test_table.dateF=test_table1.dateF", "test_table.dateF, test_table1.dateF RIGHT JOIN test_table1 ON test_table.dateF=test_table1.dateF"),
("SELECT test_table.timestampF as timestampF, test_table1.timestampF as timestampF1 from test_table RIGHT JOIN test_table1 ON test_table.timestampF=test_table1.timestampF", "test_table.timestampF, test_table1.timestampF RIGHT JOIN test_table1 ON test_table.timestampF=test_table1.timestampF"),
pytest.param(("SELECT test_table.byteF as byteF, test_table1.byteF as byteF1 from test_table FULL JOIN test_table1 ON test_table.byteF=test_table1.byteF", "test_table.byteF, test_table1.byteF FULL JOIN test_table1 ON test_table.byteF=test_table1.byteF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
pytest.param(("SELECT test_table.shortF as shortF, test_table1.shortF as shortF1 from test_table FULL JOIN test_table1 ON test_table.shortF=test_table1.shortF", "test_table.shortF, test_table1.shortF FULL JOIN test_table1 ON test_table.shortF=test_table1.shortF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
pytest.param(("SELECT test_table.intF as intF, test_table1.intF as intF1 from test_table FULL JOIN test_table1 ON test_table.intF=test_table1.intF", "test_table.intF, test_table1.intF FULL JOIN test_table1 ON test_table.intF=test_table1.intF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
pytest.param(("SELECT test_table.longF as longF, test_table1.longF as longF1 from test_table FULL JOIN test_table1 ON test_table.longF=test_table1.longF", "test_table.longF, test_table1.longF FULL JOIN test_table1 ON test_table.longF=test_table1.longF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
pytest.param(("SELECT test_table.floatF as floatF, test_table1.floatF as floatF1 from test_table FULL JOIN test_table1 ON test_table.floatF=test_table1.floatF", "test_table.floatF, test_table1.floatF FULL JOIN test_table1 ON test_table.floatF=test_table1.floatF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
pytest.param(("SELECT test_table.doubleF as doubleF, test_table1.doubleF as doubleF1 from test_table FULL JOIN test_table1 ON test_table.doubleF=test_table1.doubleF", "test_table.doubleF, test_table1.doubleF FULL JOIN test_table1 ON test_table.doubleF=test_table1.doubleF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
pytest.param(("SELECT test_table.booleanF as booleanF, test_table1.booleanF as booleanF1 from test_table FULL JOIN test_table1 ON test_table.booleanF=test_table1.booleanF", "test_table.booleanF, test_table1.booleanF FULL JOIN test_table1 ON test_table.booleanF=test_table1.booleanF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
pytest.param(("SELECT test_table.strF as strF, test_table1.strF as strF1 from test_table FULL JOIN test_table1 ON test_table.strF=test_table1.strF", "test_table.strF, test_table1.strF FULL JOIN test_table1 ON test_table.strF=test_table1.strF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
pytest.param(("SELECT test_table.dateF as dateF, test_table1.dateF as dateF1 from test_table FULL JOIN test_table1 ON test_table.dateF=test_table1.dateF", "test_table.dateF, test_table1.dateF FULL JOIN test_table1 ON test_table.dateF=test_table1.dateF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
pytest.param(("SELECT test_table.timestampF as timestampF, test_table1.timestampF as timestampF1 from test_table FULL JOIN test_table1 ON test_table.timestampF=test_table1.timestampF", "test_table.timestampF, test_table1.timestampF FULL JOIN test_table1 ON test_table.timestampF=test_table1.timestampF"), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/280')),
("SELECT test_table.byteF as byteF, test_table1.byteF as byteF1 from test_table FULL JOIN test_table1 ON test_table.byteF=test_table1.byteF", "test_table.byteF, test_table1.byteF FULL JOIN test_table1 ON test_table.byteF=test_table1.byteF"),
("SELECT test_table.shortF as shortF, test_table1.shortF as shortF1 from test_table FULL JOIN test_table1 ON test_table.shortF=test_table1.shortF", "test_table.shortF, test_table1.shortF FULL JOIN test_table1 ON test_table.shortF=test_table1.shortF"),
("SELECT test_table.intF as intF, test_table1.intF as intF1 from test_table FULL JOIN test_table1 ON test_table.intF=test_table1.intF", "test_table.intF, test_table1.intF FULL JOIN test_table1 ON test_table.intF=test_table1.intF"),
("SELECT test_table.longF as longF, test_table1.longF as longF1 from test_table FULL JOIN test_table1 ON test_table.longF=test_table1.longF", "test_table.longF, test_table1.longF FULL JOIN test_table1 ON test_table.longF=test_table1.longF"),
("SELECT test_table.floatF as floatF, test_table1.floatF as floatF1 from test_table FULL JOIN test_table1 ON test_table.floatF=test_table1.floatF", "test_table.floatF, test_table1.floatF FULL JOIN test_table1 ON test_table.floatF=test_table1.floatF"),
("SELECT test_table.doubleF as doubleF, test_table1.doubleF as doubleF1 from test_table FULL JOIN test_table1 ON test_table.doubleF=test_table1.doubleF", "test_table.doubleF, test_table1.doubleF FULL JOIN test_table1 ON test_table.doubleF=test_table1.doubleF"),
("SELECT test_table.booleanF as booleanF, test_table1.booleanF as booleanF1 from test_table FULL JOIN test_table1 ON test_table.booleanF=test_table1.booleanF", "test_table.booleanF, test_table1.booleanF FULL JOIN test_table1 ON test_table.booleanF=test_table1.booleanF"),
("SELECT test_table.strF as strF, test_table1.strF as strF1 from test_table FULL JOIN test_table1 ON test_table.strF=test_table1.strF", "test_table.strF, test_table1.strF FULL JOIN test_table1 ON test_table.strF=test_table1.strF"),
("SELECT test_table.dateF as dateF, test_table1.dateF as dateF1 from test_table FULL JOIN test_table1 ON test_table.dateF=test_table1.dateF", "test_table.dateF, test_table1.dateF FULL JOIN test_table1 ON test_table.dateF=test_table1.dateF"),
("SELECT test_table.timestampF as timestampF, test_table1.timestampF as timestampF1 from test_table FULL JOIN test_table1 ON test_table.timestampF=test_table1.timestampF", "test_table.timestampF, test_table1.timestampF FULL JOIN test_table1 ON test_table.timestampF=test_table1.timestampF"),
]

SELECT_PRE_ORDER_SQL=[
Expand Down
37 changes: 13 additions & 24 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashJoin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,7 @@ object GpuHashJoin {
rightKeys: Seq[Expression],
condition: Option[Expression]): Unit = joinType match {
case _: InnerLike =>
case FullOuter =>
if (leftKeys.exists(_.nullable) || rightKeys.exists(_.nullable)) {
// https://github.com/rapidsai/cudf/issues/5563
meta.willNotWorkOnGpu("Full outer join does not work on nullable keys")
}
if (condition.isDefined) {
meta.willNotWorkOnGpu(s"$joinType joins currently do not support conditions")
}
case RightOuter | LeftOuter | LeftSemi | LeftAnti =>
case FullOuter | RightOuter | LeftOuter | LeftSemi | LeftAnti =>
if (condition.isDefined) {
meta.willNotWorkOnGpu(s"$joinType joins currently do not support conditions")
}
Expand Down Expand Up @@ -134,7 +126,7 @@ trait GpuHashJoin extends GpuExec with HashJoin {
TaskContext.get().addTaskCompletionListener[Unit](_ => closeCb())

def closeCb(): Unit = {
nextCb.map(_.close())
nextCb.foreach(_.close())
nextCb = None
}

Expand Down Expand Up @@ -224,22 +216,19 @@ trait GpuHashJoin extends GpuExec with HashJoin {
private[this] def doJoinLeftRight(leftTable: Table, rightTable: Table): ColumnarBatch = {
val joinedTable = joinType match {
case LeftOuter => leftTable.onColumns(joinKeyIndices: _*)
.leftJoin(rightTable.onColumns(joinKeyIndices: _*))
.leftJoin(rightTable.onColumns(joinKeyIndices: _*), false)
case RightOuter => rightTable.onColumns(joinKeyIndices: _*)
.leftJoin(leftTable.onColumns(joinKeyIndices: _*))
case _: InnerLike =>
leftTable.onColumns(joinKeyIndices: _*).innerJoin(rightTable.onColumns(joinKeyIndices: _*))
case LeftSemi =>
leftTable.onColumns(joinKeyIndices: _*)
.leftSemiJoin(rightTable.onColumns(joinKeyIndices: _*))
case LeftAnti =>
leftTable.onColumns(joinKeyIndices: _*)
.leftAntiJoin(rightTable.onColumns(joinKeyIndices: _*))
case FullOuter =>
leftTable.onColumns(joinKeyIndices: _*)
.fullJoin(rightTable.onColumns(joinKeyIndices: _*))
.leftJoin(leftTable.onColumns(joinKeyIndices: _*), false)
case _: InnerLike => leftTable.onColumns(joinKeyIndices: _*)
.innerJoin(rightTable.onColumns(joinKeyIndices: _*), false)
case LeftSemi => leftTable.onColumns(joinKeyIndices: _*)
.leftSemiJoin(rightTable.onColumns(joinKeyIndices: _*), false)
case LeftAnti => leftTable.onColumns(joinKeyIndices: _*)
.leftAntiJoin(rightTable.onColumns(joinKeyIndices: _*), false)
case FullOuter => leftTable.onColumns(joinKeyIndices: _*)
.fullJoin(rightTable.onColumns(joinKeyIndices: _*), false)
case _ => throw new NotImplementedError(s"Joint Type ${joinType.getClass} is not currently" +
s" supported")
s" supported")
}
try {
val result = joinIndices.map(joinIndex =>
Expand Down

0 comments on commit 46c8b41

Please sign in to comment.