diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala index cbc561c8811..2815bbbdd64 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala @@ -20,10 +20,12 @@ import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer + import ai.rapids.cudf import ai.rapids.cudf.{ColumnVector, DType, HostMemoryBuffer, Scalar, Schema, Table} import com.nvidia.spark.rapids._ import org.apache.hadoop.conf.Configuration + import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow @@ -377,21 +379,24 @@ class JsonPartitionReader( private def sanitizeNumbers(input: ColumnVector): ColumnVector = { // Note that this is not 100% consistent with Spark versions prior to Spark 3.3.0 // due to https://issues.apache.org/jira/browse/SPARK-38060 - val regex = if (parsedOptions.allowNonNumericNumbers) { - "^" + - "(?:" + - "(?:-?[0-9]+(?:\\.[0-9]+)?(?:[eE][\\-\\+]?[0-9]+)?)" + - "|NaN" + - "|(?:[\\+\\-]INF)" + - "|(?:[\\-\\+]?Infinity)" + - ")" + - "$" + // cuDF `isFloat` supports some inputs that are not valid JSON numbers, such as `.1`, `1.`, + // and `+1` so we use a regular expression to match valid JSON numbers instead + val jsonNumberRegexp = "^-?[0-9]+(?:\\.[0-9]+)?(?:[eE][\\-\\+]?[0-9]+)?$" + val isValid = if (parsedOptions.allowNonNumericNumbers) { + withResource(ColumnVector.fromStrings("NaN", "+INF", "-INF", "+Infinity", + "Infinity", "-Infinity")) { nonNumeric => + withResource(input.matchesRe(jsonNumberRegexp)) { isJsonNumber => + withResource(input.contains(nonNumeric)) { nonNumeric => + isJsonNumber.or(nonNumeric) + } + } + } } else { - "^-?[0-9]+(?:\\.[0-9]+)?(?:[eE][\\-\\+]?[0-9]+)?$" + input.matchesRe(jsonNumberRegexp) } - withResource(input.matchesRe(regex)) { validJsonDecimal => + withResource(isValid) { _ => withResource(Scalar.fromNull(DType.STRING)) { nullString => - validJsonDecimal.ifElse(input, nullString) + isValid.ifElse(input, nullString) } } }