diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 386d6f32d41..53c38529e4e 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -16588,13 +16588,13 @@ Accelerator support is described below.
S |
S* |
S |
-NS |
-NS |
-NS |
-NS |
-NS |
-NS |
-NS |
+S* |
+S |
+S |
+S |
+PS* (missing nested UDT) |
+PS* (missing nested UDT) |
+PS* (missing nested UDT) |
NS |
@@ -16609,13 +16609,13 @@ Accelerator support is described below.
S |
S* |
S |
-NS |
-NS |
-NS |
-NS |
-PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) |
-NS |
-NS |
+S* |
+S |
+S |
+S |
+PS* (missing nested UDT) |
+PS* (missing nested UDT) |
+PS* (missing nested UDT) |
NS |
@@ -16678,13 +16678,13 @@ Accelerator support is described below.
S |
S* |
S |
-NS |
-NS |
-NS |
-NS |
-NS |
-NS |
-NS |
+S* |
+S |
+S |
+S |
+PS* (missing nested UDT) |
+PS* (missing nested UDT) |
+PS* (missing nested UDT) |
NS |
@@ -16699,13 +16699,13 @@ Accelerator support is described below.
S |
S* |
S |
-NS |
-NS |
-NS |
-NS |
-PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) |
-NS |
-NS |
+S* |
+S |
+S |
+S |
+PS* (missing nested UDT) |
+PS* (missing nested UDT) |
+PS* (missing nested UDT) |
NS |
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveOverrides.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveOverrides.scala
index f0b86fe4613..2821f350cf5 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveOverrides.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveOverrides.scala
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020, NVIDIA CORPORATION.
+ * Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -23,6 +23,10 @@ import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.hive.{HiveGenericUDF, HiveSimpleUDF}
object GpuHiveOverrides {
+ // UDFs can support all types except UDT which does not have a clear columnar representation.
+ private val udfTypeSig = (TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.NULL +
+ TypeSig.BINARY + TypeSig.CALENDAR + TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT).nested()
+
def isSparkHiveAvailable: Boolean = {
// Using the same approach as SparkSession.hiveClassesArePresent
val loader = Thread.currentThread().getContextClassLoader
@@ -48,12 +52,9 @@ object GpuHiveOverrides {
GpuOverrides.expr[HiveSimpleUDF](
"Hive UDF, support requires the UDF to implement a RAPIDS-accelerated interface",
ExprChecks.projectNotLambda(
- TypeSig.commonCudfTypes + TypeSig.ARRAY.nested(TypeSig.commonCudfTypes),
+ udfTypeSig,
TypeSig.all,
- repeatingParamCheck = Some(RepeatingParamCheck(
- "param",
- TypeSig.commonCudfTypes,
- TypeSig.all))),
+ repeatingParamCheck = Some(RepeatingParamCheck("param", udfTypeSig, TypeSig.all))),
(a, conf, p, r) => new ExprMeta[HiveSimpleUDF](a, conf, p, r) {
override def tagExprForGpu(): Unit = {
a.function match {
@@ -79,12 +80,9 @@ object GpuHiveOverrides {
"Hive Generic UDF, support requires the UDF to implement a " +
"RAPIDS-accelerated interface",
ExprChecks.projectNotLambda(
- TypeSig.commonCudfTypes + TypeSig.ARRAY.nested(TypeSig.commonCudfTypes),
+ udfTypeSig,
TypeSig.all,
- repeatingParamCheck = Some(RepeatingParamCheck(
- "param",
- TypeSig.commonCudfTypes,
- TypeSig.all))),
+ repeatingParamCheck = Some(RepeatingParamCheck("param", udfTypeSig, TypeSig.all))),
(a, conf, p, r) => new ExprMeta[HiveGenericUDF](a, conf, p, r) {
override def tagExprForGpu(): Unit = {
a.function match {