Skip to content

Commit

Permalink
Let collect_list supports type of array.
Browse files Browse the repository at this point in the history
Let collect_list with Windowing supports type of array of struct.

Signed-off-by: Firestarman <firestarmanllc@gmail.com>
  • Loading branch information
firestarman committed Jan 27, 2021
1 parent cdc08e4 commit aee4fc6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -735,11 +735,11 @@ object GpuOverrides {
"\"window\") of rows",
ExprChecks.windowOnly(
TypeSig.commonCudfTypes + TypeSig.DECIMAL +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL),
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
TypeSig.all,
Seq(ParamCheck("windowFunction",
TypeSig.commonCudfTypes + TypeSig.DECIMAL +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL),
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
TypeSig.all),
ParamCheck("windowSpec",
TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL,
Expand Down Expand Up @@ -1644,11 +1644,13 @@ object GpuOverrides {
expr[AggregateExpression](
"Aggregate expression",
ExprChecks.fullAgg(
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.STRUCT),
TypeSig.all,
Seq(ParamCheck(
"aggFunc",
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.STRUCT),
TypeSig.all)),
Some(RepeatingParamCheck("filter", TypeSig.BOOLEAN, TypeSig.BOOLEAN))),
(a, conf, p, r) => new ExprMeta[AggregateExpression](a, conf, p, r) {
Expand Down Expand Up @@ -2174,12 +2176,11 @@ object GpuOverrides {
expr[CollectList](
"Collect a list of elements",
/* It should be 'fullAgg' eventually but now only support windowing, so 'windowOnly' */
ExprChecks.windowOnly(TypeSig.ARRAY.nested(TypeSig.integral +
TypeSig.STRUCT.nested(TypeSig.commonCudfTypes)),
ExprChecks.windowOnly(
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.STRUCT),
TypeSig.ARRAY.nested(TypeSig.all),
Seq(ParamCheck("input",
TypeSig.integral + TypeSig.STRUCT.nested(
TypeSig.integral + TypeSig.STRING + TypeSig.TIMESTAMP),
TypeSig.commonCudfTypes + TypeSig.STRUCT.nested(TypeSig.commonCudfTypes),
TypeSig.all))),
(c, conf, p, r) => new ExprMeta[CollectList](c, conf, p, r) {
override def convertToGpu(): GpuExpression = GpuCollectList(
Expand Down Expand Up @@ -2488,7 +2489,9 @@ object GpuOverrides {
(expand, conf, p, r) => new GpuExpandExecMeta(expand, conf, p, r)),
exec[WindowExec](
"Window-operator backend",
ExecChecks(TypeSig.commonCudfTypes + TypeSig.DECIMAL, TypeSig.all),
ExecChecks(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT +
TypeSig.ARRAY.nested(TypeSig.STRUCT + TypeSig.commonCudfTypes),
TypeSig.all),
(windowOp, conf, p, r) =>
new GpuWindowExecMeta(windowOp, conf, p, r)
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,8 @@ case class GpuWindowExpression(windowFunction: Expression, windowSpec: GpuWindow
}
}
}
val expectedType = GpuColumnVector.getNonNestedRapidsType(windowFunc.dataType)
if (expectedType != aggColumn.getType) {
withResource(aggColumn) { aggColumn =>
GpuColumnVector.from(aggColumn.castTo(expectedType), windowFunc.dataType)
}
} else {
GpuColumnVector.from(aggColumn, windowFunc.dataType)
}
// Seems we should not cast the type explicitly here but let GpuColumnVector handle it.
GpuColumnVector.from(aggColumn, windowFunc.dataType)
}

private def evaluateRangeBasedWindowExpression(cb : ColumnarBatch) : GpuColumnVector = {
Expand All @@ -230,14 +224,8 @@ case class GpuWindowExpression(windowFunction: Expression, windowSpec: GpuWindow
}
}
}
val expectedType = GpuColumnVector.getNonNestedRapidsType(windowFunc.dataType)
if (expectedType != aggColumn.getType) {
withResource(aggColumn) { aggColumn =>
GpuColumnVector.from(aggColumn.castTo(expectedType), windowFunc.dataType)
}
} else {
GpuColumnVector.from(aggColumn, windowFunc.dataType)
}
// Seems we should not cast the type explicitly here but let GpuColumnVector handle it.
GpuColumnVector.from(aggColumn, windowFunc.dataType)
}
}

Expand Down

0 comments on commit aee4fc6

Please sign in to comment.