8. spark源码分析(基于yarn cluster模式)- Task执行,Map端写入实现,ShuffleMapTask,SortShuffleWriter,ResultStage,ResultT

本系列基于spark-2.4.6
通过上一节分析,我们知道,task提交之后通过launchTasks来具体执行任务,这一章节我们看下其具体实现.

private def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
    
    
      for (task <- tasks.flatten) {
    
    
        val serializedTask = TaskDescription.encode(task)
        if (serializedTask.limit() >= maxRpcMessageSize) {
    
    
          Option(scheduler.taskIdToTaskSetManager.get(task.taskId)).foreach {
    
     taskSetMgr =>
            try {
    
    
              var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
                "spark.rpc.message.maxSize (%d bytes). Consider increasing " +
                "spark.rpc.message.maxSize or using broadcast variables for large values."
              msg = msg.format(task.taskId, task.index, serializedTask.limit(), maxRpcMessageSize)
              taskSetMgr.abort(msg)
            } catch {
    
    
              case e: Exception => logError("Exception in error callback", e)
            }
          }
        }
        else {
    
    
          val executorData = executorDataMap(task.executorId)
          executorData.freeCores -= scheduler.CPUS_PER_TASK

          logDebug(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " +
            s"${executorData.executorHost}.")

          executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask)))
        }
      }
    }

第一步判断序列化后的task是否大于maxRpcMessageSize,如果大于,直接报任务失败异常,退出执行当前任务,如果没有的话,则会发送LaunchTask指令给Executor去执行。接下来就要看Executor中是怎么执行的了。


通过前面的分析,我们知道,executor中消息到来会在org.apache.spark.rpc.netty.Inbox中进行处理:

 def process(dispatcher: Dispatcher): Unit = {
    
    
    var message: InboxMessage = null
    inbox.synchronized {
    
    
      if (!enableConcurrent && numActiveThreads != 0) {
    
    
        return
      }
      message = messages.poll()
      if (message != null) {
    
    
        numActiveThreads += 1
      } else {
    
    
        return
      }
    }
    while (true) {
    
    
      safelyCall(endpoint) {
    
    
        message match {
    
    
          case RpcMessage(_sender, content, context) =>
            try {
    
    
              endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, {
    
     msg =>
                throw new SparkException(s"Unsupported message $message from ${_sender}")
              })
            } catch {
    
    
              case e: Throwable =>
                context.sendFailure(e)
                throw e
            }

          case OneWayMessage(_sender, content) =>
            endpoint.receive.applyOrElse[Any, Unit](content, {
    
     msg =>
              throw new SparkException(s"Unsupported message $message from ${_sender}")
            })

          case OnStart =>
            endpoint.onStart()
            if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
    
    
              inbox.synchronized {
    
    
                if (!stopped) {
    
    
                  enableConcurrent = true
                }
              }
            }

          case OnStop =>
            val activeThreads = inbox.synchronized {
    
     inbox.numActiveThreads }
            dispatcher.removeRpcEndpointRef(endpoint)
            endpoint.onStop()
          case RemoteProcessConnected(remoteAddress) =>
            endpoint.onConnected(remoteAddress)

          case RemoteProcessDisconnected(remoteAddress) =>
            endpoint.onDisconnected(remoteAddress)

          case RemoteProcessConnectionError(cause, remoteAddress) =>
            endpoint.onNetworkError(cause, remoteAddress)
        }
      }

      inbox.synchronized {
    
    
        if (!enableConcurrent && numActiveThreads != 1) {
    
    
          numActiveThreads -= 1
          return
        }
        message = messages.poll()
        if (message == null) {
    
    
          numActiveThreads -= 1
          return
        }
      }
    }
  }

实际在CoarseGrainedExecutorBackend进行处理:

case LaunchTask(data) =>
      if (executor == null) {
    
    
        exitExecutor(1, "Received LaunchTask command but executor was null")
      } else {
    
    
        val taskDesc = TaskDescription.decode(data.value)
        logInfo("Got assigned task " + taskDesc.taskId)
        executor.launchTask(this, taskDesc)
      }
  def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
    
    
    val tr = new TaskRunner(context, taskDescription)
    runningTasks.put(taskDescription.taskId, tr)
    threadPool.execute(tr)
  }

将接收到的task封装成一个TaskRunner然后提交到线程池中执行:

override def run(): Unit = {
    
    
      threadId = Thread.currentThread.getId
      Thread.currentThread.setName(threadName)
      val threadMXBean = ManagementFactory.getThreadMXBean
      val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
      val deserializeStartTime = System.currentTimeMillis()
      val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    
    
        threadMXBean.getCurrentThreadCpuTime
      } else 0L
      Thread.currentThread.setContextClassLoader(replClassLoader)
      val ser = env.closureSerializer.newInstance()
      logInfo(s"Running $taskName (TID $taskId)")
      execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
      var taskStartTime: Long = 0
      var taskStartCpu: Long = 0
      startGCTime = computeTotalGcTime()

      try {
    
    
        Executor.taskDeserializationProps.set(taskDescription.properties)

        updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
        task = ser.deserialize[Task[Any]](
          taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
        task.localProperties = taskDescription.properties
        task.setTaskMemoryManager(taskMemoryManager)
        val killReason = reasonIfKilled
        if (killReason.isDefined) {
    
    
          throw new TaskKilledException(killReason.get)
        }
        if (!isLocal) {
    
    
          logDebug("Task " + taskId + "'s epoch is " + task.epoch)
          env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch)
        }

        // Run the actual task and measure its runtime.
        taskStartTime = System.currentTimeMillis()
        taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    
    
          threadMXBean.getCurrentThreadCpuTime
        } else 0L
        var threwException = true
        val value = Utils.tryWithSafeFinally {
    
    
          val res = task.run(
            taskAttemptId = taskId,
            attemptNumber = taskDescription.attemptNumber,
            metricsSystem = env.metricsSystem)
          threwException = false
          res
        } {
    
    
          val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
          val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()

          if (freedMemory > 0 && !threwException) {
    
    
            val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
            if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
    
    
              throw new SparkException(errMsg)
            } else {
    
    
              logWarning(errMsg)
            }
          }

          if (releasedLocks.nonEmpty && !threwException) {
    
    
            if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false)) {
    
    
              throw new SparkException(errMsg)
            } else {
    
    
              logInfo(errMsg)
            }
          }
        }
        task.context.fetchFailed.foreach {
    
     fetchFailure =>      
        }
        val taskFinish = System.currentTimeMillis()
        val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    
    
          threadMXBean.getCurrentThreadCpuTime
        } else 0L
        task.context.killTaskIfInterrupted()
        val resultSer = env.serializer.newInstance()
        val beforeSerialization = System.currentTimeMillis()
        val valueBytes = resultSer.serialize(value)
        val afterSerialization = System.currentTimeMillis()
        task.metrics.setExecutorDeserializeTime(
          (taskStartTime - deserializeStartTime) + task.executorDeserializeTime)
        task.metrics.setExecutorDeserializeCpuTime(
          (taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime)
        task.metrics.setExecutorRunTime((taskFinish - taskStartTime) - task.executorDeserializeTime)
        task.metrics.setExecutorCpuTime(
          (taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime)
        task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
        task.metrics.setResultSerializationTime(afterSerialization - beforeSerialization)
       executorSource.METRIC_CPU_TIME.inc(task.metrics.executorCpuTime)
        executorSource.METRIC_RUN_TIME.inc(task.metrics.executorRunTime)
        executorSource.METRIC_JVM_GC_TIME.inc(task.metrics.jvmGCTime)
        executorSource.METRIC_DESERIALIZE_TIME.inc(task.metrics.executorDeserializeTime)
        executorSource.METRIC_DESERIALIZE_CPU_TIME.inc(task.metrics.executorDeserializeCpuTime)
        executorSource.METRIC_RESULT_SERIALIZE_TIME.inc(task.metrics.resultSerializationTime)
        executorSource.METRIC_SHUFFLE_FETCH_WAIT_TIME
          .inc(task.metrics.shuffleReadMetrics.fetchWaitTime)
        executorSource.METRIC_SHUFFLE_WRITE_TIME.inc(task.metrics.shuffleWriteMetrics.writeTime)
        executorSource.METRIC_SHUFFLE_TOTAL_BYTES_READ
          .inc(task.metrics.shuffleReadMetrics.totalBytesRead)
        executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ
          .inc(task.metrics.shuffleReadMetrics.remoteBytesRead)
        executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ_TO_DISK
          .inc(task.metrics.shuffleReadMetrics.remoteBytesReadToDisk)
        executorSource.METRIC_SHUFFLE_LOCAL_BYTES_READ
          .inc(task.metrics.shuffleReadMetrics.localBytesRead)
        executorSource.METRIC_SHUFFLE_RECORDS_READ
          .inc(task.metrics.shuffleReadMetrics.recordsRead)
        executorSource.METRIC_SHUFFLE_REMOTE_BLOCKS_FETCHED
          .inc(task.metrics.shuffleReadMetrics.remoteBlocksFetched)
        executorSource.METRIC_SHUFFLE_LOCAL_BLOCKS_FETCHED
          .inc(task.metrics.shuffleReadMetrics.localBlocksFetched)
        executorSource.METRIC_SHUFFLE_BYTES_WRITTEN
          .inc(task.metrics.shuffleWriteMetrics.bytesWritten)
        executorSource.METRIC_SHUFFLE_RECORDS_WRITTEN
          .inc(task.metrics.shuffleWriteMetrics.recordsWritten)
        executorSource.METRIC_INPUT_BYTES_READ
          .inc(task.metrics.inputMetrics.bytesRead)
        executorSource.METRIC_INPUT_RECORDS_READ
          .inc(task.metrics.inputMetrics.recordsRead)
        executorSource.METRIC_OUTPUT_BYTES_WRITTEN
          .inc(task.metrics.outputMetrics.bytesWritten)
        executorSource.METRIC_OUTPUT_RECORDS_WRITTEN
          .inc(task.metrics.outputMetrics.recordsWritten)
        executorSource.METRIC_RESULT_SIZE.inc(task.metrics.resultSize)
        executorSource.METRIC_DISK_BYTES_SPILLED.inc(task.metrics.diskBytesSpilled)
        executorSource.METRIC_MEMORY_BYTES_SPILLED.inc(task.metrics.memoryBytesSpilled)
        val accumUpdates = task.collectAccumulatorUpdates()
        val directResult = new DirectTaskResult(valueBytes, accumUpdates)
        val serializedDirectResult = ser.serialize(directResult)
        val resultSize = serializedDirectResult.limit()
        val serializedResult: ByteBuffer = {
    
    
          if (maxResultSize > 0 && resultSize > maxResultSize) {
    
    ze)}), " +
              s"dropping it.")
            ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
          } else if (resultSize > maxDirectResultSize) {
    
    
            val blockId = TaskResultBlockId(taskId)
            env.blockManager.putBytes(
              blockId,
              new ChunkedByteBuffer(serializedDirectResult.duplicate()),
              StorageLevel.MEMORY_AND_DISK_SER)
            ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
          } else {
    
    
            logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
            serializedDirectResult
          }
        }
        setTaskFinishedAndClearInterruptStatus()
        execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

      } catch {
    
    .......
      } finally {
    
    
        runningTasks.remove(taskId)
      }
    }

