Skip to content

Commit

Permalink
Support ANSI intervals to/from Parquet (#4810)
Browse files Browse the repository at this point in the history
* Support ANSI intervals to/from Parquet

Signed-off-by: Chong Gao <res_life@163.com>
  • Loading branch information
Chong Gao authored Mar 8, 2022
1 parent ef9236a commit fbb2f07
Show file tree
Hide file tree
Showing 13 changed files with 475 additions and 38 deletions.
7 changes: 5 additions & 2 deletions integration_tests/src/main/python/asserts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,7 +13,7 @@
# limitations under the License.

from conftest import is_incompat, should_sort_on_spark, should_sort_locally, get_float_check, get_limit, spark_jvm
from datetime import date, datetime
from datetime import date, datetime, timedelta
from decimal import Decimal
import math
from pyspark.sql import Row
Expand Down Expand Up @@ -92,6 +92,9 @@ def _assert_equal(cpu, gpu, float_check, path):
assert cpu == gpu, "GPU and CPU decimal values are different at {}".format(path)
elif isinstance(cpu, bytearray):
assert cpu == gpu, "GPU and CPU bytearray values are different at {}".format(path)
elif isinstance(cpu, timedelta):
# Used by interval type DayTimeInterval for Pyspark 3.3.0+
assert cpu == gpu, "GPU and CPU timedelta values are different at {}".format(path)
elif (cpu == None):
assert cpu == gpu, "GPU and CPU are not both null at {}".format(path)
else:
Expand Down
27 changes: 27 additions & 0 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,33 @@ def make_null():
return None
self._start(rand, make_null)

# DayTimeIntervalGen is for Spark 3.3.0+
# DayTimeIntervalType(startField, endField): Represents a day-time interval which is made up of a contiguous subset of the following fields:
# SECOND, seconds within minutes and possibly fractions of a second [0..59.999999],
# MINUTE, minutes within hours [0..59],
# HOUR, hours within days [0..23],
# DAY, days in the range [0..106751991].
# For more details: https://spark.apache.org/docs/latest/sql-ref-datatypes.html
# Note: 106751991/365 = 292471 years which is much bigger than 9999 year, seems something is wrong
class DayTimeIntervalGen(DataGen):
"""Generate DayTimeIntervalType values"""
def __init__(self, max_days = None, nullable=True, special_cases =[timedelta(seconds = 0)]):
super().__init__(DayTimeIntervalType(), nullable=nullable, special_cases=special_cases)
if max_days is None:
self._max_days = 106751991
else:
self._max_days = max_days
def start(self, rand):
self._start(rand,
lambda : timedelta(
microseconds = rand.randint(0, 999999),
seconds = rand.randint(0, 59),
minutes = rand.randint(0, 59),
hours = rand.randint(0, 23),
days = rand.randint(0, self._max_days),
)
)

def skip_if_not_utc():
if (not is_tz_utc()):
skip_unless_precommit_tests('The java system time zone is not set to UTC')
Expand Down
14 changes: 12 additions & 2 deletions integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,7 @@
from datetime import date, datetime, timezone
from marks import incompat, allow_non_gpu
from pyspark.sql.types import *
from spark_session import with_spark_session, is_before_spark_311
from spark_session import with_spark_session, is_before_spark_311, is_before_spark_330
import pyspark.sql.functions as f

# We only support literal intervals for TimeSub
Expand All @@ -41,6 +41,16 @@ def test_timeadd(data_gen):
lambda spark: unary_op_df(spark, TimestampGen(start=datetime(5, 1, 1, tzinfo=timezone.utc), end=datetime(15, 1, 1, tzinfo=timezone.utc)), seed=1)
.selectExpr("a + (interval {} days {} seconds)".format(days, seconds)))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_timeadd_daytime_column():
gen_list = [
# timestamp column max year is 1000
('t', TimestampGen(end = datetime(1000, 1, 1, tzinfo=timezone.utc))),
# max days is 8000 year, so added result will not be out of range
('d', DayTimeIntervalGen(max_days = 8000 * 365))]
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND"))

