From 8380df856bcff6a3ed1aeb6ac632888c99219f9b Mon Sep 17 00:00:00 2001 From: Niranjan Artal <50492963+nartal1@users.noreply.github.com> Date: Thu, 3 Feb 2022 08:49:11 -0800 Subject: [PATCH] Make Collect, first and last as deterministic aggregate functions for Spark-3.3 (#4677) * Make First, Last and Collect as deterministic for Spark-3.3 Signed-off-by: Niranjan Artal --- .../shims/v2/Spark30Xuntil33XShims.scala | 7 +++++ .../spark/rapids/shims/v2/Spark33XShims.scala | 5 +++- .../spark/sql/rapids/AggregateFunctions.scala | 27 ++++++++++--------- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/Spark30Xuntil33XShims.scala b/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/Spark30Xuntil33XShims.scala index ac52ba9d3a8..014ed1dccc2 100644 --- a/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/Spark30Xuntil33XShims.scala +++ b/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/Spark30Xuntil33XShims.scala @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids.shims.v2 import com.nvidia.spark.rapids._ +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.json.rapids.shims.v2.Spark30Xuntil33XFileOptionsShims import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.v2._ @@ -28,3 +29,9 @@ trait Spark30Xuntil33XShims extends Spark30Xuntil33XFileOptionsShims { GpuOverrides.neverReplaceExec[ShowCurrentNamespaceExec]("Namespace metadata operation") } } + +// First, Last and Collect have mistakenly been marked as non-deterministic until Spark-3.3. +// They are actually deterministic iff their child expression is deterministic. +trait GpuDeterministicFirstLastCollectShim extends Expression { + override lazy val deterministic = false +} diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/Spark33XShims.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/Spark33XShims.scala index e54c891ca01..bad7c3d8e4d 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/Spark33XShims.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/Spark33XShims.scala @@ -22,7 +22,7 @@ import org.apache.parquet.schema.MessageType import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.json.rapids.shims.v2.Spark33XFileOptionsShims import org.apache.spark.sql.connector.read.{Scan, SupportsRuntimeFiltering} import org.apache.spark.sql.execution.SparkPlan @@ -144,3 +144,6 @@ trait Spark33XShims extends Spark33XFileOptionsShims { }) ).map(r => (r.getClassFor.asSubclass(classOf[Scan]), r)).toMap } + +// Fallback to the default definition of `deterministic` +trait GpuDeterministicFirstLastCollectShim extends Expression diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala index 55953cc1b7c..099a7a80756 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf import ai.rapids.cudf.{BinaryOp, ColumnVector, DType, GroupByAggregation, GroupByScanAggregation, NullPolicy, ReductionAggregation, ReplacePolicy, RollingAggregation, RollingAggregationOnColumn, ScanAggregation} import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2.{ShimExpression, ShimUnaryExpression} +import com.nvidia.spark.rapids.shims.v2.{GpuDeterministicFirstLastCollectShim, ShimExpression, ShimUnaryExpression} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult @@ -1389,7 +1389,10 @@ case class GpuAverage(child: Expression) extends GpuAggregateFunction * here). */ case class GpuFirst(child: Expression, ignoreNulls: Boolean) - extends GpuAggregateFunction with ImplicitCastInputTypes with Serializable { + extends GpuAggregateFunction + with GpuDeterministicFirstLastCollectShim + with ImplicitCastInputTypes + with Serializable { private lazy val cudfFirst = AttributeReference("first", child.dataType)() private lazy val valueSet = AttributeReference("valueSet", BooleanType)() @@ -1419,8 +1422,6 @@ case class GpuFirst(child: Expression, ignoreNulls: Boolean) override def children: Seq[Expression] = child :: Nil override def nullable: Boolean = true override def dataType: DataType = child.dataType - // First is not a deterministic function. - override lazy val deterministic: Boolean = false override def toString: String = s"gpufirst($child)${if (ignoreNulls) " ignore nulls"}" override def checkInputDataTypes(): TypeCheckResult = { @@ -1434,7 +1435,10 @@ case class GpuFirst(child: Expression, ignoreNulls: Boolean) } case class GpuLast(child: Expression, ignoreNulls: Boolean) - extends GpuAggregateFunction with ImplicitCastInputTypes with Serializable { + extends GpuAggregateFunction + with GpuDeterministicFirstLastCollectShim + with ImplicitCastInputTypes + with Serializable { private lazy val cudfLast = AttributeReference("last", child.dataType)() private lazy val valueSet = AttributeReference("valueSet", BooleanType)() @@ -1463,8 +1467,6 @@ case class GpuLast(child: Expression, ignoreNulls: Boolean) override def children: Seq[Expression] = child :: Nil override def nullable: Boolean = true override def dataType: DataType = child.dataType - // Last is not a deterministic function. - override lazy val deterministic: Boolean = false override def toString: String = s"gpulast($child)${if (ignoreNulls) " ignore nulls"}" override def checkInputDataTypes(): TypeCheckResult = { @@ -1477,14 +1479,13 @@ case class GpuLast(child: Expression, ignoreNulls: Boolean) } } -trait GpuCollectBase extends GpuAggregateFunction with GpuAggregateWindowFunction { +trait GpuCollectBase + extends GpuAggregateFunction + with GpuDeterministicFirstLastCollectShim + with GpuAggregateWindowFunction { def child: Expression - // Collect operations are non-deterministic since their results depend on the - // actual order of input rows. - override lazy val deterministic: Boolean = false - override def nullable: Boolean = false override def dataType: DataType = ArrayType(child.dataType, containsNull = false)