Skip to content

Commit

Permalink
Revert "Support rebase checking for nested dates and timestamps (#9617)"
Browse files Browse the repository at this point in the history
This reverts commit 401d0d8.

Signed-off-by: Nghia Truong <nghiat@nvidia.com>
  • Loading branch information
ttnghia committed Nov 6, 2023
1 parent 68883ac commit b0d8327
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 54 deletions.
41 changes: 31 additions & 10 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,35 @@ def test_parquet_pred_push_round_trip(spark_tmp_path, parquet_gen, read_func, v1

parquet_ts_write_options = ['INT96', 'TIMESTAMP_MICROS', 'TIMESTAMP_MILLIS']

# Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with timestamp_gen
@pytest.mark.parametrize('gen', [TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc)),
ArrayGen(TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))),

# Once https://github.com/NVIDIA/spark-rapids/issues/1126 is fixed delete this test and merge it
# into test_ts_read_round_trip nested timestamps and dates are not supported right now.
@pytest.mark.parametrize('gen', [ArrayGen(TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))),
ArrayGen(ArrayGen(TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))))], ids=idfn)
@pytest.mark.parametrize('ts_write', parquet_ts_write_options)
@pytest.mark.parametrize('ts_rebase', ['CORRECTED', 'LEGACY'])
@pytest.mark.parametrize('reader_confs', reader_opt_confs)
@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"])
@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/1126')
def test_parquet_ts_read_round_trip_nested(gen, spark_tmp_path, ts_write, ts_rebase, v1_enabled_list, reader_confs):
data_path = spark_tmp_path + '/PARQUET_DATA'
with_cpu_session(
lambda spark : unary_op_df(spark, gen).write.parquet(data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase,
'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase,
'spark.sql.parquet.outputTimestampType': ts_write})
all_confs = copy_and_update(reader_confs, {'spark.sql.sources.useV1SourceList': v1_enabled_list})
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.read.parquet(data_path),
conf=all_confs)

# Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with
# timestamp_gen
@pytest.mark.parametrize('gen', [TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))], ids=idfn)
@pytest.mark.parametrize('ts_write', parquet_ts_write_options)
@pytest.mark.parametrize('ts_rebase', ['CORRECTED', 'LEGACY'])
@pytest.mark.parametrize('reader_confs', reader_opt_confs)
@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"])
def test_ts_read_round_trip(gen, spark_tmp_path, ts_write, ts_rebase, v1_enabled_list, reader_confs):
data_path = spark_tmp_path + '/PARQUET_DATA'
with_cpu_session(
Expand All @@ -337,10 +358,10 @@ def readParquetCatchException(spark, data_path):
df = spark.read.parquet(data_path).collect()
assert e_info.match(r".*SparkUpgradeException.*")

# Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with timestamp_gen
@pytest.mark.parametrize('gen', [TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc)),
ArrayGen(TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))),
ArrayGen(ArrayGen(TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))))], ids=idfn)
# Once https://github.com/NVIDIA/spark-rapids/issues/1126 is fixed nested timestamps and dates should be added in
# Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with
# timestamp_gen
@pytest.mark.parametrize('gen', [TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))], ids=idfn)
@pytest.mark.parametrize('ts_write', parquet_ts_write_options)
@pytest.mark.parametrize('ts_rebase', ['LEGACY'])
@pytest.mark.parametrize('reader_confs', reader_opt_confs)
Expand Down Expand Up @@ -982,7 +1003,7 @@ def test_parquet_reading_from_unaligned_pages_basic_filters_with_nulls(spark_tmp


conf_for_parquet_aggregate_pushdown = {
"spark.sql.parquet.aggregatePushdown": "true",
"spark.sql.parquet.aggregatePushdown": "true",
"spark.sql.sources.useV1SourceList": ""
}

Expand Down Expand Up @@ -1469,15 +1490,15 @@ def test_parquet_read_count(spark_tmp_path):
def test_read_case_col_name(spark_tmp_path, read_func, v1_enabled_list, reader_confs, col_name):
all_confs = copy_and_update(reader_confs, {
'spark.sql.sources.useV1SourceList': v1_enabled_list})
gen_list =[('k0', LongGen(nullable=False, min_val=0, max_val=0)),
gen_list =[('k0', LongGen(nullable=False, min_val=0, max_val=0)),
('k1', LongGen(nullable=False, min_val=1, max_val=1)),
('k2', LongGen(nullable=False, min_val=2, max_val=2)),
('k3', LongGen(nullable=False, min_val=3, max_val=3)),
('v0', LongGen()),
('v1', LongGen()),
('v2', LongGen()),
('v3', LongGen())]

gen = StructGen(gen_list, nullable=False)
data_path = spark_tmp_path + '/PAR_DATA'
reader = read_func(data_path)
Expand Down
66 changes: 35 additions & 31 deletions sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,58 +16,62 @@

package com.nvidia.spark

import ai.rapids.cudf.{ColumnView, DType, Scalar}
import ai.rapids.cudf.{ColumnVector, DType, Scalar}
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.shims.SparkShimImpl

import org.apache.spark.sql.catalyst.util.RebaseDateTime
import org.apache.spark.sql.rapids.execution.TrampolineUtil

object RebaseHelper {
private[this] def isRebaseNeeded(column: ColumnView, checkType: DType,
minGood: Scalar): Boolean = {
private[this] def isDateRebaseNeeded(column: ColumnVector,
startDay: Int): Boolean = {
// TODO update this for nested column checks
// https://github.com/NVIDIA/spark-rapids/issues/1126
val dtype = column.getType
require(!dtype.hasTimeResolution || dtype == DType.TIMESTAMP_MICROSECONDS)
if (dtype == DType.TIMESTAMP_DAYS) {
val hasBad = withResource(Scalar.timestampDaysFromInt(startDay)) {
column.lessThan
}
val anyBad = withResource(hasBad) {
_.any()
}
withResource(anyBad) { _ =>
anyBad.isValid && anyBad.getBoolean
}
} else {
false
}
}

dtype match {
case `checkType` =>
private[this] def isTimeRebaseNeeded(column: ColumnVector,
startTs: Long): Boolean = {
val dtype = column.getType
if (dtype.hasTimeResolution) {
require(dtype == DType.TIMESTAMP_MICROSECONDS)
withResource(
Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, startTs)) { minGood =>
withResource(column.lessThan(minGood)) { hasBad =>
withResource(hasBad.any()) { anyBad =>
anyBad.isValid && anyBad.getBoolean
withResource(hasBad.any()) { a =>
a.isValid && a.getBoolean
}
}

case DType.LIST | DType.STRUCT => (0 until column.getNumChildren).exists(i =>
withResource(column.getChildColumnView(i)) { child =>
isRebaseNeeded(child, checkType, minGood)
})

case _ => false
}
}

private[this] def isDateRebaseNeeded(column: ColumnView, startDay: Int): Boolean = {
withResource(Scalar.timestampDaysFromInt(startDay)) { minGood =>
isRebaseNeeded(column, DType.TIMESTAMP_DAYS, minGood)
}
}

private[this] def isTimeRebaseNeeded(column: ColumnView, startTs: Long): Boolean = {
withResource(Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, startTs)) { minGood =>
isRebaseNeeded(column, DType.TIMESTAMP_MICROSECONDS, minGood)
}
} else {
false
}
}

def isDateRebaseNeededInRead(column: ColumnView): Boolean =
def isDateRebaseNeededInRead(column: ColumnVector): Boolean =
isDateRebaseNeeded(column, RebaseDateTime.lastSwitchJulianDay)

def isTimeRebaseNeededInRead(column: ColumnView): Boolean =
def isTimeRebaseNeededInRead(column: ColumnVector): Boolean =
isTimeRebaseNeeded(column, RebaseDateTime.lastSwitchJulianTs)

def isDateRebaseNeededInWrite(column: ColumnView): Boolean =
def isDateRebaseNeededInWrite(column: ColumnVector): Boolean =
isDateRebaseNeeded(column, RebaseDateTime.lastSwitchGregorianDay)

def isTimeRebaseNeededInWrite(column: ColumnView): Boolean =
def isTimeRebaseNeededInWrite(column: ColumnVector): Boolean =
isTimeRebaseNeeded(column, RebaseDateTime.lastSwitchGregorianTs)

def newRebaseExceptionInRead(format: String): Exception = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,11 @@ object GpuParquetScan {
hasInt96Timestamps: Boolean): Unit = {
(0 until table.getNumberOfColumns).foreach { i =>
val col = table.getColumn(i)
// if col is a day
if (!isCorrectedDateTimeRebase && RebaseHelper.isDateRebaseNeededInRead(col)) {
throw DataSourceUtils.newRebaseExceptionInRead("Parquet")
}
// if col is a time
else if (hasInt96Timestamps && !isCorrectedInt96Rebase ||
!hasInt96Timestamps && !isCorrectedDateTimeRebase) {
if (RebaseHelper.isTimeRebaseNeededInRead(col)) {
Expand Down Expand Up @@ -199,6 +201,21 @@ object GpuParquetScan {

FileFormatChecks.tag(meta, readSchema, ParquetFormatType, ReadFileOp)

val schemaHasTimestamps = readSchema.exists { field =>
TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType])
}
def isTsOrDate(dt: DataType) : Boolean = dt match {
case TimestampType | DateType => true
case _ => false
}
val schemaMightNeedNestedRebase = readSchema.exists { field =>
if (DataTypeUtils.isNestedType(field.dataType)) {
TrampolineUtil.dataTypeExistsRecursively(field.dataType, isTsOrDate)
} else {
false
}
}

// Currently timestamp conversion is not supported.
// If support needs to be added then we need to follow the logic in Spark's
// ParquetPartitionReaderFactory and VectorizedColumnReader which essentially
Expand All @@ -208,32 +225,35 @@ object GpuParquetScan {
// were written in that timezone and convert them to UTC timestamps.
// Essentially this should boil down to a vector subtract of the scalar delta
// between the configured timezone's delta from UTC on the timestamp data.
val schemaHasTimestamps = readSchema.exists { field =>
TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType])
}
if (schemaHasTimestamps && sparkSession.sessionState.conf.isParquetINT96TimestampConversion) {
meta.willNotWorkOnGpu("GpuParquetScan does not support int96 timestamp conversion")
}

val schemaHasDates = readSchema.exists { field =>
TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[DateType])
}

sqlConf.get(SparkShimImpl.int96ParquetRebaseReadKey) match {
case "EXCEPTION" | "CORRECTED" => // Good
case "EXCEPTION" => if (schemaMightNeedNestedRebase) {
meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " +
s"${SparkShimImpl.int96ParquetRebaseReadKey} is EXCEPTION")
}
case "CORRECTED" => // Good
case "LEGACY" => // really is EXCEPTION for us...
if (schemaHasTimestamps) {
meta.willNotWorkOnGpu("LEGACY rebase mode for dates and timestamps is not supported")
if (schemaMightNeedNestedRebase) {
meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " +
s"${SparkShimImpl.int96ParquetRebaseReadKey} is LEGACY")
}
case other =>
meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode")
}

sqlConf.get(SparkShimImpl.parquetRebaseReadKey) match {
case "EXCEPTION" | "CORRECTED" => // Good
case "EXCEPTION" => if (schemaMightNeedNestedRebase) {
meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " +
s"${SparkShimImpl.parquetRebaseReadKey} is EXCEPTION")
}
case "CORRECTED" => // Good
case "LEGACY" => // really is EXCEPTION for us...
if (schemaHasDates || schemaHasTimestamps) {
meta.willNotWorkOnGpu("LEGACY rebase mode for dates and timestamps is not supported")
if (schemaMightNeedNestedRebase) {
meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " +
s"${SparkShimImpl.parquetRebaseReadKey} is LEGACY")
}
case other =>
meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode")
Expand Down Expand Up @@ -2898,3 +2918,4 @@ object ParquetPartitionReader {
block
}
}

0 comments on commit b0d8327

Please sign in to comment.