Skip to content

Commit

Permalink
Move input metadata tests to pyspark (#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
revans2 authored Jun 15, 2020
1 parent 8aac414 commit 2c2883d
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 203 deletions.
18 changes: 18 additions & 0 deletions integration_tests/src/main/python/csv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,21 @@ def test_ts_formats_round_trip(spark_tmp_path, date_format, ts_part):
.option('timestampFormat', full_format)\
.csv(data_path),
conf=_enable_ts_conf)

def test_input_meta(spark_tmp_path):
gen = StructGen([('a', long_gen), ('b', long_gen)], nullable=False)
first_data_path = spark_tmp_path + '/CSV_DATA/key=0'
with_cpu_session(
lambda spark : gen_df(spark, gen).write.csv(first_data_path))
second_data_path = spark_tmp_path + '/CSV_DATA/key=1'
with_cpu_session(
lambda spark : gen_df(spark, gen).write.csv(second_data_path))
data_path = spark_tmp_path + '/CSV_DATA'
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.read.schema(gen.data_type)\
.csv(data_path)\
.filter(f.col('a') > 0)\
.selectExpr('a',
'input_file_name()',
'input_file_block_start()',
'input_file_block_length()'))
33 changes: 33 additions & 0 deletions integration_tests/src/main/python/misc_expr_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect
from data_gen import *
from marks import incompat, approximate_float
from pyspark.sql.types import *
import pyspark.sql.functions as f

def test_mono_id():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, short_gen).select(
f.col('a'),
f.monotonically_increasing_id()))

def test_part_id():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, short_gen).select(
f.col('a'),
f.spark_partition_id()))
15 changes: 15 additions & 0 deletions integration_tests/src/main/python/orc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,18 @@ def test_compress_write_round_trip(spark_tmp_path, compress):
data_path,
conf={'spark.sql.orc.compression.codec': compress})

def test_input_meta(spark_tmp_path):
first_data_path = spark_tmp_path + '/ORC_DATA/key=0'
with_cpu_session(
lambda spark : unary_op_df(spark, long_gen).write.orc(first_data_path))
second_data_path = spark_tmp_path + '/ORC_DATA/key=1'
with_cpu_session(
lambda spark : unary_op_df(spark, long_gen).write.orc(second_data_path))
data_path = spark_tmp_path + '/ORC_DATA'
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.read.orc(data_path)\
.filter(f.col('a') > 0)\
.selectExpr('a',
'input_file_name()',
'input_file_block_start()',
'input_file_block_length()'))
15 changes: 15 additions & 0 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,18 @@ def test_compress_write_round_trip(spark_tmp_path, compress):
data_path,
conf={'spark.sql.parquet.compression.codec': compress})

def test_input_meta(spark_tmp_path):
first_data_path = spark_tmp_path + '/PARQUET_DATA/key=0'
with_cpu_session(
lambda spark : unary_op_df(spark, long_gen).write.parquet(first_data_path))
second_data_path = spark_tmp_path + '/PARQUET_DATA/key=1'
with_cpu_session(
lambda spark : unary_op_df(spark, long_gen).write.parquet(second_data_path))
data_path = spark_tmp_path + '/PARQUET_DATA'
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.read.parquet(data_path)\
.filter(f.col('a') > 0)\
.selectExpr('a',
'input_file_name()',
'input_file_block_start()',
'input_file_block_length()'))
100 changes: 0 additions & 100 deletions tests/src/test/resources/lots_o_longs_more.csv

This file was deleted.

96 changes: 0 additions & 96 deletions tests/src/test/scala/ai/rapids/spark/ProjectExprSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package ai.rapids.spark

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._

class ProjectExprSuite extends SparkQueryCompareTestSuite {
Expand Down Expand Up @@ -49,99 +48,4 @@ class ProjectExprSuite extends SparkQueryCompareTestSuite {
conf = forceHostColumnarToGpu()) {
frame => frame.select("time")
}

def booleanWithNullsDf(session: SparkSession): DataFrame = {
import session.sqlContext.implicits._
Seq[java.lang.Boolean](
true,
false,
true,
false,
null,
null,
true,
false
).toDF("bools")
}

def bytesDf(session: SparkSession): DataFrame = {
import session.sqlContext.implicits._
Seq[Byte](
0.toByte,
2.toByte,
3.toByte,
(-1).toByte,
(-10).toByte,
(-128).toByte,
127.toByte
).toDF("bytes")
}

def bytesWithNullsDf(session: SparkSession): DataFrame = {
import session.sqlContext.implicits._
Seq[java.lang.Byte](
0.toByte,
2.toByte,
3.toByte,
(-1).toByte,
(-10).toByte,
(-128).toByte,
127.toByte,
null,
null,
0.toByte
).toDF("bytes")
}

def shortsDf(session: SparkSession): DataFrame = {
import session.sqlContext.implicits._
Seq[Short](
0.toShort,
23456.toShort,
3.toShort,
(-1).toShort,
(-10240).toShort,
(-32768).toShort,
32767.toShort
).toDF("shorts")
}

def shortsWithNullsDf(session: SparkSession): DataFrame = {
import session.sqlContext.implicits._
Seq[java.lang.Short](
0.toShort,
23456.toShort,
3.toShort,
(-1).toShort,
(-10240).toShort,
(-32768).toShort,
32767.toShort,
null,
null,
0.toShort
).toDF("shorts")
}

testSparkResultsAreEqual("input_file_name", longsFromMultipleCSVDf, repart = 0) {
// The filter forces a coalesce so we can test that we disabled coalesce properly in this case
frame => frame.filter(col("longs") > 0).select(col("longs"), input_file_name())
}

testSparkResultsAreEqual("input_file_block_start", longsFromMultipleCSVDf, repart = 0) {
// The filter forces a coalesce so we can test that we disabled coalesce properly in this case
frame => frame.filter(col("longs") > 0).selectExpr("longs", "input_file_block_start()")
}

testSparkResultsAreEqual("input_file_block_length", longsFromMultipleCSVDf, repart = 0) {
// The filter forces a coalesce so we can test that we disabled coalesce properly in this case
frame => frame.filter(col("longs") > 0).selectExpr("longs", "input_file_block_length()")
}

testSparkResultsAreEqual("monotonically_increasing_id", shortsDf) {
frame => frame.select(col("shorts"), monotonically_increasing_id())
}

testSparkResultsAreEqual("spark_partition_id", shortsDf) {
frame => frame.select(col("shorts"), spark_partition_id())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1501,13 +1501,6 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm {
)))(_)
}

def longsFromMultipleCSVDf = {
fromCsvPatternDf("./", "lots_o_longs*.csv", StructType(Array(
StructField("longs", LongType, true),
StructField("more_longs", LongType, true)
)))(_)
}

def longsFromCSVDf = {
fromCsvDf("lots_o_longs.csv", StructType(Array(
StructField("longs", LongType, true),
Expand Down

0 comments on commit 2c2883d

Please sign in to comment.