Skip to content

Commit

Permalink
Support String to Decimal 128 [databricks] (#4172)
Browse files Browse the repository at this point in the history
* Add support for string->decimal128

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* Added test for casting decimal 128 to string

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* added method for enabling/disabling negative decimal support

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* addressed review comments

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* refactored method to parent class

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* we should test the op on cpu and gpu if pre spark311

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* override the right method for databricks

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* addressed review comments

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* addressed review comments

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* only test negative scale decimal if Spark version < 3.1.1

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* fixed the precedence of the predicates

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* Don't skip the test for spark 3.3.0+

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* updated message if the test is skipped

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* added shim method and removed the comment

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* configs.md check in

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* updated copyright

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* remove decimal 128 from unsupported types

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

Co-authored-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
razajafri and razajafri authored Jan 5, 2022
1 parent 86a1208 commit 56245ab
Show file tree
Hide file tree
Showing 13 changed files with 82 additions and 58 deletions.
1 change: 0 additions & 1 deletion docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ Name | Description | Default Value
<a name="sql.castFloatToDecimal.enabled"></a>spark.rapids.sql.castFloatToDecimal.enabled|Casting from floating point types to decimal on the GPU returns results that have tiny difference compared to results returned from CPU.|false
<a name="sql.castFloatToIntegralTypes.enabled"></a>spark.rapids.sql.castFloatToIntegralTypes.enabled|Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.|false
<a name="sql.castFloatToString.enabled"></a>spark.rapids.sql.castFloatToString.enabled|Casting from floating point types to string on the GPU returns results that have a different precision than the default results of Spark.|false
<a name="sql.castStringToDecimal.enabled"></a>spark.rapids.sql.castStringToDecimal.enabled|When set to true, enables casting from strings to decimal type on the GPU. Currently string to decimal type on the GPU might produce results which slightly differed from the correct results when the string represents any number exceeding the max precision that CAST_STRING_TO_FLOAT can keep. For instance, the GPU returns 99999999999999987 given input string "99999999999999999". The cause of divergence is that we can not cast strings containing scientific notation to decimal directly. So, we have to cast strings to floats firstly. Then, cast floats to decimals. The first step may lead to precision loss.|false
<a name="sql.castStringToFloat.enabled"></a>spark.rapids.sql.castStringToFloat.enabled|When set to true, enables casting from strings to float types (float, double) on the GPU. Currently hex values aren't supported on the GPU. Also note that casting from string to float types on the GPU returns incorrect results when the string represents any number "1.7976931348623158E308" <= x < "1.7976931348623159E308" and "-1.7976931348623158E308" >= x > "-1.7976931348623159E308" in both these cases the GPU returns Double.MaxValue while CPU returns "+Infinity" and "-Infinity" respectively|false
<a name="sql.castStringToTimestamp.enabled"></a>spark.rapids.sql.castStringToTimestamp.enabled|When set to true, casting from string to timestamp is supported on the GPU. The GPU only supports a subset of formats when casting strings to timestamps. Refer to the CAST documentation for more details.|false
<a name="sql.concurrentGpuTasks"></a>spark.rapids.sql.concurrentGpuTasks|Set the number of tasks that can execute concurrently per GPU. Tasks may temporarily block when the number of concurrent tasks in the executor exceeds this amount. Allowing too many concurrent tasks on the same GPU may lead to GPU out of memory errors.|1
Expand Down
19 changes: 10 additions & 9 deletions integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,13 +14,11 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error, assert_gpu_fallback_collect
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error, assert_gpu_fallback_collect, assert_py4j_exception
from data_gen import *
from functools import reduce
from spark_session import is_before_spark_311, is_before_spark_320
from spark_session import is_before_spark_311, is_before_spark_320, is_before_spark_330, with_gpu_session
from marks import allow_non_gpu, approximate_float
from pyspark.sql.types import *
from pyspark.sql.functions import array_contains, col, first, isnan, lit, element_at

def test_cast_empty_string_to_int():
assert_gpu_and_cpu_are_equal_collect(
Expand Down Expand Up @@ -167,14 +165,12 @@ def test_cast_long_to_decimal_overflow():
f.col('a').cast(DecimalType(18, -1))),
conf={'spark.sql.legacy.allowNegativeScaleOfDecimal': True})



