聊聊Fork/Join并发框架

本文正在参加「金石计划」

什么是Fork/Join

Fork/Join框架是一个实现了ExecutorService接口的多线程处理器,它专为那些可以通过递归分解成更细小的任务而设计,最大化的利用多核处理器来提高应用程序的性能。

与其他ExecutorService相关的实现相同的是,Fork/Join框架会将任务分配给线程池中的线程。而与之不同的是,Fork/Join框架在执行任务时使用了工作窃取算法

fork在英文里有分叉的意思,join在英文里连接、结合的意思。顾名思义,fork就是要使一个大任务分解成若干个小任务,而join就是最后将各个小任务的结果结合起来得到大任务的结果。

Fork/Join的运行流程大致如下所示:

image.png

需要注意的是,图里的次级子任务可以一直分下去,一直分到子任务足够小为止。用伪代码来表示如下:

 solve(任务):
     if(任务已经划分到足够小):
         顺序执行任务
     else:
         for(划分任务得到子任务)
             solve(子任务)
         结合所有子任务的结果到上一层循环
         return 最终结合的结果
复制代码

通过上面伪代码可以看出,我们通过递归嵌套的计算得到最终结果,这里有体现分而治之(divide and conquer) 的算法思想。

工作窃取算法

工作窃取算法指的是在多线程执行不同任务队列的过程中,某个线程执行完自己队列的任务后从其他线程的任务队列里窃取任务来执行。

工作窃取流程如下图所示:

image.png

值得注意的是,当一个线程窃取另一个线程的时候,为了减少两个任务线程之间的竞争,我们通常使用双端队列来存储任务。被窃取的任务线程都从双端队列的头部拿任务执行,而窃取其他任务的线程从双端队列的尾部执行任务。

另外,当一个线程在窃取任务时要是没有其他可用的任务了,这个线程会进入阻塞状态以等待再次“工作”。

Fork/Join的具体实现

Fork/Join框架简单来讲就是对任务的分割与子任务的合并,所以要实现这个框架,先得有任务

在Fork/Join框架里提供了抽象类ForkJoinTask来实现任务。

ForkJoinTask

ForkJoinTask是一个类似普通线程的实体,但是比普通线程轻量得多。

fork()方法

其实fork()只做了一件事,那就是把任务推入当前工作线程的工作队列里

来看下fork()的源码:

 public final ForkJoinTask<V> fork() {
     Thread t;
     // ForkJoinWorkerThread是执行ForkJoinTask的专有线程,由ForkJoinPool管理
     // 先判断当前线程是否是ForkJoin专有线程,如果是,则将任务push到当前线程所负责的队列里去
     if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
         ((ForkJoinWorkerThread)t).workQueue.push(this);
     else
          // 如果不是则将线程加入队列
         // 没有显式创建ForkJoinPool的时候走这里,提交任务到默认的common线程池中
         ForkJoinPool.common.externalPush(this);
     return this;
 }
复制代码

join()方法

Join() 的主要作用是阻塞当前线程并等待获取结果。

来看下join()的源码:

 public final V join() {
     int s;
     // doJoin()方法来获取当前任务的执行状态
     if ((s = doJoin() & DONE_MASK) != NORMAL)
         // 任务异常,抛出异常
         reportException(s);
     // 任务正常完成,获取返回值
     return getRawResult();
 }

 /**
  * doJoin()方法用来返回当前任务的执行状态
  **/
 private int doJoin() {
     int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
     // 先判断任务是否执行完毕,执行完毕直接返回结果(执行状态)
     return (s = status) < 0 ? s :
     // 如果没有执行完毕,先判断是否是ForkJoinWorkThread线程
     ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
         // 如果是,先判断任务是否处于工作队列顶端(意味着下一个就执行它)
         // tryUnpush()方法判断任务是否处于当前工作队列顶端,是返回true
         // doExec()方法执行任务
         (w = (wt = (ForkJoinWorkerThread)t).workQueue).
         // 如果是处于顶端并且任务执行完毕,返回结果
         tryUnpush(this) && (s = doExec()) < 0 ? s :
         // 如果不在顶端或者在顶端却没未执行完毕,那就调用awitJoin()执行任务
         // awaitJoin():使用自旋使任务执行完成,返回结果
         wt.pool.awaitJoin(w, this, 0L) :
     // 如果不是ForkJoinWorkThread线程,执行externalAwaitDone()返回任务结果
     externalAwaitDone();
 }
