重写ThreadPoolExecutor

ThreadPoolExecutor

简单介绍一下ThreadPoolExecutor的加入规则:

corePoolSize: maximumPoolSize, workQueue;

假设核心线程全部建立,并且不销毁

当前任务加入线程池后, 如果正在执行的任务数量少于corePoolSize, 那么直接加入corePoolSize已经开辟的线程中运行; 

但如果运行的任务数量等于corePoolSize时, 该任务会添加到队列, 直到队列放置不下时; 判断存在的线程数是否达到maximumPoolSize,如果没有达到,开辟新线程继续运行.

重写需求如下:

1.线程池中始终保持核心线程数, 当任务加入ThreadPoolExecutor时,如果有空闲线程,自动到空闲线程中运行

2.如果核心线程都在运行, 依然有新加入任务时, 这时如果线程数量少于maximumPoolSize, 创建新线程运行该任务.

3.如果运行的线程数等于maximumPoolSize时, 将请求任务加入队列, 不拒绝任务. 


修改如下:

import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

public class MyThreadPoolExecutor extends java.util.concurrent.ThreadPoolExecutor {
    private final AtomicInteger submittedTaskCount = new AtomicInteger(0);

    public MyThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit,
                                TaskQueue workQueue) {
        super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue);
        workQueue.setExecutor(this);
        prestartAllCoreThreads();
    }

    //执行完成后计数器减1
    @Override
    protected void afterExecute(Runnable r, Throwable t) {
        super.afterExecute(r, t);
        submittedTaskCount.decrementAndGet();
    }

    public int getSubmittedTaskCount() {
        return submittedTaskCount.get();
    }

    @Override
    public void execute(Runnable command) {
        submittedTaskCount.incrementAndGet();
        try {
            super.execute(command);
        } catch (RejectedExecutionException rx) {
            if(super.getQueue() instanceof TaskQueue) {
                final TaskQueue queue = (TaskQueue) (super.getQueue());
                try {
                    if (!queue.force(command)) {  // 无限队列, 理论上不会失败;
                        submittedTaskCount.decrementAndGet();
                        throw new RejectedExecutionException("Queue capacity is full.");
                    }
                } catch(Exception x){
                    submittedTaskCount.decrementAndGet();
                    throw new RejectedExecutionException(x);
                }
            } else {
                submittedTaskCount.decrementAndGet();
                throw rx;
            }
        }
    }
}
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ThreadPoolExecutor;

public class TaskQueue extends LinkedBlockingQueue<Runnable>{
    private MyThreadPoolExecutor executor;

    public TaskQueue(){
        super();
    }
    public void setExecutor(MyThreadPoolExecutor executor) {
        this.executor = executor;
    }
    public boolean force(Runnable o) {
        if (executor.isShutdown()) {
            throw new RejectedExecutionException("Executor not running, can't force a command into the queue");
        }
        return super.offer(o); //forces the item onto the queue, to be used if the task is rejected
    }

    @Override
    public boolean offer(Runnable o) {
        int currentPoolThreadSize = executor.getPoolSize();
        // 线程数达到最大,添加到队列
        if(currentPoolThreadSize >= executor.getMaximumPoolSize()) {
            return super.offer(o);
        }
        // 有空闲线程,直接添加到队列
        if(executor.getSubmittedTaskCount() < currentPoolThreadSize) {
            return super.offer(o);
        }
        // 当前线程池数还不是最大,创建线程
        if(currentPoolThreadSize < executor.getMaximumPoolSize()){
            return false;
        }

        return super.offer(o);
    }
}

任务队列继承LinkedBlockingQueue, 无上限队列

测试:

import java.util.concurrent.TimeUnit;

public class Main {
    static MyThreadPoolExecutor mutilSecPool = new MyThreadPoolExecutor(2, 7, 2, TimeUnit.SECONDS, new TaskQueue());

    static void print(String s){
        for(int i = 0; i < 2; ++i){
            System.out.println("当前运行"+s);
            try {
               Thread.sleep(1000);

            }catch (Exception e){
                e.printStackTrace();
            }
        }
    }

    public static void main(String[] args){
        System.out.println("线程池线程数: " + mutilSecPool.getPoolSize());
        mutilSecPool.submit(()->print("1"));
        mutilSecPool.submit(()->print("2"));
        System.out.println("线程池线程数: " + mutilSecPool.getPoolSize());
        mutilSecPool.submit(()->print("3"));
        mutilSecPool.submit(()->print("4"));
        mutilSecPool.submit(()->print("5"));
        mutilSecPool.submit(()->print("6"));
        mutilSecPool.submit(()->print("7"));
        mutilSecPool.submit(()->print("8"));
        mutilSecPool.submit(()->print("9"));
        mutilSecPool.submit(()->print("10"));
        System.out.println("线程池线程数: " + mutilSecPool.getPoolSize());
        try {
            Thread.sleep(10000);
            System.out.println("线程池线程数: " + mutilSecPool.getPoolSize());
        }catch (Exception e){
            e.printStackTrace();
        }
        mutilSecPool.shutdown();
    }
}

测试出来的结果基本符合预期的, 只是第二次打印线程池线程数目的时候比预期多了一个,这个原因还没有查清楚

参考:点击打开链接

参考:点击打开链接

如果需要更加完善的,建议参考tomcat ThreadPoolExecutor源码

猜你喜欢

转载自blog.csdn.net/wangzhuo0978/article/details/80338334