Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
gerashegalov committed Mar 26, 2021
1 parent 9ee262b commit e5f9996
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 18 deletions.
4 changes: 2 additions & 2 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,9 @@ Accelerator supports are described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,18 +423,13 @@ object GpuOverrides {
"\\S", "\\v", "\\V", "\\w", "\\w", "\\p", "$", "\\b", "\\B", "\\A", "\\G", "\\Z", "\\z", "\\R",
"?", "|", "(", ")", "{", "}", "\\k", "\\Q", "\\E", ":", "!", "<=", ">")

private[this] val pluginSupportedOrderableSig = (
TypeSig.commonCudfTypes +
TypeSig.NULL +
TypeSig.DECIMAL +
TypeSig.STRUCT.nested(
TypeSig.commonCudfTypes +
TypeSig.NULL +
TypeSig.DECIMAL
))

private[this] def isNestedType(dataType: DataType) = dataType match {
case ArrayType(_, _) | MapType(_, _, _) | StructType(_) => true
private[this] val _commonTypes = TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL

private[this] val pluginSupportedOrderableSig = _commonTypes +
TypeSig.STRUCT.nested(_commonTypes)

private[this] def isStructType(dataType: DataType) = dataType match {
case StructType(_) => true
case _ => false
}

Expand Down Expand Up @@ -1826,7 +1821,7 @@ object GpuOverrides {
TypeSig.orderable))),
(sortOrder, conf, p, r) => new BaseExprMeta[SortOrder](sortOrder, conf, p, r) {
override def tagExprForGpu(): Unit = {
if (isNestedType(sortOrder.dataType)) {
if (isStructType(sortOrder.dataType)) {
val nullOrdering = sortOrder.nullOrdering
val directionDefaultNullOrdering = sortOrder.direction.defaultNullOrdering
val direction = sortOrder.direction.sql
Expand Down Expand Up @@ -2520,7 +2515,7 @@ object GpuOverrides {

override def tagPartForGpu() {
val numPartitions = rp.numPartitions
if (numPartitions > 1 && rp.ordering.exists(so => isNestedType(so.dataType))) {
if (numPartitions > 1 && rp.ordering.exists(so => isStructType(so.dataType))) {
willNotWorkOnGpu("only single partition sort is supported for nested types, " +
s"actual partitions: $numPartitions")
}
Expand Down Expand Up @@ -2772,10 +2767,10 @@ object GpuOverrides {
"The backend for the sort operator",
// The SortOrder TypeSig will govern what types can actually be used as sorting key data type.
// The types below are allowed as inputs and outputs.
ExecChecks(pluginSupportedOrderableSig, TypeSig.all),
ExecChecks(pluginSupportedOrderableSig + TypeSig.ARRAY.nested(), TypeSig.all),
(sort, conf, p, r) => new GpuSortMeta(sort, conf, p, r) {
override def tagPlanForGpu() {
if (!conf.stableSort && sort.sortOrder.exists(so => isNestedType(so.dataType))) {
if (!conf.stableSort && sort.sortOrder.exists(so => isStructType(so.dataType))) {
willNotWorkOnGpu("it's disabled for nested types " +
s"unless ${RapidsConf.STABLE_SORT.key} is true")
}
Expand Down

0 comments on commit e5f9996

Please sign in to comment.