这里前面很大一部分在准备task执行的相关环境,然后调用task.run来执行,最后通过Task.runTask来实际执行任务,这是一个接口,我们以ShuffleMapTask为例进行说明:

// ShuffleMapTask
override def runTask(context: TaskContext): MapStatus = {
    
    
    // Deserialize the RDD using the broadcast variable.
    val threadMXBean = ManagementFactory.getThreadMXBean
    val deserializeStartTime = System.currentTimeMillis()
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    
    
      threadMXBean.getCurrentThreadCpuTime
    } else 0L
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
    _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    
    
      threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
    } else 0L

    var writer: ShuffleWriter[Any, Any] = null
    try {
    
    
      val manager = SparkEnv.get.shuffleManager
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
      writer.stop(success = true).get
    } catch {
    
    
     ....
    }
  }

这里会调用Writer进行处理:

writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])

spark中的ShuffleWriter有如下几种实现:
在这里插入图片描述

我们看下默认的SortShuffleWriter实现:

override def write(records: Iterator[Product2[K, V]]): Unit = {
    
    
    sorter = if (dep.mapSideCombine) {
    
    
      new ExternalSorter[K, V, C](
        context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
    } else {
    
    
      new ExternalSorter[K, V, V](
        context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
    }
    sorter.insertAll(records)
    val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
    val tmp = Utils.tempFileWith(output)
    try {
    
    
      val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
      val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
      shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
      mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
    } finally {
    
    
      if (tmp.exists() && !tmp.delete()) {
    
    
        logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
      }
    }
  }

这里可以看到,里面用ExternalSorter,执行sorter.insertAll(records):

def insertAll(records: Iterator[Product2[K, V]]): Unit = {
    
    
    val shouldCombine = aggregator.isDefined
    if (shouldCombine) {
    
    
      val mergeValue = aggregator.get.mergeValue
      val createCombiner = aggregator.get.createCombiner
      var kv: Product2[K, V] = null
      val update = (hadValue: Boolean, oldValue: C) => {
    
    
        if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
      }
      while (records.hasNext) {
    
    
        addElementsRead()
        kv = records.next()
        map.changeValue((getPartition(kv._1), kv._1), update)
        maybeSpillCollection(usingMap = true)
      }
    } else {
    
    
      while (records.hasNext) {
    
    
        addElementsRead()
        val kv = records.next()
        buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
        maybeSpillCollection(usingMap = false)
      }
    }
  }

可以看到,这里如果需要在map端进行合并,则写入到了一个map中,否则写入到了一个数组中.map的实现为PartitionedAppendOnlyMap,数组缓存实现为PartitionedPairBuffer。 这里是写入到了一个缓存A中,缓存的实现为PartitionedPairBuffer,调用其insert`实现

  def insert(partition: Int, key: K, value: V): Unit = {
    
    
    if (curSize == capacity) {
    
    
      growArray()
    }
    data(2 * curSize) = (partition, key.asInstanceOf[AnyRef])
    data(2 * curSize + 1) = value.asInstanceOf[AnyRef]
    curSize += 1
    afterUpdate()
  }

最底层的datas还是一个数组来实现,一般内容为:

  • index保存分区号和key
  • index+1保存value

ExternalSorter每次插入完之后会调用maybeSpillCollection来进行是否需要溢写到磁盘上,需要注意的是这里所谓的溢写磁盘是先将上面缓存中的数据先排好序,然后在写入到一个缓存B中,满足一定的阈值在写入到文件中,这里开启maybeSpillCollection这个步骤的条件是写入缓存A的大小大于阈值spark.shuffle.spill.initialMemoryThreshold,默认是5*1024*1024 5MB,然后执行maybeSpillCollection,这里面会对缓存A的数据先进行排序,然后在写入缓存B中,满足条件在flush到磁盘中,条件是已经写到缓存B中的数据条数等于spark.shuffle.spill.batchSize,默认是10000条:

private def maybeSpillCollection(usingMap: Boolean): Unit = {
    
    
    var estimatedSize = 0L
    if (usingMap) {
    
    
      estimatedSize = map.estimateSize()
      if (maybeSpill(map, estimatedSize)) {
    
    
        map = new PartitionedAppendOnlyMap[K, C]
      }
    } else {
    
    
      estimatedSize = buffer.estimateSize()
      if (maybeSpill(buffer, estimatedSize)) {
    
    
        buffer = new PartitionedPairBuffer[K, C]
      }
    }
    if (estimatedSize > _peakMemoryUsedBytes) {
    
    
      _peakMemoryUsedBytes = estimatedSize
    }
  }

Spillable去执行maybeSpill这里面如果需要溢写的话,会执行实际的溢写磁盘处理

  override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
    
    
    val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)
    val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
    spills += spillFile
  }

最终还是在ExternalSorter处理:

private[this] def spillMemoryIteratorToDisk(inMemoryIterator: WritablePartitionedIterator)
      : SpilledFile = {
    
    
    val (blockId, file) = diskBlockManager.createTempShuffleBlock()
    var objectsWritten: Long = 0
    val spillMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics
    val writer: DiskBlockObjectWriter =
      blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)
    val batchSizes = new ArrayBuffer[Long]
    val elementsPerPartition = new Array[Long](numPartitions)
    def flush(): Unit = {
    
    
      val segment = writer.commitAndGet()
      batchSizes += segment.length
      _diskBytesSpilled += segment.length
      objectsWritten = 0
    }
    var success = false
    try {
    
    
      while (inMemoryIterator.hasNext) {
    
    
        val partitionId = inMemoryIterator.nextPartition()
        inMemoryIterator.writeNext(writer)
        elementsPerPartition(partitionId) += 1
        objectsWritten += 1

        if (objectsWritten == serializerBatchSize) {
    
    
          flush()
        }
      }
      if (objectsWritten > 0) {
    
    
        flush()
      } else {
    
    
        writer.revertPartialWritesAndClose()
      }
      success = true
    } finally {
    
    
      if (success) {
    
    
        writer.close()
      } else {
    
    
        writer.revertPartialWritesAndClose()
        if (file.exists()) {
    
    
        }
      }
    }
    SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)
  }

这里在写入的时候首先会通过diskBlockManager.createTempShuffleBlock()来获取一个临时的文件名和对应的路径,这个文件名就是通过UUID.randomUUID

 def createTempShuffleBlock(): (TempShuffleBlockId, File) = {
    
    
    var blockId = new TempShuffleBlockId(UUID.randomUUID())
    while (getFile(blockId).exists()) {
    
    
      blockId = new TempShuffleBlockId(UUID.randomUUID())
    }
    (blockId, getFile(blockId))
  }

当写入缓存的数据条数等于10000的时候会执行刷盘,写入创建的文件中。
写入前会对要写入的缓存数据进行排序,按照paritionId+Key的方式进行排序,先比较paritionId,同一个partitionId在比较Key。这样写入到文件中的数据就是已经排好序的文件,写入文件,每个partition对应的数据都是紧密排在一起、相邻的。

当任务都写完之后,这时候实际上是多个小文件,然后会进行小文件的合并,将多个文件合并成一个文件并生成一个对应的索引文件:

 val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
      val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
      shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
      mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)

这里通过sorter.writePartitionedFile(blockId, tmp)将所有溢写的小文件按照分区划分,相同分区数据写在一个文件中,通过shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)生成索引文件。

def writePartitionedFile(
      blockId: BlockId,
    val lengths = new Array[Long](numPartitions)
    val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
      context.taskMetrics().shuffleWriteMetrics)
    if (spills.isEmpty) {
    
    
      val collection = if (aggregator.isDefined) map else buffer
      val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
      while (it.hasNext) {
    
    
        val partitionId = it.nextPartition()
        while (it.hasNext && it.nextPartition() == partitionId) {
    
    
          it.writeNext(writer)
        }
        val segment = writer.commitAndGet()
        lengths(partitionId) = segment.length
      }
    } else {
    
    
      for ((id, elements) <- this.partitionedIterator) {
    
    
        if (elements.hasNext) {
    
    
          for (elem <- elements) {
    
    
            writer.write(elem._1, elem._2)
        }
          val segment = writer.commitAndGet()
          lengths(id) = segment.length
        }
      }
    }
....
  }

这里的spills就是上面我们提到的溢写的小文件的集合,这里我们以已经有溢写小文件来看看是怎么处理的,主要逻辑在this.partitionedIterator返回的迭代器:

def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
    
    
    val usingMap = aggregator.isDefined
    val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
    if (spills.isEmpty) {
    
    
      if (!ordering.isDefined) {
    
      groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)))
      } else {
    
    
        // We do need to sort by both partition ID and key
        groupByPartition(destructiveIterator(
          collection.partitionedDestructiveSortedIterator(Some(keyComparator))))
      }
    } else {
    
    
      merge(spills, destructiveIterator(
        collection.partitionedDestructiveSortedIterator(comparator)))
    }
  }

private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
      : Iterator[(Int, Iterator[Product2[K, C]])] = {
    
    
    val readers = spills.map(new SpillReader(_))
    val inMemBuffered = inMemory.buffered
    (0 until numPartitions).iterator.map {
    
     p =>
      val inMemIterator = new IteratorForPartition(p, inMemBuffered)
      val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
      if (aggregator.isDefined) {
    
    
        // Perform partial aggregation across partitions
        (p, mergeWithAggregation(
          iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
      } else if (ordering.isDefined) {
    
    
        // No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey);
        // sort the elements without trying to merge them
        (p, mergeSort(iterators, ordering.get))
      } else {
    
    
        (p, iterators.iterator.flatten)
      }
    }
  }

可以看到,这里返回的是一个迭代器的集合,主要是用来迭代已经写入数据的小文件和还在缓存中的数据。小文件的迭代器实现为SpillReader,其是按照分区来读取文件中的内容,另外缓存数据的迭代器还是和上面写小文件用的迭代器一样。

 for ((id, elements) <- this.partitionedIterator) {
    
    
        if (elements.hasNext) {
    
    
          for (elem <- elements) {
    
    
            writer.write(elem._1, elem._2)
        }
          val segment = writer.commitAndGet()
          lengths(id) = segment.length
        }
      }
    }

这里每次都是一个分区的数据迭代器的集合,这样一次迭代完就把一个分区的数据都读取完并写入到文件中,另外这里在每个分区写入完成之后,会记录每个分区在写入文件中的数据的大小。这里Sort模式每个Task最终生成两个文件,一个数据文件,一个索引文件
然后生成索引文件,索引文件根据上面每个分区写入数据的大小,按照分区依次写入对应数据文件大小。到这里,map端的数据完成了。

这里总结下整个Map端的写入流程

  1. Executor收到Drvier端发送的LaunchTask命令后,将接收到的Task信息封装到TaskRunner中,然后通过线程池执行TaskRunner.run方法,执行任务
  2. TaskRunner.run最终会调用Task.runTask方法,在ShuffleMapTask中,交由ShuffleWriter去进行写入
  3. 默认的SortShuffleWriter会将数据写入一个map或者数组中(看是否需要在Map端进行聚合)
  4. SortShuffleWriter调用ExternalSorter进行插入数据到缓存中,每次插入完一条数据之后判断缓存是否需要一些磁盘,判断条件是缓存数据大小是否大于spark.shuffle.spill.initialMemoryThreshold,默认是5MB
  5. 如果需要进行溢写磁盘,先进行排好序后会将已经排好序的数据在写到一个临时缓存中,写入前会对要写入的缓存数据进行排序,按照paritionId+Key的方式进行排序,先比较paritionId,同一个partitionId在比较Key。这样写入到文件中的数据就是已经排好序的文件,写入文件,每个partition对应的数据都是紧密排在一起、相邻的。然后这个临时缓存如果数据条数大于spark.shuffle.spill.initialMemoryThreshold(默认10000条),则会把数据flush到本地磁盘中,(这里每次大于spark.shuffle.spill.initialMemoryThreshold写磁盘都是生成一个新的磁盘文件,大于spark.shuffle.spill.initialMemoryThreshold只是flush到磁盘,并不会生成一个新文件)
  6. 当任务执行完成之后,实际上生成的是多个小文件,最后会把这多个小文件的数据都写到一个文件中,同时生成一个对应的索引文件,写入数据到文件的时候是按照分区来进行写入的,会按照分区逐个遍历所有的小文件,然后读取该分区的数据并写入到最终的数据文件中,这样最终的数据文件中每个分区的数据都是紧密排列在一起的同时部分有序,这样索引文件就能够根据文件位置找到对应的分区的数据

上面说的都是ShuffleMapTask的处理,如果是 ResultStage生成的则是ResultTask,唯一的区别则是ResultTask不会进行类似ShuffleMapTask的本地临时文件写入(执行writer.write方法),而是根据用户传入的执行算子,执行方法:

  override def runTask(context: TaskContext): U = {
    
    
    val threadMXBean = ManagementFactory.getThreadMXBean
    val deserializeStartTime = System.currentTimeMillis()
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    
    
      threadMXBean.getCurrentThreadCpuTime
    } else 0L
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
    _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    
    
      threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
    } else 0L

    func(context, rdd.iterator(partition, context))
  }

可以看到这里ResultTask获取到自己需要处理的分区数据后直接执行用户的方法。

猜你喜欢

转载自blog.csdn.net/LeoHan163/article/details/120909779