所有文章

Spark – 源码分析(四)

Task 的执行

上篇讲到TaskRunner类,这个类是定义在Executor类中,实现了Java的Runnable接口,Spark运行过程中的一个Task就是一个TaskRunner实例,下面是它的 run() 方法

  override def run(): Unit = {
    threadId = Thread.currentThread.getId
    Thread.currentThread.setName(threadName)
    val threadMXBean = ManagementFactory.getThreadMXBean
    val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
    val deserializeStartTime = System.currentTimeMillis()
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime
    } else 0L
    Thread.currentThread.setContextClassLoader(replClassLoader)
    val ser = env.closureSerializer.newInstance()
    logInfo(s"Running $taskName (TID $taskId)")
    execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
    var taskStart: Long = 0
    var taskStartCpu: Long = 0
    startGCTime = computeTotalGcTime()
 
    try {
      // Must be set before updateDependencies() is called, in case fetching dependencies
      // requires access to properties contained within (e.g. for access control).
      Executor.taskDeserializationProps.set(taskDescription.properties)
 
      updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
      task = ser.deserialize[Task[Any]](
        taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
      task.localProperties = taskDescription.properties
      task.setTaskMemoryManager(taskMemoryManager)
...
      // Run the actual task and measure its runtime.
      taskStart = System.currentTimeMillis()
      taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
        threadMXBean.getCurrentThreadCpuTime
      } else 0L
      var threwException = true
      val value = try {
        val res = task.run(
          taskAttemptId = taskId,
          attemptNumber = taskDescription.attemptNumber,
          metricsSystem = env.metricsSystem)
        threwException = false
        res
      } finally {
        val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
        val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
 
        if (freedMemory > 0 && !threwException) {
          val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
          if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
            throw new SparkException(errMsg)
          } else {
            logWarning(errMsg)
          }
        }
...
      val resultSer = env.serializer.newInstance()
      val beforeSerialization = System.currentTimeMillis()
      val valueBytes = resultSer.serialize(value)
      val afterSerialization = System.currentTimeMillis()
...
      execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
...

看起来有点长,只贴出部分,在这里分步总结一下:

  1. 将传进来的taskDescription对象反序列化为Task实例,这个实例是Task的子类,就是之前生成Task时的两种类型,ResultTask和SuffleMapTask,我们就用ResultTask为例
  2. 调用Task的 run() 方法,run() 又调用 runTask(),如下是ResultTask类中的 runTask() 方法
    ”`
    override def runTask(context: TaskContext): U = {
    // Deserialize the RDD and the func using the broadcast variables.
    val threadMXBean = ManagementFactory.getThreadMXBean
    val deserializeStartTime = System.currentTimeMillis()
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    threadMXBean.getCurrentThreadCpuTime
    } else 0L
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, func) = ser.deserialize(RDD[T], (TaskContext, Iterator[T]) => U), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
    _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
    } else 0L

func(context, rdd.iterator(partition, context))
}

方法内将该Task内包含的RDD与闭包都反序列化出来,然后执行这个闭包,闭包来处理RDD数据 
  3. Task的 `run()` 方法运行结束后就回到了TaskRunner中的 `run()` 方法,且返回了结果数据,这代表Task的主要任务已经完成了,所心要收集一些信息,包括:该Task的运行时长,CPU时长,GC时长等,这些数据应该会发给UI端,但我没有细找这一块
  4. 然后将结果数据序列化,最后通过ExecutorBackend的 `statusUpdate()` 方法将结果传回到SparkDirver,ExecutorBackend是个物特质,在上一篇中提到过,CoarseGrainedExecutorBackend是其中的一个实现,它对` statusUpdate()` 方法的实现如下

override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
val msg = StatusUpdate(executorId, taskId, state, data)
driver match {

case Some(driverRef) => driverRef.send(msg)
case None => logWarning(s"Drop $msg because has not yet connected to driver")

}
}

将data封装成StatusUpdate,将其作为消息通过RPC框架发送给Driver端 
SparkDriver端的接收逻辑
override def receive: PartialFunction[Any, Unit] = {
  case StatusUpdate(executorId, taskId, state, data) =>
    scheduler.statusUpdate(taskId, state, data.value)
    if (TaskState.isFinished(state)) {
      executorDataMap.get(executorId) match {
        case Some(executorInfo) =>
          executorInfo.freeCores += scheduler.CPUS_PER_TASK
          makeOffers(executorId)
        case None =>
          // Ignoring the update since we don't know about the executor.
          logWarning(s"Ignored task status update ($taskId state $state) " +
            s"from unknown executor with ID $executorId")
      }
    }

Driver收到StatusUpdate后,处理流程如下: 
  1. 先调用了调度器的 `statusUpdate()` 方法,调度器将数据返序列出来,将Task标记为完成状态
  2. 然后调度器又调用了DAGScheduler的post方法,将Task和数据封装为CompletionEvent放入DAGScheduler类的事件池eventProcessLoop中等待处理
  3. 这时回到` receive()` 中,回收该Task的CPU资源
  4. 调用 `makeOffers(executorId)` 方法启动下一个Task任务,直到所有Task都执行完
  5. 因为DAGScheduler的事件池是被一直被循环取出来处理的,所以刚才完放进去的CompletionEvent很快会被DGAScheduler的 `doOnReceive()` 方法取出来做出相应的动作

private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {

case completion: CompletionEvent =>

dagScheduler.handleTaskCompletion(completion)


看handleTaskCompletion()的处理逻辑

private[scheduler] def handleTaskCompletion(event: CompletionEvent) {

val task = event.task
val taskId = event.taskInfo.id
val stageId = task.stageId
val taskType = Utils.getFormattedClassName(task)

// The stage may have already finished when we get this event -- eg. maybe it was a
// speculative task. It is important that we send the TaskEnd event in any case, so listeners
// are properly notified and can chose to handle it. For instance, some listeners are
// doing their own accounting and if they don't get the task end event they think
// tasks are still running when they really aren't.
listenerBus.post(SparkListenerTaskEnd(
   stageId, task.stageAttemptId, taskType, event.reason, event.taskInfo, taskMetrics))

val stage = stageIdToStage(task.stageId)
event.reason match {
  case Success =>
    task match {
      case rt: ResultTask[_, _] =>
        // Cast to ResultStage here because it's part of the ResultTask
        // TODO Refactor this out to a function that accepts a ResultStage
        val resultStage = stage.asInstanceOf[ResultStage]
        resultStage.activeJob match {
          case Some(job) =>
            if (!job.finished(rt.outputId)) {
              updateAccumulators(event)
              job.finished(rt.outputId) = true
              job.numFinished += 1
              // If the whole job has finished, remove it
              if (job.numFinished == job.numPartitions) {
                markStageAsFinished(resultStage)
                cleanupStateForJobAndIndependentStages(job)
                listenerBus.post(
                  SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded))
              }
              // taskSucceeded runs some user code that might throw an exception. Make sure
              // we are resilient against that.
              try {
                job.listener.taskSucceeded(rt.outputId, event.result)
              } catch {
                case e: Exception =>
                  // TODO: Perhaps we want to mark the resultStage as failed?
                  job.listener.jobFailed(new SparkDriverExecutionException(e))
              }
            }

这的逻辑实在有点长,就只贴出一部分来
1. `listenerBus.post` 是通知所有监听者此Task执行结束,listenerBus显然是一个集合,里面有各路监听者,当他们得知此消息后可以做出相应的动作
2. `SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded))` 如果这个Task是ResultTask类型,则通知所有监听者此Job执行结束
3. `job.listener.taskSucceeded(rt.outputId, event.result)` 和 `job.listener.jobFailed(new SparkDriverExecutionException(e))` 行是通知JobWaiter对象Task执行结果,然后JobWaiter就可以结束等待了
4. 这时就回到了当初提交Job的地方

def runJob[T, U](

rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: CallSite,
resultHandler: (Int, U) => Unit,
properties: Properties): Unit = {

val start = System.nanoTime
val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)
// Note: Do not call Await.ready(future) because that calls scala.concurrent.blocking,
// which causes concurrent SQL executions to fail if a fork-join pool is used. Note that
// due to idiosyncrasies in Scala, awaitPermission is not actually used anywhere so it’s
// safe to pass in null here. For more detail, see SPARK-13747.
val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait]
waiter.completionFuture.ready(Duration.Inf)(awaitPermission)
waiter.completionFuture.value.get match {

case scala.util.Success(_) =>
  logInfo("Job %d finished: %s, took %f s".format
    (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
case scala.util.Failure(exception) =>
  logInfo("Job %d failed: %s, took %f s".format
    (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
  // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler.
  val callerStackTrace = Thread.currentThread().getStackTrace.tail
  exception.setStackTrace(exception.getStackTrace ++ callerStackTrace)
  throw exception

}
}
JobWaiter等待结束后,case scala.util.Success(_)case scala.util.Failure(exception)` 行会打印相关日志,此时提交Job的整个流程也就结束了。

能力有限,文章可能看起来有些乱,博主还没有理清全部流程,后面会抽时间出来画张流程图,帮助理解。


编写日期:2017-05-06