复制代码

下面是ForkJoinPool.join()的流程图:

image.png

通常情况下我们不需要直接继承 ForkJoinTask 类,而只需要继承它的子类,Fork/Join 框架提供了以下两个子类:

  • RecursiveAction:用于没有返回结果的任务。
  • RecursiveTask :用于有返回结果的任务。

ForkJoinPool

ForkJoinPool是用于执行ForkJoinTask任务的执行(线程)池。

ForkJoinPool管理着执行池中的线程和任务队列,此外,执行池是否还接受任务,显示线程的运行状态也是在这里处理。

ForkJoinPool的源码如下:

 @sun.misc.Contended
 public class ForkJoinPool extends AbstractExecutorService {
     // 任务队列
     volatile WorkQueue[] workQueues;   
 ​
     // 线程的运行状态
     volatile int runState;  
 ​
     // 创建ForkJoinWorkerThread的默认工厂,可以通过构造函数重写
     public static final ForkJoinWorkerThreadFactory defaultForkJoinWorkerThreadFactory;
 ​
     // 公用的线程池,其运行状态不受shutdown()和shutdownNow()的影响
     static final ForkJoinPool common;
 ​
     // 私有构造方法,没有任何安全检查和参数校验,由makeCommonPool直接调用
     // 其他构造方法都是源自于此方法
     // parallelism: 并行度,
     // 默认调用java.lang.Runtime.availableProcessors() 方法返回可用处理器的数量
     private ForkJoinPool(int parallelism,
                          ForkJoinWorkerThreadFactory factory, // 工作线程工厂
                          UncaughtExceptionHandler handler, // 拒绝任务的handler
                          int mode, // 同步模式
                          String workerNamePrefix) { // 线程名prefix
         this.workerNamePrefix = workerNamePrefix;
         this.factory = factory;
         this.ueh = handler;
         this.config = (parallelism & SMASK) | mode;
         long np = (long)(-parallelism); // offset ctl counts
         this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
     }
 }
复制代码

WorkQueue

双端队列,ForkJoinTask存放在这里。

当工作线程在处理自己的工作队列时,会从队列首取任务来执行(FIFO);如果是窃取其他队列的任务时,窃取的任务位于所属任务队列的队尾(LIFO)。

ForkJoinPool与传统线程池最显著的区别就是它维护了一个工作队列数组

runState

ForkJoinPool的运行状态。

Fork/Join的异常处理

ForkJoinTask 在执行的时候可能会抛出异常,但是我们没办法在主线程里直接捕获异常,所以 ForkJoinTask 提供了 isCompletedAbnormally() 方法来检查任务是否已经抛出异常或已经被取消了,并且可以通过 ForkJoinTask 的 getException 方法获取异常。使用如下代码:

if(task.isCompletedAbnormally()){
   System.out.println(task.getException());
}
复制代码

getException 方法返回 Throwable 对象,如果任务被取消了则返回 CancellationException。如果任务没有完成或者没有抛出异常则返回 null。

Fork/Join的使用

上面我们说ForkJoinPool负责管理线程和任务,ForkJoinTask实现fork和join操作,所以要使用Fork/Join框架就离不开这两个类了,只是在实际开发中我们常用ForkJoinTask的子类RecursiveTask 和RecursiveAction来替代ForkJoinTask。

下面我们用一个计算斐波那契数列第n项的例子来看一下Fork/Join的使用:

斐波那契数列数列是一个线性递推数列,从第三项开始,每一项的值都等于前两项之和:

1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89······

