Skip to content

Commit

Permalink
Improve CAST string to float implementation to handle more edge cases (
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove authored Jul 20, 2021
1 parent 345cd70 commit 2f965b6
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 61 deletions.
150 changes: 97 additions & 53 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,76 @@ object GpuCast extends Arm {
})
}

def sanitizeStringToFloat(input: ColumnVector, ansiEnabled: Boolean): ColumnVector = {

// This regex gets applied after the transformation to normalize use of Inf and is
// just strict enough to filter out known edge cases that would result in incorrect
// values. We further filter out invalid values using the cuDF isFloat method.
val VALID_FLOAT_REGEX =
"^" + // start of line
"[+\\-]?" + // optional + or - at start of string
"(" +
"(" +
"(" +
"([0-9]+)|" + // digits, OR
"([0-9]*\\.[0-9]+)|" + // decimal with optional leading and mandatory trailing, OR
"([0-9]+\\.[0-9]*)" + // decimal with mandatory leading and optional trailing
")" +
"([eE][+\\-]?[0-9]+)?" + // exponent
"[fFdD]?" + // floating-point designator
")" +
"|Inf" + // Infinity
"|[nN][aA][nN]" + // NaN
")" +
"$" // end of line

withResource(input.lstrip()) { stripped =>
withResource(GpuScalar.from(null, DataTypes.StringType)) { nullString =>
// filter out strings containing breaking whitespace
val withoutWhitespace = withResource(ColumnVector.fromStrings("\r", "\n")) {
verticalWhitespace =>
withResource(stripped.contains(verticalWhitespace)) {
_.ifElse(nullString, stripped)
}
}
// replace all possible versions of "Inf" and "Infinity" with "Inf"
val inf = withResource(withoutWhitespace) { _ =>
withoutWhitespace.stringReplaceWithBackrefs(
"(?:[iI][nN][fF])" + "(?:[iI][nN][iI][tT][yY])?", "Inf")
}
// replace "+Inf" with "Inf" because cuDF only supports "Inf" and "-Inf"
val infWithoutPlus = withResource(inf) { _ =>
withResource(GpuScalar.from("+Inf", DataTypes.StringType)) { search =>
withResource(GpuScalar.from("Inf", DataTypes.StringType)) { replace =>
inf.stringReplace(search, replace)
}
}
}
// filter out any strings that are not valid floating point numbers according
// to the regex pattern
val floatOrNull = withResource(infWithoutPlus) { _ =>
withResource(infWithoutPlus.matchesRe(VALID_FLOAT_REGEX)) { isFloat =>
if (ansiEnabled) {
withResource(isFloat.all()) { allMatch =>
// Check that all non-null values are valid floats.
if (allMatch.isValid && !allMatch.getBoolean) {
throw new NumberFormatException(GpuCast.INVALID_FLOAT_CAST_MSG)
}
infWithoutPlus.incRefCount()
}
} else {
isFloat.ifElse(infWithoutPlus, nullString)
}
}
}
// strip floating-point designator 'f' or 'd' but don't strip the 'f' from 'Inf'
withResource(floatOrNull) {
_.stringReplaceWithBackrefs("([^n])[fFdD]$", "\\1")
}
}
}
}

