diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1b1e2ad71e7c8..b2fc334ac893e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -543,27 +543,33 @@ object LikeSimplification extends Rule[LogicalPlan] { private val equalTo = "([^_%]*)".r def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case Like(input, Literal(pattern, StringType), escapeChar) => + case l @ Like(input, Literal(pattern, StringType), escapeChar) => if (pattern == null) { // If pattern is null, return null value directly, since "col like null" == null. Literal(null, BooleanType) } else { - val escapeStr = String.valueOf(escapeChar) pattern.toString match { - case startsWith(prefix) if !prefix.endsWith(escapeStr) => + // There are three different situations when pattern containing escapeChar: + // 1. pattern contains invalid escape sequence, e.g. 'm\aca' + // 2. pattern contains escaped wildcard character, e.g. 'ma\%ca' + // 3. pattern contains escaped escape character, e.g. 'ma\\ca' + // Although there are patterns can be optimized if we handle the escape first, we just + // skip this rule if pattern contains any escapeChar for simplicity. + case p if p.contains(escapeChar) => l + case startsWith(prefix) => StartsWith(input, Literal(prefix)) case endsWith(postfix) => EndsWith(input, Literal(postfix)) // 'a%a' pattern is basically same with 'a%' && '%a'. // However, the additional `Length` condition is required to prevent 'a' match 'a%a'. - case startsAndEndsWith(prefix, postfix) if !prefix.endsWith(escapeStr) => + case startsAndEndsWith(prefix, postfix) => And(GreaterThanOrEqual(Length(input), Literal(prefix.length + postfix.length)), And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix)))) - case contains(infix) if !infix.endsWith(escapeStr) => + case contains(infix) => Contains(input, Literal(infix)) case equalTo(str) => EqualTo(input, Literal(str)) - case _ => Like(input, Literal.create(pattern, StringType), escapeChar) + case _ => l } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala index 436f62e4225c8..1812dce0da426 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala @@ -116,4 +116,52 @@ class LikeSimplificationSuite extends PlanTest { val optimized2 = Optimize.execute(originalQuery2.analyze) comparePlans(optimized2, originalQuery2.analyze) } + + test("SPARK-33677: LikeSimplification should be skipped if pattern contains any escapeChar") { + val originalQuery1 = + testRelation + .where(('a like "abc%") || ('a like "\\abc%")) + val optimized1 = Optimize.execute(originalQuery1.analyze) + val correctAnswer1 = testRelation + .where(StartsWith('a, "abc") || ('a like "\\abc%")) + .analyze + comparePlans(optimized1, correctAnswer1) + + val originalQuery2 = + testRelation + .where(('a like "%xyz") || ('a like "%xyz\\")) + val optimized2 = Optimize.execute(originalQuery2.analyze) + val correctAnswer2 = testRelation + .where(EndsWith('a, "xyz") || ('a like "%xyz\\")) + .analyze + comparePlans(optimized2, correctAnswer2) + + val originalQuery3 = + testRelation + .where(('a like ("@bc%def", '@')) || ('a like "abc%def")) + val optimized3 = Optimize.execute(originalQuery3.analyze) + val correctAnswer3 = testRelation + .where(('a like ("@bc%def", '@')) || + (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def")))) + .analyze + comparePlans(optimized3, correctAnswer3) + + val originalQuery4 = + testRelation + .where(('a like "%mn%") || ('a like ("%mn%", '%'))) + val optimized4 = Optimize.execute(originalQuery4.analyze) + val correctAnswer4 = testRelation + .where(Contains('a, "mn") || ('a like ("%mn%", '%'))) + .analyze + comparePlans(optimized4, correctAnswer4) + + val originalQuery5 = + testRelation + .where(('a like "abc") || ('a like ("abbc", 'b'))) + val optimized5 = Optimize.execute(originalQuery5.analyze) + val correctAnswer5 = testRelation + .where(('a === "abc") || ('a like ("abbc", 'b'))) + .analyze + comparePlans(optimized5, correctAnswer5) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 727482e551a8b..2eeb729ece3fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3718,6 +3718,20 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } + + test("SPARK-33677: LikeSimplification should be skipped if pattern contains any escapeChar") { + withTempView("df") { + Seq("m@ca").toDF("s").createOrReplaceTempView("df") + + val e = intercept[AnalysisException] { + sql("SELECT s LIKE 'm%@ca' ESCAPE '%' FROM df").collect() + } + assert(e.message.contains("the pattern 'm%@ca' is invalid, " + + "the escape character is not allowed to precede '@'")) + + checkAnswer(sql("SELECT s LIKE 'm@@ca' ESCAPE '@' FROM df"), Row(true)) + } + } } case class Foo(bar: Option[String])