Skip to content

Commit

Permalink
Hive module.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Jan 27, 2015
1 parent d35efd5 commit 6d53134
Show file tree
Hide file tree
Showing 13 changed files with 81 additions and 43 deletions.
22 changes: 20 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import scala.reflect.ClassTag

import com.fasterxml.jackson.core.JsonFactory

import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
Expand Down Expand Up @@ -170,9 +172,13 @@ class DataFrame(

override def unionAll(other: DataFrame): DataFrame = Union(logicalPlan, other.logicalPlan)

def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan)
override def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan)

def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan)
override def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan)

override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
Sample(fraction, withReplacement, seed, logicalPlan)
}

/////////////////////////////////////////////////////////////////////////////

Expand Down Expand Up @@ -238,6 +244,18 @@ class DataFrame(
sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd
}

@Experimental
override def saveAsTable(tableName: String): Unit = {
sqlContext.executePlan(
CreateTableAsSelect(None, tableName, logicalPlan, allowExisting = false)).toRdd
}

@Experimental
override def insertInto(tableName: String, overwrite: Boolean): Unit = {
sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)),
Map.empty, logicalPlan, overwrite)).toRdd
}

override def toJSON: RDD[String] = {
val rowSchema = this.schema
this.mapPartitions { iter =>
Expand Down
19 changes: 18 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/api.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package org.apache.spark.sql

import scala.reflect.ClassTag

import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils


trait RDDApi[T] {
Expand Down Expand Up @@ -129,6 +131,12 @@ trait DataFrameSpecificApi {

def except(other: DataFrame): DataFrame

def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame

def sample(withReplacement: Boolean, fraction: Double): DataFrame = {
sample(withReplacement, fraction, Utils.random.nextLong)
}

/////////////////////////////////////////////////////////////////////////////
// Column mutation
/////////////////////////////////////////////////////////////////////////////
Expand All @@ -144,11 +152,20 @@ trait DataFrameSpecificApi {

def rdd: RDD[Row]

def toJSON: RDD[String]

def registerTempTable(tableName: String): Unit

def saveAsParquetFile(path: String): Unit

def toJSON: RDD[String]
@Experimental
def saveAsTable(tableName: String): Unit

@Experimental
def insertInto(tableName: String, overwrite: Boolean): Unit

@Experimental
def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false)

/////////////////////////////////////////////////////////////////////////////
// Stat functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.io._
import java.util.{ArrayList => JArrayList}

import jline.{ConsoleReader, History}

import org.apache.commons.lang.StringUtils
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.conf.Configuration
Expand All @@ -39,7 +40,6 @@ import org.apache.thrift.transport.TSocket

import org.apache.spark.Logging
import org.apache.spark.sql.hive.HiveShim
import org.apache.spark.sql.hive.thriftserver.HiveThriftServerShim

private[hive] object SparkSQLCLIDriver {
private var prompt = "spark-sql"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation
import org.apache.hive.service.cli.session.HiveSession

import org.apache.spark.Logging
import org.apache.spark.sql.{SQLConf, SchemaRDD, Row => SparkRow}
import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow}
import org.apache.spark.sql.execution.SetCommand
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
Expand Down Expand Up @@ -71,7 +71,7 @@ private[hive] class SparkExecuteStatementOperation(
sessionToActivePool: SMap[SessionHandle, String])
extends ExecuteStatementOperation(parentSession, statement, confOverlay) with Logging {

private var result: SchemaRDD = _
private var result: DataFrame = _
private var iter: Iterator[SparkRow] = _
private var dataTypes: Array[DataType] = _

Expand Down Expand Up @@ -202,7 +202,7 @@ private[hive] class SparkExecuteStatementOperation(
val useIncrementalCollect =
hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean
if (useIncrementalCollect) {
result.toLocalIterator
result.rdd.toLocalIterator
} else {
result.collect().iterator
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation
import org.apache.hive.service.cli.session.HiveSession

import org.apache.spark.Logging
import org.apache.spark.sql.{Row => SparkRow, SQLConf, SchemaRDD}
import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf}
import org.apache.spark.sql.execution.SetCommand
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
Expand Down Expand Up @@ -72,7 +72,7 @@ private[hive] class SparkExecuteStatementOperation(
// NOTE: `runInBackground` is set to `false` intentionally to disable asynchronous execution
extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) with Logging {

private var result: SchemaRDD = _
private var result: DataFrame = _
private var iter: Iterator[SparkRow] = _
private var dataTypes: Array[DataType] = _

Expand Down Expand Up @@ -173,7 +173,7 @@ private[hive] class SparkExecuteStatementOperation(
val useIncrementalCollect =
hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean
if (useIncrementalCollect) {
result.toLocalIterator
result.rdd.toLocalIterator
} else {
result.collect().iterator
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true"

override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution { val logical = plan }
new this.QueryExecution(plan)

override def sql(sqlText: String): SchemaRDD = {
override def sql(sqlText: String): DataFrame = {
val substituted = new VariableSubstitution().substitute(hiveconf, sqlText)
// TODO: Create a framework for registering parsers instead of just hardcoding if statements.
if (conf.dialect == "sql") {
super.sql(substituted)
} else if (conf.dialect == "hiveql") {
new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted)))
new DataFrame(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted)))
} else {
sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'")
}
Expand Down Expand Up @@ -352,7 +352,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
override protected[sql] val planner = hivePlanner

/** Extends QueryExecution with hive specific features. */
protected[sql] abstract class QueryExecution extends super.QueryExecution {
protected[sql] class QueryExecution(logicalPlan: LogicalPlan)
extends super.QueryExecution(logicalPlan) {

/**
* Returns the result as a hive compatible sequence of strings. For native commands, the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive
import scala.collection.JavaConversions._

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.{SQLContext, SchemaRDD, Strategy}
import org.apache.spark.sql.{Column, DataFrame, SQLContext, Strategy}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
Expand Down Expand Up @@ -55,16 +55,15 @@ private[hive] trait HiveStrategies {
*/
@Experimental
object ParquetConversion extends Strategy {
implicit class LogicalPlanHacks(s: SchemaRDD) {
def lowerCase =
new SchemaRDD(s.sqlContext, s.logicalPlan)
implicit class LogicalPlanHacks(s: DataFrame) {
def lowerCase = new DataFrame(s.sqlContext, s.logicalPlan)

def addPartitioningAttributes(attrs: Seq[Attribute]) = {
// Don't add the partitioning key if its already present in the data.
if (attrs.map(_.name).toSet.subsetOf(s.logicalPlan.output.map(_.name).toSet)) {
s
} else {
new SchemaRDD(
new DataFrame(
s.sqlContext,
s.logicalPlan transform {
case p: ParquetRelation => p.copy(partitioningAttributes = attrs)
Expand Down Expand Up @@ -97,13 +96,13 @@ private[hive] trait HiveStrategies {
// We are going to throw the predicates and projection back at the whole optimization
// sequence so lets unresolve all the attributes, allowing them to be rebound to the
// matching parquet attributes.
val unresolvedOtherPredicates = otherPredicates.map(_ transform {
val unresolvedOtherPredicates = new Column(otherPredicates.map(_ transform {
case a: AttributeReference => UnresolvedAttribute(a.name)
}).reduceOption(And).getOrElse(Literal(true))
}).reduceOption(And).getOrElse(Literal(true)))

val unresolvedProjection = projectList.map(_ transform {
val unresolvedProjection: Seq[Column] = projectList.map(_ transform {
case a: AttributeReference => UnresolvedAttribute(a.name)
})
}).map(new Column(_))

try {
if (relation.hiveQlTable.isPartitioned) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
override def runSqlHive(sql: String): Seq[String] = super.runSqlHive(rewritePaths(sql))

override def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution { val logical = plan }
new this.QueryExecution(plan)

/** Fewer partitions to speed up testing. */
protected[sql] override lazy val conf: SQLConf = new SQLConf {
Expand Down Expand Up @@ -150,16 +150,17 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {

val describedTable = "DESCRIBE (\\w+)".r

protected[hive] class HiveQLQueryExecution(hql: String) extends this.QueryExecution {
lazy val logical = HiveQl.parseSql(hql)
protected[hive] class HiveQLQueryExecution(hql: String)
extends this.QueryExecution(HiveQl.parseSql(hql)) {
def hiveExec() = runSqlHive(hql)
override def toString = hql + "\n" + super.toString
}

/**
* Override QueryExecution with special debug workflow.
*/
abstract class QueryExecution extends super.QueryExecution {
class QueryExecution(logicalPlan: LogicalPlan)
extends super.QueryExecution(logicalPlan) {
override lazy val analyzed = {
val describedTables = logical match {
case HiveNativeCommand(describedTable(tbl)) => tbl :: Nil
Expand Down
10 changes: 5 additions & 5 deletions sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer contains all of the keywords, or the
* none of keywords are listed in the answer
* @param rdd the [[SchemaRDD]] to be executed
* @param rdd the [[DataFrame]] to be executed
* @param exists true for make sure the keywords are listed in the output, otherwise
* to make sure none of the keyword are not listed in the output
* @param keywords keyword in string array
*/
def checkExistence(rdd: SchemaRDD, exists: Boolean, keywords: String*) {
def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) {
val outputs = rdd.collect().map(_.mkString).mkString
for (key <- keywords) {
if (exists) {
Expand All @@ -54,10 +54,10 @@ class QueryTest extends PlanTest {

/**
* Runs the plan and makes sure the answer matches the expected result.
* @param rdd the [[SchemaRDD]] to be executed
* @param rdd the [[DataFrame]] to be executed
* @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
*/
protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = {
protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = {
val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
Expand Down Expand Up @@ -101,7 +101,7 @@ class QueryTest extends PlanTest {
}
}

protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = {
protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = {
checkAnswer(rdd, Seq(expectedAnswer))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ package org.apache.spark.sql.hive
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.{QueryTest, SchemaRDD}
import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.storage.RDDBlockId

class CachedTableSuite extends QueryTest {
/**
* Throws a test failed exception when the number of cached tables differs from the expected
* number.
*/
def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = {
def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = {
val planWithCaching = query.queryExecution.withCachedData
val cachedData = planWithCaching collect {
case cached: InMemoryRelation => cached
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ import scala.util.Try
import org.apache.hadoop.hive.conf.HiveConf.ConfVars

import org.apache.spark.{SparkFiles, SparkException}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.dsl._
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.{SQLConf, Row, SchemaRDD}

case class TestData(a: Int, b: String)

Expand Down Expand Up @@ -473,7 +474,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
}
}

def isExplanation(result: SchemaRDD) = {
def isExplanation(result: DataFrame) = {
val explanation = result.select('plan).collect().map { case Row(plan: String) => plan }
explanation.contains("== Physical Plan ==")
}
Expand Down Expand Up @@ -842,7 +843,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
val testVal = "test.val.0"
val nonexistentKey = "nonexistent"
val KV = "([^=]+)=([^=]*)".r
def collectResults(rdd: SchemaRDD): Set[(String, String)] =
def collectResults(rdd: DataFrame): Set[(String, String)] =
rdd.collect().map {
case Row(key: String, value: String) => key -> value
case Row(KV(key, value)) => key -> value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.dsl._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.Row

import org.apache.spark.util.Utils

Expand Down Expand Up @@ -82,10 +83,10 @@ class HiveTableScanSuite extends HiveComparisonTest {
sql("create table spark_4959 (col1 string)")
sql("""insert into table spark_4959 select "hi" from src limit 1""")
table("spark_4959").select(
'col1.as('CaseSensitiveColName),
'col1.as('CaseSensitiveColName2)).registerTempTable("spark_4959_2")
'col1.as("CaseSensitiveColName"),
'col1.as("CaseSensitiveColName2")).registerTempTable("spark_4959_2")

assert(sql("select CaseSensitiveColName from spark_4959_2").first() === Row("hi"))
assert(sql("select casesensitivecolname from spark_4959_2").first() === Row("hi"))
assert(sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi"))
assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class HiveUdfSuite extends QueryTest {
| getStruct(1).f3,
| getStruct(1).f4,
| getStruct(1).f5 FROM src LIMIT 1
""".stripMargin).first() === Row(1, 2, 3, 4, 5))
""".stripMargin).head() === Row(1, 2, 3, 4, 5))
}

test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") {
Expand Down

0 comments on commit 6d53134

Please sign in to comment.