Java并发编程指南(五):Fork/Join框架

这个框架被设计用来解决可以使用分而治之技术将任务分解成更小的问题。在一个任务中,检查你想要解决问题的大小,如果它大于一个既定的大小,把它分解成更小的任务,然后用这个框架来执行。


这个框架基于以下两种操作:

  • fork操作:当你把任务分成更小的任务和使用这个框架执行它们。
  • join操作:当一个任务等待它创建的任务的结束。
work-stealing算法:

当一个任务正在等待它使用join操作创建的子任务的结 束时,执行这个任务的线程(工作线程)查找其他未被执行的任务并开始它的执行。通过这种方式,线程充分利用它们的运行时间,从而提高了应用程序的性能。

Fork/Join框架执行的任务有以下局限性:

  • 任务只能使用fork()和join()操作,作为同步机制。如果使用其他同步机制,工作线程不能执行其他任务,当它们在同步操作时。比如,在Fork/Join框架中,你使任务进入睡眠,正在执行这个任务的工作线程将不会执行其他任务,在这睡眠期间内。
  • 任务不应该执行I/O操作,如读或写数据文件。
  • 任务不能抛出检查异常,它必须包括必要的代码来处理它们。
Fork/Join框架的核心是由以下两个类:
  • ForkJoinPool:它实现ExecutorService接口和work-stealing算法。它管理工作线程和提供关于任务的状态和它们执行的信息。
  • ForkJoinTask: 它是将在ForkJoinPool中执行的任务的基类。它提供在任务中执行fork()和join()操作的机制,并且这两个方法控制任务的状态。通常, 为了实现你的Fork/Join任务,你将实现这个类的两个子类的子类:RecursiveAction对于没有返回结果的任务和RecursiveTask 对于返回结果的任务。

1. 创建一个Fork/Join池 :


//1.创建类Product,将用来存储产品的名称和价格。
class Product {
    private String name;
    private double price;
    public String getName() {
        return name;
    }
    public void setName(String name) {
        this.name = name;
    }
    public double getPrice() {
        return price;
    }
    public void setPrice(double price) {
        this.price = price;
    }
}

//2.创建ProductListGenerator类,用来产生随机产品的数列。
class ProductListGenerator {
    public List<Product> generate(int size) {
        List<Product> ret = new ArrayList<>();
        for (int i = 0; i < size; i++) {
            Product product = new Product();
            product.setName("Product" + i);
            product.setPrice(10);
            ret.add(product);
        }
        return ret;
    }
}

//3.创建Task类,指定它继承RecursiveAction类。
class Task extends RecursiveAction {
    private static final long serialVersionUID = 1L;
    private List<Product> products;
    //11.声明两个私有的、int类型的属性first和last。这些属性将决定这个任务产品的阻塞过程。
    private int first;
    private int last;
    //12.声明一个私有的、double类型的属性increment,用来存储产品价格的增长。
    private double increment;
    //13.实现这个类的构造器,初始化所有属性。
    public Task(List<Product> products, int first, int last, double increment) {
        this.products = products;
        this.first = first;
        this.last = last;
        this.increment = increment;
    }
    //14.实现compute()方法 ,该方法将实现任务的逻辑。
    @Override
    protected void compute() {
        //15.如果任务量足够小,就直接在当前线程计算任务。
        if (last - first < 10) {
            updatePrices();
            //16.如果任务量不够小,则创建两个新的Task对象,一个处理产品的前半部分,另一个处理产品的后半部分,
            // 然后在ForkJoinPool中,使用invokeAll()方法执行它们。
        } else {
            int middle = (last + first) / 2;
            System.out.printf("Task: Pending tasks: %s\n", getQueuedTaskCount());
            Task t1 = new Task(products, first, middle + 1, increment);
            Task t2 = new Task(products, middle + 1, last, increment);
            invokeAll(t1, t2);
        }
    }
    //17.实现updatePrices()方法。这个方法更新产品队列中位于first值和last值之间的产品。
    private void updatePrices() {
        for (int i = first; i < last; i++) {
            Product product = products.get(i);
            product.setPrice(product.getPrice() * (1 + increment));
        }
    }
}

