Skip to content

Commit

Permalink
Simplify conditional catalyst expressions generated by udf-compiler
Browse files Browse the repository at this point in the history
if (c) true else false => c
if (c) false else true => !c

Signed-off-by: Sean Lee <selee@nvidia.com>
  • Loading branch information
seanprime7 committed Apr 26, 2022
1 parent 961343b commit 633771a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ object CatalystExpressionBuilder extends Logging {
case And(Literal.TrueLiteral, c) => simplifyExpr(c)
case And(c, Literal.TrueLiteral) => simplifyExpr(c)
case And(Literal.FalseLiteral, _) => Literal.FalseLiteral
case And(_, Literal.FalseLiteral) => Literal.FalseLiteral
case And(c1@LessThan(s1, Literal(v1, t1)),
c2@LessThan(s2, Literal(v2, t2))) if s1 == s2 && t1 == t2 => {
t1 match {
Expand Down Expand Up @@ -347,6 +348,7 @@ object CatalystExpressionBuilder extends Logging {
}
case And(c1, c2) => And(simplifyExpr(c1), simplifyExpr(c2))
case Or(Literal.TrueLiteral, _) => Literal.TrueLiteral
case Or(_, Literal.TrueLiteral) => Literal.TrueLiteral
case Or(Literal.FalseLiteral, c) => simplifyExpr(c)
case Or(c, Literal.FalseLiteral) => simplifyExpr(c)
case Or(c1@GreaterThan(s1, Literal(v1, t1)),
Expand Down Expand Up @@ -374,6 +376,7 @@ object CatalystExpressionBuilder extends Logging {
case Not(LessThanOrEqual(c1, c2)) => GreaterThan(c1, c2)
case Not(GreaterThan(c1, c2)) => LessThanOrEqual(c1, c2)
case Not(GreaterThanOrEqual(c1, c2)) => LessThan(c1, c2)
case Not(c) => Not(simplifyExpr(c))
case EqualTo(Literal(v1, _), Literal(v2, _)) =>
if (v1 == v2) Literal.TrueLiteral else Literal.FalseLiteral
case LessThan(If(c1,
Expand Down Expand Up @@ -424,6 +427,9 @@ object CatalystExpressionBuilder extends Logging {
}
case If(c, Repr.ArrayBuffer(t), Repr.ArrayBuffer(f)) => Repr.ArrayBuffer(If(c, t, f))
case If(c, Repr.StringBuilder(t), Repr.StringBuilder(f)) => Repr.StringBuilder(If(c, t, f))
case If(c, Literal.TrueLiteral, Literal.FalseLiteral) => c
case If(c, Literal.FalseLiteral, Literal.TrueLiteral) => Not(c)
case If(c, t, f) => If(simplifyExpr(c), simplifyExpr(t), simplifyExpr(f))
case _ => expr
}
logDebug(s"[CatalystExpressionBuilder] simplify: ${expr} ==> ${res}")
Expand Down
24 changes: 24 additions & 0 deletions udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,30 @@ class OpcodeSuite extends FunSuite {
checkEquiv(result, ref)
}

test("Conditional simplification - if (c) true else false => c") {
val myudf: (Int) => Boolean = i => {
if (i < 20) true else false
}
val u = makeUdf(myudf)
val dataset = List(20, 19).toDF("x")
val result = dataset.withColumn("new", u(col("x")))
val ref = dataset.withColumn("new", col("x") < 20)
assert(udfIsCompiled(result))
assert(!result.queryExecution.analyzed.toString.contains("if"))
}

test("Conditional simplification - if (c) false else true => !c") {
val myudf: (Int) => Boolean = i => {
if (i < 20) false else true
}
val u = makeUdf(myudf)
val dataset = List(20, 19).toDF("x")
val result = dataset.withColumn("new", u(col("x")))
val ref = dataset.withColumn("new", col("x") >= 20)
checkEquiv(result, ref)
assert(!result.queryExecution.analyzed.toString.contains("if"))
}

test("LDC_W opcode") {
val myudf: () => String = () => {
val myString : String = "a"
Expand Down

0 comments on commit 633771a

Please sign in to comment.