We will discuss a simple example. Suppose we want to count how many elements of an array fulfill a particular property. We cut the array in half, compute the counts of each half, and add them up. To put the recursive computation in a form that is usable by the framework, supply a class that extends "RecursiveTask<T> (if the computation produces a result of type T) or "RecursiveAction" (if it doesn't produce a result). Override the "compute" method to generate and invoke subtasks, and to combine their results.
class Counter extends RecursiveTask<Integer> {
...
protected Integer compute() {
if (to - from < THRESHOLD) {
solve problem directly
} else {
int mid = (to - from)/2;
Counter first = new Counter(values, from, min, filter);
Counter second = new Counter(values, min, to, filter);
invokeAll(first, second);
return first.join() + second.join();
}
}
}
Here, the "invokeAll" method receives a number of tasks and blocks until all of them have completed. The "join" method yields the result. Here, we apply "join" to each subtask and return the sum.
NOTE: There is also a "get" method for getting the current result, but it is less attractive since it can throw checked exceptions that we are not allowed to throw in the "compute" method.
package ConcurrentTest;
import java.util.Arrays;
import java.util.Date;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
import java.util.function.DoublePredicate;
/**
* Created by lenovo on 2018/6/29.
*/
public class ForkJoinFramwork {
public static void main(String[] args) {
int SIZE = 100000000;
double[] values = new double[SIZE];
DoublePredicate filter = x -> x > 0.5;
long beginSet = new Date().getTime();
Arrays.setAll(values, i -> Math.random());
System.out.println("parallel setting costs: " + ((new Date().getTime() - beginSet)/1000) + " seconds.");
beginSet = new Date().getTime();
Counter task = new Counter(values, 0, SIZE, filter);
ForkJoinPool pool = new ForkJoinPool();
pool.invoke(task);
System.out.println(task.join());
System.out.println("Counter costs: " + ((new Date().getTime() - beginSet)/1000) + " seconds.");
}
}
package ConcurrentTest;
import java.util.concurrent.RecursiveTask;
import java.util.function.DoublePredicate;
/**
* Created by lenovo on 2018/6/29.
*/
public class Counter extends RecursiveTask<Long> {
private double[] values;
private int from;
private int to;
private DoublePredicate filter;
private int threshold = 1000;
public Counter(double[] values, int from, int to, DoublePredicate filter) {
this.values = values;
this.from = from;
this.to = to;
this.filter = filter;
}
@Override
protected Long compute() {
long counter = 0;
if ((to - from)/2 < threshold) {
for (int i = from; i < to; i++) {
double value = values[i];
if (filter.test(value)) {
counter++;
}
}
} else {
Counter subCounter1 = new Counter(values, from, from + (to - from)/2, filter);
Counter subCounter2 = new Counter(values, from + (to - from)/2, to, filter);
invokeAll(subCounter1, subCounter2);
counter = subCounter1.join() + subCounter2.join();
}
return counter;
}
}