From 92370d04511d21cd1bbf2d1dcd62fd8734883e9f Mon Sep 17 00:00:00 2001 From: Firestarman Date: Mon, 4 Dec 2023 19:41:33 +0800 Subject: [PATCH] more formats support by additional conversions for FromUnixTime Signed-off-by: Firestarman --- .../nvidia/spark/rapids/GpuOverrides.scala | 6 +- .../sql/rapids/datetimeExpressions.scala | 65 ++++++++++++++++++- 2 files changed, 63 insertions(+), 8 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 4d45dacfd0d..553017cb6ff 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1770,11 +1770,7 @@ object GpuOverrides extends Logging { ("format", TypeSig.lit(TypeEnum.STRING) .withPsNote(TypeEnum.STRING, "Only a limited number of formats are supported"), TypeSig.STRING)), - (a, conf, p, r) => new UnixTimeExprMeta[FromUnixTime](a, conf, p, r) { - override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = - // passing the already converted strf string for a little optimization - GpuFromUnixTime(lhs, rhs, strfFormat) - }), + (a, conf, p, r) => new FromUnitTimeMeta(a ,conf ,p ,r)), expr[FromUTCTimestamp]( "Render the input UTC timestamp in the input timezone", ExprChecks.binaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP, diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index 8f6c591787f..9b59c1772c4 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -27,7 +27,7 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.jni.GpuTimeZoneDB import com.nvidia.spark.rapids.shims.ShimBinaryExpression -import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUTCTimestamp, ImplicitCastInputTypes, NullIntolerant, TimeZoneAwareExpression} +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUnixTime, FromUTCTimestamp, ImplicitCastInputTypes, NullIntolerant, TimeZoneAwareExpression} import org.apache.spark.sql.catalyst.util.DateTimeConstants import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -997,27 +997,86 @@ case class GpuGetTimestamp( override def right: Expression = format } +class FromUnitTimeMeta(a: FromUnixTime, + override val conf: RapidsConf, + val p: Option[RapidsMeta[_, _, _]], + r: DataFromReplacementRule) extends UnixTimeExprMeta[FromUnixTime](a, conf, p, r) { + + private type FmtConverter = ColumnView => ColumnVector + + private var colConverter: Option[FmtConverter] = None + + /** + * More supported formats by additional conversions. + * Need to remove the entry if the key format is supported by cuDF. + */ + private val FORMATS_BY_CONVERSION: Map[String, (String, FmtConverter)] = Map( + // spark format -> (intermediate format, converter) + "yyyyMMdd" -> (("yyyy-MM-dd", + col => { + withResource(Scalar.fromString("-")) { dashStr => + withResource(Scalar.fromString("")) { emptyStr => + col.stringReplace(dashStr, emptyStr) + } + } + } + )) + ) + + override def tagExprForGpu(): Unit = { + extractStringLit(a.right) match { + case Some(rightLit) => + sparkFormat = rightLit + val inputFormat = FORMATS_BY_CONVERSION.get(sparkFormat) + .map { case (tempFormat, converter) => + colConverter = Some(converter) + tempFormat + }.getOrElse(sparkFormat) + strfFormat = DateUtils.tagAndGetCudfFormat(this, + inputFormat, a.left.dataType == DataTypes.StringType) + case None => + willNotWorkOnGpu("format has to be a string literal") + } + } + + override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = + // passing the already converted strf string for a little optimization + GpuFromUnixTime(lhs, rhs, strfFormat, colConverter) +} + case class GpuFromUnixTime( sec: Expression, format: Expression, strfFormat: String, + colConverter: Option[ColumnView => ColumnVector], timeZoneId: Option[String] = None) extends GpuBinaryExpressionArgsAnyScalar with TimeZoneAwareExpression with ImplicitCastInputTypes { + // To avoid duplicated "if...else" for each input batch + private val convertFunc: ColumnVector => ColumnVector = { + if (colConverter.isDefined) { + col => withResource(col)(colConverter.get.apply) + } else { + col => col + } + } + override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = { // we aren't using rhs as it was already converted in the GpuOverrides while creating the // expressions map and passed down here as strfFormat - withResource(lhs.getBase.asTimestampSeconds) { tsVector => + val ret = withResource(lhs.getBase.asTimestampSeconds) { tsVector => tsVector.asStrings(strfFormat) } + convertFunc(ret) } override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = { - withResource(GpuColumnVector.from(lhs, numRows, left.dataType)) { expandedLhs => + val ret = withResource(GpuColumnVector.from(lhs, numRows, left.dataType)) { expandedLhs => doColumnar(expandedLhs, rhs) } + convertFunc(ret) } override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {