Java并发编程之Fork/Join框架详解

简介

Fork/Join框架是Java7提供了的一个用于并行执行任务的框架, 是一个把大任务分割成若干个小任务,最终汇总每个小任务结果后得到大任务结果的框架。Fork/Join的运行流程如下图所示:


工作窃取算法

工作窃取(work-stealing)算法是指某个线程从其他队列里窃取任务来执行。那么为什么需要使用工作窃取算法呢?

假如我们需要做一个比较大的任务,我们可以把这个任务分割为若干互不依赖的子任务,为了减少线程间的竞争,于是把这些子任务分别放到不同的队列里,并为每个队列创建一个单独的线程来执行队列里的任务,线程和队列一一对应,比如A线程负责处理A队列里的任务。但是有的线程会先把自己队列里的任务干完,而其他线程对应的队列里还有任务等待处理。干完活的线程与其等着,不如去帮其他线程干活,于是它就去其他线程的队列里窃取一个任务来执行。而在这时它们会访问同一个队列,所以为了减少窃取任务线程和被窃取任务线程之间的竞争,通常会使用双端队列,被窃取任务线程永远从双端队列的头部拿任务执行,而窃取任务的线程永远从双端队列的尾部拿任务执行。

工作窃取的运行流程图如下:


工作窃取算法的优点:充分利用线程进行并行计算,并减少了线程间的竞争;

工作窃取算法的缺点:在某些情况下还是存在竞争,比如双端队列里只有一个任务时。并且消耗了更多的系统资源,比如创建多个线程和多个双端队列。

Fork/Join框架的实现原理

在Java的Fork/Join框架中,它提供了两个类来帮助我们完成任务分割以及执行任务并合并结果

1、ForkJoinTask:我们要使用ForkJoin框架,必须首先创建一个ForkJoin任务。它提供在任务中执行fork()和join()操作的机制,通常情况下我们不需要直接继承ForkJoinTask类,而只需要继承它的子类,Fork/Join框架提供了以下两个子类:

  • RecursiveAction:用于没有返回结果的任务。
  • RecursiveTask :用于有返回结果的任务。

2、ForkJoinPool :ForkJoinTask需要通过ForkJoinPool来执行,任务分割出的子任务会添加到当前工作线程所维护的双端队列中,进入队列的头部。当一个工作线程的队列里暂时没有任务时,它会随机从其他工作线程的队列的尾部获取一个任务。

ForkJoinPool里面,有两个特别重要的成员如下:

volatile WorkQueue[] workQueues;     // main registry
final ForkJoinWorkerThreadFactory factory;

workQueues 用于保存向ForkJoinPool提交的任务,而具体的执行由ForkJoinWorkerThread执行,而ForkJoinWorkerThreadFactory可以用于生产出ForkJoinWorkerThread:

public static interface ForkJoinWorkerThreadFactory {
    /**
        * Returns a new worker thread operating in the given pool.
        *
        * @param pool the pool this thread works in
        * @return the new worker thread
        * @throws NullPointerException if the pool is null
        */
    public ForkJoinWorkerThread newThread(ForkJoinPool pool);
}

ForkJoinTask的fork方法实现原理

public final ForkJoinTask<V> fork() {
    Thread t;
    if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
        ((ForkJoinWorkerThread)t).workQueue.push(this);
    else
        ForkJoinPool.common.externalPush(this);
    return this;
}

若当前线程是ForkJoinWorkerThread线程,则强制类型转换(向下转换)成ForkJoinWorkerThread,然后将任务push到这个线程负责的队列里面去,在ForkJoinWorkerThread类中有一个pool和一个workQueue字段:

// 线程工作的ForkJoinPool
final ForkJoinPool pool;                // the pool this thread works in
// 工作窃取队列
final ForkJoinPool.WorkQueue workQueue; // work-stealing mechanics

workQueue的push()方法如下:

final void push(ForkJoinTask<?> task) {
    ForkJoinTask<?>[] a; ForkJoinPool p;
    int b = base, s = top, n;
    if ((a = array) != null) {    // ignore if queue removed
        int m = a.length - 1;     // fenced write for task visibility
        U.putOrderedObject(a, ((m & s) << ASHIFT) + ABASE, task);
        U.putOrderedInt(this, QTOP, s + 1);
        if ((n = s - b) <= 1) {
            if ((p = pool) != null)
                p.signalWork(p.workQueues, this);
        }
        else if (n >= m)
            growArray();
    }
}

该方法的主要功能就是将当前任务存放在ForkJoinTask数组array里。然后再调用ForkJoinPool的signalWork()方法唤醒或创建一个工作线程来执行任务。

ForkJoinTask的join方法实现原理

Join方法的主要作用是阻塞当前线程并等待获取结果,其源码如下:

public final V join() {
    int s;
    if ((s = doJoin() & DONE_MASK) != NORMAL)
        reportException(s);
    return getRawResult();
}

