diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 29ed4e2e274..5866cc450c7 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1764,8 +1764,7 @@ object GpuOverrides extends Logging { Seq(ParamCheck( "pivotColumn", (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128) - .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) - .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + .withPsNote(Seq(TypeEnum.DOUBLE, TypeEnum.FLOAT), nanAggPsNote), TypeSig.all), ParamCheck("valueColumn", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, @@ -1806,8 +1805,7 @@ object GpuOverrides extends Logging { Seq(ParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT) .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) - .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) - .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + .withPsNote(Seq(TypeEnum.DOUBLE, TypeEnum.FLOAT), nanAggPsNote), TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts ++ ExprChecks.windowOnly( @@ -1815,8 +1813,7 @@ object GpuOverrides extends Logging { TypeSig.orderable, Seq(ParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) - .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) - .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + .withPsNote(Seq(TypeEnum.DOUBLE, TypeEnum.FLOAT), nanAggPsNote), TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts), (max, conf, p, r) => new AggExprMeta[Max](max, conf, p, r) { override def tagAggForGpu(): Unit = { @@ -1835,8 +1832,7 @@ object GpuOverrides extends Logging { Seq(ParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT) .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) - .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) - .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + .withPsNote(Seq(TypeEnum.DOUBLE, TypeEnum.FLOAT), nanAggPsNote), TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts ++ ExprChecks.windowOnly( @@ -1844,8 +1840,7 @@ object GpuOverrides extends Logging { TypeSig.orderable, Seq(ParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) - .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) - .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + .withPsNote(Seq(TypeEnum.DOUBLE, TypeEnum.FLOAT), nanAggPsNote), TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts), (a, conf, p, r) => new AggExprMeta[Min](a, conf, p, r) { override def tagAggForGpu(): Unit = { @@ -2159,8 +2154,7 @@ object GpuOverrides extends Logging { TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, TypeSig.orderable, TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) - .withPsNote(TypeEnum.DOUBLE, GpuOverrides.nanAggPsNote) - .withPsNote(TypeEnum.FLOAT, GpuOverrides.nanAggPsNote), + .withPsNote(Seq(TypeEnum.DOUBLE, TypeEnum.FLOAT), GpuOverrides.nanAggPsNote), TypeSig.ARRAY.nested(TypeSig.orderable)), (in, conf, p, r) => new UnaryExprMeta[ArrayMin](in, conf, p, r) { override def tagExprForGpu(): Unit = { @@ -2173,8 +2167,7 @@ object GpuOverrides extends Logging { TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, TypeSig.orderable, TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) - .withPsNote(TypeEnum.DOUBLE, GpuOverrides.nanAggPsNote) - .withPsNote(TypeEnum.FLOAT, GpuOverrides.nanAggPsNote), + .withPsNote(Seq(TypeEnum.DOUBLE, TypeEnum.FLOAT), GpuOverrides.nanAggPsNote), TypeSig.ARRAY.nested(TypeSig.orderable)), (in, conf, p, r) => new UnaryExprMeta[ArrayMax](in, conf, p, r) { override def tagExprForGpu(): Unit = { @@ -2195,10 +2188,10 @@ object GpuOverrides extends Logging { ("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.NULL), TypeSig.ARRAY.nested(TypeSig.all)), ("key", TypeSig.commonCudfTypes - .withPsNote(TypeEnum.DOUBLE, "NaN literals are not supported. Columnar input" + - s" must not contain NaNs and ${RapidsConf.HAS_NANS} must be false.") - .withPsNote(TypeEnum.FLOAT, "NaN literals are not supported. Columnar input" + - s" must not contain NaNs and ${RapidsConf.HAS_NANS} must be false."), + .withPsNote( + Seq(TypeEnum.DOUBLE, TypeEnum.FLOAT), + "NaN literals are not supported. Columnar input" + + s" must not contain NaNs and ${RapidsConf.HAS_NANS} must be false."), TypeSig.all)), (in, conf, p, r) => new BinaryExprMeta[ArrayContains](in, conf, p, r) { override def tagExprForGpu(): Unit = { @@ -2834,9 +2827,9 @@ object GpuOverrides extends Logging { TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP).nested() .withPsNote(TypeEnum.STRUCT, "Round-robin partitioning is not supported for nested " + s"structs if ${SQLConf.SORT_BEFORE_REPARTITION.key} is true") - .withPsNote(TypeEnum.ARRAY, "Round-robin partitioning is not supported if " + - s"${SQLConf.SORT_BEFORE_REPARTITION.key} is true") - .withPsNote(TypeEnum.MAP, "Round-robin partitioning is not supported if " + + .withPsNote( + Seq(TypeEnum.ARRAY, TypeEnum.MAP), + "Round-robin partitioning is not supported if " + s"${SQLConf.SORT_BEFORE_REPARTITION.key} is true"), TypeSig.all), (shuffle, conf, p, r) => new GpuShuffleMeta(shuffle, conf, p, r)), @@ -2886,8 +2879,8 @@ object GpuOverrides extends Logging { (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP + TypeSig.ARRAY + TypeSig.STRUCT) .nested() - .withPsNote(TypeEnum.ARRAY, "not allowed for grouping expressions") - .withPsNote(TypeEnum.MAP, "not allowed for grouping expressions") + .withPsNote(Seq(TypeEnum.ARRAY, TypeEnum.MAP), + "not allowed for grouping expressions") .withPsNote(TypeEnum.STRUCT, "not allowed for grouping expressions if containing Array or Map as child"), TypeSig.all), diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 0da3a6834e0..3de2ddee361 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -256,6 +256,18 @@ final class TypeSig private( new TypeSig(initialTypes + dataType, maxAllowedDecimalPrecision, childTypes, litOnlyTypes, notes.+((dataType, note))) + /** + * Add a note about given types that marks them as partially supported. + * @param dataTypes the types this note is for. + * @param note the note itself + * @return the updated TypeSignature. + */ + def withPsNote(dataTypes: Seq[TypeEnum.Value], note: String): TypeSig = + new TypeSig( + dataTypes.foldLeft(initialTypes)(_+_), maxAllowedDecimalPrecision, childTypes, + litOnlyTypes, dataTypes.foldLeft(notes)((notes, dataType) => notes.+((dataType, note)))) + + private def isSupportedType(dataType: TypeEnum.Value): Boolean = initialTypes.contains(dataType) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 0fe567eab03..2c1b7abcc01 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2235,8 +2235,7 @@ object GpuOverrides extends Logging { Seq(ParamCheck( "pivotColumn", (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128) - .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) - .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + .withPsNote(Seq(TypeEnum.DOUBLE, TypeEnum.FLOAT), nanAggPsNote), TypeSig.all), ParamCheck("valueColumn", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, @@ -2284,8 +2283,7 @@ object GpuOverrides extends Logging { Seq(ParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT) .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) - .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) - .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + .withPsNote(Seq(TypeEnum.DOUBLE, TypeEnum.FLOAT), nanAggPsNote), TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts ++ ExprChecks.windowOnly( @@ -2293,8 +2291,7 @@ object GpuOverrides extends Logging { TypeSig.orderable, Seq(ParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) - .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) - .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + .withPsNote(Seq(TypeEnum.DOUBLE, TypeEnum.FLOAT), nanAggPsNote), TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts), (max, conf, p, r) => new AggExprMeta[Max](max, conf, p, r) { override def tagAggForGpu(): Unit = { @@ -2319,8 +2316,7 @@ object GpuOverrides extends Logging { Seq(ParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT) .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) - .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) - .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + .withPsNote(Seq(TypeEnum.DOUBLE, TypeEnum.FLOAT), nanAggPsNote), TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts ++ ExprChecks.windowOnly( @@ -2328,8 +2324,7 @@ object GpuOverrides extends Logging { TypeSig.orderable, Seq(ParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) - .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) - .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + .withPsNote(Seq(TypeEnum.DOUBLE, TypeEnum.FLOAT), nanAggPsNote), TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts), (a, conf, p, r) => new AggExprMeta[Min](a, conf, p, r) { override def tagAggForGpu(): Unit = { @@ -2633,26 +2628,11 @@ object GpuOverrides extends Logging { ("index/key", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128) .withPsNote(TypeEnum.INT, "Supported as array index. " + "Only Literals supported as map keys.") - .withPsNote(TypeEnum.BOOLEAN, "Unsupported as array index. " + - "Only Literals supported as map keys.") - .withPsNote(TypeEnum.BYTE, "Unsupported as array index. " + - "Only Literals supported as map keys.") - .withPsNote(TypeEnum.SHORT, "Unsupported as array index. " + - "Only Literals supported as map keys.") - .withPsNote(TypeEnum.LONG, "Unsupported as array index. " + - "Only Literals supported as map keys.") - .withPsNote(TypeEnum.FLOAT, "Unsupported as array index. " + - "Only Literals supported as map keys.") - .withPsNote(TypeEnum.DOUBLE, "Unsupported as array index. " + - "Only Literals supported as map keys.") - .withPsNote(TypeEnum.DATE, "Unsupported as array index. " + - "Only Literals supported as map keys.") - .withPsNote(TypeEnum.TIMESTAMP, "Unsupported as array index. " + - "Only Literals supported as map keys.") - .withPsNote(TypeEnum.STRING, "Unsupported as array index. " + - "Only Literals supported as map keys.") - .withPsNote(TypeEnum.DECIMAL, "Unsupported as array index. " + - "Only Literals supported as map keys."), + .withPsNote( + Seq(TypeEnum.BOOLEAN, TypeEnum.BYTE, TypeEnum.SHORT, TypeEnum.LONG, + TypeEnum.FLOAT, TypeEnum.DOUBLE, TypeEnum.DATE, TypeEnum.TIMESTAMP, + TypeEnum.STRING, TypeEnum.DECIMAL), + "Unsupported as array index. Only Literals supported as map keys."), TypeSig.all)), (in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) { override def tagExprForGpu(): Unit = { @@ -2740,8 +2720,7 @@ object GpuOverrides extends Logging { TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, TypeSig.orderable, TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) - .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) - .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + .withPsNote(Seq(TypeEnum.DOUBLE, TypeEnum.FLOAT), nanAggPsNote), TypeSig.ARRAY.nested(TypeSig.orderable)), (in, conf, p, r) => new UnaryExprMeta[ArrayMin](in, conf, p, r) { override def tagExprForGpu(): Unit = { @@ -2757,8 +2736,7 @@ object GpuOverrides extends Logging { TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, TypeSig.orderable, TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) - .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) - .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + .withPsNote(Seq(TypeEnum.DOUBLE, TypeEnum.FLOAT), nanAggPsNote), TypeSig.ARRAY.nested(TypeSig.orderable)), (in, conf, p, r) => new UnaryExprMeta[ArrayMax](in, conf, p, r) { override def tagExprForGpu(): Unit = { @@ -2783,10 +2761,10 @@ object GpuOverrides extends Logging { ("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.NULL), TypeSig.ARRAY.nested(TypeSig.all)), ("key", TypeSig.commonCudfTypes - .withPsNote(TypeEnum.DOUBLE, "NaN literals are not supported. Columnar input" + - s" must not contain NaNs and ${RapidsConf.HAS_NANS} must be false.") - .withPsNote(TypeEnum.FLOAT, "NaN literals are not supported. Columnar input" + - s" must not contain NaNs and ${RapidsConf.HAS_NANS} must be false."), + .withPsNote( + Seq(TypeEnum.DOUBLE, TypeEnum.FLOAT), + "NaN literals are not supported. Columnar input" + + s" must not contain NaNs and ${RapidsConf.HAS_NANS} must be false."), TypeSig.all)), (in, conf, p, r) => new BinaryExprMeta[ArrayContains](in, conf, p, r) { override def tagExprForGpu(): Unit = { @@ -3765,9 +3743,9 @@ object GpuOverrides extends Logging { GpuTypeShims.additionalArithmeticSupportedTypes).nested() .withPsNote(TypeEnum.STRUCT, "Round-robin partitioning is not supported for nested " + s"structs if ${SQLConf.SORT_BEFORE_REPARTITION.key} is true") - .withPsNote(TypeEnum.ARRAY, "Round-robin partitioning is not supported if " + - s"${SQLConf.SORT_BEFORE_REPARTITION.key} is true") - .withPsNote(TypeEnum.MAP, "Round-robin partitioning is not supported if " + + .withPsNote( + Seq(TypeEnum.ARRAY, TypeEnum.MAP), + "Round-robin partitioning is not supported if " + s"${SQLConf.SORT_BEFORE_REPARTITION.key} is true"), TypeSig.all), (shuffle, conf, p, r) => new GpuShuffleMeta(shuffle, conf, p, r)), @@ -3830,8 +3808,8 @@ object GpuOverrides extends Logging { (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP + TypeSig.ARRAY + TypeSig.STRUCT) .nested() - .withPsNote(TypeEnum.ARRAY, "not allowed for grouping expressions") - .withPsNote(TypeEnum.MAP, "not allowed for grouping expressions") + .withPsNote(Seq(TypeEnum.ARRAY, TypeEnum.MAP), + "not allowed for grouping expressions") .withPsNote(TypeEnum.STRUCT, "not allowed for grouping expressions if containing Array or Map as child"), TypeSig.all), @@ -3846,8 +3824,8 @@ object GpuOverrides extends Logging { .nested() .withPsNote(TypeEnum.BINARY, "only allowed when aggregate buffers can be " + "converted between CPU and GPU") - .withPsNote(TypeEnum.ARRAY, "not allowed for grouping expressions") - .withPsNote(TypeEnum.MAP, "not allowed for grouping expressions") + .withPsNote(Seq(TypeEnum.ARRAY, TypeEnum.MAP), + "not allowed for grouping expressions") .withPsNote(TypeEnum.STRUCT, "not allowed for grouping expressions if containing Array or Map as child"), TypeSig.all), @@ -3864,8 +3842,8 @@ object GpuOverrides extends Logging { .nested() .withPsNote(TypeEnum.BINARY, "only allowed when aggregate buffers can be " + "converted between CPU and GPU") - .withPsNote(TypeEnum.ARRAY, "not allowed for grouping expressions") - .withPsNote(TypeEnum.MAP, "not allowed for grouping expressions") + .withPsNote(Seq(TypeEnum.ARRAY, TypeEnum.MAP), + "not allowed for grouping expressions") .withPsNote(TypeEnum.STRUCT, "not allowed for grouping expressions if containing Array or Map as child"), TypeSig.all), diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 6119115e0a0..ebde6eff912 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -270,6 +270,17 @@ final class TypeSig private( new TypeSig(initialTypes + dataType, maxAllowedDecimalPrecision, childTypes, litOnlyTypes, notes.+((dataType, note))) + /** + * Add a note about given types that marks them as partially supported. + * @param dataTypes the types this note is for. + * @param note the note itself + * @return the updated TypeSignature. + */ + def withPsNote(dataTypes: Seq[TypeEnum.Value], note: String): TypeSig = + new TypeSig( + dataTypes.foldLeft(initialTypes)(_+_), maxAllowedDecimalPrecision, childTypes, + litOnlyTypes, dataTypes.foldLeft(notes)((notes, dataType) => notes.+((dataType, note)))) + private def isSupportedType(dataType: TypeEnum.Value): Boolean = initialTypes.contains(dataType)