Skip to content

Commit

Permalink
Implement AST-based regular expression fuzz tests (#4504)
Browse files Browse the repository at this point in the history
* Regexp AST fuzz test

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* bug fix

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* skip known issues

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* code cleanup and skip some known issues

* remove debug println

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* add more capabilities

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* save progress

* Implement hex and octal generators

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* prepare for review

* update URL for Pattern javadoc
  • Loading branch information
andygrove authored Jan 22, 2022
1 parent a409d0e commit 4349acd
Show file tree
Hide file tree
Showing 2 changed files with 273 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -716,9 +716,9 @@ sealed case class RegexOctalChar(a: String) extends RegexCharacterClassComponent
override def toRegexString: String = s"\\$a"
}

sealed case class RegexChar(a: Char) extends RegexCharacterClassComponent {
sealed case class RegexChar(ch: Char) extends RegexCharacterClassComponent {
override def children(): Seq[RegexAST] = Seq.empty
override def toRegexString: String = s"$a"
override def toRegexString: String = s"$ch"
}

sealed case class RegexEscaped(a: Char) extends RegexCharacterClassComponent{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package com.nvidia.spark.rapids

import java.util.regex.Pattern

import scala.collection.mutable.ListBuffer
import scala.collection.mutable.{HashSet, ListBuffer}
import scala.util.{Random, Try}

import ai.rapids.cudf.{ColumnVector, CudfException}
Expand Down Expand Up @@ -346,18 +346,56 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
.map(_ => r.nextString())

// generate patterns that are valid on both CPU and GPU
val patterns = ListBuffer[String]()
while (patterns.length < 5000) {
val patterns = HashSet[String]()
while (patterns.size < 5000) {
val pattern = r.nextString()
if (Try(Pattern.compile(pattern)).isSuccess && Try(transpile(pattern, replace)).isSuccess) {
patterns += pattern
if (!patterns.contains(pattern)) {
if (Try(Pattern.compile(pattern)).isSuccess && Try(transpile(pattern, replace)).isSuccess) {
patterns += pattern
}
}
}

if (replace) {
assertCpuGpuMatchesRegexpReplace(patterns, data)
assertCpuGpuMatchesRegexpReplace(patterns.toSeq, data)
} else {
assertCpuGpuMatchesRegexpFind(patterns, data)
assertCpuGpuMatchesRegexpFind(patterns.toSeq, data)
}
}

test("AST fuzz test - regexp_find") {
doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_FIND), replace = false)
}

test("AST fuzz test - regexp_replace") {
doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_REPLACE), replace = true)
}

private def doAstFuzzTest(validChars: Option[String], replace: Boolean) {

val r = new EnhancedRandom(new Random(seed = 0L),
FuzzerOptions(validChars, maxStringLen = 12))

val fuzzer = new FuzzRegExp(REGEXP_LIMITED_CHARS_FIND)

val data = Range(0, 1000)
.map(_ => r.nextString())

// generate patterns that are valid on both CPU and GPU
val patterns = HashSet[String]()
while (patterns.size < 5000) {
val pattern = fuzzer.generate(0).toRegexString
if (!patterns.contains(pattern)) {
if (Try(Pattern.compile(pattern)).isSuccess && Try(transpile(pattern, replace)).isSuccess) {
patterns += pattern
}
}
}

if (replace) {
assertCpuGpuMatchesRegexpReplace(patterns.toSeq, data)
} else {
assertCpuGpuMatchesRegexpFind(patterns.toSeq, data)
}
}

Expand Down Expand Up @@ -475,3 +513,229 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
private def parse(pattern: String): RegexAST = new RegexParser(pattern).parse()

}

