From 6f5271c4f95c0bca6fb66dd303e98ca2e8b077bf Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Thu, 3 Jun 2021 11:24:36 -0500 Subject: [PATCH] Add checks for format option --- .../tool/qualification/Qualification.scala | 37 +++++++++---------- .../qualification/QualificationArgs.scala | 5 +++ 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/rapids-4-spark-tools/src/main/scala/com/nvidia/spark/rapids/tool/qualification/Qualification.scala b/rapids-4-spark-tools/src/main/scala/com/nvidia/spark/rapids/tool/qualification/Qualification.scala index 2d101b47c64..9785ab65147 100644 --- a/rapids-4-spark-tools/src/main/scala/com/nvidia/spark/rapids/tool/qualification/Qualification.scala +++ b/rapids-4-spark-tools/src/main/scala/com/nvidia/spark/rapids/tool/qualification/Qualification.scala @@ -105,25 +105,24 @@ object Qualification extends Logging { def writeQualification(df: DataFrame, outputDir: String, format: String): Unit = { - // val fileWriter = apps.head.fileWriter - // val dfRenamed = apps.head.renameQualificationColumns(df) - if (format.equals("csv")) { - df.repartition(1).write.option("header", "true"). - mode("overwrite").csv(s"$outputDir/rapids_4_spark_qualification_output") - logInfo(s"Output log location: $outputDir") - } else { - // This tool's output log file name - val logFileName = "rapids_4_spark_qualification_output.log" - val outputFilePath = new Path(s"$outputDir/$logFileName") - val fs = FileSystem.get(outputFilePath.toUri, new Configuration()) - val outFile = fs.create(outputFilePath) - outFile.writeUTF(ToolUtils.showString(df, 1000)) - outFile.flush() - outFile.close() - logInfo(s"Output log location: $outputFilePath") + format match { + case "csv" => + df.repartition(1).write.option("header", "true"). + mode("overwrite").csv(s"$outputDir/rapids_4_spark_qualification_output") + logInfo(s"Output log location: $outputDir") + case "text" => + // This tool's output log file name + val logFileName = "rapids_4_spark_qualification_output.log" + val outputFilePath = new Path(s"$outputDir/$logFileName") + val fs = FileSystem.get(outputFilePath.toUri, new Configuration()) + val outFile = fs.create(outputFilePath) + // outFile.writeUTF(ToolUtils.showString(df, 1000)) + df.repartition(1).write.option("header", "true"). + mode("overwrite").text(s"$outputDir/rapids_4_spark_qualification_output_text") + outFile.flush() + outFile.close() + logInfo(s"Output log location: $outputFilePath") + case _ => logError("Invalid format") } - - // fileWriter.write("\n" + ToolUtils.showString(dfRenamed, - // apps(0).numOutputRows)) } } diff --git a/rapids-4-spark-tools/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualificationArgs.scala b/rapids-4-spark-tools/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualificationArgs.scala index 9231d079184..3c1ce63812e 100644 --- a/rapids-4-spark-tools/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualificationArgs.scala +++ b/rapids-4-spark-tools/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualificationArgs.scala @@ -68,5 +68,10 @@ For usage see below: required = false, default = Some(false), descr = "Include the executor CPU time percent. It will take longer with this option.") + + validateOpt(outputFormat) { + case Some("text") | Some("csv") => Right(Unit) + case _ => Left(s"Invalid format - must be 'csv' or 'text'.") + } verify() }