class Main5 {
    public static void main(String[] args) {
        //19.使用ProductListGenerator类创建一个包括10000个产品的数列。
        ProductListGenerator generator = new ProductListGenerator();
        List<Product> products = generator.generate(10000);
        //20.创建一个新的Task对象,用来更新产品队列中的产品。first参数使用值0,last参数使用值10000(产品数列的大小)。
        Task task = new Task(products, 0, products.size(), 0.20);
        //21.使用无参构造器创建ForkJoinPool对象。
        ForkJoinPool pool = new ForkJoinPool();
        //22.在池中使用execute()方法执行这个任务 。
        pool.execute(task);
        //23.实现一个显示关于每隔5毫秒池中的变化信息的代码块。将池中的一些参数值写入到控制台,直到任务完成它的执行。
        do {
            System.out.printf("Main: Thread Count: %d\n", pool.getActiveThreadCount());
            System.out.printf("Main: Thread Steal: %d\n", pool.getStealCount());
            System.out.printf("Main: Parallelism: %d\n", pool.getParallelism());
            try {
                TimeUnit.MILLISECONDS.sleep(5);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        } while (!task.isDone());
        //24.使用shutdown()方法关闭这个池。
        pool.shutdown();
        //25.使用isCompletedNormally()方法检查假设任务完成时没有出错,在这种情况下,写入一条信息到控制台。
        if (task.isCompletedNormally()) {
            System.out.printf("Main: The process has completed normally.\n");
        }
        //26.在增长之后,所有产品的价格应该是12。将价格不是12的所有产品的名称和价格写入到控制台,用来检查它们错误地增长它们的价格。
        for (int i = 0; i < products.size(); i++) {
            Product product = products.get(i);
            if (product.getPrice() != 12) {
                System.out.printf("Product %s: %f\n", product.getName(), product.getPrice());
            }
        }
        //27.写入一条信息到控制台表明程序的结束。
        System.out.println("Main: End of the program.\n");
    }
}

2. 加入任务的结果

//1.创建一个Document类,它将产生用来模拟文档的字符串的二维数组。
class Document {
    //2.创建一个带有一些单词的字符串数组。这个数组将被用来生成字符串二维数组。
    private String words[] = {"the", "hello", "goodbye", "packt", "java", "thread", "pool", "random", "class", "main"};
    //3.实现generateDocument()方法。它接收以下参数:行数、每行的单词数。这个例子返回一个字符串二维数组,来表示将要查找的单词。
    public String[][] generateDocument(int numLines, int numWords, String word) {
        //4.首先,创建生成这个文档必需的对象:字符串二维对象和生成随机数的Random对象。
        int counter = 0;
        String document[][] = new String[numLines][numWords];
        Random random = new Random();
        //5.用字符串填充这个数组。存储在每个位置的字符串是单词数组的随机位置,统计这个程序将要在生成的数组中查找的单词出现的次数。你可以使用这个值来检查程序是否执行正确。
        for (int i = 0; i < numLines; i++) {
            for (int j = 0; j < numWords; j++) {
                int index = random.nextInt(words.length);
                document[i][j] = words[index];
                if (document[i][j].equals(word)) {
                    counter++;
                }
            }
        }
        //6.将单词出现的次数写入控制台,并返回生成的二维数组。
        System.out.println("DocumentMock: The word appears " + counter + " times in the document");
        return document;
    }
}

//7.创建一个DocumentTask类,指定它继承RecursiveTask类,并参数化为Integer类型。该类将实现统计单词在一组行中出现的次数的任务。
class DocumentTask extends RecursiveTask<Integer> {
    //8.声明一个私有的String类型的二维数组document,两个私有的int类型的属性名为start和end,一个私有的String类型的属性名为word。
    private String document[][];
    private int start, end;
    private String word;

    //9.实现这个类的构造器,用来初始化这些属性。
    public DocumentTask(String document[][], int start, int end, String word) {
        this.document = document;
        this.start = start;
        this.end = end;
        this.word = word;
    }