/**
* Generates random regular expression patterns by building an AST and then
* converting to a string. This results in better coverage than randomly
* generating strings directly.
*
* See https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html for
* Java regular expression syntax.
*/
class FuzzRegExp(suggestedChars: String, skipKnownIssues: Boolean = true) {
private val maxDepth = 5
private val rr = new Random(0)

val chars = if (skipKnownIssues) {
// skip '\\' and '-' due to https://github.com/NVIDIA/spark-rapids/issues/4505
suggestedChars.filterNot(ch => "\\-".contains(ch))
} else {
suggestedChars
}

def generate(depth: Int): RegexAST = {
if (depth == maxDepth) {
// when we reach maximum depth we generate a non-nested type
nonNestedTerm
} else {
val baseGenerators: Seq[() => RegexAST] = Seq(
() => lineTerminator,
() => escapedChar,
() => char,
() => hexDigit,
() => octalDigit,
() => characterClass,
() => predefinedCharacterClass,
() => group(depth),
() => boundaryMatch,
() => sequence(depth))
val generators = if (skipKnownIssues) {
baseGenerators
} else {
baseGenerators ++ Seq(
() => repetition(depth), // https://github.com/NVIDIA/spark-rapids/issues/4487
() => choice(depth)) // https://github.com/NVIDIA/spark-rapids/issues/4603
}
generators(rr.nextInt(generators.length))()
}
}

private def nonNestedTerm: RegexAST = {
val generators: Seq[() => RegexAST] = Seq(
() => lineTerminator,
() => escapedChar,
() => char,
() => hexDigit,
() => octalDigit,
() => charRange,
() => boundaryMatch,
() => predefinedCharacterClass
)
generators(rr.nextInt(generators.length))()
}

private def characterClassComponent = {
val baseGenerators = Seq[() => RegexCharacterClassComponent](
() => char,
() => charRange)
val generators = if (skipKnownIssues) {
baseGenerators
} else {
baseGenerators ++ Seq(
() => escapedChar, // https://github.com/NVIDIA/spark-rapids/issues/4505
() => hexDigit, // https://github.com/NVIDIA/spark-rapids/issues/4486
() => octalDigit) // https://github.com/NVIDIA/spark-rapids/issues/4409
}
generators(rr.nextInt(generators.length))()
}

private def charRange: RegexCharacterClassComponent = {
val baseGenerators = Seq[() => RegexCharacterClassComponent](
() => RegexCharacterRange('a', 'z'),
() => RegexCharacterRange('A', 'Z'),
() => RegexCharacterRange('z', 'a'),
() => RegexCharacterRange('Z', 'A'),
() => RegexCharacterRange('0', '9'),
() => RegexCharacterRange('9', '0')
)
val generators = if (skipKnownIssues) {
baseGenerators
} else {
// we do not support escaped characters in character ranges yet
// see https://github.com/NVIDIA/spark-rapids/issues/4505
baseGenerators ++ Seq(() => RegexCharacterRange(char.ch, char.ch))
}
generators(rr.nextInt(generators.length))()
}

private def sequence(depth: Int) = {
val b = new ListBuffer[RegexAST]()
b.appendAll(Range(0, 3).map(_ => generate(depth + 1)))
RegexSequence(b)
}

private def characterClass = {
val characters = new ListBuffer[RegexCharacterClassComponent]()
characters.appendAll(Range(0, 3).map(_ => characterClassComponent))
RegexCharacterClass(negated = rr.nextBoolean(), characters = characters)
}

private def char: RegexChar = {
RegexChar(chars(rr.nextInt(chars.length)))
}

/** Any escaped character */
private def escapedChar: RegexEscaped = {
RegexEscaped(char.ch)
}

private def lineTerminator: RegexAST = {
val generators = Seq[() => RegexAST](
() => RegexChar('\r'),
() => RegexChar('\n'),
() => RegexSequence(ListBuffer(RegexChar('\r'), RegexChar('\n'))),
() => RegexChar('\u0085'),
() => RegexChar('\u2028'),
() => RegexChar('\u2029')
)
generators(rr.nextInt(generators.length))()
}

private def boundaryMatch: RegexAST = {
val generators = Seq[() => RegexAST](
() => RegexChar('^'),
() => RegexChar('$'),
() => RegexEscaped('b'),
() => RegexEscaped('B'),
() => RegexEscaped('A'),
() => RegexEscaped('G'),
() => RegexEscaped('Z'),
() => RegexEscaped('z')
)
generators(rr.nextInt(generators.length))()
}

private def predefinedCharacterClass: RegexAST = {
val generators = Seq[() => RegexAST](
() => RegexChar('.'),
() => RegexEscaped('d'),
() => RegexEscaped('D'),
() => RegexEscaped('s'),
() => RegexEscaped('S'),
() => RegexEscaped('w'),
() => RegexEscaped('W')
)
generators(rr.nextInt(generators.length))()
}

private def hexDigit: RegexHexDigit = {
// \\xhh The character with hexadecimal value 0xhh
// \\uhhhh The character with hexadecimal value 0xhhhh
// \\x{h...h} The character with hexadecimal value 0xh...h
// (Character.MIN_CODE_POINT <= 0xh...h <= Character.MAX_CODE_POINT)
val generators: Seq[() => String] = Seq(
() => rr.nextInt(0xFF).toHexString,
() => rr.nextInt(0xFFFF).toHexString,
() => Character.MIN_CODE_POINT.toHexString,
() => Character.MAX_CODE_POINT.toHexString
)
RegexHexDigit(generators(rr.nextInt(generators.length))())
}

private def octalDigit: RegexOctalChar = {
// \\0n The character with octal value 0n (0 <= n <= 7)
// \\0nn The character with octal value 0nn (0 <= n <= 7)
// \\0mnn The character with octal value 0mnn (0 <= m <= 3, 0 <= n <= 7)
val chars = "01234567"
val generators: Seq[() => String] = Seq(
() => Range(0,1).map(_ => chars(rr.nextInt(chars.length))).mkString,
() => Range(0,2).map(_ => chars(rr.nextInt(chars.length))).mkString,
() =>
// this will generate some invalid octal numbers were the first digit > 3
Range(0,3).map(_ => chars(rr.nextInt(chars.length))).mkString
)
RegexOctalChar("0" + generators(rr.nextInt(generators.length))())
}

private def choice(depth: Int) = {
RegexChoice(generate(depth + 1), generate(depth + 1))
}

private def group(depth: Int) = {
RegexGroup(capture = rr.nextBoolean(), generate(depth + 1))
}

private def repetition(depth: Int) = {
val generators = Seq(
() =>
// greedy quantifier
RegexRepetition(generate(depth + 1), quantifier),
() =>
// reluctant quantifier
RegexRepetition(RegexRepetition(generate(depth + 1), quantifier),
SimpleQuantifier('?')),
() =>
// possessive quantifier
RegexRepetition(RegexRepetition(generate(depth + 1), quantifier),
SimpleQuantifier('+'))
)
generators(rr.nextInt(generators.length))()
}

private def quantifier: RegexQuantifier = {
val generators = Seq[() => RegexQuantifier](
() => SimpleQuantifier('+'),
() => SimpleQuantifier('*'),
() => SimpleQuantifier('?'),
() => QuantifierFixedLength(rr.nextInt(3)),
() => QuantifierVariableLength(rr.nextInt(3), None),
() => {
// this intentionally generates some invalid quantifiers where the maxLength
// is less than the minLength, such as "{2,1}" which should be handled as a
// literal string match on "{2,1}" rather than as a valid quantifier.
QuantifierVariableLength(rr.nextInt(3), Some(rr.nextInt(3)))
}
)
generators(rr.nextInt(generators.length))()
}
}

0 comments on commit 4349acd

Please sign in to comment.