Skip to content

Commit

Permalink
Support predictors on ANSI day time interval type (#4946)
Browse files Browse the repository at this point in the history
Signed-off-by: Chong Gao <res_life@163.com>
  • Loading branch information
Chong Gao authored Mar 16, 2022
1 parent 9b6f5dc commit 410e42b
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 9 deletions.
101 changes: 100 additions & 1 deletion integration_tests/src/main/python/cmp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from asserts import assert_gpu_and_cpu_are_equal_collect
from data_gen import *
from spark_session import with_cpu_session
from spark_session import with_cpu_session, is_before_spark_330
from pyspark.sql.types import *
import pyspark.sql.functions as f

Expand All @@ -32,6 +32,19 @@ def test_eq(data_gen):
f.col('b') == f.lit(None).cast(data_type),
f.col('a') == f.col('b')))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_eq_for_interval():
data_gen = DayTimeIntervalGen()
(s1, s2) = gen_scalars(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
data_type = data_gen.data_type
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).select(
f.col('a') == s1,
s2 == f.col('b'),
f.lit(None).cast(data_type) == f.col('a'),
f.col('b') == f.lit(None).cast(data_type),
f.col('a') == f.col('b')))

@pytest.mark.parametrize('data_gen', eq_gens_with_decimal_gen, ids=idfn)
def test_eq_ns(data_gen):
(s1, s2) = gen_scalars(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
Expand All @@ -44,6 +57,19 @@ def test_eq_ns(data_gen):
f.col('b').eqNullSafe(f.lit(None).cast(data_type)),
f.col('a').eqNullSafe(f.col('b'))))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_eq_ns_for_interval():
data_gen = DayTimeIntervalGen()
(s1, s2) = gen_scalars(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
data_type = data_gen.data_type
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).select(
f.col('a').eqNullSafe(s1),
s2.eqNullSafe(f.col('b')),
f.lit(None).cast(data_type).eqNullSafe(f.col('a')),
f.col('b').eqNullSafe(f.lit(None).cast(data_type)),
f.col('a').eqNullSafe(f.col('b'))))

@pytest.mark.parametrize('data_gen', eq_gens_with_decimal_gen, ids=idfn)
def test_ne(data_gen):
(s1, s2) = gen_scalars(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
Expand All @@ -56,6 +82,19 @@ def test_ne(data_gen):
f.col('b') != f.lit(None).cast(data_type),
f.col('a') != f.col('b')))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_ne_for_interval():
data_gen = DayTimeIntervalGen()
(s1, s2) = gen_scalars(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
data_type = data_gen.data_type
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).select(
f.col('a') != s1,
s2 != f.col('b'),
f.lit(None).cast(data_type) != f.col('a'),
f.col('b') != f.lit(None).cast(data_type),
f.col('a') != f.col('b')))

@pytest.mark.parametrize('data_gen', orderable_gens, ids=idfn)
def test_lt(data_gen):
(s1, s2) = gen_scalars(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
Expand All @@ -68,6 +107,19 @@ def test_lt(data_gen):
f.col('b') < f.lit(None).cast(data_type),
f.col('a') < f.col('b')))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_lt_for_interval():
data_gen = DayTimeIntervalGen()
(s1, s2) = gen_scalars(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
data_type = data_gen.data_type
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).select(
f.col('a') < s1,
s2 < f.col('b'),
f.lit(None).cast(data_type) < f.col('a'),
f.col('b') < f.lit(None).cast(data_type),
f.col('a') < f.col('b')))

@pytest.mark.parametrize('data_gen', orderable_gens, ids=idfn)
def test_lte(data_gen):
(s1, s2) = gen_scalars(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
Expand All @@ -80,6 +132,20 @@ def test_lte(data_gen):
f.col('b') <= f.lit(None).cast(data_type),
f.col('a') <= f.col('b')))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_lte_for_interval():
data_gen = DayTimeIntervalGen()
(s1, s2) = gen_scalars(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
data_type = data_gen.data_type
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).select(
f.col('a') <= s1,
s2 <= f.col('b'),
f.lit(None).cast(data_type) <= f.col('a'),
f.col('b') <= f.lit(None).cast(data_type),
f.col('a') <= f.col('b')))


@pytest.mark.parametrize('data_gen', orderable_gens, ids=idfn)
def test_gt(data_gen):
(s1, s2) = gen_scalars(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
Expand All @@ -92,6 +158,19 @@ def test_gt(data_gen):
f.col('b') > f.lit(None).cast(data_type),
f.col('a') > f.col('b')))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_gt_interval():
data_gen = DayTimeIntervalGen()
(s1, s2) = gen_scalars(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
data_type = data_gen.data_type
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).select(
f.col('a') > s1,
s2 > f.col('b'),
f.lit(None).cast(data_type) > f.col('a'),
f.col('b') > f.lit(None).cast(data_type),
f.col('a') > f.col('b')))

