Spark 源码分析(八):DAGScheduler 源码分析2(task 最佳位置计算)

栏目: 编程工具 · 发布时间: 5年前

内容简介:前面一篇文章已经讲了 DAGScheduler 中的 stage 划分算法。实际上就是每当执行到 RDD 的 action 算子时会去调用 DAGScheduler 的 handleJobSubmitted 方法,这个方法内部会根据当前的 RDD 创建一个 ResultStage,然后根据这个 ResultStage 对象创建一个 Job。再将这个 stage 对象传入 submitStage 方法,这个方法内部会调用一些其它方法,会根据当前 stage 中的那个 RDD 的依赖链往前推,依据 RDD 之间

前面一篇文章已经讲了 DAGScheduler 中的 stage 划分算法。

实际上就是每当执行到 RDD 的 action 算子时会去调用 DAGScheduler 的 handleJobSubmitted 方法,这个方法内部会根据当前的 RDD 创建一个 ResultStage,然后根据这个 ResultStage 对象创建一个 Job。

再将这个 stage 对象传入 submitStage 方法,这个方法内部会调用一些其它方法,会根据当前 stage 中的那个 RDD 的依赖链往前推,依据 RDD 之间的依赖关系,碰到宽依赖就创建一个新的 stage,窄依赖就将当前 RDD 加入当前 stage 中,一直到所有 RDD 都遍历完。

至此所有的 stage 就划分完了。

在前面的 submitStage 方法中会找到划分出的 stage 中的第一个 stage,然后调用 submitMissingTasks 方法。

if (missing.isEmpty) {
          logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
          // 找到第一个 stage 去调用 submitMissingTasks 方法
          submitMissingTasks(stage, jobId.get)
        }
复制代码

submitMissingTasks 方法中做了这些事:

1,拿到 stage 中没有计算的 partition;

2,获取 task 对应的 partition 的最佳位置,这个是这里主要讲解的算法;

3,获取 taskBinary,将 stage 的 RDD 和 ShuffleDependency(或 func)广播到 Executor;

4,为 stage 创建 task;

这个方法的代码很多,我们主要分析下怎么分配 task 到最优的 partition 上去的,也就是计算 partitionId 和 taskId 的对应关系。

// 计算 taskId 和 partition 的对应关系	
val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try {
      stage match {
        // 如果是 ShuffleMapStage
        case s: ShuffleMapStage =>
          partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
        // 如果是 ResultStage
        case s: ResultStage =>
          partitionsToCompute.map { id =>
            val p = s.partitions(id)
            (id, getPreferredLocs(stage.rdd, p))
          }.toMap
      }
    } catch {
      case NonFatal(e) =>
        stage.makeNewStageAttempt(partitionsToCompute.size)
        listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
        abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e))
        runningStages -= stage
        return
    }
复制代码

可以看出主要是调用了 getPreferredLocs 这个方法,这个方法实际上是调用了 getPreferredLocsInternal 这个方法。

private[spark]
  def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = {
    getPreferredLocsInternal(rdd, partition, new HashSet)
  }
复制代码

下面主要看下 getPreferredLocsInternal 这个方法做了哪些操作:

1,判断 RDD 的 partition 是否被操作过了,如果被操作过了就什么都不做;

2,查看当前 RDD 的 partition 的最佳计算位置是否被缓存过,如果被缓存过直接返回对应的缓存位置;

3,如果没有缓存,就调用 RDD 的 preferredLocations 去计算最佳位置,实际上就是看看当前 RDD 是否被 checkpoint 了,如果有就返回 checkpoint 的位置;

4,如果当前 RDD 既没有被缓存又没有 checkpoint 的话,就去遍历 RDD 的依赖链,如果有窄依赖,就去遍历父 RDD 的所有 partition,递归调用 getPreferredLocsInternal 方法。

这里实际上就是找出当前 stage 中是否存在某个 RDD 被缓存或者 checkpoint了,如果有就返回其缓存或者 checkpoint 的位置,添加到序列中,然后返回。如果当前 stage 中的所有 RDD 都没有被缓存或者 checkpoint 的话,那么 task 的最佳计算位置就返回 Nil。

