Spark – 源码分析(四)
Task 的执行
上篇讲到TaskRunner类,这个类是定义在Executor类中,实现了Java的Runnable接口,Spark运行过程中的一个Task就是一个TaskRunner实例,下面是它的 run()
方法
override def run(): Unit = {
threadId = Thread.currentThread.getId
Thread.currentThread.setName(threadName)
val threadMXBean = ManagementFactory.getThreadMXBean
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStart: Long = 0
var taskStartCpu: Long = 0
startGCTime = computeTotalGcTime()
try {
// Must be set before updateDependencies() is called, in case fetching dependencies
// requires access to properties contained within (e.g. for access control).
Executor.taskDeserializationProps.set(taskDescription.properties)
updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
task = ser.deserialize[Task[Any]](
taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
task.localProperties = taskDescription.properties
task.setTaskMemoryManager(taskMemoryManager)
...
// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
var threwException = true
val value = try {
val res = task.run(
taskAttemptId = taskId,
attemptNumber = taskDescription.attemptNumber,
metricsSystem = env.metricsSystem)
threwException = false
res
} finally {
val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
if (freedMemory > 0 && !threwException) {
val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
throw new SparkException(errMsg)
} else {
logWarning(errMsg)
}
}
...
val resultSer = env.serializer.newInstance()
val beforeSerialization = System.currentTimeMillis()
val valueBytes = resultSer.serialize(value)
val afterSerialization = System.currentTimeMillis()
...
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
...
看起来有点长,只贴出部分,在这里分步总结一下:
- 将传进来的taskDescription对象反序列化为Task实例,这个实例是Task的子类,就是之前生成Task时的两种类型,ResultTask和SuffleMapTask,我们就用ResultTask为例
- 调用Task的
run()
方法,run()
又调用runTask()
,如下是ResultTask类中的runTask()
方法
override def runTask(context: TaskContext): U = {
// Deserialize the RDD and the func using the broadcast variables.
val threadMXBean = ManagementFactory.getThreadMXBean
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
val ser = SparkEnv.get.closureSerializer.newInstance()
val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
_executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
} else 0L
func(context, rdd.iterator(partition, context))
}
方法内将该Task内包含的RDD与闭包都反序列化出来,然后执行这个闭包,闭包来处理RDD数据
- Task的
run()
方法运行结束后就回到了TaskRunner中的run()
方法,且返回了结果数据,这代表Task的主要任务已经完成了,所心要收集一些信息,包括:该Task的运行时长,CPU时长,GC时长等,这些数据应该会发给UI端,但我没有细找这一块 - 然后将结果数据序列化,最后通过ExecutorBackend的
statusUpdate()
方法将结果传回到SparkDirver,ExecutorBackend是个物特质,在上一篇中提到过,CoarseGrainedExecutorBackend是其中的一个实现,它对statusUpdate()
方法的实现如下
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
val msg = StatusUpdate(executorId, taskId, state, data)
driver match {
case Some(driverRef) => driverRef.send(msg)
case None => logWarning(s"Drop $msg because has not yet connected to driver")
}
}
将data封装成StatusUpdate,将其作为消息通过RPC框架发送给Driver端 SparkDriver端的接收逻辑
override def receive: PartialFunction[Any, Unit] = {
case StatusUpdate(executorId, taskId, state, data) =>
scheduler.statusUpdate(taskId, state, data.value)
if (TaskState.isFinished(state)) {
executorDataMap.get(executorId) match {
case Some(executorInfo) =>
executorInfo.freeCores += scheduler.CPUS_PER_TASK
makeOffers(executorId)
case None =>
// Ignoring the update since we don't know about the executor.
logWarning(s"Ignored task status update ($taskId state $state) " +
s"from unknown executor with ID $executorId")
}
}
...
Driver收到StatusUpdate后,处理流程如下:
- 先调用了调度器的
statusUpdate()
方法,调度器将数据返序列出来,将Task标记为完成状态 - 然后调度器又调用了DAGScheduler的post方法,将Task和数据封装为CompletionEvent放入DAGScheduler类的事件池eventProcessLoop中等待处理
- 这时回到
receive()
中,回收该Task的CPU资源 - 调用
makeOffers(executorId)
方法启动下一个Task任务,直到所有Task都执行完 - 因为DAGScheduler的事件池是被一直被循环取出来处理的,所以刚才完放进去的CompletionEvent很快会被DGAScheduler的
doOnReceive()
方法取出来做出相应的动作
private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {
...
case completion: CompletionEvent =>
dagScheduler.handleTaskCompletion(completion)
...
看handleTaskCompletion()的处理逻辑
private[scheduler] def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
val taskId = event.taskInfo.id
val stageId = task.stageId
val taskType = Utils.getFormattedClassName(task)
...
// The stage may have already finished when we get this event -- eg. maybe it was a
// speculative task. It is important that we send the TaskEnd event in any case, so listeners
// are properly notified and can chose to handle it. For instance, some listeners are
// doing their own accounting and if they don't get the task end event they think
// tasks are still running when they really aren't.
listenerBus.post(SparkListenerTaskEnd(
stageId, task.stageAttemptId, taskType, event.reason, event.taskInfo, taskMetrics))
...
val stage = stageIdToStage(task.stageId)
event.reason match {
case Success =>
task match {
case rt: ResultTask[_, _] =>
// Cast to ResultStage here because it's part of the ResultTask
// TODO Refactor this out to a function that accepts a ResultStage
val resultStage = stage.asInstanceOf[ResultStage]
resultStage.activeJob match {
case Some(job) =>
if (!job.finished(rt.outputId)) {
updateAccumulators(event)
job.finished(rt.outputId) = true
job.numFinished += 1
// If the whole job has finished, remove it
if (job.numFinished == job.numPartitions) {
markStageAsFinished(resultStage)
cleanupStateForJobAndIndependentStages(job)
listenerBus.post(
SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded))
}
// taskSucceeded runs some user code that might throw an exception. Make sure
// we are resilient against that.
try {
job.listener.taskSucceeded(rt.outputId, event.result)
} catch {
case e: Exception =>
// TODO: Perhaps we want to mark the resultStage as failed?
job.listener.jobFailed(new SparkDriverExecutionException(e))
}
}
...
这的逻辑实在有点长,就只贴出一部分来
listenerBus.post
是通知所有监听者此Task执行结束,listenerBus显然是一个集合,里面有各路监听者,当他们得知此消息后可以做出相应的动作SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded))
如果这个Task是ResultTask类型,则通知所有监听者此Job执行结束job.listener.taskSucceeded(rt.outputId, event.result)
和job.listener.jobFailed(new SparkDriverExecutionException(e))
行是通知JobWaiter对象Task执行结果,然后JobWaiter就可以结束等待了- 这时就回到了当初提交Job的地方
def runJob[T, U](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: CallSite,
resultHandler: (Int, U) => Unit,
properties: Properties): Unit = {
val start = System.nanoTime
val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)
// Note: Do not call Await.ready(future) because that calls `scala.concurrent.blocking`,
// which causes concurrent SQL executions to fail if a fork-join pool is used. Note that
// due to idiosyncrasies in Scala, `awaitPermission` is not actually used anywhere so it's
// safe to pass in null here. For more detail, see SPARK-13747.
val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait]
waiter.completionFuture.ready(Duration.Inf)(awaitPermission)
waiter.completionFuture.value.get match {
case scala.util.Success(_) =>
logInfo("Job %d finished: %s, took %f s".format
(waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
case scala.util.Failure(exception) =>
logInfo("Job %d failed: %s, took %f s".format
(waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
// SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler.
val callerStackTrace = Thread.currentThread().getStackTrace.tail
exception.setStackTrace(exception.getStackTrace ++ callerStackTrace)
throw exception
}
}
JobWaiter等待结束后,case scala.util.Success(_)
或 case scala.util.Failure(exception)
行会打印相关日志,此时提交Job的整个流程也就结束了。
能力有限,文章可能看起来有些乱,博主还没有理清全部流程,后面会抽时间出来画张流程图,帮助理解。