所有文章

Spark – 源码分析(四)

Task 的执行

上篇讲到TaskRunner类,这个类是定义在Executor类中,实现了Java的Runnable接口,Spark运行过程中的一个Task就是一个TaskRunner实例,下面是它的 run() 方法

  override def run(): Unit = {
    threadId = Thread.currentThread.getId
    Thread.currentThread.setName(threadName)
    val threadMXBean = ManagementFactory.getThreadMXBean
    val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
    val deserializeStartTime = System.currentTimeMillis()
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime
    } else 0L
    Thread.currentThread.setContextClassLoader(replClassLoader)
    val ser = env.closureSerializer.newInstance()
    logInfo(s"Running $taskName (TID $taskId)")
    execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
    var taskStart: Long = 0
    var taskStartCpu: Long = 0
    startGCTime = computeTotalGcTime()
 
    try {
      // Must be set before updateDependencies() is called, in case fetching dependencies
      // requires access to properties contained within (e.g. for access control).
      Executor.taskDeserializationProps.set(taskDescription.properties)
 
      updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
      task = ser.deserialize[Task[Any]](
        taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
      task.localProperties = taskDescription.properties
      task.setTaskMemoryManager(taskMemoryManager)
...
      // Run the actual task and measure its runtime.
      taskStart = System.currentTimeMillis()
      taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
        threadMXBean.getCurrentThreadCpuTime
      } else 0L
      var threwException = true
      val value = try {
        val res = task.run(
          taskAttemptId = taskId,
          attemptNumber = taskDescription.attemptNumber,
          metricsSystem = env.metricsSystem)
        threwException = false
        res
      } finally {
        val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
        val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
 
        if (freedMemory > 0 && !threwException) {
          val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
          if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
            throw new SparkException(errMsg)
          } else {
            logWarning(errMsg)
          }
        }
...
      val resultSer = env.serializer.newInstance()
      val beforeSerialization = System.currentTimeMillis()
      val valueBytes = resultSer.serialize(value)
      val afterSerialization = System.currentTimeMillis()
...
      execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
...

看起来有点长,只贴出部分,在这里分步总结一下:

  1. 将传进来的taskDescription对象反序列化为Task实例,这个实例是Task的子类,就是之前生成Task时的两种类型,ResultTask和SuffleMapTask,我们就用ResultTask为例
  2. 调用Task的 run() 方法,run() 又调用 runTask(),如下是ResultTask类中的 runTask() 方法
override def runTask(context: TaskContext): U = {
  // Deserialize the RDD and the func using the broadcast variables.
  val threadMXBean = ManagementFactory.getThreadMXBean
  val deserializeStartTime = System.currentTimeMillis()
  val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    threadMXBean.getCurrentThreadCpuTime
  } else 0L
  val ser = SparkEnv.get.closureSerializer.newInstance()
  val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
    ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
  _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
  _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
  } else 0L
 
  func(context, rdd.iterator(partition, context))
}

方法内将该Task内包含的RDD与闭包都反序列化出来,然后执行这个闭包,闭包来处理RDD数据

  1. Task的 run() 方法运行结束后就回到了TaskRunner中的 run() 方法,且返回了结果数据,这代表Task的主要任务已经完成了,所心要收集一些信息,包括:该Task的运行时长,CPU时长,GC时长等,这些数据应该会发给UI端,但我没有细找这一块
  2. 然后将结果数据序列化,最后通过ExecutorBackend的 statusUpdate() 方法将结果传回到SparkDirver,ExecutorBackend是个物特质,在上一篇中提到过,CoarseGrainedExecutorBackend是其中的一个实现,它对statusUpdate() 方法的实现如下
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
  val msg = StatusUpdate(executorId, taskId, state, data)
  driver match {
    case Some(driverRef) => driverRef.send(msg)
    case None => logWarning(s"Drop $msg because has not yet connected to driver")
  }
}

