我理解的Java并发基础(五):并发工具类和ThreadLocal

J.U.C并发包提供了几个非常有用的、用于并发流程控制的CountDownLatch、CyclicBarrier、Semaphore、类等。

1,CountDownLatch,闭锁。实现类似计数器的功能。CountDownLatch的常用API如下:

CountDownLatch(int count) // 构造方法,接受一个int类型参数表示总计数
void await() throws InterruptedException // 阻塞当前线程,直到计数=0,或者线程被中断
boolean await(long timeout, TimeUnit unit) throws InterruptedException // 一段时间内阻塞当前线程。如果计数=0被唤醒则返回true,如果超时被唤醒则返回false
void countDown() // 将计数-1
long getCount() // 获取当前计数值,常用于debug

举个例子。10个运动员比赛百米赛跑,只有这10个运动员都准备好之后,发令员才能发枪。

public class MyTest4CountDownLatch {
    
    public static void main(String[] args) {
        int playerNumbs = 10; // 运动员数量
        CountDownLatch cdl = new CountDownLatch(playerNumbs); // 创建一个CountDownLatch对象(cdl的初始计数=10)
        
        Thread starter = new Starter(cdl); // 创建一个发令员
        List<Thread> players = new ArrayList<>(playerNumbs); // 创建10个运动员(放到一个list集合中方便操作)
        for (int i = 0; i < playerNumbs; i++) {
            Thread player = new Player(i, cdl);
            players.add(player);
        }
        
        starter.start(); 
        for (int i = 0; i < playerNumbs; i++) {
            players.get(i).start();
        }
        
    }
    
}

class Player extends Thread{
    
    private CountDownLatch cdl;
    private int number;
    
    Player(int number, CountDownLatch cdl){
        this.cdl = cdl;
        this.number = number;
    }
    
    @Override
    public void run() {
        try {
            Thread.sleep((long) (Math.random() * 100));
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println("运动员" + number + "准备好了.");
        cdl.countDown(); // 该运动员准备就绪(cdl的计数-1)
    }
}

class Starter extends Thread {
    private CountDownLatch cdl;
    
    Starter(CountDownLatch cdl){
        this.cdl = cdl;
    }
    
    @Override
    public void run() {
        System.out.println("发令员举起手枪,等待所有运动员准备就绪.");
        try {
            cdl.await(); // 等待所有运动员准备就绪(等待cdl的计数=0)
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println("所有运动员已准备就绪,发令员发枪。");
    }
}

2,CyclicBarrier,循环栅栏。等待线程的数量达到目标的时候,所有等待的线程同时执行。可重置。CyclicBarrier的常用API如下:

CyclicBarrier(int parties) // 构造方法,指明需要等待的计数
CyclicBarrier(int parties, Runnable barrierAction) // 构造方法,指明需要等待的计数 和 计数=0时的触发操作
int getParties() // 返回初始化指定的parties
int await() throws InterruptedException, BrokenBarrierException // 阻塞当前线程。返回剩余等待计数。如果返回paties-1则表示是第1个到达。如果返回0则表示最后一个到达
int await(long timeout, TimeUnit unit) throws InterruptedException, BrokenBarrierException, TimeoutException // 一段时间内阻塞等待。响应线程中断标识。返回值同await()
boolean isBroken() // 如果当前cyclicBarrier对象处于broken状态则返回true
void reset() // 重置计数
int getNumberWaiting() // 返回剩余计数

  CyclicBarrier是一个所有线程要么全通过,要么全不通过的工具。如果有线程调用await(long timeout, TimeUnit unit) 超时通过,则CyclicBarrier处于broken状态,其他别的正在等待的线程会收到InterruptedException, 后续调用await()的线程会收到BrokenBarrierException。

举个例子。地铁的修建是按分段来修建的,只有所有的分段都施工完成之后,才可以通车。

public class MyTest4CyclicBarrier {

