Skip to content

Commit

Permalink
test enhencement
Browse files Browse the repository at this point in the history
Signed-off-by: sperlingxx <lovedreamf@gmail.com>
  • Loading branch information
sperlingxx committed Aug 2, 2021
1 parent 3bbc9aa commit 66084d4
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 22 deletions.
47 changes: 47 additions & 0 deletions integration_tests/src/main/python/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,53 @@ def assert_gpu_fallback_write(write_func,

assert_equal(from_cpu, from_gpu)

def assert_cpu_and_gpu_are_equal_collect_with_capture(func,
exist_classes='',
non_exist_classes='',
conf={}):
(bring_back, collect_type) = _prep_func_for_compare(func, 'COLLECT_WITH_DATAFRAME')

conf = _prep_incompat_conf(conf)

print('### CPU RUN ###')
cpu_start = time.time()
from_cpu, cpu_df = with_cpu_session(bring_back, conf=conf)
cpu_end = time.time()
print('### GPU RUN ###')
gpu_start = time.time()
from_gpu, gpu_df = with_gpu_session(bring_back, conf=conf)
gpu_end = time.time()
jvm = spark_jvm()
for clz in exist_classes.split(','):
jvm.com.nvidia.spark.rapids.ExecutionPlanCaptureCallback.assertContains(gpu_df._jdf, clz)
for clz in non_exist_classes.split(','):
jvm.com.nvidia.spark.rapids.ExecutionPlanCaptureCallback.assertNotContain(gpu_df._jdf, clz)
print('### {}: GPU TOOK {} CPU TOOK {} ###'.format(collect_type,
gpu_end - gpu_start, cpu_end - cpu_start))
if should_sort_locally():
from_cpu.sort(key=_RowCmp)
from_gpu.sort(key=_RowCmp)

assert_equal(from_cpu, from_gpu)

def assert_cpu_and_gpu_are_equal_sql_with_capture(df_fun,
sql,
table_name,
exist_classes='',
non_exist_classes='',
conf=None,
debug=False):
if conf is None:
conf = {}
def do_it_all(spark):
df = df_fun(spark)
df.createOrReplaceTempView(table_name)
if debug:
return data_gen.debug_df(spark.sql(sql))
else:
return spark.sql(sql)
assert_cpu_and_gpu_are_equal_collect_with_capture(do_it_all, exist_classes, non_exist_classes, conf)

def assert_gpu_fallback_collect(func,
cpu_fallback_class_name,
conf={}):
Expand Down
61 changes: 43 additions & 18 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_fallback_collect
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql,\
assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_sql_with_capture,\
assert_cpu_and_gpu_are_equal_collect_with_capture
from data_gen import *
from functools import reduce
from pyspark.sql.types import *
Expand Down Expand Up @@ -433,6 +435,11 @@ def test_hash_groupby_collect_with_single_distinct(data_gen):
f.countDistinct('c'),
f.count('c')))

