简介
Fork/Join框架是Java7提供了的一个用于并行执行任务的框架, 是一个把大任务分割成若干个小任务,最终汇总每个小任务结果后得到大任务结果的框架。Fork/Join的运行流程如下图所示:
工作窃取算法
工作窃取(work-stealing)算法是指某个线程从其他队列里窃取任务来执行。那么为什么需要使用工作窃取算法呢?
假如我们需要做一个比较大的任务,我们可以把这个任务分割为若干互不依赖的子任务,为了减少线程间的竞争,于是把这些子任务分别放到不同的队列里,并为每个队列创建一个单独的线程来执行队列里的任务,线程和队列一一对应,比如A线程负责处理A队列里的任务。但是有的线程会先把自己队列里的任务干完,而其他线程对应的队列里还有任务等待处理。干完活的线程与其等着,不如去帮其他线程干活,于是它就去其他线程的队列里窃取一个任务来执行。而在这时它们会访问同一个队列,所以为了减少窃取任务线程和被窃取任务线程之间的竞争,通常会使用双端队列,被窃取任务线程永远从双端队列的头部拿任务执行,而窃取任务的线程永远从双端队列的尾部拿任务执行。
工作窃取的运行流程图如下:
工作窃取算法的优点:充分利用线程进行并行计算,并减少了线程间的竞争;
工作窃取算法的缺点:在某些情况下还是存在竞争,比如双端队列里只有一个任务时。并且消耗了更多的系统资源,比如创建多个线程和多个双端队列。
Fork/Join框架的实现原理
在Java的Fork/Join框架中,它提供了两个类来帮助我们完成任务分割以及执行任务并合并结果:
1、ForkJoinTask:我们要使用ForkJoin框架,必须首先创建一个ForkJoin任务。它提供在任务中执行fork()和join()操作的机制,通常情况下我们不需要直接继承ForkJoinTask类,而只需要继承它的子类,Fork/Join框架提供了以下两个子类:
- RecursiveAction:用于没有返回结果的任务。
- RecursiveTask :用于有返回结果的任务。
2、ForkJoinPool :ForkJoinTask需要通过ForkJoinPool来执行,任务分割出的子任务会添加到当前工作线程所维护的双端队列中,进入队列的头部。当一个工作线程的队列里暂时没有任务时,它会随机从其他工作线程的队列的尾部获取一个任务。
ForkJoinPool里面,有两个特别重要的成员如下:
volatile WorkQueue[] workQueues; // main registry final ForkJoinWorkerThreadFactory factory;
workQueues 用于保存向ForkJoinPool提交的任务,而具体的执行由ForkJoinWorkerThread执行,而ForkJoinWorkerThreadFactory可以用于生产出ForkJoinWorkerThread:
public static interface ForkJoinWorkerThreadFactory { /** * Returns a new worker thread operating in the given pool. * * @param pool the pool this thread works in * @return the new worker thread * @throws NullPointerException if the pool is null */ public ForkJoinWorkerThread newThread(ForkJoinPool pool); }
ForkJoinTask的fork方法实现原理
public final ForkJoinTask<V> fork() { Thread t; if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ((ForkJoinWorkerThread)t).workQueue.push(this); else ForkJoinPool.common.externalPush(this); return this; }
若当前线程是ForkJoinWorkerThread线程,则强制类型转换(向下转换)成ForkJoinWorkerThread,然后将任务push到这个线程负责的队列里面去,在ForkJoinWorkerThread类中有一个pool和一个workQueue字段:
// 线程工作的ForkJoinPool final ForkJoinPool pool; // the pool this thread works in // 工作窃取队列 final ForkJoinPool.WorkQueue workQueue; // work-stealing mechanics
workQueue的push()方法如下:
final void push(ForkJoinTask<?> task) { ForkJoinTask<?>[] a; ForkJoinPool p; int b = base, s = top, n; if ((a = array) != null) { // ignore if queue removed int m = a.length - 1; // fenced write for task visibility U.putOrderedObject(a, ((m & s) << ASHIFT) + ABASE, task); U.putOrderedInt(this, QTOP, s + 1); if ((n = s - b) <= 1) { if ((p = pool) != null) p.signalWork(p.workQueues, this); } else if (n >= m) growArray(); } }
该方法的主要功能就是将当前任务存放在ForkJoinTask数组array里。然后再调用ForkJoinPool的signalWork()方法唤醒或创建一个工作线程来执行任务。
ForkJoinTask的join方法实现原理
Join方法的主要作用是阻塞当前线程并等待获取结果,其源码如下:
public final V join() { int s; if ((s = doJoin() & DONE_MASK) != NORMAL) reportException(s); return getRawResult(); }
首先,它调用了doJoin()方法,通过doJoin()方法得到当前任务的状态来判断返回什么结果,任务状态有四种:已完成(NORMAL),被取消(CANCELLED),信号(SIGNAL)和出现异常(EXCEPTIONAL):
若状态不是NORMAL,则通过reportException(int)方法来处理状态:
private void reportException(int s) { if (s == CANCELLED) throw new CancellationException(); if (s == EXCEPTIONAL) rethrow(getThrowableException()); }
- 如果任务状态是已完成,则直接返回任务结果。
- 如果任务状态是被取消,则直接抛出CancellationException。
- 如果任务状态是抛出异常,则直接抛出对应的异常。
让我们再来分析下doJoin()方法的实现代码:
private int doJoin() { int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w; return (s = status) < 0 ? s : ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ? (w = (wt = (ForkJoinWorkerThread)t).workQueue). tryUnpush(this) && (s = doExec()) < 0 ? s : wt.pool.awaitJoin(w, this, 0L) : externalAwaitDone(); }
在doJoin()方法里,首先通过查看任务的状态,看任务是否已经执行完了,如果执行完了,则直接返回任务状态,如果没有执行完,则从任务数组里取出任务并执行。如果任务顺利执行完成了,则设置任务状态为NORMAL,如果出现异常,则纪录异常,并将任务状态设置为EXCEPTIONAL。
执行任务是通过doExec()方法来完成的:
final int doExec() { int s; boolean completed; if ((s = status) >= 0) { try { completed = exec(); } catch (Throwable rex) { return setExceptionalCompletion(rex); } if (completed) s = setCompletion(NORMAL); } return s; }
真正的执行过程是由exec()方法来完成的:
protected abstract boolean exec();
这就是我们需要重写的方法,若是我们的任务继承自RecursiveAction,则我们需要重写RecursiveAction的compute()方法:
public abstract class RecursiveAction extends ForkJoinTask<Void> { private static final long serialVersionUID = 5232453952276485070L; /** * The main computation performed by this task. */ protected abstract void compute(); /** * Always returns {@code null}. * * @return {@code null} always */ public final Void getRawResult() { return null; } /** * Requires null completion value. */ protected final void setRawResult(Void mustBeNull) { } /** * Implements execution conventions for RecursiveActions. */ protected final boolean exec() { compute(); return true; } }若是我们的任务继承自RecursiveTask,则我们同样需要重写RecursiveTask的compute()方法:
public abstract class RecursiveTask<V> extends ForkJoinTask<V> { private static final long serialVersionUID = 5232453952276485270L; /** * The result of the computation. */ V result; /** * The main computation performed by this task. * @return the result of the computation */ protected abstract V compute(); public final V getRawResult() { return result; } protected final void setRawResult(V value) { result = value; } /** * Implements execution conventions for RecursiveTask. */ protected final boolean exec() { result = compute(); return true; } }
通过上面的分析可知,执行我们的业务代码是在调用了join()之后的,也就是说,fork仅仅是分割任务,只有当我们执行join的时候,我们的任务才会被执行。
Fork/Join框架应用示例
public class ForkJoinTest { static class SumTask extends RecursiveTask<Long> { static final int THRESHOLD = 100; long[] array; int start; int end; SumTask(long[] array, int start, int end) { this.array = array; this.start = start; this.end = end; } @Override protected Long compute() { if (end - start <= THRESHOLD) { // 如果任务足够小,直接计算: long sum = 0; for (int i = start; i < end; i++) { sum += array[i]; } try { Thread.sleep(1000); } catch (InterruptedException e) { } System.out.println(String.format("compute %d~%d = %d", start, end, sum)); return sum; } // 任务太大,一分为二: int middle = (end + start) / 2; System.out.println(String.format("split %d~%d ==> %d~%d, %d~%d", start, end, start, middle, middle, end)); SumTask subtask1 = new SumTask(this.array, start, middle); SumTask subtask2 = new SumTask(this.array, middle, end); invokeAll(subtask1, subtask2); // subtask1.fork(); // subtask2.fork(); Long subresult1 = subtask1.join(); Long subresult2 = subtask2.join(); Long result = subresult1 + subresult2; System.out.println("result = " + subresult1 + " + " + subresult2 + " ==> " + result); return result; } } public static void main(String[] args) { // 创建随机数组成的数组: long[] array = new long[1000]; fillRandom(array); // fork/join task: ForkJoinPool fjp = new ForkJoinPool(4); // 最大并发数4 ForkJoinTask<Long> task = new SumTask(array, 0, array.length); long startTime = System.currentTimeMillis(); Long result = fjp.invoke(task); long endTime = System.currentTimeMillis(); System.out.println("Fork/join sum: " + result + " in " + (endTime - startTime) + " ms."); } private static void fillRandom(long[] array) { Random random = new Random(); int bound = 100000; for (int i = 0; i < array.length; i++) { array[i] = random.nextInt(bound); } } }
运行结果:
split 0~1000 ==> 0~500, 500~1000 split 0~500 ==> 0~250, 250~500 split 500~1000 ==> 500~750, 750~1000 split 250~500 ==> 250~375, 375~500 split 250~375 ==> 250~312, 312~375 split 0~250 ==> 0~125, 125~250 split 0~125 ==> 0~62, 62~125 split 500~750 ==> 500~625, 625~750 split 750~1000 ==> 750~875, 875~1000 split 500~625 ==> 500~562, 562~625 split 750~875 ==> 750~812, 812~875 compute 0~62 = 3321492 compute 750~812 = 3109921 compute 250~312 = 2809208 compute 500~562 = 3113648 compute 62~125 = 3177944 result = 3321492 + 3177944 ==> 6499436 compute 812~875 = 3123212 split 125~250 ==> 125~187, 187~250 result = 3109921 + 3123212 ==> 6233133 compute 562~625 = 3423239 result = 3113648 + 3423239 ==> 6536887 split 875~1000 ==> 875~937, 937~1000 split 625~750 ==> 625~687, 687~750 compute 312~375 = 3161144 result = 2809208 + 3161144 ==> 5970352 split 375~500 ==> 375~437, 437~500 compute 125~187 = 2999804 compute 625~687 = 3002258 compute 875~937 = 2825923 compute 375~437 = 2772717 compute 187~250 = 3647221 result = 2999804 + 3647221 ==> 6647025 result = 6499436 + 6647025 ==> 13146461 compute 937~1000 = 3277843 result = 2825923 + 3277843 ==> 6103766 compute 687~750 = 2819135 result = 6233133 + 6103766 ==> 12336899 result = 3002258 + 2819135 ==> 5821393 result = 6536887 + 5821393 ==> 12358280 result = 12358280 + 12336899 ==> 24695179 compute 437~500 = 3185771 result = 2772717 + 3185771 ==> 5958488 result = 5970352 + 5958488 ==> 11928840 result = 13146461 + 11928840 ==> 25075301 result = 25075301 + 24695179 ==> 49770480 Fork/join sum: 49770480 in 4031 ms.
我们是采用了廖雪峰老师的源代码作为应用示例讲解的,该代码就是通过Fork/Join框架来计算数组的和,计算耗时4031毫秒。通过该代码作为应用示例主要是为了告诉大家,使用Fork/Join模型的正确方式,在源代码中可以看到,SumTask继承自RecursiveTask,重写的compute方法为:
protected Long compute() { if (end - start <= THRESHOLD) { // 如果任务足够小,直接计算: long sum = 0; for (int i = start; i < end; i++) { sum += array[i]; } try { Thread.sleep(1000); } catch (InterruptedException e) { } System.out.println(String.format("compute %d~%d = %d", start, end, sum)); return sum; } // 任务太大,一分为二: int middle = (end + start) / 2; System.out.println(String.format("split %d~%d ==> %d~%d, %d~%d", start, end, start, middle, middle, end)); SumTask subtask1 = new SumTask(this.array, start, middle); SumTask subtask2 = new SumTask(this.array, middle, end); invokeAll(subtask1, subtask2); // subtask1.fork(); // subtask2.fork(); Long subresult1 = subtask1.join(); Long subresult2 = subtask2.join(); Long result = subresult1 + subresult2; System.out.println("result = " + subresult1 + " + " + subresult2 + " ==> " + result); return result; }
compute()方法使用了invokeAll方法来分解任务,而不是它下面的
subtask1.fork(); subtask2.fork();
这两个方法,使用invokeAll方法的主要原因是为了充分利用线程池,在invokeAll的N个任务中,其中N-1个任务会使用fork()交给其它线程执行,但是,它还会留一个任务自己执行,这样,就充分利用了线程池,保证没有空闲的不干活的线程。
若是采用另外一种方式来运行,程序的运行时间为6028毫秒,可以看到,明显比invokeAll方式慢了很多。
参考资料
方腾飞:《Java并发编程的艺术》