From e5f99968569b62ecf1605b7f05da1332c90bee2c Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Thu, 25 Mar 2021 17:52:17 -0700 Subject: [PATCH] wip --- docs/supported_ops.md | 4 +-- .../nvidia/spark/rapids/GpuOverrides.scala | 27 ++++++++----------- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 9aec6529863..b115249e8f6 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -354,9 +354,9 @@ Accelerator supports are described below. S NS NS +PS* (missing nested BINARY, CALENDAR, MAP, STRUCT, UDT) NS -NS -PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) +PS* (missing nested BINARY, CALENDAR, MAP, STRUCT, UDT) NS diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 652f5b10435..88c4c8164a5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -423,18 +423,13 @@ object GpuOverrides { "\\S", "\\v", "\\V", "\\w", "\\w", "\\p", "$", "\\b", "\\B", "\\A", "\\G", "\\Z", "\\z", "\\R", "?", "|", "(", ")", "{", "}", "\\k", "\\Q", "\\E", ":", "!", "<=", ">") - private[this] val pluginSupportedOrderableSig = ( - TypeSig.commonCudfTypes + - TypeSig.NULL + - TypeSig.DECIMAL + - TypeSig.STRUCT.nested( - TypeSig.commonCudfTypes + - TypeSig.NULL + - TypeSig.DECIMAL - )) - - private[this] def isNestedType(dataType: DataType) = dataType match { - case ArrayType(_, _) | MapType(_, _, _) | StructType(_) => true + private[this] val _commonTypes = TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + + private[this] val pluginSupportedOrderableSig = _commonTypes + + TypeSig.STRUCT.nested(_commonTypes) + + private[this] def isStructType(dataType: DataType) = dataType match { + case StructType(_) => true case _ => false } @@ -1826,7 +1821,7 @@ object GpuOverrides { TypeSig.orderable))), (sortOrder, conf, p, r) => new BaseExprMeta[SortOrder](sortOrder, conf, p, r) { override def tagExprForGpu(): Unit = { - if (isNestedType(sortOrder.dataType)) { + if (isStructType(sortOrder.dataType)) { val nullOrdering = sortOrder.nullOrdering val directionDefaultNullOrdering = sortOrder.direction.defaultNullOrdering val direction = sortOrder.direction.sql @@ -2520,7 +2515,7 @@ object GpuOverrides { override def tagPartForGpu() { val numPartitions = rp.numPartitions - if (numPartitions > 1 && rp.ordering.exists(so => isNestedType(so.dataType))) { + if (numPartitions > 1 && rp.ordering.exists(so => isStructType(so.dataType))) { willNotWorkOnGpu("only single partition sort is supported for nested types, " + s"actual partitions: $numPartitions") } @@ -2772,10 +2767,10 @@ object GpuOverrides { "The backend for the sort operator", // The SortOrder TypeSig will govern what types can actually be used as sorting key data type. // The types below are allowed as inputs and outputs. - ExecChecks(pluginSupportedOrderableSig, TypeSig.all), + ExecChecks(pluginSupportedOrderableSig + TypeSig.ARRAY.nested(), TypeSig.all), (sort, conf, p, r) => new GpuSortMeta(sort, conf, p, r) { override def tagPlanForGpu() { - if (!conf.stableSort && sort.sortOrder.exists(so => isNestedType(so.dataType))) { + if (!conf.stableSort && sort.sortOrder.exists(so => isStructType(so.dataType))) { willNotWorkOnGpu("it's disabled for nested types " + s"unless ${RapidsConf.STABLE_SORT.key} is true") }