Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ignoreCorruptFiles for ORC readers [databricks] #4809

Merged
merged 1 commit into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions integration_tests/src/main/python/json_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

from asserts import assert_gpu_and_cpu_are_equal_collect
from data_gen import *
from src.main.python.marks import approximate_float, allow_non_gpu
from conftest import is_databricks_runtime
from marks import approximate_float, allow_non_gpu, ignore_order

from src.main.python.spark_session import with_cpu_session
from spark_session import with_cpu_session, with_gpu_session

json_supported_gens = [
# Spark does not escape '\r' or '\n' even though it uses it to mark end of record
Expand Down Expand Up @@ -206,3 +207,27 @@ def test_json_unquotedCharacters(std_input_path, filename, schema, read_func, al
schema,
{"allowUnquotedControlChars": allow_unquoted_chars}),
conf=_enable_all_types_conf)

@ignore_order
@pytest.mark.parametrize('v1_enabled_list', ["", "json"])
@pytest.mark.skipif(is_databricks_runtime(), reason="Databricks does not support ignoreCorruptFiles")
def test_json_read_with_corrupt_files(spark_tmp_path, v1_enabled_list):
first_data_path = spark_tmp_path + '/JSON_DATA/first'
with_cpu_session(lambda spark : spark.range(1).toDF("a").write.json(first_data_path))
second_data_path = spark_tmp_path + '/JSON_DATA/second'
with_cpu_session(lambda spark : spark.range(1, 2).toDF("a").write.orc(second_data_path))
third_data_path = spark_tmp_path + '/JSON_DATA/third'
with_cpu_session(lambda spark : spark.range(2, 3).toDF("a").write.json(third_data_path))

all_confs = copy_and_update(_enable_all_types_conf,
{'spark.sql.files.ignoreCorruptFiles': "true",
'spark.sql.sources.useV1SourceList': v1_enabled_list})
schema = StructType([StructField("a", IntegerType())])

# when ignoreCorruptFiles is enabled, gpu reading should not throw exception, while CPU can successfully
# read the three files without ignore corrupt files. So we just check if GPU will throw exception.
with_gpu_session(
lambda spark : spark.read.schema(schema)
.json([first_data_path, second_data_path, third_data_path])
.collect(),
conf=all_confs)
24 changes: 23 additions & 1 deletion integration_tests/src/main/python/orc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pyspark.sql.types import *
from spark_session import with_cpu_session, is_before_spark_330
from parquet_test import _nested_pruning_schemas
from conftest import is_databricks_runtime

pytestmark = pytest.mark.nightly_resource_consuming_test

Expand Down Expand Up @@ -464,4 +465,25 @@ def do_orc_scan(spark):
assert_cpu_and_gpu_are_equal_collect_with_capture(
do_orc_scan,
exist_classes= "FileSourceScanExec",
non_exist_classes= "GpuBatchScanExec")
non_exist_classes= "GpuBatchScanExec")


@ignore_order
@pytest.mark.parametrize('v1_enabled_list', ["", "orc"])
@pytest.mark.parametrize('reader_confs', reader_opt_confs, ids=idfn)
@pytest.mark.skipif(is_databricks_runtime(), reason="Databricks does not support ignoreCorruptFiles")
def test_orc_read_with_corrupt_files(spark_tmp_path, reader_confs, v1_enabled_list):
first_data_path = spark_tmp_path + '/ORC_DATA/first'
with_cpu_session(lambda spark : spark.range(1).toDF("a").write.orc(first_data_path))
second_data_path = spark_tmp_path + '/ORC_DATA/second'
with_cpu_session(lambda spark : spark.range(1, 2).toDF("a").write.orc(second_data_path))
third_data_path = spark_tmp_path + '/ORC_DATA/third'
with_cpu_session(lambda spark : spark.range(2, 3).toDF("a").write.json(third_data_path))

all_confs = copy_and_update(reader_confs,
{'spark.sql.files.ignoreCorruptFiles': "true",
'spark.sql.sources.useV1SourceList': v1_enabled_list})

assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.read.orc([first_data_path, second_data_path, third_data_path]),
conf=all_confs)
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids

