Skip to content

Commit

Permalink
Modifications based on the comments on PR 126.
Browse files Browse the repository at this point in the history
  • Loading branch information
tdas committed Mar 13, 2014
1 parent ae9da88 commit e61daa0
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 60 deletions.
55 changes: 32 additions & 23 deletions core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}

import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}

import org.apache.spark.rdd.RDD

/** Listener class used for testing when any item has been cleaned by the Cleaner class */
private[spark] trait CleanerListener {
def rddCleaned(rddId: Int)
Expand All @@ -32,12 +30,12 @@ private[spark] trait CleanerListener {
/**
* Cleans RDDs and shuffle data.
*/
private[spark] class ContextCleaner(env: SparkEnv) extends Logging {
private[spark] class ContextCleaner(sc: SparkContext) extends Logging {

/** Classes to represent cleaning tasks */
private sealed trait CleaningTask
private case class CleanRDD(sc: SparkContext, id: Int) extends CleaningTask
private case class CleanShuffle(id: Int) extends CleaningTask
private case class CleanRDD(rddId: Int) extends CleaningTask
private case class CleanShuffle(shuffleId: Int) extends CleaningTask
// TODO: add CleanBroadcast

private val queue = new LinkedBlockingQueue[CleaningTask]
Expand All @@ -47,7 +45,7 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging {

private val cleaningThread = new Thread() { override def run() { keepCleaning() }}

private var stopped = false
@volatile private var stopped = false

/** Start the cleaner */
def start() {
Expand All @@ -57,26 +55,37 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging {

/** Stop the cleaner */
def stop() {
synchronized { stopped = true }
stopped = true
cleaningThread.interrupt()
}

/** Clean (unpersist) RDD data. */
def cleanRDD(rdd: RDD[_]) {
enqueue(CleanRDD(rdd.sparkContext, rdd.id))
logDebug("Enqueued RDD " + rdd + " for cleaning up")
/**
* Clean (unpersist) RDD data. Do not perform any time or resource intensive
* computation in this function as this is called from a finalize() function.
*/
def cleanRDD(rddId: Int) {
enqueue(CleanRDD(rddId))
logDebug("Enqueued RDD " + rddId + " for cleaning up")
}

/** Clean shuffle data. */
/**
* Clean shuffle data. Do not perform any time or resource intensive
* computation in this function as this is called from a finalize() function.
*/
def cleanShuffle(shuffleId: Int) {
enqueue(CleanShuffle(shuffleId))
logDebug("Enqueued shuffle " + shuffleId + " for cleaning up")
}

/** Attach a listener object to get information of when objects are cleaned. */
def attachListener(listener: CleanerListener) {
listeners += listener
}
/** Enqueue a cleaning task */

/**
* Enqueue a cleaning task. Do not perform any time or resource intensive
* computation in this function as this is called from a finalize() function.
*/
private def enqueue(task: CleaningTask) {
queue.put(task)
}
Expand All @@ -86,24 +95,24 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging {
try {
while (!isStopped) {
val taskOpt = Option(queue.poll(100, TimeUnit.MILLISECONDS))
if (taskOpt.isDefined) {
taskOpt.foreach(task => {
logDebug("Got cleaning task " + taskOpt.get)
taskOpt.get match {
case CleanRDD(sc, rddId) => doCleanRDD(sc, rddId)
task match {
case CleanRDD(rddId) => doCleanRDD(sc, rddId)
case CleanShuffle(shuffleId) => doCleanShuffle(shuffleId)
}
}
})
}
} catch {
case ie: java.lang.InterruptedException =>
case ie: InterruptedException =>
if (!isStopped) logWarning("Cleaning thread interrupted")
}
}

/** Perform RDD cleaning */
private def doCleanRDD(sc: SparkContext, rddId: Int) {
logDebug("Cleaning rdd " + rddId)
sc.env.blockManager.master.removeRdd(rddId, false)
blockManagerMaster.removeRdd(rddId, false)
sc.persistentRdds.remove(rddId)
listeners.foreach(_.rddCleaned(rddId))
logInfo("Cleaned rdd " + rddId)
Expand All @@ -113,14 +122,14 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging {
private def doCleanShuffle(shuffleId: Int) {
logDebug("Cleaning shuffle " + shuffleId)
mapOutputTrackerMaster.unregisterShuffle(shuffleId)
blockManager.master.removeShuffle(shuffleId)
blockManagerMaster.removeShuffle(shuffleId)
listeners.foreach(_.shuffleCleaned(shuffleId))
logInfo("Cleaned shuffle " + shuffleId)
}

private def mapOutputTrackerMaster = env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]

private def blockManager = env.blockManager
private def blockManagerMaster = sc.env.blockManager.master

private def isStopped = synchronized { stopped }
private def isStopped = stopped
}
52 changes: 24 additions & 28 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,18 @@

package org.apache.spark

import scala.Some
import scala.collection.mutable.{HashSet, Map}
import scala.concurrent.Await

import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}

import scala.collection.mutable.HashSet
import scala.Some
import scala.collection.mutable.{HashSet, Map}
import scala.concurrent.Await

import akka.actor._
import akka.pattern.ask

import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.{AkkaUtils, TimeStampedHashMap, BoundedHashMap}
import org.apache.spark.util._

private[spark] sealed trait MapOutputTrackerMessage
private[spark] case class GetMapOutputStatuses(shuffleId: Int)
Expand All @@ -55,7 +51,7 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
}

/**
* Class that keeps track of the location of the location of the mapt output of
* Class that keeps track of the location of the location of the map output of
* a stage. This is abstract because different versions of MapOutputTracker
* (driver and worker) use different HashMap to store its metadata.
*/
Expand Down Expand Up @@ -155,10 +151,6 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}
}

protected def cleanup(cleanupTime: Long) {
mapStatuses.asInstanceOf[TimeStampedHashMap[_, _]].clearOldValues(cleanupTime)
}

def stop() {
communicate(StopMapOutputTracker)
mapStatuses.clear()
Expand Down Expand Up @@ -195,10 +187,13 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
/**
* Bounded HashMap for storing serialized statuses in the worker. This allows
* the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be
* automatically repopulated by fetching them again from the driver.
* automatically repopulated by fetching them again from the driver. Its okay to
* keep the cache size small as it unlikely that there will be a very large number of
* stages active simultaneously in the worker.
*/
protected val MAX_MAP_STATUSES = 100
protected val mapStatuses = new BoundedHashMap[Int, Array[MapStatus]](MAX_MAP_STATUSES, true)
protected val mapStatuses = new BoundedHashMap[Int, Array[MapStatus]](
conf.getInt("spark.mapOutputTracker.cacheSize", 100), true
)
}

/**
Expand All @@ -212,20 +207,18 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
private var cacheEpoch = epoch

/**
* Timestamp based HashMap for storing mapStatuses in the master, so that statuses are dropped
* only by explicit deregistering or by ttl-based cleaning (if set). Other than these two
* Timestamp based HashMap for storing mapStatuses and cached serialized statuses
* in the master, so that statuses are dropped only by explicit deregistering or
* by TTL-based cleaning (if set). Other than these two
* scenarios, nothing should be dropped from this HashMap.
*/

protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]()
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]()

/**
* Bounded HashMap for storing serialized statuses in the master. This allows
* the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be
* automatically repopulated by serializing the lost statuses again .
*/
protected val MAX_SERIALIZED_STATUSES = 100
private val cachedSerializedStatuses =
new BoundedHashMap[Int, Array[Byte]](MAX_SERIALIZED_STATUSES, true)
// For cleaning up TimeStampedHashMaps
private val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf)

def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
Expand Down Expand Up @@ -264,6 +257,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)

def unregisterShuffle(shuffleId: Int) {
mapStatuses.remove(shuffleId)
cachedSerializedStatuses.remove(shuffleId)
}

def incrementEpoch() {
Expand Down Expand Up @@ -303,20 +297,22 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
}

def contains(shuffleId: Int): Boolean = {
mapStatuses.contains(shuffleId)
cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId)
}

