Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parquet writer support for TIMESTAMP_MILLIS #726

Merged
merged 8 commits into from
Sep 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,18 +291,32 @@ def test_read_merge_schema_from_conf(spark_tmp_path, v1_enabled_list, mt_opt):
@pytest.mark.parametrize('parquet_gens', parquet_write_gens_list, ids=idfn)
@pytest.mark.parametrize('mt_opt', ["true", "false"])
@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"])
def test_write_round_trip(spark_tmp_path, parquet_gens, mt_opt, v1_enabled_list):
@pytest.mark.parametrize('ts_type', ["TIMESTAMP_MICROS", "TIMESTAMP_MILLIS"])
def test_write_round_trip(spark_tmp_path, parquet_gens, mt_opt, v1_enabled_list, ts_type):
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)]
data_path = spark_tmp_path + '/PARQUET_DATA'
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path,
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED',
'spark.sql.parquet.outputTimestampType': 'TIMESTAMP_MICROS',
'spark.sql.parquet.outputTimestampType': ts_type,
'spark.rapids.sql.format.parquet.multiThreadedRead.enabled': mt_opt,
'spark.sql.sources.useV1SourceList': v1_enabled_list})

@pytest.mark.parametrize('ts_type', ['TIMESTAMP_MILLIS', 'TIMESTAMP_MICROS'])
@pytest.mark.parametrize('ts_rebase', ['CORRECTED'])
@ignore_order
def test_write_ts_millis(spark_tmp_path, ts_type, ts_rebase):
gen = TimestampGen()
data_path = spark_tmp_path + '/PARQUET_DATA'
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: unary_op_df(spark, gen).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path,
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase,
'spark.sql.parquet.outputTimestampType': ts_type})

parquet_part_write_gens = [
byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
# Some file systems have issues with UTF8 strings so to help the test pass even there
Expand All @@ -314,7 +328,8 @@ def test_write_round_trip(spark_tmp_path, parquet_gens, mt_opt, v1_enabled_list)
@pytest.mark.parametrize('parquet_gen', parquet_part_write_gens, ids=idfn)
@pytest.mark.parametrize('mt_opt', ["true", "false"])
@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"])
def test_part_write_round_trip(spark_tmp_path, parquet_gen, mt_opt, v1_enabled_list):
@pytest.mark.parametrize('ts_type', ['TIMESTAMP_MILLIS', 'TIMESTAMP_MICROS'])
def test_part_write_round_trip(spark_tmp_path, parquet_gen, mt_opt, v1_enabled_list, ts_type):
gen_list = [('a', RepeatSeqGen(parquet_gen, 10)),
('b', parquet_gen)]
data_path = spark_tmp_path + '/PARQUET_DATA'
Expand All @@ -323,7 +338,7 @@ def test_part_write_round_trip(spark_tmp_path, parquet_gen, mt_opt, v1_enabled_l
lambda spark, path: spark.read.parquet(path),
data_path,
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED',
'spark.sql.parquet.outputTimestampType': 'TIMESTAMP_MICROS',
'spark.sql.parquet.outputTimestampType': ts_type,
'spark.rapids.sql.format.parquet.multiThreadedRead.enabled': mt_opt,
'spark.sql.sources.useV1SourceList': v1_enabled_list})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ import org.apache.spark.sql.execution.datasources.DataSourceUtils
import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetWriteSupport}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
import org.apache.spark.sql.rapids.ColumnarWriteTaskStatsTracker
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.types.{DateType, StructType, TimestampType}
import org.apache.spark.sql.types.{DataTypes, DateType, StructType, TimestampType}
import org.apache.spark.sql.vectorized.ColumnarBatch

object GpuParquetFileFormat {
def tagGpuSupport(
Expand Down Expand Up @@ -64,10 +66,9 @@ object GpuParquetFileFormat {
TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType])
}
if (schemaHasTimestamps) {
// TODO: Could support TIMESTAMP_MILLIS by performing cast on all timestamp input columns
sqlConf.parquetOutputTimestampType match {
case ParquetOutputTimestampType.TIMESTAMP_MICROS =>
case t => meta.willNotWorkOnGpu(s"Output timestamp type $t is not supported")
if(!isOutputTimestampTypeSupported(sqlConf.parquetOutputTimestampType)) {
meta.willNotWorkOnGpu(s"Output timestamp type " +
s"${sqlConf.parquetOutputTimestampType} is not supported")
}
}

Expand Down Expand Up @@ -100,6 +101,15 @@ object GpuParquetFileFormat {
case _ => None
}
}

def isOutputTimestampTypeSupported(
outputTimestampType: ParquetOutputTimestampType.Value): Boolean = {
outputTimestampType match {
case ParquetOutputTimestampType.TIMESTAMP_MICROS |
ParquetOutputTimestampType.TIMESTAMP_MILLIS => true
case _ => false
}
}
}

class GpuParquetFileFormat extends ColumnarFileFormat with Logging {
Expand Down Expand Up @@ -160,7 +170,7 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging {
sparkSession.sessionState.conf.writeLegacyParquetFormat.toString)

val outputTimestampType = sparkSession.sessionState.conf.parquetOutputTimestampType
if (outputTimestampType != ParquetOutputTimestampType.TIMESTAMP_MICROS) {
if(!GpuParquetFileFormat.isOutputTimestampTypeSupported(outputTimestampType)) {
val hasTimestamps = dataSchema.exists { field =>
TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType])
}
Expand Down Expand Up @@ -217,6 +227,8 @@ class GpuParquetWriter(
context: TaskAttemptContext)
extends ColumnarOutputWriter(path, context, dataSchema, "Parquet") {

val outputTimestampType = conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key)

override def scanTableBeforeWrite(table: Table): Unit = {
if (dateTimeRebaseException) {
(0 until table.getNumberOfColumns).foreach { i =>
Expand All @@ -227,6 +239,32 @@ class GpuParquetWriter(
}
}

/**
* Persists a columnar batch. Invoked on the executor side. When writing to dynamically
* partitioned tables, dynamic partition columns are not included in columns to be written.
* NOTE: It is the writer's responsibility to close the batch.
*/
override def write(batch: ColumnarBatch,
statsTrackers: Seq[ColumnarWriteTaskStatsTracker]): Unit = {
val newBatch =
if (outputTimestampType == ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString) {
new ColumnarBatch(GpuColumnVector.extractColumns(batch).map {
cv => {
cv.dataType() match {
case DataTypes.TimestampType => new GpuColumnVector(DataTypes.TimestampType,
withResource(cv.getBase()) { v =>
v.castTo(DType.TIMESTAMP_MILLISECONDS)
})
case _ => cv
}
}
})
} else {
batch
}

super.write(newBatch, statsTrackers)
}
override val tableWriter: TableWriter = {
val writeContext = new ParquetWriteSupport().init(conf)
val builder = ParquetWriterOptions.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,21 @@ class ParquetWriterSuite extends SparkQueryCompareTestSuite {
}
}

testExpectedGpuException(
"Old timestamps millis in EXCEPTION mode",
classOf[SparkException],
oldTsDf,
new SparkConf()
.set("spark.sql.legacy.parquet.datetimeRebaseModeInWrite", "EXCEPTION")
.set("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MILLIS")) {
val tempFile = File.createTempFile("oldTimeStamp", "parquet")
tempFile.delete()
frame => {
frame.write.mode("overwrite").parquet(tempFile.getAbsolutePath)
frame
}
}

testExpectedGpuException(
"Old timestamps in EXCEPTION mode",
classOf[SparkException],
Expand Down