diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index 6fa5c106c8b..090ad4e2493 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -90,17 +90,18 @@ class CastExprMeta[INPUT <: CastBase]( object GpuCast { - private val DATE_REGEX_YYYY = "\\A\\d{4}\\Z" - private val DATE_REGEX_YYYY_MM = "\\A\\d{4}\\-\\d{2}\\Z" private val DATE_REGEX_YYYY_MM_DD = "\\A\\d{4}\\-\\d{2}\\-\\d{2}([ T](:?[\\r\\n]|.)*)?\\Z" - private val TIMESTAMP_REGEX_YYYY = "\\A\\d{4}\\Z" private val TIMESTAMP_REGEX_YYYY_MM = "\\A\\d{4}\\-\\d{2}[ ]?\\Z" private val TIMESTAMP_REGEX_YYYY_MM_DD = "\\A\\d{4}\\-\\d{2}\\-\\d{2}[ ]?\\Z" - private val TIMESTAMP_REGEX_FULL = - "\\A\\d{4}\\-\\d{2}\\-\\d{2}[ T]\\d{2}:\\d{2}:\\d{2}\\.\\d{6}Z\\Z" private val TIMESTAMP_REGEX_NO_DATE = "\\A[T]?(\\d{2}:\\d{2}:\\d{2}\\.\\d{6}Z)\\Z" + /** + * The length of a timestamp with 6 digits for microseconds followed by 'Z', such + * as "2020-01-01T12:34:56.123456Z". + */ + private val FULL_TIMESTAMP_LENGTH = 27 + /** * Regex for identifying strings that contain numeric values that can be casted to integral * types. This includes floating point numbers but not numbers containing exponents. @@ -629,21 +630,21 @@ case class GpuCast( } /** - * Parse dates that match the provided regex. This method does not close the `input` - * ColumnVector. + * Parse dates that match the provided length and format. This method does not + * close the `input` ColumnVector. + * + * @param input Input ColumnVector + * @param len The string length to match against + * @param cudfFormat The cuDF timestamp format to match against + * @return ColumnVector containing timestamps for input entries that match both + * the length and format, and null for other entries */ - def convertDateOrNull( + def convertFixedLenDateOrNull( input: ColumnVector, - regex: String, + len: Int, cudfFormat: String): ColumnVector = { - val isValidDate = withResource(input.matchesRe(regex)) { isMatch => - withResource(input.isTimestamp(cudfFormat)) { isTimestamp => - isMatch.and(isTimestamp) - } - } - - withResource(isValidDate) { isValidDate => + withResource(isValidTimestamp(input, len, cudfFormat)) { isValidDate => withResource(input.asTimestampDays(cudfFormat)) { asDays => withResource(Scalar.fromNull(DType.TIMESTAMP_DAYS)) { nullScalar => isValidDate.ifElse(asDays, nullScalar) @@ -653,21 +654,45 @@ case class GpuCast( } /** This method does not close the `input` ColumnVector. */ - def convertDateOr( + def convertVarLenDateOr( input: ColumnVector, regex: String, cudfFormat: String, orElse: ColumnVector): ColumnVector = { - val isValidDate = withResource(input.matchesRe(regex)) { isMatch => - withResource(input.isTimestamp(cudfFormat)) { isTimestamp => - isMatch.and(isTimestamp) + withResource(orElse) { orElse => + val isValidDate = withResource(input.matchesRe(regex)) { isMatch => + withResource(input.isTimestamp(cudfFormat)) { isTimestamp => + isMatch.and(isTimestamp) + } + } + withResource(isValidDate) { isValidDate => + withResource(input.asTimestampDays(cudfFormat)) { asDays => + isValidDate.ifElse(asDays, orElse) + } } } + } - withResource(isValidDate) { isValidDate => - withResource(input.asTimestampDays(cudfFormat)) { asDays => - withResource(orElse) { orElse => + /** + * Parse dates that match the provided length and format. This method does not + * close the `input` ColumnVector. + * + * @param input Input ColumnVector + * @param len The string length to match against + * @param cudfFormat The cuDF timestamp format to match against + * @return ColumnVector containing timestamps for input entries that match both + * the length and format, and null for other entries + */ + def convertFixedLenDateOr( + input: ColumnVector, + len: Int, + cudfFormat: String, + orElse: ColumnVector): ColumnVector = { + + withResource(orElse) { orElse => + withResource(isValidTimestamp(input, len, cudfFormat)) { isValidDate => + withResource(input.asTimestampDays(cudfFormat)) { asDays => isValidDate.ifElse(asDays, orElse) } } @@ -691,9 +716,9 @@ case class GpuCast( withResource(sanitizedInput) { sanitizedInput => // convert dates that are in valid formats yyyy, yyyy-mm, yyyy-mm-dd - val converted = convertDateOr(sanitizedInput, DATE_REGEX_YYYY_MM_DD, "%Y-%m-%d", - convertDateOr(sanitizedInput, DATE_REGEX_YYYY_MM, "%Y-%m", - convertDateOrNull(sanitizedInput, DATE_REGEX_YYYY, "%Y"))) + val converted = convertVarLenDateOr(sanitizedInput, DATE_REGEX_YYYY_MM_DD, "%Y-%m-%d", + convertFixedLenDateOr(sanitizedInput, 7, "%Y-%m", + convertFixedLenDateOrNull(sanitizedInput, 4, "%Y"))) // handle special dates like "epoch", "now", etc. specialDates.foldLeft(converted)((prev, specialDate) => @@ -728,41 +753,35 @@ case class GpuCast( * Parse dates that match the the provided regex. This method does not close the `input` * ColumnVector. */ - def convertTimestampOrNull( + def convertFixedLenTimestampOrNull( input: ColumnVector, - regex: String, + len: Int, cudfFormat: String): ColumnVector = { - val isValidTimestamp = withResource(input.matchesRe(regex)) { isMatch => - withResource(input.isTimestamp(cudfFormat)) { isTimestamp => - isMatch.and(isTimestamp) - } - } - - withResource(isValidTimestamp) { isValidTimestamp => + withResource(isValidTimestamp(input, len, cudfFormat)) { isTimestamp => withResource(Scalar.fromNull(DType.TIMESTAMP_MICROSECONDS)) { nullScalar => withResource(input.asTimestampMicroseconds(cudfFormat)) { asDays => - isValidTimestamp.ifElse(asDays, nullScalar) + isTimestamp.ifElse(asDays, nullScalar) } } } } /** This method does not close the `input` ColumnVector. */ - def convertTimestampOr( + def convertVarLenTimestampOr( input: ColumnVector, regex: String, cudfFormat: String, orElse: ColumnVector): ColumnVector = { - val isValidTimestamp = withResource(input.matchesRe(regex)) { isMatch => - withResource(input.isTimestamp(cudfFormat)) { isTimestamp => - isMatch.and(isTimestamp) + withResource(orElse) { orElse => + val isValidTimestamp = withResource(input.matchesRe(regex)) { isMatch => + withResource(input.isTimestamp(cudfFormat)) { isTimestamp => + isMatch.and(isTimestamp) + } } - } - withResource(isValidTimestamp) { isValidTimestamp => - withResource(input.asTimestampMicroseconds(cudfFormat)) { asDays => - withResource(orElse) { orElse => + withResource(isValidTimestamp) { isValidTimestamp => + withResource(input.asTimestampMicroseconds(cudfFormat)) { asDays => isValidTimestamp.ifElse(asDays, orElse) } } @@ -770,30 +789,37 @@ case class GpuCast( } /** This method does not close the `input` ColumnVector. */ - def convertTimestampFullOr( + def convertFullTimestampOr( input: ColumnVector, orElse: ColumnVector): ColumnVector = { val cudfFormat1 = "%Y-%m-%d %H:%M:%S.%f" val cudfFormat2 = "%Y-%m-%dT%H:%M:%S.%f" - // valid dates must match the regex and either of the cuDF formats - val isCudfMatch = withResource(input.isTimestamp(cudfFormat1)) { isTimestamp1 => - withResource(input.isTimestamp(cudfFormat2)) { isTimestamp2 => - isTimestamp1.or(isTimestamp2) + withResource(orElse) { orElse => + + // valid dates must match the regex and either of the cuDF formats + val isCudfMatch = withResource(input.isTimestamp(cudfFormat1)) { isTimestamp1 => + withResource(input.isTimestamp(cudfFormat2)) { isTimestamp2 => + isTimestamp1.or(isTimestamp2) + } } - } - val isValidTimestamp = withResource(isCudfMatch) { isCudfMatch => - withResource(input.matchesRe(TIMESTAMP_REGEX_FULL)) { isRegexMatch => - isCudfMatch.and(isRegexMatch) + + val isValidTimestamp = withResource(isCudfMatch) { isCudfMatch => + val isValidLength = withResource(Scalar.fromInt(FULL_TIMESTAMP_LENGTH)) { requiredLen => + withResource(input.getCharLengths) { actualLen => + requiredLen.equalTo(actualLen) + } + } + withResource(isValidLength) { isValidLength => + isValidLength.and(isCudfMatch) + } } - } - // we only need to parse with one of the cuDF formats because the parsing code ignores - // the ' ' or 'T' between the date and time components - withResource(isValidTimestamp) { isValidTimestamp => - withResource(input.asTimestampMicroseconds(cudfFormat1)) { asDays => - withResource(orElse) { orElse => + // we only need to parse with one of the cuDF formats because the parsing code ignores + // the ' ' or 'T' between the date and time components + withResource(isValidTimestamp) { isValidTimestamp => + withResource(input.asTimestampMicroseconds(cudfFormat1)) { asDays => isValidTimestamp.ifElse(asDays, orElse) } } @@ -832,10 +858,10 @@ case class GpuCast( withResource(sanitizedInput) { sanitizedInput => // convert dates that are in valid timestamp formats val converted = - convertTimestampFullOr(sanitizedInput, - convertTimestampOr(sanitizedInput, TIMESTAMP_REGEX_YYYY_MM_DD, "%Y-%m-%d", - convertTimestampOr(sanitizedInput, TIMESTAMP_REGEX_YYYY_MM, "%Y-%m", - convertTimestampOrNull(sanitizedInput, TIMESTAMP_REGEX_YYYY, "%Y")))) + convertFullTimestampOr(sanitizedInput, + convertVarLenTimestampOr(sanitizedInput, TIMESTAMP_REGEX_YYYY_MM_DD, "%Y-%m-%d", + convertVarLenTimestampOr(sanitizedInput, TIMESTAMP_REGEX_YYYY_MM, "%Y-%m", + convertFixedLenTimestampOrNull(sanitizedInput, 4, "%Y")))) // handle special dates like "epoch", "now", etc. val finalResult = specialDates.foldLeft(converted)((prev, specialDate) => @@ -862,6 +888,29 @@ case class GpuCast( } } + /** + * Determine which timestamps are the specified length and also comply with the specified + * cuDF format string. + * + * @param input Input ColumnVector + * @param len The string length to match against + * @param cudfFormat The cuDF timestamp format to match against + * @return ColumnVector containing booleans representing which entries match both + * the length and format + */ + private def isValidTimestamp(input: ColumnVector, len: Int, cudfFormat: String) = { + val isCorrectLength = withResource(Scalar.fromInt(len)) { requiredLen => + withResource(input.getCharLengths) { actualLen => + requiredLen.equalTo(actualLen) + } + } + withResource(isCorrectLength) { isCorrectLength => + withResource(input.isTimestamp(cudfFormat)) { isTimestamp => + isCorrectLength.and(isTimestamp) + } + } + } + /** * Cast column of long values to a smaller integral type (bytes, short, int). *