    public static void main(String[] args) throws InterruptedException {
        int lineSegementNumbs = 3;
        CyclicBarrier cb = new CyclicBarrier(lineSegementNumbs); // 创建CyclicBarrier对象 (cb计数=3)
        
        for (int i = 0; i < lineSegementNumbs; i++) {
            new LineSegment(i, cb).start();
        }
        
        Thread.sleep(1000);
        System.out.println("CyclicBarrier重用"); // 自动重置计数
        
        for (int i = 0; i < lineSegementNumbs; i++) {
            new LineSegment(i, cb).start();
        }
        
        
    }
}

class LineSegment extends Thread{
    
    private int segementNum;
    private CyclicBarrier cb;
    
    LineSegment(int segementNum, CyclicBarrier cb){
        this.segementNum = segementNum;
        this.cb = cb;
    }
    
    
    @Override
    public void run() {
        try {
            Thread.sleep((long) (Math.random() * 100));
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        
        System.out.println("线路段" + segementNum +"已施工完成.");
        try {
            cb.await(); // cb计数-1,等待cb计数=0
        } catch (InterruptedException | BrokenBarrierException e) {
            e.printStackTrace();
        }
        System.out.println("通车了...");
    }
    
}

3,Semaphore,信号量。通常用来表示许可数量,用于限制可以访问某些资源(物理或逻辑的)的线程数目。Semaphore的常用API如下:

Semaphore(int permits) // 构造方法,指定许可数量,默认非公平
Semaphore(int permits, boolean fair) // 构造方法,指定许可数量,制定是否公平获取
void acquire() throws InterruptedException // 阻塞等待许可。阻塞期间响应线程的中断标识
void acquireUninterruptibly() // 阻塞等待许可,不相应线程中断标识。
boolean tryAcquire() // 尝试获取许可。如果失败则返回false,如果获取成功则返回true
boolean tryAcquire(long timeout, TimeUnit unit) throws InterruptedException // 一段时间内阻塞等待许可,阻塞期间响应线程的中断标识
void release() // 释放一个许可

// 以下API中,int类型的permits参数表示一次获取或释放多个许可
void acquire(int permits) throws InterruptedException // 同 acquire()
void acquireUninterruptibly(int permits) // 同 acquireUninterruptibly()
boolean tryAcquire(int permits) // 同 tryAcquire()
boolean tryAcquire(int permits, long timeout, TimeUnit unit) throws InterruptedException // 同 tryAcquire(long timeout, TimeUnit unit)
void release(int permits) // 同 release()

// 以下API常用于控制、监控或debug
int availablePermits() // 剩余可用的许可数量
int drainPermits() // 将所有剩余许可置为0
int getQueueLength() // 等待获取许可的线程数量

举个例子。桥每次最多能通过两个人,每个人通过桥的时间为10秒,桥东西两侧各有10个人同时等待准备过桥。输出每次过桥人的姓名、过桥方向和过桥时间。

public class MyTest4Semaphore {

    public static void main(String[] args) throws InterruptedException {
        String[] gruopEast = {"张三0", "张三1", "张三2", "张三3", "张三4", "张三5", "张三6", "张三7", "张三8", "张三9"};
        String[] gruopWest = {"李四0", "李四1", "李四2", "李四3", "李四4", "李四5", "李四6", "李四7", "李四8", "李四9"};

        int takeTime = 10;

        CyclicBarrier cb = new CyclicBarrier(20); // 使用CyclicBarrier来达到“同时准备过桥”的目的。
        Semaphore semaphore = new Semaphore(2);

        for (int i = 0; i < gruopEast.length; i++) {
            new Player(gruopEast[i], "西", takeTime, cb, semaphore).start();
        }

        for (int i = 0; i < gruopWest.length; i++) {
            new Player(gruopWest[i], "东", takeTime, cb, semaphore).start();
        }
    }
}

class Player extends Thread {
    private String name;
    private String destination;
    private int takeTimeSeconds;
    private CyclicBarrier cb;
    private Semaphore semaphore;

