Skip to content

Commit

Permalink
More formats support by additional conversions for FromUnixTime (#8)
Browse files Browse the repository at this point in the history
Support format yyyyMMdd for GpuFromUnixTime

Signed-off-by: Firestarman <firestarmanllc@gmail.com>
  • Loading branch information
firestarman authored Dec 5, 2023
1 parent efa2a4e commit 6b4463d
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1770,11 +1770,7 @@ object GpuOverrides extends Logging {
("format", TypeSig.lit(TypeEnum.STRING)
.withPsNote(TypeEnum.STRING, "Only a limited number of formats are supported"),
TypeSig.STRING)),
(a, conf, p, r) => new UnixTimeExprMeta[FromUnixTime](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
// passing the already converted strf string for a little optimization
GpuFromUnixTime(lhs, rhs, strfFormat)
}),
(a, conf, p, r) => new FromUnitTimeMeta(a ,conf ,p ,r)),
expr[FromUTCTimestamp](
"Render the input UTC timestamp in the input timezone",
ExprChecks.binaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.jni.GpuTimeZoneDB
import com.nvidia.spark.rapids.shims.ShimBinaryExpression

import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUTCTimestamp, ImplicitCastInputTypes, NullIntolerant, TimeZoneAwareExpression}
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUnixTime, FromUTCTimestamp, ImplicitCastInputTypes, NullIntolerant, TimeZoneAwareExpression}
import org.apache.spark.sql.catalyst.util.DateTimeConstants
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -997,27 +997,86 @@ case class GpuGetTimestamp(
override def right: Expression = format
}

class FromUnitTimeMeta(a: FromUnixTime,
override val conf: RapidsConf,
val p: Option[RapidsMeta[_, _, _]],
r: DataFromReplacementRule) extends UnixTimeExprMeta[FromUnixTime](a, conf, p, r) {

private type FmtConverter = ColumnView => ColumnVector

private var colConverter: Option[FmtConverter] = None

/**
* More supported formats by additional conversions.
* Need to remove the entry if the key format is supported by cuDF.
*/
private val FORMATS_BY_CONVERSION: Map[String, (String, FmtConverter)] = Map(
// spark format -> (intermediate format, converter)
"yyyyMMdd" -> (("yyyy-MM-dd",
col => {
withResource(Scalar.fromString("-")) { dashStr =>
withResource(Scalar.fromString("")) { emptyStr =>
col.stringReplace(dashStr, emptyStr)
}
}
}
))
)

override def tagExprForGpu(): Unit = {
extractStringLit(a.right) match {
case Some(rightLit) =>
sparkFormat = rightLit
val inputFormat = FORMATS_BY_CONVERSION.get(sparkFormat)
.map { case (tempFormat, converter) =>
colConverter = Some(converter)
tempFormat
}.getOrElse(sparkFormat)
strfFormat = DateUtils.tagAndGetCudfFormat(this,
inputFormat, a.left.dataType == DataTypes.StringType)
case None =>
willNotWorkOnGpu("format has to be a string literal")
}
}

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
// passing the already converted strf string for a little optimization
GpuFromUnixTime(lhs, rhs, strfFormat, colConverter)
}

case class GpuFromUnixTime(
sec: Expression,
format: Expression,
strfFormat: String,
colConverter: Option[ColumnView => ColumnVector],
timeZoneId: Option[String] = None)
extends GpuBinaryExpressionArgsAnyScalar
with TimeZoneAwareExpression
with ImplicitCastInputTypes {

// To avoid duplicated "if...else" for each input batch
private val convertFunc: ColumnVector => ColumnVector = {
if (colConverter.isDefined) {
col => withResource(col)(colConverter.get.apply)
} else {
col => col
}
}

override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = {
// we aren't using rhs as it was already converted in the GpuOverrides while creating the
// expressions map and passed down here as strfFormat
withResource(lhs.getBase.asTimestampSeconds) { tsVector =>
val ret = withResource(lhs.getBase.asTimestampSeconds) { tsVector =>
tsVector.asStrings(strfFormat)
}
convertFunc(ret)
}

override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = {
withResource(GpuColumnVector.from(lhs, numRows, left.dataType)) { expandedLhs =>
val ret = withResource(GpuColumnVector.from(lhs, numRows, left.dataType)) { expandedLhs =>
doColumnar(expandedLhs, rhs)
}
convertFunc(ret)
}

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
Expand Down

0 comments on commit 6b4463d

Please sign in to comment.