Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPU version of ToPrettyString [databricks] #9221

Merged
merged 28 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
3711905
poc working
razajafri Sep 8, 2023
d48ab00
Refactoring
razajafri Sep 9, 2023
ec8e575
enable spark 350 tests
razajafri Sep 9, 2023
b3d238d
cleanup
razajafri Sep 11, 2023
223047e
Signing off
razajafri Sep 12, 2023
3708a2c
added unit tests
razajafri Sep 14, 2023
e8869c2
addressed review comments
razajafri Sep 14, 2023
adf05dd
Renamed test
razajafri Sep 14, 2023
09aee4f
cleanup
razajafri Sep 14, 2023
4bfca7a
add function for getting spark conf values
razajafri Sep 14, 2023
f435b61
fixed scalastyle
razajafri Sep 19, 2023
57d87fe
addressed review comments
razajafri Sep 19, 2023
8781b4f
added hexString to show binary to string
razajafri Sep 20, 2023
5b4daf8
Merge remote-tracking branch 'origin/branch-23.10' into SP-9171-topre…
razajafri Sep 20, 2023
650e25e
Merge remote-tracking branch 'origin/branch-23.10' into SP-9171-topre…
razajafri Sep 21, 2023
79ccc0b
refactored cast options
razajafri Sep 21, 2023
8f3ae34
Refactored GpuCast to take a CastOptions to make it easer to pass in …
razajafri Sep 21, 2023
d3c6bee
singing off
razajafri Sep 21, 2023
c1ac19f
Merge remote-tracking branch 'private/SP-9284-add-castoptions' into S…
razajafri Sep 21, 2023
691120e
addressed review concerns
razajafri Sep 21, 2023
063252c
addressed review comments
razajafri Sep 22, 2023
5a92e5c
removed CastOptions to make merging easy
razajafri Sep 22, 2023
1c765ed
Merge remote-tracking branch 'private/SP-9284-add-castoptions' into S…
razajafri Sep 22, 2023
3841994
added more tests for basic types
razajafri Sep 23, 2023
c889f97
Merge remote-tracking branch 'origin/branch-23.10' into SP-9171-topre…
razajafri Sep 23, 2023
8a4ff65
added more unit tests
razajafri Sep 25, 2023
8cf22e6
Added tests for more types
razajafri Sep 26, 2023
9ce742a
call toHex() instead of fromStrings()
razajafri Sep 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error, assert_gpu_fallback_collect, assert_spark_exception
from data_gen import *
from spark_session import is_before_spark_320, is_before_spark_330, is_spark_340_or_later, is_spark_350_or_later, \
is_databricks113_or_later, with_gpu_session
from spark_session import is_before_spark_320, is_before_spark_330, is_spark_340_or_later, \
is_databricks113_or_later
from marks import allow_non_gpu, approximate_float
from pyspark.sql.types import *
from spark_init_internal import spark_version
Expand Down Expand Up @@ -297,7 +297,6 @@ def _assert_cast_to_string_equal (data_gen, conf):