# casting these types to string should be passed
basic_gens_for_cast_to_string = [byte_gen, short_gen, int_gen, long_gen, string_gen, boolean_gen, date_gen, null_gen, timestamp_gen] + decimal_gens_no_neg
# casting these types to string is not exact match, marked as xfail when testing
not_matched_gens_for_cast_to_string = [float_gen, double_gen, decimal_gen_neg_scale]
# casting these types to string is not supported, marked as xfail when testing
not_support_gens_for_cast_to_string = decimal_128_gens + [MapGen(ByteGen(False), ByteGen())]
not_support_gens_for_cast_to_string = [MapGen(ByteGen(False), ByteGen())]

single_level_array_gens_for_cast_to_string = [ArrayGen(sub_gen) for sub_gen in basic_gens_for_cast_to_string]
nested_array_gens_for_cast_to_string = [
Expand Down Expand Up @@ -304,4 +300,9 @@ def test_cast_struct_with_unsupported_element_to_string_fallback(data_gen, legac
"spark.sql.legacy.castComplexTypesToString.enabled": legacy,
"spark.sql.legacy.allowNegativeScaleOfDecimal": 'true'}
)


@pytest.mark.skipif(not is_before_spark_311() and is_before_spark_330(), reason="RAPIDS doesn't support casting string to decimal for negative scale decimal in this version of Spark because of SPARK-37451")
def test_cast_string_to_negative_scale_decimal():
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, StringGen("[0-9]{9}")).select(
f.col('a').cast(DecimalType(8, -3))), conf={'spark.sql.legacy.allowNegativeScaleOfDecimal': True})
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,4 +37,6 @@ class Spark322Shims extends Spark322PlusShims with Spark30Xuntil33XShims {
metadataColumns: Seq[AttributeReference]): RDD[InternalRow] = {
new FileScanRDD(sparkSession, readFunction, filePartitions)
}

override def isCastingStringToNegDecimalScaleSupported: Boolean = false
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,4 +21,6 @@ import com.nvidia.spark.rapids.shims.v2._

