diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala index 37a34fac66364..48db0c7d971c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala @@ -27,8 +27,15 @@ object UnsafeRowUtils { * - schema.fields.length == row.numFields should always be true * - UnsafeRow.calculateBitSetWidthInBytes(row.numFields) < row.getSizeInBytes should always be * true if the expectedSchema contains at least one field. - * - For variable-length fields: if null bit says it's null then don't do anything, else extract - * offset and size: + * - For variable-length fields: + * - if null bit says it's null, then + * - in general the offset-and-size should be zero + * - special case: variable-length DecimalType is considered mutable in UnsafeRow, and to + * support that, the offset is set to point to the variable-length part like a non-null + * value, while the size is set to zero to signal that it's a null value. The offset + * may also be set to zero, in which case this variable-length Decimal no longer supports + * being mutable in the UnsafeRow. + * - otherwise the field is not null, then extract offset and size: * 1) 0 <= size < row.getSizeInBytes should always be true. We can be even more precise than * this, where the upper bound of size can only be as big as the variable length part of * the row. @@ -52,9 +59,7 @@ object UnsafeRowUtils { var varLenFieldsSizeInBytes = 0 expectedSchema.fields.zipWithIndex.foreach { case (field, index) if !UnsafeRow.isFixedLength(field.dataType) && !row.isNullAt(index) => - val offsetAndSize = row.getLong(index) - val offset = (offsetAndSize >> 32).toInt - val size = offsetAndSize.toInt + val (offset, size) = getOffsetAndSize(row, index) if (size < 0 || offset < bitSetWidthInBytes + 8 * row.numFields || offset + size > rowSizeInBytes) { return false @@ -74,8 +79,26 @@ object UnsafeRowUtils { if ((row.getLong(index) >> 32) != 0L) return false case _ => } - case (_, index) if row.isNullAt(index) => - if (row.getLong(index) != 0L) return false + case (field, index) if row.isNullAt(index) => + field.dataType match { + case dt: DecimalType if !UnsafeRow.isFixedLength(dt) => + // See special case in UnsafeRowWriter.write(int, Decimal, int, int) and + // UnsafeRow.setDecimal(int, Decimal, int). + // A variable-length Decimal may be marked as null while having non-zero offset and + // zero length. This allows the field to be updated (i.e. mutable variable-length data) + + // Check the integrity of null value of variable-length DecimalType in UnsafeRow: + // 1. size must be zero + // 2. offset may be zero, in which case this variable-length field is no longer mutable + // 3. otherwise offset is non-zero, range check it the same way as a non-null value + val (offset, size) = getOffsetAndSize(row, index) + if (size != 0 || offset != 0 && + (offset < bitSetWidthInBytes + 8 * row.numFields || offset > rowSizeInBytes)) { + return false + } + case _ => + if (row.getLong(index) != 0L) return false + } case _ => } if (bitSetWidthInBytes + 8 * row.numFields + varLenFieldsSizeInBytes > rowSizeInBytes) { @@ -83,4 +106,11 @@ object UnsafeRowUtils { } true } + + def getOffsetAndSize(row: UnsafeRow, index: Int): (Int, Int) = { + val offsetAndSize = row.getLong(index) + val offset = (offsetAndSize >> 32).toInt + val size = offsetAndSize.toInt + (offset, size) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala index 4b6a3cfafd894..518d68ce1d285 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.catalyst.util +import java.math.{BigDecimal => JavaBigDecimal} + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, StringType, StructField, StructType} class UnsafeRowUtilsSuite extends SparkFunSuite { @@ -52,4 +54,31 @@ class UnsafeRowUtilsSuite extends SparkFunSuite { StructField("value2", IntegerType, false))) assert(!UnsafeRowUtils.validateStructuralIntegrity(testRow, invalidSchema)) } + + test("Handle special case for null variable-length Decimal") { + val schema = StructType(StructField("d", DecimalType(19, 0), nullable = true) :: Nil) + val unsafeRowProjection = UnsafeProjection.create(schema) + val row = unsafeRowProjection(new SpecificInternalRow(schema)) + + // row is empty at this point + assert(row.isNullAt(0) && UnsafeRowUtils.getOffsetAndSize(row, 0) == (16, 0)) + assert(UnsafeRowUtils.validateStructuralIntegrity(row, schema)) + + // set Decimal field to precision-overflowed value + val bigDecimalVal = Decimal(new JavaBigDecimal("12345678901234567890")) // precision=20, scale=0 + row.setDecimal(0, bigDecimalVal, 19) // should overflow and become null + assert(row.isNullAt(0) && UnsafeRowUtils.getOffsetAndSize(row, 0) == (16, 0)) + assert(UnsafeRowUtils.validateStructuralIntegrity(row, schema)) + + // set Decimal field to valid non-null value + val bigDecimalVal2 = Decimal(new JavaBigDecimal("1234567890123456789")) // precision=19, scale=0 + row.setDecimal(0, bigDecimalVal2, 19) // should succeed + assert(!row.isNullAt(0) && UnsafeRowUtils.getOffsetAndSize(row, 0) == (16, 8)) + assert(UnsafeRowUtils.validateStructuralIntegrity(row, schema)) + + // set Decimal field to null explicitly, after which this field no longer supports updating + row.setNullAt(0) + assert(row.isNullAt(0) && UnsafeRowUtils.getOffsetAndSize(row, 0) == (0, 0)) + assert(UnsafeRowUtils.validateStructuralIntegrity(row, schema)) + } }