首先,它调用了doJoin()方法,通过doJoin()方法得到当前任务的状态来判断返回什么结果,任务状态有四种:已完成(NORMAL),被取消(CANCELLED),信号(SIGNAL)和出现异常(EXCEPTIONAL):

若状态不是NORMAL,则通过reportException(int)方法来处理状态:

private void reportException(int s) {
    if (s == CANCELLED)
        throw new CancellationException();
    if (s == EXCEPTIONAL)
        rethrow(getThrowableException());
}

  • 如果任务状态是已完成,则直接返回任务结果。
  • 如果任务状态是被取消,则直接抛出CancellationException。
  • 如果任务状态是抛出异常,则直接抛出对应的异常。

让我们再来分析下doJoin()方法的实现代码:

private int doJoin() {
    int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
    return (s = status) < 0 ? s :
        ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
        (w = (wt = (ForkJoinWorkerThread)t).workQueue).
        tryUnpush(this) && (s = doExec()) < 0 ? s :
        wt.pool.awaitJoin(w, this, 0L) :
        externalAwaitDone();
}

在doJoin()方法里,首先通过查看任务的状态,看任务是否已经执行完了,如果执行完了,则直接返回任务状态,如果没有执行完,则从任务数组里取出任务并执行。如果任务顺利执行完成了,则设置任务状态为NORMAL,如果出现异常,则纪录异常,并将任务状态设置为EXCEPTIONAL。

执行任务是通过doExec()方法来完成的:

final int doExec() {
    int s; boolean completed;
    if ((s = status) >= 0) {
        try {
            completed = exec();
        } catch (Throwable rex) {
            return setExceptionalCompletion(rex);
        }
        if (completed)
            s = setCompletion(NORMAL);
    }
    return s;
}

真正的执行过程是由exec()方法来完成的:

protected abstract boolean exec();

这就是我们需要重写的方法,若是我们的任务继承自RecursiveAction,则我们需要重写RecursiveAction的compute()方法:

public abstract class RecursiveAction extends ForkJoinTask<Void> {
    private static final long serialVersionUID = 5232453952276485070L;

    /**
     * The main computation performed by this task.
     */
    protected abstract void compute();

    /**
     * Always returns {@code null}.
     *
     * @return {@code null} always
     */
    public final Void getRawResult() { return null; }

    /**
     * Requires null completion value.
     */
    protected final void setRawResult(Void mustBeNull) { }

    /**
     * Implements execution conventions for RecursiveActions.
     */
    protected final boolean exec() {
        compute();
        return true;
    }
}
若是我们的任务继承自RecursiveTask,则我们同样需要重写RecursiveTask的compute()方法:
public abstract class RecursiveTask<V> extends ForkJoinTask<V> {
    private static final long serialVersionUID = 5232453952276485270L;

    /**
     * The result of the computation.
     */
    V result;

    /**
     * The main computation performed by this task.
     * @return the result of the computation
     */
    protected abstract V compute();

    public final V getRawResult() {
        return result;
    }

    protected final void setRawResult(V value) {
        result = value;
    }

    /**
     * Implements execution conventions for RecursiveTask.
     */
    protected final boolean exec() {
        result = compute();
        return true;
    }
}

通过上面的分析可知,执行我们的业务代码是在调用了join()之后的,也就是说,fork仅仅是分割任务,只有当我们执行join的时候,我们的任务才会被执行。

Fork/Join框架应用示例

public class ForkJoinTest {
	static class SumTask extends RecursiveTask<Long> {
	    static final int THRESHOLD = 100;
	    long[] array;
	    int start;
	    int end;

	    SumTask(long[] array, int start, int end) {
	    	this.array = array;
	        this.start = start;
	        this.end = end;
	    }

	    @Override
	    protected Long compute() {
	        if (end - start <= THRESHOLD) {
	            // 如果任务足够小,直接计算:
	            long sum = 0;
	            for (int i = start; i < end; i++) {
	                sum += array[i];
	            }
	            try {
	                Thread.sleep(1000);
	            } catch (InterruptedException e) {
	            }
	            System.out.println(String.format("compute %d~%d = %d", start, end, sum));
	            return sum;
	        }
	        // 任务太大,一分为二:
	        int middle = (end + start) / 2;
	        System.out.println(String.format("split %d~%d ==> %d~%d, %d~%d", start, end, start, middle, middle, end));
	        SumTask subtask1 = new SumTask(this.array, start, middle);
	        SumTask subtask2 = new SumTask(this.array, middle, end);
	        invokeAll(subtask1, subtask2);
//	        subtask1.fork();
//	        subtask2.fork();
	        Long subresult1 = subtask1.join();
	        Long subresult2 = subtask2.join();
	        Long result = subresult1 + subresult2;
	        System.out.println("result = " + subresult1 + " + " + subresult2 + " ==> " + result);
	        return result;
	    }
	}

