Skip to content

Commit

Permalink
[SPARK-17415][SQL] Better error message for driver-side broadcast joi…
Browse files Browse the repository at this point in the history
…n OOMs

## What changes were proposed in this pull request?

This is a trivial patch that catches all `OutOfMemoryError` while building the broadcast hash relation and rethrows it by wrapping it in a nice error message.

## How was this patch tested?

Existing Tests

Author: Sameer Agarwal <sameerag@cs.berkeley.edu>

Closes apache#14979 from sameeragarwal/broadcast-join-error.
  • Loading branch information
sameeragarwal authored and hvanhovell committed Sep 11, 2016
1 parent 883c763 commit 767d480
Showing 1 changed file with 42 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._

import org.apache.spark.{broadcast, SparkException}
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning}
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.ThreadUtils

/**
Expand Down Expand Up @@ -70,38 +72,47 @@ case class BroadcastExchangeExec(
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sparkContext, executionId) {
val beforeCollect = System.nanoTime()
// Note that we use .executeCollect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = child.executeCollect()
if (input.length >= 512000000) {
throw new SparkException(
s"Cannot broadcast the table with more than 512 millions rows: ${input.length} rows")
try {
val beforeCollect = System.nanoTime()
// Note that we use .executeCollect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = child.executeCollect()
if (input.length >= 512000000) {
throw new SparkException(
s"Cannot broadcast the table with more than 512 millions rows: ${input.length} rows")
}
val beforeBuild = System.nanoTime()
longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000
val dataSize = input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
longMetric("dataSize") += dataSize
if (dataSize >= (8L << 30)) {
throw new SparkException(
s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB")
}

// Construct and broadcast the relation.
val relation = mode.transform(input)
val beforeBroadcast = System.nanoTime()
longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000

val broadcasted = sparkContext.broadcast(relation)
longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000

// There are some cases we don't care about the metrics and call `SparkPlan.doExecute`
// directly without setting an execution id. We should be tolerant to it.
if (executionId != null) {
sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates(
executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq))
}

broadcasted
} catch {
case oe: OutOfMemoryError =>
throw new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " +
s"all worker nodes. As a workaround, you can either disable broadcast by setting " +
s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark driver " +
s"memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value")
.initCause(oe.getCause)
}
val beforeBuild = System.nanoTime()
longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000
val dataSize = input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
longMetric("dataSize") += dataSize
if (dataSize >= (8L << 30)) {
throw new SparkException(
s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB")
}

// Construct and broadcast the relation.
val relation = mode.transform(input)
val beforeBroadcast = System.nanoTime()
longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000

val broadcasted = sparkContext.broadcast(relation)
longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000

// There are some cases we don't care about the metrics and call `SparkPlan.doExecute`
// directly without setting an execution id. We should be tolerant to it.
if (executionId != null) {
sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates(
executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq))
}

broadcasted
}
}(BroadcastExchangeExec.executionContext)
}
Expand Down

0 comments on commit 767d480

Please sign in to comment.