Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-33655][SQL] Improve performance of processing FETCH_PRIOR #30600

Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive.thriftserver

private[hive] sealed trait FetchIterator[A] extends Iterator[A] {
/**
* Begin a fetch block, forward from the current position.
* Resets the fetch start offset.
*/
def fetchNext(): Unit

/**
* Begin a fetch block, moving the iterator back by offset from the start of the previous fetch
* block start.
* Resets the fetch start offset.
*
* @param offset the amount to move a fetch start position toward the prior direction.
*/
def fetchPrior(offset: Long): Unit = fetchAbsolute(getFetchStart - offset)

/**
* Begin a fetch block, moving the iterator to the given position.
* Resets the fetch start offset.
*
* @param pos index to move a position of iterator.
*/
def fetchAbsolute(pos: Long): Unit

def getFetchStart: Long

def getPosition: Long
}

private[hive] class ArrayFetchIterator[A](src: Array[A]) extends FetchIterator[A] {
private var fetchStart: Long = 0

private var position: Long = 0

override def fetchNext(): Unit = fetchStart = position

override def fetchAbsolute(pos: Long): Unit = {
position = (pos max 0) min src.length
fetchStart = position
}

override def getFetchStart: Long = fetchStart

override def getPosition: Long = position

override def hasNext: Boolean = position < src.length

override def next(): A = {
position += 1
src(position.toInt - 1)
}
}