	public static void main(String[] args) {
		// 创建随机数组成的数组:
	    long[] array = new long[1000];
	    fillRandom(array);
	    // fork/join task:
	    ForkJoinPool fjp = new ForkJoinPool(4); // 最大并发数4
	    ForkJoinTask<Long> task = new SumTask(array, 0, array.length);
	    long startTime = System.currentTimeMillis();
	    Long result = fjp.invoke(task);
	    long endTime = System.currentTimeMillis();
	    System.out.println("Fork/join sum: " + result + " in " + (endTime - startTime) + " ms.");
	}
	
	private static void fillRandom(long[] array) {
		Random random = new Random();
		int bound = 100000;
		
		for (int i = 0; i < array.length; i++) {
			array[i] = random.nextInt(bound);
		}
	}
}

运行结果:

split 0~1000 ==> 0~500, 500~1000
split 0~500 ==> 0~250, 250~500
split 500~1000 ==> 500~750, 750~1000
split 250~500 ==> 250~375, 375~500
split 250~375 ==> 250~312, 312~375
split 0~250 ==> 0~125, 125~250
split 0~125 ==> 0~62, 62~125
split 500~750 ==> 500~625, 625~750
split 750~1000 ==> 750~875, 875~1000
split 500~625 ==> 500~562, 562~625
split 750~875 ==> 750~812, 812~875
compute 0~62 = 3321492
compute 750~812 = 3109921
compute 250~312 = 2809208
compute 500~562 = 3113648
compute 62~125 = 3177944
result = 3321492 + 3177944 ==> 6499436
compute 812~875 = 3123212
split 125~250 ==> 125~187, 187~250
result = 3109921 + 3123212 ==> 6233133
compute 562~625 = 3423239
result = 3113648 + 3423239 ==> 6536887
split 875~1000 ==> 875~937, 937~1000
split 625~750 ==> 625~687, 687~750
compute 312~375 = 3161144
result = 2809208 + 3161144 ==> 5970352
split 375~500 ==> 375~437, 437~500
compute 125~187 = 2999804
compute 625~687 = 3002258
compute 875~937 = 2825923
compute 375~437 = 2772717
compute 187~250 = 3647221
result = 2999804 + 3647221 ==> 6647025
result = 6499436 + 6647025 ==> 13146461
compute 937~1000 = 3277843
result = 2825923 + 3277843 ==> 6103766
compute 687~750 = 2819135
result = 6233133 + 6103766 ==> 12336899
result = 3002258 + 2819135 ==> 5821393
result = 6536887 + 5821393 ==> 12358280
result = 12358280 + 12336899 ==> 24695179
compute 437~500 = 3185771
result = 2772717 + 3185771 ==> 5958488
result = 5970352 + 5958488 ==> 11928840
result = 13146461 + 11928840 ==> 25075301
result = 25075301 + 24695179 ==> 49770480
Fork/join sum: 49770480 in 4031 ms.

我们是采用了廖雪峰老师的源代码作为应用示例讲解的,该代码就是通过Fork/Join框架来计算数组的和,计算耗时4031毫秒。通过该代码作为应用示例主要是为了告诉大家,使用Fork/Join模型的正确方式,在源代码中可以看到,SumTask继承自RecursiveTask,重写的compute方法为:

protected Long compute() {
    if (end - start <= THRESHOLD) {
        // 如果任务足够小,直接计算:
        long sum = 0;
        for (int i = start; i < end; i++) {
            sum += array[i];
        }
        try {
            Thread.sleep(1000);
        } catch (InterruptedException e) {
        }
        System.out.println(String.format("compute %d~%d = %d", start, end, sum));
        return sum;
    }
    // 任务太大,一分为二:
    int middle = (end + start) / 2;
    System.out.println(String.format("split %d~%d ==> %d~%d, %d~%d", start, end, start, middle, middle, end));
    SumTask subtask1 = new SumTask(this.array, start, middle);
    SumTask subtask2 = new SumTask(this.array, middle, end);
    invokeAll(subtask1, subtask2);
//	        subtask1.fork();
//	        subtask2.fork();
    Long subresult1 = subtask1.join();
    Long subresult2 = subtask2.join();
    Long result = subresult1 + subresult2;
    System.out.println("result = " + subresult1 + " + " + subresult2 + " ==> " + result);
    return result;
}

compute()方法使用了invokeAll方法来分解任务,而不是它下面的

subtask1.fork();
subtask2.fork();

这两个方法,使用invokeAll方法的主要原因是为了充分利用线程池,在invokeAll的N个任务中,其中N-1个任务会使用fork()交给其它线程执行,但是,它还会留一个任务自己执行,这样,就充分利用了线程池,保证没有空闲的不干活的线程。

若是采用另外一种方式来运行,程序的运行时间为6028毫秒,可以看到,明显比invokeAll方式慢了很多。

参考资料

方腾飞:《Java并发编程的艺术》

Java的Fork/Join任务,你写对了吗?

猜你喜欢

转载自blog.csdn.net/qq_38293564/article/details/80610519