From 4c6325b008c686055ec87648967bf4fa788c7771 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Thu, 17 Feb 2022 09:33:32 +0800 Subject: [PATCH] Add ignoreCorruptFiles for ORC readers Signed-off-by: Bobby Wang --- .../src/main/python/json_test.py | 29 +++++++++++++++++-- integration_tests/src/main/python/orc_test.py | 24 ++++++++++++++- .../nvidia/spark/rapids/GpuOrcScanBase.scala | 26 +++++++++++++---- 3 files changed, 71 insertions(+), 8 deletions(-) diff --git a/integration_tests/src/main/python/json_test.py b/integration_tests/src/main/python/json_test.py index a349e3b3e67..52b1d60af65 100644 --- a/integration_tests/src/main/python/json_test.py +++ b/integration_tests/src/main/python/json_test.py @@ -16,9 +16,10 @@ from asserts import assert_gpu_and_cpu_are_equal_collect from data_gen import * -from src.main.python.marks import approximate_float, allow_non_gpu +from conftest import is_databricks_runtime +from marks import approximate_float, allow_non_gpu, ignore_order -from src.main.python.spark_session import with_cpu_session +from spark_session import with_cpu_session, with_gpu_session json_supported_gens = [ # Spark does not escape '\r' or '\n' even though it uses it to mark end of record @@ -206,3 +207,27 @@ def test_json_unquotedCharacters(std_input_path, filename, schema, read_func, al schema, {"allowUnquotedControlChars": allow_unquoted_chars}), conf=_enable_all_types_conf) + +@ignore_order +@pytest.mark.parametrize('v1_enabled_list', ["", "json"]) +@pytest.mark.skipif(is_databricks_runtime(), reason="Databricks does not support ignoreCorruptFiles") +def test_json_read_with_corrupt_files(spark_tmp_path, v1_enabled_list): + first_data_path = spark_tmp_path + '/JSON_DATA/first' + with_cpu_session(lambda spark : spark.range(1).toDF("a").write.json(first_data_path)) + second_data_path = spark_tmp_path + '/JSON_DATA/second' + with_cpu_session(lambda spark : spark.range(1, 2).toDF("a").write.orc(second_data_path)) + third_data_path = spark_tmp_path + '/JSON_DATA/third' + with_cpu_session(lambda spark : spark.range(2, 3).toDF("a").write.json(third_data_path)) + + all_confs = copy_and_update(_enable_all_types_conf, + {'spark.sql.files.ignoreCorruptFiles': "true", + 'spark.sql.sources.useV1SourceList': v1_enabled_list}) + schema = StructType([StructField("a", IntegerType())]) + + # when ignoreCorruptFiles is enabled, gpu reading should not throw exception, while CPU can successfully + # read the three files without ignore corrupt files. So we just check if GPU will throw exception. + with_gpu_session( + lambda spark : spark.read.schema(schema) + .json([first_data_path, second_data_path, third_data_path]) + .collect(), + conf=all_confs) diff --git a/integration_tests/src/main/python/orc_test.py b/integration_tests/src/main/python/orc_test.py index ba5fca0711c..dbbaddde7fa 100644 --- a/integration_tests/src/main/python/orc_test.py +++ b/integration_tests/src/main/python/orc_test.py @@ -20,6 +20,7 @@ from pyspark.sql.types import * from spark_session import with_cpu_session, is_before_spark_330 from parquet_test import _nested_pruning_schemas +from conftest import is_databricks_runtime pytestmark = pytest.mark.nightly_resource_consuming_test @@ -464,4 +465,25 @@ def do_orc_scan(spark): assert_cpu_and_gpu_are_equal_collect_with_capture( do_orc_scan, exist_classes= "FileSourceScanExec", - non_exist_classes= "GpuBatchScanExec") \ No newline at end of file + non_exist_classes= "GpuBatchScanExec") + + +@ignore_order +@pytest.mark.parametrize('v1_enabled_list', ["", "orc"]) +@pytest.mark.parametrize('reader_confs', reader_opt_confs, ids=idfn) +@pytest.mark.skipif(is_databricks_runtime(), reason="Databricks does not support ignoreCorruptFiles") +def test_orc_read_with_corrupt_files(spark_tmp_path, reader_confs, v1_enabled_list): + first_data_path = spark_tmp_path + '/ORC_DATA/first' + with_cpu_session(lambda spark : spark.range(1).toDF("a").write.orc(first_data_path)) + second_data_path = spark_tmp_path + '/ORC_DATA/second' + with_cpu_session(lambda spark : spark.range(1, 2).toDF("a").write.orc(second_data_path)) + third_data_path = spark_tmp_path + '/ORC_DATA/third' + with_cpu_session(lambda spark : spark.range(2, 3).toDF("a").write.json(third_data_path)) + + all_confs = copy_and_update(reader_confs, + {'spark.sql.files.ignoreCorruptFiles': "true", + 'spark.sql.sources.useV1SourceList': v1_enabled_list}) + + assert_gpu_and_cpu_are_equal_collect( + lambda spark : spark.read.orc([first_data_path, second_data_path, third_data_path]), + conf=all_confs) \ No newline at end of file diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScanBase.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScanBase.scala index 095386d3337..14afd362602 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScanBase.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScanBase.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids -import java.io.DataOutputStream +import java.io.{DataOutputStream, FileNotFoundException, IOException} import java.net.URI import java.nio.ByteBuffer import java.nio.channels.{Channels, WritableByteChannel} @@ -163,12 +163,14 @@ case class GpuOrcMultiFilePartitionReaderFactory( private val numThreads = rapidsConf.orcMultiThreadReadNumThreads private val maxNumFileProcessed = rapidsConf.maxNumOrcFilesParallel private val filterHandler = GpuOrcFileFilterHandler(sqlConf, broadcastedConf, filters) + private val ignoreMissingFiles = sqlConf.ignoreMissingFiles + private val ignoreCorruptFiles = sqlConf.ignoreCorruptFiles // we can't use the coalescing files reader when InputFileName, InputFileBlockStart, // or InputFileBlockLength because we are combining all the files into a single buffer // and we don't know which file is associated with each row. override val canUseCoalesceFilesReader: Boolean = - rapidsConf.isOrcCoalesceFileReadEnabled && !queryUsesInputFile + rapidsConf.isOrcCoalesceFileReadEnabled && !(queryUsesInputFile || ignoreCorruptFiles) override val canUseMultiThreadReader: Boolean = rapidsConf.isOrcMultiThreadReadEnabled @@ -183,7 +185,7 @@ case class GpuOrcMultiFilePartitionReaderFactory( PartitionReader[ColumnarBatch] = { new MultiFileCloudOrcPartitionReader(conf, files, dataSchema, readDataSchema, partitionSchema, maxReadBatchSizeRows, maxReadBatchSizeBytes, numThreads, maxNumFileProcessed, - debugDumpPrefix, filters, filterHandler, metrics) + debugDumpPrefix, filters, filterHandler, metrics, ignoreMissingFiles, ignoreCorruptFiles) } /** @@ -1292,6 +1294,8 @@ private case class GpuOrcFileFilterHandler( * @param filters filters passed into the filterHandler * @param filterHandler used to filter the ORC stripes * @param execMetrics the metrics + * @param ignoreMissingFiles Whether to ignore missing files + * @param ignoreCorruptFiles Whether to ignore corrupt files */ class MultiFileCloudOrcPartitionReader( conf: Configuration, @@ -1306,9 +1310,11 @@ class MultiFileCloudOrcPartitionReader( debugDumpPrefix: String, filters: Array[Filter], filterHandler: GpuOrcFileFilterHandler, - override val execMetrics: Map[String, GpuMetric]) + override val execMetrics: Map[String, GpuMetric], + ignoreMissingFiles: Boolean, + ignoreCorruptFiles: Boolean) extends MultiFileCloudPartitionReaderBase(conf, files, numThreads, maxNumFileProcessed, filters, - execMetrics) with MultiFileReaderFunctions with OrcPartitionReaderBase { + execMetrics, ignoreCorruptFiles) with MultiFileReaderFunctions with OrcPartitionReaderBase { private case class HostMemoryBuffersWithMetaData( override val partitionedFile: PartitionedFile, @@ -1329,6 +1335,16 @@ class MultiFileCloudOrcPartitionReader( TrampolineUtil.setTaskContext(taskContext) try { doRead() + } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${partFile.filePath}", e) + HostMemoryBuffersWithMetaData(partFile, Array((null, 0)), 0, null, None) + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e + case e @ (_: RuntimeException | _: IOException) if ignoreCorruptFiles => + logWarning( + s"Skipped the rest of the content in the corrupted file: ${partFile.filePath}", e) + HostMemoryBuffersWithMetaData(partFile, Array((null, 0)), 0, null, None) } finally { TrampolineUtil.unsetTaskContext() }