diff --git a/docs/configs.md b/docs/configs.md index 8d49ed4d31c..edad85b1904 100644 --- a/docs/configs.md +++ b/docs/configs.md @@ -75,9 +75,10 @@ Name | Description | Default Value spark.rapids.sql.explain|Explain why some parts of a query were not placed on a GPU or not. Possible values are ALL: print everything, NONE: print nothing, NOT_ON_GPU: print only parts of a query that did not go on the GPU|NONE spark.rapids.sql.fast.sample|Option to turn on fast sample. If enable it is inconsistent with CPU sample because of GPU sample algorithm is inconsistent with CPU.|false spark.rapids.sql.format.avro.enabled|When set to true enables all avro input and output acceleration. (only input is currently supported anyways)|false +spark.rapids.sql.format.avro.multiThreadedRead.maxNumFilesParallel|A limit on the maximum number of files per task processed in parallel on the CPU side before the file is sent to the GPU. This affects the amount of host memory used when reading the files in parallel. Used with MULTITHREADED reader, see spark.rapids.sql.format.avro.reader.type|2147483647 spark.rapids.sql.format.avro.multiThreadedRead.numThreads|The maximum number of threads, on one executor, to use for reading small avro files in parallel. This can not be changed at runtime after the executor has started. Used with MULTITHREADED reader, see spark.rapids.sql.format.avro.reader.type.|20 spark.rapids.sql.format.avro.read.enabled|When set to true enables avro input acceleration|false -spark.rapids.sql.format.avro.reader.type|Sets the avro reader type. We support different types that are optimized for different environments. The original Spark style reader can be selected by setting this to PERFILE which individually reads and copies files to the GPU. Loading many small files individually has high overhead, and using COALESCING is recommended instead. The COALESCING reader is good when using a local file system where the executors are on the same nodes or close to the nodes the data is being read on. This reader coalesces all the files assigned to a task into a single host buffer before sending it down to the GPU. It copies blocks from a single file into a host buffer in separate threads in parallel, see spark.rapids.sql.format.avro.multiThreadedRead.numThreads. By default this is set to AUTO so we select the reader we think is best. This will be COALESCING.|AUTO +spark.rapids.sql.format.avro.reader.type|Sets the avro reader type. We support different types that are optimized for different environments. The original Spark style reader can be selected by setting this to PERFILE which individually reads and copies files to the GPU. Loading many small files individually has high overhead, and using either COALESCING or MULTITHREADED is recommended instead. The COALESCING reader is good when using a local file system where the executors are on the same nodes or close to the nodes the data is being read on. This reader coalesces all the files assigned to a task into a single host buffer before sending it down to the GPU. It copies blocks from a single file into a host buffer in separate threads in parallel, see spark.rapids.sql.format.avro.multiThreadedRead.numThreads. MULTITHREADED is good for cloud environments where you are reading from a blobstore that is totally separate and likely has a higher I/O read cost. Many times the cloud environments also get better throughput when you have multiple readers in parallel. This reader uses multiple threads to read each file in parallel and each file is sent to the GPU separately. This allows the CPU to keep reading while GPU is also doing work. See spark.rapids.sql.format.avro.multiThreadedRead.numThreads and spark.rapids.sql.format.avro.multiThreadedRead.maxNumFilesParallel to control the number of threads and amount of memory used. By default this is set to AUTO so we select the reader we think is best. This will either be the COALESCING or the MULTITHREADED based on whether we think the file is in the cloud. See spark.rapids.cloudSchemes.|AUTO spark.rapids.sql.format.csv.enabled|When set to false disables all csv input and output acceleration. (only input is currently supported anyways)|true spark.rapids.sql.format.csv.read.enabled|When set to false disables csv input acceleration|true spark.rapids.sql.format.json.enabled|When set to true enables all json input and output acceleration. (only input is currently supported anyways)|false diff --git a/integration_tests/src/main/python/avro_test.py b/integration_tests/src/main/python/avro_test.py index 05219c04d76..2bea2f7cc88 100644 --- a/integration_tests/src/main/python/avro_test.py +++ b/integration_tests/src/main/python/avro_test.py @@ -30,7 +30,7 @@ 'spark.rapids.sql.format.avro.enabled': 'true', 'spark.rapids.sql.format.avro.read.enabled': 'true'} -rapids_reader_types = ['PERFILE', 'COALESCING'] +rapids_reader_types = ['PERFILE', 'COALESCING', 'MULTITHREADED'] # 50 files for the coalescing reading case coalescingPartitionNum = 50 @@ -117,3 +117,23 @@ def test_coalescing_uniform_sync(spark_tmp_path, v1_enabled_list): # read the coalesced files by CPU with_cpu_session( lambda spark: spark.read.format("avro").load(dump_path).collect()) + + +@ignore_order(local=True) +@pytest.mark.parametrize('v1_enabled_list', ["", "avro"], ids=["v1", "v2"]) +@pytest.mark.parametrize('reader_type', rapids_reader_types) +def test_avro_read_with_corrupt_files(spark_tmp_path, reader_type, v1_enabled_list): + first_dpath = spark_tmp_path + '/AVRO_DATA/first' + with_cpu_session(lambda spark : spark.range(1).toDF("a").write.format("avro").save(first_dpath)) + second_dpath = spark_tmp_path + '/AVRO_DATA/second' + with_cpu_session(lambda spark : spark.range(1, 2).toDF("a").write.format("avro").save(second_dpath)) + third_dpath = spark_tmp_path + '/AVRO_DATA/third' + with_cpu_session(lambda spark : spark.range(2, 3).toDF("a").write.json(third_dpath)) + + all_confs = copy_and_update(_enable_all_types_conf, { + 'spark.sql.files.ignoreCorruptFiles': "true", + 'spark.sql.sources.useV1SourceList': v1_enabled_list}) + + assert_gpu_and_cpu_are_equal_collect( + lambda spark : spark.read.format("avro").load([first_dpath, second_dpath, third_dpath]), + conf=all_confs) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AvroDataFileReader.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AvroDataFileReader.scala index f4039dcc09c..b868f86e44f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AvroDataFileReader.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AvroDataFileReader.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids -import java.io.{InputStream, IOException} +import java.io.{EOFException, InputStream, IOException, OutputStream} import java.nio.charset.StandardCharsets import scala.collection.mutable @@ -27,12 +27,13 @@ import org.apache.avro.file.SeekableInput import org.apache.avro.io.{BinaryData, BinaryDecoder, DecoderFactory} import org.apache.commons.io.output.{CountingOutputStream, NullOutputStream} -private class AvroSeekableInputStream(in: SeekableInput) extends InputStream with SeekableInput { +private[rapids] class AvroSeekableInputStream(in: SeekableInput) extends InputStream + with SeekableInput { var oneByte = new Array[Byte](1) override def read(): Int = { val n = read(oneByte, 0, 1) - if (n == 1) return oneByte(0) & 0xff else return n + if (n == 1) oneByte(0) & 0xff else n } override def read(b: Array[Byte]): Int = in.read(b, 0, b.length) @@ -121,26 +122,46 @@ object Header { /** * The each Avro block information * - * @param blockStart the start of block - * @param blockLength the whole block length = the size between two sync buffers + sync buffer - * @param blockSize the block data size - * @param count how many entries in this block + * @param blockStart the start of block + * @param blockSize the whole block size = the size between two sync buffers + sync buffer + * @param dataSize the block data size + * @param count how many entries in this block */ -case class BlockInfo(blockStart: Long, blockLength: Long, blockDataSize: Long, count: Long) +case class BlockInfo(blockStart: Long, blockSize: Long, dataSize: Long, count: Long) /** - * AvroDataFileReader parses the Avro file to get the header and all block information + * The mutable version of the BlockInfo without block start. + * This is for reusing an existing instance when accessing data in the iterator pattern. + * + * @param blockSize the whole block size (the size between two sync buffers + sync buffer size) + * @param dataSize the data size in this block + * @param count how many entries in this block */ -class AvroDataFileReader(si: SeekableInput) extends AutoCloseable { - private val sin = new AvroSeekableInputStream(si) - sin.seek(0) // seek to the start of file and get some meta info. - private var vin: BinaryDecoder = DecoderFactory.get.binaryDecoder(sin, vin); +case class MutableBlockInfo(var blockSize: Long, var dataSize: Long, var count: Long) + +/** The parent of the Rapids Avro file readers */ +abstract class AvroFileReader(si: SeekableInput) extends AutoCloseable { + // Children should update this pointer accordingly. + protected var curBlockStart = 0L + protected val headerSync: Array[Byte] = new Array[Byte](SYNC_SIZE) + + protected val sin = new AvroSeekableInputStream(si) + sin.seek(0) // seek to the file start to parse the header + protected var vin: BinaryDecoder = DecoderFactory.get.binaryDecoder(sin, vin) val (header, headerSize): (Header, Long) = initialize() - // store all blocks info - lazy val blocks: Seq[BlockInfo] = parseBlocks() + private lazy val partialMatchTable: Array[Int] = computePartialMatchTable(headerSync) + + protected def seek(position: Long): Unit = { + sin.seek(position) + vin = DecoderFactory.get.binaryDecoder(sin, vin) + } + // parse the Avro file header: + // ---------------------------------- + // | magic | metadata | sync marker | + // ---------------------------------- private def initialize(): (Header, Long) = { // read magic val magic = new Array[Byte](MAGIC.length) @@ -167,30 +188,122 @@ class AvroDataFileReader(si: SeekableInput) extends AutoCloseable { } while (l != 0) } // read sync marker - val sync = new Array[Byte](SYNC_SIZE) - vin.readFixed(sync) - (Header(meta.toMap, sync), sin.tell - vin.inputStream.available) + vin.readFixed(headerSync) + // store some information + curBlockStart = sin.tell - vin.inputStream.available + (Header(meta.toMap, headerSync), curBlockStart) } - private def seek(position: Long): Unit = { - sin.seek(position) - vin = DecoderFactory.get().binaryDecoder(this.sin, vin); + /** Return true if current point is past the next sync point after a position. */ + def pastSync(position: Long): Boolean = { + // If 'position' is in a block, this block will be read into this partition. + // If 'position' is in a sync marker, the next block will be read into this partition. + (curBlockStart >= position + SYNC_SIZE) || (curBlockStart >= sin.length()) + } + + /** + * Move to the next synchronization point after a position. To process a range + * of file entries, call this with the starting position, then check + * "pastSync(long)" with the end position before each call to "peekBlock()" or + * "readNextRawBlock". + * (Based off of the 'sync' in "DataFileReader" of apache/avro) + */ + def sync(position: Long): Unit = { + seek(position) + val pm = partialMatchTable + val in = vin.inputStream() + // Search for the sequence of bytes in the stream using Knuth-Morris-Pratt + var i = 0L + var j = 0 + var b = in.read() + while (b != -1) { + val cb = b.toByte + while (j > 0 && cb != headerSync(j)) { + j = pm(j - 1) + } + if (cb == headerSync(j)) { + j += 1 + } + if (j == SYNC_SIZE) { + curBlockStart = position + i + 1L + return + } + b = in.read() + i += 1 + } + // if no match set to the end position + curBlockStart = sin.tell() + } + + /** + * Compute that Knuth-Morris-Pratt partial match table. + * + * @param pattern The pattern being searched + * @return the pre-computed partial match table + * + * @see William Fiset + * Algorithms + * (Based off of the 'computePartialMatchTable' in "DataFileReader" of apache/avro) + */ + private def computePartialMatchTable(pattern: Array[Byte]): Array[Int] = { + val pm = new Array[Int](pattern.length) + var i = 1 + var len = 0 + while (i < pattern.length) { + if (pattern(i) == pattern(len)) { + len += 1 + pm(i) = len + i += 1 + } else { + if (len > 0) { + len = pm(len - 1) + } else { + i += 1 + } + } + } + pm + } + + override def close(): Unit = { + vin.inputStream().close() + } +} + +/** + * AvroMetaFileReader collects the blocks' information from the Avro file + * without reading the block data. + */ +class AvroMetaFileReader(si: SeekableInput) extends AvroFileReader(si) { + // store the blocks info + private var blocks: Seq[BlockInfo] = null + private var curStop: Long = -1L + + /** + * Collect the metadata of the blocks until the given stop point. + * The start block can also be specified by calling 'sync(start)' first. + * + * It is recommended setting start and stop positions to minimize what + * is going to be read. + */ + def getPartialBlocks(stop: Long): Seq[BlockInfo] = { + if (curStop != stop || blocks == null) { + blocks = parsePartialBlocks(stop) + curStop = stop + } + blocks } - private def parseBlocks(): Seq[BlockInfo] = { - var blockStart = headerSize - if (blockStart >= sin.length() || vin.isEnd()) { + private def parsePartialBlocks(stop: Long): Seq[BlockInfo] = { + if (curBlockStart >= sin.length() || vin.isEnd()) { // no blocks return Seq.empty } val blocks = mutable.ArrayBuffer.empty[BlockInfo] - // buf is used for writing long - val buf = new Array[Byte](12) - while (blockStart < sin.length()) { - seek(blockStart) - if (vin.isEnd()) { - return blocks.toSeq - } + // buf is used for writing a long, requiring at most 10 bytes. + val buf = new Array[Byte](10) + while (curBlockStart < sin.length() && !pastSync(stop)) { + seek(curBlockStart) val blockCount = vin.readLong() val blockDataSize = vin.readLong() if (blockDataSize > Integer.MAX_VALUE || blockDataSize < 0) { @@ -198,27 +311,136 @@ class AvroDataFileReader(si: SeekableInput) extends AutoCloseable { } // Get how many bytes used to store the value of count and block data size. - val blockCountLen = BinaryData.encodeLong(blockCount, buf, 0) - val blockDataSizeLen: Int = BinaryData.encodeLong(blockDataSize, buf, 0) - + val countLongLen = BinaryData.encodeLong(blockCount, buf, 0) + val dataSizeLongLen = BinaryData.encodeLong(blockDataSize, buf, 0) // (len of entries) + (len of block size) + (block size) + (sync size) - val blockLength = blockCountLen + blockDataSizeLen + blockDataSize + SYNC_SIZE - blocks += BlockInfo(blockStart, blockLength, blockDataSize, blockCount) + val blockLength = countLongLen + dataSizeLongLen + blockDataSize + SYNC_SIZE + blocks += BlockInfo(curBlockStart, blockLength, blockDataSize, blockCount) // Do we need to check the SYNC BUFFER, or just let cudf do it? - blockStart += blockLength + curBlockStart += blockLength } blocks.toSeq } - override def close(): Unit = { - vin.inputStream().close() +} + +/** + * AvroDataFileReader reads the Avro file data in the iterator pattern. + * You can use it as below. + * while(reader.hasNextBlock) { + * val b = reader.peekBlock + * estimateBufSize(b) + * // allocate the batch buffer + * reader.readNextRawBlock(buffer_as_out_stream) + * } + */ +class AvroDataFileReader(si: SeekableInput) extends AvroFileReader(si) { + // Avro file format: + // ---------------------------------------- + // | header | block | block | ... | block | + // ---------------------------------------- + // One block format: / \ + // ---------------------------------------------------- + // | Count | Data Size | Data in binary | sync marker | + // ---------------------------------------------------- + // - longsBuffer for the encoded block count and data size, each is at most 10 bytes. + // - syncBuffer for the sync marker, with the fixed size: 16 bytes. + // - dataBuffer for the block binary data + private val longsBuffer = new Array[Byte](20) + private val syncBuffer = new Array[Byte](SYNC_SIZE) + private var dataBuffer: Array[Byte] = null + + // count of objects in block + private var curCount: Long = 0L + // size in bytes of the serialized objects in block + private var curDataSize: Long = 0L + // block size = encoded count long size + encoded data-size long size + data size + 16 + private var curBlockSize: Long = 0L + // a flag to indicate whether there is block available currently + private var curBlockReady = false + + /** Test if there is a block. */ + def hasNextBlock(): Boolean = { + try { + if (curBlockReady) { return true } + // if reaching the end of stream + if (vin.isEnd()) { return false } + curCount = vin.readLong() // read block count + curDataSize = vin.readLong() // read block data size + if (curDataSize > Int.MaxValue || curDataSize < 0) { + throw new IOException(s"Invalid data size: $curDataSize, should be in (0, Int.MaxValue).") + } + // Get how many bytes used to store the values of count and block data size. + val countLongLen = BinaryData.encodeLong(curCount, longsBuffer, 0) + val dataSizeLongLen = BinaryData.encodeLong(curDataSize, longsBuffer, countLongLen) + curBlockSize = countLongLen + dataSizeLongLen + curDataSize + SYNC_SIZE + curBlockReady = true + true + } catch { + case _: EOFException => false + } } + + /** + * Get the current block metadata. Call 'readNextRawBlock' to get the block raw data. + * Better to check its existence by calling 'hasNextBlock' first. + * This will not move the reader position forward. + */ + def peekBlock(reuse: MutableBlockInfo): MutableBlockInfo = { + if (!hasNextBlock) { + throw new NoSuchElementException + } + if (reuse == null) { + MutableBlockInfo(curBlockSize, curDataSize, curCount) + } else { + reuse.blockSize = curBlockSize + reuse.dataSize = curDataSize + reuse.count = curCount + reuse + } + } + + /** + * Read the current block raw data to the given output stream. + */ + def readNextRawBlock(out: OutputStream): Unit = { + // This is designed to reduce the data copy as much as possible. + // Currently it leverages the BinarayDecoder, and data will be copied twice. + // Once to the temporary buffer `dataBuffer`, again to the output stream (the + // batch buffer in native). + // Later we may want to implement a Decoder ourselves to copy the data from raw + // buffer directly. + if (!hasNextBlock) { + throw new NoSuchElementException + } + val dataSize = curDataSize.toInt + if (dataBuffer == null || dataBuffer.size < dataSize) { + dataBuffer = new Array[Byte](dataSize) + } + // throws if it can't read the size requested + vin.readFixed(dataBuffer, 0, dataSize) + vin.readFixed(syncBuffer) + curBlockStart = sin.tell - vin.inputStream.available + if (!headerSync.sameElements(syncBuffer)) { + curBlockReady = false + throw new IOException("Invalid sync!") + } + out.write(longsBuffer, 0, (curBlockSize - curDataSize - SYNC_SIZE).toInt) + out.write(dataBuffer, 0, dataSize) + out.write(syncBuffer) + curBlockReady = false + } + } -object AvroDataFileReader { +object AvroFileReader { + + def openMetaReader(si: SeekableInput): AvroMetaFileReader = { + new AvroMetaFileReader(si) + } - def openReader(si: SeekableInput): AvroDataFileReader = { + def openDataReader(si: SeekableInput): AvroDataFileReader = { new AvroDataFileReader(si) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 2d86f4ed315..cc412b5bc7e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -962,18 +962,27 @@ object RapidsConf { .doc("Sets the avro reader type. We support different types that are optimized for " + "different environments. The original Spark style reader can be selected by setting this " + "to PERFILE which individually reads and copies files to the GPU. Loading many small files " + - "individually has high overhead, and using COALESCING is " + + "individually has high overhead, and using either COALESCING or MULTITHREADED is " + "recommended instead. The COALESCING reader is good when using a local file system where " + "the executors are on the same nodes or close to the nodes the data is being read on. " + "This reader coalesces all the files assigned to a task into a single host buffer before " + "sending it down to the GPU. It copies blocks from a single file into a host buffer in " + "separate threads in parallel, see " + "spark.rapids.sql.format.avro.multiThreadedRead.numThreads. " + + "MULTITHREADED is good for cloud environments where you are reading from a blobstore " + + "that is totally separate and likely has a higher I/O read cost. Many times the cloud " + + "environments also get better throughput when you have multiple readers in parallel. " + + "This reader uses multiple threads to read each file in parallel and each file is sent " + + "to the GPU separately. This allows the CPU to keep reading while GPU is also doing work. " + + "See spark.rapids.sql.format.avro.multiThreadedRead.numThreads and " + + "spark.rapids.sql.format.avro.multiThreadedRead.maxNumFilesParallel to control " + + "the number of threads and amount of memory used. " + "By default this is set to AUTO so we select the reader we think is best. This will " + - "be COALESCING.") + "either be the COALESCING or the MULTITHREADED based on whether we think the file is " + + "in the cloud. See spark.rapids.cloudSchemes.") .stringConf .transform(_.toUpperCase(java.util.Locale.ROOT)) - .checkValues((RapidsReaderType.values - RapidsReaderType.MULTITHREADED).map(_.toString)) + .checkValues(RapidsReaderType.values.map(_.toString)) .createWithDefault(RapidsReaderType.AUTO.toString) val AVRO_MULTITHREAD_READ_NUM_THREADS = @@ -985,6 +994,16 @@ object RapidsConf { .integerConf .createWithDefault(20) + val AVRO_MULTITHREAD_READ_MAX_NUM_FILES_PARALLEL = + conf("spark.rapids.sql.format.avro.multiThreadedRead.maxNumFilesParallel") + .doc("A limit on the maximum number of files per task processed in parallel on the CPU " + + "side before the file is sent to the GPU. This affects the amount of host memory used " + + "when reading the files in parallel. Used with MULTITHREADED reader, see " + + "spark.rapids.sql.format.avro.reader.type") + .integerConf + .checkValue(v => v > 0, "The maximum number of files must be greater than 0.") + .createWithDefault(Integer.MAX_VALUE) + val ENABLE_RANGE_WINDOW_BYTES = conf("spark.rapids.sql.window.range.byte.enabled") .doc("When the order-by column of a range based window is byte type and " + "the range boundary calculated for a value has overflow, CPU and GPU will get " + @@ -1776,8 +1795,13 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val isAvroCoalesceFileReadEnabled: Boolean = isAvroAutoReaderEnabled || RapidsReaderType.withName(get(AVRO_READER_TYPE)) == RapidsReaderType.COALESCING + lazy val isAvroMultiThreadReadEnabled: Boolean = isAvroAutoReaderEnabled || + RapidsReaderType.withName(get(AVRO_READER_TYPE)) == RapidsReaderType.MULTITHREADED + lazy val avroMultiThreadReadNumThreads: Int = get(AVRO_MULTITHREAD_READ_NUM_THREADS) + lazy val maxNumAvroFilesParallel: Int = get(AVRO_MULTITHREAD_READ_MAX_NUM_FILES_PARALLEL) + lazy val shuffleManagerEnabled: Boolean = get(SHUFFLE_MANAGER_ENABLED) lazy val shuffleTransportEnabled: Boolean = get(SHUFFLE_TRANSPORT_ENABLE) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala index cfd02b7424c..50f1a524017 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala @@ -99,7 +99,7 @@ object ExternalSource { if (hasSparkAvroJar) { format match { case _: AvroFileFormat => - val f = GpuAvroMultiFilePartitionReaderFactory( + GpuAvroMultiFilePartitionReaderFactory( fileScan.relation.sparkSession.sessionState.conf, fileScan.rapidsConf, broadcastedConf, @@ -110,22 +110,6 @@ object ExternalSource { fileScan.allMetrics, pushedFilters, fileScan.queryUsesInputFile) - // Now only coalescing is supported, so need to check if it can be used - // for the final choice. - if (f.canUseCoalesceFilesReader){ - f - } else { - // Fall back to PerFile reading - GpuAvroPartitionReaderFactory( - fileScan.relation.sparkSession.sessionState.conf, - fileScan.rapidsConf, - broadcastedConf, - fileScan.relation.dataSchema, - fileScan.requiredSchema, - fileScan.relation.partitionSchema, - new AvroOptions(fileScan.relation.options, broadcastedConf.value.value), - fileScan.allMetrics) - } case _ => // never reach here throw new RuntimeException(s"File format $format is not supported yet") diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuAvroScan.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuAvroScan.scala index f33362c7b6d..d9af5460ca9 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuAvroScan.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuAvroScan.scala @@ -29,6 +29,7 @@ import scala.math.max import ai.rapids.cudf.{AvroOptions => CudfAvroOptions, HostMemoryBuffer, NvtxColor, NvtxRange, Table} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.GpuMetric.{GPU_DECODE_TIME, NUM_OUTPUT_BATCHES, PEAK_DEVICE_MEMORY, READ_FS_TIME, SEMAPHORE_WAIT_TIME, WRITE_BUFFER_TIME} +import com.nvidia.spark.rapids.RapidsPluginImplicits._ import org.apache.avro.Schema import org.apache.avro.file.DataFileConstants.SYNC_SIZE import org.apache.avro.mapred.FsInput @@ -42,7 +43,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.avro.AvroOptions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.connector.read.{PartitionReader, PartitionReaderFactory} +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.datasources.{PartitionedFile, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.v2.{FilePartitionReaderFactory, FileScan} @@ -119,17 +120,9 @@ case class GpuAvroScan( GpuAvroPartitionReaderFactory(sparkSession.sessionState.conf, rapidsConf, broadcastedConf, dataSchema, readDataSchema, readPartitionSchema, parsedOptions, metrics) } else { - val f = GpuAvroMultiFilePartitionReaderFactory(sparkSession.sessionState.conf, + GpuAvroMultiFilePartitionReaderFactory(sparkSession.sessionState.conf, rapidsConf, broadcastedConf, dataSchema, readDataSchema, readPartitionSchema, parsedOptions, metrics, pushedFilters, queryUsesInputFile) - // Now only coalescing is supported, so need to check it can be used for the final choice. - if (f.canUseCoalesceFilesReader) { - f - } else { - // Fall back to PerFile reading - GpuAvroPartitionReaderFactory(sparkSession.sessionState.conf, rapidsConf, broadcastedConf, - dataSchema, readDataSchema, readPartitionSchema, parsedOptions, metrics) - } } } @@ -167,6 +160,8 @@ case class GpuAvroPartitionReaderFactory( private val maxReadBatchSizeRows = rapidsConf.maxReadBatchSizeRows private val maxReadBatchSizeBytes = rapidsConf.maxReadBatchSizeBytes + override def supportColumnarReads(partition: InputPartition): Boolean = true + override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = { throw new IllegalStateException("ROW BASED PARSING IS NOT SUPPORTED ON THE GPU...") } @@ -202,6 +197,7 @@ case class GpuAvroMultiFilePartitionReaderFactory( private val ignoreCorruptFiles = sqlConf.ignoreCorruptFiles private val numThreads = rapidsConf.avroMultiThreadReadNumThreads + private val maxNumFileProcessed = rapidsConf.maxNumAvroFilesParallel // we can't use the coalescing files reader when InputFileName, InputFileBlockStart, // or InputFileBlockLength because we are combining all the files into a single buffer @@ -209,8 +205,7 @@ case class GpuAvroMultiFilePartitionReaderFactory( override val canUseCoalesceFilesReader: Boolean = rapidsConf.isAvroCoalesceFileReadEnabled && !(queryUsesInputFile || ignoreCorruptFiles) - // disbale multi-threaded until it is supported. - override val canUseMultiThreadReader: Boolean = false + override val canUseMultiThreadReader: Boolean = rapidsConf.isAvroMultiThreadReadEnabled /** * File format short name used for logging and other things to uniquely identity @@ -221,10 +216,20 @@ case class GpuAvroMultiFilePartitionReaderFactory( /** * Build the PartitionReader for cloud reading */ + @scala.annotation.nowarn( + "msg=value ignoreExtension in class AvroOptions is deprecated*" + ) override def buildBaseColumnarReaderForCloud( - files: Array[PartitionedFile], + partFiles: Array[PartitionedFile], conf: Configuration): PartitionReader[ColumnarBatch] = { - throw new UnsupportedOperationException() + val files = if (options.ignoreExtension) { + partFiles + } else { + partFiles.filter(_.filePath.endsWith(".avro")) + } + new GpuMultiFileCloudAvroPartitionReader(conf, files, numThreads, maxNumFileProcessed, + filters, metrics, ignoreCorruptFiles, ignoreMissingFiles, debugDumpPrefix, + readDataSchema, partitionSchema, maxReadBatchSizeRows, maxReadBatchSizeBytes) } /** @@ -282,6 +287,11 @@ trait GpuAvroReaderBase extends Arm with Logging { self: FilePartitionReaderBase def readDataSchema: StructType + def conf: Configuration + + // Same default buffer size with Parquet readers. + val cacheBufferSize = conf.getInt("avro.read.allocation.size", 8 * 1024 * 1024) + /** * Read the host data to GPU for decoding, and return it as a cuDF Table. * The input host buffer should contain valid data, otherwise the behavior is @@ -362,7 +372,7 @@ trait GpuAvroReaderBase extends Arm with Logging { self: FilePartitionReaderBase if (numBytes == 0 || numBytes + estBytes <= maxReadBatchSizeBytes) { currentChunk += blockIter.next() numRows += currentChunk.last.count - numAvroBytes += currentChunk.last.blockDataSize + numAvroBytes += currentChunk.last.dataSize numBytes += estBytes readNextBatch() } @@ -411,7 +421,7 @@ trait GpuAvroReaderBase extends Arm with Logging { self: FilePartitionReaderBase // Start from the Header var totalSize: Long = headerSize // Add all blocks - totalSize += blocks.map(_.blockLength).sum + totalSize += blocks.map(_.blockSize).sum totalSize } @@ -426,7 +436,7 @@ trait GpuAvroReaderBase extends Arm with Logging { self: FilePartitionReaderBase // Copy every block without the tailing sync marker if a sync is given. This // is for coalescing reader who requires to append this given sync marker // to each block. Then we can not merge sequential blocks. - blocks.map(b => CopyRange(b.blockStart, b.blockLength - SYNC_SIZE)) + blocks.map(b => CopyRange(b.blockStart, b.blockSize - SYNC_SIZE)) }.getOrElse(computeCopyRanges(blocks)) val copySyncFunc: OutputStream => Unit = if (sync.isEmpty) { @@ -434,8 +444,8 @@ trait GpuAvroReaderBase extends Arm with Logging { self: FilePartitionReaderBase } else { out => out.write(sync.get, 0, SYNC_SIZE) } - // copy cache: 8MB - val copyCache = new Array[Byte](8 * 1024 * 1024) + // copy cache, default to 8MB + val copyCache = new Array[Byte](cacheBufferSize) var readTime, writeTime = 0L copyRanges.foreach { range => @@ -479,7 +489,7 @@ trait GpuAvroReaderBase extends Arm with Logging { self: FilePartitionReaderBase currentCopyStart = block.blockStart currentCopyEnd = currentCopyStart } - currentCopyEnd += block.blockLength + currentCopyEnd += block.blockSize } if (currentCopyEnd != currentCopyStart) { @@ -492,7 +502,7 @@ trait GpuAvroReaderBase extends Arm with Logging { self: FilePartitionReaderBase /** A PartitionReader that reads an AVRO file split on the GPU. */ class GpuAvroPartitionReader( - conf: Configuration, + override val conf: Configuration, partFile: PartitionedFile, blockMeta: AvroBlockMeta, override val readDataSchema: StructType, @@ -552,13 +562,250 @@ class GpuAvroPartitionReader( } +/** + * A PartitionReader that can read multiple AVRO files in parallel. + * This is most efficient running in a cloud environment where the I/O of reading is slow. + * + * When reading a file, it + * - seeks to the start position of the first block located in this partition. + * - next, parses the meta and sync, rewrites the meta and sync, and copies the data to a + * batch buffer per block, until reaching the last one of the current partition. + * - sends batches to GPU at last. + * + * @param conf the Hadoop configuration + * @param files the partitioned files to read + * @param numThreads the size of the threadpool + * @param maxNumFileProcessed threshold to control the maximum file number to be + * submitted to threadpool + * @param filters filters passed into the filterHandler + * @param execMetrics the metrics + * @param ignoreCorruptFiles Whether to ignore corrupt files + * @param ignoreMissingFiles Whether to ignore missing files + * @param debugDumpPrefix a path prefix to use for dumping the fabricated AVRO data or null + * @param readDataSchema the Spark schema describing what will be read + * @param partitionSchema Schema of partitions. + * @param maxReadBatchSizeRows soft limit on the maximum number of rows to be read per batch + * @param maxReadBatchSizeBytes soft limit on the maximum number of bytes to be read per batch + */ +class GpuMultiFileCloudAvroPartitionReader( + override val conf: Configuration, + files: Array[PartitionedFile], + numThreads: Int, + maxNumFileProcessed: Int, + filters: Array[Filter], + execMetrics: Map[String, GpuMetric], + ignoreCorruptFiles: Boolean, + ignoreMissingFiles: Boolean, + override val debugDumpPrefix: Option[String], + override val readDataSchema: StructType, + partitionSchema: StructType, + maxReadBatchSizeRows: Integer, + maxReadBatchSizeBytes: Long) + extends MultiFileCloudPartitionReaderBase(conf, files, numThreads, maxNumFileProcessed, filters, + execMetrics, ignoreCorruptFiles) with MultiFileReaderFunctions with GpuAvroReaderBase { + + override def readBatch(fileBufsAndMeta: HostMemoryBuffersWithMetaDataBase): + Option[ColumnarBatch] = fileBufsAndMeta match { + case buffer: AvroHostBuffersWithMeta => + val bufsAndSizes = buffer.memBuffersAndSizes + val (dataBuf, dataSize) = bufsAndSizes.head + val partitionValues = buffer.partitionedFile.partitionValues + val optBatch = if (dataBuf == null) { + // Not reading any data, but add in partition data if needed + // Someone is going to process this data, even if it is just a row count + GpuSemaphore.acquireIfNecessary(TaskContext.get(), metrics(SEMAPHORE_WAIT_TIME)) + val emptyBatch = new ColumnarBatch(Array.empty, dataSize.toInt) + addPartitionValues(Some(emptyBatch), partitionValues, partitionSchema) + } else { + val maybeBatch = sendToGpu(dataBuf, dataSize, files) + // we have to add partition values here for this batch, we already verified that + // it's not different for all the blocks in this batch + addPartitionValues(maybeBatch, partitionValues, partitionSchema) + } + // Update the current buffers + closeOnExcept(optBatch) { _ => + if (bufsAndSizes.length > 1) { + val updatedBuffers = bufsAndSizes.drop(1) + currentFileHostBuffers = Some(buffer.copy(memBuffersAndSizes = updatedBuffers)) + } else { + currentFileHostBuffers = None + } + } + optBatch + case t => + throw new RuntimeException(s"Unknown avro buffer type: ${t.getClass.getSimpleName}") + } + + override def getThreadPool(numThreads: Int): ThreadPoolExecutor = + AvroMultiFileThreadPool.getOrCreateThreadPool(getFileFormatShortName, numThreads) + + override final def getFileFormatShortName: String = "AVRO" + + override def getBatchRunner( + tc: TaskContext, + file: PartitionedFile, + config: Configuration, + filters: Array[Filter]): Callable[HostMemoryBuffersWithMetaDataBase] = + new ReadBatchRunner(tc, file, config, filters) + + /** Two utils classes */ + private case class AvroHostBuffersWithMeta( + override val partitionedFile: PartitionedFile, + override val memBuffersAndSizes: Array[(HostMemoryBuffer, Long)], + override val bytesRead: Long) extends HostMemoryBuffersWithMetaDataBase + + private class ReadBatchRunner( + taskContext: TaskContext, + partFile: PartitionedFile, + config: Configuration, + filters: Array[Filter]) extends Callable[HostMemoryBuffersWithMetaDataBase] with Logging { + + override def call(): HostMemoryBuffersWithMetaDataBase = { + TrampolineUtil.setTaskContext(taskContext) + try { + doRead() + } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${partFile.filePath}", e) + AvroHostBuffersWithMeta(partFile, Array((null, 0)), 0) + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e + case e @(_: RuntimeException | _: IOException) if ignoreCorruptFiles => + logWarning( + s"Skipped the rest of the content in the corrupted file: ${partFile.filePath}", e) + AvroHostBuffersWithMeta(partFile, Array((null, 0)), 0) + } finally { + TrampolineUtil.unsetTaskContext() + } + } + + private def createBufferAndMeta( + arrayBufSize: Array[(HostMemoryBuffer, Long)], + startingBytesRead: Long): HostMemoryBuffersWithMetaDataBase = { + val bytesRead = fileSystemBytesRead() - startingBytesRead + AvroHostBuffersWithMeta(partFile, arrayBufSize, bytesRead) + } + + private val stopPosition = partFile.start + partFile.length + + /** + * Read the split to one or more batches. + * Here is overview of the process: + * - some preparation + * - while (has next block in this split) { + * - 1) peek the current head block and estimate the batch buffer size + * - 2) read blocks as many as possible to fill the batch buffer + * - 3) One batch is done, append it to the list + * - } + * - post processing + */ + private def doRead(): HostMemoryBuffersWithMetaDataBase = { + val startingBytesRead = fileSystemBytesRead() + val in = new FsInput(new Path(new URI(partFile.filePath)), config) + val reader = closeOnExcept(in) { _ => AvroFileReader.openDataReader(in) } + withResource(reader) { _ => + // Go to the start of the first block after the start position + reader.sync(partFile.start) + if (!reader.hasNextBlock || isDone) { + // no data or got close before finishing, return null buffer and zero size + return createBufferAndMeta(Array((null, 0)), startingBytesRead) + } + val hostBuffers = new ArrayBuffer[(HostMemoryBuffer, Long)] + try { + val headerSize = reader.headerSize + var isBlockSizeEstimated = false + var estBlocksSize = 0L + var totalRowsNum = 0 + var curBlock: MutableBlockInfo = null + while (reader.hasNextBlock && !reader.pastSync(stopPosition)) { + // Get the block metadata first + curBlock = reader.peekBlock(curBlock) + if (!isBlockSizeEstimated) { + // Initialize the estimated block total size. + // The AVRO file has no special section for block metadata, and collecting the + // block meta through the file is quite expensive for files in cloud. So we do + // not know the target buffer size ahead. Then we have to do an estimation. + // "the estimated total block size = partFile.length + additional space" + // Letting "additional space = one block length * 1.2" is because we may + // move the start and stop positions when reading this split to keep the + // integrity of edge blocks. + // One worst case is the stop position is one byte after a block start, + // then we need to read the whole block into the current batch. And this block + // may be larger than the first block. So we preserve an additional space + // whose size is 'one block length * 1.2' to try to avoid reading it to a new + // batch. + estBlocksSize = partFile.length + (curBlock.blockSize * 1.2F).toLong + isBlockSizeEstimated = true + } + + var estSizeToRead = if (estBlocksSize > maxReadBatchSizeBytes) { + maxReadBatchSizeBytes + } else if (estBlocksSize < curBlock.blockSize) { + // This may happen only for the last block. + logInfo("Less buffer is estimated, read the last block into a new batch.") + curBlock.blockSize + } else { + estBlocksSize + } + // Allocate the buffer for the header and blocks for a batch + closeOnExcept(HostMemoryBuffer.allocate(headerSize + estSizeToRead)) { hmb => + val out = new HostMemoryOutputStream(hmb) + // Write the header to the output stream + AvroFileWriter(out).writeHeader(reader.header) + // Read the block data to the output stream + var batchRowsNum: Int = 0 + var hasNextBlock = true + do { + reader.readNextRawBlock(out) + batchRowsNum += curBlock.count.toInt + estSizeToRead -= curBlock.blockSize + // Continue reading the next block into the current batch when + // - the next block exists, and + // - the remaining buffer is enough to hold the next block, and + // - the batch rows number does not go beyond the upper limit. + hasNextBlock = reader.hasNextBlock && !reader.pastSync(stopPosition) + if (hasNextBlock) { + curBlock = reader.peekBlock(curBlock) + } + } while (hasNextBlock && curBlock.blockSize <= estSizeToRead && + batchRowsNum <= maxReadBatchSizeRows) + + // One batch is done + hostBuffers += ((hmb, out.getPos)) + totalRowsNum += batchRowsNum + estBlocksSize -= (out.getPos - headerSize) + } + } // end of while + + val bufAndSize: Array[(HostMemoryBuffer, Long)] = if (readDataSchema.isEmpty) { + // Overload the size to be the number of rows with null buffer + Array((null, totalRowsNum)) + } else if (isDone) { + // got close before finishing, return null buffer and zero size + hostBuffers.foreach(_._1.safeClose(new Exception)) + Array((null, 0)) + } else { + hostBuffers.toArray + } + createBufferAndMeta(bufAndSize, startingBytesRead) + } catch { + case e: Throwable => + hostBuffers.foreach(_._1.safeClose(e)) + throw e + } + } // end of withResource(reader) + } // end of doRead + } // end of Class ReadBatchRunner + +} + /** * A PartitionReader that can read multiple AVRO files up to the certain size. It will * coalesce small files together and copy the block data in a separate thread pool to speed * up processing the small files before sending down to the GPU. */ class GpuMultiFileAvroPartitionReader( - conf: Configuration, + override val conf: Configuration, splits: Array[PartitionedFile], clippedBlocks: Seq[AvroSingleDataBlockInfo], override val readDataSchema: StructType, @@ -586,7 +833,6 @@ class GpuMultiFileAvroPartitionReader( s" splitting it into a new batch!") return true } - false } @@ -711,25 +957,15 @@ case class AvroFileFilterHandler( ) val ignoreExtension = options.ignoreExtension - private def passSync(blockStart: Long, position: Long): Boolean = { - blockStart >= position + SYNC_SIZE - } - def filterBlocks(partFile: PartitionedFile): AvroBlockMeta = { if (ignoreExtension || partFile.filePath.endsWith(".avro")) { val in = new FsInput(new Path(new URI(partFile.filePath)), hadoopConf) - closeOnExcept(in) { _ => - withResource(AvroDataFileReader.openReader(in)) { reader => - val blocks = reader.blocks - val filteredBlocks = new ArrayBuffer[BlockInfo]() - blocks.foreach(block => { - if (partFile.start <= block.blockStart - SYNC_SIZE && - !passSync(block.blockStart, partFile.start + partFile.length)) { - filteredBlocks.append(block) - } - }) - AvroBlockMeta(reader.header, reader.headerSize, filteredBlocks) - } + val reader = closeOnExcept(in) { _ => AvroFileReader.openMetaReader(in) } + withResource(reader) { _ => + // Get blocks only belong to this split + reader.sync(partFile.start) + val partBlocks = reader.getPartialBlocks(partFile.start + partFile.length) + AvroBlockMeta(reader.header, reader.headerSize, partBlocks) } } else { AvroBlockMeta(null, 0L, Seq.empty) @@ -762,8 +998,8 @@ case class AvroSchemaWrapper(schema: Schema) extends SchemaBase /** avro BlockInfo wrapper */ case class AvroDataBlock(blockInfo: BlockInfo) extends DataBlockBase { override def getRowCount: Long = blockInfo.count - override def getReadDataSize: Long = blockInfo.blockDataSize - override def getBlockSize: Long = blockInfo.blockLength + override def getReadDataSize: Long = blockInfo.dataSize + override def getBlockSize: Long = blockInfo.blockSize } case class AvroSingleDataBlockInfo( diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuReaderTypeSuites.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuReaderTypeSuites.scala index b48366d6fe7..e78705c3be4 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuReaderTypeSuites.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuReaderTypeSuites.scala @@ -31,6 +31,7 @@ trait ReaderTypeSuite extends SparkQueryCompareTestSuite with Arm { protected def format: String protected def otherConfs: Iterable[(String, String)] = Seq.empty + protected def testContextOk: Boolean = true private def checkReaderType( readerFactory: PartitionReaderFactory, @@ -83,11 +84,9 @@ trait ReaderTypeSuite extends SparkQueryCompareTestSuite with Arm { }, conf.setAll(otherConfs)) } } -} - -trait MultiReaderTypeSuite extends ReaderTypeSuite { test("Use coalescing reading for local files") { + assume(testContextOk) val testFile = Array("/tmp/xyz") Seq(format, "").foreach(useV1Source => { val conf = new SparkConf() @@ -97,6 +96,7 @@ trait MultiReaderTypeSuite extends ReaderTypeSuite { } test("Use multithreaded reading for cloud files") { + assume(testContextOk) val testFile = Array("s3:/tmp/xyz") Seq(format, "").foreach(useV1Source => { val conf = new SparkConf() @@ -106,6 +106,7 @@ trait MultiReaderTypeSuite extends ReaderTypeSuite { } test("Force coalescing reading for cloud files when setting COALESCING ") { + assume(testContextOk) val testFile = Array("s3:/tmp/xyz") Seq(format, "").foreach(useV1Source => { val conf = new SparkConf() @@ -116,6 +117,7 @@ trait MultiReaderTypeSuite extends ReaderTypeSuite { } test("Force multithreaded reading for local files when setting MULTITHREADED") { + assume(testContextOk) val testFile = Array("/tmp/xyz") Seq(format, "").foreach(useV1Source => { val conf = new SparkConf() @@ -126,6 +128,7 @@ trait MultiReaderTypeSuite extends ReaderTypeSuite { } test("Use multithreaded reading for input expression even setting COALESCING") { + assume(testContextOk) val testFile = Array("/tmp/xyz") Seq(format, "").foreach(useV1Source => { val conf = new SparkConf() @@ -136,6 +139,7 @@ trait MultiReaderTypeSuite extends ReaderTypeSuite { } test("Use multithreaded reading for ignoreCorruptFiles even setting COALESCING") { + assume(testContextOk) val testFile = Array("/tmp/xyz") Seq(format, "").foreach(useV1Source => { val conf = new SparkConf() @@ -147,11 +151,11 @@ trait MultiReaderTypeSuite extends ReaderTypeSuite { } } -class GpuParquetReaderTypeSuites extends MultiReaderTypeSuite { +class GpuParquetReaderTypeSuites extends ReaderTypeSuite { override protected def format: String = "parquet" } -class GpuOrcReaderTypeSuites extends MultiReaderTypeSuite { +class GpuOrcReaderTypeSuites extends ReaderTypeSuite { override protected def format: String = "orc" } @@ -160,60 +164,5 @@ class GpuAvroReaderTypeSuites extends ReaderTypeSuite { override lazy val otherConfs: Iterable[(String, String)] = Seq( ("spark.rapids.sql.format.avro.read.enabled", "true"), ("spark.rapids.sql.format.avro.enabled", "true")) - - private lazy val hasAvroJar = ExternalSource.hasSparkAvroJar - - test("Use coalescing reading for local files") { - assume(hasAvroJar) - val testFile = Array("/tmp/xyz") - Seq(format, "").foreach(useV1Source => { - val conf = new SparkConf() - .set("spark.sql.sources.useV1SourceList", useV1Source) - testReaderType(conf, testFile, COALESCING) - }) - } - - test("Use coalescing reading for cloud files if coalescing can work") { - assume(hasAvroJar) - val testFile = Array("s3:/tmp/xyz") - Seq(format, "").foreach(useV1Source => { - val conf = new SparkConf() - .set("spark.sql.sources.useV1SourceList", useV1Source) - testReaderType(conf, testFile, COALESCING) - }) - } - - test("Force coalescing reading for cloud files when setting COALESCING ") { - assume(hasAvroJar) - val testFile = Array("s3:/tmp/xyz") - Seq(format, "").foreach(useV1Source => { - val conf = new SparkConf() - .set("spark.sql.sources.useV1SourceList", useV1Source) - .set(s"spark.rapids.sql.format.${format}.reader.type", "COALESCING") - testReaderType(conf, testFile, COALESCING) - }) - } - - test("Use per-file reading for input expression even setting COALESCING") { - assume(hasAvroJar) - val testFile = Array("/tmp/xyz") - Seq(format, "").foreach(useV1Source => { - val conf = new SparkConf() - .set("spark.sql.sources.useV1SourceList", useV1Source) - .set(s"spark.rapids.sql.format.${format}.reader.type", "COALESCING") - testReaderType(conf, testFile, PERFILE, hasInputExpression=true) - }) - } - - test("Use per-file reading for ignoreCorruptFiles even setting COALESCING") { - assume(hasAvroJar) - val testFile = Array("/tmp/xyz") - Seq(format, "").foreach(useV1Source => { - val conf = new SparkConf() - .set("spark.sql.sources.useV1SourceList", useV1Source) - .set("spark.sql.files.ignoreCorruptFiles", "true") - .set(s"spark.rapids.sql.format.${format}.reader.type", "COALESCING") - testReaderType(conf, testFile, PERFILE) - }) - } + override lazy val testContextOk = ExternalSource.hasSparkAvroJar }