From e39a105d7b6adfc25b494c5c3b08c486f4ae90f3 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Tue, 8 Sep 2020 17:34:53 -0700 Subject: [PATCH] Add Parquet-based cache serializer (#638) * upmerged * Pluggable cache using parquet to compress/decompress * test change needed for running with the serializer * Added GpuInMemoryTableScanExec * add 3.1 dependency * cache plugin with shims * Tagged RowToColumnar * cleaning up the TransitionOverrides Signed-off-by: Raza Jafri * cleanup Signed-off-by: Raza Jafri * review changes Signed-off-by: Raza Jafri * only read necessary columns * missing configs.md * regenerated configs.md * addressed review comments * fix the assert Co-authored-by: Raza Jafri --- .../src/main/python/cache_test.py | 2 +- .../rapids/shims/spark300/Spark300Shims.scala | 7 +- .../ParquetCachedBatchSerializer.scala | 288 ++++++++++++++++++ .../rapids/shims/spark310/Spark310Shims.scala | 33 +- .../GpuColumnarToRowTransitionExec.scala | 25 ++ .../spark310/GpuInMemoryTableScanExec.scala | 115 +++++++ .../spark/rapids/GpuColumnarToRowExec.scala | 157 +++++----- .../spark/rapids/GpuRowToColumnarExec.scala | 29 +- .../spark/rapids/GpuTransitionOverrides.scala | 23 +- .../com/nvidia/spark/rapids/SparkShims.scala | 4 +- 10 files changed, 591 insertions(+), 92 deletions(-) create mode 100644 shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/ParquetCachedBatchSerializer.scala create mode 100644 shims/spark310/src/main/scala/org/apache/spark/sql/rapids/shims/spark310/GpuColumnarToRowTransitionExec.scala create mode 100644 shims/spark310/src/main/scala/org/apache/spark/sql/rapids/shims/spark310/GpuInMemoryTableScanExec.scala diff --git a/integration_tests/src/main/python/cache_test.py b/integration_tests/src/main/python/cache_test.py index 573f4e8ccb3..346f2da0ccd 100644 --- a/integration_tests/src/main/python/cache_test.py +++ b/integration_tests/src/main/python/cache_test.py @@ -149,7 +149,7 @@ def do_join(spark): TimestampGen()] @pytest.mark.parametrize('data_gen', all_gen_restricting_dates, ids=idfn) -@allow_non_gpu('InMemoryTableScanExec', 'DataWritingCommandExec') +@allow_non_gpu('DataWritingCommandExec') def test_cache_posexplode_makearray(spark_tmp_path, data_gen): if is_spark_300() and data_gen.data_type == BooleanType(): pytest.xfail("https://issues.apache.org/jira/browse/SPARK-32672") diff --git a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala index bf9473cb45b..bf4298a67d0 100644 --- a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala +++ b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala @@ -359,6 +359,11 @@ class Spark300Shims extends SparkShims { override def copyFileSourceScanExec(scanExec: GpuFileSourceScanExec, supportsSmallFileOpt: Boolean): GpuFileSourceScanExec = { - scanExec.copy(supportsSmallFileOpt=supportsSmallFileOpt) + scanExec.copy(supportsSmallFileOpt = supportsSmallFileOpt) + } + + override def getGpuColumnarToRowTransition(plan: SparkPlan, + exportColumnRdd: Boolean): GpuColumnarToRowExecParent = { + GpuColumnarToRowExec(plan, exportColumnRdd) } } diff --git a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/ParquetCachedBatchSerializer.scala b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/ParquetCachedBatchSerializer.scala new file mode 100644 index 00000000000..46c3d1f568a --- /dev/null +++ b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/ParquetCachedBatchSerializer.scala @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.spark310 + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import ai.rapids.cudf._ +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.GpuColumnVector.GpuColumnarBatchBuilder +import com.nvidia.spark.rapids.RapidsPluginImplicits._ + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.columnar.{CachedBatch, CachedBatchSerializer} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.storage.StorageLevel + +class ParquetBufferConsumer(val numRows: Int) extends HostBufferConsumer with AutoCloseable { + @transient private[this] val offHeapBuffers = mutable.Queue[(HostMemoryBuffer, Long)]() + private var buffer: Array[Byte] = null + + override def handleBuffer(buffer: HostMemoryBuffer, len: Long): Unit = { + offHeapBuffers += Tuple2(buffer, len) + } + + def getBuffer(): Array[Byte] = { + if (buffer == null) { + writeBuffers() + } + buffer + } + + def close(): Unit = { + if (buffer == null) { + writeBuffers() + } + } + + private def writeBuffers(): Unit = { + val toProcess = offHeapBuffers.dequeueAll(_ => true) + // this could be problematic if the buffers are big as their cumulative length could be more + // than Int.MAX_SIZE. We could just have a list of buffers in that case and iterate over them + val bytes = toProcess.unzip._2.sum + + // for now assert bytes are less than Int.MaxValue + assert(bytes <= Int.MaxValue) + buffer = new Array(bytes.toInt) + try { + var offset: Int = 0 + toProcess.foreach(ops => { + val origBuffer = ops._1 + val len = ops._2.toInt + origBuffer.asByteBuffer().get(buffer, offset, len) + offset = offset + len + }) + } finally { + toProcess.map(_._1).safeClose() + } + } +} + +object ParquetCachedBatch { + def apply(parquetBuff: ParquetBufferConsumer): ParquetCachedBatch = { + new ParquetCachedBatch(parquetBuff.numRows, parquetBuff.getBuffer()) + } +} + +case class ParquetCachedBatch(numRows: Int, buffer: Array[Byte]) extends CachedBatch { + override def sizeInBytes: Long = buffer.length +} + +/** + * Spark wants the producer to close the batch. We have a listener in this iterator that will close + * the batch after the task is completed + */ +private case class CloseableColumnBatchIterator(iter: Iterator[ColumnarBatch]) extends + Iterator[ColumnarBatch] { + var cb: ColumnarBatch = null + + private def closeCurrentBatch(): Unit = { + if (cb != null) { + cb.close + cb = null + } + } + + TaskContext.get().addTaskCompletionListener[Unit]((tc: TaskContext) => { + closeCurrentBatch() + }) + + override def hasNext: Boolean = iter.hasNext + + override def next(): ColumnarBatch = { + closeCurrentBatch() + cb = iter.next() + cb + } +} + +/** + * This class assumes, the data is Columnar and the plugin is on + */ +class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm { + override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = true + + override def supportsColumnarOutput(schema: StructType): Boolean = true + + /** + * Convert an `RDD[ColumnarBatch]` into an `RDD[CachedBatch]` in preparation for caching the data. + * This method uses Parquet Writer on the GPU to write the cached batch + * @param input the input `RDD` to be converted. + * @param schema the schema of the data being stored. + * @param storageLevel where the data will be stored. + * @param conf the config for the query. + * @return The data converted into a format more suitable for caching. + */ + override def convertColumnarBatchToCachedBatch(input: RDD[ColumnarBatch], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + def putOnGpuIfNeeded(batch: ColumnarBatch): ColumnarBatch = { + if (batch.numCols() > 0 && !batch.column(0).isInstanceOf[GpuColumnVector]) { + val s = StructType( + schema.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + val gpuCB = new GpuColumnarBatchBuilder(s, batch.numRows(), batch).build(batch.numRows()) + batch.close() + gpuCB + } else { + batch + } + } + + input.map(batch => { + withResource(putOnGpuIfNeeded(batch)) { gpuCB => + compressColumnarBatchWithParquet(gpuCB) + } + }) + } + + private def compressColumnarBatchWithParquet(gpuCB: ColumnarBatch): ParquetCachedBatch = { + val buffer = new ParquetBufferConsumer(gpuCB.numRows()) + withResource(GpuColumnVector.from(gpuCB)) { table => + withResource(Table.writeParquetChunked(ParquetWriterOptions.DEFAULT, buffer)) { writer => + writer.write(table) + } + } + ParquetCachedBatch(buffer) + } + + /** + * This method decodes the CachedBatch leaving it on the GPU to avoid the extra copying back to + * the host + * @param input the cached batches that should be converted. + * @param cacheAttributes the attributes of the data in the batch. + * @param selectedAttributes the fields that should be loaded from the data and the order they + * should appear in the output batch. + * @param conf the configuration for the job. + * @return an RDD of the input cached batches transformed into the ColumnarBatch format. + */ + def gpuConvertCachedBatchToColumnarBatch(input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[ColumnarBatch] = { + convertCachedBatchToColumnarInternal(input, cacheAttributes, selectedAttributes) + } + + private def convertCachedBatchToColumnarInternal(input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute]) = { + + val requestedColumnIndices = selectedAttributes.map(a => + cacheAttributes.map(_.exprId).indexOf(a.exprId)) + + val cbRdd: RDD[ColumnarBatch] = input.map(batch => { + if (batch.isInstanceOf[ParquetCachedBatch]) { + val parquetCB = batch.asInstanceOf[ParquetCachedBatch] + val parquetOptions = ParquetOptions.builder().includeColumn(requestedColumnIndices + .map(i => "_col"+i).asJavaCollection).build() + withResource(Table.readParquet(parquetOptions, parquetCB.buffer, 0, + parquetCB.sizeInBytes)) { table => + withResource(GpuColumnVector.from(table)) { cb => + val cols = GpuColumnVector.extractColumns(cb) + new ColumnarBatch(requestedColumnIndices.map(ordinal => + cols(ordinal).incRefCount()).toArray, cb.numRows()) + } + } + } else { + throw new IllegalStateException("I don't know how to convert this batch") + } + }) + cbRdd + } + + /** + * Convert the cached data into a ColumnarBatch taking the result data back to the host + * @param input the cached batches that should be converted. + * @param cacheAttributes the attributes of the data in the batch. + * @param selectedAttributes the fields that should be loaded from the data and the order they + * should appear in the output batch. + * @param conf the configuration for the job. + * @return an RDD of the input cached batches transformed into the ColumnarBatch format. + */ + override def convertCachedBatchToColumnarBatch(input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[ColumnarBatch] = { + val batches = convertCachedBatchToColumnarInternal(input, cacheAttributes, + selectedAttributes) + val cbRdd = batches.map(batch => { + withResource(batch) { gpuBatch => + val cols = GpuColumnVector.extractColumns(gpuBatch) + new ColumnarBatch(cols.map(_.copyToHost()).toArray, gpuBatch.numRows()) + } + }) + cbRdd.mapPartitions(iter => new CloseableColumnBatchIterator(iter)) + } + + /** + * Convert the cached batch into `InternalRow`s. + * @param input the cached batches that should be converted. + * @param cacheAttributes the attributes of the data in the batch. + * @param selectedAttributes the field that should be loaded from the data and the order they + * should appear in the output rows. + * @param conf the configuration for the job. + * @return RDD of the rows that were stored in the cached batches. + */ + override def convertCachedBatchToInternalRow(input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[InternalRow] = { + val cb = convertCachedBatchToColumnarBatch(input, cacheAttributes, selectedAttributes, conf) + val rowRdd = cb.mapPartitions(iter => { + new ColumnarToRowIterator(iter) + }) + rowRdd + } + + /** + * Convert an `RDD[InternalRow]` into an `RDD[CachedBatch]` in preparation for caching the data. + * We use the RowToColumnarIterator and convert each batch at a time + * @param input the input `RDD` to be converted. + * @param schema the schema of the data being stored. + * @param storageLevel where the data will be stored. + * @param conf the config for the query. + * @return The data converted into a format more suitable for caching. + */ + override def convertInternalRowToCachedBatch(input: RDD[InternalRow], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + val s = StructType(schema.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + val converters = new GpuRowToColumnConverter(s) + val columnarBatchRdd = input.mapPartitions(iter => { + new RowToColumnarIterator(iter, s, RequireSingleBatch, converters) + }) + columnarBatchRdd.map(cb => { + withResource(cb) { columnarBatch => + val cachedBatch = compressColumnarBatchWithParquet(columnarBatch) + cachedBatch + } + }) + } + + override def buildFilter(predicates: Seq[Expression], + cachedAttributes: Seq[Attribute]): (Int, Iterator[CachedBatch]) => Iterator[CachedBatch] = { + //essentially a noop + (partId: Int, b: Iterator[CachedBatch]) => b + } +} diff --git a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala index 86bdf57c832..740dad6124a 100644 --- a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala +++ b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala @@ -28,17 +28,20 @@ import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.datasources.HadoopFsRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec +import org.apache.spark.sql.internal.StaticSQLConf import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, ShuffleManagerShimBase} import org.apache.spark.sql.rapids.execution.GpuBroadcastNestedLoopJoinExecBase -import org.apache.spark.sql.rapids.shims.spark310._ +import org.apache.spark.sql.rapids.shims.spark310.{GpuInMemoryTableScanExec, ShuffleManagerShim} import org.apache.spark.sql.types._ import org.apache.spark.storage.{BlockId, BlockManagerId} @@ -141,6 +144,21 @@ class Spark310Shims extends Spark301Shims { canUseSmallFileOpt) } }), + GpuOverrides.exec[InMemoryTableScanExec]( + "Implementation of InMemoryTableScanExec to use GPU accelerated Caching", + (scan, conf, p, r) => new SparkPlanMeta[InMemoryTableScanExec](scan, conf, p, r) { + override def tagPlanForGpu(): Unit = { + if (!scan.relation.cacheBuilder.serializer.isInstanceOf[ParquetCachedBatchSerializer]) { + willNotWorkOnGpu("ParquetCachedBatchSerializer is not being used") + } + } + /** + * Convert InMemoryTableScanExec to a GPU enabled version. + */ + override def convertToGpu(): GpuExec = { + GpuInMemoryTableScanExec(scan.attributes, scan.predicates, scan.relation) + } + }), GpuOverrides.exec[SortMergeJoinExec]( "Sort merge join, replacing with shuffled hash join", (join, conf, p, r) => new GpuSortMergeJoinMeta(join, conf, p, r)), @@ -223,6 +241,17 @@ class Spark310Shims extends Spark301Shims { override def copyFileSourceScanExec(scanExec: GpuFileSourceScanExec, supportsSmallFileOpt: Boolean): GpuFileSourceScanExec = { - scanExec.copy(supportsSmallFileOpt=supportsSmallFileOpt) + scanExec.copy(supportsSmallFileOpt = supportsSmallFileOpt) + } + + override def getGpuColumnarToRowTransition(plan: SparkPlan, + exportColumnRdd: Boolean): GpuColumnarToRowExecParent = { + val serName = plan.conf.getConf(StaticSQLConf.SPARK_CACHE_SERIALIZER) + val serClass = Class.forName(serName) + if (serClass == classOf[ParquetCachedBatchSerializer]) { + org.apache.spark.sql.rapids.shims.spark310.GpuColumnarToRowTransitionExec(plan) + } else { + GpuColumnarToRowExec(plan) + } } } diff --git a/shims/spark310/src/main/scala/org/apache/spark/sql/rapids/shims/spark310/GpuColumnarToRowTransitionExec.scala b/shims/spark310/src/main/scala/org/apache/spark/sql/rapids/shims/spark310/GpuColumnarToRowTransitionExec.scala new file mode 100644 index 00000000000..d282d30a0a1 --- /dev/null +++ b/shims/spark310/src/main/scala/org/apache/spark/sql/rapids/shims/spark310/GpuColumnarToRowTransitionExec.scala @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.shims.spark310 + +import com.nvidia.spark.rapids.GpuColumnarToRowExecParent + +import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan} + +case class GpuColumnarToRowTransitionExec(child: SparkPlan, + override val exportColumnarRdd: Boolean = false) + extends GpuColumnarToRowExecParent(child, exportColumnarRdd) with ColumnarToRowTransition diff --git a/shims/spark310/src/main/scala/org/apache/spark/sql/rapids/shims/spark310/GpuInMemoryTableScanExec.scala b/shims/spark310/src/main/scala/org/apache/spark/sql/rapids/shims/spark310/GpuInMemoryTableScanExec.scala new file mode 100644 index 00000000000..ea8ae8f68bc --- /dev/null +++ b/shims/spark310/src/main/scala/org/apache/spark/sql/rapids/shims/spark310/GpuInMemoryTableScanExec.scala @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.shims.spark310 + +import com.nvidia.spark.rapids.GpuExec +import com.nvidia.spark.rapids.shims.spark310.ParquetCachedBatchSerializer + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan} +import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.vectorized.ColumnarBatch + +case class GpuInMemoryTableScanExec( + attributes: Seq[Attribute], + predicates: Seq[Expression], + @transient relation: InMemoryRelation) extends LeafExecNode with GpuExec { + + override val nodeName: String = { + relation.cacheBuilder.tableName match { + case Some(_) => + "Scan " + relation.cacheBuilder.cachedName + case _ => + super.nodeName + } + } + + override def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren + + override def doCanonicalize(): SparkPlan = + copy(attributes = attributes.map(QueryPlan.normalizeExpressions(_, relation.output)), + predicates = predicates.map(QueryPlan.normalizeExpressions(_, relation.output)), + relation = relation.canonicalized.asInstanceOf[InMemoryRelation]) + + override def vectorTypes: Option[Seq[String]] = + relation.cacheBuilder.serializer.vectorTypes(attributes, conf) + + private lazy val columnarInputRDD: RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val buffers = filteredCachedBatches() + relation.cacheBuilder.serializer.asInstanceOf[ParquetCachedBatchSerializer] + .gpuConvertCachedBatchToColumnarBatch( + buffers, + relation.output, + attributes, + conf).map { cb => + numOutputRows += cb.numRows() + cb + } + } + + override def output: Seq[Attribute] = attributes + + private def updateAttribute(expr: Expression): Expression = { + // attributes can be pruned so using relation's output. + // E.g., relation.output is [id, item] but this scan's output can be [item] only. + val attrMap = AttributeMap(relation.cachedPlan.output.zip(relation.output)) + expr.transform { + case attr: Attribute => attrMap.getOrElse(attr, attr) + } + } + + // The cached version does not change the outputPartitioning of the original SparkPlan. + // But the cached version could alias output, so we need to replace output. + override def outputPartitioning: Partitioning = { + relation.cachedPlan.outputPartitioning match { + case e: Expression => updateAttribute(e).asInstanceOf[Partitioning] + case other => other + } + } + + // The cached version does not change the outputOrdering of the original SparkPlan. + // But the cached version could alias output, so we need to replace output. + override def outputOrdering: Seq[SortOrder] = + relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) + + lazy val enableAccumulatorsForTest: Boolean = sqlContext.conf.inMemoryTableScanStatisticsEnabled + + // Accumulators used for testing purposes + lazy val readPartitions = sparkContext.longAccumulator + lazy val readBatches = sparkContext.longAccumulator + + private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning + + private def filteredCachedBatches() = { + // Right now just return the batch without filtering + relation.cacheBuilder.cachedColumnBuffers + } + + protected override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException("This Exec only deals with Columnar Data") + } + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + columnarInputRDD + } + } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarToRowExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarToRowExec.scala index 198cc94cf43..6e47f8e6095 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarToRowExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarToRowExec.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids -import ai.rapids.cudf.NvtxColor +import ai.rapids.cudf.{NvtxColor, NvtxRange} import com.nvidia.spark.rapids.GpuMetricNames._ import org.apache.spark.TaskContext @@ -32,7 +32,83 @@ import org.apache.spark.sql.rapids.execution.GpuColumnToRowMapPartitionsRDD import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} -case class GpuColumnarToRowExec(child: SparkPlan, exportColumnarRdd: Boolean = false) +class ColumnarToRowIterator(batches: Iterator[ColumnarBatch], numInputBatches: SQLMetric = null, + numOutputRows: SQLMetric = null, totalTime: SQLMetric = null) extends Iterator[InternalRow] { + // GPU batches read in must be closed by the receiver (us) + @transient var cb: ColumnarBatch = null + var it: java.util.Iterator[InternalRow] = null + + TaskContext.get().addTaskCompletionListener[Unit](_ => closeCurrentBatch()) + + private def closeCurrentBatch(): Unit = { + if (cb != null) { + cb.close() + cb = null + } + } + + def loadNextBatch(): Unit = { + closeCurrentBatch() + if (it != null) { + it = null + } + if (batches.hasNext) { + val devCb = batches.next() + val nvtxRange = if (totalTime != null) { + new NvtxWithMetrics("ColumnarToRow: batch", NvtxColor.RED, totalTime) + } else { + new NvtxRange("ColumnarToRow: batch", NvtxColor.RED) + } + + try { + cb = new ColumnarBatch(GpuColumnVector.extractColumns(devCb).map(_.copyToHost()), + devCb.numRows()) + it = cb.rowIterator() + if (numInputBatches != null) { + numInputBatches += 1 + } + // In order to match the numOutputRows metric in the generated code we update + // numOutputRows for each batch. This is less accurate than doing it at output + // because it will over count the number of rows output in the case of a limit, + // but it is more efficient. + if (numOutputRows != null) { + numOutputRows += cb.numRows() + } + } finally { + devCb.close() + // Leaving the GPU for a while + GpuSemaphore.releaseIfNecessary(TaskContext.get()) + nvtxRange.close() + } + } + } + + override def hasNext: Boolean = { + val itHasNext = it != null && it.hasNext + if (!itHasNext) { + loadNextBatch() + it != null && it.hasNext + } else { + itHasNext + } + } + + override def next(): InternalRow = { + if (it == null || !it.hasNext) { + loadNextBatch() + } + if (it == null) { + throw new NoSuchElementException() + } + it.next() + } + + // This is to convert the InternalRow to an UnsafeRow. Even though the type is + // InternalRow some operations downstream operations like collect require it to + // be UnsafeRow +} + +abstract class GpuColumnarToRowExecParent(child: SparkPlan, val exportColumnarRdd: Boolean) extends UnaryExecNode with CodegenSupport with GpuExec { // We need to do this so the assertions don't fail override def supportsColumnar = false @@ -71,71 +147,7 @@ case class GpuColumnarToRowExec(child: SparkPlan, exportColumnarRdd: Boolean = f val f = (batches: Iterator[ColumnarBatch]) => { // UnsafeProjection is not serializable so do it on the executor side val toUnsafe = UnsafeProjection.create(localOutput, localOutput) - new Iterator[InternalRow] { - // GPU batches read in must be closed by the receiver (us) - @transient var cb: ColumnarBatch = null - var it: java.util.Iterator[InternalRow] = null - - TaskContext.get().addTaskCompletionListener[Unit](_ => closeCurrentBatch()) - - private def closeCurrentBatch(): Unit = { - if (cb != null) { - cb.close() - cb = null - } - } - - def loadNextBatch(): Unit = { - closeCurrentBatch() - if (it != null) { - it = null - } - if (batches.hasNext) { - val devCb = batches.next() - val nvtxRange = new NvtxWithMetrics("ColumnarToRow: batch", NvtxColor.RED, totalTime) - try { - cb = new ColumnarBatch(GpuColumnVector.extractColumns(devCb).map(_.copyToHost()), - devCb.numRows()) - it = cb.rowIterator() - numInputBatches += 1 - // In order to match the numOutputRows metric in the generated code we update - // numOutputRows for each batch. This is less accurate than doing it at output - // because it will over count the number of rows output in the case of a limit, - // but it is more efficient. - numOutputRows += cb.numRows() - } finally { - devCb.close() - // Leaving the GPU for a while - GpuSemaphore.releaseIfNecessary(TaskContext.get()) - nvtxRange.close() - } - } - } - - override def hasNext: Boolean = { - val itHasNext = it != null && it.hasNext - if (!itHasNext) { - loadNextBatch() - it != null && it.hasNext - } else { - itHasNext - } - } - - override def next(): InternalRow = { - if (it == null || !it.hasNext) { - loadNextBatch() - } - if (it == null) { - throw new NoSuchElementException() - } - it.next() - } - - // This is to convert the InternalRow to an UnsafeRow. Even though the type is - // InternalRow some operations downstream operations like collect require it to - // be UnsafeRow - }.map(toUnsafe) + new ColumnarToRowIterator(batches, numInputBatches, numOutputRows, totalTime).map(toUnsafe) } val cdata = child.executeColumnar() @@ -211,7 +223,7 @@ case class GpuColumnarToRowExec(child: SparkPlan, exportColumnarRdd: Boolean = f | }); | } """.stripMargin.trim) - s"$initTCListenerFuncName();" }, forceInline=true) + s"$initTCListenerFuncName();" }, forceInline = true) val idx = ctx.addMutableState(CodeGenerator.JAVA_INT, "batchIdx") // init as batchIdx = 0 val columnVectorClzs = child.vectorTypes.getOrElse( @@ -282,3 +294,12 @@ case class GpuColumnarToRowExec(child: SparkPlan, exportColumnarRdd: Boolean = f """.stripMargin } } + +object GpuColumnarToRowExecParent { + def unapply(arg: GpuColumnarToRowExecParent): Option[(SparkPlan, Boolean)] = { + Option(Tuple2(arg.child, arg.exportColumnarRdd)) + } +} + +case class GpuColumnarToRowExec(child: SparkPlan, override val exportColumnarRdd: Boolean = false) + extends GpuColumnarToRowExecParent(child, exportColumnarRdd) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRowToColumnarExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRowToColumnarExec.scala index 396ba2dd4ca..256c4386b28 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRowToColumnarExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRowToColumnarExec.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids -import ai.rapids.cudf.NvtxColor +import ai.rapids.cudf.{NvtxColor, NvtxRange} import com.nvidia.spark.rapids.GpuColumnVector.GpuColumnarBatchBuilder import com.nvidia.spark.rapids.GpuMetricNames._ import com.nvidia.spark.rapids.GpuRowToColumnConverter.{FixedWidthTypeConverter, VariableWidthTypeConverter} @@ -363,10 +363,10 @@ class RowToColumnarIterator( localSchema: StructType, localGoal: CoalesceGoal, converters: GpuRowToColumnConverter, - totalTime: SQLMetric, - numInputRows: SQLMetric, - numOutputRows: SQLMetric, - numOutputBatches: SQLMetric) extends Iterator[ColumnarBatch] { + totalTime: SQLMetric = null, + numInputRows: SQLMetric = null, + numOutputRows: SQLMetric = null, + numOutputBatches: SQLMetric = null) extends Iterator[ColumnarBatch] { private val dataTypes: Array[DataType] = localSchema.fields.map(_.dataType) private val variableWidthColumnCount = dataTypes.count(dt => !GpuBatchUtils.isFixedWidth(dt)) @@ -431,15 +431,26 @@ class RowToColumnarIterator( // option here Option(TaskContext.get()).foreach(GpuSemaphore.acquireIfNecessary) - val buildRange = new NvtxWithMetrics("RowToColumnar", NvtxColor.GREEN, totalTime) + var buildRange: NvtxRange = null + if (totalTime != null) { + buildRange = new NvtxWithMetrics("RowToColumnar", NvtxColor.GREEN, totalTime) + } else { + buildRange = new NvtxRange("RowToColumnar", NvtxColor.GREEN) + } val ret = try { builders.build(rowCount) } finally { buildRange.close() } - numInputRows += rowCount - numOutputRows += rowCount - numOutputBatches += 1 + if (numInputRows != null) { + numInputRows += rowCount + } + if (numOutputRows != null) { + numOutputRows += rowCount + } + if (numOutputBatches != null) { + numOutputBatches += 1 + } // refine the targetRows estimate based on the average of all batches processed so far totalOutputBytes += GpuColumnVector.getTotalDeviceMemoryUsed(ret) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala index c339a8e0683..f33a52c3939 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, CustomShuffleReaderExec, QueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.command.ExecutedCommandExec -import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExecBase import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec} @@ -40,11 +39,15 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { case HostColumnarToGpu(r2c: RowToColumnarExec, goal) => GpuRowToColumnarExec(optimizeGpuPlanTransitions(r2c.child), goal) case ColumnarToRowExec(bb: GpuBringBackToHost) => - GpuColumnarToRowExec(optimizeGpuPlanTransitions(bb.child)) + getColumnarToRowExec(optimizeGpuPlanTransitions(bb.child)) case p => p.withNewChildren(p.children.map(optimizeGpuPlanTransitions)) } + private def getColumnarToRowExec(plan: SparkPlan, exportColumnRdd: Boolean = false) = { + ShimLoader.getSparkShims.getGpuColumnarToRowTransition(plan, exportColumnRdd) + } + def optimizeAdaptiveTransitions(plan: SparkPlan): SparkPlan = plan match { case HostColumnarToGpu(r2c: RowToColumnarExec, goal) => GpuRowToColumnarExec(optimizeAdaptiveTransitions(r2c.child), goal) @@ -63,9 +66,9 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { // in future query stages. Note that because these query stages have already executed, we // don't need to recurse down and optimize them again case ColumnarToRowExec(e: BroadcastQueryStageExec) => - GpuColumnarToRowExec(e) + getColumnarToRowExec(e) case ColumnarToRowExec(e: ShuffleQueryStageExec) => - GpuColumnarToRowExec(e) + getColumnarToRowExec(e) case HostColumnarToGpu(e: BroadcastQueryStageExec, _) => e case HostColumnarToGpu(e: ShuffleQueryStageExec, _) => e @@ -74,7 +77,7 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { optimizeAdaptiveTransitions(bb.child) match { case e: GpuBroadcastExchangeExecBase => e case e: GpuShuffleExchangeExecBase => e - case other => GpuColumnarToRowExec(other) + case other => getColumnarToRowExec(other) } case p => @@ -82,7 +85,7 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { } def optimizeCoalesce(plan: SparkPlan): SparkPlan = plan match { - case c2r: GpuColumnarToRowExec if c2r.child.isInstanceOf[GpuCoalesceBatches] => + case c2r: GpuColumnarToRowExecParent if c2r.child.isInstanceOf[GpuCoalesceBatches] => // Don't build a batch if we are just going to go back to ROWS val co = c2r.child.asInstanceOf[GpuCoalesceBatches] c2r.withNewChildren(co.children.map(optimizeCoalesce)) @@ -326,7 +329,7 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { throw new IllegalArgumentException("It looks like some operations were " + s"pushed down to InMemoryTableScanExec ${imts.expressions.mkString(",")}") } - case _: GpuColumnarToRowExec => () // Ignored + case _: GpuColumnarToRowExecParent => () // Ignored case _: ExecutedCommandExec => () // Ignored case _: RDDScanExec => () // Ignored case _: ShuffleExchangeExec => () // Ignored for now, we don't force it to the GPU if @@ -352,9 +355,9 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { } def detectAndTagFinalColumnarOutput(plan: SparkPlan): SparkPlan = plan match { - case d: DeserializeToObjectExec if d.child.isInstanceOf[GpuColumnarToRowExec] => - val gpuColumnar = d.child.asInstanceOf[GpuColumnarToRowExec] - plan.withNewChildren(Seq(GpuColumnarToRowExec(gpuColumnar.child, true))) + case d: DeserializeToObjectExec if d.child.isInstanceOf[GpuColumnarToRowExecParent] => + val gpuColumnar = d.child.asInstanceOf[GpuColumnarToRowExecParent] + plan.withNewChildren(Seq(getColumnarToRowExec(gpuColumnar.child, true))) case _ => plan } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala index ad183978e36..705e3394236 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.read.Scan -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.datasources.{FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile} import org.apache.spark.sql.execution.joins._ @@ -65,6 +65,8 @@ trait SparkShims { def getBuildSide(join: HashJoin): GpuBuildSide def getBuildSide(join: BroadcastNestedLoopJoinExec): GpuBuildSide def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] + def getGpuColumnarToRowTransition(plan: SparkPlan, + exportColumnRdd: Boolean): GpuColumnarToRowExecParent def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] def getScans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]]