Skip to content

Commit

Permalink
Make Collect, first and last as deterministic aggregate functions for…
Browse files Browse the repository at this point in the history
… Spark-3.3 (#4677)

* Make First, Last and Collect as deterministic for Spark-3.3

Signed-off-by: Niranjan Artal <nartal@nvidia.com>
  • Loading branch information
nartal1 authored Feb 3, 2022
1 parent c000449 commit 8380df8
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids.shims.v2

import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.json.rapids.shims.v2.Spark30Xuntil33XFileOptionsShims
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.v2._
Expand All @@ -28,3 +29,9 @@ trait Spark30Xuntil33XShims extends Spark30Xuntil33XFileOptionsShims {
GpuOverrides.neverReplaceExec[ShowCurrentNamespaceExec]("Namespace metadata operation")
}
}

// First, Last and Collect have mistakenly been marked as non-deterministic until Spark-3.3.
// They are actually deterministic iff their child expression is deterministic.
trait GpuDeterministicFirstLastCollectShim extends Expression {
override lazy val deterministic = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.parquet.schema.MessageType
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.json.rapids.shims.v2.Spark33XFileOptionsShims
import org.apache.spark.sql.connector.read.{Scan, SupportsRuntimeFiltering}
import org.apache.spark.sql.execution.SparkPlan
Expand Down Expand Up @@ -144,3 +144,6 @@ trait Spark33XShims extends Spark33XFileOptionsShims {
})
).map(r => (r.getClassFor.asSubclass(classOf[Scan]), r)).toMap
}

// Fallback to the default definition of `deterministic`
trait GpuDeterministicFirstLastCollectShim extends Expression
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids
import ai.rapids.cudf
import ai.rapids.cudf.{BinaryOp, ColumnVector, DType, GroupByAggregation, GroupByScanAggregation, NullPolicy, ReductionAggregation, ReplacePolicy, RollingAggregation, RollingAggregationOnColumn, ScanAggregation}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.v2.{ShimExpression, ShimUnaryExpression}
import com.nvidia.spark.rapids.shims.v2.{GpuDeterministicFirstLastCollectShim, ShimExpression, ShimUnaryExpression}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
Expand Down Expand Up @@ -1389,7 +1389,10 @@ case class GpuAverage(child: Expression) extends GpuAggregateFunction
* here).
*/
case class GpuFirst(child: Expression, ignoreNulls: Boolean)
extends GpuAggregateFunction with ImplicitCastInputTypes with Serializable {
extends GpuAggregateFunction
with GpuDeterministicFirstLastCollectShim
with ImplicitCastInputTypes
with Serializable {

private lazy val cudfFirst = AttributeReference("first", child.dataType)()
private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
Expand Down Expand Up @@ -1419,8 +1422,6 @@ case class GpuFirst(child: Expression, ignoreNulls: Boolean)
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
// First is not a deterministic function.
override lazy val deterministic: Boolean = false
override def toString: String = s"gpufirst($child)${if (ignoreNulls) " ignore nulls"}"

override def checkInputDataTypes(): TypeCheckResult = {
Expand All @@ -1434,7 +1435,10 @@ case class GpuFirst(child: Expression, ignoreNulls: Boolean)
}

case class GpuLast(child: Expression, ignoreNulls: Boolean)
extends GpuAggregateFunction with ImplicitCastInputTypes with Serializable {
extends GpuAggregateFunction
with GpuDeterministicFirstLastCollectShim
with ImplicitCastInputTypes
with Serializable {

private lazy val cudfLast = AttributeReference("last", child.dataType)()
private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
Expand Down Expand Up @@ -1463,8 +1467,6 @@ case class GpuLast(child: Expression, ignoreNulls: Boolean)
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
// Last is not a deterministic function.
override lazy val deterministic: Boolean = false
override def toString: String = s"gpulast($child)${if (ignoreNulls) " ignore nulls"}"

override def checkInputDataTypes(): TypeCheckResult = {
Expand All @@ -1477,14 +1479,13 @@ case class GpuLast(child: Expression, ignoreNulls: Boolean)
}
}

trait GpuCollectBase extends GpuAggregateFunction with GpuAggregateWindowFunction {
trait GpuCollectBase
extends GpuAggregateFunction
with GpuDeterministicFirstLastCollectShim
with GpuAggregateWindowFunction {

def child: Expression

// Collect operations are non-deterministic since their results depend on the
// actual order of input rows.
override lazy val deterministic: Boolean = false

override def nullable: Boolean = false

override def dataType: DataType = ArrayType(child.dataType, containsNull = false)
Expand Down

0 comments on commit 8380df8

Please sign in to comment.