Skip to content

Commit

Permalink
add CaseKeyWhen
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed May 7, 2015
1 parent 4f87e95 commit 3ce54e1
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,13 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
| LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) }
| IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^
{ case c ~ t ~ f => If(c, t, f) }
| CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~
| CASE ~> expression.? ~ rep1(WHEN ~> expression ~ (THEN ~> expression)) ~
(ELSE ~> expression).? <~ END ^^ {
case casePart ~ altPart ~ elsePart =>
val altExprs = altPart.flatMap { case whenExpr ~ thenExpr =>
Seq(casePart.fold(whenExpr)(EqualTo(_, whenExpr)), thenExpr)
}
CaseWhen(altExprs ++ elsePart.toList)
val branches = altPart.flatMap { case whenExpr ~ thenExpr =>
Seq(whenExpr, thenExpr)
} ++ elsePart
casePart.map(CaseKeyWhen(_, branches)).getOrElse(CaseWhen(branches))
}
| (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^
{ case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -631,31 +631,24 @@ trait HiveTypeCoercion {
import HiveTypeCoercion._

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) =>
val valueTypes = branches.sliding(2, 2).map {
case Seq(_, value) => value.dataType
case Seq(elseVal) => elseVal.dataType
}.toSeq

logDebug(s"Input values for null casting ${valueTypes.mkString(",")}")

if (valueTypes.distinct.size > 1) {
val commonType = valueTypes.reduce { (v1, v2) =>
findTightestCommonType(v1, v2)
.getOrElse(sys.error(
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
}
val transformedBranches = branches.sliding(2, 2).map {
case Seq(cond, value) if value.dataType != commonType =>
Seq(cond, Cast(value, commonType))
case Seq(elseVal) if elseVal.dataType != commonType =>
Seq(Cast(elseVal, commonType))
case s => s
}.reduce(_ ++ _)
CaseWhen(transformedBranches)
} else {
// Types match up. Hopefully some other rule fixes whatever is wrong with resolution.
cw
case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual =>
logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}")
val commonType = cw.valueTypes.reduce { (v1, v2) =>
findTightestCommonType(v1, v2).getOrElse(sys.error(
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
}
val transformedBranches = cw.branches.sliding(2, 2).map {
case Seq(when, value) if value.dataType != commonType =>
Seq(when, Cast(value, commonType))
case Seq(elseVal) if elseVal.dataType != commonType =>
Seq(Cast(elseVal, commonType))
case s => s
}.reduce(_ ++ _)
cw match {
case _: CaseWhen =>
CaseWhen(transformedBranches)
case CaseKeyWhen(key, _) =>
CaseKeyWhen(key, transformedBranches)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ abstract class Expression extends TreeNode[Expression] {
* Returns true if all the children of this expression have been resolved to a specific schema
* and false if any still contains any unresolved placeholders.
*/
def childrenResolved: Boolean = !children.exists(!_.resolved)
def childrenResolved: Boolean = children.forall(_.resolved)

/**
* Returns a string representation of this expression that does not have developer centric
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,79 +353,134 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
override def toString: String = s"if ($predicate) $trueValue else $falseValue"
}

trait CaseWhenLike extends Expression {
self: Product =>

type EvaluatedType = Any

// Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
// element is the value for the default catch-all case (if provided).
// Hence, `branches` consists of at least two elements, and can have an odd or even length.
def branches: Seq[Expression]

@transient lazy val whenList =
branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq
@transient lazy val thenList =
branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq
val elseValue = if (branches.length % 2 == 0) None else Option(branches.last)

// both then and else val should be considered.
def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType)
def valueTypesEqual: Boolean = valueTypes.distinct.size <= 1

override def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
}
valueTypes.head
}

override def nullable: Boolean = {
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
}
}

// scalastyle:off
/**
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
* Refer to this link for the corresponding semantics:
* https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
*
* The other form of case statements "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END" gets
* translated to this form at parsing time. Namely, such a statement gets translated to
* "CASE WHEN a=b THEN c [WHEN a=d THEN e]* [ELSE f] END".
*
* Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
* element is the value for the default catch-all case (if provided). Hence, `branches` consists of
* at least two elements, and can have an odd or even length.
*/
// scalastyle:on
case class CaseWhen(branches: Seq[Expression]) extends Expression {
type EvaluatedType = Any
case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {

// Use private[this] Array to speed up evaluation.
@transient private[this] lazy val branchesArr = branches.toArray

override def children: Seq[Expression] = branches

override def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
override lazy val resolved: Boolean =
childrenResolved &&
whenList.forall(_.dataType == BooleanType) &&
valueTypesEqual

/** Written in imperative fashion for performance considerations. */
override def eval(input: Row): Any = {
val len = branchesArr.length
var i = 0
// If all branches fail and an elseVal is not provided, the whole statement
// defaults to null, according to Hive's semantics.
while (i < len - 1) {
if (branchesArr(i).eval(input) == true) {
return branchesArr(i + 1).eval(input)
}
i += 2
}
var res: Any = null
if (i == len - 1) {
res = branchesArr(i).eval(input)
}
branches(1).dataType
return res
}

override def toString: String = {
"CASE" + branches.sliding(2, 2).map {
case Seq(cond, value) => s" WHEN $cond THEN $value"
case Seq(elseValue) => s" ELSE $elseValue"
}.mkString
}
}

// scalastyle:off
/**
* Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
* Refer to this link for the corresponding semantics:
* https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
*/
// scalastyle:on
case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike {

// Use private[this] Array to speed up evaluation.
@transient private[this] lazy val branchesArr = branches.toArray
@transient private[this] lazy val predicates =
branches.sliding(2, 2).collect { case Seq(cond, _) => cond }.toSeq
@transient private[this] lazy val values =
branches.sliding(2, 2).collect { case Seq(_, value) => value }.toSeq
@transient private[this] lazy val elseValue =
if (branches.length % 2 == 0) None else Option(branches.last)

override def nullable: Boolean = {
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
}
override def children: Seq[Expression] = key +: branches

override lazy val resolved: Boolean = {
if (!childrenResolved) {
false
} else {
val allCondBooleans = predicates.forall(_.dataType == BooleanType)
// both then and else val should be considered.
val dataTypesEqual = (values ++ elseValue).map(_.dataType).distinct.size <= 1
allCondBooleans && dataTypesEqual
}
}
override lazy val resolved: Boolean =
childrenResolved && valueTypesEqual

/** Written in imperative fashion for performance considerations. */
override def eval(input: Row): Any = {
val evaluatedKey = key.eval(input)
val len = branchesArr.length
var i = 0
// If all branches fail and an elseVal is not provided, the whole statement
// defaults to null, according to Hive's semantics.
var res: Any = null
while (i < len - 1) {
if (branchesArr(i).eval(input) == true) {
res = branchesArr(i + 1).eval(input)
return res
if (equalNullSafe(evaluatedKey, branchesArr(i).eval(input))) {
return branchesArr(i + 1).eval(input)
}
i += 2
}
var res: Any = null
if (i == len - 1) {
res = branchesArr(i).eval(input)
}
res
return res
}

private def equalNullSafe(l: Any, r: Any) = {
if (l == null && r == null) {
true
} else if (l == null || r == null) {
false
} else {
l == r
}
}

override def toString: String = {
"CASE" + branches.sliding(2, 2).map {
s"CASE $key" + branches.sliding(2, 2).map {
case Seq(cond, value) => s" WHEN $cond THEN $value"
case Seq(elseValue) => s" ELSE $elseValue"
}.mkString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,32 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true)
}

test("case key when") {
val row = create_row(null, 1, 2, "a", "b", "c")
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
val c4 = 'a.string.at(3)
val c5 = 'a.string.at(4)
val c6 = 'a.string.at(5)

val literalNull = Literal.create(null, BooleanType)
val literalInt = Literal(1)
val literalString = Literal("a")

checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, c5)), "b", row)
checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, literalNull, c5, c6)), "b", row)
checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row)
checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row)
checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row)
checkEvaluation(CaseKeyWhen(c4, Seq(c1, c3, c5, c2, Literal(3))), 3, row)

checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row)
checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row)
checkEvaluation(CaseKeyWhen(literalInt, Seq(c5, c2, c4, c3)), null, row)
checkEvaluation(CaseKeyWhen(literalNull, Seq(c5, c2, c1, c3)), 2, row)
}