class Spark330Shims extends Spark33XShims {
override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION

override def isCastingStringToNegDecimalScaleSupported: Boolean = true
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -128,4 +128,6 @@ trait Spark30XdbShimsBase extends SparkShims {
}

override def shouldFallbackOnAnsiTimestamp(): Boolean = false

override def isCastingStringToNegDecimalScaleSupported: Boolean = true
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -383,6 +383,8 @@ abstract class Spark30XShims extends Spark301until320Shims with Logging {
adaptivePlan.initialPlan
}

override def isCastingStringToNegDecimalScaleSupported: Boolean = true

// this is to help with an optimization in Spark 3.1, so we disable it by default in Spark 3.0.x
override def isEmptyRelation(relation: Any): Boolean = false
override def tryTransformIfEmptyRelation(mode: BroadcastMode): Option[Any] = None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -555,6 +555,8 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {

override def hasCastFloatTimestampUpcast: Boolean = false

override def isCastingStringToNegDecimalScaleSupported: Boolean = false

override def supportsColumnarAdaptivePlans: Boolean = false

override def columnarAdaptivePlan(a: AdaptiveSparkPlanExec, goal: CoalesceSizeGoal): SparkPlan = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -128,4 +128,6 @@ trait Spark31XdbShimsBase extends SparkShims {
}

override def shouldFallbackOnAnsiTimestamp(): Boolean = false

override def isCastingStringToNegDecimalScaleSupported: Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,8 @@ trait Spark320until322Shims extends SparkShims with RebaseShims with Logging {
adaptivePlan.initialPlan
}

override def isCastingStringToNegDecimalScaleSupported: Boolean = false

override def columnarAdaptivePlan(a: AdaptiveSparkPlanExec,
goal: CoalesceSizeGoal): SparkPlan = {
a.copy(supportsColumnar = true)
Expand Down
16 changes: 10 additions & 6 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -76,9 +76,9 @@ final class CastExprMeta[INPUT <: CastBase](
s"set ${RapidsConf.ENABLE_CAST_DECIMAL_TO_STRING} to true if semantically " +
s"equivalent decimal strings are sufficient for your application.")
}
if (dt.precision > DType.DECIMAL64_MAX_PRECISION) {
if (dt.precision > DType.DECIMAL128_MAX_PRECISION) {
willNotWorkOnGpu(s"decimal to string with a " +
s"precision > ${DType.DECIMAL64_MAX_PRECISION} is not supported yet")
s"precision > ${DType.DECIMAL128_MAX_PRECISION} is not supported yet")
}
case ( _: DecimalType, _: FloatType | _: DoubleType) if !conf.isCastDecimalToFloatEnabled =>
willNotWorkOnGpu("the GPU will use a different strategy from Java's BigDecimal " +
Expand All @@ -95,7 +95,7 @@ final class CastExprMeta[INPUT <: CastBase](
"converting floating point data types to strings and this can produce results that " +
"differ from the default behavior in Spark. To enable this operation on the GPU, set" +
s" ${RapidsConf.ENABLE_CAST_FLOAT_TO_STRING} to true.")
case (_: StringType, dt: DecimalType) if dt.precision + 1 > Decimal.MAX_LONG_DIGITS =>
case (_: StringType, dt: DecimalType) if dt.precision + 1 > DecimalType.MAX_PRECISION =>
willNotWorkOnGpu(s"Because of rounding requirements we cannot support $dt on the GPU")
case (_: StringType, _: FloatType | _: DoubleType) if !conf.isCastStringToFloatEnabled =>
willNotWorkOnGpu("Currently hex values aren't supported on the GPU. Also note " +
Expand All @@ -115,6 +115,11 @@ final class CastExprMeta[INPUT <: CastBase](
YearParseUtil.tagParseStringAsDate(conf, this)
case (_: StringType, _: DateType) =>
YearParseUtil.tagParseStringAsDate(conf, this)
case (_: StringType, dt:DecimalType) =>
if (dt.scale < 0 && !ShimLoader.getSparkShims.isCastingStringToNegDecimalScaleSupported) {
willNotWorkOnGpu("RAPIDS doesn't support casting string to decimal for " +
"negative scale decimal in this version of Spark because of SPARK-37451")
}
case (structType: StructType, StringType) =>
structType.foreach { field =>
recursiveTagExprForGpuCheck(field.dataType, StringType, depth + 1)
Expand Down Expand Up @@ -911,8 +916,7 @@ object GpuCast extends Arm {
// needed. This step is required so we can round up if needed in the final step
// 4. Now cast newDt to dt (Decimal to Decimal)
def getInterimDecimalPromoteIfNeeded(dt: DecimalType): DecimalType = {
if (dt.precision + 1 > Decimal.MAX_LONG_DIGITS) {
//We don't support Decimal 128
if (dt.precision + 1 > DecimalType.MAX_PRECISION) {
throw new IllegalArgumentException("One or more values exceed the maximum supported " +
"Decimal precision while conversion")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -605,18 +605,6 @@ object RapidsConf {
.booleanConf
.createWithDefault(false)

val ENABLE_CAST_STRING_TO_DECIMAL = conf("spark.rapids.sql.castStringToDecimal.enabled")
.doc("When set to true, enables casting from strings to decimal type on the GPU. Currently " +
"string to decimal type on the GPU might produce results which slightly differed from the " +
"correct results when the string represents any number exceeding the max precision that " +
"CAST_STRING_TO_FLOAT can keep. For instance, the GPU returns 99999999999999987 given " +
"input string \"99999999999999999\". The cause of divergence is that we can not cast " +
"strings containing scientific notation to decimal directly. So, we have to cast strings " +
"to floats firstly. Then, cast floats to decimals. The first step may lead to precision " +
"loss.")
.booleanConf
.createWithDefault(false)

val ENABLE_CAST_STRING_TO_TIMESTAMP = conf("spark.rapids.sql.castStringToTimestamp.enabled")
.doc("When set to true, casting from string to timestamp is supported on the GPU. The GPU " +
"only supports a subset of formats when casting strings to timestamps. Refer to the CAST " +
Expand Down Expand Up @@ -1561,8 +1549,6 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val isCastStringToFloatEnabled: Boolean = get(ENABLE_CAST_STRING_TO_FLOAT)

lazy val isCastStringToDecimalEnabled: Boolean = get(ENABLE_CAST_STRING_TO_DECIMAL)

lazy val isCastFloatToIntegralTypesEnabled: Boolean = get(ENABLE_CAST_FLOAT_TO_INTEGRAL_TYPES)

lazy val isCsvTimestampReadEnabled: Boolean = get(ENABLE_CSV_TIMESTAMPS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ trait SparkShims {
def int96ParquetRebaseWrite(conf: SQLConf): String
def int96ParquetRebaseReadKey: String
def int96ParquetRebaseWriteKey: String
def isCastingStringToNegDecimalScaleSupported: Boolean

def getParquetFilters(
schema: MessageType,
Expand Down
Loading

0 comments on commit 56245ab

Please sign in to comment.