    public Player(String name, String destination, int takeTimeSeconds, CyclicBarrier cb, Semaphore semaphore) {
        super();
        this.name = name;
        this.destination = destination;
        this.takeTimeSeconds = takeTimeSeconds;
        this.cb = cb;
        this.semaphore = semaphore;
    }

    @Override
    public void run(){
        try {
            cb.await();
            semaphore.acquire(); // 阻塞获取许可
            System.out.println(name + "准备向" + destination + "过桥,需要花费" + takeTimeSeconds + "秒");
            Thread.sleep(1000 * takeTimeSeconds);
            semaphore.release(); // 释放许可
        } catch (InterruptedException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        } catch (BrokenBarrierException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }
}

4,Phaser,JDK1.7新增,功能上比CyclicBarrier、CountDownLatch更强大,提供更为丰富的API。常用于多线程参与多阶段完成的场景。在不同阶段,可以等待不同数量人的人完成,再进入下一阶段。在进入下一个阶段的时候,用户还可以重写onAdvance()来实现更佳自定义更加灵活的场景。Phaser的常用API如下:

Phaser(int parties) // 构造方法,指明有多少个线程参与
int register() // 当前线程注册到当前阶段
int arrive() // 当前线程已完成自己的工作
int arriveAndDeregister() // 当前线程已完成自己的工作,并取消注册(不参与下一个阶段)。
int arriveAndAwaitAdvance()  // 当前线程已完成自己的工作,等待进入一下个阶段(参与下一个阶段)
int awaitAdvance(int phase) // 阻塞等待第phase个阶段
int awaitAdvanceInterruptibly(int phase) //  阻塞等待第phase个阶段,阻塞期间响应线程中断标识
int awaitAdvanceInterruptibly(int phase, long timeout, TimeUnit unit) // 阻塞等待第phase个阶段一段时间,阻塞期间响应线程中断标识
void forceTermination() // 强制将该phaser对象设置为中指状态
int getPhase() // 获取当前第phase个阶段
int getRegisteredParties() // 获取当前阶段的注册数量
int getArrivedParties() // 获取当前阶段已经完成的线程数量
int getUnarrivedParties() // 获取当前阶段还未完成的线程数量
boolean isTerminated() // 判断该phaser是否可用
boolean onAdvance(int phase, int registeredParties) // protected方法,当一个阶段完成后触发,进入下一个阶段之前的动作。

  每个Phaser实例都会维护一个phase number,初始值为0。每当所有注册的任务都到达Phaser时,phase number累加,并在超过Integer.MAX_VALUE后清零。arrive()和arriveAndDeregister()方法用于记录到达;其中arrive(),某个参与者完成任务后调用;arriveAndDeregister(),任务完成,取消自己的注册。arriveAndAwaitAdvance(),自己完成等待其他参与者完成,进入阻塞,直到Phaser成功进入下个阶段。

举个例子。有一个项目,分为4个阶段完成。第一个阶段有工人A,工人B,工人C共同参与。第二阶段由工人C、工人D共同完成,第三阶段由工人B、工人E共同完成,第四阶段由工人A、B、C共同完成。前一个阶段完成之后,才能进入下一个阶段的工作。每个阶段完成之后,先向经理汇报,然后进入下一阶段。

public class MyTest4Phaser {

    public static void main(String[] args) {
        Phaser phaser = new MyProject(5);

        int[] workerAPhaseArray = new int[] {0, 3};
        int[] workerBPhaseArray = new int[] {0, 2, 3};
        int[] workerCPhaseArray = new int[] {0, 1, 3};
        int[] workerDPhaseArray = new int[] {1};
        int[] workerEPhaseArray = new int[] {2};

        int phaseAmount = 4;

        Worker workerA = new Worker("A", workerAPhaseArray, phaser, phaseAmount);
        Worker workerB = new Worker("B", workerBPhaseArray, phaser, phaseAmount);
        Worker workerC = new Worker("C", workerCPhaseArray, phaser, phaseAmount);
        Worker workerD = new Worker("D", workerDPhaseArray, phaser, phaseAmount);
        Worker workerE = new Worker("E", workerEPhaseArray, phaser, phaseAmount);

        workerA.start();
        workerB.start();
        workerC.start();
        workerD.start();
        workerE.start();

    }
}

class MyProject extends Phaser {

