-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support float/double castings for ORC reading [databricks] #6319
Changes from 9 commits
c312ecb
a729c86
bb2d3a6
ef1163d
2930ab8
e785de1
c6e05a5
69e9d14
0cf548a
0b2a675
e6e1aa9
d427a4d
db0f0d2
46e71f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -122,6 +122,7 @@ case class GpuOrcScan( | |
} | ||
|
||
object GpuOrcScan extends Arm { | ||
|
||
def tagSupport(scanMeta: ScanMeta[OrcScan]): Unit = { | ||
val scan = scanMeta.wrapped | ||
val schema = StructType(scan.readDataSchema ++ scan.readPartitionSchema) | ||
|
@@ -186,6 +187,78 @@ object GpuOrcScan extends Arm { | |
} | ||
} | ||
|
||
/** | ||
* Get the overflow flags in booleans. | ||
* true means no overflow, while false means getting overflow. | ||
* | ||
* @param doubleMillis the input double column | ||
* @param millis the long column casted from the doubleMillis | ||
*/ | ||
private def getOverflowFlags(doubleMillis: ColumnView, millis: ColumnView): ColumnView = { | ||
// No overflow when | ||
// doubleMillis <= Long.MAX_VALUE && | ||
// doubleMillis >= Long.MIN_VALUE && | ||
// ((millis >= 0) == (doubleMillis >= 0)) | ||
val rangeCheck = withResource(Scalar.fromLong(Long.MaxValue)) { max => | ||
withResource(doubleMillis.lessOrEqualTo(max)) { upperCheck => | ||
withResource(Scalar.fromLong(Long.MinValue)) { min => | ||
withResource(doubleMillis.greaterOrEqualTo(min)) { lowerCheck => | ||
upperCheck.and(lowerCheck) | ||
} | ||
} | ||
} | ||
} | ||
withResource(rangeCheck) { _ => | ||
val signCheck = withResource(Scalar.fromInt(0)) { zero => | ||
withResource(millis.greaterOrEqualTo(zero)) { longSign => | ||
withResource(doubleMillis.greaterOrEqualTo(zero)) { doubleSign => | ||
longSign.equalTo(doubleSign) | ||
} | ||
} | ||
} | ||
withResource(signCheck) { _ => | ||
rangeCheck.and(signCheck) | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Borrowed from ORC "ConvertTreeReaderFactory" | ||
* Scala does not support such numeric literal, so parse from string. | ||
*/ | ||
private val MIN_LONG_AS_DOUBLE = java.lang.Double.valueOf("-0x1p63") | ||
|
||
/** | ||
* We cannot store Long.MAX_VALUE as a double without losing precision. Instead, we store | ||
* Long.MAX_VALUE + 1 == -Long.MIN_VALUE, and then offset all comparisons by 1. | ||
*/ | ||
private val MAX_LONG_AS_DOUBLE_PLUS_ONE = java.lang.Double.valueOf("0x1p63") | ||
|
||
/** | ||
* Return a boolean column indicates whether the rows in col can fix in a long. | ||
* It assumes the input type is float or double. | ||
*/ | ||
private def doubleCanFitInLong(col: ColumnView): ColumnVector = { | ||
// It is true when | ||
// (MIN_LONG_AS_DOUBLE - doubleValue < 1.0) && | ||
// (doubleValue < MAX_LONG_AS_DOUBLE_PLUS_ONE) | ||
val lowRet = withResource(Scalar.fromDouble(MIN_LONG_AS_DOUBLE)) { sMin => | ||
withResource(Scalar.fromDouble(1.0)) { sOne => | ||
withResource(sMin.sub(col)) { diff => | ||
diff.lessThan(sOne) | ||
} | ||
} | ||
} | ||
withResource(lowRet) { _ => | ||
withResource(Scalar.fromDouble(MAX_LONG_AS_DOUBLE_PLUS_ONE)) { sMax => | ||
withResource(col.lessThan(sMax)) { highRet => | ||
lowRet.and(highRet) | ||
} | ||
} | ||
} | ||
} | ||
|
||
|
||
/** | ||
* Cast the column to the target type for ORC schema evolution. | ||
* It is designed to support all the cases that `canCast` returns true. | ||
|
@@ -233,6 +306,73 @@ object GpuOrcScan extends Arm { | |
DType.TIMESTAMP_MICROSECONDS) => | ||
OrcCastingShims.castIntegerToTimestamp(col, fromDt) | ||
|
||
// float to bool/integral | ||
case (DType.FLOAT32 | DType.FLOAT64, DType.BOOL8 | DType.INT8 | DType.INT16 | DType.INT32 | ||
| DType.INT64) => | ||
// Follow the CPU ORC conversion: | ||
// First replace rows that cannot fit in long with nulls, | ||
// next convert to long, | ||
// then down cast long to the target integral type. | ||
val longDoubles = withResource(doubleCanFitInLong(col)) { fitLongs => | ||
col.copyWithBooleanColumnAsValidity(fitLongs) | ||
} | ||
withResource(longDoubles) { _ => | ||
withResource(longDoubles.castTo(DType.INT64)) { longs => | ||
toDt match { | ||
case DType.BOOL8 => longs.castTo(toDt) | ||
case DType.INT64 => longs.incRefCount() | ||
case _ => downCastAnyInteger(longs, toDt) | ||
} | ||
} | ||
} | ||
|
||
// float/double to double/float | ||
case (DType.FLOAT32 | DType.FLOAT64, DType.FLOAT32 | DType.FLOAT64) => | ||
col.castTo(toDt) | ||
|
||
// float/double to string | ||
// cuDF keep 9 decimal numbers after the decimal point, and CPU keeps more than 10. | ||
// So when casting float/double to string, the result of GPU is different from CPU. | ||
// We let a conf 'spark.rapids.sql.format.orc.floatTypesToString.enable' to control it's | ||
// enable or not. | ||
case (DType.FLOAT32 | DType.FLOAT64, DType.STRING) => | ||
GpuCast.castFloatingTypeToString(col) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please file a follow on issue for us to go back an see what we can do to fix this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Ok, after merging this, I will file an issue to describe this problem. |
||
|
||
// float/double -> timestamp | ||
case (DType.FLOAT32 | DType.FLOAT64, DType.TIMESTAMP_MICROSECONDS) => | ||
// Follow the CPU ORC conversion. | ||
// val doubleMillis = doubleValue * 1000, | ||
// val millis = Math.round(doubleMillis) | ||
// if (noOverflow) millis else null | ||
val milliSeconds = withResource(Scalar.fromDouble(1000.0)) { thousand => | ||
// ORC assumes value is in seconds | ||
withResource(col.mul(thousand, DType.FLOAT64)) { doubleMillis => | ||
withResource(doubleMillis.round()) { millis => | ||
withResource(getOverflowFlags(doubleMillis, millis)) { overflowFlags => | ||
millis.copyWithBooleanColumnAsValidity(overflowFlags) | ||
} | ||
} | ||
} | ||
} | ||
// Cast milli-seconds to micro-seconds | ||
firestarman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// We need to pay attention that when convert (milliSeconds * 1000) to INT64, there may be | ||
// INT64-overflow. | ||
// In this step, ORC casting of CPU throw an exception rather than replace such values with | ||
// null. We followed the CPU code here. | ||
withResource(milliSeconds) { _ => | ||
// Test whether if there is long-overflow | ||
// If milliSeconds.max() * 1000 > LONG_MAX, then 'Math.multiplyExact' will | ||
// throw an exception (as CPU code does). | ||
if (milliSeconds.max() != null) { | ||
testLongMultiplicationOverflow(milliSeconds.max().getDouble.toLong, 1000L) | ||
} | ||
withResource(milliSeconds.mul(Scalar.fromDouble(1000.0))) { microSeconds => | ||
withResource(microSeconds.castTo(DType.INT64)) { longVec => | ||
longVec.castTo(DType.TIMESTAMP_MICROSECONDS) | ||
} | ||
} | ||
} | ||
|
||
// TODO more types, tracked in https://github.com/NVIDIA/spark-rapids/issues/5895 | ||
case (f, t) => | ||
throw new QueryExecutionException(s"Unsupported type casting: $f -> $t") | ||
|
@@ -246,7 +386,8 @@ object GpuOrcScan extends Arm { | |
* but the ones between GPU supported types. | ||
* Each supported casting is implemented in "castColumnTo". | ||
*/ | ||
def canCast(from: TypeDescription, to: TypeDescription): Boolean = { | ||
def canCast(from: TypeDescription, to: TypeDescription, | ||
isOrcFloatTypesToStringEnable: Boolean): Boolean = { | ||
import org.apache.orc.TypeDescription.Category._ | ||
if (!to.getCategory.isPrimitive || !from.getCategory.isPrimitive) { | ||
// Don't convert from any to complex, or from complex to any. | ||
|
@@ -268,7 +409,16 @@ object GpuOrcScan extends Arm { | |
} | ||
case VARCHAR => | ||
toType == STRING | ||
case _ => false | ||
|
||
case FLOAT | DOUBLE => | ||
toType match { | ||
case BOOLEAN | BYTE | SHORT | INT | LONG | FLOAT | DOUBLE | TIMESTAMP => true | ||
case STRING => isOrcFloatTypesToStringEnable | ||
case _ => false | ||
} | ||
// TODO more types, tracked in https://github.com/NVIDIA/spark-rapids/issues/5895 | ||
case _ => | ||
false | ||
} | ||
} | ||
|
||
|
@@ -313,7 +463,8 @@ case class GpuOrcMultiFilePartitionReaderFactory( | |
private val debugDumpPrefix = Option(rapidsConf.orcDebugDumpPrefix) | ||
private val numThreads = rapidsConf.multiThreadReadNumThreads | ||
private val maxNumFileProcessed = rapidsConf.maxNumOrcFilesParallel | ||
private val filterHandler = GpuOrcFileFilterHandler(sqlConf, broadcastedConf, filters) | ||
private val filterHandler = GpuOrcFileFilterHandler(sqlConf, broadcastedConf, filters, | ||
rapidsConf.isOrcFloatTypesToStringEnable) | ||
private val ignoreMissingFiles = sqlConf.ignoreMissingFiles | ||
private val ignoreCorruptFiles = sqlConf.ignoreCorruptFiles | ||
|
||
|
@@ -400,7 +551,8 @@ case class GpuOrcPartitionReaderFactory( | |
private val debugDumpPrefix = Option(rapidsConf.orcDebugDumpPrefix) | ||
private val maxReadBatchSizeRows: Integer = rapidsConf.maxReadBatchSizeRows | ||
private val maxReadBatchSizeBytes: Long = rapidsConf.maxReadBatchSizeBytes | ||
private val filterHandler = GpuOrcFileFilterHandler(sqlConf, broadcastedConf, pushedFilters) | ||
private val filterHandler = GpuOrcFileFilterHandler(sqlConf, broadcastedConf, pushedFilters, | ||
rapidsConf.isOrcFloatTypesToStringEnable) | ||
|
||
override def supportColumnarReads(partition: InputPartition): Boolean = true | ||
|
||
|
@@ -931,7 +1083,8 @@ private object OrcTools extends Arm { | |
private case class GpuOrcFileFilterHandler( | ||
@transient sqlConf: SQLConf, | ||
broadcastedConf: Broadcast[SerializableConfiguration], | ||
pushedFilters: Array[Filter]) extends Arm { | ||
pushedFilters: Array[Filter], | ||
isOrcFloatTypesToStringEnable: Boolean) extends Arm { | ||
|
||
private[rapids] val isCaseSensitive = sqlConf.caseSensitiveAnalysis | ||
|
||
|
@@ -1026,7 +1179,7 @@ private case class GpuOrcFileFilterHandler( | |
val isCaseSensitive = readerOpts.getIsSchemaEvolutionCaseAware | ||
|
||
val (updatedReadSchema, fileIncluded) = checkSchemaCompatibility(orcReader.getSchema, | ||
readerOpts.getSchema, isCaseSensitive) | ||
readerOpts.getSchema, isCaseSensitive, isOrcFloatTypesToStringEnable) | ||
// GPU has its own read schema, so unset the reader include to read all the columns | ||
// specified by its read schema. | ||
readerOpts.include(null) | ||
|
@@ -1206,11 +1359,13 @@ private case class GpuOrcFileFilterHandler( | |
private def checkSchemaCompatibility( | ||
fileSchema: TypeDescription, | ||
readSchema: TypeDescription, | ||
isCaseAware: Boolean): (TypeDescription, Array[Boolean]) = { | ||
isCaseAware: Boolean, | ||
isOrcFloatTypesToStringEnable: Boolean): (TypeDescription, Array[Boolean]) = { | ||
// all default to false | ||
val fileIncluded = new Array[Boolean](fileSchema.getMaximumId + 1) | ||
val isForcePos = OrcShims.forcePositionalEvolution(conf) | ||
(checkTypeCompatibility(fileSchema, readSchema, isCaseAware, fileIncluded, isForcePos), | ||
(checkTypeCompatibility(fileSchema, readSchema, isCaseAware, fileIncluded, isForcePos, | ||
isOrcFloatTypesToStringEnable), | ||
fileIncluded) | ||
} | ||
|
||
|
@@ -1224,7 +1379,8 @@ private case class GpuOrcFileFilterHandler( | |
readType: TypeDescription, | ||
isCaseAware: Boolean, | ||
fileIncluded: Array[Boolean], | ||
isForcePos: Boolean): TypeDescription = { | ||
isForcePos: Boolean, | ||
isOrcFloatTypesToStringEnable: Boolean): TypeDescription = { | ||
(fileType.getCategory, readType.getCategory) match { | ||
case (TypeDescription.Category.STRUCT, TypeDescription.Category.STRUCT) => | ||
// Check for the top or nested struct types. | ||
|
@@ -1252,7 +1408,7 @@ private case class GpuOrcFileFilterHandler( | |
.zipWithIndex.foreach { case ((fileFieldName, fType), idx) => | ||
getReadFieldType(fileFieldName, idx).foreach { case (rField, rType) => | ||
val newChild = checkTypeCompatibility(fType, rType, | ||
isCaseAware, fileIncluded, isForcePos) | ||
isCaseAware, fileIncluded, isForcePos, isOrcFloatTypesToStringEnable) | ||
prunedReadSchema.addField(rField, newChild) | ||
} | ||
} | ||
|
@@ -1262,19 +1418,22 @@ private case class GpuOrcFileFilterHandler( | |
// for struct children. | ||
case (TypeDescription.Category.LIST, TypeDescription.Category.LIST) => | ||
val newChild = checkTypeCompatibility(fileType.getChildren.get(0), | ||
readType.getChildren.get(0), isCaseAware, fileIncluded, isForcePos) | ||
readType.getChildren.get(0), isCaseAware, fileIncluded, isForcePos, | ||
isOrcFloatTypesToStringEnable) | ||
fileIncluded(fileType.getId) = true | ||
TypeDescription.createList(newChild) | ||
case (TypeDescription.Category.MAP, TypeDescription.Category.MAP) => | ||
val newKey = checkTypeCompatibility(fileType.getChildren.get(0), | ||
readType.getChildren.get(0), isCaseAware, fileIncluded, isForcePos) | ||
readType.getChildren.get(0), isCaseAware, fileIncluded, isForcePos, | ||
isOrcFloatTypesToStringEnable) | ||
val newValue = checkTypeCompatibility(fileType.getChildren.get(1), | ||
readType.getChildren.get(1), isCaseAware, fileIncluded, isForcePos) | ||
readType.getChildren.get(1), isCaseAware, fileIncluded, isForcePos, | ||
isOrcFloatTypesToStringEnable) | ||
fileIncluded(fileType.getId) = true | ||
TypeDescription.createMap(newKey, newValue) | ||
case (ft, rt) if ft.isPrimitive && rt.isPrimitive => | ||
if (OrcShims.typeDescriptionEqual(fileType, readType) || | ||
GpuOrcScan.canCast(fileType, readType)) { | ||
GpuOrcScan.canCast(fileType, readType, isOrcFloatTypesToStringEnable)) { | ||
// Since type casting is supported, here should return the file type. | ||
fileIncluded(fileType.getId) = true | ||
fileType.clone() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -864,6 +864,16 @@ object RapidsConf { | |
.booleanConf | ||
.createWithDefault(true) | ||
|
||
val ENABLE_ORC_FLOAT_TYPES_TO_STRING = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add some docs to this that indicate what will happen if we run into this situation. For most configs when we are in this kind of a situation we fall back to the CPU, but here we will throw an exception and the job will fail. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have updated this in |
||
conf("spark.rapids.sql.format.orc.floatTypesToString.enable") | ||
.doc("When reading an ORC file, the source data schemas(schemas of ORC file) may differ " + | ||
"from the target schemas (schemas of the reader), we need to handle the castings from " + | ||
"source type to target type. Since float/double numbers in GPU have different precision " + | ||
"with CPU, when casting float/double to string, the result of GPU is different from " + | ||
"result of CPU spark.") | ||
.booleanConf | ||
.createWithDefault(true) | ||
|
||
val ORC_READER_TYPE = conf("spark.rapids.sql.format.orc.reader.type") | ||
.doc("Sets the ORC reader type. We support different types that are optimized for " + | ||
"different environments. The original Spark style reader can be selected by setting this " + | ||
|
@@ -1856,6 +1866,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { | |
|
||
lazy val isOrcWriteEnabled: Boolean = get(ENABLE_ORC_WRITE) | ||
|
||
lazy val isOrcFloatTypesToStringEnable: Boolean = get(ENABLE_ORC_FLOAT_TYPES_TO_STRING) | ||
|
||
lazy val isOrcPerFileReadEnabled: Boolean = | ||
RapidsReaderType.withName(get(ORC_READER_TYPE)) == RapidsReaderType.PERFILE | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we off if we don't do this? It feels odd that we would get a different answer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, it's okay to remove
approximate_float
, we can still pass the test.But I think we should pay attention to the method of comparing float types numbers whether if they are equal.
For example,
I don't know whether if the conversion
float -> double
in GPU is same as CPU.We should check two float types numbers if they're equal via
abs(val1 - val2) < EPSLION
, whereEPSILON
is the allowable accuracy error.