@pytest.mark.parametrize('data_gen', all_array_gens_for_cast_to_string, ids=idfn)
@pytest.mark.parametrize('legacy', ['true', 'false'])
@pytest.mark.xfail(condition=is_spark_350_or_later(), reason='https://github.com/NVIDIA/spark-rapids/issues/9065')
def test_cast_array_to_string(data_gen, legacy):
_assert_cast_to_string_equal(
data_gen,
Expand All @@ -317,7 +316,6 @@ def test_cast_array_with_unmatched_element_to_string(data_gen, legacy):

@pytest.mark.parametrize('data_gen', basic_map_gens_for_cast_to_string, ids=idfn)
@pytest.mark.parametrize('legacy', ['true', 'false'])
@pytest.mark.xfail(condition=is_spark_350_or_later(), reason='https://github.com/NVIDIA/spark-rapids/issues/9065')
def test_cast_map_to_string(data_gen, legacy):
_assert_cast_to_string_equal(
data_gen,
Expand All @@ -337,7 +335,6 @@ def test_cast_map_with_unmatched_element_to_string(data_gen, legacy):

@pytest.mark.parametrize('data_gen', [StructGen([[str(i), gen] for i, gen in enumerate(basic_array_struct_gens_for_cast_to_string)] + [["map", MapGen(ByteGen(nullable=False), null_gen)]])], ids=idfn)
@pytest.mark.parametrize('legacy', ['true', 'false'])
@pytest.mark.xfail(condition=is_spark_350_or_later(), reason='https://github.com/NVIDIA/spark-rapids/issues/9065')
def test_cast_struct_to_string(data_gen, legacy):
_assert_cast_to_string_equal(
data_gen,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import java.util.NoSuchElementException;

import ai.rapids.cudf.Scalar;
import com.nvidia.spark.rapids.GpuCast;
import com.nvidia.spark.rapids.CastOperation;
import com.nvidia.spark.rapids.GpuColumnVector;
import com.nvidia.spark.rapids.GpuScalar;
import com.nvidia.spark.rapids.iceberg.data.GpuDeleteFilter;
Expand Down Expand Up @@ -157,7 +157,7 @@ static ColumnarBatch addUpcastsIfNeeded(ColumnarBatch batch, Schema expectedSche
DataType expectedSparkType = SparkSchemaUtil.convert(expectedColumnTypes.get(i).type());
GpuColumnVector oldColumn = columns[i];
columns[i] = GpuColumnVector.from(
GpuCast.doCast(oldColumn.getBase(), oldColumn.dataType(), expectedSparkType,
CastOperation.apply(oldColumn.getBase(), oldColumn.dataType(), expectedSparkType,
false, false, false), expectedSparkType);
}
ColumnarBatch newBatch = new ColumnarBatch(columns, batch.numRows());
Expand Down
21 changes: 21 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/FloatUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,25 @@ object FloatUtils {
}
}
}

private[rapids] def castFloatingTypeToString(input: ColumnView): ColumnVector = {
withResource(input.castTo(DType.STRING)) { cudfCast =>

// replace "e+" with "E"
val replaceExponent = withResource(Scalar.fromString("e+")) { cudfExponent =>
withResource(Scalar.fromString("E")) { sparkExponent =>
cudfCast.stringReplace(cudfExponent, sparkExponent)
}
}

// replace "Inf" with "Infinity"
withResource(replaceExponent) { replaceExponent =>
withResource(Scalar.fromString("Inf")) { cudfInf =>
withResource(Scalar.fromString("Infinity")) { sparkInfinity =>
replaceExponent.stringReplace(cudfInf, sparkInfinity)
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package com.nvidia.spark.rapids
import ai.rapids.cudf
import ai.rapids.cudf.{DType, GroupByAggregation, ReductionAggregation}
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.GpuCast.doCast
import com.nvidia.spark.rapids.shims.ShimExpression

import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -144,7 +143,7 @@ case class ApproxPercentileFromTDigestExpr(
// array and return that (after converting from Double to finalDataType)
withResource(cv.getBase.approxPercentile(Array(p))) { percentiles =>
withResource(percentiles.extractListElement(0)) { childView =>
withResource(doCast(childView, DataTypes.DoubleType, finalDataType,
withResource(CastOperation(childView, DataTypes.DoubleType, finalDataType,
ansiMode = false, legacyCastToString = false,
stringToDateAnsiModeEnabled = false)) { childCv =>
GpuColumnVector.from(childCv.copyToColumnVector(), dataType)
Expand All @@ -159,7 +158,7 @@ case class ApproxPercentileFromTDigestExpr(
GpuColumnVector.from(percentiles.incRefCount(), dataType)
} else {
withResource(percentiles.getChildColumnView(0)) { childView =>
withResource(doCast(childView, DataTypes.DoubleType, finalDataType,
withResource(CastOperation(childView, DataTypes.DoubleType, finalDataType,
ansiMode = false, legacyCastToString = false,
stringToDateAnsiModeEnabled = false)) { childCv =>
withResource(percentiles.replaceListChild(childCv)) { x =>
Expand Down
Loading
Loading