Java -- The Fork-Join Framework

We will discuss a simple example. Suppose we want to count how many elements of an array fulfill a particular property. We cut the array in half, compute the counts of each half, and add them up. To put the recursive computation in a form that is usable by the framework, supply a class that extends "RecursiveTask<T> (if the computation produces a result of type T) or "RecursiveAction" (if it doesn't produce a result). Override the "compute" method to generate and invoke subtasks, and to combine their results.

class Counter extends RecursiveTask<Integer> {
    ...
    protected Integer compute() {
        
        if (to - from < THRESHOLD) {
            solve problem directly
        } else {
            int mid = (to - from)/2;
            Counter first = new Counter(values, from, min, filter);
            Counter second = new Counter(values, min, to, filter);
            invokeAll(first, second);
            return first.join() + second.join();
        }

    }

}

Here, the "invokeAll" method receives a number of tasks and blocks until all of them have completed. The "join" method yields the result. Here, we apply "join" to each subtask and return the sum.

NOTE: There is also a "get" method for getting the current result, but it is less attractive since it can throw checked exceptions that we are not allowed to throw in the "compute" method.

package ConcurrentTest;

import java.util.Arrays;
import java.util.Date;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
import java.util.function.DoublePredicate;

/**
 * Created by lenovo on 2018/6/29.
 */
public class ForkJoinFramwork {

    public static void main(String[] args) {
        int SIZE = 100000000;
        double[] values = new double[SIZE];
        DoublePredicate filter = x -> x > 0.5;
        long beginSet = new Date().getTime();

        Arrays.setAll(values, i -> Math.random());
        System.out.println("parallel setting costs: " + ((new Date().getTime() - beginSet)/1000) + " seconds.");

        beginSet = new Date().getTime();
        Counter task = new Counter(values, 0, SIZE, filter);
        ForkJoinPool pool = new ForkJoinPool();
        pool.invoke(task);
        System.out.println(task.join());
        System.out.println("Counter costs: " + ((new Date().getTime() - beginSet)/1000) + " seconds.");
    }

}

package ConcurrentTest;

import java.util.concurrent.RecursiveTask;
import java.util.function.DoublePredicate;

/**
 * Created by lenovo on 2018/6/29.
 */
public class Counter extends RecursiveTask<Long> {
    private double[] values;
    private int from;
    private int to;
    private DoublePredicate filter;
    private int threshold = 1000;

    public Counter(double[] values, int from, int to, DoublePredicate filter) {
        this.values = values;
        this.from = from;
        this.to = to;
        this.filter = filter;
    }

    @Override
    protected Long compute() {
        long counter = 0;

        if ((to - from)/2 < threshold) {

            for (int i = from; i < to; i++) {
                double value = values[i];

                if (filter.test(value)) {
                    counter++;
                }

            }

        } else {
            Counter subCounter1 = new Counter(values, from, from + (to - from)/2, filter);
            Counter subCounter2 = new Counter(values, from + (to - from)/2, to, filter);
            invokeAll(subCounter1, subCounter2);
            counter = subCounter1.join() + subCounter2.join();
        }

        return counter;
    }

}

猜你喜欢

转载自blog.csdn.net/liangking81/article/details/80860781