def sanitizeStringToIntegralType(input: ColumnVector, ansiEnabled: Boolean): ColumnVector = {
// Convert any strings containing whitespace to null values. The input is assumed to already
// have been stripped of leading and trailing whitespace
Expand All @@ -200,9 +270,8 @@ object GpuCast extends Arm {
val regex = "^[+\\-]?[0-9]+$"
withResource(sanitized.matchesRe(regex)) { isInt =>
withResource(isInt.all()) { allInts =>
// Check that all non-null values are valid integers. Note that allInts will be false
// if all rows are null so we need to check for that condition.
if (!allInts.getBoolean && sanitized.getNullCount != sanitized.getRowCount) {
// Check that all non-null values are valid integers.
if (allInts.isValid && !allInts.getBoolean) {
throw new NumberFormatException(GpuCast.INVALID_INPUT_MESSAGE)
}
}
Expand Down Expand Up @@ -724,7 +793,7 @@ case class GpuCast(
// in ansi mode, fail if any values are not valid bool strings
if (ansiEnabled) {
withResource(validBools.all()) { isAllBool =>
if (!isAllBool.getBoolean) {
if (isAllBool.isValid && !isAllBool.getBoolean) {
throw new IllegalStateException(GpuCast.INVALID_INPUT_MESSAGE)
}
}
Expand Down Expand Up @@ -753,9 +822,8 @@ case class GpuCast(
withResource(sanitized.isInteger(dType)) { isInt =>
if (ansiEnabled) {
withResource(isInt.all()) { allInts =>
// Check that all non-null values are valid integers. Note that allInts will be false
// if all rows are null so we need to check for that condition.
if (!allInts.getBoolean && sanitized.getNullCount != sanitized.getRowCount) {
// Check that all non-null values are valid integers.
if (allInts.isValid && !allInts.getBoolean) {
throw new IllegalStateException(GpuCast.INVALID_INPUT_MESSAGE)
}
}
Expand All @@ -774,59 +842,35 @@ case class GpuCast(
ansiEnabled: Boolean,
dType: DType): ColumnVector = {

// TODO: since cudf doesn't support case-insensitive regex, we have to generate all
// possible strings. But these should cover most of the cases
val POS_INF_REGEX = "^[+]?(?:infinity|inf|Infinity|Inf|INF|INFINITY)$"
val NEG_INF_REGEX = "^[\\-](?:infinity|inf|Infinity|Inf|INF|INFINITY)$"
val NAN_REGEX = "^(?:nan|NaN|NAN)$"


// 1. convert the different infinities to "Inf"/"-Inf" which is the only variation cudf
// understands
// 2. identify the nans
// 3. identify the floats. "nan", "null" and letters are not considered floats
// 4. if ansi is enabled we want to throw and exception if the string is neither float nor nan
// 5. convert everything thats not floats to null
// 6. set the indices where we originally had nans to Float.NaN
//
// NOTE Limitation: "1.7976931348623159E308" and "-1.7976931348623159E308" are not considered
// Inf even though spark does

if (ansiEnabled && input.hasNulls()) {
throw new NumberFormatException(GpuCast.INVALID_FLOAT_CAST_MSG)
}
// First replace different spellings/cases of infinity with Inf and -Infinity with -Inf
val posInfReplaced = withResource(input.matchesRe(POS_INF_REGEX)) { containsInf =>
withResource(Scalar.fromString("Inf")) { inf =>
containsInf.ifElse(inf, input)
}
}
val withPosNegInfinityReplaced = withResource(posInfReplaced) { withPositiveInfinityReplaced =>
withResource(withPositiveInfinityReplaced.matchesRe(NEG_INF_REGEX)) { containsNegInf =>
withResource(Scalar.fromString("-Inf")) { negInf =>
containsNegInf.ifElse(negInf, withPositiveInfinityReplaced)
}
}
}
withResource(withPosNegInfinityReplaced) { withPosNegInfinityReplaced =>
// 1. convert the different infinities to "Inf"/"-Inf" which is the only variation cudf
// understands
// 2. identify the nans
// 3. identify the floats. "nan", "null" and letters are not considered floats
// 4. if ansi is enabled we want to throw and exception if the string is neither float nor nan
// 5. convert everything thats not floats to null
// 6. set the indices where we originally had nans to Float.NaN
//
// NOTE Limitation: "1.7976931348623159E308" and "-1.7976931348623159E308" are not considered
// Inf even though Spark does

val NAN_REGEX = "^[nN][aA][nN]$"

withResource(GpuCast.sanitizeStringToFloat(input, ansiEnabled)) { sanitized =>
//Now identify the different variations of nans
withResource(withPosNegInfinityReplaced.matchesRe(NAN_REGEX)) { isNan =>
withResource(sanitized.matchesRe(NAN_REGEX)) { isNan =>
// now check if the values are floats
withResource(withPosNegInfinityReplaced.isFloat()) { isFloat =>
withResource(sanitized.isFloat) { isFloat =>
if (ansiEnabled) {
withResource(isNan.not()) { notNan =>
withResource(isFloat.not()) { notFloat =>
withResource(notFloat.and(notNan)) { notFloatAndNotNan =>
withResource(notFloatAndNotNan.any()) { notNanAndNotFloat =>
if (notNanAndNotFloat.getBoolean()) {
throw new NumberFormatException(GpuCast.INVALID_FLOAT_CAST_MSG)
}
}
withResource(isNan.or(isFloat)) { nanOrFloat =>
withResource(nanOrFloat.all()) { allNanOrFloat =>
// Check that all non-null values are valid floats or NaN.
if (allNanOrFloat.isValid && !allNanOrFloat.getBoolean) {
throw new NumberFormatException(GpuCast.INVALID_FLOAT_CAST_MSG)
}
}
}
}
withResource(withPosNegInfinityReplaced.castTo(dType)) { casted =>
withResource(sanitized.castTo(dType)) { casted =>
withResource(Scalar.fromNull(dType)) { nulls =>
withResource(isFloat.ifElse(casted, nulls)) { floatsOnly =>
withResource(FloatUtils.getNanScalar(dType)) { nan =>
Expand Down
52 changes: 44 additions & 8 deletions tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class CastOpSuite extends GpuExpressionTestSuite {
}

private val BOOL_CHARS = " \t\r\nFALSEfalseTRUEtrue01yesYESnoNO"
private val NUMERIC_CHARS = "inf \t\r\n0123456789.+-eE"
private val NUMERIC_CHARS = "infinityINFINITY \t\r\n0123456789.+-eEfFdD"
private val DATE_CHARS = " \t\r\n0123456789:-/TZ"

test("Cast from string to boolean using random inputs") {
Expand Down Expand Up @@ -122,15 +122,28 @@ class CastOpSuite extends GpuExpressionTestSuite {
testCastStringTo(DataTypes.LongType, generateRandomStrings(Some(NUMERIC_CHARS)))
}

ignore("Cast from string to float using random inputs") {
// Test ignored due to known issues
// https://github.com/NVIDIA/spark-rapids/issues/2900
test("Cast from string to float using random inputs") {
testCastStringTo(DataTypes.FloatType, generateRandomStrings(Some(NUMERIC_CHARS)))
}

ignore("Cast from string to double using random inputs") {
// Test ignored due to known issues
// https://github.com/NVIDIA/spark-rapids/issues/2900
test("Cast from string to float using hand-picked values") {
testCastStringTo(DataTypes.FloatType, Seq(".", "e", "Infinity", "+Infinity", "-Infinity",
"+nAn", "-naN", "Nan", "5f", "1.2f", "\riNf", null))
}

test("Cast from string to float ANSI mode with nulls") {
testCastStringTo(DataTypes.FloatType, Seq(null, null, null), ansiMode = AnsiExpectSuccess)
}

test("Cast from string to float ANSI mode with invalid values") {
val values = Seq(".", "e")
// test the values individually
for (value <- values ) {
testCastStringTo(DataTypes.FloatType, Seq(value), ansiMode = AnsiExpectFailure)
}
}

test("Cast from string to double using random inputs") {
testCastStringTo(DataTypes.DoubleType, generateRandomStrings(Some(NUMERIC_CHARS)))
}

Expand Down Expand Up @@ -214,8 +227,11 @@ class CastOpSuite extends GpuExpressionTestSuite {
assert(cpuRow.getInt(INDEX_ID) === gpuRow.getInt(INDEX_ID))
val cpuValue = cpuRow.get(INDEX_C1)
val gpuValue = gpuRow.get(INDEX_C1)
if (!compare(cpuValue, gpuValue)) {
if (!compare(cpuValue, gpuValue, epsilon = 0.0001)) {
val inputValue = cpuRow.getString(INDEX_C0)
.replace("\r", "\\r")
.replace("\t", "\\t")
.replace("\n", "\\n")
fail(s"Mismatch casting string [$inputValue] " +
s"to $toType. CPU: $cpuValue; GPU: $gpuValue")
}
Expand Down Expand Up @@ -872,6 +888,26 @@ class CastOpSuite extends GpuExpressionTestSuite {
}
}

test("CAST string to float - sanitize step") {
val testPairs = Seq(
("\tinf", "Inf"),
("\t+InFinITy", "Inf"),
("\tInFinITy", "Inf"),
("\t-InFinITy", "-Inf"),
("\t61f", "61"),
(".8E4f", ".8E4")
)
val inputs = testPairs.map(_._1)
val expected = testPairs.map(_._2)
withResource(ColumnVector.fromStrings(inputs: _*)) { v =>
withResource(ColumnVector.fromStrings(expected: _*)) { expected =>
withResource(GpuCast.sanitizeStringToFloat(v, ansiEnabled = false)) { actual =>
CudfTestHelper.assertColumnsAreEqual(expected, actual)
}
}
}
}

test("CAST string to integer - sanitize step") {
val testPairs: Seq[(String, String)] = Seq(
(null, null),
Expand Down

0 comments on commit 2f965b6

Please sign in to comment.