Skip to content

Commit

Permalink
parquet writer support for TIMESTAMP_MILLIS (NVIDIA#726)
Browse files Browse the repository at this point in the history
Signed-off-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
razajafri authored Sep 15, 2020
1 parent 5774b44 commit 5f6ed5c
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 10 deletions.
23 changes: 19 additions & 4 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,18 +291,32 @@ def test_read_merge_schema_from_conf(spark_tmp_path, v1_enabled_list, mt_opt):
@pytest.mark.parametrize('parquet_gens', parquet_write_gens_list, ids=idfn)
@pytest.mark.parametrize('mt_opt', ["true", "false"])
@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"])
def test_write_round_trip(spark_tmp_path, parquet_gens, mt_opt, v1_enabled_list):
@pytest.mark.parametrize('ts_type', ["TIMESTAMP_MICROS", "TIMESTAMP_MILLIS"])
def test_write_round_trip(spark_tmp_path, parquet_gens, mt_opt, v1_enabled_list, ts_type):
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)]
data_path = spark_tmp_path + '/PARQUET_DATA'
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path,
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED',
'spark.sql.parquet.outputTimestampType': 'TIMESTAMP_MICROS',
'spark.sql.parquet.outputTimestampType': ts_type,
'spark.rapids.sql.format.parquet.multiThreadedRead.enabled': mt_opt,
'spark.sql.sources.useV1SourceList': v1_enabled_list})

@pytest.mark.parametrize('ts_type', ['TIMESTAMP_MILLIS', 'TIMESTAMP_MICROS'])
@pytest.mark.parametrize('ts_rebase', ['CORRECTED'])
@ignore_order
def test_write_ts_millis(spark_tmp_path, ts_type, ts_rebase):
gen = TimestampGen()
data_path = spark_tmp_path + '/PARQUET_DATA'
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: unary_op_df(spark, gen).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path,
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase,
'spark.sql.parquet.outputTimestampType': ts_type})

parquet_part_write_gens = [
byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
# Some file systems have issues with UTF8 strings so to help the test pass even there
Expand All @@ -314,7 +328,8 @@ def test_write_round_trip(spark_tmp_path, parquet_gens, mt_opt, v1_enabled_list)
@pytest.mark.parametrize('parquet_gen', parquet_part_write_gens, ids=idfn)
@pytest.mark.parametrize('mt_opt', ["true", "false"])
@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"])
def test_part_write_round_trip(spark_tmp_path, parquet_gen, mt_opt, v1_enabled_list):
@pytest.mark.parametrize('ts_type', ['TIMESTAMP_MILLIS', 'TIMESTAMP_MICROS'])
def test_part_write_round_trip(spark_tmp_path, parquet_gen, mt_opt, v1_enabled_list, ts_type):
gen_list = [('a', RepeatSeqGen(parquet_gen, 10)),
('b', parquet_gen)]
data_path = spark_tmp_path + '/PARQUET_DATA'
Expand All @@ -323,7 +338,7 @@ def test_part_write_round_trip(spark_tmp_path, parquet_gen, mt_opt, v1_enabled_l
lambda spark, path: spark.read.parquet(path),
data_path,
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED',
'spark.sql.parquet.outputTimestampType': 'TIMESTAMP_MICROS',
'spark.sql.parquet.outputTimestampType': ts_type,
'spark.rapids.sql.format.parquet.multiThreadedRead.enabled': mt_opt,
'spark.sql.sources.useV1SourceList': v1_enabled_list})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ import org.apache.spark.sql.execution.datasources.DataSourceUtils
import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetWriteSupport}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
import org.apache.spark.sql.rapids.ColumnarWriteTaskStatsTracker
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.types.{DateType, StructType, TimestampType}
import org.apache.spark.sql.types.{DataTypes, DateType, StructType, TimestampType}
import org.apache.spark.sql.vectorized.ColumnarBatch

