所有文章

Spark - 源码分析(三)

生成Task

上篇中我们讲到Spark在提交一个Stage后,DAGScheduler中的submitStage()方法会以递归的方式找到该Stage依赖的最上层的父Stage,找到后会将这个最上层的Stage传给submitMissingTasks()方法,该方法定义如下:

...
  /** Called when stage's parents are available and we can now do its task. */
  private def submitMissingTasks(stage: Stage, jobId: Int) {
    logDebug("submitMissingTasks(" + stage + ")")
    ...
    runningStages += stage
    stage match {
      case s: ShuffleMapStage =>
        outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId = s.numPartitions - 1)
      case s: ResultStage =>
        outputCommitCoordinator.stageStart(
          stage = s.id, maxPartitionId = s.rdd.partitions.length - 1)
    }
    val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try {
      stage match {
        case s: ShuffleMapStage =>
          partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
        case s: ResultStage =>
          partitionsToCompute.map { id =>
            val p = s.partitions(id)
            (id, getPreferredLocs(stage.rdd, p))
          }.toMap
      }
    ...
    taskBinary = sc.broadcast(taskBinaryBytes)
...

上面只贴出了该方法的一部分,我们刚才说的父Stage被传进来后,做了几件事情,我们分步骤说一下:

提交Task

TaskSchedulerImpl是TaskScheduler的子类,它重写了父类的submitTasks方法,上面的taskScheduler.submitTasks()调用的其实就是TaskSchedulerImpl类重写后的方法,原型如下:

override def submitTasks(taskSet: TaskSet) {
    val tasks = taskSet.tasks
    logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
    this.synchronized {
      val manager = createTaskSetManager(taskSet, maxTaskFailures)
      val stage = taskSet.stageId
      val stageTaskSets =
        taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
      stageTaskSets(taskSet.stageAttemptId) = manager
      val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
        ts.taskSet != taskSet && !ts.isZombie
      }
      if (conflictingTaskSet) {
        throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
          s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
      }
      schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
 
      if (!isLocal && !hasReceivedTask) {
        starvationTimer.scheduleAtFixedRate(new TimerTask() {
          override def run() {
            if (!hasLaunchedTask) {
              logWarning("Initial job has not accepted any resources; " +
                "check your cluster UI to ensure that workers are registered " +
                "and have sufficient resources")
            } else {
              this.cancel()
            }
          }
        }, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS)
      }
      hasReceivedTask = true
    }
    backend.reviveOffers()
  }

这个方法最后用到了backend对象,这个对象是SchedulerBackend类型,这一个特质(相当于Java中的接口),其下面有多个实现,它是在创建SparkContext的时候,根据用户指定的–master参数来创建的,如果指定的Yarn模式的话,这个具体的实现类是YarnClusterSchedulerBackend,它又是CoarseGrainedSchedulerBackend的子类,这个reviveOffers()就是定义在reviveOffers()中的,YarnClusterSchedulerBackend中并没有重写这个方法,也就是说这里其实是调用了CoarseGrainedSchedulerBackend中的reviveOffers()方法

override def reviveOffers() {
  driverEndpoint.send(ReviveOffers)
}

方法中的ReviveOffers是一个样例类,在这里被当作一种消息经序列化后被RPC框架传输,这个ReviveOffers会被发送给SparkDriver,SparkDriver端也有一个相应的方法来接收消息

