Skip to content

Commit

Permalink
Struct to string casting functionality (#1814)
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Lee <ryanlee@nvidia.com>

Co-authored-by: Ryan Lee <ryanlee@nvidia.com>
  • Loading branch information
rwlee and rwlee authored Mar 30, 2021
1 parent 6e3970c commit 200c72d
Show file tree
Hide file tree
Showing 8 changed files with 289 additions and 51 deletions.
4 changes: 2 additions & 2 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -18175,7 +18175,7 @@ and the accelerator produces the same result.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS (the struct's children must also support being cast to string)</em></td>
<td> </td>
<td> </td>
<td> </td>
Expand Down Expand Up @@ -18579,7 +18579,7 @@ and the accelerator produces the same result.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS (the struct's children must also support being cast to string)</em></td>
<td> </td>
<td> </td>
<td> </td>
Expand Down
60 changes: 60 additions & 0 deletions integration_tests/src/main/python/struct_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,63 @@ def test_orderby_struct_2(data_gen):
lambda spark : append_unique_int_col_to_df(spark, unary_op_df(spark, data_gen)),
'struct_table',
'select struct_table.a, struct_table.uniq_int from struct_table order by uniq_int')

# conf with legacy cast to string on
legacy_complex_types_to_string = {'spark.sql.legacy.castComplexTypesToString.enabled': 'true'}
@pytest.mark.parametrize('data_gen', [StructGen([["first", boolean_gen], ["second", byte_gen], ["third", short_gen], ["fourth", int_gen], ["fifth", long_gen], ["sixth", string_gen], ["seventh", date_gen]])], ids=idfn)
def test_legacy_cast_struct_to_string(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(
f.col('a').cast("STRING")),
conf = legacy_complex_types_to_string)

@pytest.mark.parametrize('data_gen', [StructGen([["first", float_gen]])], ids=idfn)
@pytest.mark.xfail(reason='casting float to string is not an exact match')
def test_legacy_cast_struct_with_float_to_string(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(
f.col('a').cast("STRING")),
conf = legacy_complex_types_to_string)

@pytest.mark.parametrize('data_gen', [StructGen([["first", double_gen]])], ids=idfn)
@pytest.mark.xfail(reason='casting double to string is not an exact match')
def test_legacy_cast_struct_with_double_to_string(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(
f.col('a').cast("STRING")),
conf = legacy_complex_types_to_string)

@pytest.mark.parametrize('data_gen', [StructGen([["first", timestamp_gen]])], ids=idfn)
@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/219')
def test_legacy_cast_struct_with_timestamp_to_string(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(
f.col('a').cast("STRING")),
conf = legacy_complex_types_to_string)

@pytest.mark.parametrize('data_gen', [StructGen([["first", boolean_gen], ["second", byte_gen], ["third", short_gen], ["fourth", int_gen], ["fifth", long_gen], ["sixth", string_gen], ["seventh", date_gen]])], ids=idfn)
def test_cast_struct_to_string(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(
f.col('a').cast("STRING")))

@pytest.mark.parametrize('data_gen', [StructGen([["first", float_gen]])], ids=idfn)
@pytest.mark.xfail(reason='casting float to string is not an exact match')
def test_cast_struct_with_float_to_string(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(
f.col('a').cast("STRING")))

@pytest.mark.parametrize('data_gen', [StructGen([["first", double_gen]])], ids=idfn)
@pytest.mark.xfail(reason='casting double to string is not an exact match')
def test_cast_struct_with_double_to_string(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(
f.col('a').cast("STRING")))

@pytest.mark.parametrize('data_gen', [StructGen([["first", timestamp_gen]])], ids=idfn)
@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/219')
def test_cast_struct_with_timestamp_to_string(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(
f.col('a').cast("STRING")))
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,8 @@ class Spark300Shims extends SparkShims {
InMemoryFileIndex.shouldFilterOut(path)
}

override def getLegacyComplexTypeToString(): Boolean = true

// Arrow version changed between Spark versions
override def getArrowDataBuf(vec: ValueVector): (ByteBuffer, ReferenceManager) = {
val arrowBuf = vec.getDataBuffer()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,10 @@ class Spark311Shims extends Spark301Shims {
HadoopFSUtilsShim.shouldIgnorePath(path)
}

override def getLegacyComplexTypeToString(): Boolean = {
SQLConf.get.getConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING)
}

// Arrow version changed between Spark versions
override def getArrowDataBuf(vec: ValueVector): (ByteBuffer, ReferenceManager) = {
val arrowBuf = vec.getDataBuffer()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

package com.nvidia.spark.rapids

import ai.rapids.cudf.{ColumnVector, DType, Scalar}
import ai.rapids.cudf.{ColumnVector, ColumnView, DType, Scalar}

object FloatUtils extends Arm {

def nanToZero(cv: ColumnVector): ColumnVector = {
def nanToZero(cv: ColumnView): ColumnVector = {
if (cv.getType() != DType.FLOAT32 && cv.getType() != DType.FLOAT64) {
throw new IllegalArgumentException("Only Floats and Doubles allowed")
}
Expand Down
Loading

0 comments on commit 200c72d

Please sign in to comment.