Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: sperlingxx <lovedreamf@gmail.com>
  • Loading branch information
sperlingxx committed Dec 13, 2021
1 parent bbcaf43 commit 809bf17
Showing 1 changed file with 13 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ import scala.collection.JavaConverters.asScalaIteratorConverter
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration

import ai.rapids.cudf.NvtxColor
import com.nvidia.spark.rapids.{GpuExec, GpuMetric, NvtxWithMetrics}
import com.nvidia.spark.rapids.{GpuExec, GpuMetric}
import com.nvidia.spark.rapids.GpuMetric.{COLLECT_TIME, DESCRIPTION_COLLECT_TIME, ESSENTIAL_LEVEL}
import com.nvidia.spark.rapids.shims.v2.ShimUnaryExecNode

Expand All @@ -46,26 +45,27 @@ case class GpuSubqueryBroadcastExec(
private lazy val relationFuture: Future[Array[InternalRow]] = {
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
val collectTime = gpuLongMetric(COLLECT_TIME)

Future {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sparkSession, executionId) {
withResource(new NvtxWithMetrics("broadcast collect", NvtxColor.GREEN, collectTime)) { _ =>
val beforeCollect = System.nanoTime()

val batchBc = child.executeBroadcast[SerializeConcatHostBuffersDeserializeBatch]()
val batchBc = child.executeBroadcast[SerializeConcatHostBuffersDeserializeBatch]()

gpuLongMetric("dataSize") += batchBc.value.dataSize

val toUnsafe = UnsafeProjection.create(output, output)
withResource(batchBc.value.hostBatches) { hostBatches =>
hostBatches.flatMap { cb =>
cb.rowIterator().asScala
.map(toUnsafe(_).copy().asInstanceOf[InternalRow])
}
val toUnsafe = UnsafeProjection.create(output, output)
val result = withResource(batchBc.value.hostBatches) { hostBatches =>
hostBatches.flatMap { cb =>
cb.rowIterator().asScala
.map(toUnsafe(_).copy().asInstanceOf[InternalRow])
}
}

gpuLongMetric("dataSize") += batchBc.value.dataSize
gpuLongMetric(COLLECT_TIME) += System.nanoTime() - beforeCollect

result
}
}(GpuSubqueryBroadcastExec.executionContext)
}
Expand Down

0 comments on commit 809bf17

Please sign in to comment.