object GpuParquetFileFormat {
def tagGpuSupport(
Expand Down Expand Up @@ -64,10 +66,9 @@ object GpuParquetFileFormat {
TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType])
}
if (schemaHasTimestamps) {
// TODO: Could support TIMESTAMP_MILLIS by performing cast on all timestamp input columns
sqlConf.parquetOutputTimestampType match {
case ParquetOutputTimestampType.TIMESTAMP_MICROS =>
case t => meta.willNotWorkOnGpu(s"Output timestamp type $t is not supported")
if(!isOutputTimestampTypeSupported(sqlConf.parquetOutputTimestampType)) {
meta.willNotWorkOnGpu(s"Output timestamp type " +
s"${sqlConf.parquetOutputTimestampType} is not supported")
}
}

Expand Down Expand Up @@ -100,6 +101,15 @@ object GpuParquetFileFormat {
case _ => None
}
}

def isOutputTimestampTypeSupported(
outputTimestampType: ParquetOutputTimestampType.Value): Boolean = {
outputTimestampType match {
case ParquetOutputTimestampType.TIMESTAMP_MICROS |
ParquetOutputTimestampType.TIMESTAMP_MILLIS => true
case _ => false
}
}
}

class GpuParquetFileFormat extends ColumnarFileFormat with Logging {
Expand Down Expand Up @@ -160,7 +170,7 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging {
sparkSession.sessionState.conf.writeLegacyParquetFormat.toString)

val outputTimestampType = sparkSession.sessionState.conf.parquetOutputTimestampType
if (outputTimestampType != ParquetOutputTimestampType.TIMESTAMP_MICROS) {
if(!GpuParquetFileFormat.isOutputTimestampTypeSupported(outputTimestampType)) {
val hasTimestamps = dataSchema.exists { field =>
TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType])
}
Expand Down Expand Up @@ -217,6 +227,8 @@ class GpuParquetWriter(
context: TaskAttemptContext)
extends ColumnarOutputWriter(path, context, dataSchema, "Parquet") {

val outputTimestampType = conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key)

override def scanTableBeforeWrite(table: Table): Unit = {
if (dateTimeRebaseException) {
(0 until table.getNumberOfColumns).foreach { i =>
Expand All @@ -227,6 +239,32 @@ class GpuParquetWriter(
}
}

/**
* Persists a columnar batch. Invoked on the executor side. When writing to dynamically
* partitioned tables, dynamic partition columns are not included in columns to be written.
* NOTE: It is the writer's responsibility to close the batch.
*/
override def write(batch: ColumnarBatch,
statsTrackers: Seq[ColumnarWriteTaskStatsTracker]): Unit = {
val newBatch =
if (outputTimestampType == ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString) {
new ColumnarBatch(GpuColumnVector.extractColumns(batch).map {
cv => {
cv.dataType() match {
case DataTypes.TimestampType => new GpuColumnVector(DataTypes.TimestampType,
withResource(cv.getBase()) { v =>
v.castTo(DType.TIMESTAMP_MILLISECONDS)
})
case _ => cv
}
}
})
} else {
batch
}

super.write(newBatch, statsTrackers)
}
override val tableWriter: TableWriter = {
val writeContext = new ParquetWriteSupport().init(conf)
val builder = ParquetWriterOptions.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,21 @@ class ParquetWriterSuite extends SparkQueryCompareTestSuite {
}
}

testExpectedGpuException(
"Old timestamps millis in EXCEPTION mode",
classOf[SparkException],
oldTsDf,
new SparkConf()
.set("spark.sql.legacy.parquet.datetimeRebaseModeInWrite", "EXCEPTION")
.set("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MILLIS")) {
val tempFile = File.createTempFile("oldTimeStamp", "parquet")
tempFile.delete()
frame => {
frame.write.mode("overwrite").parquet(tempFile.getAbsolutePath)
frame
}
}

testExpectedGpuException(
"Old timestamps in EXCEPTION mode",
classOf[SparkException],
Expand Down

0 comments on commit 5f6ed5c

Please sign in to comment.