Skip to content

Commit

Permalink
Fix tests for Spark 3.2.0 shim (NVIDIA#1869)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
revans2 authored Mar 4, 2021
1 parent 41c39ba commit de3db66
Show file tree
Hide file tree
Showing 14 changed files with 112 additions and 49 deletions.
22 changes: 12 additions & 10 deletions integration_tests/src/main/python/orc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,26 +169,28 @@ def test_input_meta(spark_tmp_path):
'input_file_block_start()',
'input_file_block_length()'))

def setup_orc_file_no_column_names(spark):
drop_query = "DROP TABLE IF EXISTS test_orc_data"
create_query = "CREATE TABLE `test_orc_data` (`_col1` INT, `_col2` STRING, `_col3` INT) USING orc"
insert_query = "INSERT INTO test_orc_data VALUES(13, '155', 2020)"
def setup_orc_file_no_column_names(spark, table_name):
drop_query = "DROP TABLE IF EXISTS {}".format(table_name)
create_query = "CREATE TABLE `{}` (`_col1` INT, `_col2` STRING, `_col3` INT) USING orc".format(table_name)
insert_query = "INSERT INTO {} VALUES(13, '155', 2020)".format(table_name)
spark.sql(drop_query).collect
spark.sql(create_query).collect
spark.sql(insert_query).collect

def test_missing_column_names():
def test_missing_column_names(spark_tmp_table_factory):
if is_spark_300():
pytest.skip("Apache Spark 3.0.0 does not handle ORC files without column names")

with_cpu_session(setup_orc_file_no_column_names)
table_name = spark_tmp_table_factory.get()
with_cpu_session(lambda spark : setup_orc_file_no_column_names(spark, table_name))
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.sql("SELECT _col3,_col2 FROM test_orc_data"))
lambda spark : spark.sql("SELECT _col3,_col2 FROM {}".format(table_name)))

def test_missing_column_names_filter():
def test_missing_column_names_filter(spark_tmp_table_factory):
if is_spark_300():
pytest.skip("Apache Spark 3.0.0 does not handle ORC files without column names")

with_cpu_session(setup_orc_file_no_column_names)
table_name = spark_tmp_table_factory.get()
with_cpu_session(lambda spark : setup_orc_file_no_column_names(spark, table_name))
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.sql("SELECT _col3,_col2 FROM test_orc_data WHERE _col2 = '155'"))
lambda spark : spark.sql("SELECT _col3,_col2 FROM {} WHERE _col2 = '155'".format(table_name)))
1 change: 1 addition & 0 deletions jenkins/spark-nightly-build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ mvn -U -B -Pspark302tests,snapshot-shims test $MVN_URM_MIRROR -Dmaven.repo.local
mvn -U -B -Pspark303tests,snapshot-shims test $MVN_URM_MIRROR -Dmaven.repo.local=$M2DIR
mvn -U -B -Pspark311tests,snapshot-shims test $MVN_URM_MIRROR -Dmaven.repo.local=$M2DIR
mvn -U -B -Pspark312tests,snapshot-shims test $MVN_URM_MIRROR -Dmaven.repo.local=$M2DIR
mvn -U -B -Pspark320tests,snapshot-shims test $MVN_URM_MIRROR -Dmaven.repo.local=$M2DIR

# Parse cudf and spark files from local mvn repo
jenkins/printJarVersion.sh "CUDFVersion" "$M2DIR/ai/rapids/cudf/${CUDF_VER}" "cudf-${CUDF_VER}" "-${CUDA_CLASSIFIER}.jar" $SERVER_ID
Expand Down
4 changes: 3 additions & 1 deletion jenkins/spark-premerge-build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ tar zxf $SPARK_HOME.tgz -C $ARTF_ROOT && \

