Skip to content

Commit

Permalink
Reduce regex use in CAST (NVIDIA#1741)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove authored Feb 18, 2021
1 parent 457ce71 commit d994a98
Showing 1 changed file with 112 additions and 63 deletions.
175 changes: 112 additions & 63 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,18 @@ class CastExprMeta[INPUT <: CastBase](

object GpuCast {

private val DATE_REGEX_YYYY = "\\A\\d{4}\\Z"
private val DATE_REGEX_YYYY_MM = "\\A\\d{4}\\-\\d{2}\\Z"
private val DATE_REGEX_YYYY_MM_DD = "\\A\\d{4}\\-\\d{2}\\-\\d{2}([ T](:?[\\r\\n]|.)*)?\\Z"

private val TIMESTAMP_REGEX_YYYY = "\\A\\d{4}\\Z"
private val TIMESTAMP_REGEX_YYYY_MM = "\\A\\d{4}\\-\\d{2}[ ]?\\Z"
private val TIMESTAMP_REGEX_YYYY_MM_DD = "\\A\\d{4}\\-\\d{2}\\-\\d{2}[ ]?\\Z"
private val TIMESTAMP_REGEX_FULL =
"\\A\\d{4}\\-\\d{2}\\-\\d{2}[ T]\\d{2}:\\d{2}:\\d{2}\\.\\d{6}Z\\Z"
private val TIMESTAMP_REGEX_NO_DATE = "\\A[T]?(\\d{2}:\\d{2}:\\d{2}\\.\\d{6}Z)\\Z"

/**
* The length of a timestamp with 6 digits for microseconds followed by 'Z', such
* as "2020-01-01T12:34:56.123456Z".
*/
private val FULL_TIMESTAMP_LENGTH = 27

/**
* Regex for identifying strings that contain numeric values that can be casted to integral
* types. This includes floating point numbers but not numbers containing exponents.
Expand Down Expand Up @@ -629,21 +630,21 @@ case class GpuCast(
}

/**
* Parse dates that match the provided regex. This method does not close the `input`
* ColumnVector.
* Parse dates that match the provided length and format. This method does not
* close the `input` ColumnVector.
*
* @param input Input ColumnVector
* @param len The string length to match against
* @param cudfFormat The cuDF timestamp format to match against
* @return ColumnVector containing timestamps for input entries that match both
* the length and format, and null for other entries
*/
def convertDateOrNull(
def convertFixedLenDateOrNull(
input: ColumnVector,
regex: String,
len: Int,
cudfFormat: String): ColumnVector = {

val isValidDate = withResource(input.matchesRe(regex)) { isMatch =>
withResource(input.isTimestamp(cudfFormat)) { isTimestamp =>
isMatch.and(isTimestamp)
}
}

withResource(isValidDate) { isValidDate =>
withResource(isValidTimestamp(input, len, cudfFormat)) { isValidDate =>
withResource(input.asTimestampDays(cudfFormat)) { asDays =>
withResource(Scalar.fromNull(DType.TIMESTAMP_DAYS)) { nullScalar =>
isValidDate.ifElse(asDays, nullScalar)
Expand All @@ -653,21 +654,45 @@ case class GpuCast(
}

/** This method does not close the `input` ColumnVector. */
def convertDateOr(
def convertVarLenDateOr(
input: ColumnVector,
regex: String,
cudfFormat: String,
orElse: ColumnVector): ColumnVector = {

val isValidDate = withResource(input.matchesRe(regex)) { isMatch =>
withResource(input.isTimestamp(cudfFormat)) { isTimestamp =>
isMatch.and(isTimestamp)
withResource(orElse) { orElse =>
val isValidDate = withResource(input.matchesRe(regex)) { isMatch =>
withResource(input.isTimestamp(cudfFormat)) { isTimestamp =>
isMatch.and(isTimestamp)
}
}
withResource(isValidDate) { isValidDate =>
withResource(input.asTimestampDays(cudfFormat)) { asDays =>
isValidDate.ifElse(asDays, orElse)
}
}
}
}

withResource(isValidDate) { isValidDate =>
withResource(input.asTimestampDays(cudfFormat)) { asDays =>
withResource(orElse) { orElse =>
/**
* Parse dates that match the provided length and format. This method does not
* close the `input` ColumnVector.
*
* @param input Input ColumnVector
* @param len The string length to match against
* @param cudfFormat The cuDF timestamp format to match against
* @return ColumnVector containing timestamps for input entries that match both
* the length and format, and null for other entries
*/
def convertFixedLenDateOr(
input: ColumnVector,
len: Int,
cudfFormat: String,
orElse: ColumnVector): ColumnVector = {

withResource(orElse) { orElse =>
withResource(isValidTimestamp(input, len, cudfFormat)) { isValidDate =>
withResource(input.asTimestampDays(cudfFormat)) { asDays =>
isValidDate.ifElse(asDays, orElse)
}
}
Expand All @@ -691,9 +716,9 @@ case class GpuCast(
withResource(sanitizedInput) { sanitizedInput =>

// convert dates that are in valid formats yyyy, yyyy-mm, yyyy-mm-dd
val converted = convertDateOr(sanitizedInput, DATE_REGEX_YYYY_MM_DD, "%Y-%m-%d",
convertDateOr(sanitizedInput, DATE_REGEX_YYYY_MM, "%Y-%m",
convertDateOrNull(sanitizedInput, DATE_REGEX_YYYY, "%Y")))
val converted = convertVarLenDateOr(sanitizedInput, DATE_REGEX_YYYY_MM_DD, "%Y-%m-%d",
convertFixedLenDateOr(sanitizedInput, 7, "%Y-%m",
convertFixedLenDateOrNull(sanitizedInput, 4, "%Y")))

// handle special dates like "epoch", "now", etc.
specialDates.foldLeft(converted)((prev, specialDate) =>
Expand Down Expand Up @@ -728,72 +753,73 @@ case class GpuCast(
* Parse dates that match the the provided regex. This method does not close the `input`
* ColumnVector.
*/
def convertTimestampOrNull(
def convertFixedLenTimestampOrNull(
input: ColumnVector,
regex: String,
len: Int,
cudfFormat: String): ColumnVector = {

val isValidTimestamp = withResource(input.matchesRe(regex)) { isMatch =>
withResource(input.isTimestamp(cudfFormat)) { isTimestamp =>
isMatch.and(isTimestamp)
}
}

withResource(isValidTimestamp) { isValidTimestamp =>
withResource(isValidTimestamp(input, len, cudfFormat)) { isTimestamp =>
withResource(Scalar.fromNull(DType.TIMESTAMP_MICROSECONDS)) { nullScalar =>
withResource(input.asTimestampMicroseconds(cudfFormat)) { asDays =>
isValidTimestamp.ifElse(asDays, nullScalar)
isTimestamp.ifElse(asDays, nullScalar)
}
}
}
}

/** This method does not close the `input` ColumnVector. */
def convertTimestampOr(
def convertVarLenTimestampOr(
input: ColumnVector,
regex: String,
cudfFormat: String,
orElse: ColumnVector): ColumnVector = {

val isValidTimestamp = withResource(input.matchesRe(regex)) { isMatch =>
withResource(input.isTimestamp(cudfFormat)) { isTimestamp =>
isMatch.and(isTimestamp)
withResource(orElse) { orElse =>
val isValidTimestamp = withResource(input.matchesRe(regex)) { isMatch =>
withResource(input.isTimestamp(cudfFormat)) { isTimestamp =>
isMatch.and(isTimestamp)
}
}
}
withResource(isValidTimestamp) { isValidTimestamp =>
withResource(input.asTimestampMicroseconds(cudfFormat)) { asDays =>
withResource(orElse) { orElse =>
withResource(isValidTimestamp) { isValidTimestamp =>
withResource(input.asTimestampMicroseconds(cudfFormat)) { asDays =>
isValidTimestamp.ifElse(asDays, orElse)
}
}
}
}

/** This method does not close the `input` ColumnVector. */
def convertTimestampFullOr(
def convertFullTimestampOr(
input: ColumnVector,
orElse: ColumnVector): ColumnVector = {

val cudfFormat1 = "%Y-%m-%d %H:%M:%S.%f"
val cudfFormat2 = "%Y-%m-%dT%H:%M:%S.%f"

// valid dates must match the regex and either of the cuDF formats
val isCudfMatch = withResource(input.isTimestamp(cudfFormat1)) { isTimestamp1 =>
withResource(input.isTimestamp(cudfFormat2)) { isTimestamp2 =>
isTimestamp1.or(isTimestamp2)
withResource(orElse) { orElse =>

// valid dates must match the regex and either of the cuDF formats
val isCudfMatch = withResource(input.isTimestamp(cudfFormat1)) { isTimestamp1 =>
withResource(input.isTimestamp(cudfFormat2)) { isTimestamp2 =>
isTimestamp1.or(isTimestamp2)
}
}
}
val isValidTimestamp = withResource(isCudfMatch) { isCudfMatch =>
withResource(input.matchesRe(TIMESTAMP_REGEX_FULL)) { isRegexMatch =>
isCudfMatch.and(isRegexMatch)

val isValidTimestamp = withResource(isCudfMatch) { isCudfMatch =>
val isValidLength = withResource(Scalar.fromInt(FULL_TIMESTAMP_LENGTH)) { requiredLen =>
withResource(input.getCharLengths) { actualLen =>
requiredLen.equalTo(actualLen)
}
}
withResource(isValidLength) { isValidLength =>
isValidLength.and(isCudfMatch)
}
}
}

// we only need to parse with one of the cuDF formats because the parsing code ignores
// the ' ' or 'T' between the date and time components
withResource(isValidTimestamp) { isValidTimestamp =>
withResource(input.asTimestampMicroseconds(cudfFormat1)) { asDays =>
withResource(orElse) { orElse =>
// we only need to parse with one of the cuDF formats because the parsing code ignores
// the ' ' or 'T' between the date and time components
withResource(isValidTimestamp) { isValidTimestamp =>
withResource(input.asTimestampMicroseconds(cudfFormat1)) { asDays =>
isValidTimestamp.ifElse(asDays, orElse)
}
}
Expand Down Expand Up @@ -832,10 +858,10 @@ case class GpuCast(
withResource(sanitizedInput) { sanitizedInput =>
// convert dates that are in valid timestamp formats
val converted =
convertTimestampFullOr(sanitizedInput,
convertTimestampOr(sanitizedInput, TIMESTAMP_REGEX_YYYY_MM_DD, "%Y-%m-%d",
convertTimestampOr(sanitizedInput, TIMESTAMP_REGEX_YYYY_MM, "%Y-%m",
convertTimestampOrNull(sanitizedInput, TIMESTAMP_REGEX_YYYY, "%Y"))))
convertFullTimestampOr(sanitizedInput,
convertVarLenTimestampOr(sanitizedInput, TIMESTAMP_REGEX_YYYY_MM_DD, "%Y-%m-%d",
convertVarLenTimestampOr(sanitizedInput, TIMESTAMP_REGEX_YYYY_MM, "%Y-%m",
convertFixedLenTimestampOrNull(sanitizedInput, 4, "%Y"))))

// handle special dates like "epoch", "now", etc.
val finalResult = specialDates.foldLeft(converted)((prev, specialDate) =>
Expand All @@ -862,6 +888,29 @@ case class GpuCast(
}
}

/**
* Determine which timestamps are the specified length and also comply with the specified
* cuDF format string.
*
* @param input Input ColumnVector
* @param len The string length to match against
* @param cudfFormat The cuDF timestamp format to match against
* @return ColumnVector containing booleans representing which entries match both
* the length and format
*/
private def isValidTimestamp(input: ColumnVector, len: Int, cudfFormat: String) = {
val isCorrectLength = withResource(Scalar.fromInt(len)) { requiredLen =>
withResource(input.getCharLengths) { actualLen =>
requiredLen.equalTo(actualLen)
}
}
withResource(isCorrectLength) { isCorrectLength =>
withResource(input.isTimestamp(cudfFormat)) { isTimestamp =>
isCorrectLength.and(isTimestamp)
}
}
}

/**
* Cast column of long values to a smaller integral type (bytes, short, int).
*
Expand Down

0 comments on commit d994a98

Please sign in to comment.