From 13e199fbc3e5225f8f7796b265c623ace2bec6c4 Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Mon, 26 Apr 2021 09:54:49 -0700 Subject: [PATCH] Allow specifying a superclass for non-GPU execs (#2247) Fixes #2220 Signed-off-by: Gera Shegalov --- integration_tests/src/main/python/csv_test.py | 12 +++-- .../spark/rapids/GpuTransitionOverrides.scala | 13 ++--- .../com/nvidia/spark/rapids/PlanUtils.scala | 47 +++++++++++++++++++ .../com/nvidia/spark/rapids/Plugin.scala | 18 ++----- 4 files changed, 61 insertions(+), 29 deletions(-) create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/PlanUtils.scala diff --git a/integration_tests/src/main/python/csv_test.py b/integration_tests/src/main/python/csv_test.py index 6710f6ce46a..2d8f0bcc338 100644 --- a/integration_tests/src/main/python/csv_test.py +++ b/integration_tests/src/main/python/csv_test.py @@ -15,6 +15,7 @@ import pytest from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_gpu_fallback_write +from conftest import get_non_gpu_allowed from datetime import datetime, timezone from data_gen import * from marks import * @@ -268,27 +269,28 @@ def test_round_trip(spark_tmp_path, data_gen, v1_enabled_list): lambda spark : spark.read.schema(schema).csv(data_path), conf=updated_conf) -@allow_non_gpu('FileSourceScanExec') +@allow_non_gpu('org.apache.spark.sql.execution.LeafExecNode') @pytest.mark.parametrize('read_func', [read_csv_df, read_csv_sql]) @pytest.mark.parametrize('disable_conf', ['spark.rapids.sql.format.csv.enabled', 'spark.rapids.sql.format.csv.read.enabled']) def test_csv_fallback(spark_tmp_path, read_func, disable_conf): data_gens =[ StringGen('(\\w| |\t|\ud720){0,10}', nullable=False), byte_gen, short_gen, int_gen, long_gen, boolean_gen, date_gen] - + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(data_gens)] gen = StructGen(gen_list, nullable=False) data_path = spark_tmp_path + '/CSV_DATA' schema = gen.data_type updated_conf=_enable_all_types_conf.copy() updated_conf[disable_conf]='false' - + reader = read_func(data_path, schema) with_cpu_session( lambda spark : gen_df(spark, gen).write.csv(data_path)) assert_gpu_fallback_collect( lambda spark : reader(spark).select(f.col('*'), f.col('_c2') + f.col('_c3')), - 'FileSourceScanExec', + # TODO add support for lists + cpu_fallback_class_name=get_non_gpu_allowed()[0], conf=updated_conf) csv_supported_date_formats = ['yyyy-MM-dd', 'yyyy/MM/dd', 'yyyy-MM', 'yyyy/MM', @@ -345,7 +347,7 @@ def test_ts_formats_round_trip(spark_tmp_path, date_format, ts_part, v1_enabled_ @pytest.mark.parametrize('v1_enabled_list', ["", "csv"]) def test_input_meta(spark_tmp_path, v1_enabled_list): - gen = StructGen([('a', long_gen), ('b', long_gen)], nullable=False) + gen = StructGen([('a', long_gen), ('b', long_gen)], nullable=False) first_data_path = spark_tmp_path + '/CSV_DATA/key=0' with_cpu_session( lambda spark : gen_df(spark, gen).write.csv(first_data_path)) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala index f950c925736..6983ec78e1e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala @@ -392,17 +392,12 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { } } - private def getBaseNameFromClass(planClassStr: String): String = { - val firstDotIndex = planClassStr.lastIndexOf(".") - if (firstDotIndex != -1) planClassStr.substring(firstDotIndex + 1) else planClassStr - } - def assertIsOnTheGpu(exp: Expression, conf: RapidsConf): Unit = { // There are no GpuAttributeReference or GpuSortOrder if (!exp.isInstanceOf[AttributeReference] && !exp.isInstanceOf[SortOrder] && !exp.isInstanceOf[GpuExpression] && - !conf.testingAllowedNonGpu.contains(getBaseNameFromClass(exp.getClass.toString))) { + !conf.testingAllowedNonGpu.contains(PlanUtils.getBaseNameFromClass(exp.getClass.toString))) { throw new IllegalArgumentException(s"The expression $exp is not columnar ${exp.getClass}") } exp.children.foreach(subExp => assertIsOnTheGpu(subExp, conf)) @@ -438,9 +433,9 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { // Ignored for now, we don't force it to the GPU if // children are not on the gpu } - case other => - if (!plan.supportsColumnar && - !conf.testingAllowedNonGpu.contains(getBaseNameFromClass(other.getClass.toString))) { + case _ => + if (!plan.supportsColumnar && !conf.testingAllowedNonGpu.exists(nonGpuClass => + PlanUtils.sameClass(plan, nonGpuClass))) { throw new IllegalArgumentException(s"Part of the plan is not columnar " + s"${plan.getClass}\n${plan}") } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/PlanUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/PlanUtils.scala new file mode 100644 index 00000000000..cc61eb112b1 --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/PlanUtils.scala @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import scala.util.Try + +import org.apache.spark.sql.execution.SparkPlan + +object PlanUtils { + def getBaseNameFromClass(planClassStr: String): String = { + val firstDotIndex = planClassStr.lastIndexOf(".") + if (firstDotIndex != -1) planClassStr.substring(firstDotIndex + 1) else planClassStr + } + + /** + * Determines if plan is either fallbackCpuClass or a subclass thereof + * + * Useful subclass expression are LeafLike + * + * @param plan + * @param fallbackCpuClass + * @return + */ + def sameClass(plan: SparkPlan, fallbackCpuClass: String): Boolean = { + val planClass = plan.getClass + val execNameWithoutPackage = getBaseNameFromClass(planClass.getName) + execNameWithoutPackage == fallbackCpuClass || + plan.getClass.getName == fallbackCpuClass || + Try(java.lang.Class.forName(fallbackCpuClass)) + .map(_.isAssignableFrom(planClass)) + .getOrElse(false) + } +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala index bf6e695a726..1b08887e78b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala @@ -318,28 +318,16 @@ object ExecutionPlanCaptureCallback { s"Could not find $fallbackCpuClass in the GPU plan\n$executedPlan") } - private def getBaseNameFromClass(planClassStr: String): String = { - val firstDotIndex = planClassStr.lastIndexOf(".") - if (firstDotIndex != -1) planClassStr.substring(firstDotIndex + 1) else planClassStr - } - private def didFallBack(exp: Expression, fallbackCpuClass: String): Boolean = { - if (!exp.isInstanceOf[GpuExpression] && - getBaseNameFromClass(exp.getClass.getName) == fallbackCpuClass) { - true - } else { + !exp.isInstanceOf[GpuExpression] && + PlanUtils.getBaseNameFromClass(exp.getClass.getName) == fallbackCpuClass || exp.children.exists(didFallBack(_, fallbackCpuClass)) - } } private def didFallBack(plan: SparkPlan, fallbackCpuClass: String): Boolean = { val executedPlan = ExecutionPlanCaptureCallback.extractExecutedPlan(Some(plan)) - if (!executedPlan.isInstanceOf[GpuExec] && - getBaseNameFromClass(executedPlan.getClass.getName) == fallbackCpuClass) { - true - } else { + !executedPlan.isInstanceOf[GpuExec] && PlanUtils.sameClass(executedPlan, fallbackCpuClass) || executedPlan.expressions.exists(didFallBack(_, fallbackCpuClass)) - } } }