-
Notifications
You must be signed in to change notification settings - Fork 232
/
GpuBroadcastExchangeExec.scala
675 lines (617 loc) · 27.2 KB
/
GpuBroadcastExchangeExec.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.rapids.execution
import java.io._
import java.util.UUID
import java.util.concurrent._
import scala.collection.mutable
import scala.concurrent.ExecutionContext
import scala.ref.WeakReference
import scala.util.control.NonFatal
import ai.rapids.cudf.{HostMemoryBuffer, JCudfSerialization, NvtxColor, NvtxRange}
import ai.rapids.cudf.JCudfSerialization.HostConcatResult
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.GpuMetric._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.{ShimBroadcastExchangeLike, ShimUnaryExecNode, SparkShimImpl}
import org.apache.spark.SparkException
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning}
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.MAX_BROADCAST_TABLE_BYTES
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
/**
* Class that is used to broadcast results (a contiguous host batch) to executors.
*
* This is instantiated in the driver, serialized to an output stream provided by Spark
* to broadcast, and deserialized on the executor. Both the driver's and executor's copies
* are cleaned via GC. Because Spark closes `AutoCloseable` broadcast results after spilling
* to disk, this class does not subclass `AutoCloseable`. Instead we implement a `closeInternal`
* method only to be triggered via GC.
*
* @param data HostConcatResult populated for a broadcast that has column, otherwise it is null.
* It is transient because we want the executor to deserialize its `data` from Spark's
* torrent-backed input stream.
* @param output used to find the schema for this broadcast batch
* @param numRows number of rows for this broadcast batch
* @param dataLen size in bytes for this broadcast batch
*/
// scalastyle:off no.finalize
@SerialVersionUID(100L)
class SerializeConcatHostBuffersDeserializeBatch(
@transient var data: HostConcatResult,
output: Seq[Attribute],
var numRows: Int,
var dataLen: Long)
extends Serializable with Logging {
@transient private var dataTypes = output.map(_.dataType).toArray
// used for memoization of deserialization to GPU on Executor
@transient private var batchInternal: SpillableColumnarBatch = null
private def maybeGpuBatch: Option[SpillableColumnarBatch] = Option(batchInternal)
def batch: SpillableColumnarBatch = this.synchronized {
maybeGpuBatch.getOrElse {
withResource(new NvtxRange("broadcast manifest batch", NvtxColor.PURPLE)) { _ =>
val spillable =
if (data == null || data.getTableHeader.getNumColumns == 0) {
// If `data` is null or there are no columns, this is a rows-only batch
SpillableColumnarBatch(
new ColumnarBatch(Array.empty, numRows),
SpillPriorities.ACTIVE_BATCHING_PRIORITY)
} else if (data.getTableHeader.getNumRows == 0) {
// If we have columns but no rows, we can use the emptyBatchFromTypes optimization
SpillableColumnarBatch(
GpuColumnVector.emptyBatchFromTypes(dataTypes),
SpillPriorities.ACTIVE_BATCHING_PRIORITY)
} else {
// Regular GPU batch with rows/cols
SpillableColumnarBatch(
data.toContiguousTable,
dataTypes,
SpillPriorities.ACTIVE_BATCHING_PRIORITY)
}
// At this point we no longer need the host data and should not need to touch it again.
// Note that we don't close this using `withResources` around the creation of the
// `SpillableColumnarBatch`. That is because if a retry exception is thrown we want to
// still be able to recreate this broadcast batch, so we can't close the host data
// until we are at this line.
data.safeClose()
data = null
batchInternal = spillable
spillable
}
}
}
/**
* Create host columnar batches from either serialized buffers or device columnar batch. This
* method can be safely called in both driver node and executor nodes. For now, it is used on
* the driver side for reusing GPU broadcast results in the CPU.
*
* NOTE: The caller is responsible to release these host columnar batches.
*/
def hostBatch: ColumnarBatch = this.synchronized {
maybeGpuBatch.map { spillable =>
withResource(spillable.getColumnarBatch()) { batch =>
val hostColumns: Array[ColumnVector] = GpuColumnVector
.extractColumns(batch)
.safeMap(_.copyToHost())
new ColumnarBatch(hostColumns, numRows)
}
}.getOrElse {
withResource(new NvtxRange("broadcast manifest batch", NvtxColor.PURPLE)) { _ =>
if (data == null) {
new ColumnarBatch(Array.empty, numRows)
} else {
val header = data.getTableHeader
val buffer = data.getHostBuffer
val hostColumns = SerializedHostTableUtils.buildHostColumns(
header, buffer, dataTypes)
val rowCount = header.getNumRows
new ColumnarBatch(hostColumns.toArray, rowCount)
}
}
}
}
private def writeObject(out: ObjectOutputStream): Unit = {
doWriteObject(out)
}
private def readObject(in: ObjectInputStream): Unit = {
doReadObject(in)
}
/**
* doWriteObject is invoked from both the driver, when it is trying to write
* a collected broadcast result on an stream to torrent broadcast to executors, and also
* when the executor MemoryStore evicts a "broadcast_[id]" block to make room in host memory.
*
* The driver will have `data` populated on construction and the executor will deserialize
* the object and, as part of the deserialization, invoke `doReadObject`.
* This will populate `data` before any task has had a chance to call `.batch` on this class.
*
* If `batchInternal` is defined we are in the executor, and there is no work to be done.
* This broadcast has been materialized on the GPU/RapidsBufferCatalog, and it is completely
* managed by the plugin.
*
* Public for unit tests.
*
* @param out the stream to write to
*/
def doWriteObject(out: ObjectOutputStream): Unit = this.synchronized {
maybeGpuBatch.map {
case justRows: JustRowsColumnarBatch =>
JCudfSerialization.writeRowsToStream(out, justRows.numRows())
case scb: SpillableColumnarBatch =>
val table = withResource(scb.getColumnarBatch()) { cb =>
GpuColumnVector.from(cb)
}
withResource(table) { _ =>
JCudfSerialization.writeToStream(table, out, 0, table.getRowCount)
}
out.writeObject(dataTypes)
}.getOrElse {
if (data == null || data.getTableHeader.getNumColumns == 0) {
JCudfSerialization.writeRowsToStream(out, numRows)
} else if (numRows == 0) {
// We didn't get any data back, but we need to write out an empty table that matches
withResource(GpuColumnVector.emptyHostColumns(dataTypes)) { hostVectors =>
JCudfSerialization.writeToStream(hostVectors, out, 0, 0)
}
out.writeObject(dataTypes)
} else {
val headers = Array(data.getTableHeader)
val buffers = Array(data.getHostBuffer)
JCudfSerialization.writeConcatedStream(headers, buffers, out)
out.writeObject(dataTypes)
}
}
}
/**
* Deserializes a broadcast result in the host into `data`, `numRows` and `dataLen`.
*
* Public for unit tests.
*/
def doReadObject(in: ObjectInputStream): Unit = this.synchronized {
// no-op if we already have `batchInternal` or `data` set
if (batchInternal == null && data == null) {
withResource(new NvtxRange("DeserializeBatch", NvtxColor.PURPLE)) { _ =>
val (header, buffer) = SerializedHostTableUtils.readTableHeaderAndBuffer(in)
withResource(buffer) { _ =>
dataTypes = if (header.getNumColumns > 0) {
in.readObject().asInstanceOf[Array[DataType]]
} else {
Array.empty
}
// for a rowsOnly broadcast, null out the `data` member.
val rowsOnly = dataTypes.isEmpty
numRows = header.getNumRows
dataLen = header.getDataLen
data = if (!rowsOnly) {
JCudfSerialization.concatToHostBuffer(Array(header), Array(buffer))
} else {
null
}
}
}
}
}
def dataSize: Long = dataLen
/**
* This method is meant to only be called from `finalize` and it is not a regular
* AutoCloseable.close because we do not want Spark to close `batchInternal` when it spills
* the broadcast block's host torrent data.
*
* Reference: https://github.com/NVIDIA/spark-rapids/issues/8602
*
* Public for tests.
*/
def closeInternal(): Unit = this.synchronized {
Seq(data, batchInternal).safeClose()
data = null
batchInternal = null
}
@scala.annotation.nowarn("msg=method finalize in class Object is deprecated")
override def finalize(): Unit = {
super.finalize()
closeInternal()
}
}
// scalastyle:on no.finalize
// scalastyle:off no.finalize
/**
* Object used for executors to serialize a result for their partition that will be collected
* on the driver to be broadcasted out as part of the exchange.
* @param batch - GPU batch to be serialized and sent to the driver.
*/
@SerialVersionUID(100L)
class SerializeBatchDeserializeHostBuffer(batch: ColumnarBatch)
extends Serializable with AutoCloseable {
@transient private var columns = GpuColumnVector.extractBases(batch).map(_.copyToHost())
@transient var header: JCudfSerialization.SerializedTableHeader = null
@transient var buffer: HostMemoryBuffer = null
@transient private var numRows = batch.numRows()
private def writeObject(out: ObjectOutputStream): Unit = {
withResource(new NvtxRange("SerializeBatch", NvtxColor.PURPLE)) { _ =>
if (buffer != null) {
throw new IllegalStateException("Cannot re-serialize a batch this way...")
} else {
JCudfSerialization.writeToStream(columns, out, 0, numRows)
// In this case an RDD, we want to close the batch once it is serialized out or we will
// leak GPU memory (technically it will just wait for GC to release it and probably
// not a lot because this is used for a broadcast that really should be small)
// In the case of broadcast the life cycle of the object is tied to GC and there is no clean
// way to separate the two right now. So we accept the leak.
columns.safeClose()
columns = null
}
}
}
private def readObject(in: ObjectInputStream): Unit = {
withResource(new NvtxRange("HostDeserializeBatch", NvtxColor.PURPLE)) { _ =>
val (h, b) = SerializedHostTableUtils.readTableHeaderAndBuffer(in)
// buffer will only be cleaned up on GC, so cannot warn about leaks
b.noWarnLeakExpected()
header = h
buffer = b
numRows = h.getNumRows
}
}
def dataSize: Long = {
JCudfSerialization.getSerializedSizeInBytes(columns, 0, numRows)
}
override def close(): Unit = {
columns.safeClose()
columns = null
buffer.safeClose()
buffer = null
}
@scala.annotation.nowarn("msg=method finalize in class Object is deprecated")
override def finalize(): Unit = {
super.finalize()
close()
}
}
class GpuBroadcastMeta(
exchange: BroadcastExchangeExec,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule) extends
SparkPlanMeta[BroadcastExchangeExec](exchange, conf, parent, rule) with Logging {
override def tagPlanForGpu(): Unit = {
if (!TrampolineUtil.isSupportedRelation(exchange.mode)) {
willNotWorkOnGpu(
s"unsupported BroadcastMode: ${exchange.mode}. " +
s"GPU supports only IdentityBroadcastMode and HashedRelationBroadcastMode")
}
def isSupported(rm: RapidsMeta[_, _, _]): Boolean = rm.wrapped match {
case _: BroadcastHashJoinExec => true
case _: BroadcastNestedLoopJoinExec => true
case _ => false
}
if (parent.isDefined) {
if (!parent.exists(isSupported)) {
willNotWorkOnGpu("BroadcastExchange only works on the GPU if being used " +
"with a GPU version of BroadcastHashJoinExec or BroadcastNestedLoopJoinExec")
}
}
}
override def convertToGpu(): GpuExec = {
GpuBroadcastExchangeExec(exchange.mode, childPlans.head.convertIfNeeded())(
exchange.canonicalized.asInstanceOf[BroadcastExchangeExec])
}
}
abstract class GpuBroadcastExchangeExecBase(
mode: BroadcastMode,
child: SparkPlan) extends ShimBroadcastExchangeLike with ShimUnaryExecNode with GpuExec {
override val outputRowsLevel: MetricsLevel = ESSENTIAL_LEVEL
override val outputBatchesLevel: MetricsLevel = MODERATE_LEVEL
override lazy val additionalMetrics = Map(
"dataSize" -> createSizeMetric(ESSENTIAL_LEVEL, "data size"),
COLLECT_TIME -> createNanoTimingMetric(ESSENTIAL_LEVEL, DESCRIPTION_COLLECT_TIME),
BUILD_TIME -> createNanoTimingMetric(ESSENTIAL_LEVEL, DESCRIPTION_BUILD_TIME),
"broadcastTime" -> createNanoTimingMetric(ESSENTIAL_LEVEL, "time to broadcast"))
override def outputPartitioning: Partitioning = BroadcastPartitioning(mode)
// For now all broadcasts produce a single batch. We might need to change that at some point
override def outputBatching: CoalesceGoal = RequireSingleBatch
@transient
protected val timeout: Long = SQLConf.get.broadcastTimeout
// prior to Spark 3.5.0, runId is defined as `def` rather than `val` so
// produces a new ID on each reference. We override with a `val` so that
// the value is assigned once.
override val runId: UUID = UUID.randomUUID
@transient
lazy val relationFuture: Future[Broadcast[Any]] = {
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
val numOutputBatches = gpuLongMetric(NUM_OUTPUT_BATCHES)
val numOutputRows = gpuLongMetric(NUM_OUTPUT_ROWS)
val dataSize = gpuLongMetric("dataSize")
val collectTime = gpuLongMetric(COLLECT_TIME)
val buildTime = gpuLongMetric(BUILD_TIME)
val broadcastTime = gpuLongMetric("broadcastTime")
val task = new Callable[Broadcast[Any]]() {
override def call(): Broadcast[Any] = {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sparkSession, executionId) {
try {
// Setup a job group here so later it may get cancelled by groupId if necessary.
sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId ${runId})",
interruptOnCancel = true)
val broadcastResult = {
val collected =
withResource(new NvtxWithMetrics("broadcast collect", NvtxColor.GREEN,
collectTime)) { _ =>
val childRdd = child.executeColumnar()
// collect batches from the executors
val data = childRdd.map(withResource(_) { cb =>
new SerializeBatchDeserializeHostBuffer(cb)
})
data.collect()
}
withResource(new NvtxWithMetrics("broadcast build", NvtxColor.DARK_GREEN,
buildTime)) { _ =>
val emptyRelation = if (collected.isEmpty) {
SparkShimImpl.tryTransformIfEmptyRelation(mode)
} else {
None
}
emptyRelation.getOrElse {
GpuBroadcastExchangeExecBase.makeBroadcastBatch(
collected, output, numOutputBatches, numOutputRows, dataSize)
}
}
}
val broadcasted =
withResource(new NvtxWithMetrics("broadcast", NvtxColor.CYAN,
broadcastTime)) { _ =>
// Broadcast the relation
sparkContext.broadcast(broadcastResult)
}
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
promise.success(broadcasted)
broadcasted
} catch {
// SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
// SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
// will catch this exception and re-throw the wrapped fatal throwable.
case oe: OutOfMemoryError =>
val ex = createOutOfMemoryException(oe)
promise.failure(ex)
throw ex
case e if !NonFatal(e) =>
val ex = new Exception(e)
promise.failure(ex)
throw ex
case e: Throwable =>
promise.failure(e)
throw e
}
}
}
}
GpuBroadcastExchangeExecBase.executionContext.submit[Broadcast[Any]](task)
}
protected def createOutOfMemoryException(oe: OutOfMemoryError) = {
new Exception(
new OutOfMemoryError("Not enough memory to build and broadcast the table to all " +
"worker nodes. As a workaround, you can either disable broadcast by setting " +
s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark " +
s"driver memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value.")
.initCause(oe.getCause))
}
override protected def doPrepare(): Unit = {
// Materialize the future.
relationFuture
}
override protected def doExecute(): RDD[InternalRow] = {
throw new UnsupportedOperationException(
"GpuBroadcastExchange does not support the execute() code path.")
}
override protected[sql] def doExecuteBroadcast[T](): Broadcast[T] = {
try {
relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[Broadcast[T]]
} catch {
case ex: TimeoutException =>
logError(s"Could not execute broadcast in $timeout secs.", ex)
if (!relationFuture.isDone) {
sparkContext.cancelJobGroup(runId.toString)
relationFuture.cancel(true)
}
throw new SparkException(s"Could not execute broadcast in $timeout secs. " +
s"You can increase the timeout for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or " +
s"disable broadcast join by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1",
ex)
}
}
final def executeColumnarBroadcast[T](): Broadcast[T] = {
if (isCanonicalizedPlan) {
throw new IllegalStateException("A canonicalized plan is not supposed to be executed.")
}
try {
relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[Broadcast[T]]
} catch {
case ex: TimeoutException =>
logError(s"Could not execute broadcast in $timeout secs.", ex)
if (!relationFuture.isDone) {
sparkContext.cancelJobGroup(runId.toString)
relationFuture.cancel(true)
}
throw new SparkException(s"Could not execute broadcast in $timeout secs. " +
s"You can increase the timeout for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or " +
s"disable broadcast join by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1",
ex)
}
}
override def runtimeStatistics: Statistics = {
Statistics(
sizeInBytes = metrics("dataSize").value,
rowCount = Some(metrics(GpuMetric.NUM_OUTPUT_ROWS).value))
}
override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = {
throw new IllegalStateException(s"Internal Error ${this.getClass} has column support" +
s" mismatch:\n$this")
}
}
object GpuBroadcastExchangeExecBase {
/**
* Create a thread factory that names threads with a prefix and also sets the threads to daemon.
*/
private def namedThreadFactory(prefix: String): ThreadFactory = {
new ThreadFactoryBuilder().setDaemon(true).setNameFormat(prefix + "-%d").build()
}
/**
* Create a cached thread pool whose max number of threads is `maxThreadNumber`. Thread names
* are formatted as prefix-ID, where ID is a unique, sequentially assigned integer.
*/
private def newDaemonCachedThreadPool(
prefix: String, maxThreadNumber: Int, keepAliveSeconds: Int = 60): ThreadPoolExecutor = {
val threadFactory = namedThreadFactory(prefix)
val threadPool = new ThreadPoolExecutor(
maxThreadNumber, // corePoolSize: the max number of threads to create before queuing the tasks
maxThreadNumber, // maximumPoolSize: because we use LinkedBlockingDeque, this one is not used
keepAliveSeconds,
TimeUnit.SECONDS,
new LinkedBlockingQueue[Runnable],
threadFactory)
threadPool.allowCoreThreadTimeOut(true)
threadPool
}
val executionContext = ExecutionContext.fromExecutorService(
newDaemonCachedThreadPool("gpu-broadcast-exchange",
SQLConf.get.getConf(StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD)))
protected def checkRowLimit(numRows: Int) = {
// Spark restricts the size of broadcast relations to be less than 512000000 rows and we
// enforce the same limit
// scalastyle:off line.size.limit
// https://github.com/apache/spark/blob/v3.1.1/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala#L586
// scalastyle:on line.size.limit
if (numRows >= 512000000) {
throw new SparkException(
s"Cannot broadcast the table with 512 million or more rows: $numRows rows")
}
}
protected def checkSizeLimit(sizeInBytes: Long) = {
// Spark restricts the size of broadcast relations to be less than 8GB
if (sizeInBytes >= MAX_BROADCAST_TABLE_BYTES) {
throw new SparkException(
s"Cannot broadcast the table that is larger than" +
s"${MAX_BROADCAST_TABLE_BYTES >> 30}GB: ${sizeInBytes >> 30} GB")
}
}
/**
* Concatenate deserialized host buffers into a single HostConcatResult that is then
* passed to a `SerializeConcatHostBuffersDeserializeBatch`.
*
* This result will in turn be broadcasted from the driver to the executors.
*/
def makeBroadcastBatch(
buffers: Array[SerializeBatchDeserializeHostBuffer],
output: Seq[Attribute],
numOutputBatches: GpuMetric,
numOutputRows: GpuMetric,
dataSize: GpuMetric): SerializeConcatHostBuffersDeserializeBatch = {
val rowsOnly = buffers.isEmpty || buffers.head.header.getNumColumns == 0
var numRows = 0
var dataLen: Long = 0
val hostConcatResult = if (rowsOnly) {
numRows = withResource(buffers) { _ =>
require(output.isEmpty,
"Rows-only broadcast resolved had non-empty " +
s"output ${output.mkString(",")}")
buffers.map(_.header.getNumRows).sum
}
checkRowLimit(numRows)
null
} else {
val hostConcatResult = withResource(buffers) { _ =>
JCudfSerialization.concatToHostBuffer(
buffers.map(_.header), buffers.map(_.buffer))
}
closeOnExcept(hostConcatResult) { _ =>
checkRowLimit(hostConcatResult.getTableHeader.getNumRows)
checkSizeLimit(hostConcatResult.getTableHeader.getDataLen)
}
// this result will be GC'ed later, so we mark it as such
hostConcatResult.getHostBuffer.noWarnLeakExpected()
numRows = hostConcatResult.getTableHeader.getNumRows
dataLen = hostConcatResult.getTableHeader.getDataLen
hostConcatResult
}
numOutputBatches += 1
numOutputRows += numRows
dataSize += dataLen
// create the batch we will broadcast out
new SerializeConcatHostBuffersDeserializeBatch(
hostConcatResult, output, numRows, dataLen)
}
}
case class GpuBroadcastExchangeExec(
mode: BroadcastMode,
child: SparkPlan)
(val cpuCanonical: BroadcastExchangeExec)
extends GpuBroadcastExchangeExecBase(mode, child) {
override def otherCopyArgs: Seq[AnyRef] = Seq(cpuCanonical)
private var _isGpuPlanningComplete = false
/**
* Returns true if this node and children are finished being optimized by the RAPIDS Accelerator.
*/
def isGpuPlanningComplete: Boolean = _isGpuPlanningComplete
/**
* Method to call after all RAPIDS Accelerator optimizations have been applied
* to indicate this node and its children are done being planned by the RAPIDS Accelerator.
* Some optimizations, such as AQE exchange reuse fixup, need to know when a node will no longer
* be updated so it can be tracked for reuse.
*/
def markGpuPlanningComplete(): Unit = {
if (!_isGpuPlanningComplete) {
_isGpuPlanningComplete = true
ExchangeMappingCache.trackExchangeMapping(cpuCanonical, this)
}
}
override def doCanonicalize(): SparkPlan = {
GpuBroadcastExchangeExec(mode.canonicalized, child.canonicalized)(cpuCanonical)
}
}
/** Caches the mappings from canonical CPU exchanges to the GPU exchanges that replaced them */
object ExchangeMappingCache extends Logging {
// Cache is a mapping from CPU broadcast plan to GPU broadcast plan. The cache should not
// artificially hold onto unused plans, so we make both the keys and values weak. The values
// point to their corresponding keys, so the keys will not be collected unless the value
// can be collected. The values will be held during normal Catalyst planning until those
// plans are no longer referenced, allowing both the key and value to be reaped at that point.
private val cache = new mutable.WeakHashMap[Exchange, WeakReference[Exchange]]
/** Try to find a recent GPU exchange that has replaced the specified CPU canonical plan. */
def findGpuExchangeReplacement(cpuCanonical: Exchange): Option[Exchange] = {
cache.get(cpuCanonical).flatMap(_.get)
}
/** Add a GPU exchange to the exchange cache */
def trackExchangeMapping(cpuCanonical: Exchange, gpuExchange: Exchange): Unit = {
val old = findGpuExchangeReplacement(cpuCanonical)
if (!old.exists(_.asInstanceOf[GpuBroadcastExchangeExec].isGpuPlanningComplete)) {
cache.put(cpuCanonical, WeakReference(gpuExchange))
}
}
}