@approximate_float
@ignore_order(local=True)
@incompat
@pytest.mark.parametrize('data_gen', _gen_data_for_collect_op, ids=idfn)
def test_hash_groupby_single_distinct_collect(data_gen):
# test distinct collect with other aggregations
sql = """select a,
sort_array(collect_list(distinct b)),
Expand Down Expand Up @@ -463,7 +470,6 @@ def spark_fn(spark_session):
f.sort_array(f.collect_set('b')),
f.countDistinct('b'),
f.countDistinct('c'))
assert_gpu_and_cpu_are_equal_collect(spark_fn)
assert_gpu_fallback_collect(func=spark_fn, cpu_fallback_class_name='SortAggregateExec')

@approximate_float
Expand All @@ -479,34 +485,53 @@ def test_hash_groupby_collect_partial_replace_fallback(data_gen, conf, aqe_enabl
local_conf = conf.copy()
local_conf.update({'spark.sql.adaptive.enabled': aqe_enabled})
# test without Distinct
assert_gpu_and_cpu_are_equal_collect(
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark: gen_df(spark, data_gen, length=100)
.groupby('a')
.agg(f.sort_array(f.collect_list('b')), f.sort_array(f.collect_set('b'))),
conf=local_conf)
assert_gpu_fallback_collect(
lambda spark: gen_df(spark, data_gen, length=100)
.groupby('a')
.agg(f.sort_array(f.collect_list('b')), f.sort_array(f.collect_set('b'))),
cpu_fallback_class_name='ObjectHashAggregateExec',
exist_classes='CollectList,CollectSet',
non_exist_classes='GpuCollectList,GpuCollectSet',
conf=local_conf)
# test with single Distinct
assert_gpu_and_cpu_are_equal_collect(
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark: gen_df(spark, data_gen, length=100)
.groupby('a')
.agg(f.sort_array(f.collect_list('b')),
f.sort_array(f.collect_set('b')),
f.countDistinct('c'),
f.count('c')),
exist_classes='CollectList,CollectSet',
non_exist_classes='GpuCollectList,GpuCollectSet',
conf=local_conf)
assert_gpu_fallback_collect(
lambda spark: gen_df(spark, data_gen, length=100)
.groupby('a')
.agg(f.sort_array(f.collect_list('b')),
f.sort_array(f.collect_set('b')),
f.countDistinct('c'),
f.count('c')),
cpu_fallback_class_name='ObjectHashAggregateExec',

@ignore_order(local=True)
@allow_non_gpu('ObjectHashAggregateExec', 'ShuffleExchangeExec', 'HashAggregateExec',
'HashPartitioning', 'SortArray', 'Alias', 'Literal',
'CollectList', 'CollectSet', 'Max', 'AggregateExpression')
@pytest.mark.parametrize('conf', [_nans_float_conf_final, _nans_float_conf_partial], ids=idfn)
@pytest.mark.parametrize('aqe_enabled', ['true', 'false'], ids=idfn)
def test_hash_groupby_collect_partial_replace_fallback_with_other_agg(conf, aqe_enabled):
# This test is to ensure "associated fallback" will not affect another Aggregate plans.
local_conf = conf.copy()
local_conf.update({'spark.sql.adaptive.enabled': aqe_enabled})

assert_cpu_and_gpu_are_equal_sql_with_capture(
lambda spark: gen_df(spark, [('k1', RepeatSeqGen(LongGen(), length=20)),
('k2', RepeatSeqGen(LongGen(), length=20)),
('v', LongRangeGen())], length=100),
exist_classes='GpuMax,Max,CollectList,CollectSet',
non_exist_classes='GpuObjectHashAggregateExec,GpuCollectList,GpuCollectSet',
table_name='table',
sql="""
select k1,
sort_array(collect_set(k2)),
sort_array(collect_list(max_v))
from
(select k1, k2,
max(v) as max_v
from table group by k1, k2
)t
group by k1""",
conf=local_conf)

@approximate_float
Expand Down
39 changes: 38 additions & 1 deletion sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.{DataFrame, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.internal.StaticSQLConf
import org.apache.spark.sql.rapids.GpuShuffleEnv
import org.apache.spark.sql.util.QueryExecutionListener
Expand Down Expand Up @@ -330,6 +330,28 @@ object ExecutionPlanCaptureCallback {
assertDidFallBack(executedPlan, fallbackCpuClass)
}

def assertContains(gpuPlan: SparkPlan, className: String): Unit = {
val executedPlan = ExecutionPlanCaptureCallback.extractExecutedPlan(Some(gpuPlan))
assert(containsPlan(executedPlan, className),
s"Could not find $className in the Spark plan\n$executedPlan")
}

def assertContains(df: DataFrame, gpuClass: String): Unit = {
val executedPlan = df.queryExecution.executedPlan
assertContains(executedPlan, gpuClass)
}

def assertNotContain(gpuPlan: SparkPlan, className: String): Unit = {
val executedPlan = ExecutionPlanCaptureCallback.extractExecutedPlan(Some(gpuPlan))
assert(!containsPlan(executedPlan, className),
s"We found $className in the Spark plan\n$executedPlan")
}

def assertNotContain(df: DataFrame, gpuClass: String): Unit = {
val executedPlan = df.queryExecution.executedPlan
assertNotContain(executedPlan, gpuClass)
}

private def didFallBack(exp: Expression, fallbackCpuClass: String): Boolean = {
!exp.isInstanceOf[GpuExpression] &&
PlanUtils.getBaseNameFromClass(exp.getClass.getName) == fallbackCpuClass ||
Expand All @@ -341,6 +363,21 @@ object ExecutionPlanCaptureCallback {
!executedPlan.isInstanceOf[GpuExec] && PlanUtils.sameClass(executedPlan, fallbackCpuClass) ||
executedPlan.expressions.exists(didFallBack(_, fallbackCpuClass))
}

private def containsExpression(exp: Expression, className: String): Boolean = {
PlanUtils.getBaseNameFromClass(exp.getClass.getName) == className ||
exp.children.exists(containsExpression(_, className))
}

private def containsPlan(plan: SparkPlan, className: String): Boolean = {
val p = ExecutionPlanCaptureCallback.extractExecutedPlan(Some(plan)) match {
case p: QueryStageExec => p.plan
case p => p
}
PlanUtils.sameClass(p, className) ||
p.expressions.exists(containsExpression(_, className)) ||
p.children.exists(containsPlan(_, className))
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1076,9 +1076,8 @@ object GpuTypedImperativeSupportedAggregateExecMeta {
private def checkAndFallbackEntirely(
meta: GpuTypedImperativeSupportedAggregateExecMeta[_]): Unit = {
// We only run the check for final stages which contain TypedImperativeAggregate.
val needToCheck = meta.agg.aggregateExpressions.exists(e =>
(e.mode == Final || e.mode == Complete) &&
e.aggregateFunction.isInstanceOf[TypedImperativeAggregate[_]])
val needToCheck = meta.agg.aggregateExpressions.exists(e => e.mode == Final &&
e.aggregateFunction.isInstanceOf[TypedImperativeAggregate[_]])
if (!needToCheck) return
// Avoid duplicated check and fallback.
val checked = meta.agg.getTagValue[Boolean](entireAggFallbackCheck).contains(true)
Expand Down

0 comments on commit 66084d4

Please sign in to comment.