test("complex type") {
val row = create_row(
"^Ba*n", // 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,12 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* TODO: This can be optimized to use broadcast join when replacementMap is large.
*/
private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = {
val branches: Seq[Expression] = replacementMap.flatMap { case (source, target) =>
df.col(col.name).equalTo(lit(source).cast(col.dataType)).expr ::
lit(target).cast(col.dataType).expr :: Nil
val keyExpr = df.col(col.name).expr
def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
val branches = replacementMap.flatMap { case (source, target) =>
Seq(buildExpr(source), buildExpr(target))
}.toSeq
new Column(CaseWhen(branches ++ Seq(df.col(col.name).expr))).as(col.name)
new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name)
}

private def convertToDouble(v: Any): Double = v match {
Expand Down
12 changes: 2 additions & 10 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1249,16 +1249,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) =>
CaseWhen(branches.map(nodeToExpr))
case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) =>
val transformed = branches.drop(1).sliding(2, 2).map {
case Seq(condVal, value) =>
// FIXME (SPARK-2155): the key will get evaluated for multiple times in CaseWhen's eval().
// Hence effectful / non-deterministic key expressions are *not* supported at the moment.
// We should consider adding new Expressions to get around this.
Seq(EqualTo(nodeToExpr(branches(0)), nodeToExpr(condVal)),
nodeToExpr(value))
case Seq(elseVal) => Seq(nodeToExpr(elseVal))
}.toSeq.reduce(_ ++ _)
CaseWhen(transformed)
val keyExpr = nodeToExpr(branches.head)
CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr))

/* Complex datatype manipulation */
case Token("[", child :: ordinal :: Nil) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -751,4 +751,11 @@ class SQLQuerySuite extends QueryTest {
(6, "c", 0, 6)
).map(i => Row(i._1, i._2, i._3, i._4)))
}

test("test case key when") {
(1 to 5).map(i => (i, i.toString)).toDF("k", "v").registerTempTable("t")
checkAnswer(
sql("SELECT CASE k WHEN 2 THEN 22 WHEN 4 THEN 44 ELSE 0 END, v FROM t"),
Row(0, "1") :: Row(22, "2") :: Row(0, "3") :: Row(44, "4") :: Row(0, "5") :: Nil)
}
}

0 comments on commit 3ce54e1

Please sign in to comment.