Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #SPARK-1149 Bad partitioners can cause Spark to hang #44

Closed
wants to merge 18 commits into from
Closed
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,8 @@ class SparkContext(
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
val rddPartitions = rdd.partitions.map(_.index)
require(partitions.forall(rddPartitions.contains(_)), "partition index out of range")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind explaining a bit more the case where these two will not match? I'm just wondering if it make more sense to check this invariant inside of the getPartitions function of ShuffleRDD.scala - but maybe there are other code paths where this could get messed up that don't go through that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

 val partitioner = new Partitioner {
  override def numPartitions: Int = 2
  override def getPartition(key: Any): Int = key.hashCode() % 2
 }

val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (-1, 7)))
val shuffled = pairs.partitionBy(partitioner)
shuffled.count

Here error, is looking forward to the results.

val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (6, 7)))
val shuffled = pairs.partitionBy(partitioner)
shuffled.lookup(-1)

Although the log records the error, but Spark to hang

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just scanned the code, this issue (partitions does not match with rdd.partitions.map(_.index)) can only happen when you run the computation based on the correctness of partitioner

In current implementation, there are only two cases:

First is lookup, the computation is based on the correctness of getPartition()

def lookup(key: K): Seq[V] = {
    self.partitioner match {
      case Some(p) =>
        val index = p.getPartition(key)
        def process(it: Iterator[(K, V)]): Seq[V] = {
          val buf = new ArrayBuffer[V]
          for ((k, v) <- it if k == key) {
            buf += v
          }
          buf
        }
        val res = self.context.runJob(self, process _, Array(index), false)
        res(0)
      case None =>
        self.filter(_._1 == key).map(_._2).collect()
    }
  }

the other case is ShuffleMapTask

// Write the map output to its associated buckets.
      for (elem <- rdd.iterator(split, context)) {
        val pair = elem.asInstanceOf[Product2[Any, Any]]
        val bucketId = dep.partitioner.getPartition(pair._1)
        shuffle.writers(bucketId).write(pair)
      }

I'm not sure which fix option is better, add a checking condition in SparkContext, or we have a specific checking in these two places separately

I just felt that without looking at the code, I cannot get the idea why the partitions does not match rdd.partitions (if you look at how SparkContext run the job you will get more confusion, because partitions are exactly derived from "0 until rdd.partitions.size")

/**
   * Run a job on all partitions in an RDD and return the results in an array.
   */
  def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = {
    runJob(rdd, func, 0 until rdd.partitions.size, false)
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not understand what you mean.
Caused by improper design Partitioner partition index out of range.

val partitioner = new Partitioner {
 override def numPartitions: Int = 2
 override def getPartition(key: Any): Int = key.hashCode() % 2
}

partitioner.getPartition(-1) result -1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@witgo so in the case you mentioned, why not put this check in the constructor of ShuffleRDD? It seems more natural to check it there rather than inside of runjob.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correctness of the partitioner is related to the input key. The current example, if key> = 0 is no problem. In the constructor can not be detected in.

val callSite = getCallSite
val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite)
Expand Down Expand Up @@ -950,6 +952,8 @@ class SparkContext(
resultHandler: (Int, U) => Unit,
resultFunc: => R): SimpleFutureAction[R] =
{
val rddPartitions = rdd.partitions.map(_.index)
require(partitions.forall(rddPartitions.contains(_)), "partition index out of range")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check as written is going to have quadratic complexity. If you have 100 partitions for example, you're going to create a list of length 100 at the top and then check for all 100 partitions whether they're in that list, getting 10,000 operations. Can't you just check that all the indices in partitions are between 0 and rdd.partitions.size? I don't think RDDs can have non-contiguous partition numbers, though there might have been some stuff in the past with partition pruning that I may be misremembering.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, so it should be better.

require(partitions.toSet.diff(rdd.partitions.map(_.index).toSet).isEmpty, "partition index out of range")

or is this?

val partitionRange = (0 until rdd.partitions.size)
require(partitions.forall(partitionRange.contains(_)), "partition index out of range")

val cleanF = clean(processPartition)
val callSite = getCallSite
val waiter = dagScheduler.submitJob(
Expand Down