diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala index 52aa5a69737ef..424526eafdfaa 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.types._ // NOTE: This class is instantiated and used on executor side only, no need to be serializable. private[avro] class AvroOutputWriter( - path: String, + val path: String, context: TaskAttemptContext, schema: StructType, avroSchema: Schema) extends OutputWriter { diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index df64de4b10075..837883e53d306 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration private[libsvm] class LibSVMOutputWriter( - path: String, + val path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 04e740039f005..9d09715d25932 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3150,6 +3150,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val MAX_CONCURRENT_OUTPUT_FILE_WRITERS = buildConf("spark.sql.maxConcurrentOutputFileWriters") + .internal() + .doc("Maximum number of output file writers to use concurrently. If number of writers " + + "needed reaches this limit, task will sort rest of output then writing them.") + .version("3.2.0") + .intConf + .createWithDefault(0) + /** * Holds information about keys that have been deprecated. * @@ -3839,6 +3847,8 @@ class SQLConf extends Serializable with Logging { def decorrelateInnerQueryEnabled: Boolean = getConf(SQLConf.DECORRELATE_INNER_QUERY_ENABLED) + def maxConcurrentOutputFileWriters: Int = getConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index b6b07de8a5d17..4f60a9d4c8c0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -53,11 +53,11 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) private[this] val partitions: mutable.ArrayBuffer[InternalRow] = mutable.ArrayBuffer.empty private[this] var numFiles: Int = 0 - private[this] var submittedFiles: Int = 0 + private[this] var numSubmittedFiles: Int = 0 private[this] var numBytes: Long = 0L private[this] var numRows: Long = 0L - private[this] var curFile: Option[String] = None + private[this] val submittedFiles = mutable.HashSet[String]() /** * Get the size of the file expected to have been written by a worker. @@ -134,23 +134,20 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) partitions.append(partitionValues) } - override def newBucket(bucketId: Int): Unit = { - // currently unhandled + override def newFile(filePath: String): Unit = { + submittedFiles += filePath + numSubmittedFiles += 1 } - override def newFile(filePath: String): Unit = { - statCurrentFile() - curFile = Some(filePath) - submittedFiles += 1 + override def closeFile(filePath: String): Unit = { + updateFileStats(filePath) + submittedFiles.remove(filePath) } - private def statCurrentFile(): Unit = { - curFile.foreach { path => - getFileSize(path).foreach { len => - numBytes += len - numFiles += 1 - } - curFile = None + private def updateFileStats(filePath: String): Unit = { + getFileSize(filePath).foreach { len => + numBytes += len + numFiles += 1 } } @@ -159,7 +156,8 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) } override def getFinalStats(): WriteTaskStats = { - statCurrentFile() + submittedFiles.foreach(updateFileStats) + submittedFiles.clear() // Reports bytesWritten and recordsWritten to the Spark output metrics. Option(TaskContext.get()).map(_.taskMetrics().outputMetrics).foreach { outputMetrics => @@ -167,8 +165,8 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) outputMetrics.setRecordsWritten(numRows) } - if (submittedFiles != numFiles) { - logInfo(s"Expected $submittedFiles files, but only saw $numFiles. " + + if (numSubmittedFiles != numFiles) { + logInfo(s"Expected $numSubmittedFiles files, but only saw $numFiles. " + "This could be due to the output format not writing empty files, " + "or files being not immediately visible in the filesystem.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index 6de9b1d7cea4b..8230737a61ca0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.spark.internal.Logging import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.sql.catalyst.InternalRow @@ -28,6 +29,8 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} +import org.apache.spark.sql.execution.datasources.FileFormatWriter.ConcurrentOutputWriterSpec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StringType import org.apache.spark.util.SerializableConfiguration @@ -52,19 +55,35 @@ abstract class FileFormatDataWriter( protected val statsTrackers: Seq[WriteTaskStatsTracker] = description.statsTrackers.map(_.newTaskInstance()) - protected def releaseResources(): Unit = { + /** Release resources of `currentWriter`. */ + protected def releaseCurrentWriter(): Unit = { if (currentWriter != null) { try { currentWriter.close() + statsTrackers.foreach(_.closeFile(currentWriter.path())) } finally { currentWriter = null } } } - /** Writes a record */ + /** Release all resources. */ + protected def releaseResources(): Unit = { + // Call `releaseCurrentWriter()` by default, as this is the only resource to be released. + releaseCurrentWriter() + } + + /** Writes a record. */ def write(record: InternalRow): Unit + + /** Write an iterator of records. */ + def writeWithIterator(iterator: Iterator[InternalRow]): Unit = { + while (iterator.hasNext) { + write(iterator.next()) + } + } + /** * Returns the summary of relative information which * includes the list of partition strings written out. The list of partitions is sent back @@ -144,34 +163,38 @@ class SingleDirectoryDataWriter( } /** - * Writes data to using dynamic partition writes, meaning this single function can write to + * Holds common logic for writing data with dynamic partition writes, meaning it can write to * multiple directories (partitions) or files (bucketing). */ -class DynamicPartitionDataWriter( +abstract class BaseDynamicPartitionDataWriter( description: WriteJobDescription, taskAttemptContext: TaskAttemptContext, committer: FileCommitProtocol) extends FileFormatDataWriter(description, taskAttemptContext, committer) { /** Flag saying whether or not the data to be written out is partitioned. */ - private val isPartitioned = description.partitionColumns.nonEmpty + protected val isPartitioned = description.partitionColumns.nonEmpty /** Flag saying whether or not the data to be written out is bucketed. */ - private val isBucketed = description.bucketIdExpression.isDefined + protected val isBucketed = description.bucketIdExpression.isDefined assert(isPartitioned || isBucketed, s"""DynamicPartitionWriteTask should be used for writing out data that's either - |partitioned or bucketed. In this case neither is true. - |WriteJobDescription: $description + |partitioned or bucketed. In this case neither is true. + |WriteJobDescription: $description """.stripMargin) - private var fileCounter: Int = _ - private var recordsInFile: Long = _ - private var currentPartitionValues: Option[UnsafeRow] = None - private var currentBucketId: Option[Int] = None + /** Number of records in current file. */ + protected var recordsInFile: Long = _ + + /** + * File counter for writing current partition or bucket. For same partition or bucket, + * we may have more than one file, due to number of records limit per file. + */ + protected var fileCounter: Int = _ /** Extracts the partition values out of an input row. */ - private lazy val getPartitionValues: InternalRow => UnsafeRow = { + protected lazy val getPartitionValues: InternalRow => UnsafeRow = { val proj = UnsafeProjection.create(description.partitionColumns, description.allColumns) row => proj(row) } @@ -186,22 +209,24 @@ class DynamicPartitionDataWriter( if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) }) - /** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns - * the partition string. */ + /** + * Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns + * the partition string. + */ private lazy val getPartitionPath: InternalRow => String = { val proj = UnsafeProjection.create(Seq(partitionPathExpression), description.partitionColumns) row => proj(row).getString(0) } /** Given an input row, returns the corresponding `bucketId` */ - private lazy val getBucketId: InternalRow => Int = { + protected lazy val getBucketId: InternalRow => Int = { val proj = UnsafeProjection.create(description.bucketIdExpression.toSeq, description.allColumns) row => proj(row).getInt(0) } /** Returns the data columns to be written given an input row */ - private val getOutputRow = + protected val getOutputRow = UnsafeProjection.create(description.dataColumns, description.allColumns) /** @@ -209,13 +234,20 @@ class DynamicPartitionDataWriter( * If bucket id is specified, we will append it to the end of the file name, but before the * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet * - * @param partitionValues the partition which all tuples being written by this `OutputWriter` + * @param partitionValues the partition which all tuples being written by this OutputWriter * belong to - * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to + * @param bucketId the bucket which all tuples being written by this OutputWriter belong to + * @param closeCurrentWriter close and release resource for current writer */ - private def newOutputWriter(partitionValues: Option[InternalRow], bucketId: Option[Int]): Unit = { + protected def renewCurrentWriter( + partitionValues: Option[InternalRow], + bucketId: Option[Int], + closeCurrentWriter: Boolean): Unit = { + recordsInFile = 0 - releaseResources() + if (closeCurrentWriter) { + releaseCurrentWriter() + } val partDir = partitionValues.map(getPartitionPath(_)) partDir.foreach(updatedPartitions.add) @@ -243,6 +275,51 @@ class DynamicPartitionDataWriter( statsTrackers.foreach(_.newFile(currentPath)) } + /** + * Open a new output writer when number of records exceeding limit. + * + * @param partitionValues the partition which all tuples being written by this `OutputWriter` + * belong to + * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to + */ + protected def renewCurrentWriterIfTooManyRecords( + partitionValues: Option[InternalRow], + bucketId: Option[Int]): Unit = { + // Exceeded the threshold in terms of the number of records per file. + // Create a new file by increasing the file counter. + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + renewCurrentWriter(partitionValues, bucketId, closeCurrentWriter = true) + } + + /** + * Writes the given record with current writer. + * + * @param record The record to write + */ + protected def writeRecord(record: InternalRow): Unit = { + val outputRow = getOutputRow(record) + currentWriter.write(outputRow) + statsTrackers.foreach(_.newRow(outputRow)) + recordsInFile += 1 + } +} + +/** + * Dynamic partition writer with single writer, meaning only one writer is opened at any time for + * writing. The records to be written are required to be sorted on partition and/or bucket + * column(s) before writing. + */ +class DynamicPartitionDataSingleWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol) + extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer) { + + private var currentPartitionValues: Option[UnsafeRow] = None + private var currentBucketId: Option[Int] = None + override def write(record: InternalRow): Unit = { val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None @@ -255,25 +332,199 @@ class DynamicPartitionDataWriter( } if (isBucketed) { currentBucketId = nextBucketId - statsTrackers.foreach(_.newBucket(currentBucketId.get)) } fileCounter = 0 - newOutputWriter(currentPartitionValues, currentBucketId) + renewCurrentWriter(currentPartitionValues, currentBucketId, closeCurrentWriter = true) } else if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { - // Exceeded the threshold in terms of the number of records per file. - // Create a new file by increasing the file counter. - fileCounter += 1 - assert(fileCounter < MAX_FILE_COUNTER, - s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + renewCurrentWriterIfTooManyRecords(currentPartitionValues, currentBucketId) + } + writeRecord(record) + } +} + +/** + * Dynamic partition writer with concurrent writers, meaning multiple concurrent writers are opened + * for writing. + * + * The process has the following steps: + * - Step 1: Maintain a map of output writers per each partition and/or bucket columns. Keep all + * writers opened and write rows one by one. + * - Step 2: If number of concurrent writers exceeds limit, sort rest of rows on partition and/or + * bucket column(s). Write rows one by one, and eagerly close the writer when finishing + * each partition and/or bucket. + * + * Caller is expected to call `writeWithIterator()` instead of `write()` to write records. + */ +class DynamicPartitionDataConcurrentWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol, + concurrentOutputWriterSpec: ConcurrentOutputWriterSpec) + extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer) + with Logging { + + /** Wrapper class to index a unique concurrent output writer. */ + private case class WriterIndex( + var partitionValues: Option[UnsafeRow], + var bucketId: Option[Int]) + + /** Wrapper class for status of a unique concurrent output writer. */ + private class WriterStatus( + var outputWriter: OutputWriter, + var recordsInFile: Long, + var fileCounter: Int) + + /** + * State to indicate if we are falling back to sort-based writer. + * Because we first try to use concurrent writers, its initial value is false. + */ + private var sorted: Boolean = false + private val concurrentWriters = mutable.HashMap[WriterIndex, WriterStatus]() + + /** + * The index for current writer. Intentionally make the index mutable and reusable. + * Avoid JVM GC issue when many short-living `WriterIndex` objects are created + * if switching between concurrent writers frequently. + */ + private val currentWriterId = WriterIndex(None, None) + + /** + * Release resources for all concurrent output writers. + */ + override protected def releaseResources(): Unit = { + currentWriter = null + concurrentWriters.values.foreach(status => { + if (status.outputWriter != null) { + try { + status.outputWriter.close() + } finally { + status.outputWriter = null + } + } + }) + concurrentWriters.clear() + } - newOutputWriter(currentPartitionValues, currentBucketId) + override def write(record: InternalRow): Unit = { + val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None + val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None + + if (currentWriterId.partitionValues != nextPartitionValues || + currentWriterId.bucketId != nextBucketId) { + // See a new partition or bucket - write to a new partition dir (or a new bucket file). + if (currentWriter != null) { + if (!sorted) { + // Update writer status in concurrent writers map, because the writer is probably needed + // again later for writing other rows. + updateCurrentWriterStatusInMap() + } else { + // Remove writer status in concurrent writers map and release current writer resource, + // because the writer is not needed any more. + concurrentWriters.remove(currentWriterId) + releaseCurrentWriter() + } + } + + if (isBucketed) { + currentWriterId.bucketId = nextBucketId + } + if (isPartitioned && currentWriterId.partitionValues != nextPartitionValues) { + currentWriterId.partitionValues = Some(nextPartitionValues.get.copy()) + if (!concurrentWriters.contains(currentWriterId)) { + statsTrackers.foreach(_.newPartition(currentWriterId.partitionValues.get)) + } + } + setupCurrentWriterUsingMap() } - val outputRow = getOutputRow(record) - currentWriter.write(outputRow) - statsTrackers.foreach(_.newRow(outputRow)) - recordsInFile += 1 + + if (description.maxRecordsPerFile > 0 && + recordsInFile >= description.maxRecordsPerFile) { + renewCurrentWriterIfTooManyRecords(currentWriterId.partitionValues, currentWriterId.bucketId) + // Update writer status in concurrent writers map, as a new writer is created. + updateCurrentWriterStatusInMap() + } + writeRecord(record) + } + + /** + * Write iterator of records with concurrent writers. + */ + override def writeWithIterator(iterator: Iterator[InternalRow]): Unit = { + while (iterator.hasNext && !sorted) { + write(iterator.next()) + } + + if (iterator.hasNext) { + clearCurrentWriterStatus() + val sorter = concurrentOutputWriterSpec.createSorter() + val sortIterator = sorter.sort(iterator.asInstanceOf[Iterator[UnsafeRow]]) + while (sortIterator.hasNext) { + write(sortIterator.next()) + } + } + } + + /** + * Update current writer status in map. + */ + private def updateCurrentWriterStatusInMap(): Unit = { + val status = concurrentWriters(currentWriterId) + status.outputWriter = currentWriter + status.recordsInFile = recordsInFile + status.fileCounter = fileCounter + } + + /** + * Retrieve writer in map, or create a new writer if not exists. + */ + private def setupCurrentWriterUsingMap(): Unit = { + if (concurrentWriters.contains(currentWriterId)) { + val status = concurrentWriters(currentWriterId) + currentWriter = status.outputWriter + recordsInFile = status.recordsInFile + fileCounter = status.fileCounter + } else { + fileCounter = 0 + renewCurrentWriter( + currentWriterId.partitionValues, + currentWriterId.bucketId, + closeCurrentWriter = false) + if (!sorted) { + assert(concurrentWriters.size <= concurrentOutputWriterSpec.maxWriters, + s"Number of concurrent output file writers is ${concurrentWriters.size} " + + s" which is beyond max value ${concurrentOutputWriterSpec.maxWriters}") + } else { + assert(concurrentWriters.size <= concurrentOutputWriterSpec.maxWriters + 1, + s"Number of output file writers after sort is ${concurrentWriters.size} " + + s" which is beyond max value ${concurrentOutputWriterSpec.maxWriters + 1}") + } + concurrentWriters.put( + currentWriterId.copy(), + new WriterStatus(currentWriter, recordsInFile, fileCounter)) + if (concurrentWriters.size >= concurrentOutputWriterSpec.maxWriters && !sorted) { + // Fall back to sort-based sequential writer mode. + logInfo(s"Number of concurrent writers ${concurrentWriters.size} reaches the threshold. " + + "Fall back from concurrent writers to sort-based sequential writer. You may change " + + s"threshold with configuration ${SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS.key}") + sorted = true + } + } + } + + /** + * Clear the current writer status in map. + */ + private def clearCurrentWriterStatus(): Unit = { + if (currentWriterId.partitionValues.isDefined || currentWriterId.bucketId.isDefined) { + updateCurrentWriterStatusInMap() + } + currentWriterId.partitionValues = None + currentWriterId.bucketId = None + currentWriter = null + recordsInFile = 0 + fileCounter = 0 } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 6300e10c0bb3d..6839a4db0bc28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution, UnsafeExternalRowSorter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StringType import org.apache.spark.unsafe.types.UTF8String @@ -73,6 +73,11 @@ object FileFormatWriter extends Logging { copy(child = newChild) } + /** Describes how concurrent output writers should be executed. */ + case class ConcurrentOutputWriterSpec( + maxWriters: Int, + createSorter: () => UnsafeExternalRowSorter) + /** * Basic work flow of this command is: * 1. Driver side setup, including output committer initialization and data source specific @@ -177,18 +182,27 @@ object FileFormatWriter extends Logging { committer.setupJob(job) try { - val rdd = if (orderingMatched) { - empty2NullPlan.execute() + val (rdd, concurrentOutputWriterSpec) = if (orderingMatched) { + (empty2NullPlan.execute(), None) } else { // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and // the physical plan may have different attribute ids due to optimizer removing some // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. val orderingExpr = bindReferences( requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns) - SortExec( + val sortPlan = SortExec( orderingExpr, global = false, - child = empty2NullPlan).execute() + child = empty2NullPlan) + + val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters + val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty + if (concurrentWritersEnabled) { + (empty2NullPlan.execute(), + Some(ConcurrentOutputWriterSpec(maxWriters, () => sortPlan.createSorter()))) + } else { + (sortPlan.execute(), None) + } } // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single @@ -211,7 +225,8 @@ object FileFormatWriter extends Logging { sparkPartitionId = taskContext.partitionId(), sparkAttemptNumber = taskContext.taskAttemptId().toInt & Integer.MAX_VALUE, committer, - iterator = iter) + iterator = iter, + concurrentOutputWriterSpec = concurrentOutputWriterSpec) }, rddWithNonEmptyPartitions.partitions.indices, (index, res: WriteTaskResult) => { @@ -245,7 +260,8 @@ object FileFormatWriter extends Logging { sparkPartitionId: Int, sparkAttemptNumber: Int, committer: FileCommitProtocol, - iterator: Iterator[InternalRow]): WriteTaskResult = { + iterator: Iterator[InternalRow], + concurrentOutputWriterSpec: Option[ConcurrentOutputWriterSpec]): WriteTaskResult = { val jobId = SparkHadoopWriterUtils.createJobID(new Date(jobIdInstant), sparkStageId) val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) @@ -273,15 +289,19 @@ object FileFormatWriter extends Logging { } else if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) { new SingleDirectoryDataWriter(description, taskAttemptContext, committer) } else { - new DynamicPartitionDataWriter(description, taskAttemptContext, committer) + concurrentOutputWriterSpec match { + case Some(spec) => + new DynamicPartitionDataConcurrentWriter( + description, taskAttemptContext, committer, spec) + case _ => + new DynamicPartitionDataSingleWriter(description, taskAttemptContext, committer) + } } try { Utils.tryWithSafeFinallyAndFailureCallbacks(block = { // Execute the task to write rows out and commit the task. - while (iterator.hasNext) { - dataWriter.write(iterator.next()) - } + dataWriter.writeWithIterator(iterator) dataWriter.commit() })(catchBlock = { // If there is an error, abort the task diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala index 1d7abe5b938c2..7c479d986f3e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala @@ -57,7 +57,7 @@ abstract class OutputWriterFactory extends Serializable { */ abstract class OutputWriter { /** - * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned + * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned * tables, dynamic partition columns are not included in rows to be written. */ def write(row: InternalRow): Unit @@ -67,4 +67,9 @@ abstract class OutputWriter { * the task output is committed. */ def close(): Unit + + /** + * The file path to write. Invoked on the executor side. + */ + def path(): String } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala index c39a82ee037bc..aaf866bced868 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala @@ -32,20 +32,7 @@ trait WriteTaskStats extends Serializable * A trait for classes that are capable of collecting statistics on data that's being processed by * a single write task in [[FileFormatWriter]] - i.e. there should be one instance per executor. * - * This trait is coupled with the way [[FileFormatWriter]] works, in the sense that its methods - * will be called according to how tuples are being written out to disk, namely in sorted order - * according to partitionValue(s), then bucketId. - * - * As such, a typical call scenario is: - * - * newPartition -> newBucket -> newFile -> newRow -. - * ^ |______^___________^ ^ ^____| - * | | |______________| - * | |____________________________| - * |____________________________________________| - * - * newPartition and newBucket events are only triggered if the relation to be written out is - * partitioned and/or bucketed, respectively. + * newPartition event is only triggered if the relation to be written out is partitioned. */ trait WriteTaskStatsTracker { @@ -56,22 +43,20 @@ trait WriteTaskStatsTracker { */ def newPartition(partitionValues: InternalRow): Unit - /** - * Process the fact that a new bucket is about to written. - * Only triggered when the relation is bucketed by a (non-empty) sequence of columns. - * @param bucketId The bucket number. - */ - def newBucket(bucketId: Int): Unit - /** * Process the fact that a new file is about to be written. * @param filePath Path of the file into which future rows will be written. */ def newFile(filePath: String): Unit + /** + * Process the fact that a file is finished to be written and closed. + * @param filePath Path of the file. + */ + def closeFile(filePath: String): Unit + /** * Process the fact that a new row to update the tracked statistics accordingly. - * The row will be written to the most recently witnessed file (via `newFile`). * @note Keep in mind that any overhead here is per-row, obviously, * so implementations should be as lightweight as possible. * @param row Current data row to be processed. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala index 2b549536ae486..35d0e098b19e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter} import org.apache.spark.sql.types.StructType class CsvOutputWriter( - path: String, + val path: String, dataSchema: StructType, context: TaskAttemptContext, params: CSVOptions) extends OutputWriter with Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala index 719d72f5b9b52..55602ce2ed9b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter} import org.apache.spark.sql.types.StructType class JsonOutputWriter( - path: String, + val path: String, options: JSONOptions, dataSchema: StructType, context: TaskAttemptContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala index 08086bcd91f6e..6f215737f5703 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.OutputWriter import org.apache.spark.sql.types._ private[sql] class OrcOutputWriter( - path: String, + val path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala index 70f6726c581a2..efb322f3fc906 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.OutputWriter // NOTE: This class is instantiated and used on executor side only, no need to be serializable. -class ParquetOutputWriter(path: String, context: TaskAttemptContext) +class ParquetOutputWriter(val path: String, context: TaskAttemptContext) extends OutputWriter { private val recordWriter: RecordWriter[Void, InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala index 2b1b81f60ceb4..2fb37c0dc0359 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter} import org.apache.spark.sql.types.StructType class TextOutputWriter( - path: String, + val path: String, dataSchema: StructType, lineSeparator: Array[Byte], context: TaskAttemptContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala index 1f25fed3000b2..d827e83623570 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory} -import org.apache.spark.sql.execution.datasources.{DynamicPartitionDataWriter, SingleDirectoryDataWriter, WriteJobDescription} +import org.apache.spark.sql.execution.datasources.{DynamicPartitionDataSingleWriter, SingleDirectoryDataWriter, WriteJobDescription} case class FileWriterFactory ( description: WriteJobDescription, @@ -35,7 +35,7 @@ case class FileWriterFactory ( if (description.partitionColumns.isEmpty) { new SingleDirectoryDataWriter(description, taskAttemptContext, committer) } else { - new DynamicPartitionDataWriter(description, taskAttemptContext, committer) + new DynamicPartitionDataSingleWriter(description, taskAttemptContext, committer) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 273658fcfa4c2..41d11568750cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.test import java.io.File -import java.util.Locale +import java.util.{Locale, Random} import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ @@ -1219,4 +1219,40 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with } } } + + test("SPARK-26164: Allow concurrent writers for multiple partitions and buckets") { + withTable("t1", "t2") { + // Uses fixed seed to ensure reproducible test execution + val r = new Random(31) + val df = spark.range(200).map(_ => { + val n = r.nextInt() + (n, n.toString, n % 5) + }).toDF("k1", "k2", "part") + df.write.format("parquet").saveAsTable("t1") + spark.sql("CREATE TABLE t2(k1 int, k2 string, part int) USING parquet PARTITIONED " + + "BY (part) CLUSTERED BY (k1) INTO 3 BUCKETS") + val queryToInsertTable = "INSERT OVERWRITE TABLE t2 SELECT k1, k2, part FROM t1" + + Seq( + // Single writer + 0, + // Concurrent writers without fallback + 200, + // concurrent writers with fallback + 3 + ).foreach { maxWriters => + withSQLConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS.key -> maxWriters.toString) { + spark.sql(queryToInsertTable).collect() + checkAnswer(spark.table("t2").orderBy("k1"), + spark.table("t1").orderBy("k1")) + + withSQLConf(SQLConf.MAX_RECORDS_PER_FILE.key -> "1") { + spark.sql(queryToInsertTable).collect() + checkAnswer(spark.table("t2").orderBy("k1"), + spark.table("t1").orderBy("k1")) + } + } + } + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala index c51c521cacba0..d4ec590f79f5e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala @@ -108,7 +108,7 @@ class HiveFileFormat(fileSinkConf: FileSinkDesc) } class HiveOutputWriter( - path: String, + val path: String, fileSinkConf: FileSinkDesc, jobConf: JobConf, dataSchema: StructType) extends OutputWriter with HiveInspectors { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 4707311341fcb..d2ac06ad0a16a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -271,7 +271,7 @@ private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) } private[orc] class OrcOutputWriter( - path: String, + val path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index d1b97b2852fbc..debe1ab734cc9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -117,7 +117,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { } } -class SimpleTextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext) +class SimpleTextOutputWriter(val path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))