mvn -U -B $MVN_URM_MIRROR '-P!snapshot-shims,pre-merge' clean verify -Dpytest.TEST_TAGS='' -Dpytest.TEST_TYPE="pre-commit" -Dpytest.TEST_PARALLEL=4
# Run the unit tests for other Spark versions but dont run full python integration tests
env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark301tests,snapshot-shims test -Dpytest.TEST_TAGS=''
# NOT ALL TESTS NEEDED FOR PREMERGE
#env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark301tests,snapshot-shims test -Dpytest.TEST_TAGS=''
env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark302tests,snapshot-shims test -Dpytest.TEST_TAGS=''
env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark303tests,snapshot-shims test -Dpytest.TEST_TAGS=''
env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark311tests,snapshot-shims test -Dpytest.TEST_TAGS=''
env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark312tests,snapshot-shims test -Dpytest.TEST_TAGS=''
env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark320tests,snapshot-shims test -Dpytest.TEST_TAGS=''

# The jacoco coverage should have been collected, but because of how the shade plugin
# works and jacoco we need to clean some things up so jacoco will only report for the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkEnv
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.errors.attachTree
Expand All @@ -42,13 +42,13 @@ import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, RunnableCommand}
import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, FileScanRDD, HadoopFsRelation, InMemoryFileIndex, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.datasources.rapids.GpuPartitioningUtils
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.python.WindowInPandasExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, GpuStringReplace, GpuTimeSub, ShuffleManagerShimBase}
Expand All @@ -62,6 +62,21 @@ import org.apache.spark.unsafe.types.CalendarInterval
class Spark300Shims extends SparkShims {

override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION
override def parquetRebaseReadKey: String =
SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ.key
override def parquetRebaseWriteKey: String =
SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key
override def avroRebaseReadKey: String =
SQLConf.LEGACY_AVRO_REBASE_MODE_IN_READ.key
override def avroRebaseWriteKey: String =
SQLConf.LEGACY_AVRO_REBASE_MODE_IN_WRITE.key
override def parquetRebaseRead(conf: SQLConf): String =
conf.getConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ)
override def parquetRebaseWrite(conf: SQLConf): String =
conf.getConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE)

override def v1RepairTableCommand(tableName: TableIdentifier): RunnableCommand =
AlterTableRecoverPartitionsCommand(tableName)

