Skip to content

Commit

Permalink
[KYUUBI #360] Correct handle getNextRowSet with FETCH_PRIOR FETCH_FIRST
Browse files Browse the repository at this point in the history
![pan3793](https://badgen.net/badge/Hello/pan3793/green) [![Closes #370](https://badgen.net/badge/Preview/Closes%20%23370/blue)](https://github.com/yaooqinn/kyuubi/pull/370) ![332](https://badgen.net/badge/%2B/332/red) ![24](https://badgen.net/badge/-/24/green) ![8](https://badgen.net/badge/commits/8/yellow) ![Feature](https://badgen.net/badge/Label/Feature/) ![Bug](https://badgen.net/badge/Label/Bug/) [&#10088;?&#10089;](https://pullrequestbadge.com/?utm_medium=github&utm_source=yaooqinn&utm_campaign=badge_info)<!-- PR-BADGE: PLEASE DO NOT REMOVE THIS COMMENT -->

<!--
Thanks for sending a pull request!

Here are some tips for you:
  1. If this is your first time, please read our contributor guidelines: https://kyuubi.readthedocs.io/en/latest/community/contributions.html
  2. If the PR is related to an issue in https://github.com/yaooqinn/kyuubi/issues, add '[KYUUBI #XXXX]' in your PR title, e.g., '[KYUUBI #XXXX] Your PR title ...'.
  3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][KYUUBI #XXXX] Your PR title ...'.
-->

### _Why are the changes needed?_
<!--
Please clarify why the changes are needed. For instance,
  1. If you add a feature, you can talk about the use case of it.
  2. If you fix a bug, you can clarify why it is a bug.
-->
close #360

Ref: apache/spark#30600

### _How was this patch tested?_
- [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [x] [Run test](https://kyuubi.readthedocs.io/en/latest/tools/testing.html#running-tests) locally before make a pull request

Closes #370 from pan3793/KYUUBI-360.

e79b8cb [Cheng Pan] [KYUUBI #360] comments
0fae3db [Cheng Pan] fix import
3d1b2a6 [Cheng Pan] [KYUUBI #360] fix ut
eda3e59 [Cheng Pan] [KYUUBI #360] fix import
16178d6 [Cheng Pan] [KYUUBI #360] ut
179404d [Cheng Pan] [KYUUBI #360] nit
455af6b [Cheng Pan] [KYUUBI #360] correct getNextRowSet with FETCH_PRIOR FETCH_FIRST
2307f1f [Cheng Pan] [KYUUBI #360] move ThriftUtils to kyuubi-common

Authored-by: Cheng Pan <379377944@qq.com>
Signed-off-by: Kent Yao <yao@apache.org>
  • Loading branch information
pan3793 authored and yaooqinn committed Feb 25, 2021
1 parent d94b1c4 commit c659089
Show file tree
Hide file tree
Showing 19 changed files with 332 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.kyuubi.engine.spark

/**
* Borrowed from Apache Spark, see SPARK-33655
*/
private[engine] sealed trait FetchIterator[A] extends Iterator[A] {
/**
* Begin a fetch block, forward from the current position.
* Resets the fetch start offset.
*/
def fetchNext(): Unit

/**
* Begin a fetch block, moving the iterator back by offset from the start of the previous fetch
* block start.
* Resets the fetch start offset.
*
* @param offset the amount to move a fetch start position toward the prior direction.
*/
def fetchPrior(offset: Long): Unit = fetchAbsolute(getFetchStart - offset)

/**
* Begin a fetch block, moving the iterator to the given position.
* Resets the fetch start offset.
*
* @param pos index to move a position of iterator.
*/
def fetchAbsolute(pos: Long): Unit

def getFetchStart: Long

def getPosition: Long
}

private[engine] class ArrayFetchIterator[A](src: Array[A]) extends FetchIterator[A] {
private var fetchStart: Long = 0

private var position: Long = 0

override def fetchNext(): Unit = fetchStart = position

override def fetchAbsolute(pos: Long): Unit = {
position = (pos max 0) min src.length
fetchStart = position
}

override def getFetchStart: Long = fetchStart

override def getPosition: Long = position

override def hasNext: Boolean = position < src.length

override def next(): A = {
position += 1
src(position.toInt - 1)
}
}

private[engine] class IterableFetchIterator[A](iterable: Iterable[A]) extends FetchIterator[A] {
private var iter: Iterator[A] = iterable.iterator

private var fetchStart: Long = 0

private var position: Long = 0

override def fetchNext(): Unit = fetchStart = position

override def fetchAbsolute(pos: Long): Unit = {
val newPos = pos max 0
if (newPos < position) resetPosition()
while (position < newPos && hasNext) next()
fetchStart = position
}

override def getFetchStart: Long = fetchStart

override def getPosition: Long = position

override def hasNext: Boolean = iter.hasNext

override def next(): A = {
position += 1
iter.next()
}

private def resetPosition(): Unit = {
if (position != 0) {
iter = iterable.iterator
position = 0
fetchStart = 0
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

import org.apache.kyuubi.{KyuubiSQLException, Logging}
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil
import org.apache.kyuubi.engine.spark.{ArrayFetchIterator, KyuubiSparkUtil}
import org.apache.kyuubi.operation.{OperationState, OperationType}
import org.apache.kyuubi.operation.log.OperationLog
import org.apache.kyuubi.session.Session
Expand Down Expand Up @@ -74,7 +74,7 @@ class ExecuteStatement(
debug(s"original result queryExecution: ${result.queryExecution}")
val castedResult = result.select(castCols: _*)
debug(s"casted result queryExecution: ${castedResult.queryExecution}")
iter = castedResult.collect().toList.iterator
iter = new ArrayFetchIterator(castedResult.collect())
setState(OperationState.FINISHED)
} catch {
onError(cancel = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.engine.spark.IterableFetchIterator
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant.TABLE_CAT
Expand All @@ -35,7 +36,7 @@ class GetCatalogs(spark: SparkSession, session: Session)

override protected def runInternal(): Unit = {
try {
iter = SparkCatalogShim().getCatalogs(spark).toIterator
iter = new IterableFetchIterator(SparkCatalogShim().getCatalogs(spark).toList)
} catch onError()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._

import org.apache.kyuubi.engine.spark.IterableFetchIterator
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
Expand Down Expand Up @@ -88,9 +89,8 @@ class GetColumns(
val schemaPattern = toJavaRegex(schemaName)
val tablePattern = toJavaRegex(tableName)
val columnPattern = toJavaRegex(columnName)
iter = SparkCatalogShim()
.getColumns(spark, catalogName, schemaPattern, tablePattern, columnPattern)
.toList.iterator
iter = new IterableFetchIterator(SparkCatalogShim()
.getColumns(spark, catalogName, schemaPattern, tablePattern, columnPattern).toList)
} catch {
onError()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.sql.DatabaseMetaData
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.engine.spark.IterableFetchIterator
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
import org.apache.kyuubi.session.Session
Expand Down Expand Up @@ -70,7 +71,7 @@ class GetFunctions(
info.getClassName)
}
}
iter = a.toList.iterator
iter = new IterableFetchIterator(a.toList)
} catch {
onError()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.engine.spark.IterableFetchIterator
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
Expand All @@ -42,7 +43,7 @@ class GetSchemas(spark: SparkSession, session: Session, catalogName: String, sch
try {
val schemaPattern = toJavaRegex(schema)
val rows = SparkCatalogShim().getSchemas(spark, catalogName, schemaPattern)
iter = rows.toList.toIterator
iter = new IterableFetchIterator(rows)
} catch onError()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.engine.spark.IterableFetchIterator
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
Expand All @@ -33,6 +34,6 @@ class GetTableTypes(spark: SparkSession, session: Session)
}

override protected def runInternal(): Unit = {
iter = SparkCatalogShim.sparkTableTypes.map(Row(_)).toList.iterator
iter = new IterableFetchIterator(SparkCatalogShim.sparkTableTypes.map(Row(_)).toList)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.engine.spark.IterableFetchIterator
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
Expand Down Expand Up @@ -73,7 +74,7 @@ class GetTables(
} else {
catalogTablesAndViews
}
iter = allTableAndViews.toList.iterator
iter = new IterableFetchIterator(allTableAndViews)
} catch {
onError()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.sql.Types._
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.engine.spark.IterableFetchIterator
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
import org.apache.kyuubi.session.Session
Expand Down Expand Up @@ -83,7 +84,7 @@ class GetTypeInfo(spark: SparkSession, session: Session)
}

override protected def runInternal(): Unit = {
iter = Seq(
iter = new IterableFetchIterator(Seq(
toRow("VOID", NULL),
toRow("BOOLEAN", BOOLEAN),
toRow("TINYINT", TINYINT, 3),
Expand All @@ -101,6 +102,6 @@ class GetTypeInfo(spark: SparkSession, session: Session)
toRow("MAP", JAVA_OBJECT),
toRow("STRUCT", STRUCT),
toRow("INTERVAL", OTHER)
).toList.iterator
))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.KyuubiSQLException
import org.apache.kyuubi.engine.spark.FetchIterator
import org.apache.kyuubi.operation.{AbstractOperation, OperationState}
import org.apache.kyuubi.operation.FetchOrientation.FetchOrientation
import org.apache.kyuubi.operation.FetchOrientation._
import org.apache.kyuubi.operation.OperationState.OperationState
import org.apache.kyuubi.operation.OperationType.OperationType
import org.apache.kyuubi.operation.log.OperationLog
Expand All @@ -36,7 +37,7 @@ import org.apache.kyuubi.session.Session
abstract class SparkOperation(spark: SparkSession, opType: OperationType, session: Session)
extends AbstractOperation(opType, session) {

protected var iter: Iterator[Row] = _
protected var iter: FetchIterator[Row] = _

protected final val operationLog: OperationLog =
OperationLog.createOperationLog(session.handle, getHandle)
Expand Down Expand Up @@ -130,8 +131,15 @@ abstract class SparkOperation(spark: SparkSession, opType: OperationType, sessio
validateDefaultFetchOrientation(order)
assertState(OperationState.FINISHED)
setHasResultSet(true)
order match {
case FETCH_NEXT => iter.fetchNext()
case FETCH_PRIOR => iter.fetchPrior(rowSetSize);
case FETCH_FIRST => iter.fetchAbsolute(0);
}
val taken = iter.take(rowSetSize)
RowSet.toTRowSet(taken.toList, resultSchema, getProtocolVersion)
val resultRowSet = RowSet.toTRowSet(taken.toList, resultSchema, getProtocolVersion)
resultRowSet.setStartRowOffset(iter.getPosition)
resultRowSet
}

override def shouldRunAsync: Boolean = false
Expand Down
Loading

0 comments on commit c659089

Please sign in to comment.