Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt to storage-partitioned join additions in SPARK-37377 [databricks] #5144

Merged
merged 2 commits into from
Apr 5, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.{GpuBatchScanExecMetrics, ScanWithMetrics}

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.execution.datasources.v2._

case class GpuBatchScanExec(
output: Seq[AttributeReference],
@transient scan: Scan) extends DataSourceV2ScanExecBase with GpuBatchScanExecMetrics {
@transient lazy val batch: Batch = scan.toBatch

scan match {
case s: ScanWithMetrics => s.metrics = allMetrics ++ additionalMetrics
case _ =>
}

override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions()

override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory()

override lazy val inputRDD: RDD[InternalRow] = {
new GpuDataSourceRDD(sparkContext, partitions, readerFactory)
}

override def doCanonicalize(): GpuBatchScanExec = {
this.copy(output = output.map(QueryPlan.normalizeExpressions(_, output)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@
* limitations under the License.
*/

package com.nvidia.spark.rapids
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.shims.ShimDataSourceRDD
import com.nvidia.spark.rapids.{MetricsBatchIterator, PartitionIterator}

import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, SparkException, TaskContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.DataSourceRDDPartition
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD, DataSourceRDDPartition}
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
Expand All @@ -36,7 +35,7 @@ class GpuDataSourceRDD(
sc: SparkContext,
@transient private val inputPartitions: Seq[InputPartition],
partitionReaderFactory: PartitionReaderFactory)
extends ShimDataSourceRDD(sc, inputPartitions, partitionReaderFactory, columnarReads = true) {
extends DataSourceRDD(sc, inputPartitions, partitionReaderFactory, columnarReads = true) {

private def castPartition(split: Partition): DataSourceRDDPartition = split match {
case p: DataSourceRDDPartition => p
Expand All @@ -53,50 +52,11 @@ class GpuDataSourceRDD(
}
}

private class PartitionIterator[T](reader: PartitionReader[T]) extends Iterator[T] {
private[this] var valuePrepared = false

override def hasNext: Boolean = {
if (!valuePrepared) {
valuePrepared = reader.next()
}
valuePrepared
}

override def next(): T = {
if (!hasNext) {
throw new java.util.NoSuchElementException("End of stream")
}
valuePrepared = false
reader.get()
}
}

private class MetricsBatchIterator(iter: Iterator[ColumnarBatch]) extends Iterator[ColumnarBatch] {
private[this] val inputMetrics = TaskContext.get().taskMetrics().inputMetrics

override def hasNext: Boolean = iter.hasNext

override def next(): ColumnarBatch = {
val batch = iter.next()
TrampolineUtil.incInputRecordsRows(inputMetrics, batch.numRows())
batch
object GpuDataSourceRDD {
def apply(
sc: SparkContext,
inputPartitions: Seq[InputPartition],
partitionReaderFactory: PartitionReaderFactory): GpuDataSourceRDD = {
new GpuDataSourceRDD(sc, inputPartitions, partitionReaderFactory)
}
}

/** Wraps a columnar PartitionReader to update bytes read metric based on filesystem statistics. */
class PartitionReaderWithBytesRead(reader: PartitionReader[ColumnarBatch])
extends PartitionReader[ColumnarBatch] {
private[this] val inputMetrics = TaskContext.get.taskMetrics().inputMetrics
private[this] val getBytesRead = TrampolineUtil.getFSBytesReadOnThreadCallback()

override def next(): Boolean = {
val result = reader.next()
TrampolineUtil.incBytesRead(inputMetrics, getBytesRead())
result
}

override def get(): ColumnarBatch = reader.get()

override def close(): Unit = reader.close()
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging {
}

override def convertToGpu(): GpuExec =
GpuBatchScanExec(p.output, childScans.head.convertToGpu())
GpuBatchScanExec(p.output, childScans.head.convertToGpu(), p.runtimeFilters)
})
).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
tgravescs marked this conversation as resolved.
Show resolved Hide resolved
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -14,15 +14,14 @@
* limitations under the License.
*/

package com.nvidia.spark.rapids.shims
package org.apache.spark.sql.execution.datasources.rapids

import org.apache.spark.SparkContext
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.DataSourceRDD
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources.Filter

class ShimDataSourceRDD(
sc: SparkContext,
@transient private val inputPartitions: Seq[InputPartition],
partitionReaderFactory: PartitionReaderFactory,
columnarReads: Boolean
) extends DataSourceRDD(sc, inputPartitions, partitionReaderFactory, columnarReads)
object DataSourceStrategyUtils {
// Trampoline utility to access protected translateRuntimeFilter
def translateRuntimeFilter(expr: Expression): Option[Filter] =
DataSourceStrategy.translateRuntimeFilter(expr)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.shims

import com.google.common.base.Objects
import com.nvidia.spark.rapids.{GpuBatchScanExecMetrics, ScanWithMetrics}

import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, DynamicPruningExpression, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.execution.datasources.rapids.DataSourceStrategyUtils
import org.apache.spark.sql.execution.datasources.v2._

case class GpuBatchScanExec(
output: Seq[AttributeReference],
@transient scan: Scan,
runtimeFilters: Seq[Expression] = Seq.empty)
extends DataSourceV2ScanExecBase with GpuBatchScanExecMetrics {
@transient lazy val batch: Batch = scan.toBatch

scan match {
case s: ScanWithMetrics => s.metrics = allMetrics ++ additionalMetrics
case _ =>
}

// TODO: unify the equal/hashCode implementation for all data source v2 query plans.
override def equals(other: Any): Boolean = other match {
case other: GpuBatchScanExec =>
this.batch == other.batch && this.runtimeFilters == other.runtimeFilters
case _ =>
false
}

override def hashCode(): Int = Objects.hashCode(batch, runtimeFilters)

@transient override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions()

@transient private lazy val filteredPartitions: Seq[InputPartition] = {
val dataSourceFilters = runtimeFilters.flatMap {
case DynamicPruningExpression(e) => DataSourceStrategyUtils.translateRuntimeFilter(e)
case _ => None
}

if (dataSourceFilters.nonEmpty && scan.isInstanceOf[SupportsRuntimeFiltering]) {
val originalPartitioning = outputPartitioning

// the cast is safe as runtime filters are only assigned if the scan can be filtered
val filterableScan = scan.asInstanceOf[SupportsRuntimeFiltering]
filterableScan.filter(dataSourceFilters.toArray)

// call toBatch again to get filtered partitions
val newPartitions = scan.toBatch.planInputPartitions()

originalPartitioning match {
case p: DataSourcePartitioning if p.numPartitions != newPartitions.size =>
throw new SparkException(
"Data source must have preserved the original partitioning during runtime filtering; " +
s"reported num partitions: ${p.numPartitions}, " +
s"num partitions after runtime filtering: ${newPartitions.size}")
case _ =>
// no validation is needed as the data source did not report any specific partitioning
}

newPartitions
} else {
partitions
}
}

override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory()

override lazy val inputRDD: RDD[InternalRow] = {
if (filteredPartitions.isEmpty && outputPartitioning == SinglePartition) {
// return an empty RDD with 1 partition if dynamic filtering removed the only split
sparkContext.parallelize(Array.empty[InternalRow], 1)
} else {
new GpuDataSourceRDD(sparkContext, partitions, readerFactory)
}
}

override def doCanonicalize(): GpuBatchScanExec = {
this.copy(
output = output.map(QueryPlan.normalizeExpressions(_, output)),
runtimeFilters = QueryPlan.normalizePredicates(
runtimeFilters.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)),
output))
}

override def simpleString(maxFields: Int): String = {
val truncatedOutputString = truncatedString(output, "[", ", ", "]", maxFields)
val runtimeFiltersString = s"RuntimeFilters: ${runtimeFilters.mkString("[", ",", "]")}"
val result = s"$nodeName$truncatedOutputString ${scan.description()} $runtimeFiltersString"
redact(result)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.{MetricsBatchIterator, PartitionIterator}

import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, SparkException, TaskContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD, DataSourceRDDPartition}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
* A replacement for DataSourceRDD that does NOT compute the bytes read input metric.
* DataSourceRDD assumes all reads occur on the task thread, and some GPU input sources
* use multithreaded readers that cannot generate proper metrics with DataSourceRDD.
* @note It is the responsibility of users of this RDD to generate the bytes read input
* metric explicitly!
*/
class GpuDataSourceRDD(
sc: SparkContext,
@transient private val inputPartitions: Seq[InputPartition],
partitionReaderFactory: PartitionReaderFactory
) extends DataSourceRDD(sc, inputPartitions, partitionReaderFactory, columnarReads = true,
Map.empty[String, SQLMetric]) {

private def castPartition(split: Partition): DataSourceRDDPartition = split match {
case p: DataSourceRDDPartition => p
case _ => throw new SparkException(s"[BUG] Not a DataSourceRDDPartition: $split")
}

override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
val inputPartition = castPartition(split).inputPartition
val batchReader = partitionReaderFactory.createColumnarReader(inputPartition)
val iter = new MetricsBatchIterator(new PartitionIterator[ColumnarBatch](batchReader))
context.addTaskCompletionListener[Unit](_ => batchReader.close())
// TODO: SPARK-25083 remove the type erasure hack in data source scan
new InterruptibleIterator(context, iter.asInstanceOf[Iterator[InternalRow]])
}
}

object GpuDataSourceRDD {
def apply(
sc: SparkContext,
inputPartitions: Seq[InputPartition],
partitionReaderFactory: PartitionReaderFactory): GpuDataSourceRDD = {
new GpuDataSourceRDD(sc, inputPartitions, partitionReaderFactory)
}
}
Loading