Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support executing collect_list on GPU with windowing. #1548

Merged
merged 15 commits into from
Feb 5, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.Year"></a>spark.rapids.sql.expression.Year|`year`|Returns the year from a date or timestamp|true|None|
<a name="sql.expression.AggregateExpression"></a>spark.rapids.sql.expression.AggregateExpression| |Aggregate expression|true|None|
<a name="sql.expression.Average"></a>spark.rapids.sql.expression.Average|`avg`, `mean`|Average aggregate operator|true|None|
<a name="sql.expression.CollectList"></a>spark.rapids.sql.expression.CollectList|`collect_list`|Collect a list of elements, now only supported by windowing.|false|This is disabled by default because for now the GPU collects null values to a list, but Spark does not. This will be fixed in future releases.|
<a name="sql.expression.Count"></a>spark.rapids.sql.expression.Count|`count`|Count aggregate operator|true|None|
<a name="sql.expression.First"></a>spark.rapids.sql.expression.First|`first_value`, `first`|first aggregate operator|true|None|
<a name="sql.expression.Last"></a>spark.rapids.sql.expression.Last|`last`, `last_value`|last aggregate operator|true|None|
Expand Down
67 changes: 57 additions & 10 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -745,9 +745,9 @@ Accelerator supports are described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
</tr>
</table>
Expand Down Expand Up @@ -15227,7 +15227,7 @@ Accelerator support is described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -15269,7 +15269,7 @@ Accelerator support is described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -15453,7 +15453,7 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -15495,7 +15495,7 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand All @@ -15517,7 +15517,7 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -15559,7 +15559,7 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand All @@ -15581,7 +15581,7 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -15623,7 +15623,7 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -15762,6 +15762,53 @@ Accelerator support is described below.
<td> </td>
</tr>
<tr>
<td rowSpan="2">CollectList</td>
<td rowSpan="2">`collect_list`</td>
<td rowSpan="2">Collect a list of elements, now only supported by windowing.</td>
<td rowSpan="2">This is disabled by default because for now the GPU collects null values to a list, but Spark does not. This will be fixed in future releases.</td>
<td rowSpan="2">window</td>
<td>input</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S*</td>
<td>S</td>
<td>S*</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="6">Count</td>
<td rowSpan="6">`count`</td>
<td rowSpan="6">Count aggregate operator</td>
Expand Down
51 changes: 51 additions & 0 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,54 @@ def test_window_aggs_for_ranges_of_dates(data_gen):
' range between 1 preceding and 1 following) as sum_c_asc '
'from window_agg_table'
)


'''
Spark will drop nulls when collecting, but seems GPU does not, so exceptions come up.
revans2 marked this conversation as resolved.
Show resolved Hide resolved
E Caused by: java.lang.AssertionError: value at 350 is null
E at ai.rapids.cudf.HostColumnVectorCore.assertsForGet(HostColumnVectorCore.java:228)
E at ai.rapids.cudf.HostColumnVectorCore.getInt(HostColumnVectorCore.java:254)
E at com.nvidia.spark.rapids.RapidsHostColumnVectorCore.getInt(RapidsHostColumnVectorCore.java:109)
E at org.apache.spark.sql.vectorized.ColumnarArray.getInt(ColumnarArray.java:128)

Now set nullable to false to pass the tests, once native supports dropping nulls, will set it to true.
'''
collect_data_gen = [
revans2 marked this conversation as resolved.
Show resolved Hide resolved
('a', RepeatSeqGen(LongGen(), length=20)),
('b', IntegerGen()),
('c_int', IntegerGen(nullable=False)),
('c_long', LongGen(nullable=False)),
('c_time', DateGen(nullable=False)),
('c_string', StringGen(nullable=False)),
('c_float', FloatGen(nullable=False)),
('c_decimal', DecimalGen(nullable=False, precision=8, scale=3)),
('c_struct', StructGen(nullable=False, children = [
['child_int', IntegerGen()],
['child_time', DateGen()],
['child_string', StringGen()],
['child_decimal', DecimalGen(nullable=False, precision=8, scale=3)]]))]