private[hive] class IterableFetchIterator[A](iterable: Iterable[A]) extends FetchIterator[A] {
private var iter: Iterator[A] = iterable.iterator

private var fetchStart: Long = 0

private var position: Long = 0

override def fetchNext(): Unit = fetchStart = position

override def fetchAbsolute(pos: Long): Unit = {
val newPos = pos max 0
if (newPos < position) resetPosition()
while (position < newPos && hasNext) next()
fetchStart = position
}

override def getFetchStart: Long = fetchStart

override def getPosition: Long = position

override def hasNext: Boolean = iter.hasNext

override def next(): A = {
position += 1
iter.next()
}

private def resetPosition(): Unit = {
if (position != 0) {
iter = iterable.iterator
position = 0
fetchStart = 0
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,7 @@ private[hive] class SparkExecuteStatementOperation(

private var result: DataFrame = _

// We cache the returned rows to get iterators again in case the user wants to use FETCH_FIRST.
// This is only used when `spark.sql.thriftServer.incrementalCollect` is set to `false`.
// In case of `true`, this will be `None` and FETCH_FIRST will trigger re-execution.
private var resultList: Option[Array[SparkRow]] = _
private var previousFetchEndOffset: Long = 0
private var previousFetchStartOffset: Long = 0
private var iter: Iterator[SparkRow] = _
private var iter: FetchIterator[SparkRow] = _
private var dataTypes: Array[DataType] = _

private lazy val resultSchema: TableSchema = {
Expand Down Expand Up @@ -148,43 +142,10 @@ private[hive] class SparkExecuteStatementOperation(
setHasResultSet(true)
val resultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion, false)

// Reset iter when FETCH_FIRST or FETCH_PRIOR
if ((order.equals(FetchOrientation.FETCH_FIRST) ||
order.equals(FetchOrientation.FETCH_PRIOR)) && previousFetchEndOffset != 0) {
// Reset the iterator to the beginning of the query.
iter = if (sqlContext.getConf(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) {
resultList = None
result.toLocalIterator.asScala
} else {
if (resultList.isEmpty) {
resultList = Some(result.collect())
}
resultList.get.iterator
}
}

var resultOffset = {
if (order.equals(FetchOrientation.FETCH_FIRST)) {
logInfo(s"FETCH_FIRST request with $statementId. Resetting to resultOffset=0")
0
} else if (order.equals(FetchOrientation.FETCH_PRIOR)) {
// TODO: FETCH_PRIOR should be handled more efficiently than rewinding to beginning and
// reiterating.
val targetOffset = math.max(previousFetchStartOffset - maxRowsL, 0)
logInfo(s"FETCH_PRIOR request with $statementId. Resetting to resultOffset=$targetOffset")
var off = 0
while (off < targetOffset && iter.hasNext) {
iter.next()
off += 1
}
off
} else { // FETCH_NEXT
previousFetchEndOffset
}
}

resultRowSet.setStartOffset(resultOffset)
previousFetchStartOffset = resultOffset
if (order.equals(FetchOrientation.FETCH_FIRST)) iter.fetchAbsolute(0)
else if (order.equals(FetchOrientation.FETCH_PRIOR)) iter.fetchPrior(maxRowsL)
else iter.fetchNext()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resultRowSet.setStartOffset(iter.getPosition)
if (!iter.hasNext) {
resultRowSet
} else {
Expand All @@ -206,11 +167,9 @@ private[hive] class SparkExecuteStatementOperation(
}
resultRowSet.addRow(row.toArray.asInstanceOf[Array[Object]])
curRow += 1
resultOffset += 1
}
previousFetchEndOffset = resultOffset
log.info(s"Returning result set with ${curRow} rows from offsets " +
s"[$previousFetchStartOffset, $previousFetchEndOffset) with $statementId")
s"[${iter.getFetchStart}, ${iter.getPosition}) with $statementId")
resultRowSet
}
}
Expand Down Expand Up @@ -326,14 +285,12 @@ private[hive] class SparkExecuteStatementOperation(
logDebug(result.queryExecution.toString())
HiveThriftServer2.eventManager.onStatementParsed(statementId,
result.queryExecution.toString())
iter = {
if (sqlContext.getConf(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) {
resultList = None
result.toLocalIterator.asScala
} else {
resultList = Some(result.collect())
resultList.get.iterator
}
iter = if (sqlContext.getConf(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) {
new IterableFetchIterator[SparkRow](new Iterable[SparkRow] {
override def iterator: Iterator[SparkRow] = result.toLocalIterator.asScala
})
} else {
new ArrayFetchIterator[SparkRow](result.collect())
}
dataTypes = result.schema.fields.map(_.dataType)
} catch {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive.thriftserver

import org.apache.spark.SparkFunSuite

class FetchIteratorSuite extends SparkFunSuite {

private def getRows(fetchIter: FetchIterator[Int], maxRowCount: Int): Seq[Int] = {
for (_ <- 0 until maxRowCount if fetchIter.hasNext) yield fetchIter.next()
}

test("SPARK-33655: Test fetchNext and fetchPrior") {
val testData = 0 until 10

def iteratorTest(fetchIter: FetchIterator[Int]): Unit = {
fetchIter.fetchNext()
assert(fetchIter.getFetchStart == 0)
assert(fetchIter.getPosition == 0)
assertResult(0 until 2)(getRows(fetchIter, 2))
assert(fetchIter.getFetchStart == 0)
assert(fetchIter.getPosition == 2)

fetchIter.fetchNext()
assert(fetchIter.getFetchStart == 2)
assert(fetchIter.getPosition == 2)
assertResult(2 until 3)(getRows(fetchIter, 1))
assert(fetchIter.getFetchStart == 2)
assert(fetchIter.getPosition == 3)

fetchIter.fetchPrior(2)
assert(fetchIter.getFetchStart == 0)
assert(fetchIter.getPosition == 0)
assertResult(0 until 3)(getRows(fetchIter, 3))
assert(fetchIter.getFetchStart == 0)
assert(fetchIter.getPosition == 3)

fetchIter.fetchNext()
assert(fetchIter.getFetchStart == 3)
assert(fetchIter.getPosition == 3)
assertResult(3 until 8)(getRows(fetchIter, 5))
assert(fetchIter.getFetchStart == 3)
assert(fetchIter.getPosition == 8)

fetchIter.fetchPrior(2)
assert(fetchIter.getFetchStart == 1)
assert(fetchIter.getPosition == 1)
assertResult(1 until 4)(getRows(fetchIter, 3))
assert(fetchIter.getFetchStart == 1)
assert(fetchIter.getPosition == 4)

fetchIter.fetchNext()
assert(fetchIter.getFetchStart == 4)
assert(fetchIter.getPosition == 4)
assertResult(4 until 10)(getRows(fetchIter, 10))
assert(fetchIter.getFetchStart == 4)
assert(fetchIter.getPosition == 10)

fetchIter.fetchNext()
assert(fetchIter.getFetchStart == 10)
assert(fetchIter.getPosition == 10)
assertResult(Seq.empty[Int])(getRows(fetchIter, 10))
assert(fetchIter.getFetchStart == 10)
assert(fetchIter.getPosition == 10)

fetchIter.fetchPrior(20)
assert(fetchIter.getFetchStart == 0)
assert(fetchIter.getPosition == 0)
assertResult(0 until 3)(getRows(fetchIter, 3))
assert(fetchIter.getFetchStart == 0)
assert(fetchIter.getPosition == 3)
}
iteratorTest(new ArrayFetchIterator[Int](testData.toArray))
iteratorTest(new IterableFetchIterator[Int](testData))
}

test("SPARK-33655: Test fetchAbsolute") {
val testData = 0 until 10

def iteratorTest(fetchIter: FetchIterator[Int]): Unit = {
fetchIter.fetchNext()
assert(fetchIter.getFetchStart == 0)
assert(fetchIter.getPosition == 0)
assertResult(0 until 5)(getRows(fetchIter, 5))
assert(fetchIter.getFetchStart == 0)
assert(fetchIter.getPosition == 5)

fetchIter.fetchAbsolute(2)
assert(fetchIter.getFetchStart == 2)
assert(fetchIter.getPosition == 2)
assertResult(2 until 5)(getRows(fetchIter, 3))
assert(fetchIter.getFetchStart == 2)
assert(fetchIter.getPosition == 5)

fetchIter.fetchAbsolute(7)
assert(fetchIter.getFetchStart == 7)
assert(fetchIter.getPosition == 7)
assertResult(7 until 8)(getRows(fetchIter, 1))
assert(fetchIter.getFetchStart == 7)
assert(fetchIter.getPosition == 8)

fetchIter.fetchAbsolute(20)
assert(fetchIter.getFetchStart == 10)
assert(fetchIter.getPosition == 10)
assertResult(Seq.empty[Int])(getRows(fetchIter, 1))
assert(fetchIter.getFetchStart == 10)
assert(fetchIter.getPosition == 10)

fetchIter.fetchAbsolute(0)
assert(fetchIter.getFetchStart == 0)
assert(fetchIter.getPosition == 0)
assertResult(0 until 3)(getRows(fetchIter, 3))
assert(fetchIter.getFetchStart == 0)
assert(fetchIter.getPosition == 3)
}
iteratorTest(new ArrayFetchIterator[Int](testData.toArray))
iteratorTest(new IterableFetchIterator[Int](testData))
}
}