override def getScalaUDFAsExpression(
function: AnyRef,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,7 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm {
// at least a single block
val stream = new ByteArrayOutputStream(ByteArrayOutputFile.BLOCK_SIZE)
val outputFile: OutputFile = new ByteArrayOutputFile(stream)
sharedConf.setConfString(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key,
sharedConf.setConfString(ShimLoader.getSparkShims.parquetRebaseWriteKey,
LegacyBehaviorPolicy.CORRECTED.toString)
val recordWriter = SQLConf.withExistingConf(sharedConf) {
parquetOutputFileFormat.getRecordWriter(outputFile, sharedHadoopConf)
Expand Down Expand Up @@ -1218,7 +1218,7 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm {
hadoopConf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, false)
hadoopConf.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, false)

hadoopConf.set(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key,
hadoopConf.set(ShimLoader.getSparkShims.parquetRebaseWriteKey,
LegacyBehaviorPolicy.CORRECTED.toString)

hadoopConf.set(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,38 @@ import com.nvidia.spark.rapids.ShimVersion
import com.nvidia.spark.rapids.shims.spark311.Spark311Shims
import com.nvidia.spark.rapids.spark320.RapidsShuffleManager

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.command.{RepairTableCommand, RunnableCommand}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.internal.SQLConf

class Spark320Shims extends Spark311Shims {

override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION320

override def parquetRebaseReadKey: String =
SQLConf.PARQUET_REBASE_MODE_IN_READ.key
override def parquetRebaseWriteKey: String =
SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key
override def avroRebaseReadKey: String =
SQLConf.AVRO_REBASE_MODE_IN_READ.key
override def avroRebaseWriteKey: String =
SQLConf.AVRO_REBASE_MODE_IN_WRITE.key
override def parquetRebaseRead(conf: SQLConf): String =
conf.getConf(SQLConf.PARQUET_REBASE_MODE_IN_READ)
override def parquetRebaseWrite(conf: SQLConf): String =
conf.getConf(SQLConf.PARQUET_REBASE_MODE_IN_WRITE)

override def v1RepairTableCommand(tableName: TableIdentifier): RunnableCommand =
RepairTableCommand(tableName,
// These match the one place that this is called, if we start to call this in more places
// we will need to change the API to pass these values in.
enableAddPartitions = true,
enableDropPartitions = false)


override def getRapidsShuffleManagerClass: String = {
classOf[RapidsShuffleManager].getCanonicalName
}
Expand Down
9 changes: 4 additions & 5 deletions sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,10 +17,9 @@
package com.nvidia.spark

import ai.rapids.cudf.{ColumnVector, DType, Scalar}
import com.nvidia.spark.rapids.Arm
import com.nvidia.spark.rapids.{Arm, ShimLoader}

import org.apache.spark.sql.catalyst.util.RebaseDateTime
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.execution.TrampolineUtil

object RebaseHelper extends Arm {
Expand Down Expand Up @@ -67,9 +66,9 @@ object RebaseHelper extends Arm {

def newRebaseExceptionInRead(format: String): Exception = {
val config = if (format == "Parquet") {
SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ.key
ShimLoader.getSparkShims.parquetRebaseReadKey
} else if (format == "Avro") {
SQLConf.LEGACY_AVRO_REBASE_MODE_IN_READ.key
ShimLoader.getSparkShims.avroRebaseReadKey
} else {
throw new IllegalStateException("unrecognized format " + format)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
import org.apache.spark.sql.rapids.ColumnarWriteTaskStatsTracker
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DateType, DecimalType, MapType, StructType, TimestampType}
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DateType, DecimalType, StructType, TimestampType}
import org.apache.spark.sql.vectorized.ColumnarBatch

object GpuParquetFileFormat {
Expand Down Expand Up @@ -83,7 +83,7 @@ object GpuParquetFileFormat {
TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[DateType])
}

sqlConf.getConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE) match {
ShimLoader.getSparkShims.parquetRebaseWrite(sqlConf) match {
case "EXCEPTION" => //Good
case "CORRECTED" => //Good
case "LEGACY" =>
Expand Down Expand Up @@ -148,8 +148,8 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging {

val conf = ContextUtil.getConfiguration(job)

val dateTimeRebaseException =
"EXCEPTION".equals(conf.get(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key))
val dateTimeRebaseException = "EXCEPTION".equals(
sparkSession.sqlContext.getConf(ShimLoader.getSparkShims.parquetRebaseWriteKey))

val committerClass =
conf.getClass(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,16 @@ object GpuParquetScanBase {
meta.willNotWorkOnGpu("GpuParquetScan does not support int96 timestamp conversion")
}

sqlConf.get(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ.key) match {
sqlConf.get(ShimLoader.getSparkShims.parquetRebaseReadKey) match {
case "EXCEPTION" => if (schemaMightNeedNestedRebase) {
meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " +
s"${SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ.key} is EXCEPTION")
s"${ShimLoader.getSparkShims.parquetRebaseReadKey} is EXCEPTION")
}
case "CORRECTED" => // Good
case "LEGACY" => // really is EXCEPTION for us...
if (schemaMightNeedNestedRebase) {
meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " +
s"${SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ.key} is LEGACY")
s"${ShimLoader.getSparkShims.parquetRebaseReadKey} is LEGACY")
}
case other =>
meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode")
Expand Down Expand Up @@ -294,7 +294,7 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte
private val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith
private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold
private val isCorrectedRebase =
"CORRECTED" == sqlConf.getConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ)
"CORRECTED" == ShimLoader.getSparkShims.parquetRebaseRead(sqlConf)

def filterBlocks(
file: PartitionedFile,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, ExprId, NullOrdering, SortDirection, SortOrder}
Expand All @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins._
Expand Down Expand Up @@ -69,6 +70,14 @@ case class EMRShimVersion(major: Int, minor: Int, patch: Int) extends ShimVersio

trait SparkShims {
def getSparkShimVersion: ShimVersion
def parquetRebaseReadKey: String
def parquetRebaseWriteKey: String
def avroRebaseReadKey: String
def avroRebaseWriteKey: String
def parquetRebaseRead(conf: SQLConf): String
def parquetRebaseWrite(conf: SQLConf): String
def v1RepairTableCommand(tableName: TableIdentifier): RunnableCommand

def isGpuBroadcastHashJoin(plan: SparkPlan): Boolean
def isGpuShuffledHashJoin(plan: SparkPlan): Boolean
def isBroadcastExchangeLike(plan: SparkPlan): Boolean
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,13 +18,13 @@ package org.apache.spark.sql.rapids

import java.net.URI

import com.nvidia.spark.rapids.{ColumnarFileFormat, GpuDataWritingCommand}
import com.nvidia.spark.rapids.{ColumnarFileFormat, GpuDataWritingCommand, ShimLoader}

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, CommandUtils}
import org.apache.spark.sql.execution.command.CommandUtils
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -84,7 +84,8 @@ case class GpuCreateDataSourceTableAsSelectCommand(
case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty &&
sparkSession.sqlContext.conf.manageFilesourcePartitions =>
// Need to recover partitions into the metastore so our saved data is visible.
sessionState.executePlan(AlterTableRecoverPartitionsCommand(table.identifier)).toRdd
sessionState.executePlan(
ShimLoader.getSparkShims.v1RepairTableCommand(table.identifier)).toRdd
case _ =>
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ class AdaptiveQueryExecSuite
}
}

test("Join partitioned tables") {
test("Join partitioned tables DPP fallback") {
assumeSpark301orLater
assumePriorToSpark320 // In 3.2.0 AQE works with DPP

val conf = new SparkConf()
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
Expand Down Expand Up @@ -443,7 +444,19 @@ class AdaptiveQueryExecSuite
case DatabricksShimVersion(3, 0, 0) => false
case _ => true
}
assume(isValidTestForSparkVersion)
assume(isValidTestForSparkVersion, "SPARK 3.1.0 or later required")
}

private def assumePriorToSpark320 = {
val sparkShimVersion = ShimLoader.getSparkShims.getSparkShimVersion
val isValidTestForSparkVersion = sparkShimVersion match {
case ver: SparkShimVersion =>
(ver.major == 3 && ver.minor < 2) || ver.major < 3
case ver: DatabricksShimVersion =>
(ver.major == 3 && ver.minor < 2) || ver.major < 3
case _ => true
}
assume(isValidTestForSparkVersion, "Prior to SPARK 3.2.0 required")
}

def checkSkewJoin(
Expand Down
13 changes: 9 additions & 4 deletions tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -935,10 +935,15 @@ object CastOpSuite {
"2010-1-7 T")
}

val timestampWithoutDate = Seq(
"23:59:59.333666Z",
"T21:34:56.333666Z"
)
val timestampWithoutDate = if (validOnly && !castStringToTimestamp) {
// 3.2.0+ throws exceptions on string to date ANSI cast errors
Seq.empty
} else {
Seq(
"23:59:59.333666Z",
"T21:34:56.333666Z"
)
}

val allValues = specialDates ++
validYear ++
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,19 +17,12 @@
package com.nvidia.spark.rapids

import java.io.File
import java.lang.reflect.Method
import java.nio.charset.StandardCharsets

import ai.rapids.cudf.{ColumnVector, DType, Table, TableWriter}
import org.apache.hadoop.fs.Path
import org.apache.parquet.hadoop.ParquetFileReader
import org.mockito.ArgumentMatchers._
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.types.{ByteType, DataType}
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
* Tests for writing Parquet files with the GPU.
Expand Down

0 comments on commit de3db66

Please sign in to comment.