Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port whole parsePartitions method from Spark3.3 to Gpu side #6048

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.types.shims

import java.time.ZoneId

import org.apache.spark.sql.types.DataType

object PartitionValueCastShims {
// AnyTimestamp, TimestampNTZTtpe and AnsiIntervalType types are not defined before Spark 3.2.0
// return false between 311 until 320
def isSupportedType(dt: DataType): Boolean = false

def castTo(desiredType: DataType, value: String, zoneId: ZoneId): Any = {
throw new IllegalArgumentException(s"Unexpected type $desiredType")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.types.shims

import java.time.ZoneId

import scala.util.Try

import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.unescapePathName
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
import org.apache.spark.sql.types.{AnsiIntervalType, AnyTimestampType, DataType, DateType}

object PartitionValueCastShims {
def isSupportedType(dt: DataType): Boolean = dt match {
// Timestamp types
case dt if AnyTimestampType.acceptsType(dt) => true
case it: AnsiIntervalType => true
case _ => false
}

// Only for TimestampType and TimestampNTZType
def castTo(desiredType: DataType, value: String, zoneId: ZoneId): Any = desiredType match {
// Copied from org/apache/spark/sql/execution/datasources/PartitionUtils.scala
case dt if AnyTimestampType.acceptsType(desiredType) =>
tgravescs marked this conversation as resolved.
Show resolved Hide resolved
Try {
Cast(Literal(unescapePathName(value)), dt, Some(zoneId.getId)).eval()
}.getOrElse {
Cast(Cast(Literal(value), DateType, Some(zoneId.getId)), dt).eval()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.util.rapids

import java.text.SimpleDateFormat
import java.time.LocalDate
import java.util.{Date, Locale}

import org.apache.commons.lang3.time.FastDateFormat

import org.apache.spark.sql.catalyst.util.{DateTimeFormatterHelper, DateTimeUtils, LegacyDateFormats}
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy._

// Copied from org/apache/spark/sql/catalyst/util/DateFormatter
// for https://github.com/NVIDIA/spark-rapids/issues/6026
// It can be removed when Spark3.3 is the least supported Spark version
sealed trait DateFormatter extends Serializable {
def parse(s: String): Int // returns days since epoch

def format(days: Int): String
def format(date: Date): String
def format(localDate: LocalDate): String

def validatePatternString(): Unit
}

class Iso8601DateFormatter(
pattern: String,
locale: Locale,
legacyFormat: LegacyDateFormats.LegacyDateFormat,
isParsing: Boolean)
extends DateFormatter with DateTimeFormatterHelper {

@transient
private lazy val formatter = getOrCreateFormatter(pattern, locale, isParsing)

@transient
protected lazy val legacyFormatter =
DateFormatter.getLegacyFormatter(pattern, locale, legacyFormat)

override def parse(s: String): Int = {
try {
val localDate = toLocalDate(formatter.parse(s))
localDateToDays(localDate)
} catch checkParsedDiff(s, legacyFormatter.parse)
}

override def format(localDate: LocalDate): String = {
try {
localDate.format(formatter)
} catch checkFormattedDiff(toJavaDate(localDateToDays(localDate)),
(d: Date) => format(d))
}

override def format(days: Int): String = {
format(LocalDate.ofEpochDay(days))
}

override def format(date: Date): String = {
legacyFormatter.format(date)
}

override def validatePatternString(): Unit = {
try {
formatter
} catch checkLegacyFormatter(pattern, legacyFormatter.validatePatternString)
()
}
}

/**
* The formatter for dates which doesn't require users to specify a pattern. While formatting,
* it uses the default pattern [[DateFormatter.defaultPattern]]. In parsing, it follows the CAST
* logic in conversion of strings to Catalyst's DateType.
*
* @param locale The locale overrides the system locale and is used in formatting.
* @param legacyFormat Defines the formatter used for legacy dates.
* @param isParsing Whether the formatter is used for parsing (`true`) or for formatting (`false`).
*/
class DefaultDateFormatter(
locale: Locale,
legacyFormat: LegacyDateFormats.LegacyDateFormat,
isParsing: Boolean)
extends Iso8601DateFormatter(DateFormatter.defaultPattern, locale, legacyFormat, isParsing) {

}

trait LegacyDateFormatter extends DateFormatter {
def parseToDate(s: String): Date

override def parse(s: String): Int = {
fromJavaDate(new java.sql.Date(parseToDate(s).getTime))
}

override def format(days: Int): String = {
format(DateTimeUtils.toJavaDate(days))
}

override def format(localDate: LocalDate): String = {
format(localDateToDays(localDate))
}
}

/**
* The legacy formatter is based on Apache Commons FastDateFormat. The formatter uses the default
* JVM time zone intentionally for compatibility with Spark 2.4 and earlier versions.
*
* Note: Using of the default JVM time zone makes the formatter compatible with the legacy
* `DateTimeUtils` methods `toJavaDate` and `fromJavaDate` that are based on the default
* JVM time zone too.
*
* @param pattern `java.text.SimpleDateFormat` compatible pattern.
* @param locale The locale overrides the system locale and is used in parsing/formatting.
*/
class LegacyFastDateFormatter(pattern: String, locale: Locale) extends LegacyDateFormatter {
@transient
private lazy val fdf = FastDateFormat.getInstance(pattern, locale)
override def parseToDate(s: String): Date = fdf.parse(s)
override def format(d: Date): String = fdf.format(d)
override def validatePatternString(): Unit = fdf
}

// scalastyle:off line.size.limit
/**
* The legacy formatter is based on `java.text.SimpleDateFormat`. The formatter uses the default
* JVM time zone intentionally for compatibility with Spark 2.4 and earlier versions.
*
* Note: Using of the default JVM time zone makes the formatter compatible with the legacy
* `DateTimeUtils` methods `toJavaDate` and `fromJavaDate` that are based on the default
* JVM time zone too.
*
* @param pattern The pattern describing the date and time format.
* See <a href="https://docs.oracle.com/javase/7/docs/api/java/text/SimpleDateFormat.html">
* Date and Time Patterns</a>
* @param locale The locale whose date format symbols should be used. It overrides the system
* locale in parsing/formatting.
*/
// scalastyle:on line.size.limit
class LegacySimpleDateFormatter(pattern: String, locale: Locale) extends LegacyDateFormatter {
@transient
private lazy val sdf = new SimpleDateFormat(pattern, locale)
override def parseToDate(s: String): Date = sdf.parse(s)
override def format(d: Date): String = sdf.format(d)
override def validatePatternString(): Unit = sdf

}

object DateFormatter {
import LegacyDateFormats._

val defaultLocale: Locale = Locale.US

val defaultPattern: String = "yyyy-MM-dd"

private def getFormatter(
format: Option[String],
locale: Locale = defaultLocale,
legacyFormat: LegacyDateFormat = LENIENT_SIMPLE_DATE_FORMAT,
isParsing: Boolean): DateFormatter = {
if (SQLConf.get.legacyTimeParserPolicy == LEGACY) {
getLegacyFormatter(format.getOrElse(defaultPattern), locale, legacyFormat)
} else {
val df = format
.map(new Iso8601DateFormatter(_, locale, legacyFormat, isParsing))
.getOrElse(new DefaultDateFormatter(locale, legacyFormat, isParsing))
df.validatePatternString()
df
}
}

def getLegacyFormatter(
pattern: String,
locale: Locale,
legacyFormat: LegacyDateFormat): DateFormatter = {
legacyFormat match {
case FAST_DATE_FORMAT =>
new LegacyFastDateFormatter(pattern, locale)
case SIMPLE_DATE_FORMAT | LENIENT_SIMPLE_DATE_FORMAT =>
new LegacySimpleDateFormatter(pattern, locale)
}
}

def apply(
format: Option[String],
locale: Locale,
legacyFormat: LegacyDateFormat,
isParsing: Boolean): DateFormatter = {
getFormatter(format, locale, legacyFormat, isParsing)
}

def apply(
format: String,
locale: Locale,
legacyFormat: LegacyDateFormat,
isParsing: Boolean): DateFormatter = {
getFormatter(Some(format), locale, legacyFormat, isParsing)
}

def apply(format: String, isParsing: Boolean = false): DateFormatter = {
getFormatter(Some(format), isParsing = isParsing)
}

def apply(): DateFormatter = {
getFormatter(None, isParsing = false)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.util.rapids

import java.text.ParseException
import java.time.DateTimeException
import java.time.format.DateTimeParseException

// Copied from Spark: org/apache/spark/sql/catalyst/util/TimestampFormatter.scala
// for for https://github.com/NVIDIA/spark-rapids/issues/6026
// It can be removed when Spark3.3 is the least supported Spark version
sealed trait TimestampFormatter extends Serializable {
tgravescs marked this conversation as resolved.
Show resolved Hide resolved
/**
* Parses a timestamp in a string and converts it to microseconds since Unix Epoch in local time.
*
* @param s - string with timestamp to parse
* @param allowTimeZone - indicates strict parsing of timezone
* @return microseconds since epoch.
* @throws ParseException can be thrown by legacy parser
* @throws DateTimeParseException can be thrown by new parser
* @throws DateTimeException unable to obtain local date or time
* @throws IllegalStateException The formatter for timestamp without time zone should always
* implement this method. The exception should never be hit.
*/
@throws(classOf[ParseException])
@throws(classOf[DateTimeParseException])
@throws(classOf[DateTimeException])
@throws(classOf[IllegalStateException])
def parseWithoutTimeZone(s: String, allowTimeZone: Boolean): Long =
throw new IllegalStateException(
s"The method `parseWithoutTimeZone(s: String, allowTimeZone: Boolean)` should be " +
"implemented in the formatter of timestamp without time zone")

/**
* Parses a timestamp in a string and converts it to microseconds since Unix Epoch in local time.
* Zone-id and zone-offset components are ignored.
*/
@throws(classOf[ParseException])
@throws(classOf[DateTimeParseException])
@throws(classOf[DateTimeException])
@throws(classOf[IllegalStateException])
final def parseWithoutTimeZone(s: String): Long =
// This is implemented to adhere to the original behaviour of `parseWithoutTimeZone` where we
// did not fail if timestamp contained zone-id or zone-offset component and instead ignored it.
parseWithoutTimeZone(s, true)
}
Loading