org.scalatest
scalatest_${scala.binary.version}
diff --git a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala
index 9f9bdafb42d..e9140b81a5a 100644
--- a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala
+++ b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import org.apache.spark.sql.execution.{InputAdapter, QueryExecution, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.types.DataTypes
import org.apache.spark.sql.util.QueryExecutionListener
object BenchUtils {
@@ -69,7 +70,7 @@ object BenchUtils {
runBench(
spark,
createDataFrame,
- WriteParquet(path, mode, writeOptions),
+ WriteCsv(path, mode, writeOptions),
queryDescription,
filenameStub,
iterations,
@@ -399,6 +400,12 @@ object BenchUtils {
if (count1 == count2) {
println(s"Both DataFrames contain $count1 rows")
+
+ if (!ignoreOrdering && (df1.rdd.getNumPartitions > 1 || df2.rdd.getNumPartitions > 1)) {
+ throw new IllegalStateException("Cannot run with ignoreOrdering=false because one or " +
+ "more inputs have multiple partitions")
+ }
+
val result1 = collectResults(df1, ignoreOrdering, useIterator)
val result2 = collectResults(df2, ignoreOrdering, useIterator)
@@ -435,7 +442,11 @@ object BenchUtils {
// apply sorting if specified
val resultDf = if (ignoreOrdering) {
// let Spark do the sorting
- df.sort(df.columns.map(col): _*)
+ val nonFloatCols = df.schema.fields
+ .filter(field => !(field.dataType == DataTypes.FloatType ||
+ field.dataType == DataTypes.DoubleType))
+ .map(field => col(field.name))
+ df.sort(nonFloatCols: _*)
} else {
df
}
diff --git a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/CompareResults.scala b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/CompareResults.scala
new file mode 100644
index 00000000000..7b54ca4f8c8
--- /dev/null
+++ b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/CompareResults.scala
@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.nvidia.spark.rapids.tests.common
+
+import org.rogach.scallop.ScallopConf
+
+import org.apache.spark.sql.SparkSession
+
+/**
+ * Utility for comparing two csv or parquet files, such as the output from a benchmark, to
+ * verify that they match, allowing for differences in precision.
+ *
+ * This utility is intended to be run via spark-submit.
+ *
+ * Example usage:
+ *
+ *
+ * $SPARK_HOME/bin/spark-submit --jars $SPARK_RAPIDS_PLUGIN_JAR,$CUDF_JAR \
+ * --master local[*] \
+ * --class com.nvidia.spark.rapids.tests.common.CompareResults \
+ * $SPARK_RAPIDS_PLUGIN_INTEGRATION_TEST_JAR \
+ * --input1 /path/to/result1 \
+ * --input2 /path/to/result2 \
+ * --input-format parquet
+ *
+ */
+object CompareResults {
+ def main(arg: Array[String]): Unit = {
+ val conf = new Conf(arg)
+
+ val spark = SparkSession.builder.appName("CompareResults").getOrCreate()
+
+ val (df1, df2) = conf.inputFormat() match {
+ case "csv" =>
+ (spark.read.csv(conf.input1()), spark.read.csv(conf.input2()))
+ case "parquet" =>
+ (spark.read.parquet(conf.input1()), spark.read.parquet(conf.input2()))
+ }
+
+ BenchUtils.compareResults(
+ df1,
+ df2,
+ conf.ignoreOrdering(),
+ conf.useIterator(),
+ conf.maxErrors(),
+ conf.epsilon())
+ }
+}
+
+class Conf(arguments: Seq[String]) extends ScallopConf(arguments) {
+ /** Path to first data set */
+ val input1 = opt[String](required = true)
+ /** Path to second data set */
+ val input2 = opt[String](required = true)
+ /** Input format (csv or parquet) */
+ val inputFormat = opt[String](required = true)
+ /** Sort the data collected from the DataFrames before comparing them. */
+ val ignoreOrdering = opt[Boolean](required = false, default = Some(false))
+ /**
+ * When set to true, use `toLocalIterator` to load one partition at a time into driver memory,
+ * reducing memory usage at the cost of performance because processing will be single-threaded.
+ */
+ val useIterator = opt[Boolean](required = false, default = Some(false))
+ /** Maximum number of differences to report */
+ val maxErrors = opt[Int](required = false, default = Some(10))
+ /** Allow for differences in precision when comparing floating point values */
+ val epsilon = opt[Double](required = false, default = Some(0.00001))
+ verify()
+}
\ No newline at end of file
diff --git a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/tpcds/TpcdsLikeBench.scala b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/tpcds/TpcdsLikeBench.scala
index 2f137f82029..34790a8ba3f 100644
--- a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/tpcds/TpcdsLikeBench.scala
+++ b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/tpcds/TpcdsLikeBench.scala
@@ -17,9 +17,10 @@
package com.nvidia.spark.rapids.tests.tpcds
import com.nvidia.spark.rapids.tests.common.BenchUtils
+import org.rogach.scallop.ScallopConf
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
+import org.apache.spark.sql.{SaveMode, SparkSession}
object TpcdsLikeBench extends Logging {
@@ -32,6 +33,8 @@ object TpcdsLikeBench extends Logging {
* @param spark The Spark session
* @param query The name of the query to run e.g. "q5"
* @param iterations The number of times to run the query.
+ * @param gcBetweenRuns Whether to call `System.gc` between iterations to cause Spark to
+ * call `unregisterShuffle`
*/
def collect(
spark: SparkSession,
@@ -55,7 +58,12 @@ object TpcdsLikeBench extends Logging {
*
* @param spark The Spark session
* @param query The name of the query to run e.g. "q5"
+ * @param path The path to write the results to
+ * @param mode The SaveMode to use when writing the results
+ * @param writeOptions Write options
* @param iterations The number of times to run the query.
+ * @param gcBetweenRuns Whether to call `System.gc` between iterations to cause Spark to
+ * call `unregisterShuffle`
*/
def writeCsv(
spark: SparkSession,
@@ -85,7 +93,12 @@ object TpcdsLikeBench extends Logging {
*
* @param spark The Spark session
* @param query The name of the query to run e.g. "q5"
- * @param iterations The number of times to run the query.
+ * @param path The path to write the results to
+ * @param mode The SaveMode to use when writing the results
+ * @param writeOptions Write options
+ * @param iterations The number of times to run the query
+ * @param gcBetweenRuns Whether to call `System.gc` between iterations to cause Spark to
+ * call `unregisterShuffle`
*/
def writeParquet(
spark: SparkSession,
@@ -111,15 +124,49 @@ object TpcdsLikeBench extends Logging {
* The main method can be invoked by using spark-submit.
*/
def main(args: Array[String]): Unit = {
- val input = args(0)
+ val conf = new Conf(args)
val spark = SparkSession.builder.appName("TPC-DS Like Bench").getOrCreate()
- TpcdsLikeSpark.setupAllParquet(spark, input)
-
- args.drop(1).foreach(query => {
- println(s"*** RUNNING TPC-DS QUERY $query")
- collect(spark, query)
- })
+ conf.inputFormat().toLowerCase match {
+ case "parquet" => TpcdsLikeSpark.setupAllParquet(spark, conf.input())
+ case "csv" => TpcdsLikeSpark.setupAllCSV(spark, conf.input())
+ case other =>
+ println(s"Invalid input format: $other")
+ System.exit(-1)
+ }
+ println(s"*** RUNNING TPC-DS QUERY ${conf.query()}")
+ conf.output.toOption match {
+ case Some(path) => conf.outputFormat().toLowerCase match {
+ case "parquet" =>
+ writeParquet(
+ spark,
+ conf.query(),
+ path,
+ iterations = conf.iterations())
+ case "csv" =>
+ writeCsv(
+ spark,
+ conf.query(),
+ path,
+ iterations = conf.iterations())
+ case _ =>
+ println("Invalid or unspecified output format")
+ System.exit(-1)
+ }
+ case _ =>
+ collect(spark, conf.query(), conf.iterations())
+ }
}
}
+
+class Conf(arguments: Seq[String]) extends ScallopConf(arguments) {
+ val input = opt[String](required = true)
+ val inputFormat = opt[String](required = true)
+ val query = opt[String](required = true)
+ val iterations = opt[Int](default = Some(3))
+ val output = opt[String](required = false)
+ val outputFormat = opt[String](required = false)
+ verify()
+}
+
diff --git a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/tpch/Benchmarks.scala b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/tpch/Benchmarks.scala
deleted file mode 100644
index 637705c9e8c..00000000000
--- a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/tpch/Benchmarks.scala
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.nvidia.spark.rapids.tests.tpch
-
-import java.util.concurrent.TimeUnit
-
-import com.nvidia.spark.rapids.tests.DebugRange
-
-import org.apache.spark.sql.SparkSession
-
-object Benchmarks {
- def session: SparkSession = {
- val builder = SparkSession.builder.appName("TPCHLikeJob")
-
- val master = System.getenv("SPARK_MASTER")
- if (master != null) {
- builder.master(master)
- }
-
- val spark = builder.getOrCreate()
- spark.sparkContext.setLogLevel("warn")
-
- spark.sqlContext.clearCache()
-
- spark
- }
-
- def runQueries(output: String, args: Array[String]): Unit = {
- (1 until args.length).foreach(index => {
- val query = args(index)
- System.err.println(s"QUERY: ${query}")
- val df = query match {
- case "1" => Q1Like(session)
- case "2" => Q2Like(session)
- case "3" => Q3Like(session)
- case "4" => Q4Like(session)
- case "5" => Q5Like(session)
- case "6" => Q6Like(session)
- case "7" => Q7Like(session)
- case "8" => Q8Like(session)
- case "9" => Q9Like(session)
- case "10" => Q10Like(session)
- case "11" => Q11Like(session)
- case "12" => Q12Like(session)
- case "13" => Q13Like(session)
- case "14" => Q14Like(session)
- case "15" => Q15Like(session)
- case "16" => Q16Like(session)
- case "17" => Q17Like(session)
- case "18" => Q18Like(session)
- case "19" => Q19Like(session)
- case "20" => Q20Like(session)
- case "21" => Q21Like(session)
- case "22" => Q22Like(session)
- }
- val start = System.nanoTime()
- val range = new DebugRange(s"QUERY: ${query}")
- try {
- df.write.mode("overwrite").csv(output + "/" + query)
- } finally {
- range.close()
- }
- val end = System.nanoTime()
- System.err.println(s"QUERY: ${query} took ${TimeUnit.NANOSECONDS.toMillis(end - start)} ms")
- })
- }
-}
-
-object CSV {
-
- def main(args: Array[String]): Unit = {
- val input = args(0)
- val output = args(1)
-
- val session = Benchmarks.session
-
- TpchLikeSpark.setupAllCSV(session, input)
- Benchmarks.runQueries(output, args.slice(1, args.length))
- }
-}
-
-object Parquet {
-
- def main(args: Array[String]): Unit = {
- val input = args(0)
- val output = args(1)
-
- val session = Benchmarks.session
-
- TpchLikeSpark.setupAllParquet(session, input)
- Benchmarks.runQueries(output, args.slice(1, args.length))
- }
-}
diff --git a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/tpch/TpchLikeBench.scala b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/tpch/TpchLikeBench.scala
new file mode 100644
index 00000000000..f1b5c8bf106
--- /dev/null
+++ b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/tpch/TpchLikeBench.scala
@@ -0,0 +1,197 @@
+/*
+ * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.tests.tpch
+
+import com.nvidia.spark.rapids.tests.common.BenchUtils
+import org.rogach.scallop.ScallopConf
+
+import org.apache.spark.sql.{SaveMode, SparkSession}
+
+object TpchLikeBench {
+
+ /**
+ * This method performs a benchmark of executing a query and collecting the results to the
+ * driver and can be called from Spark shell using the following syntax:
+ *
+ * TpchLikeBench.collect(spark, "q5", 3)
+ *
+ * @param spark The Spark session
+ * @param query The name of the query to run e.g. "q5"
+ * @param iterations The number of times to run the query.
+ * @param gcBetweenRuns Whether to call `System.gc` between iterations to cause Spark to
+ * call `unregisterShuffle`
+ */
+ def collect(
+ spark: SparkSession,
+ query: String,
+ iterations: Int = 3,
+ gcBetweenRuns: Boolean = false): Unit = {
+ BenchUtils.collect(
+ spark,
+ spark => getQuery(query)(spark),
+ query,
+ s"tpch-$query-collect",
+ iterations,
+ gcBetweenRuns)
+ }
+
+ /**
+ * This method performs a benchmark of executing a query and writing the results to CSV files
+ * and can be called from Spark shell using the following syntax:
+ *
+ * TpchLikeBench.writeCsv(spark, "q5", 3, "/path/to/write")
+ *
+ * @param spark The Spark session
+ * @param query The name of the query to run e.g. "q5"
+ * @param path The path to write the results to
+ * @param mode The SaveMode to use when writing the results
+ * @param writeOptions Write options
+ * @param iterations The number of times to run the query.
+ * @param gcBetweenRuns Whether to call `System.gc` between iterations to cause Spark to
+ * call `unregisterShuffle`
+ */
+ def writeCsv(
+ spark: SparkSession,
+ query: String,
+ path: String,
+ mode: SaveMode = SaveMode.Overwrite,
+ writeOptions: Map[String, String] = Map.empty,
+ iterations: Int = 3,
+ gcBetweenRuns: Boolean = false): Unit = {
+ BenchUtils.writeCsv(
+ spark,
+ spark => getQuery(query)(spark),
+ query,
+ s"tpch-$query-csv",
+ iterations,
+ gcBetweenRuns,
+ path,
+ mode,
+ writeOptions)
+ }
+
+ /**
+ * This method performs a benchmark of executing a query and writing the results to Parquet files
+ * and can be called from Spark shell using the following syntax:
+ *
+ * TpchLikeBench.writeParquet(spark, "q5", 3, "/path/to/write")
+ *
+ * @param spark The Spark session
+ * @param query The name of the query to run e.g. "q5"
+ * @param path The path to write the results to
+ * @param mode The SaveMode to use when writing the results
+ * @param writeOptions Write options
+ * @param iterations The number of times to run the query.
+ * @param gcBetweenRuns Whether to call `System.gc` between iterations to cause Spark to
+ * call `unregisterShuffle`
+ */
+ def writeParquet(
+ spark: SparkSession,
+ query: String,
+ path: String,
+ mode: SaveMode = SaveMode.Overwrite,
+ writeOptions: Map[String, String] = Map.empty,
+ iterations: Int = 3,
+ gcBetweenRuns: Boolean = false): Unit = {
+ BenchUtils.writeParquet(
+ spark,
+ spark => getQuery(query)(spark),
+ query,
+ s"tpch-$query-parquet",
+ iterations,
+ gcBetweenRuns,
+ path,
+ mode,
+ writeOptions)
+ }
+
+ /**
+ * The main method can be invoked by using spark-submit.
+ */
+ def main(args: Array[String]): Unit = {
+ val conf = new Conf(args)
+
+ val spark = SparkSession.builder.appName("TPC-H Like Bench").getOrCreate()
+ conf.inputFormat().toLowerCase match {
+ case "parquet" => TpchLikeSpark.setupAllParquet(spark, conf.input())
+ case "csv" => TpchLikeSpark.setupAllCSV(spark, conf.input())
+ case other =>
+ println(s"Invalid input format: $other")
+ System.exit(-1)
+ }
+
+ println(s"*** RUNNING TPC-H QUERY ${conf.query()}")
+ conf.output.toOption match {
+ case Some(path) => conf.outputFormat().toLowerCase match {
+ case "parquet" =>
+ writeParquet(
+ spark,
+ conf.query(),
+ path,
+ iterations = conf.iterations())
+ case "csv" =>
+ writeCsv(
+ spark,
+ conf.query(),
+ path,
+ iterations = conf.iterations())
+ case _ =>
+ println("Invalid or unspecified output format")
+ System.exit(-1)
+ }
+ case _ =>
+ collect(spark, conf.query(), conf.iterations())
+ }
+ }
+
+ private def getQuery(query: String)(spark: SparkSession) = {
+ query match {
+ case "q1" => Q1Like(spark)
+ case "q2" => Q2Like(spark)
+ case "q3" => Q3Like(spark)
+ case "q4" => Q4Like(spark)
+ case "q5" => Q5Like(spark)
+ case "q6" => Q6Like(spark)
+ case "q7" => Q7Like(spark)
+ case "q8" => Q8Like(spark)
+ case "q9" => Q9Like(spark)
+ case "q10" => Q10Like(spark)
+ case "q11" => Q11Like(spark)
+ case "q12" => Q12Like(spark)
+ case "q13" => Q13Like(spark)
+ case "q14" => Q14Like(spark)
+ case "q15" => Q15Like(spark)
+ case "q16" => Q16Like(spark)
+ case "q17" => Q17Like(spark)
+ case "q18" => Q18Like(spark)
+ case "q19" => Q19Like(spark)
+ case "q20" => Q20Like(spark)
+ case "q21" => Q21Like(spark)
+ case "q22" => Q22Like(spark)
+ }
+ }
+}
+
+class Conf(arguments: Seq[String]) extends ScallopConf(arguments) {
+ val input = opt[String](required = true)
+ val inputFormat = opt[String](required = true)
+ val query = opt[String](required = true)
+ val iterations = opt[Int](default = Some(3))
+ val output = opt[String](required = false)
+ val outputFormat = opt[String](required = false)
+ verify()
+}
\ No newline at end of file
diff --git a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/tpcxbb/TpcxbbLikeBench.scala b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/tpcxbb/TpcxbbLikeBench.scala
index c650592d0f0..7049947dde8 100644
--- a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/tpcxbb/TpcxbbLikeBench.scala
+++ b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/tpcxbb/TpcxbbLikeBench.scala
@@ -17,6 +17,7 @@
package com.nvidia.spark.rapids.tests.tpcxbb
import com.nvidia.spark.rapids.tests.common.BenchUtils
+import org.rogach.scallop.ScallopConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
@@ -28,10 +29,11 @@ object TpcxbbLikeBench extends Logging {
* driver and can be called from Spark shell using the following syntax:
*
* TpcxbbLikeBench.collect(spark, "q5", 3)
- *
* @param spark The Spark session
* @param query The name of the query to run e.g. "q5"
* @param iterations The number of times to run the query.
+ * @param gcBetweenRuns Whether to call `System.gc` between iterations to cause Spark to
+ * call `unregisterShuffle`
*/
def collect(
spark: SparkSession,
@@ -55,7 +57,12 @@ object TpcxbbLikeBench extends Logging {
*
* @param spark The Spark session
* @param query The name of the query to run e.g. "q5"
+ * @param path The path to write the results to
+ * @param mode The SaveMode to use when writing the results
+ * @param writeOptions Write options
* @param iterations The number of times to run the query.
+ * @param gcBetweenRuns Whether to call `System.gc` between iterations to cause Spark to
+ * call `unregisterShuffle`
*/
def writeCsv(
spark: SparkSession,
@@ -85,7 +92,12 @@ object TpcxbbLikeBench extends Logging {
*
* @param spark The Spark session
* @param query The name of the query to run e.g. "q5"
+ * @param path The path to write the results to
+ * @param mode The SaveMode to use when writing the results
+ * @param writeOptions Write options
* @param iterations The number of times to run the query.
+ * @param gcBetweenRuns Whether to call `System.gc` between iterations to cause Spark to
+ * call `unregisterShuffle`
*/
def writeParquet(
spark: SparkSession,
@@ -108,15 +120,40 @@ object TpcxbbLikeBench extends Logging {
}
def main(args: Array[String]): Unit = {
- val input = args(0)
+ val conf = new Conf(args)
val spark = SparkSession.builder.appName("TPCxBB Bench").getOrCreate()
- TpcxbbLikeSpark.setupAllParquet(spark, input)
- args.drop(1).foreach(query => {
- println(s"*** RUNNING TPCx-BB QUERY $query")
- collect(spark, query)
- })
+ conf.inputFormat().toLowerCase match {
+ case "parquet" => TpcxbbLikeSpark.setupAllParquet(spark, conf.input())
+ case "csv" => TpcxbbLikeSpark.setupAllCSV(spark, conf.input())
+ case other =>
+ println(s"Invalid input format: $other")
+ System.exit(-1)
+ }
+
+ println(s"*** RUNNING TPCx-BB QUERY ${conf.query()}")
+ conf.output.toOption match {
+ case Some(path) => conf.outputFormat().toLowerCase match {
+ case "parquet" =>
+ writeParquet(
+ spark,
+ conf.query(),
+ path,
+ iterations = conf.iterations())
+ case "csv" =>
+ writeCsv(
+ spark,
+ conf.query(),
+ path,
+ iterations = conf.iterations())
+ case _ =>
+ println("Invalid or unspecified output format")
+ System.exit(-1)
+ }
+ case _ =>
+ collect(spark, conf.query(), conf.iterations())
+ }
}
def getQuery(query: String): SparkSession => DataFrame = {
@@ -160,6 +197,15 @@ object TpcxbbLikeBench extends Logging {
case 30 => Q30Like.apply
case _ => throw new IllegalArgumentException(s"Unknown TPCx-BB query number: $queryIndex")
}
-
}
}
+
+class Conf(arguments: Seq[String]) extends ScallopConf(arguments) {
+ val input = opt[String](required = true)
+ val inputFormat = opt[String](required = true)
+ val query = opt[String](required = true)
+ val iterations = opt[Int](default = Some(3))
+ val output = opt[String](required = false)
+ val outputFormat = opt[String](required = false)
+ verify()
+}
\ No newline at end of file
diff --git a/pom.xml b/pom.xml
index c6d3bab8798..f00ab4ac993 100644
--- a/pom.xml
+++ b/pom.xml
@@ -303,6 +303,11 @@
flatbuffers-java
1.11.0
+