Skip to content

Commit

Permalink
Fix the issue of exporting Column RDD (#4335)
Browse files Browse the repository at this point in the history
Signed-off-by: Bobby Wang <wbo4958@gmail.com>
  • Loading branch information
wbo4958 authored Dec 13, 2021
1 parent bc0cccb commit dab062e
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -466,9 +466,9 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
val serName = plan.conf.getConf(StaticSQLConf.SPARK_CACHE_SERIALIZER)
val serClass = ShimLoader.loadClass(serName)
if (serClass == classOf[com.nvidia.spark.ParquetCachedBatchSerializer]) {
GpuColumnarToRowTransitionExec(plan)
GpuColumnarToRowTransitionExec(plan, exportColumnRdd)
} else {
GpuColumnarToRowExec(plan)
GpuColumnarToRowExec(plan, exportColumnRdd)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -636,9 +636,9 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
val serName = plan.conf.getConf(StaticSQLConf.SPARK_CACHE_SERIALIZER)
val serClass = ShimLoader.loadClass(serName)
if (serClass == classOf[com.nvidia.spark.ParquetCachedBatchSerializer]) {
GpuColumnarToRowTransitionExec(plan)
GpuColumnarToRowTransitionExec(plan, exportColumnRdd)
} else {
GpuColumnarToRowExec(plan)
GpuColumnarToRowExec(plan, exportColumnRdd)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -866,9 +866,9 @@ trait Spark32XShims extends SparkShims with Logging {
val serName = plan.conf.getConf(StaticSQLConf.SPARK_CACHE_SERIALIZER)
val serClass = ShimLoader.loadClass(serName)
if (serClass == classOf[com.nvidia.spark.ParquetCachedBatchSerializer]) {
GpuColumnarToRowTransitionExec(plan)
GpuColumnarToRowTransitionExec(plan, exportColumnRdd)
} else {
GpuColumnarToRowExec(plan)
GpuColumnarToRowExec(plan, exportColumnRdd)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,8 @@ object InternalColumnarRddConverter extends Logging {
convert(df)
}

def convert(df: DataFrame): RDD[Table] = {
// Extract RDD[ColumnarBatch] directly
def extractRDDColumnarBatch(df: DataFrame): (Option[RDD[ColumnarBatch]], RDD[Row]) = {
val schema = df.schema
val unsupported = schema.map(_.dataType).filter( dt => !GpuOverrides.isSupportedType(dt,
allowMaps = true, allowStringMaps = true, allowNull = true, allowStruct = true, allowArray
Expand Down Expand Up @@ -709,7 +710,12 @@ object InternalColumnarRddConverter extends Logging {
logDebug(s"Cannot extract columnar RDD directly. " +
s"(First MapPartitionsRDD not found $rdd)")
}
(batch, input)
}

def convert(df: DataFrame): RDD[Table] = {
val schema = df.schema
val (batch, input) = extractRDDColumnarBatch(df)
val b = batch.getOrElse({
// We have to fall back to doing a slow transition.
val converters = new GpuExternalRowToColumnConverter(schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ package org.apache.spark.sql.rapids.execution

import scala.collection.mutable

import com.nvidia.spark.rapids.{ColumnarToRowIterator, GpuBatchUtilsSuite, NoopMetric, SparkQueryCompareTestSuite}
import com.nvidia.spark.rapids.{ColumnarToRowIterator, GpuBatchUtilsSuite, NoopMetric, SparkQueryCompareTestSuite, TestResourceFinder}
import com.nvidia.spark.rapids.GpuColumnVector.GpuColumnarBatchBuilder

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.util.MapData
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -272,5 +273,16 @@ class InternalColumnarRDDConverterSuite extends SparkQueryCompareTestSuite {
}
}

test("InternalColumnarRddConverter should extractRDDTable RDD[ColumnarBatch]") {
withGpuSparkSession(spark => {
val path = TestResourceFinder.getResourcePath("disorder-read-schema.parquet")
val df = spark.read.parquet(path)
val (optionRddColumnBatch, _) = InternalColumnarRddConverter.extractRDDColumnarBatch(df)

assert(optionRddColumnBatch.isDefined, "Can't extract RDD[ColumnarBatch]")

}, new SparkConf().set("spark.rapids.sql.test.allowedNonGpu", "DeserializeToObjectExec"))
}

}

0 comments on commit dab062e

Please sign in to comment.