override def receive: PartialFunction[Any, Unit] = {
  case StatusUpdate(executorId, taskId, state, data) =>
    scheduler.statusUpdate(taskId, state, data.value)
    if (TaskState.isFinished(state)) {
      executorDataMap.get(executorId) match {
        case Some(executorInfo) =>
          executorInfo.freeCores += scheduler.CPUS_PER_TASK
          makeOffers(executorId)
        case None =>
          // Ignoring the update since we don't know about the executor.
          logWarning(s"Ignored task status update ($taskId state $state) " +
            s"from unknown executor with ID $executorId")
      }
    }
 
  case ReviveOffers =>
    makeOffers()

Driver在收到ReviveOffers后调用了makeOffers()方法

// Make fake resource offers on all executors
private def makeOffers() {
  // Make sure no executor is killed while some task is launching on it
  val taskDescs = CoarseGrainedSchedulerBackend.this.synchronized {
    // Filter out executors under killing
    val activeExecutors = executorDataMap.filterKeys(executorIsAlive)
    val workOffers = activeExecutors.map { case (id, executorData) =>
      new WorkerOffer(id, executorData.executorHost, executorData.freeCores)
    }.toIndexedSeq
    scheduler.resourceOffers(workOffers)
  }
  if (!taskDescs.isEmpty) {
    launchTasks(taskDescs)
  }
}

取出所有活着的Executor,activeExecutors里面存放着executorData对象,每个executorData里记录着自已所在的主机,RPC地址,剩余的Core数等,然后从中提取出每个Executor的资源信息,封装为WorkerOffer返回,还是上代码吧,没有什么比这更直观了

private[org.apache.spark.scheduler.cluster] 
class ExecutorData(val executorEndpoint: RpcEndpointRef,
				   val executorAddress: RpcAddress,
				   override val executorHost: String,
				   var freeCores: Int,
				   override val totalCores: Int,
				   override val logUrlMap: Map[String, String])
extends ExecutorInfo

为Task分配资源

取出资源信后就开始为所有Task分配资源了,也就是上面的resourceOffers(workOffers)方法

def resourceOffers(offers: IndexedSeq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {
  // Mark each slave as alive and remember its hostname
  // Also track if new executor is added
  var newExecAvail = false
  for (o <- offers) {
    if (!hostToExecutors.contains(o.host)) {
      hostToExecutors(o.host) = new HashSet[String]()
    }
    if (!executorIdToRunningTaskIds.contains(o.executorId)) {
      hostToExecutors(o.host) += o.executorId
      executorAdded(o.executorId, o.host)
      executorIdToHost(o.executorId) = o.host
      executorIdToRunningTaskIds(o.executorId) = HashSet[Long]()
      newExecAvail = true
    }
    for (rack <- getRackForHost(o.host)) {
      hostsByRack.getOrElseUpdate(rack, new HashSet[String]()) += o.host
    }
  }
 
  // Before making any offers, remove any nodes from the blacklist whose blacklist has expired. Do
  // this here to avoid a separate thread and added synchronization overhead, and also because
  // updating the blacklist is only relevant when task offers are being made.
  blacklistTrackerOpt.foreach(_.applyBlacklistTimeout())
 
  val filteredOffers = blacklistTrackerOpt.map { blacklistTracker =>
    offers.filter { offer =>
      !blacklistTracker.isNodeBlacklisted(offer.host) &&
        !blacklistTracker.isExecutorBlacklisted(offer.executorId)
    }
  }.getOrElse(offers)
 
  val shuffledOffers = shuffleOffers(filteredOffers)
  // Build a list of tasks to assign to each worker.
  val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores))
  val availableCpus = shuffledOffers.map(o => o.cores).toArray
  val sortedTaskSets = rootPool.getSortedTaskSetQueue
  for (taskSet <- sortedTaskSets) {
    logDebug("parentName: %s, name: %s, runningTasks: %s".format(
      taskSet.parent.name, taskSet.name, taskSet.runningTasks))
    if (newExecAvail) {
      taskSet.executorAdded()
    }
  }
 
  // Take each TaskSet in our scheduling order, and then offer it each node in increasing order
  // of locality levels so that it gets a chance to launch local tasks on all of them.
  // NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY
  for (taskSet <- sortedTaskSets) {
    var launchedAnyTask = false
    var launchedTaskAtCurrentMaxLocality = false
    for (currentMaxLocality <- taskSet.myLocalityLevels) {
      do {
        launchedTaskAtCurrentMaxLocality = resourceOfferSingleTaskSet(
          taskSet, currentMaxLocality, shuffledOffers, availableCpus, tasks)
        launchedAnyTask |= launchedTaskAtCurrentMaxLocality
      } while (launchedTaskAtCurrentMaxLocality)
    }
    if (!launchedAnyTask) {
      taskSet.abortIfCompletelyBlacklisted(hostToExecutors)
    }
  }
 
  if (tasks.size > 0) {
    hasLaunchedTask = true
  }
  return tasks
}

在这个方法中大概做的事情如下:
在所有Executor中过滤掉被加入黑名单的Executor
随机打乱Executor的顺序
为每个Task分配资源

分发Task到对应的Executor

分配完资源后就回到了makeOffers()方法,下一步就是发射所有Task了,也就是launchTasks(taskDescs)方法

// Launch tasks returned by a set of resource offers
private def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
  for (task <- tasks.flatten) {
    val serializedTask = TaskDescription.encode(task)
    if (serializedTask.limit >= maxRpcMessageSize) {
      scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
        try {
          var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
            "spark.rpc.message.maxSize (%d bytes). Consider increasing " +
            "spark.rpc.message.maxSize or using broadcast variables for large values."
          msg = msg.format(task.taskId, task.index, serializedTask.limit, maxRpcMessageSize)
          taskSetMgr.abort(msg)
        } catch {
          case e: Exception => logError("Exception in error callback", e)
        }
      }
    }
    else {
      val executorData = executorDataMap(task.executorId)
      executorData.freeCores -= scheduler.CPUS_PER_TASK
 
      logDebug(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " +
        s"${executorData.executorHost}.")
 
      executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask)))
    }
  }
}

循环取出每个Task,将其序列化,然后封成LaunchTask发送到某个Executor,至于是哪个Executor,这是在前面生成Task那一步已经分配好的,如果忘了可以翻回去看一下,这个LaunchTask被RPC框架发送到Executor进行处理,具体实现在CoarseGrainedExecutorBackend中

  override def receive: PartialFunction[Any, Unit] = {
    case RegisteredExecutor =>
      logInfo("Successfully registered with driver")
      try {
        executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
      } catch {
        case NonFatal(e) =>
          exitExecutor(1, "Unable to create executor due to " + e.getMessage, e)
      }
 
    case RegisterExecutorFailed(message) =>
      exitExecutor(1, "Slave registration failed: " + message)
 
    case LaunchTask(data) =>
      if (executor == null) {
        exitExecutor(1, "Received LaunchTask command but executor was null")
      } else {
        val taskDesc = TaskDescription.decode(data.value)
        logInfo("Got assigned task " + taskDesc.taskId)
        executor.launchTask(this, taskDesc)
      }
 
    case KillTask(taskId, _, interruptThread) =>
      if (executor == null) {
        exitExecutor(1, "Received KillTask command but executor was null")
      } else {
        executor.killTask(taskId, interruptThread)
      }
...

将接收到的Task反序列化,然后又调用了Executor类中的launchTask方法

def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
  val tr = new TaskRunner(context, taskDescription)
  runningTasks.put(taskDescription.taskId, tr)
  threadPool.execute(tr)
}

可以看到它创建了一个TaskRunner对象并将它加入线程池开始执行,这个TaskRunner其实就是实现了Java的Runnable接口,并且可以看出,每一个Task都持有一个Executor的上下文,也就是那个context对象,即然TaskRunner实现了Runner接口,那也就重写了run()方法,接下来我们进run()方法看看TaskRunner中的执行逻辑。

今天就分析先到这里,下一篇继续分析。


编写日期:2017-04-29