    //10.实现compute()方法。如果属性end和start的差小于10,那么这个任务统计单词位于行在调用processLines()方法的这些位置中出现的次数。
    @Override
    protected Integer compute() {
        int result = 0;
        if (end - start < 10) {
            result = processLines(document, start, end, word);
            //11.否则,用两个对象分解行组,创建两个新的DocumentTask对象用来处理这两个组,并且在池中使用invokeAll()方法来执行它们。
        } else {
            int mid = (start + end) / 2;
            DocumentTask task1 = new DocumentTask(document, start, mid, word);
            DocumentTask task2 = new DocumentTask(document, mid, end, word);
            invokeAll(task1, task2);
            try {
                result = task1.get() + task2.get();
            } catch (InterruptedException | ExecutionException e) {
                e.printStackTrace();
            }
        }
        return result;
    }

    //13.实现processLines()方法。它接收以下参数:字符串二维数组、start属性、end属性、任务将要查找的word属性。
    private Integer processLines(String[][] document, int start, int end, String word) {
        //14.对于任务要处理的每行,创建LineTask对象来处理整行,并且将它们存储在任务数列中。
        List<LineTask> tasks = new ArrayList<>();
        for (int i = start; i < end; i++) {
            LineTask task = new LineTask(document[i], 0, document[i].length, word);
            tasks.add(task);
        }
        //15.在那个数列中使用invokeAll()执行所有任务。
        invokeAll(tasks);
        //16.合计所有这些任务返回的值,并返回这个结果。
        int result = 0;
        for (int i = 0; i < tasks.size(); i++) {
            LineTask task = tasks.get(i);
            try {
                result = result + task.get();
            } catch (InterruptedException | ExecutionException e) {
                e.printStackTrace();
            }
        }
        return result;
    }
}

//18.创建LineTask类,指定它继承RecursiveTask类,并参数化为Integer类型。这个类将实现统计单词在一行中出现的次数的任务。
class LineTask extends RecursiveTask<Integer> {
    //19.声明这个类的序列号版本UID。这个元素是必需的,因为RecursiveTask类的父类,ForkJoinTask类实现了Serializable接口。声明一个私有的、String类型的属性line,两个私有的、int类型的属性start和end,一个私有的、String类型的属性word。
    private static final long serialVersionUID = 1L;
    private String line[];
    private int start, end;
    private String word;

    //20.实现这个类的构造器,初始化这些属性。
    public LineTask(String line[], int start, int end, String word) {
        this.line = line;
        this.start = start;
        this.end = end;
        this.word = word;
    }

    //21.实现这个类的compute()方法。如果属性end和start之差小于100,这个任务在行中由start和end属性使用count()方法决定的片断中查找单词。
    @Override
    protected Integer compute() {
        Integer result = null;
        if (end - start < 100) {
            result = count(line, start, end, word);
            //22.否则,将行中的单词组分成两部分,创建两个新的LineTask对象来处理这两个组,在池中使用invokeAll()方法执行它们。
        } else {
            int mid = (start + end) / 2;
            LineTask task1 = new LineTask(line, start, mid, word);
            LineTask task2 = new LineTask(line, mid, end, word);
            invokeAll(task1, task2);
            //23.然后,使用groupResults()方法将这两个任务返回的值相加。最后,返回这个任务计算的结果。
            try {
                result = groupResults(task1.get(), task2.get());
            } catch (InterruptedException | ExecutionException e) {
                e.printStackTrace();
            }
        }
        return result;
    }