将data封装成StatusUpdate,将其作为消息通过RPC框架发送给Driver端 SparkDriver端的接收逻辑

    override def receive: PartialFunction[Any, Unit] = {
      case StatusUpdate(executorId, taskId, state, data) =>
        scheduler.statusUpdate(taskId, state, data.value)
        if (TaskState.isFinished(state)) {
          executorDataMap.get(executorId) match {
            case Some(executorInfo) =>
              executorInfo.freeCores += scheduler.CPUS_PER_TASK
              makeOffers(executorId)
            case None =>
              // Ignoring the update since we don't know about the executor.
              logWarning(s"Ignored task status update ($taskId state $state) " +
                s"from unknown executor with ID $executorId")
          }
        }
...

Driver收到StatusUpdate后,处理流程如下:

  1. 先调用了调度器的 statusUpdate() 方法,调度器将数据返序列出来,将Task标记为完成状态
  2. 然后调度器又调用了DAGScheduler的post方法,将Task和数据封装为CompletionEvent放入DAGScheduler类的事件池eventProcessLoop中等待处理
  3. 这时回到receive() 中,回收该Task的CPU资源
  4. 调用 makeOffers(executorId) 方法启动下一个Task任务,直到所有Task都执行完
  5. 因为DAGScheduler的事件池是被一直被循环取出来处理的,所以刚才完放进去的CompletionEvent很快会被DGAScheduler的 doOnReceive() 方法取出来做出相应的动作
private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {
...
  case completion: CompletionEvent =>
    dagScheduler.handleTaskCompletion(completion)
...

看handleTaskCompletion()的处理逻辑

private[scheduler] def handleTaskCompletion(event: CompletionEvent) {
    val task = event.task
    val taskId = event.taskInfo.id
    val stageId = task.stageId
    val taskType = Utils.getFormattedClassName(task)
...
 
    // The stage may have already finished when we get this event -- eg. maybe it was a
    // speculative task. It is important that we send the TaskEnd event in any case, so listeners
    // are properly notified and can chose to handle it. For instance, some listeners are
    // doing their own accounting and if they don't get the task end event they think
    // tasks are still running when they really aren't.
    listenerBus.post(SparkListenerTaskEnd(
       stageId, task.stageAttemptId, taskType, event.reason, event.taskInfo, taskMetrics))
...
 
    val stage = stageIdToStage(task.stageId)
    event.reason match {
      case Success =>
        task match {
          case rt: ResultTask[_, _] =>
            // Cast to ResultStage here because it's part of the ResultTask
            // TODO Refactor this out to a function that accepts a ResultStage
            val resultStage = stage.asInstanceOf[ResultStage]
            resultStage.activeJob match {
              case Some(job) =>
                if (!job.finished(rt.outputId)) {
                  updateAccumulators(event)
                  job.finished(rt.outputId) = true
                  job.numFinished += 1
                  // If the whole job has finished, remove it
                  if (job.numFinished == job.numPartitions) {
                    markStageAsFinished(resultStage)
                    cleanupStateForJobAndIndependentStages(job)
                    listenerBus.post(
                      SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded))
                  }
                  // taskSucceeded runs some user code that might throw an exception. Make sure
                  // we are resilient against that.
                  try {
                    job.listener.taskSucceeded(rt.outputId, event.result)
                  } catch {
                    case e: Exception =>
                      // TODO: Perhaps we want to mark the resultStage as failed?
                      job.listener.jobFailed(new SparkDriverExecutionException(e))
                  }
                }
...

这的逻辑实在有点长,就只贴出一部分来

  1. listenerBus.post 是通知所有监听者此Task执行结束,listenerBus显然是一个集合,里面有各路监听者,当他们得知此消息后可以做出相应的动作
  2. SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) 如果这个Task是ResultTask类型,则通知所有监听者此Job执行结束
  3. job.listener.taskSucceeded(rt.outputId, event.result)job.listener.jobFailed(new SparkDriverExecutionException(e)) 行是通知JobWaiter对象Task执行结果,然后JobWaiter就可以结束等待了
  4. 这时就回到了当初提交Job的地方
def runJob[T, U](
    rdd: RDD[T],
    func: (TaskContext, Iterator[T]) => U,
    partitions: Seq[Int],
    callSite: CallSite,
    resultHandler: (Int, U) => Unit,
    properties: Properties): Unit = {
  val start = System.nanoTime
  val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)
  // Note: Do not call Await.ready(future) because that calls `scala.concurrent.blocking`,
  // which causes concurrent SQL executions to fail if a fork-join pool is used. Note that
  // due to idiosyncrasies in Scala, `awaitPermission` is not actually used anywhere so it's
  // safe to pass in null here. For more detail, see SPARK-13747.
  val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait]
  waiter.completionFuture.ready(Duration.Inf)(awaitPermission)
  waiter.completionFuture.value.get match {
    case scala.util.Success(_) =>
      logInfo("Job %d finished: %s, took %f s".format
        (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
    case scala.util.Failure(exception) =>
      logInfo("Job %d failed: %s, took %f s".format
        (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
      // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler.
      val callerStackTrace = Thread.currentThread().getStackTrace.tail
      exception.setStackTrace(exception.getStackTrace ++ callerStackTrace)
      throw exception
  }
}

JobWaiter等待结束后,case scala.util.Success(_)case scala.util.Failure(exception) 行会打印相关日志,此时提交Job的整个流程也就结束了。

能力有限,文章可能看起来有些乱,博主还没有理清全部流程,后面会抽时间出来画张流程图,帮助理解。


编写日期:2017-05-06