@pytest.mark.parametrize('data_gen', orderable_gens, ids=idfn)
def test_gte(data_gen):
(s1, s2) = gen_scalars(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
Expand All @@ -104,12 +183,32 @@ def test_gte(data_gen):
f.col('b') >= f.lit(None).cast(data_type),
f.col('a') >= f.col('b')))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_gte_for_interval():
data_gen = DayTimeIntervalGen()
(s1, s2) = gen_scalars(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
data_type = data_gen.data_type
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).select(
f.col('a') >= s1,
s2 >= f.col('b'),
f.lit(None).cast(data_type) >= f.col('a'),
f.col('b') >= f.lit(None).cast(data_type),
f.col('a') >= f.col('b')))

@pytest.mark.parametrize('data_gen', eq_gens_with_decimal_gen + array_gens_sample + struct_gens_sample + map_gens_sample, ids=idfn)
def test_isnull(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(
f.isnull(f.col('a'))))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_isnull_for_interval():
data_gen = DayTimeIntervalGen()
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(
f.isnull(f.col('a'))))

@pytest.mark.parametrize('data_gen', [FloatGen(), DoubleGen()], ids=idfn)
def test_isnan(data_gen):
assert_gpu_and_cpu_are_equal_collect(
Expand Down
15 changes: 13 additions & 2 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest

from asserts import assert_cpu_and_gpu_are_equal_collect_with_capture, assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect
from asserts import assert_cpu_and_gpu_are_equal_collect_with_capture, assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_gpu_and_cpu_are_equal_sql
from data_gen import *
from marks import *
from pyspark.sql.types import *
Expand Down Expand Up @@ -806,4 +806,15 @@ def test_parquet_read_daytime_interval_gpu_file(spark_tmp_path):
# write DayTimeInterval with GPU
with_gpu_session(lambda spark :gen_df(spark, gen_list).coalesce(1).write.mode("overwrite").parquet(data_path))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.read.parquet(data_path))
lambda spark: spark.read.parquet(data_path))


@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_parquet_push_down_on_interval_type(spark_tmp_path):
gen_list = [('_c1', DayTimeIntervalGen())]
data_path = spark_tmp_path + '/PARQUET_DATA'
with_cpu_session(lambda spark: gen_df(spark, gen_list).coalesce(1).write.parquet(data_path))
assert_gpu_and_cpu_are_equal_sql(
lambda spark: spark.read.parquet(data_path),
"testData",
"select * from testData where _c1 > interval '10 0:0:0' day to second")
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,16 @@ object GpuTypeShims {
}

def isParquetColumnarWriterSupportedForType(colType: DataType): Boolean = false

/**
* Whether the Shim supports converting the given type to GPU Scalar
*/
def supportToScalarForType(t: DataType): Boolean = false