private def getPreferredLocsInternal(
      rdd: RDD[_],
      partition: Int,
      visited: HashSet[(RDD[_], Int)]): Seq[TaskLocation] = {
    // 如果这个 rdd 的 partition 已经计算过了位置了就忽略
  	// 因为这个方法是被递归调用的
    if (!visited.add((rdd, partition))) {
      // Nil has already been returned for previously visited partitions.
      return Nil
    }
    // 如果这个 partition 被缓存过就返回缓存的位置
    val cached = getCacheLocs(rdd)(partition)
    if (cached.nonEmpty) {
      return cached
    }
    // 调用 RDD 内部的 preferredLocations 方法去找最佳计算位置,实际上内部是看当前
    // RDD 是否 checkpoint 了,如果做了 checkpoint 就会返回 checkpoint 的位置
    val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
    if (rddPrefs.nonEmpty) {
      return rddPrefs.map(TaskLocation(_))
    }

    /**
    * 如果该 RDD 既没有没缓存有没有 checkpoint 的话那么就会去遍历他的依赖链,发现是窄依赖的时候
    * 去就去递归调用 getPreferredLocsInternal 去看看该 RDD 是否被缓存或者 checkpoint 了。如果
    * 是,就返回缓存或者 checkpoint 的位置。如果一直没找到的话就返回 Nil
    **/
    rdd.dependencies.foreach {
      case n: NarrowDependency[_] =>
        for (inPart <- n.getParents(partition)) {
          val locs = getPreferredLocsInternal(n.rdd, inPart, visited)
          if (locs != Nil) {
            return locs
          }
        }

      case _ =>
    }

    Nil
  }
复制代码

当获取到 task 的最佳位置后,根据 stage 的类型匹配,为每个 partition 的数据创建一个 task,如果是 ShuffleMapStage 就创建 ShuffleMapTask,如果是 ResultStage 就创建 ResultTask。然后将整个 stage 创建的所有 task都放到一个 Seq 中。

创建 task 的过程会将每个 task前面计算出来的最佳位置和 taskBinary 等参数带进去。

val tasks: Seq[Task[_]] = try {
      stage match {
        // 如果是 ShuffleMapStage
        case stage: ShuffleMapStage =>
          partitionsToCompute.map { id =>
            val locs = taskIdToLocations(id)
            val part = stage.rdd.partitions(id)
            // 创建 ShuffleMapTask
            new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
              taskBinary, part, locs, stage.latestInfo.taskMetrics, properties, Option(jobId),
              Option(sc.applicationId), sc.applicationAttemptId)
          }

        // 如果是 ResultStage
        case stage: ResultStage =>
          partitionsToCompute.map { id =>
            val p: Int = stage.partitions(id)
            val part = stage.rdd.partitions(p)
            val locs = taskIdToLocations(id)
            // 创建 ResultTask
            new ResultTask(stage.id, stage.latestInfo.attemptId,
              taskBinary, part, locs, id, properties, stage.latestInfo.taskMetrics,
              Option(jobId), Option(sc.applicationId), sc.applicationAttemptId)
          }
      }
    } catch {
      case NonFatal(e) =>
        abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e))
        runningStages -= stage
        return
    }
复制代码

创建好该 stage 的 tasks 后,如果 tasks 的长度大于 0,会将这些 task 创建一个 TaskSet ,然后调用 TaskScheduler 的 submitTasks 方法,提交 TaskSet 给 TaskScheduler。

如果 tasks 的长度小于等于 0 的话,会将当前 stage 标记完成,然后调用 submitWaitingChildStages 方法,提交当前 stage 的子 stage。

// 如果 tasks 长度大于 0
if (tasks.size > 0) {
      logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
      stage.pendingPartitions ++= tasks.map(_.partitionId)
      logDebug("New pending partitions: " + stage.pendingPartitions)
      // 将 tasks 封装到 TaskSet 内部,然后通过 taskScheduler 的 submitTasks 方法提交
      taskScheduler.submitTasks(new TaskSet(
        tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties))
      stage.latestInfo.submissionTime = Some(clock.getTimeMillis())
    } else {// tasks 长度小于等于 0
      // Because we posted SparkListenerStageSubmitted earlier, we should mark
      // the stage as completed here in case there are no tasks to run
  		// 标记当前 stage 已完成
      markStageAsFinished(stage, None)

      val debugString = stage match {
        case stage: ShuffleMapStage =>
          s"Stage ${stage} is actually done; " +
            s"(available: ${stage.isAvailable}," +
            s"available outputs: ${stage.numAvailableOutputs}," +
            s"partitions: ${stage.numPartitions})"
        case stage : ResultStage =>
          s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})"
      }
      logDebug(debugString)
			// 提交当前 stage 的子 stage
      submitWaitingChildStages(stage)
    }
复制代码

至此 Stage 的 TaskSet 已经提交给 TaskScheduler 了,下面就是看 TaskScheduler 怎么对 Task 进行调度处理了。


以上所述就是小编给大家介绍的《Spark 源码分析(八):DAGScheduler 源码分析2(task 最佳位置计算)》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Head Rush Ajax

Head Rush Ajax

Brett McLaughlin、Eric Freeman、Elisabeth Freeman / O'Reilly Media, Inc. / 2006-03-01 / USD 34.99

Ajax, or Asynchronous JavaScript and XML, is a term describing the latest rage in web development. Ajax is used to create interactive web applications with XML-based web services, and using JavaScript......一起来看看 《Head Rush Ajax》 这本书的介绍吧!

在线进制转换器
在线进制转换器

各进制数互转换器

HTML 编码/解码
HTML 编码/解码

HTML 编码/解码

HEX HSV 转换工具
HEX HSV 转换工具

HEX HSV 互换工具