    //24.实现count()方法。它接收以下参数:完整行的字符串数组、start属性、end属性、任务将要查找的word属性。
    private Integer count(String[] line, int start, int end, String word) {
        //25.比较这个任务将要查找的word属性中的在start和end属性之间的位置的单词,如果它们相等,则增加count变量。
        int counter;
        counter = 0;
        for (int i = start; i < end; i++) {
            if (line[i].equals(word)) {
                counter++;
            }
        }
        //26.为了显示示例的执行,令任务睡眠10毫秒。
        try {
            Thread.sleep(10);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        //27.返回counter变量的值。
        return counter;
    }

    //28.实现groupResults()方法。它合计两个数的值,并返回这个结果。
    private Integer groupResults(Integer number1, Integer number2) {
        Integer result;
        result = number1 + number2;
        return result;
    }
}

//29.实现示例的主类,通过创建Main类,并实现main()方法。
class Main6 {
    public static void main(String[] args) {
        //30.使用DocumentMock类,创建一个带有100行,每行1000个单词的Document。
        Document mock = new Document();
        String[][] document = mock.generateDocument(100, 1000, "the");
        //31.创建一个新的DocumentTask对象,用来更新整个文档的产品。参数start值为0,参数end值为100。
        DocumentTask task = new DocumentTask(document, 0, 100, "the");
        //32.使用无参构造器创建一个ForkJoinPool对象,在池中使用execute()方法执行这个任务。
        ForkJoinPool pool = new ForkJoinPool();
        pool.execute(task);
        //33.实现一个代码块,用来显示关于池变化的信息。每秒向控制台写入池的某些参数的值,直到任务完成它的执行。
        do {
            System.out.printf("******************************************\n");
            System.out.printf("Main: Parallelism: %d\n", pool.getParallelism());
            System.out.printf("Main: Active Threads: %d\n", pool.getActiveThreadCount());
            System.out.printf("Main: Task Count: %d\n", pool.getQueuedTaskCount());
            System.out.printf("Main: Steal Count: %d\n", pool.getStealCount());
            System.out.printf("******************************************\n");
            try {
                TimeUnit.SECONDS.sleep(1);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        } while (!task.isDone());
        //34.使用shutdown()方法关闭这个池。
        pool.shutdown();
        //35.使用awaitTermination()方法等待任务的结束。
        try {
            System.out.printf("Main: The word appears %d in the document", task.get());
        } catch (InterruptedException | ExecutionException e) {
            e.printStackTrace();
        }
        //36.打印单词在文档中出现的次数。检查这个数是否与DocumentMock类中写入的数一样。
        try {
            System.out.printf("Main: The word appears %d in the document", task.get());
        } catch (InterruptedException | ExecutionException e) {
            e.printStackTrace();
        }
    }
}


3. 异步运行任务

使用fork()方法把Task对象提交给池,任务将会被异步的执行,使用join()方法等待已提交到池的所有任务的结束。


4. 在任务中抛出异常

在ForkJoinTask类的compute()方法中,你不能抛出任何已检查异常,因为在这个方法的实现中,它没有包含任何抛出(异常)声明。你必须包含必要的代码来处理异常。但是,你可以抛出(或者它可以被任何方法或使用内部方法的对象抛出)一个未检查异常。ForkJoinTask和ForkJoinPool类的行为与你可能的期望不同。程序不会结束执行,并且你将不会在控制台看到任何关于异常的信息。它只是被吞没,好像它没抛出(异常)。你可以使用ForkJoinTask类的一些方法,得知一个任务是否抛出异常及其异常种类。

扫描二维码关注公众号,回复: 191442 查看本文章

使用isCompletedAbnormally()方法,检查这个任务或它的子任务是否已经抛出异常。在这种情况下,将抛出的异常写入到控制台。使用ForkJoinTask类的getException()方法获取那个异常。

你还可以使用ForkJoinTask类的completeExceptionally()方法将异常信息返回。


5.取消任务:

当你在一个ForkJoinPool类中执行ForkJoinTask对象,在它们开始执行之前,你可以取消执行它们。ForkJoinTask类提供cancel()方法用于这个目的。当你想要取消一个任务时,有一些点你必须考虑一下,这些点如下:
  • ForkJoinPool类并没有提供任何方法来取消正在池中运行或等待的所有任务。
  • 当你取消一个任务时,你不能取消一个已经执行的任务。



猜你喜欢

转载自blog.csdn.net/sunjin9418/article/details/79558644