    public MyProject(int parties){
        super(parties);
    }

    @Override
    public boolean onAdvance(int currenPphase, int registeredParties) {
        if (currenPphase < 3) {
            System.out.println("通知经理:已经完成了第" + currenPphase + "阶段任务,准备执行第" + (currenPphase + 1) + "阶段任务。");
        }else {
            System.out.println("通知经理:已完成了所有任务");
        }
        return registeredParties == 0;
    }
}

class Worker extends Thread {
    protected String name;
    protected Phaser phaser;
    private int[] phaseArray;
    private int phaseAmount;

    public Worker (String name,int[] phaseArray, Phaser phaser, int phaseAmount){
        super(name);
        this.name = name;
        this.phaseArray = phaseArray;
        this.phaser = phaser;
        this.phaseAmount = phaseAmount;
        if (phaseArray == null || phaseArray.length == 0) throw new RuntimeException("工人参与的阶段错误");
    }

    public void doWork(){
        Set<Integer> set = intArray2Set(phaseArray);
        int lastPhase = phaseArray[phaseArray.length-1]; 
        for (int i = 0; i < phaseAmount && (!phaser.isTerminated()); i++) {
            int currentPhase = phaser.getPhase();
            if (set.contains(currentPhase)) {
                outPrint(currentPhase);
                if (lastPhase == currentPhase) {
                    phaser.arriveAndDeregister();
                    break;
                }else {
                    phaser.arriveAndAwaitAdvance();
                }
            }else{
                phaser.arriveAndAwaitAdvance();
            }
        }

    }

    public Set<Integer> intArray2Set(int[] phaseArray){
        return Arrays.stream(phaseArray).boxed().collect(Collectors.toSet());
    }

