Skip to content

Commit

Permalink
[SPARK-38684][SS] Fix correctness issue on stream-stream outer join w…
Browse files Browse the repository at this point in the history
…ith RocksDB state store provider

### What changes were proposed in this pull request?

(Credit to alex-balikov for the inspiration of the root cause observation, and anishshri-db for looking into the issue together.)

This PR fixes the correctness issue on stream-stream outer join with RocksDB state store provider, which can occur in certain condition, like below:

* stream-stream time interval outer join
  * left outer join has an issue on left side, right outer join has an issue on right side, full outer join has an issue on both sides
* At batch N, produce non-late row(s) on the problematic side
* At the same batch (batch N), some row(s) on the problematic side are evicted by the condition of watermark

The root cause is same as [SPARK-38320](https://issues.apache.org/jira/browse/SPARK-38320) - weak read consistency on iterator, especially with RocksDB state store provider. (Quoting from SPARK-38320: The problem is due to the StateStore.iterator not reflecting StateStore changes made after its creation.)

More specifically, if updates are performed during processing input rows and somehow updates the number of values for grouping key, the update is not seen in SymmetricHashJoinStateManager.removeByValueCondition, and the method does the eviction with the number of values in out of sync.

Making it more worse, if the method performs the eviction and updates the number of values for grouping key, it "overwrites" the number of value, effectively drop all rows being inserted in the same batch.

Below code blocks are references on understanding the details of the issue.

https://github.com/apache/spark/blob/ca7200b0008dc6101a252020e6c34ef7b72d81d6/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala#L327-L339

https://github.com/apache/spark/blob/ca7200b0008dc6101a252020e6c34ef7b72d81d6/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala#L619-L627

https://github.com/apache/spark/blob/ca7200b0008dc6101a252020e6c34ef7b72d81d6/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala#L195-L201

https://github.com/apache/spark/blob/ca7200b0008dc6101a252020e6c34ef7b72d81d6/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala#L208-L223

This PR fixes the outer iterators as late evaluation to ensure all updates on processing input rows are reflected "before" outer iterators are initialized.

### Why are the changes needed?

The bug is described in above section.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New UT added.

Closes #36002 from HeartSaVioR/SPARK-38684.

Authored-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
(cherry picked from commit 2f8613f)
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
  • Loading branch information
HeartSaVioR committed Apr 1, 2022
1 parent b4f996a commit 8a072ef
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -324,17 +324,22 @@ case class StreamingSymmetricHashJoinExec(
}
}

val initIterFn = { () =>
val removedRowIter = leftSideJoiner.removeOldState()
removedRowIter.filterNot { kv =>
stateFormatVersion match {
case 1 => matchesWithRightSideState(new UnsafeRowPair(kv.key, kv.value))
case 2 => kv.matched
case _ => throwBadStateFormatVersionException()
}
}.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
}

// NOTE: we need to make sure `outerOutputIter` is evaluated "after" exhausting all of
// elements in `innerOutputIter`, because evaluation of `innerOutputIter` may update
// the match flag which the logic for outer join is relying on.
val removedRowIter = leftSideJoiner.removeOldState()
val outerOutputIter = removedRowIter.filterNot { kv =>
stateFormatVersion match {
case 1 => matchesWithRightSideState(new UnsafeRowPair(kv.key, kv.value))
case 2 => kv.matched
case _ => throwBadStateFormatVersionException()
}
}.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
// elements in `hashJoinOutputIter`, otherwise it may lead to out of sync according to
// the interface contract on StateStore.iterator and end up with correctness issue.
// Please refer SPARK-38684 for more details.
val outerOutputIter = new LazilyInitializingJoinedRowIterator(initIterFn)

hashJoinOutputIter ++ outerOutputIter
case RightOuter =>
Expand All @@ -344,14 +349,23 @@ case class StreamingSymmetricHashJoinExec(
postJoinFilter(joinedRow.withLeft(leftValue).withRight(rightKeyValue.value))
}
}
val removedRowIter = rightSideJoiner.removeOldState()
val outerOutputIter = removedRowIter.filterNot { kv =>
stateFormatVersion match {
case 1 => matchesWithLeftSideState(new UnsafeRowPair(kv.key, kv.value))
case 2 => kv.matched
case _ => throwBadStateFormatVersionException()
}
}.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))

val initIterFn = { () =>
val removedRowIter = rightSideJoiner.removeOldState()
removedRowIter.filterNot { kv =>
stateFormatVersion match {
case 1 => matchesWithLeftSideState(new UnsafeRowPair(kv.key, kv.value))
case 2 => kv.matched
case _ => throwBadStateFormatVersionException()
}
}.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
}

// NOTE: we need to make sure `outerOutputIter` is evaluated "after" exhausting all of
// elements in `hashJoinOutputIter`, otherwise it may lead to out of sync according to
// the interface contract on StateStore.iterator and end up with correctness issue.
// Please refer SPARK-38684 for more details.
val outerOutputIter = new LazilyInitializingJoinedRowIterator(initIterFn)

hashJoinOutputIter ++ outerOutputIter
case FullOuter =>
Expand All @@ -360,10 +374,25 @@ case class StreamingSymmetricHashJoinExec(
case 2 => kv.matched
case _ => throwBadStateFormatVersionException()
}
val leftSideOutputIter = leftSideJoiner.removeOldState().filterNot(
isKeyToValuePairMatched).map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
val rightSideOutputIter = rightSideJoiner.removeOldState().filterNot(
isKeyToValuePairMatched).map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))