override def stop() {
super.stop()
metadataCleaner.cancel()
cachedSerializedStatuses.clear()
}

override def updateEpoch(newEpoch: Long) {
// This might be called on the MapOutputTrackerMaster if we're running in local mode.
}

def has(shuffleId: Int): Boolean = {
cachedSerializedStatuses.get(shuffleId).isDefined || mapStatuses.contains(shuffleId)
protected def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
}
}

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ class SparkContext(
@volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler)
dagScheduler.start()

private[spark] val cleaner = new ContextCleaner(env)
private[spark] val cleaner = new ContextCleaner(this)
cleaner.start()

ui.start()
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,7 @@ abstract class RDD[T: ClassTag](

def cleanup() {
logInfo("Cleanup called on RDD " + id)
sc.cleaner.cleanRDD(this)
sc.cleaner.cleanRDD(id)
dependencies.filter(_.isInstanceOf[ShuffleDependency[_, _]])
.map(_.asInstanceOf[ShuffleDependency[_, _]].shuffleId)
.foreach(sc.cleaner.cleanShuffle)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ class DAGScheduler(
: Stage =
{
val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite)
if (mapOutputTracker.has(shuffleDep.shuffleId)) {
if (mapOutputTracker.contains(shuffleDep.shuffleId)) {
val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
for (i <- 0 until locs.size) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,14 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
throw new IllegalStateException("Failed to find shuffle block: " + id)
}

/** Remove all the blocks / files related to a particular shuffle */
/** Remove all the blocks / files and metadata related to a particular shuffle */
def removeShuffle(shuffleId: ShuffleId) {
removeShuffleBlocks(shuffleId)
shuffleStates.remove(shuffleId)
}

/** Remove all the blocks / files related to a particular shuffle */
private def removeShuffleBlocks(shuffleId: ShuffleId) {
shuffleStates.get(shuffleId) match {
case Some(state) =>
if (consolidateShuffleFiles) {
Expand All @@ -194,7 +200,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
}

private def cleanup(cleanupTime: Long) {
shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffle(shuffleId))
shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ private[spark] class MetadataCleaner(

private[spark] object MetadataCleanerType extends Enumeration {

val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, BLOCK_MANAGER,
SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS, CLEANER = Value
val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, BLOCK_MANAGER,
SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value

type MetadataCleanerType = Value

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
val rdd = newRDD.persist()
rdd.count()
val tester = new CleanerTester(sc, rddIds = Seq(rdd.id))
cleaner.cleanRDD(rdd)
cleaner.cleanRDD(rdd.id)
tester.assertCleanup
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,4 +206,4 @@ class TestMap[A, B] extends WrappedJavaHashMap[A, B, A, B] {
protected[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = {
new TestMap[K1, V1]
}
}
}

0 comments on commit e61daa0

Please sign in to comment.