    protected void outPrint(int i){
        try {
            System.out.println("工人" + name + "正在执行第" + i + "阶段任务");
            Thread.sleep(1000);
            System.out.println("工人" + name + "执行完成第" + i + "阶段任务");
            Thread.sleep(1000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    @Override
    public void run() {
        doWork();
    }
}

5,Exchanger<V>,线程间的数据交换器。两个线程在一个安全点彼此交换数据。该类比较简单,就3个API:

public Exchanger() // 
V exchange(V x) throws InterruptedException // 
V exchange(V x, long timeout, TimeUnit unit) throws InterruptedException, TimeoutException // 

  第一个调用exchange()的线程会阻塞等待,直到第二个线程调用exchange()来完成彼此数据的交换。

举个例子。飞机驾驶员有主飞和副飞,重要消息需要二者互相确认的。

public class MyTest4Exchanger {

    public static void main(String[] args) {
        
        Exchanger<String> exchanger = new Exchanger<>();
        Thread primary = new Pilot("主飞", "10分钟后降落", exchanger);
        Thread secondary = new Pilot("副飞", "10分钟后降落", exchanger);
        
        primary.start();
        secondary.start();
    }
}

class Pilot extends Thread {
    
    private String pilotName;
    private String receivedMsg;
    private Exchanger<String> exchanger;
    
    Pilot(String pilotName, String receivedMsg, Exchanger<String> exchanger){
        this.pilotName = pilotName;
        this.receivedMsg = receivedMsg;
        this.exchanger = exchanger;
    }
    
    @Override
    public void run() {
        String ownReceivedMsg = receivedMsg;
        String otherReceivedMsg = null;
        try {
            otherReceivedMsg = exchanger.exchange(ownReceivedMsg);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println(pilotName + "汇报:" + (ownReceivedMsg.equals(otherReceivedMsg) ? "消息一致." : "不消息一致."));
    }
    
}

  以上介绍的5个工具类均为线程间的互相通讯的工具类。还有一个线程私有的工具类,ThreadLocal。不过,ThreadLocal是存在于java.lang包下的。

6,ThreadLocal<T>,本地线程变量。ThreadLocal为每个使用该变量的线程提供独立的变量副本,所以每一个线程都可以独立地改变自己的副本,而不会影响其它线程所对应的副本。ThreadLocal的API也比较简单:

ThreadLocal() // 构造方法
T get() // 获取当前线程中存储的本地变量
void set(T value) // 将value设置到当前线程的本地变量中存储
void remove() // 删除当前线程中存储的本地变量

举个例子。在多数据源处理中,以读写分离为例,可以将数据源的标识放到ThreadLocal中,使用aop来自动完成切换工作。本例就简单的模拟一下。(例子中不采用AOP了,直接代码中体现)

public class MyTest4ThreadLocal {

    
    public static void main(String[] args) {
        
        Map<String, String> name2DataSource = new HashMap<>(); // 缓存每个datasource,value以String代替
        name2DataSource.put("read", "ReadDataBase");
        name2DataSource.put("write", "WriteDataBase");
        
        for(int i = 0; i < 10; i++){
            new BussinessThread(i, name2DataSource).start();
        }
        
    }
}

class DataSourceHolder {
    
    private static final ThreadLocal<String> dataSources = new ThreadLocal<>();

    public static void setDataSourceKey(String customType) {
        dataSources.set(customType);
    }

    public static String getDataSourceKey() {
        return (String) dataSources.get();
    }

    public static void clearDataSourceKey() {
        dataSources.remove();
    }
}

class BussinessThread extends Thread {
    
    private Map<String, String> name2DataSource;
    private int number;
    BussinessThread(int number, Map<String, String> name2DataSource){
        this.number = number;
        this.name2DataSource = name2DataSource;
    }
    
    @Override
    public void run() {
        
        System.out.println("业务线程" + number + "准备进行读操作");
        DataSourceHolder.setDataSourceKey("read");
        
        try {
            Thread.sleep((long)(Math.random()*1000));
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        
        System.out.println("业务线程" + number + "读操作对应的数据库是:" + name2DataSource.get(DataSourceHolder.getDataSourceKey()));
        DataSourceHolder.clearDataSourceKey();
        
        // 再测试写操作。
        try {
            Thread.sleep((long)(Math.random()*1000));
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        
        System.out.println("业务线程" + number + "准备进行写操作");
        DataSourceHolder.setDataSourceKey("write");
        
        try {
            Thread.sleep((long)(Math.random()*1000));
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        
        System.out.println("业务线程" + number + "写操作对应的数据库是:" + name2DataSource.get(DataSourceHolder.getDataSourceKey()));
        DataSourceHolder.clearDataSourceKey();
    }
}

  ThreadLocal在使用完之后一定要手动threadlocal.remove()。原因有二: 1,如果使用不当会造成内存泄漏。线程类Thread中都有一份ThreadLocalMap的变量用来存储线程本地变量。由于ThreadLocalMap的生命周期跟Thread一样长,如果使用完之后没有手动threadlocal.remove()删除则会产生内存泄漏。 2,使用线程池的使用,线程是反复利用的资源,回收前的线程的副本变量会可能对再次时造成影响。
  so,正确使用ThreadLocal的姿势要注意两点: 1,ThreadLocal设置为类的静态变量。这样就只维持一份。 2,set(T value)设置,get()使用之后,一定要记得remove()删除。

参考资料:

  • 以上内容为笔者日常琐屑积累,已无从考究引用。如果有,请站内信提示。

猜你喜欢

转载自my.oschina.net/u/3466682/blog/1632744