Skip to content

Commit

Permalink
Fix bug with InSet and Strings (NVIDIA#437)
Browse files Browse the repository at this point in the history
* Fix bug with InSet and Strings

Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>

* Addressed review comments

Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
revans2 authored Jul 27, 2020
1 parent 18ff143 commit aa998f9
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
16 changes: 15 additions & 1 deletion integration_tests/src/main/python/cmp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from asserts import assert_gpu_and_cpu_are_equal_collect
from data_gen import *
from marks import incompat, approximate_float
from spark_session import with_cpu_session
from pyspark.sql.types import *
import pyspark.sql.functions as f

Expand Down Expand Up @@ -137,10 +138,23 @@ def test_filter_with_lit(expr):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, LongGen()).filter(expr))

# Spark supports two different versions of 'IN', and it depends on the spark.sql.optimizer.inSetConversionThreshold conf
# This is to test entries under that value.
@pytest.mark.parametrize('data_gen', eq_gens, ids=idfn)
def test_in(data_gen):
# nulls are not supported for in on the GPU yet
scalars = list(gen_scalars(data_gen, 5, force_no_nulls=True))
num_entries = int(with_cpu_session(lambda spark: spark.conf.get('spark.sql.optimizer.inSetConversionThreshold'))) - 1
scalars = list(gen_scalars(data_gen, num_entries, force_no_nulls=True))
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(f.col('a').isin(scalars)))

# Spark supports two different versions of 'IN', and it depends on the spark.sql.optimizer.inSetConversionThreshold conf
# This is to test entries over that value.
@pytest.mark.parametrize('data_gen', eq_gens, ids=idfn)
def test_in_set(data_gen):
# nulls are not supported for in on the GPU yet
num_entries = int(with_cpu_session(lambda spark: spark.conf.get('spark.sql.optimizer.inSetConversionThreshold'))) + 1
scalars = list(gen_scalars(data_gen, num_entries, force_no_nulls=True))
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(f.col('a').isin(scalars)))

Original file line number Diff line number Diff line change
Expand Up @@ -1052,14 +1052,14 @@ object GpuOverrides {
if (in.hset.contains(null)) {
willNotWorkOnGpu("nulls are not supported")
}
val literalTypes = in.hset.map(Literal(_).dataType).toSeq
val literalTypes = in.hset.map(LiteralHelper(_).dataType).toSeq
if (!areAllSupportedTypes(literalTypes:_*)) {
val unsupported = literalTypes.filter(!areAllSupportedTypes(_)).mkString(", ")
willNotWorkOnGpu(s"unsupported literal types: $unsupported")
}
}
override def convertToGpu(): GpuExpression =
GpuInSet(childExprs.head.convertToGpu(), in.hset.map(Literal(_)).toSeq)
GpuInSet(childExprs.head.convertToGpu(), in.hset.map(LiteralHelper(_)).toSeq)
}),
expr[LessThan](
"< operator",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,21 @@ import javax.xml.bind.DatatypeConverter
import ai.rapids.cudf.{DType, Scalar}
import org.json4s.JsonAST.{JField, JNull, JString}

import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.UTF8String

object LiteralHelper {
def apply(v: Any): Literal = v match {
case u: UTF8String => Literal(u, StringType)
case allOthers => Literal(allOthers)
}
}

object GpuScalar {
def scalaTypeToDType(v: Any): DType = {
v match {
Expand Down

0 comments on commit aa998f9

Please sign in to comment.