# SortExec does not support array type, so sort the result locally.
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [collect_data_gen], ids=idfn)
def test_window_aggs_for_rows_collect_list(data_gen):
assert_gpu_and_cpu_are_equal_sql(
lambda spark : gen_df(spark, data_gen, length=2048),
"window_collect_table",
'select '
' collect_list(c_int) over '
' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_int, '
' collect_list(c_long) over '
' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_long, '
' collect_list(c_time) over '
' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_time, '
' collect_list(c_string) over '
' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_string, '
' collect_list(c_float) over '
' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_float, '
' collect_list(c_decimal) over '
' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_decimal, '
' collect_list(c_struct) over '
' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_struct '
'from window_collect_table ',
{'spark.rapids.sql.expression.CollectList': 'true'})
Original file line number Diff line number Diff line change
Expand Up @@ -734,12 +734,11 @@ object GpuOverrides {
"Calculates a return value for every input row of a table based on a group (or " +
"\"window\") of rows",
ExprChecks.windowOnly(
TypeSig.commonCudfTypes + TypeSig.DECIMAL +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL),
(TypeSig.commonCudfTypes + TypeSig.DECIMAL).nested + TypeSig.ARRAY.nested(TypeSig.STRUCT),
revans2 marked this conversation as resolved.
Show resolved Hide resolved
TypeSig.all,
Seq(ParamCheck("windowFunction",
TypeSig.commonCudfTypes + TypeSig.DECIMAL +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL),
(TypeSig.commonCudfTypes + TypeSig.DECIMAL).nested +
revans2 marked this conversation as resolved.
Show resolved Hide resolved
TypeSig.ARRAY.nested(TypeSig.STRUCT),
TypeSig.all),
ParamCheck("windowSpec",
TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL,
Expand Down Expand Up @@ -1644,11 +1643,13 @@ object GpuOverrides {
expr[AggregateExpression](
"Aggregate expression",
ExprChecks.fullAgg(
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
(TypeSig.commonCudfTypes + TypeSig.DECIMAL).nested() + TypeSig.NULL +
TypeSig.ARRAY.nested(TypeSig.STRUCT),
revans2 marked this conversation as resolved.
Show resolved Hide resolved
TypeSig.all,
Seq(ParamCheck(
"aggFunc",
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
(TypeSig.commonCudfTypes + TypeSig.DECIMAL).nested() + TypeSig.NULL +
revans2 marked this conversation as resolved.
Show resolved Hide resolved
TypeSig.ARRAY.nested(TypeSig.STRUCT),
TypeSig.all)),
Some(RepeatingParamCheck("filter", TypeSig.BOOLEAN, TypeSig.BOOLEAN))),
(a, conf, p, r) => new ExprMeta[AggregateExpression](a, conf, p, r) {
Expand Down Expand Up @@ -2170,7 +2171,22 @@ object GpuOverrides {
(a, conf, p, r) => new UnaryExprMeta[MakeDecimal](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression =
GpuMakeDecimal(child, a.precision, a.scale, a.nullOnOverflow)
})
}),
expr[CollectList](
"Collect a list of elements, now only supported by windowing.",
/* It should be 'fullAgg' eventually but now only support windowing,
revans2 marked this conversation as resolved.
Show resolved Hide resolved
so 'aggNotGroupByOrReduction' */
ExprChecks.aggNotGroupByOrReduction(
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
TypeSig.ARRAY.nested(TypeSig.all),
Seq(ParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL).nested() + TypeSig.STRUCT,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again I think longer is clearer, but in this case because there is only one nested type I think it is okay. So we don't support doing a collect list on a struct that also contains a struct?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be supported. If it's not working, it might be a bug in libcudf. I'd be happy to investigate.

Copy link
Collaborator Author

@firestarman firestarman Feb 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated it, and the struct of struct is not covered in integration tests, so do not add it here. Shall we need it ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if yes, i will add it in a following PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me

TypeSig.all))),
(c, conf, p, r) => new ExprMeta[CollectList](c, conf, p, r) {
override def convertToGpu(): GpuExpression = GpuCollectList(
childExprs.head.convertToGpu(), c.mutableAggBufferOffset, c.inputAggBufferOffset)
}).disabledByDefault("for now the GPU collects null values to a list, but Spark does not." +
revans2 marked this conversation as resolved.
Show resolved Hide resolved
" This will be fixed in future releases.")
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap

val expressions: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] =
Expand Down Expand Up @@ -2499,7 +2515,9 @@ object GpuOverrides {
(expand, conf, p, r) => new GpuExpandExecMeta(expand, conf, p, r)),
exec[WindowExec](
"Window-operator backend",
ExecChecks(TypeSig.commonCudfTypes + TypeSig.DECIMAL, TypeSig.all),
ExecChecks(
(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT).nested() + TypeSig.ARRAY,
revans2 marked this conversation as resolved.
Show resolved Hide resolved
TypeSig.all),
(windowOp, conf, p, r) =>
new GpuWindowExecMeta(windowOp, conf, p, r)
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,15 @@ case class GpuWindowExpression(windowFunction: Expression, windowSpec: GpuWindow
}
}
}
val expectedType = GpuColumnVector.getNonNestedRapidsType(windowFunc.dataType)
if (expectedType != aggColumn.getType) {
withResource(aggColumn) { aggColumn =>
GpuColumnVector.from(aggColumn.castTo(expectedType), windowFunc.dataType)
}
} else {
GpuColumnVector.from(aggColumn, windowFunc.dataType)
// For nested type, do not cast
aggColumn.getType match {
case dType if dType.isNestedType =>
GpuColumnVector.from(aggColumn, windowFunc.dataType)
case _ =>
val expectedType = GpuColumnVector.getNonNestedRapidsType(windowFunc.dataType)
withResource(aggColumn) { aggColumn =>
GpuColumnVector.from(aggColumn.castTo(expectedType), windowFunc.dataType)
revans2 marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

Expand All @@ -230,13 +232,15 @@ case class GpuWindowExpression(windowFunction: Expression, windowSpec: GpuWindow
}
}
}
val expectedType = GpuColumnVector.getNonNestedRapidsType(windowFunc.dataType)
if (expectedType != aggColumn.getType) {
withResource(aggColumn) { aggColumn =>
GpuColumnVector.from(aggColumn.castTo(expectedType), windowFunc.dataType)
}
} else {
GpuColumnVector.from(aggColumn, windowFunc.dataType)
// For nested type, do not cast
aggColumn.getType match {
case dType if dType.isNestedType =>
GpuColumnVector.from(aggColumn, windowFunc.dataType)
case _ =>
val expectedType = GpuColumnVector.getNonNestedRapidsType(windowFunc.dataType)
withResource(aggColumn) { aggColumn =>
GpuColumnVector.from(aggColumn.castTo(expectedType), windowFunc.dataType)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ object ExprChecks {
}

/**
* Window only operations. Spark does not support these operations as anythign but a window
* Window only operations. Spark does not support these operations as anything but a window
* operation.
*/
def windowOnly(
Expand All @@ -996,6 +996,18 @@ object ExprChecks {
ExprChecksImpl(Map(
(WindowAggExprContext,
ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck))))

/**
* An aggregation check where window operations are supported by the plugin, but Spark
* also supports group by and reduction on these.
* This is now really for 'collect_list' which is only supported by windowing.
*/
def aggNotGroupByOrReduction(
outputCheck: TypeSig,
sparkOutputSig: TypeSig,
paramCheck: Seq[ParamCheck] = Seq.empty,
repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks =
windowOnly(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)
revans2 marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExprId, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Complete, Final, Partial, PartialMerge}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, BooleanType, DataType, DoubleType, LongType, NumericType, StructType}
import org.apache.spark.sql.types._

trait GpuAggregateFunction extends GpuExpression {
// using the child reference, define the shape of the vectors sent to
Expand Down Expand Up @@ -529,3 +529,42 @@ abstract class GpuLastBase(child: Expression)
override lazy val deterministic: Boolean = false
override def toString: String = s"gpulast($child)${if (ignoreNulls) " ignore nulls"}"
}

/**
* Collects and returns a list of non-unique elements.
*
* FIXME Not sure whether GPU version requires the two offset parameters. Keep it here first.
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
revans2 marked this conversation as resolved.
Show resolved Hide resolved
*/
case class GpuCollectList(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends GpuDeclarativeAggregate with GpuAggregateWindowFunction {

def this(child: Expression) = this(child, 0, 0)

override lazy val deterministic: Boolean = false
revans2 marked this conversation as resolved.
Show resolved Hide resolved

override def nullable: Boolean = false

override def prettyName: String = "collect_list"

override def dataType: DataType = ArrayType(child.dataType, false)

override def children: Seq[Expression] = child :: Nil

// WINDOW FUNCTION
override val windowInputProjection: Seq[Expression] = Seq(child)
override def windowAggregation(inputs: Seq[(ColumnVector, Int)]): AggregationOnColumn =
Aggregation.collect().onColumn(inputs.head._2)

// Declarative aggregate. But for now 'CollectList' does not support it.
// The members as below should NOT be used yet, ensured by the
// "TypeCheck.aggNotGroupByOrReduction" when trying to override the expression.
private lazy val cudfList = AttributeReference("collect_list", dataType)()
override val initialValues: Seq[GpuExpression] = Seq.empty
revans2 marked this conversation as resolved.
Show resolved Hide resolved
override val updateExpressions: Seq[Expression] = Seq.empty
override val mergeExpressions: Seq[GpuExpression] = Seq.empty
override val evaluateExpression: Expression = null
override val inputProjection: Seq[Expression] = Seq(child)
override def aggBufferAttributes: Seq[AttributeReference] = cudfList :: Nil
}