-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support float/double castings for ORC reading [databricks] #6319
Changes from 3 commits
c312ecb
a729c86
bb2d3a6
ef1163d
2930ab8
e785de1
c6e05a5
69e9d14
0cf548a
0b2a675
e6e1aa9
d427a4d
db0f0d2
46e71f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright (c) 2020-2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import pytest | ||
|
||
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error | ||
from data_gen import * | ||
from pyspark.sql.types import * | ||
from spark_session import with_cpu_session | ||
|
||
|
||
def create_orc(data_gen_list, data_path): | ||
# generate ORC dataframe, and dump it to local file 'data_path' | ||
with_cpu_session( | ||
lambda spark: gen_df(spark, data_gen_list).write.mode('overwrite').orc(data_path) | ||
) | ||
|
||
# TODO: merge test_casting_from_float and test_casting_from_double into one test | ||
# TODO: Need a float_gen with range [a, b], if float/double >= 1e13, then float/double -> timestamp will overflow | ||
''' | ||
We need this test cases: | ||
1. val * 1e3 <= LONG_MAX && val * 1e6 <= LONG_MAX (no overflow) | ||
2. val * 1e3 <= LONG_MAX && val * 1e6 > LONG_MAX (caught java.lang.ArithmeticException) | ||
3. val * 1e3 > LONG_MAX (caught java.lang.ArithmeticException) | ||
''' | ||
@pytest.mark.parametrize('to_type', ['double', 'boolean', 'tinyint', 'smallint', 'int', 'bigint', 'timestamp']) | ||
def test_casting_from_float(spark_tmp_path, to_type): | ||
orc_path = spark_tmp_path + '/orc_casting_from_float' | ||
data_gen = [('float_column', float_gen)] | ||
create_orc(data_gen, orc_path) | ||
schema_str = "float_column {}".format(to_type) | ||
assert_gpu_and_cpu_are_equal_collect( | ||
lambda spark: spark.read.schema(schema_str).orc(orc_path) | ||
) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: too many blank lines between tests. I believe the convention is to have 2, although we do not do this consistently in our tests. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
|
||
|
||
|
||
@pytest.mark.parametrize('to_type', ['float', 'boolean', 'tinyint', 'smallint', 'int', 'bigint', 'timestamp']) | ||
def test_casting_from_double(spark_tmp_path, to_type): | ||
orc_path = spark_tmp_path + '/orc_casting_from_double' | ||
data_gen = [('double_column', float_gen)] | ||
create_orc(data_gen, orc_path) | ||
schema_str = "double_column {}".format(to_type) | ||
assert_gpu_and_cpu_are_equal_collect( | ||
lambda spark: spark.read.schema(schema_str).orc(orc_path) | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -186,6 +186,78 @@ object GpuOrcScan extends Arm { | |
} | ||
} | ||
|
||
/** | ||
* Get the overflow flags in booleans. | ||
* true means no overflow, while false means getting overflow. | ||
* | ||
* @param doubleMillis the input double column | ||
* @param millis the long column casted from the doubleMillis | ||
*/ | ||
private def getOverflowFlags(doubleMillis: ColumnView, millis: ColumnView): ColumnView = { | ||
// No overflow when | ||
// doubleMillis <= Long.MAX_VALUE && | ||
// doubleMillis >= Long.MIN_VALUE && | ||
// ((millis >= 0) == (doubleMillis >= 0)) | ||
val rangeCheck = withResource(Scalar.fromLong(Long.MaxValue)) { max => | ||
withResource(doubleMillis.lessOrEqualTo(max)) { upperCheck => | ||
withResource(Scalar.fromLong(Long.MinValue)) { min => | ||
withResource(doubleMillis.greaterOrEqualTo(min)) { lowerCheck => | ||
upperCheck.and(lowerCheck) | ||
} | ||
} | ||
} | ||
} | ||
withResource(rangeCheck) { _ => | ||
val signCheck = withResource(Scalar.fromInt(0)) { zero => | ||
withResource(millis.greaterOrEqualTo(zero)) { longSign => | ||
withResource(doubleMillis.greaterOrEqualTo(zero)) { doubleSign => | ||
longSign.equalTo(doubleSign) | ||
} | ||
} | ||
} | ||
withResource(signCheck) { _ => | ||
rangeCheck.and(signCheck) | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Borrowed from ORC "ConvertTreeReaderFactory" | ||
* Scala does not support such numeric literal, so parse from string. | ||
*/ | ||
private val MIN_LONG_AS_DOUBLE = java.lang.Double.valueOf("-0x1p63") | ||
|
||
/** | ||
* We cannot store Long.MAX_VALUE as a double without losing precision. Instead, we store | ||
* Long.MAX_VALUE + 1 == -Long.MIN_VALUE, and then offset all comparisons by 1. | ||
*/ | ||
private val MAX_LONG_AS_DOUBLE_PLUS_ONE = java.lang.Double.valueOf("0x1p63") | ||
|
||
/** | ||
* Return a boolean column indicates whether the rows in col can fix in a long. | ||
* It assumes the input type is float or double. | ||
*/ | ||
private def doubleCanFitInLong(col: ColumnView): ColumnVector = { | ||
// It is true when | ||
// (MIN_LONG_AS_DOUBLE - doubleValue < 1.0) && | ||
// (doubleValue < MAX_LONG_AS_DOUBLE_PLUS_ONE) | ||
val lowRet = withResource(Scalar.fromDouble(MIN_LONG_AS_DOUBLE)) { sMin => | ||
withResource(Scalar.fromDouble(1.0)) { sOne => | ||
withResource(sMin.sub(col)) { diff => | ||
diff.lessThan(sOne) | ||
} | ||
} | ||
} | ||
withResource(lowRet) { _ => | ||
withResource(Scalar.fromDouble(MAX_LONG_AS_DOUBLE_PLUS_ONE)) { sMax => | ||
withResource(col.lessThan(sMax)) { highRet => | ||
lowRet.and(highRet) | ||
} | ||
} | ||
} | ||
} | ||
|
||
|
||
/** | ||
* Cast the column to the target type for ORC schema evolution. | ||
* It is designed to support all the cases that `canCast` returns true. | ||
|
@@ -211,6 +283,68 @@ object GpuOrcScan extends Arm { | |
} else { | ||
downCastAnyInteger(col, toDt) | ||
} | ||
|
||
// float/double(float64) to {bool, integer types, double/float, string, timestamp} | ||
// float to bool/integral | ||
case (DType.FLOAT32 | DType.FLOAT64, DType.BOOL8 | DType.INT8 | DType.INT16 | DType.INT32 | ||
| DType.INT64) => | ||
// Follow the CPU ORC conversion: | ||
// First replace rows that cannot fit in long with nulls, | ||
// next convert to long, | ||
// then down cast long to the target integral type. | ||
val longDoubles = withResource(doubleCanFitInLong(col)) { fitLongs => | ||
col.copyWithBooleanColumnAsValidity(fitLongs) | ||
} | ||
withResource(longDoubles) { _ => | ||
withResource(longDoubles.castTo(DType.INT64)) { longs => | ||
toDt match { | ||
case DType.BOOL8 => longs.castTo(toDt) | ||
case DType.INT64 => longs.incRefCount() | ||
case _ => downCastAnyInteger(longs, toDt) | ||
} | ||
} | ||
} | ||
|
||
// float/double to double/float | ||
case (DType.FLOAT32 | DType.FLOAT64, DType.FLOAT32 | DType.FLOAT64) => | ||
col.castTo(toDt) | ||
|
||
// FIXME float/double to string, there are some precision error issues | ||
case (DType.FLOAT32 | DType.FLOAT64, DType.STRING) => | ||
GpuCast.castFloatingTypeToString(col) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please file a follow on issue for us to go back an see what we can do to fix this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Ok, after merging this, I will file an issue to describe this problem. |
||
|
||
// float/double -> timestamp | ||
case (DType.FLOAT32 | DType.FLOAT64, DType.TIMESTAMP_MICROSECONDS) => | ||
// Follow the CPU ORC conversion. | ||
// val doubleMillis = doubleValue * 1000, | ||
// val millis = Math.round(doubleMillis) | ||
// if (noOverflow) millis else null | ||
val milliSeconds = withResource(Scalar.fromDouble(1000.0)) { thousand => | ||
// ORC assumes value is in seconds | ||
withResource(col.mul(thousand, DType.FLOAT64)) { doubleMillis => | ||
withResource(doubleMillis.round()) { millis => | ||
withResource(getOverflowFlags(doubleMillis, millis)) { overflows => | ||
millis.copyWithBooleanColumnAsValidity(overflows) | ||
} | ||
} | ||
} | ||
} | ||
// Cast milli-seconds to micro-seconds | ||
firestarman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// We need to pay attention that when convert (milliSeconds * 1000) to INT64, there may be | ||
// INT64-overflow. | ||
withResource(milliSeconds) { _ => | ||
// Test whether if there is long-overflow | ||
// If milliSeconds.max() > LONG_MAX, then milliSeconds.max().getLong will return LONG_MAX | ||
// If milliSeconds.max() * 1000 > LONG_MAX, then 'Math.multiplyExact' will throw an | ||
// exception (as CPU code does). | ||
Math.multiplyExact(milliSeconds.max().getLong, 1000.toLong) | ||
withResource(milliSeconds.mul(Scalar.fromDouble(1000.0))) { microSeconds => | ||
withResource(microSeconds.castTo(DType.INT64)) { longVec => | ||
longVec.castTo(DType.TIMESTAMP_MICROSECONDS) | ||
} | ||
} | ||
} | ||
|
||
// TODO more types, tracked in https://github.com/NVIDIA/spark-rapids/issues/5895 | ||
case (f, t) => | ||
throw new QueryExecutionException(s"Unsupported type casting: $f -> $t") | ||
|
@@ -239,6 +373,12 @@ object GpuOrcScan extends Arm { | |
} | ||
case VARCHAR => | ||
to.getCategory == STRING | ||
|
||
case FLOAT | DOUBLE => | ||
to.getCategory match { | ||
case BOOLEAN | BYTE | SHORT | INT | LONG | FLOAT | DOUBLE | STRING | TIMESTAMP => true | ||
case _ => false | ||
} | ||
// TODO more types, tracked in https://github.com/NVIDIA/spark-rapids/issues/5895 | ||
case _ => | ||
false | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these TODO comments still need addressing in this PR, or require follow-on issues to be filed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, I plan to address these TODOs in this PR.