上一篇:Java多线程(6)CAS详解
手写线程池
一. 线程阻塞获取任务
1. 编写任务队列 以生产者和消费者模型来实现
class TaskQueue<T> {
// 1. 任务队列实例
private Deque<T> queue = new ArrayDeque<>();
// 2. 任务队列大小
private final static int queueSize = 10;
// 3. 锁
private ReentrantLock lock = new ReentrantLock();
// 4. 生产者等待的条件变量
private Condition producerWaitSet = lock.newCondition();
// 5. 消费者等待的条件变量
private Condition consumerWaitSet = lock.newCondition();
// 生产者添加任务(阻塞)
public void addTask(T task){
// 加锁
lock.lock();
try{
// 如果 任务队列满了
while(queue.size() == queueSize){
try {
// 生产者(指的是添加任务的主线程) 等待
producerWaitSet.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
// 把任务添加到集合最后面的一个位置
queue.addLast(task);
// 添加成功就唤醒 消费者(这里的消费者指的是线程池的线程)
consumerWaitSet.signal();
} finally {
//释放锁
lock.unlock();
}
}
// 消费者获取任务(阻塞)
public T getTask(){
lock.lock();
try{
// 如果任务(即生产者没有添加任务)为空
while(queue.isEmpty()){
try {
// 则让 消费者(这里的消费者指的是线程池的线程) 等待
consumerWaitSet.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
// 拿出第一个任务
T t = queue.removeFirst();
// 唤醒生产者(即添加任务的主线程)线程,继续添加线程
producerWaitSet.signal();
return t;
} finally {
//释放锁
lock.unlock();
}
}
}
这里难在对生产者消费者模型的理解上,以线程池为例,
- 消费者就是线程池里面的每一个线程
- 生产者就是添加任务的主线程
2. 编写线程池
class MyThreadPool {
// 1. 寄存的任务队列
private TaskQueue<Runnable> taskQueue = new TaskQueue<>();
// 2. 线程集合 这里用封装的内部类Worker实现
private HashSet<Worker> workers = new HashSet<>();
// 3. 线程数
private int coreSize;
public MyThreadPool(int coreSize){
this.coreSize = coreSize;
}
// 执行任务 即主线程(生产者) 添加任务
public void executeTask(Runnable task){
// 判断还有没有 空闲的线程
if (workers.size() < coreSize){
// 如果有,就新建线程
Worker worker = new Worker(task);
// 添加进去,表示少了一个线程
workers.add(worker);
worker.start();
} else {
// 没有空闲的线程的话 就添加到任务队列 等待线程池的每个线程执行完当前的进程
taskQueue.addTask(task);
}
}
class Worker extends Thread {
private Runnable task;
public Worker(Runnable task) {
this.task = task;
}
@Override
public void run() {
// 如果当前任务执行完毕,并且获取不到新的任务下就退出while
while (task != null || (task = taskQueue.getTask())!=null){
try {
task.run();
} catch (Exception e){
e.printStackTrace();
} finally {
task = null;
}
}
synchronized (workers){
// 执行完任务就移除,表示多了一个线程
workers.remove(this);
}
}
}
}
这里的线程池给了一个有参的构造方法,参数是线程池的最大容量(即最大可同时执行的线程大小),然后给了一个开始任务的方法executeTask,后面的我不多说了,懂得都懂,不懂得多敲几编,代码注释都加上了,顺带我个人的理解
3. 测试代码
3.1 情况1 线程池容量10 给安排8个任务
public class TestThreadPool {
// 测试代码
public static void main(String[] args) {
MyThreadPool threadPool = new MyThreadPool(10);
//让10个线程去 执行 8 个任务
for (int i = 0; i < 8; i++) {
int j = i+1;
threadPool.executeTask(new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread().getName() + "执行了第" + j + "个任务");
}
});
}
}
}
3.2 情况2 线程池容量10 给安排30个任务
public class TestThreadPool {
// 测试代码
public static void main(String[] args) {
MyThreadPool threadPool = new MyThreadPool(10);
//让10个线程去 执行 30 个任务
for (int i = 0; i < 30; i++) {
int j = i+1;
threadPool.executeTask(new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread().getName() + "执行了第" + j + "个任务");
}
});
}
}
}
运行结果:
3.3 总结
可以发现任务都执行完了,诡异的是为什么主线程还是没有停止下来
这个时候可以看线程池内部类worker的 run方法 里面的whlie循环
因为这个方法让每个都进入了wait,所以我们在下面改造一下
二. 线程非阻塞获取任务
1. 生产者消费者模型
在上面的基础上新增一个增强版获取任务的方法
// 消费者获取任务增强版(非阻塞) (传递两个参数 1是时间 2是时间单位)
public T getTaskEnhance(long timeout, TimeUnit timeUnit){
lock.lock();
try{
// 统一时间格式
long nanos = timeUnit.toNanos(timeout);
while(queue.isEmpty()){
try {
// 判断有没有超过获取的超时时间 这个判断会在虚假唤醒的情况下执行
if (nanos <= 0){
return null;
}
// awaitNanos方法 对比 await(时间,时间单位) 方法的区别就是 如果不到等待时间被打断 他会返回剩余时间
nanos = consumerWaitSet.awaitNanos(nanos);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
T t = queue.removeFirst();
producerWaitSet.signal();
return t;
} finally {
lock.unlock();
}
}
2. 线程池实现
因为之前的getTask方法是阻塞方法,所以我们可以调用这个新加的增强版方法,通过构造方法的方式传递时间,和时间单位
class MyThreadPool {
// 1. 寄存的任务队列
private TaskQueue<Runnable> taskQueue = new TaskQueue<>();
// 2. 线程集合 这里用封装的内部类Worker实现
private HashSet<Worker> workers = new HashSet<>();
// 3. 线程数
private int coreSize;
// 4. 获取任务的超时时间
private long timeout;
// 5. 时间单位
private TimeUnit timeUnit;
public MyThreadPool(int coreSize, long timeout, TimeUnit timeUnit){
this.coreSize = coreSize;
this.timeout = timeout;
this.timeUnit = timeUnit;
}
// 执行任务 即主线程(生产者) 添加任务
public void executeTask(Runnable task){
// 判断还有没有 空闲的线程
if (workers.size() < coreSize){
// 如果有,就新建线程
Worker worker = new Worker(task);
// 添加进去,表示少了一个线程
workers.add(worker);
worker.start();
} else {
// 没有空闲的线程的话 就添加到任务队列 等待线程池的每个线程执行完当前的进程
taskQueue.addTask(task);
}
}
class Worker extends Thread {
private Runnable task;
public Worker(Runnable task) {
this.task = task;
}
@Override
public void run() {
// 如果当前任务执行完毕,并且获取不到新的任务下就退出while
while (task != null || (task = taskQueue.getTaskEnhance(timeout,timeUnit))!=null){
try {
task.run();
} catch (Exception e){
e.printStackTrace();
} finally {
task = null;
}
}
synchronized (workers){
// 执行完任务就移除,表示多了一个线程
workers.remove(this);
}
}
}
}
3. 测试代码
3.1 运行结果
这里我们只测试一种情况就是线程容量为10,任务个数为30
public class TestThreadPool {
// 测试代码
public static void main(String[] args) {
MyThreadPool threadPool = new MyThreadPool(10,100,TimeUnit.MILLISECONDS);
//让10个线程去 执行 30 个任务
for (int i = 0; i < 30; i++) {
int j = i+1;
threadPool.executeTask(new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread().getName() + "执行了第" + j + "个任务");
}
});
}
}
}
运行结果:
可以看到我们的主线程已经关闭了,任务也按预想的执行完成了
3.2 总结
还有一种情况没有测试 ,就是executeTask方法中
这里,添加的任务超过了任务队列最大容量,那么主线程就会一种死等,等待任务添加完成,这里的解决方案和上面的添加方法增强版是一样的
// 生产者添加任务增强版(非阻塞)(传递两个参数 1是时间 2是时间单位)
public boolean addTaskEnhance(T task,long timeout, TimeUnit timeUnit){
// 加锁
lock.lock();
try{
// 统一时间格式
long nanos = timeUnit.toNanos(timeout);
// 如果 任务队列满了
while(queue.size() == queueSize){
try {
// 判断有没有超过获取的超时时间 这个判断会在虚假唤醒的情况下执行
if (nanos <= 0){
return false;
}
// 生产者(指的是添加任务的主线程) 等待
nanos = producerWaitSet.awaitNanos(nanos);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
// 把任务添加到集合最后面的一个位置
queue.addLast(task);
// 添加成功就唤醒 消费者(这里的消费者指的是线程池的线程)
consumerWaitSet.signal();
} finally {
//释放锁
lock.unlock();
}
return true;
}
还有就是通过策略模式让用户自己选择
三. 策略模式(队列已满的情况下让用户选择死等还是其他操作)
1. 新建策略模式的接口类
@FunctionalInterface
interface RejectPolicy<T> {
void reject(TaskQueue<T> queue, T task);
}
2. 生产者消费者模型
class TaskQueue<T> {
// 1. 任务队列实例
private Deque<T> queue = new ArrayDeque<>();
// 2. 任务队列大小
private final static int queueSize = 10;
// 3. 锁
private ReentrantLock lock = new ReentrantLock();
// 4. 生产者等待的条件变量
private Condition producerWaitSet = lock.newCondition();
// 5. 消费者等待的条件变量
private Condition consumerWaitSet = lock.newCondition();
// 生产者添加任务(阻塞)
public void addTask(T task){
// 加锁
lock.lock();
try{
// 如果 任务队列满了
while(queue.size() == queueSize){
try {
// 生产者(指的是添加任务的主线程) 等待
producerWaitSet.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
// 把任务添加到集合最后面的一个位置
queue.addLast(task);
// 添加成功就唤醒 消费者(这里的消费者指的是线程池的线程)
consumerWaitSet.signal();
} finally {
//释放锁
lock.unlock();
}
}
// 生产者添加任务增强版(非阻塞)(传递两个参数 1是时间 2是时间单位)
public boolean addTaskEnhance(T task,long timeout, TimeUnit timeUnit){
// 加锁
lock.lock();
try{
// 统一时间格式
long nanos = timeUnit.toNanos(timeout);
// 如果 任务队列满了
while(queue.size() == queueSize){
try {
// 判断有没有超过获取的超时时间 这个判断会在虚假唤醒的情况下执行
if (nanos <= 0){
return false;
}
// 生产者(指的是添加任务的主线程) 等待
nanos = producerWaitSet.awaitNanos(nanos);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
// 把任务添加到集合最后面的一个位置
queue.addLast(task);
// 添加成功就唤醒 消费者(这里的消费者指的是线程池的线程)
consumerWaitSet.signal();
} finally {
//释放锁
lock.unlock();
}
return true;
}
// 消费者获取任务(阻塞)
public T getTask(){
lock.lock();
try{
// 如果任务(即生产者没有添加任务)为空
while(queue.isEmpty()){
try {
// 则让 消费者(这里的消费者指的是线程池的线程) 等待
consumerWaitSet.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
// 拿出第一个任务
T t = queue.removeFirst();
// 唤醒生产者(即添加任务的主线程)线程,继续添加线程
producerWaitSet.signal();
return t;
} finally {
//释放锁
lock.unlock();
}
}
// 消费者获取任务增强版(非阻塞) (传递两个参数 1是时间 2是时间单位)
public T getTaskEnhance(long timeout, TimeUnit timeUnit){
lock.lock();
try{
// 统一时间格式
long nanos = timeUnit.toNanos(timeout);
while(queue.isEmpty()){
try {
// 判断有没有超过获取的超时时间 这个判断会在虚假唤醒的情况下执行
if (nanos <= 0){
return null;
}
// awaitNanos方法 对比 await(时间,时间单位) 的区别就是 如果不到等待时间被打断 他会返回剩余时间
nanos = consumerWaitSet.awaitNanos(nanos);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
T t = queue.removeFirst();
producerWaitSet.signal();
return t;
} finally {
lock.unlock();
}
}
// 生产者添加,如果任务队列满了的话,策略模式选让用户选择如何处理
public void tryAddTask(RejectPolicy<T> reject, T task){
// 加锁
lock.lock();
try{
// 如果 任务队列满了
if (queue.size() == queueSize){
reject .reject(this,task);
}
// 把任务添加到集合最后面的一个位置
queue.addLast(task);
// 添加成功就唤醒 消费者(这里的消费者指的是线程池的线程)
consumerWaitSet.signal();
} finally {
//释放锁
lock.unlock();
}
}
}
3. 编写线程池
class MyThreadPool<T> {
// 1. 寄存的任务队列
private TaskQueue<Runnable> taskQueue = new TaskQueue<>();
// 2. 线程集合 这里用封装的内部类Worker实现
private HashSet<Worker> workers = new HashSet<>();
// 3. 线程数
private int coreSize;
// 4. 获取任务的超时时间
private long timeout;
// 5. 时间单位
private TimeUnit timeUnit;
private RejectPolicy<Runnable> reject;
public MyThreadPool(int coreSize, long timeout, TimeUnit timeUnit,RejectPolicy<Runnable> reject){
this.coreSize = coreSize;
this.timeout = timeout;
this.timeUnit = timeUnit;
this.reject = reject;
}
// 执行任务 即主线程(生产者) 添加任务
public void executeTask(Runnable task){
// 判断还有没有 空闲的线程
if (workers.size() < coreSize){
// 如果有,就新建线程
Worker worker = new Worker(task);
// 添加进去,表示少了一个线程
workers.add(worker);
worker.start();
} else {
// 没有空闲的线程的话 就添加到任务队列 等待线程池的每个线程执行完当前的进程
taskQueue.tryAddTask(reject,task);
}
}
class Worker extends Thread {
private Runnable task;
public Worker(Runnable task) {
this.task = task;
}
@Override
public void run() {
// 如果当前任务执行完毕,并且获取不到新的任务下就退出while
while (task != null || (task = taskQueue.getTaskEnhance(timeout,timeUnit))!=null){
try {
task.run();
} catch (Exception e){
e.printStackTrace();
} finally {
task = null;
}
}
synchronized (workers){
// 执行完任务就移除,表示多了一个线程
workers.remove(this);
}
}
}
}
重点在这里:
4. 测试代码
public class TestThreadPool {
// 测试代码
public static void main(String[] args) {
MyThreadPool threadPool = new MyThreadPool(10, 100, TimeUnit.MILLISECONDS, new RejectPolicy<Runnable>() {
@Override
public void reject(TaskQueue<Runnable> queue, Runnable task) {
//queue.addTask(task); // 如果任务队列满了,就死等
queue.addTaskEnhance(task,100,TimeUnit.MILLISECONDS); // 如果任务队列满了,就等100 毫秒,199毫秒添加不上就算了
// throw new RuntimeException(); 抛出异常
}
});
//让10个线程去 执行 30 个任务
for (int i = 0; i < 40; i++) {
int j = i+1;
threadPool.executeTask(new Runnable() {
@Override
public void run() {
try {
Thread.sleep(3000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread().getName() + "执行了第" + j + "个任务");
}
});
}
}
}
三种情况,生产者(主线程)想用哪种就哪种