diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDataWritingCommandExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDataWritingCommandExec.scala index 7b9e201ebe4..b8477a0c7de 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDataWritingCommandExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDataWritingCommandExec.scala @@ -33,7 +33,10 @@ import org.apache.spark.util.SerializableConfiguration * An extension of `DataWritingCommand` that allows columnar execution. */ trait GpuDataWritingCommand extends DataWritingCommand { - override lazy val metrics: Map[String, SQLMetric] = GpuWriteJobStatsTracker.metrics + lazy val basicMetrics: Map[String, SQLMetric] = GpuWriteJobStatsTracker.basicMetrics + lazy val taskMetrics: Map[String, SQLMetric] = GpuWriteJobStatsTracker.taskMetrics + + override lazy val metrics: Map[String, SQLMetric] = basicMetrics ++ taskMetrics override final def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = throw new UnsupportedOperationException( @@ -44,7 +47,7 @@ trait GpuDataWritingCommand extends DataWritingCommand { def gpuWriteJobStatsTracker( hadoopConf: Configuration): GpuWriteJobStatsTracker = { val serializableHadoopConf = new SerializableConfiguration(hadoopConf) - GpuWriteJobStatsTracker(serializableHadoopConf) + GpuWriteJobStatsTracker(serializableHadoopConf, this) } def requireSingleBatch: Boolean diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuWriteStatsTracker.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuWriteStatsTracker.scala index d4b41f3a8d5..0a6098a2162 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuWriteStatsTracker.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuWriteStatsTracker.scala @@ -16,10 +16,11 @@ package org.apache.spark.sql.rapids +import com.nvidia.spark.rapids.GpuDataWritingCommand import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext -import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, WriteTaskStats} +import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.SerializableConfiguration @@ -60,9 +61,9 @@ object GpuWriteJobStatsTracker { val GPU_TIME_KEY = "gpuTime" val WRITE_TIME_KEY = "writeTime" - lazy val basicMetrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics + def basicMetrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics - lazy val taskMetrics: Map[String, SQLMetric] = { + def taskMetrics: Map[String, SQLMetric] = { val sparkContext = SparkContext.getActive.get Map( GPU_TIME_KEY -> SQLMetrics.createNanoTimingMetric(sparkContext, "GPU time"), @@ -70,8 +71,7 @@ object GpuWriteJobStatsTracker { ) } - def metrics: Map[String, SQLMetric] = basicMetrics ++ taskMetrics - - def apply(serializableHadoopConf: SerializableConfiguration): GpuWriteJobStatsTracker = - new GpuWriteJobStatsTracker(serializableHadoopConf, basicMetrics, taskMetrics) + def apply(serializableHadoopConf: SerializableConfiguration, + command: GpuDataWritingCommand): GpuWriteJobStatsTracker = + new GpuWriteJobStatsTracker(serializableHadoopConf, command.basicMetrics, command.taskMetrics) }