如果设f(n)为该数列的第n项(n∈N*),那么有:f(n) = f(n-1) + f(n-2)。

 public class FibonacciTest {
 ​
   static class Fibonacci extends RecursiveTask<Integer> {
 ​
         int n;
 ​
         public Fibonacci(int n) {
             this.n = n;
         }
 ​
         // 主要的实现逻辑都在compute()里
         @Override
         protected Integer compute() {
             // 这里先假设 n >= 0
             if (n <= 1) {
                 return n;
             } else {
                 // f(n-1)
                 Fibonacci f1 = new Fibonacci(n - 1);
                 f1.fork();
                 // f(n-2)
                 Fibonacci f2 = new Fibonacci(n - 2);
                 f2.fork();
                 // f(n) = f(n-1) + f(n-2)
                 return f1.join() + f2.join();
             }
         }
     }


    public static void main(String[] args) throws ExecutionException, InterruptedException {
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        System.out.println("CPU核数:" + Runtime.getRuntime().availableProcessors());
        long start = System.currentTimeMillis();
        Fibonacci fibonacci = new Fibonacci(40);
        Future<Integer> future = forkJoinPool.submit(fibonacci);
        System.out.println(future.get());
        long end = System.currentTimeMillis();
        System.out.println(String.format("耗时:%d millis", end - start));
    }
}
复制代码

上面例子在本机的输出:

CPU核数:6
102334155
耗时:5222 millis
复制代码

需要注意的是,上述计算时间复杂度为O(2^n),随着n的增长计算效率会越来越低,这也是上面的例子中n不敢取太大的原因。

此外,也并不是所有的任务都适合Fork/Join框架,比如上面的例子任务划分过于细小反而体现不出效率,下面我们试试用普通的递归来求f(n)的值,看看是不是要比使用Fork/Join快:

 // 普通递归,复杂度为O(2^n)
 public int plainRecursion(int n) {
     if (n == 1 || n == 2) {
         return 1;
     } else {
         return plainRecursion(n -1) + plainRecursion(n - 2);
     }
 }
 ​
 @Test
 public void testPlain() {
     long start = System.currentTimeMillis();
     int result = plainRecursion(40);
     long end = System.currentTimeMillis();
     System.out.println("计算结果:" + result);
     System.out.println(String.format("耗时:%d millis",  end -start));
 }
复制代码

普通递归的例子输出:

 计算结果:102334155
 耗时:436 millis
复制代码

通过输出可以很明显的看出来,使用普通递归的效率都要比使用Fork/Join框架要高很多。

这里我们再用另一种思路来计算:

// 通过循环来计算,复杂度为O(n)
private static int computeFibonacci(int n) {
    // 假设n >= 0
    if (n <= 1) {
        return n;
    } else {
        int first = 1;
        int second = 1;
        int third = 0;
        for (int i = 3; i <= n; i ++) {
            // 第三个数是前两个数之和
            third = first + second;
            // 前两个数右移
            first = second;
            second = third;
        }
        return third;
    }
}

public static void main(String[] args) {
    long start = System.currentTimeMillis();
    int result = computeFibonacci(40);
    long end = System.currentTimeMillis();
    System.out.println("计算结果:" + result);
    System.out.println(String.format("耗时:%d millis",  end -start));
}
复制代码

上面例子在笔者所用电脑的输出为:

计算结果:102334155
耗时:0 millis
复制代码

这里耗时为0不代表没有耗时,是表明这里计算的耗时几乎可以忽略不计,大家可以在自己的电脑试试,即使是n取大很多量级的数据(注意int溢出的问题)耗时也是很短的。

为什么在这里普通的递归或循环效率更快呢?因为Fork/Join是使用多个线程协作来计算的,所以会有线程通信和线程切换的开销。

如果要计算的任务比较简单(比如我们案例中的斐波那契数列),那当然是直接使用单线程会更快一些。但如果要计算的东西比较复杂,计算机又是多核的情况下,就可以充分利用多核CPU来提高计算速度。

猜你喜欢

转载自juejin.im/post/7218698736814080058