import java.io.DataOutputStream
import java.io.{DataOutputStream, FileNotFoundException, IOException}
import java.net.URI
import java.nio.ByteBuffer
import java.nio.channels.{Channels, WritableByteChannel}
Expand Down Expand Up @@ -163,12 +163,14 @@ case class GpuOrcMultiFilePartitionReaderFactory(
private val numThreads = rapidsConf.orcMultiThreadReadNumThreads
private val maxNumFileProcessed = rapidsConf.maxNumOrcFilesParallel
private val filterHandler = GpuOrcFileFilterHandler(sqlConf, broadcastedConf, filters)
private val ignoreMissingFiles = sqlConf.ignoreMissingFiles
private val ignoreCorruptFiles = sqlConf.ignoreCorruptFiles

// we can't use the coalescing files reader when InputFileName, InputFileBlockStart,
// or InputFileBlockLength because we are combining all the files into a single buffer
// and we don't know which file is associated with each row.
override val canUseCoalesceFilesReader: Boolean =
rapidsConf.isOrcCoalesceFileReadEnabled && !queryUsesInputFile
rapidsConf.isOrcCoalesceFileReadEnabled && !(queryUsesInputFile || ignoreCorruptFiles)

override val canUseMultiThreadReader: Boolean = rapidsConf.isOrcMultiThreadReadEnabled

Expand All @@ -183,7 +185,7 @@ case class GpuOrcMultiFilePartitionReaderFactory(
PartitionReader[ColumnarBatch] = {
new MultiFileCloudOrcPartitionReader(conf, files, dataSchema, readDataSchema, partitionSchema,
maxReadBatchSizeRows, maxReadBatchSizeBytes, numThreads, maxNumFileProcessed,
debugDumpPrefix, filters, filterHandler, metrics)
debugDumpPrefix, filters, filterHandler, metrics, ignoreMissingFiles, ignoreCorruptFiles)
}

/**
Expand Down Expand Up @@ -1292,6 +1294,8 @@ private case class GpuOrcFileFilterHandler(
* @param filters filters passed into the filterHandler
* @param filterHandler used to filter the ORC stripes
* @param execMetrics the metrics
* @param ignoreMissingFiles Whether to ignore missing files
* @param ignoreCorruptFiles Whether to ignore corrupt files
*/
class MultiFileCloudOrcPartitionReader(
conf: Configuration,
Expand All @@ -1306,9 +1310,11 @@ class MultiFileCloudOrcPartitionReader(
debugDumpPrefix: String,
filters: Array[Filter],
filterHandler: GpuOrcFileFilterHandler,
override val execMetrics: Map[String, GpuMetric])
override val execMetrics: Map[String, GpuMetric],
ignoreMissingFiles: Boolean,
ignoreCorruptFiles: Boolean)
extends MultiFileCloudPartitionReaderBase(conf, files, numThreads, maxNumFileProcessed, filters,
execMetrics) with MultiFileReaderFunctions with OrcPartitionReaderBase {
execMetrics, ignoreCorruptFiles) with MultiFileReaderFunctions with OrcPartitionReaderBase {

private case class HostMemoryBuffersWithMetaData(
override val partitionedFile: PartitionedFile,
Expand All @@ -1329,6 +1335,16 @@ class MultiFileCloudOrcPartitionReader(
TrampolineUtil.setTaskContext(taskContext)
try {
doRead()
} catch {
case e: FileNotFoundException if ignoreMissingFiles =>
logWarning(s"Skipped missing file: ${partFile.filePath}", e)
HostMemoryBuffersWithMetaData(partFile, Array((null, 0)), 0, null, None)
// Throw FileNotFoundException even if `ignoreCorruptFiles` is true
case e: FileNotFoundException if !ignoreMissingFiles => throw e
case e @ (_: RuntimeException | _: IOException) if ignoreCorruptFiles =>
logWarning(
s"Skipped the rest of the content in the corrupted file: ${partFile.filePath}", e)
HostMemoryBuffersWithMetaData(partFile, Array((null, 0)), 0, null, None)
} finally {
TrampolineUtil.unsetTaskContext()
}
Expand Down