Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Collect, first and last as deterministic aggregate functions for Spark-3.3 #4677

Merged
merged 5 commits into from
Feb 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids.shims.v2

import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.json.rapids.shims.v2.Spark30Xuntil33XFileOptionsShims
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.v2._
Expand All @@ -28,3 +29,9 @@ trait Spark30Xuntil33XShims extends Spark30Xuntil33XFileOptionsShims {
GpuOverrides.neverReplaceExec[ShowCurrentNamespaceExec]("Namespace metadata operation")
}
}

// First, Last and Collect have mistakenly been marked as non-deterministic until Spark-3.3.
// They are actually deterministic iff their child expression is deterministic.
trait GpuDeterministicFirstLastCollectShim extends Expression {
override lazy val deterministic = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.parquet.schema.MessageType
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.json.rapids.shims.v2.Spark33XFileOptionsShims
import org.apache.spark.sql.connector.read.{Scan, SupportsRuntimeFiltering}
import org.apache.spark.sql.execution.SparkPlan
Expand Down Expand Up @@ -144,3 +144,6 @@ trait Spark33XShims extends Spark33XFileOptionsShims {
})
).map(r => (r.getClassFor.asSubclass(classOf[Scan]), r)).toMap
}

// Fallback to the default definition of `deterministic`
trait GpuDeterministicFirstLastCollectShim extends Expression
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids
import ai.rapids.cudf
import ai.rapids.cudf.{BinaryOp, ColumnVector, DType, GroupByAggregation, GroupByScanAggregation, NullPolicy, ReductionAggregation, ReplacePolicy, RollingAggregation, RollingAggregationOnColumn, ScanAggregation}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.v2.{ShimExpression, ShimUnaryExpression}
import com.nvidia.spark.rapids.shims.v2.{GpuDeterministicFirstLastCollectShim, ShimExpression, ShimUnaryExpression}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
Expand Down Expand Up @@ -1389,7 +1389,10 @@ case class GpuAverage(child: Expression) extends GpuAggregateFunction
* here).
*/
case class GpuFirst(child: Expression, ignoreNulls: Boolean)
extends GpuAggregateFunction with ImplicitCastInputTypes with Serializable {
extends GpuAggregateFunction
with GpuDeterministicFirstLastCollectShim
with ImplicitCastInputTypes
with Serializable {

private lazy val cudfFirst = AttributeReference("first", child.dataType)()
private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
Expand Down Expand Up @@ -1419,8 +1422,6 @@ case class GpuFirst(child: Expression, ignoreNulls: Boolean)
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
// First is not a deterministic function.
override lazy val deterministic: Boolean = false
override def toString: String = s"gpufirst($child)${if (ignoreNulls) " ignore nulls"}"

override def checkInputDataTypes(): TypeCheckResult = {
Expand All @@ -1434,7 +1435,10 @@ case class GpuFirst(child: Expression, ignoreNulls: Boolean)
}

case class GpuLast(child: Expression, ignoreNulls: Boolean)
extends GpuAggregateFunction with ImplicitCastInputTypes with Serializable {
extends GpuAggregateFunction
with GpuDeterministicFirstLastCollectShim
with ImplicitCastInputTypes
with Serializable {

private lazy val cudfLast = AttributeReference("last", child.dataType)()
private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
Expand Down Expand Up @@ -1463,8 +1467,6 @@ case class GpuLast(child: Expression, ignoreNulls: Boolean)
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
// Last is not a deterministic function.
override lazy val deterministic: Boolean = false
override def toString: String = s"gpulast($child)${if (ignoreNulls) " ignore nulls"}"

override def checkInputDataTypes(): TypeCheckResult = {
Expand All @@ -1477,14 +1479,13 @@ case class GpuLast(child: Expression, ignoreNulls: Boolean)
}
}

trait GpuCollectBase extends GpuAggregateFunction with GpuAggregateWindowFunction {
trait GpuCollectBase
extends GpuAggregateFunction
with GpuDeterministicFirstLastCollectShim
with GpuAggregateWindowFunction {

def child: Expression

// Collect operations are non-deterministic since their results depend on the
// actual order of input rows.
override lazy val deterministic: Boolean = false

override def nullable: Boolean = false

override def dataType: DataType = ArrayType(child.dataType, containsNull = false)
Expand Down