/**
* Convert the given value to Scalar
*/
def toScalarForType(t: DataType, v: Any) = {
throw new RuntimeException(s"Can not convert $v to scalar for type $t.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.nvidia.spark.rapids.shims

import ai.rapids.cudf.DType
import ai.rapids.cudf.{DType, Scalar}
import com.nvidia.spark.rapids.ColumnarCopyHelper
import com.nvidia.spark.rapids.GpuRowToColumnConverter.{LongConverter, NotNullLongConverter, TypeConverter}

Expand Down Expand Up @@ -118,4 +118,29 @@ object GpuTypeShims {
case DayTimeIntervalType(_, _) => true
case _ => false
}

/**
* Whether the Shim supports converting the given type to GPU Scalar
*/
def supportToScalarForType(t: DataType): Boolean = {
t match {
case _: DayTimeIntervalType => true
case _ => false
}
}

/**
* Convert the given value to Scalar
*/
def toScalarForType(t: DataType, v: Any) = {
t match {
case _: DayTimeIntervalType => v match {
case l: Long => Scalar.fromLong(l)
case _ => throw new IllegalArgumentException(s"'$v: ${v.getClass}' is not supported" +
s" for LongType, expecting Long")
}
case _ =>
throw new RuntimeException(s"Can not convert $v to scalar for type $t.")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,21 @@ package com.nvidia.spark.rapids.shims

import com.nvidia.spark.InMemoryTableScanMeta
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.GpuOverrides
import org.apache.parquet.schema.MessageType

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Coalesce, DynamicPruningExpression, Expression, FileSourceMetadataAttribute, TimeAdd}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.json.rapids.shims.Spark33XFileOptionsShims
import org.apache.spark.sql.execution.{BaseSubqueryExec, CoalesceExec, FileSourceScanExec, InSubqueryExec, ProjectExec, ReusedSubqueryExec, SparkPlan, SubqueryBroadcastExec}
import org.apache.spark.sql.execution.{BaseSubqueryExec, CoalesceExec, FileSourceScanExec, FilterExec, InSubqueryExec, ProjectExec, ReusedSubqueryExec, SparkPlan, SubqueryBroadcastExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources.{DataSourceUtils, FilePartition, FileScanRDD, HadoopFsRelation, PartitionedFile}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.GpuFileSourceScanExec
import org.apache.spark.sql.rapids._
import org.apache.spark.sql.rapids.shims.GpuTimeAdd
import org.apache.spark.sql.types.{CalendarIntervalType, DayTimeIntervalType, StructType}
import org.apache.spark.unsafe.types.CalendarInterval
Expand Down Expand Up @@ -154,6 +155,101 @@ trait Spark33XShims extends Spark33XFileOptionsShims {

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuTimeAdd(lhs, rhs)
}),
GpuOverrides.expr[IsNull](
"Checks if a value is null",
ExprChecks.unaryProject(TypeSig.BOOLEAN, TypeSig.BOOLEAN,
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.MAP + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.DECIMAL_128 + TypeSig.DAYTIME).nested(),
TypeSig.all),
(a, conf, p, r) => new UnaryExprMeta[IsNull](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression = GpuIsNull(child)
}),
GpuOverrides.expr[IsNotNull](
"Checks if a value is not null",
ExprChecks.unaryProject(TypeSig.BOOLEAN, TypeSig.BOOLEAN,
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.MAP + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.DECIMAL_128 + TypeSig.DAYTIME).nested(),
TypeSig.all),
(a, conf, p, r) => new UnaryExprMeta[IsNotNull](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression = GpuIsNotNull(child)
}),
GpuOverrides.expr[EqualNullSafe](
"Check if the values are equal including nulls <=>",
ExprChecks.binaryProject(
TypeSig.BOOLEAN, TypeSig.BOOLEAN,
("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME,
TypeSig.comparable),
("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME,
TypeSig.comparable)),
(a, conf, p, r) => new BinaryExprMeta[EqualNullSafe](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuEqualNullSafe(lhs, rhs)
}),
GpuOverrides.expr[EqualTo](
"Check if the values are equal",
ExprChecks.binaryProjectAndAst(
TypeSig.comparisonAstTypes,
TypeSig.BOOLEAN, TypeSig.BOOLEAN,
("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME,
TypeSig.comparable),
("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME,
TypeSig.comparable)),
(a, conf, p, r) => new BinaryAstExprMeta[EqualTo](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuEqualTo(lhs, rhs)
}),
GpuOverrides.expr[GreaterThan](
"> operator",
ExprChecks.binaryProjectAndAst(
TypeSig.comparisonAstTypes,
TypeSig.BOOLEAN, TypeSig.BOOLEAN,
("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME,
TypeSig.orderable),
("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME,
TypeSig.orderable)),
(a, conf, p, r) => new BinaryAstExprMeta[GreaterThan](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuGreaterThan(lhs, rhs)
}),
GpuOverrides.expr[GreaterThanOrEqual](
">= operator",
ExprChecks.binaryProjectAndAst(
TypeSig.comparisonAstTypes,
TypeSig.BOOLEAN, TypeSig.BOOLEAN,
("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME,
TypeSig.orderable),
("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME,
TypeSig.orderable)),
(a, conf, p, r) => new BinaryAstExprMeta[GreaterThanOrEqual](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuGreaterThanOrEqual(lhs, rhs)
}),
GpuOverrides.expr[LessThan](
"< operator",
ExprChecks.binaryProjectAndAst(
TypeSig.comparisonAstTypes,
TypeSig.BOOLEAN, TypeSig.BOOLEAN,
("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME,
TypeSig.orderable),
("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME,
TypeSig.orderable)),
(a, conf, p, r) => new BinaryAstExprMeta[LessThan](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuLessThan(lhs, rhs)
}),
GpuOverrides.expr[LessThanOrEqual](
"<= operator",
ExprChecks.binaryProjectAndAst(
TypeSig.comparisonAstTypes,
TypeSig.BOOLEAN, TypeSig.BOOLEAN,
("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME,
TypeSig.orderable),
("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME,
TypeSig.orderable)),
(a, conf, p, r) => new BinaryAstExprMeta[LessThanOrEqual](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuLessThanOrEqual(lhs, rhs)
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
super.getExprs ++ map
Expand Down Expand Up @@ -274,7 +370,15 @@ trait Spark33XShims extends Spark33XFileOptionsShims {
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP +
TypeSig.ARRAY + TypeSig.DECIMAL_128 + TypeSig.DAYTIME).nested(),
TypeSig.all),
(proj, conf, p, r) => new GpuProjectExecMeta(proj, conf, p, r))
(proj, conf, p, r) => new GpuProjectExecMeta(proj, conf, p, r)),
GpuOverrides.exec[FilterExec](
"The backend for most filter statements",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP +
TypeSig.ARRAY + TypeSig.DECIMAL_128 + TypeSig.DAYTIME).nested(), TypeSig.all),
(filter, conf, p, r) => new SparkPlanMeta[FilterExec](filter, conf, p, r) {
override def convertToGpu(): GpuExec =
GpuFilterExec(childExprs.head.convertToGpu(), childPlans.head.convertIfNeeded())
})
).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap
super.getExecs ++ map
}
Expand Down
Loading

0 comments on commit 410e42b

Please sign in to comment.