@pytest.mark.parametrize('data_gen', vals, ids=idfn)
def test_dateaddinterval(data_gen):
days, seconds = data_gen
Expand Down
18 changes: 18 additions & 0 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,3 +789,21 @@ def test_parquet_read_field_id(spark_tmp_path):
lambda spark: spark.read.schema(readSchema).parquet(data_path),
'FileSourceScanExec',
{"spark.sql.parquet.fieldId.read.enabled": "true"}) # default is false

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_parquet_read_daytime_interval_cpu_file(spark_tmp_path):
data_path = spark_tmp_path + '/PARQUET_DATA'
gen_list = [('_c1', DayTimeIntervalGen())]
# write DayTimeInterval with CPU
with_cpu_session(lambda spark :gen_df(spark, gen_list).coalesce(1).write.mode("overwrite").parquet(data_path))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.read.parquet(data_path))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_parquet_read_daytime_interval_gpu_file(spark_tmp_path):
data_path = spark_tmp_path + '/PARQUET_DATA'
gen_list = [('_c1', DayTimeIntervalGen())]
# write DayTimeInterval with GPU
with_gpu_session(lambda spark :gen_df(spark, gen_list).coalesce(1).write.mode("overwrite").parquet(data_path))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.read.parquet(data_path))
11 changes: 11 additions & 0 deletions integration_tests/src/main/python/parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,14 @@ def test_parquet_write_field_id(spark_tmp_path):
data_path,
'DataWritingCommandExec',
conf = {"spark.sql.parquet.fieldId.write.enabled" : "true"}) # default is true

@pytest.mark.order(1) # at the head of xdist worker queue if pytest-order is installed
@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_write_daytime_interval(spark_tmp_path):
gen_list = [('_c1', DayTimeIntervalGen())]
data_path = spark_tmp_path + '/PARQUET_DATA'
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path,
conf=writer_confs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids.shims.v2

import ai.rapids.cudf.DType
import com.nvidia.spark.rapids.GpuRowToColumnConverter.TypeConverter

import org.apache.spark.sql.types.DataType