val leftSideInitIterFn = { () =>
val removedRowIter = leftSideJoiner.removeOldState()
removedRowIter.filterNot(isKeyToValuePairMatched)
.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
}

val rightSideInitIterFn = { () =>
val removedRowIter = rightSideJoiner.removeOldState()
removedRowIter.filterNot(isKeyToValuePairMatched)
.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
}

// NOTE: we need to make sure both `leftSideOutputIter` and `rightSideOutputIter` are
// evaluated "after" exhausting all of elements in `hashJoinOutputIter`, otherwise it may
// lead to out of sync according to the interface contract on StateStore.iterator and
// end up with correctness issue. Please refer SPARK-38684 for more details.
val leftSideOutputIter = new LazilyInitializingJoinedRowIterator(leftSideInitIterFn)
val rightSideOutputIter = new LazilyInitializingJoinedRowIterator(rightSideInitIterFn)

hashJoinOutputIter ++ leftSideOutputIter ++ rightSideOutputIter
case _ => throwBadJoinTypeException()
Expand Down Expand Up @@ -638,4 +667,12 @@ case class StreamingSymmetricHashJoinExec(
override protected def withNewChildrenInternal(
newLeft: SparkPlan, newRight: SparkPlan): StreamingSymmetricHashJoinExec =
copy(left = newLeft, right = newRight)

private class LazilyInitializingJoinedRowIterator(
initFn: () => Iterator[JoinedRow]) extends Iterator[JoinedRow] {
private lazy val iter: Iterator[JoinedRow] = initFn()

override def hasNext: Boolean = iter.hasNext
override def next(): JoinedRow = iter.next()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec, StreamingSymmetricHashJoinHelper}
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreProviderId}
import org.apache.spark.sql.execution.streaming.state.{RocksDBStateStoreProvider, StateStore, StateStoreProviderId}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -1353,6 +1353,67 @@ class StreamingOuterJoinSuite extends StreamingJoinSuite {
).select(Symbol("leftKey1"), Symbol("rightKey1"), Symbol("leftKey2"), Symbol("rightKey2"),
$"leftWindow.end".cast("long"), Symbol("leftValue"), Symbol("rightValue"))
}

test("SPARK-38684: outer join works correctly even if processing input rows and " +
"evicting state rows for same grouping key happens in the same micro-batch") {

// The test is to demonstrate the correctness issue in outer join before SPARK-38684.
withSQLConf(
SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key -> "false",
SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) {

val input1 = MemoryStream[(Timestamp, String, String)]
val df1 = input1.toDF
.selectExpr("_1 as eventTime", "_2 as id", "_3 as comment")
.withWatermark("eventTime", "0 second")

val input2 = MemoryStream[(Timestamp, String, String)]
val df2 = input2.toDF
.selectExpr("_1 as eventTime", "_2 as id", "_3 as comment")
.withWatermark("eventTime", "0 second")

val joined = df1.as("left")
.join(df2.as("right"),
expr("""
|left.id = right.id AND left.eventTime BETWEEN
| right.eventTime - INTERVAL 30 seconds AND
| right.eventTime + INTERVAL 30 seconds
""".stripMargin),
joinType = "leftOuter")

testStream(joined)(
MultiAddData(
(input1, Seq((Timestamp.valueOf("2020-01-02 00:00:00"), "abc", "left in batch 1"))),
(input2, Seq((Timestamp.valueOf("2020-01-02 00:01:00"), "abc", "right in batch 1")))
),
CheckNewAnswer(),
MultiAddData(
(input1, Seq((Timestamp.valueOf("2020-01-02 01:00:00"), "abc", "left in batch 2"))),
(input2, Seq((Timestamp.valueOf("2020-01-02 01:01:00"), "abc", "right in batch 2")))
),
// watermark advanced to "2020-01-02 00:00:00"
CheckNewAnswer(),
AddData(input1, (Timestamp.valueOf("2020-01-02 01:30:00"), "abc", "left in batch 3")),
// watermark advanced to "2020-01-02 01:00:00"
CheckNewAnswer(
(Timestamp.valueOf("2020-01-02 00:00:00"), "abc", "left in batch 1", null, null, null)
),
// left side state should still contain "left in batch 2" and "left in batch 3"
// we should see both rows in the left side since
// - "left in batch 2" is going to be evicted in this batch
// - "left in batch 3" is going to be matched with new row in right side
AddData(input2,
(Timestamp.valueOf("2020-01-02 01:30:10"), "abc", "match with left in batch 3")),
// watermark advanced to "2020-01-02 01:01:00"
CheckNewAnswer(
(Timestamp.valueOf("2020-01-02 01:00:00"), "abc", "left in batch 2",
null, null, null),
(Timestamp.valueOf("2020-01-02 01:30:00"), "abc", "left in batch 3",
Timestamp.valueOf("2020-01-02 01:30:10"), "abc", "match with left in batch 3")
)
)
}
}
}

class StreamingFullOuterJoinSuite extends StreamingJoinSuite {
Expand Down

0 comments on commit 8a072ef

Please sign in to comment.