Skip to content

Commit

Permalink
Allow specifying a superclass for non-GPU execs (#2247)
Browse files Browse the repository at this point in the history
Fixes #2220
    
Signed-off-by: Gera Shegalov <gera@apache.org>
  • Loading branch information
gerashegalov authored Apr 26, 2021
1 parent 4f2d6e8 commit 13e199f
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 29 deletions.
12 changes: 7 additions & 5 deletions integration_tests/src/main/python/csv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_gpu_fallback_write
from conftest import get_non_gpu_allowed
from datetime import datetime, timezone
from data_gen import *
from marks import *
Expand Down Expand Up @@ -268,27 +269,28 @@ def test_round_trip(spark_tmp_path, data_gen, v1_enabled_list):
lambda spark : spark.read.schema(schema).csv(data_path),
conf=updated_conf)

@allow_non_gpu('FileSourceScanExec')
@allow_non_gpu('org.apache.spark.sql.execution.LeafExecNode')
@pytest.mark.parametrize('read_func', [read_csv_df, read_csv_sql])
@pytest.mark.parametrize('disable_conf', ['spark.rapids.sql.format.csv.enabled', 'spark.rapids.sql.format.csv.read.enabled'])
def test_csv_fallback(spark_tmp_path, read_func, disable_conf):
data_gens =[
StringGen('(\\w| |\t|\ud720){0,10}', nullable=False),
byte_gen, short_gen, int_gen, long_gen, boolean_gen, date_gen]

gen_list = [('_c' + str(i), gen) for i, gen in enumerate(data_gens)]
gen = StructGen(gen_list, nullable=False)
data_path = spark_tmp_path + '/CSV_DATA'
schema = gen.data_type
updated_conf=_enable_all_types_conf.copy()
updated_conf[disable_conf]='false'

reader = read_func(data_path, schema)
with_cpu_session(
lambda spark : gen_df(spark, gen).write.csv(data_path))
assert_gpu_fallback_collect(
lambda spark : reader(spark).select(f.col('*'), f.col('_c2') + f.col('_c3')),
'FileSourceScanExec',
# TODO add support for lists
cpu_fallback_class_name=get_non_gpu_allowed()[0],
conf=updated_conf)

csv_supported_date_formats = ['yyyy-MM-dd', 'yyyy/MM/dd', 'yyyy-MM', 'yyyy/MM',
Expand Down Expand Up @@ -345,7 +347,7 @@ def test_ts_formats_round_trip(spark_tmp_path, date_format, ts_part, v1_enabled_

@pytest.mark.parametrize('v1_enabled_list', ["", "csv"])
def test_input_meta(spark_tmp_path, v1_enabled_list):
gen = StructGen([('a', long_gen), ('b', long_gen)], nullable=False)
gen = StructGen([('a', long_gen), ('b', long_gen)], nullable=False)
first_data_path = spark_tmp_path + '/CSV_DATA/key=0'
with_cpu_session(
lambda spark : gen_df(spark, gen).write.csv(first_data_path))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,17 +392,12 @@ class GpuTransitionOverrides extends Rule[SparkPlan] {
}
}

private def getBaseNameFromClass(planClassStr: String): String = {
val firstDotIndex = planClassStr.lastIndexOf(".")
if (firstDotIndex != -1) planClassStr.substring(firstDotIndex + 1) else planClassStr
}

def assertIsOnTheGpu(exp: Expression, conf: RapidsConf): Unit = {
// There are no GpuAttributeReference or GpuSortOrder
if (!exp.isInstanceOf[AttributeReference] &&
!exp.isInstanceOf[SortOrder] &&
!exp.isInstanceOf[GpuExpression] &&
!conf.testingAllowedNonGpu.contains(getBaseNameFromClass(exp.getClass.toString))) {
!conf.testingAllowedNonGpu.contains(PlanUtils.getBaseNameFromClass(exp.getClass.toString))) {
throw new IllegalArgumentException(s"The expression $exp is not columnar ${exp.getClass}")
}
exp.children.foreach(subExp => assertIsOnTheGpu(subExp, conf))
Expand Down Expand Up @@ -438,9 +433,9 @@ class GpuTransitionOverrides extends Rule[SparkPlan] {
// Ignored for now, we don't force it to the GPU if
// children are not on the gpu
}
case other =>
if (!plan.supportsColumnar &&
!conf.testingAllowedNonGpu.contains(getBaseNameFromClass(other.getClass.toString))) {
case _ =>
if (!plan.supportsColumnar && !conf.testingAllowedNonGpu.exists(nonGpuClass =>
PlanUtils.sameClass(plan, nonGpuClass))) {
throw new IllegalArgumentException(s"Part of the plan is not columnar " +
s"${plan.getClass}\n${plan}")
}
Expand Down
47 changes: 47 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/PlanUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import scala.util.Try

import org.apache.spark.sql.execution.SparkPlan

object PlanUtils {
def getBaseNameFromClass(planClassStr: String): String = {
val firstDotIndex = planClassStr.lastIndexOf(".")
if (firstDotIndex != -1) planClassStr.substring(firstDotIndex + 1) else planClassStr
}

/**
* Determines if plan is either fallbackCpuClass or a subclass thereof
*
* Useful subclass expression are LeafLike
*
* @param plan
* @param fallbackCpuClass
* @return
*/
def sameClass(plan: SparkPlan, fallbackCpuClass: String): Boolean = {
val planClass = plan.getClass
val execNameWithoutPackage = getBaseNameFromClass(planClass.getName)
execNameWithoutPackage == fallbackCpuClass ||
plan.getClass.getName == fallbackCpuClass ||
Try(java.lang.Class.forName(fallbackCpuClass))
.map(_.isAssignableFrom(planClass))
.getOrElse(false)
}
}
18 changes: 3 additions & 15 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -318,28 +318,16 @@ object ExecutionPlanCaptureCallback {
s"Could not find $fallbackCpuClass in the GPU plan\n$executedPlan")
}

private def getBaseNameFromClass(planClassStr: String): String = {
val firstDotIndex = planClassStr.lastIndexOf(".")
if (firstDotIndex != -1) planClassStr.substring(firstDotIndex + 1) else planClassStr
}

private def didFallBack(exp: Expression, fallbackCpuClass: String): Boolean = {
if (!exp.isInstanceOf[GpuExpression] &&
getBaseNameFromClass(exp.getClass.getName) == fallbackCpuClass) {
true
} else {
!exp.isInstanceOf[GpuExpression] &&
PlanUtils.getBaseNameFromClass(exp.getClass.getName) == fallbackCpuClass ||
exp.children.exists(didFallBack(_, fallbackCpuClass))
}
}

private def didFallBack(plan: SparkPlan, fallbackCpuClass: String): Boolean = {
val executedPlan = ExecutionPlanCaptureCallback.extractExecutedPlan(Some(plan))
if (!executedPlan.isInstanceOf[GpuExec] &&
getBaseNameFromClass(executedPlan.getClass.getName) == fallbackCpuClass) {
true
} else {
!executedPlan.isInstanceOf[GpuExec] && PlanUtils.sameClass(executedPlan, fallbackCpuClass) ||
executedPlan.expressions.exists(didFallBack(_, fallbackCpuClass))
}
}
}

Expand Down

0 comments on commit 13e199f

Please sign in to comment.