object GpuTypeShims {

/**
* If Shim supports the data type for row to column converter
* @param otherType the data type that should be checked in the Shim
* @return true if Shim support the otherType, false otherwise.
*/
def hasConverterForType(otherType: DataType) : Boolean = false

/**
* Get the TypeConverter of the data type for this Shim
* Note should first calling hasConverterForType
* @param t the data type
* @param nullable is nullable
* @return the row to column convert for the data type
*/
def getConverterForType(t: DataType, nullable: Boolean): TypeConverter = {
throw new RuntimeException(s"No converter is found for type $t.")
}

/**
* Get the cuDF type for the Spark data type
* @param t the Spark data type
* @return the cuDF type if the Shim supports
*/
def toRapidsOrNull(t: DataType): DType = null
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.shims.v2

import java.util.concurrent.TimeUnit

import ai.rapids.cudf.{ColumnVector, ColumnView, DType, Scalar}
import ai.rapids.cudf.{BinaryOp, BinaryOperable, ColumnVector, ColumnView, DType, Scalar}
import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.v2.ShimBinaryExpression
Expand Down Expand Up @@ -59,48 +59,59 @@ case class GpuTimeAdd(start: Expression,
override def columnarEval(batch: ColumnarBatch): Any = {
withResourceIfAllowed(left.columnarEval(batch)) { lhs =>
withResourceIfAllowed(right.columnarEval(batch)) { rhs =>
// lhs is start, rhs is interval
(lhs, rhs) match {
case (l: GpuColumnVector, intvlS: GpuScalar) =>
val interval = intvlS.dataType match {
case (l: GpuColumnVector, intervalS: GpuScalar) =>
// get long type interval
val interval = intervalS.dataType match {
case CalendarIntervalType =>
// Scalar does not support 'CalendarInterval' now, so use
// the Scala value instead.
// Skip the null check because it wll be detected by the following calls.
val intvl = intvlS.getValue.asInstanceOf[CalendarInterval]
if (intvl.months != 0) {
val calendarI = intervalS.getValue.asInstanceOf[CalendarInterval]
if (calendarI.months != 0) {
throw new UnsupportedOperationException("Months aren't supported at the moment")
}
intvl.days * microSecondsInOneDay + intvl.microseconds
calendarI.days * microSecondsInOneDay + calendarI.microseconds
case _: DayTimeIntervalType =>
// Scalar does not support 'DayTimeIntervalType' now, so use
// the Scala value instead.
intvlS.getValue.asInstanceOf[Long]
intervalS.getValue.asInstanceOf[Long]
case _ =>
throw new UnsupportedOperationException("GpuTimeAdd unsupported data type: " +
intvlS.dataType)
throw new UnsupportedOperationException(
"GpuTimeAdd unsupported data type: " + intervalS.dataType)
}

// add interval
if (interval != 0) {
withResource(Scalar.fromLong(interval)) { us_s =>
withResource(l.getBase.bitCastTo(DType.INT64)) { us =>
withResource(intervalMath(us_s, us)) { longResult =>
GpuColumnVector.from(longResult.castTo(DType.TIMESTAMP_MICROSECONDS),
dataType)
}
}
withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, interval)) { d =>
GpuColumnVector.from(timestampAddDuration(l.getBase, d), dataType)
}
} else {
l.incRefCount()
}
case (l: GpuColumnVector, r: GpuColumnVector) =>
(l.dataType(), r.dataType) match {
case (_: TimestampType, _: DayTimeIntervalType) =>
// DayTimeIntervalType is stored as long
// bitCastTo is similar to reinterpret_cast, it's fast, the time can be ignored.
withResource(r.getBase.bitCastTo(DType.DURATION_MICROSECONDS)) { duration =>
GpuColumnVector.from(timestampAddDuration(l.getBase, duration), dataType)
}
case _ =>
throw new UnsupportedOperationException(
"GpuTimeAdd takes column and interval as an argument only")
}
case _ =>
throw new UnsupportedOperationException("GpuTimeAdd takes column and interval as an " +
"argument only")
throw new UnsupportedOperationException(
"GpuTimeAdd takes column and interval as an argument only")
}
}
}
}

private def intervalMath(us_s: Scalar, us: ColumnView): ColumnVector = {
us.add(us_s)
private def timestampAddDuration(cv: ColumnView, duration: BinaryOperable): ColumnVector = {
// Not use cv.add(duration), because of it invoke BinaryOperable.implicitConversion,
// and currently BinaryOperable.implicitConversion return Long
// Directly specify the return type is TIMESTAMP_MICROSECONDS
cv.binaryOp(BinaryOp.ADD, duration, DType.TIMESTAMP_MICROSECONDS)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids.shims.v2

import ai.rapids.cudf.DType
import com.nvidia.spark.rapids.GpuRowToColumnConverter.{LongConverter, NotNullLongConverter, TypeConverter}

import org.apache.spark.sql.types.{DataType, DayTimeIntervalType}

/**
* Spark stores ANSI YearMonthIntervalType as int32 and ANSI DayTimeIntervalType as int64
* internally when computing.
* See the comments of YearMonthIntervalType, below is copied from Spark
* Internally, values of year-month intervals are stored in `Int` values as amount of months
* that are calculated by the formula:
* -/+ (12 * YEAR + MONTH)
* See the comments of DayTimeIntervalType, below is copied from Spark
* Internally, values of day-time intervals are stored in `Long` values as amount of time in terms
* of microseconds that are calculated by the formula:
* -/+ (24*60*60 * DAY + 60*60 * HOUR + 60 * MINUTE + SECOND) * 1000000
*
* Spark also stores ANSI intervals as int32 and int64 in Parquet file:
* - year-month intervals as `INT32`
* - day-time intervals as `INT64`
* To load the values as intervals back, Spark puts the info about interval types
* to the extra key `org.apache.spark.sql.parquet.row.metadata`:
* $ java -jar parquet-tools-1.12.0.jar meta ./part-...-c000.snappy.parquet
* creator: parquet-mr version 1.12.1 (build 2a5c06c58fa987f85aa22170be14d927d5ff6e7d)
* extra: org.apache.spark.version = 3.3.0
* extra: org.apache.spark.sql.parquet.row.metadata =
* {"type":"struct","fields":[...,
* {"name":"i","type":"interval year to month","nullable":false,"metadata":{}}]}
* file schema: spark_schema
* --------------------------------------------------------------------------------
* ...
* i: REQUIRED INT32 R:0 D:0
*
* For details See https://issues.apache.org/jira/browse/SPARK-36825
*/
object GpuTypeShims {

/**
* If Shim supports the data type for row to column converter
* @param otherType the data type that should be checked in the Shim
* @return true if Shim support the otherType, false otherwise.
*/
def hasConverterForType(otherType: DataType) : Boolean = {
otherType match {
case DayTimeIntervalType(_, _) => true
case _ => false
}
}

/**
* Get the TypeConverter of the data type for this Shim
* Note should first calling hasConverterForType
* @param t the data type
* @param nullable is nullable
* @return the row to column convert for the data type
*/
def getConverterForType(t: DataType, nullable: Boolean): TypeConverter = {
(t, nullable) match {
case (DayTimeIntervalType(_, _), true) => LongConverter
case (DayTimeIntervalType(_, _), false) => NotNullLongConverter
case _ => throw new RuntimeException(s"No converter is found for type $t.")
}
}

/**
* Get the cuDF type for the Spark data type
* @param t the Spark data type
* @return the cuDF type if the Shim supports
*/
def toRapidsOrNull(t: DataType): DType = {
t match {
case _: DayTimeIntervalType =>
// use int64 as Spark does
DType.INT64
case _ =>
null
}
}
}
Loading

0 comments on commit fbb2f07

Please sign in to comment.