diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 386d6f32d41..53c38529e4e 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -16588,13 +16588,13 @@ Accelerator support is described below. S S* S -NS -NS -NS -NS -NS -NS -NS +S* +S +S +S +PS* (missing nested UDT) +PS* (missing nested UDT) +PS* (missing nested UDT) NS @@ -16609,13 +16609,13 @@ Accelerator support is described below. S S* S -NS -NS -NS -NS -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) -NS -NS +S* +S +S +S +PS* (missing nested UDT) +PS* (missing nested UDT) +PS* (missing nested UDT) NS @@ -16678,13 +16678,13 @@ Accelerator support is described below. S S* S -NS -NS -NS -NS -NS -NS -NS +S* +S +S +S +PS* (missing nested UDT) +PS* (missing nested UDT) +PS* (missing nested UDT) NS @@ -16699,13 +16699,13 @@ Accelerator support is described below. S S* S -NS -NS -NS -NS -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) -NS -NS +S* +S +S +S +PS* (missing nested UDT) +PS* (missing nested UDT) +PS* (missing nested UDT) NS diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveOverrides.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveOverrides.scala index f0b86fe4613..2821f350cf5 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveOverrides.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveOverrides.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,10 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.hive.{HiveGenericUDF, HiveSimpleUDF} object GpuHiveOverrides { + // UDFs can support all types except UDT which does not have a clear columnar representation. + private val udfTypeSig = (TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.NULL + + TypeSig.BINARY + TypeSig.CALENDAR + TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT).nested() + def isSparkHiveAvailable: Boolean = { // Using the same approach as SparkSession.hiveClassesArePresent val loader = Thread.currentThread().getContextClassLoader @@ -48,12 +52,9 @@ object GpuHiveOverrides { GpuOverrides.expr[HiveSimpleUDF]( "Hive UDF, support requires the UDF to implement a RAPIDS-accelerated interface", ExprChecks.projectNotLambda( - TypeSig.commonCudfTypes + TypeSig.ARRAY.nested(TypeSig.commonCudfTypes), + udfTypeSig, TypeSig.all, - repeatingParamCheck = Some(RepeatingParamCheck( - "param", - TypeSig.commonCudfTypes, - TypeSig.all))), + repeatingParamCheck = Some(RepeatingParamCheck("param", udfTypeSig, TypeSig.all))), (a, conf, p, r) => new ExprMeta[HiveSimpleUDF](a, conf, p, r) { override def tagExprForGpu(): Unit = { a.function match { @@ -79,12 +80,9 @@ object GpuHiveOverrides { "Hive Generic UDF, support requires the UDF to implement a " + "RAPIDS-accelerated interface", ExprChecks.projectNotLambda( - TypeSig.commonCudfTypes + TypeSig.ARRAY.nested(TypeSig.commonCudfTypes), + udfTypeSig, TypeSig.all, - repeatingParamCheck = Some(RepeatingParamCheck( - "param", - TypeSig.commonCudfTypes, - TypeSig.all))), + repeatingParamCheck = Some(RepeatingParamCheck("param", udfTypeSig, TypeSig.all))), (a, conf, p, r) => new ExprMeta[HiveGenericUDF](a, conf, p, r) { override def tagExprForGpu(): Unit = { a.function match {