封装ThreadLocal

作者简介:大家好,我是smart哥,前中兴通讯、美团架构师,现某互联网公司CTO

联系qq:184480602,加我进群,大家一起学习,一起进步,一起对抗互联网寒冬

为什么要封装ThreadLocal?

原因有两点:

1、对于Thread,如果希望在Interceptor中存入UserInfo并在Service层通过ThreadLocal把UserInfo出来,必须保证Interceptor和Service此时用的是同一个ThreadLocal。

但是一个对象如何同时出现在Interceptor和Service呢?各自new一个ThreadLocal可不行,因为此时是两个对象了。比如,在Interceptor创建的对象是紫霞,而Service创建的是青霞,紫霞在至尊宝存入的东西,后面的青霞可没办法出来,因为Thread内部的ThreadLocalMap是以ThreadLocal作为key的(看上面爱心的key)。

但如果我们在ThreadLocalUtil中new一个ThreadLocal对象作为成员变量,就可以在Service中取出来了:

扫描二维码关注公众号,回复: 17191065 查看本文章

即:把ThreadLocal对象封装在ThreadLocalUtil中,分别在Interceptor和Service中使用它。

2、原生的ThreadLocal无法满足复杂的业务场景。

比如现在我封装了一个最简单的ThreadLocal(装饰者模式,为的是解决第一个问题):

/**
 * @author mx
 */
public class MyThreadLocal {

    private MyThreadLocal() {
    }

    private static final ThreadLocal<Object> THREAD_CONTEXT = new ThreadLocal<>();

    public static void put(Object obj) {
        THREAD_CONTEXT.set(obj);
    }

    public static Object get() {
        return THREAD_CONTEXT.get();
    }

    public static void remove() {
        THREAD_CONTEXT.remove();
    }
}

MyThreadLocal确实解决了第一个问题,复用了ThreadLocal,保证了Interceptor和Service用到的ThreadLocal是同一个对象。

但是,有两个缺陷:

  • 无法存取多个不同的值
  • 语意不明

比如,Service层希望往ThreadLocal里再添加一个Score对象,好让DAO层能获取到。你要怎么做?

另外,MyThreadLocal.get()其实很突兀,语意不明,光看代码你根本不知道get出来的是什么东西。

基于以上两个原因,我们必须封装ThreadLocal。对于第二个问题,其实可以考虑把原先的value改为Map类型。比如原本是 threadLocal1:User或者threadLocal1:Score,确实只能存一个值,而且很容易发生覆盖。但是如果把Map作为value存进去,形成双层Map就灵活多了:

{
	"threadLocal1" : {
        "USER_INFO" : User,
        "SCORE" : Score
    }
}

思路分析到这,我们开始写代码。一般情况下,我们只需要考虑一个ThreadLocal和多个Thread,这也是实际编程最常见的方式,所以下面的代码只会封装一个ThreadLocalUtil,里面也只有一个ThreadLocal。

ThreadLocalUtil第一版

/**
 * @author mx
 */
public class ThreadLocalUtil {

    private ThreadLocalUtil() {
    }

    /**
     * ThreadLocal是紫霞仙子,至尊宝是Thread
     * ThreadLocal的泛型规定了紫霞仙子劈开至尊宝时,能给他心里塞的东西的类型。
     * <p>
     * 比如
     * 将ThreadLocal泛型指定为String,那么造了一个ThreadLocalMap后,这个map只能存 threadLocal:"这是字符串" 这样的键值对
     * 将ThreadLocal泛型指定为Integer,那么造了一个ThreadLocalMap后,这个map只能存 threadLocal:1111111111 这样的键值对
     *
     * 由于单纯的value会发生值覆盖,所以我们使用Map<String, Object>作为value
     */
    private static final ThreadLocal<Map<String, Object>> THREAD_CONTEXT = new ThreadLocal<>();


    /**
     * 存入线程变量
     *
     * @param key
     * @param object
     */
    public static void put(String key, Object object) {
        /**
         * 至尊宝(一个Thread)经过这段代码,遇到了紫霞(THREAD_CONTEXT)。大家可以点进get()看看,内部操作是:
         * 1.把至尊宝的心取出来(从Thread中取出ThreadLocalMap)
         *
         * ThreadLocalMap的构造类似于这样
         * {
         * ...THREAD_CONTEXT: {
         * ........."USER_INFO":"{'name':'bravo', 'age':18}",
         * ........."SCORE":"{'Math':99, 'English': 97}"
         * ......}
         * }
         *
         * 2.ThreadLocalMap.Entry e = map.getEntry(this); 把自己(THREAD_CONTEXT)作为key,取出属于自己的value,此时value是一个Map<String, Object>。
         * 3.所以最终THREAD_CONTEXT.get()返回的Map<String, Object> map
         *
         */
        Map<String, Object> map = THREAD_CONTEXT.get();
        // 第一次从ThreadLocalMap中根据threadLocal取出的value可能是null
        if (map == null) {
            map = new HashMap<>();
            // 把map作为value放进去
            THREAD_CONTEXT.set(map);
        }
        /**
         * 假设本次存的是 USER_INFO:{"name":"bravo", "age":18}
         * 此时ThreadLocalMap中的结构是
         * {
         * ...THREAD_CONTEXT: {
         * ........."USER_INFO":"{'name':'bravo', 'age':18}",
         * ......}
         * }
         *
         */
        map.put(key, object);
    }

    /**
     * 取出线程变量
     *
     * @param key
     * @return
     */
    public static Object get(String key) {
        // 先获取Map
        Map<String, Object> map = THREAD_CONTEXT.get();
        // 从Map中得到USER_INFO
        return map != null ? map.get(key) : null;
    }

    /**
     * 移除当前线程的指定变量
     * 比如把
     * {
     * ...THREAD_CONTEXT: {
     * ........."USER_INFO":"{'name':'bravo', 'age':18}",
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ......}
     * }
     * 变成
     * {
     * ...THREAD_CONTEXT: {
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ......}
     * }
     * 并不是移除所有,而是只移除USER_INFO
     *
     * @param key
     */
    public static void remove(String key) {
        Map<String, Object> map = THREAD_CONTEXT.get();
        map.remove(key);
    }

    /**
     * 移除当前线程的所有变量
     * 比如把
     * {
     * ...THREAD_CONTEXT: {
     * ........."USER_INFO":"{'name':'bravo', 'age':18}",
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ......}
     * }
     * 变成
     * {
     * }
     */
    public static void clear() {
        THREAD_CONTEXT.remove();
    }
}

建议大家从上面的MyThreadLocal开始,尝试自己一步步封装,ThreadLocalUtil第一步也不难,完全可以自己写。

ThreadLocalUtil第二版

上面的版本其实马马虎虎能用了,就是看起来不是特别优雅,很多地方需要判断null。如果你希望自己的工具类优雅些,逼格高一点,可以尝试下面这种:

/**
 * @author mx
 */
public class ThreadLocalUtil {

    private ThreadLocalUtil() {
    }

    /**
     * 注意右边new的不是原生的ThreadLocal,而是我自定义的MapThreadLocal,它继承自ThreadLocal
     *
     * @see MapThreadLocal
     */
    private final static ThreadLocal<Map<String, Object>> THREAD_CONTEXT = new MapThreadLocal();

    /**
     * 根据key获取value
     * 比如key为USER_INFO,则返回"{'name':'bravo', 'age':18}"
     * {
     * ...THREAD_CONTEXT: {
     * ........."USER_INFO":"{'name':'bravo', 'age':18}",
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     *
     * @param key
     * @return
     */
    public static Object get(String key) {
        // getContextMap()表示要先获取THREAD_CONTEXT的value,也就是Map<String, Object>。然后再从Map<String, Object>中根据key获取
        return getContextMap().get(key);
    }

    /**
     * put操作,原理同上
     *
     * @param key
     * @param value
     */
    public static void put(String key, Object value) {
        getContextMap().put(key, value);
    }

    /**
     * 清除map里的某个值
     * 比如把
     * {
     * ...THREAD_CONTEXT: {
     * ........."USER_INFO":"{'name':'bravo', 'age':18}",
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     * 变成
     * {
     * ...THREAD_CONTEXT: {
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     *
     * @param key
     * @return
     */
    public static Object remove(String key) {
        return getContextMap().remove(key);
    }

    /**
     * 清除整个Map<String, Object>
     * 比如把
     * {
     * ...THREAD_CONTEXT: {
     * ........."USER_INFO":"{'name':'bravo', 'age':18}",
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     * 变成
     * {
     * ...THREAD_CONTEXT: {}
     * }
     */
    public static void remove() {
        getContextMap().clear();
    }

    /**
     * 从ThreadLocalMap中清除当前ThreadLocal存储的内容
     * 比如把
     * {
     * ...THREAD_CONTEXT: {
     * ........."USER_INFO":"{'name':'bravo', 'age':18}",
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     * 变成
     * {
     * }
     */
    public static void clear() {
        THREAD_CONTEXT.remove();
    }

    /**
     * 从ThreadLocalMap
     * {
     * ...THREAD_CONTEXT: {
     * ........."USER_INFO":"{'name':'bravo', 'age':18}",
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     * 中获取Map<String, Object>
     * {
     * ..."USER_INFO":"{'name':'bravo', 'age':18}",
     * ..."SCORE":"{'Math':99, 'English': 97}"
     * }
     *
     * @return
     */
    private static Map<String, Object> getContextMap() {
        return THREAD_CONTEXT.get();
    }

    /**
     * 内部类,继承自ThreadLocal,和第一版一样,仍旧指定value为Map<String, Object>
     * 之所以要自定义MapThreadLocal,是为了重写原生ThreadLocal的initialValue()
     * 把ThreadLocal第一版中判断null的操作隐藏掉,让代码优雅一些(但对于初学者来说,理解难度也提升了)
     */
    private static class MapThreadLocal extends ThreadLocal<Map<String, Object>> {

        @Override
        protected Map<String, Object> initialValue() {
            return new HashMap<String, Object>(8) {

                private static final long serialVersionUID = 3637958959138295593L;

                @Override
                public Object put(String key, Object value) {
                    return super.put(key, value);
                }
            };
        }
    }
}

第二版的难点有两个:

  • 多了一个getContextMap(),部分人会晕。其实这个操作就是得到当前ThreadLocal对应Map<String, Object>
  • 为什么重写initialValue()可以避免判断null?

另外,不用担心每次都会创建新的Map覆盖原有的,get()方法内部本身会判断,如果已经有ThreadLocalMap其实是直接取值返回的。

如果还是觉得难理解,我建议取消getContextMap(),把里面的代码拷贝到各个方法中,好理解些。

最后的最后,不要因为这个工具类是自己封装的就怀疑是不是会重新导致线程安全问题。

只有同时满足下面3个条件,才有可能发生线程安全问题:

  • 多线程环境
  • 有共享数据
  • 有多条语句操作共享数据/单条语句本身非原子操作

但实际上,ThreadLocal的机制本身就避免了资源共享...因为每个线程内部都有自己的ThreadLocalMap(每个线程都有自己的资源,相互独立)。

所以记住,ThreadLocal本身和线程安全没啥关系,但你可以用它来解决线程安全问题,而且它的解决办法很粗暴,就是从根源上杜绝了资源共享。

之前说最后应该调用threadLocal.remove(),而对应ThreadLocalUtil,应该调用clear(),它对应的才是threadLocal.remove()。

ThreadLocalUtil第三版

之前封装ThreadLocal时一直在解决两个问题:

  • 原生的ThreadLocal对每个Thread的操作是基于单值的Key-Value,而我们期望基于Key-MapValue的操作
  • 如果不重写initValue(),需要在外部处理Map的初始化问题

对于initValue()的重写,其实不需要专门写一个内部类(很多人不习惯内部类),有两种替代方式:

  • 给THREAD_CONTEXT赋值时,直接new ThreadLocal()并用匿名类方式重写initValue()
  • 让ThreadLocalUtil继承ThreadLocal,然后重写initValue()

第一种方式最简单,这里演示第二种。

另外,之前在知乎专栏讨论过,ThreadLocalMap是定义在ThreadLocal内部的,由于包权限问题,我们无法直接使用。而我们的ThreadLocalUtil其实本质是就像个Map,所以第三版我改了名字,干脆咱也叫ThreadLocalMap,就当Map使用,只不过是线程内共享的。

当然,你也可以有更好的封装,可以下方留言:

/**
 * @author mx
 */
public class ThreadLocalMap extends ThreadLocal<Map<String, Object>> {

    private ThreadLocalMap() {
    }

    private final static ThreadLocal<Map<String, Object>> THREAD_CONTEXT = new ThreadLocalMap();

    /**
     * 根据key获取value
     * 比如key为USER_INFO,则返回"{'name':'bravo', 'age':18}"
     * {
     * ...THREAD_CONTEXT: {
     * ........."USER_INFO":"{'name':'bravo', 'age':18}",
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     *
     * @param key
     * @return
     */
    public static Object get(String key) {
        // getContextMap()表示要先获取THREAD_CONTEXT的value,也就是Map<String, Object>。然后再从Map<String, Object>中根据key获取
        return THREAD_CONTEXT.get().get(key);
    }

    /**
     * put操作,原理同上
     *
     * @param key
     * @param value
     */
    public static void put(String key, Object value) {
        THREAD_CONTEXT.get().put(key, value);
    }

    /**
     * 清除map里的某个值
     * 比如把
     * {
     * ...THREAD_CONTEXT: {
     * ........."USER_INFO":"{'name':'bravo', 'age':18}",
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     * 变成
     * {
     * ...THREAD_CONTEXT: {
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     *
     * @param key
     * @return
     */
    public static Object remove(String key) {
        return THREAD_CONTEXT.get().remove(key);
    }

    /**
     * 清除整个Map<String, Object>
     * 比如把
     * {
     * ...THREAD_CONTEXT: {
     * ........."USER_INFO":"{'name':'bravo', 'age':18}",
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     * 变成
     * {
     * ...THREAD_CONTEXT: {}
     * }
     */
    public static void clear() {
        THREAD_CONTEXT.get().clear();
    }

    /**
     * 从ThreadLocalMap中清除当前ThreadLocal存储的内容
     * 比如把
     * {
     * ...THREAD_CONTEXT: {
     * ........."USER_INFO":"{'name':'bravo', 'age':18}",
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     * 变成
     * {
     * }
     */
    public static void clearAll() {
        THREAD_CONTEXT.remove();
    }

    @Override
    protected Map<String, Object> initialValue() {
        return new HashMap<String, Object>(8) {

            private static final long serialVersionUID = 3637958959138295593L;

            @Override
            public Object put(String key, Object value) {
                return super.put(key, value);
            }
        };
    }

}

测试用例:

/**
 * @author mx
 */
public class ThreadLocalMapTest {

    public static void main(String[] args) {

        ThreadLocalMap.put("mainKey", "mainValue");

        new Thread(()->{
            ThreadLocalMap.put("threadKey", "threadValue");

            System.out.println("get main value in thread:" + ThreadLocalMap.get("mainKey"));
            System.out.println("get thread value in thread:" + ThreadLocalMap.get("threadKey"));
        }).start();

        System.out.println("get thread value in main:" + ThreadLocalMap.get("threadKey"));
        System.out.println("get main value in main:" + ThreadLocalMap.get("mainKey"));

    }

}

ThreadLocalUtil第四版

public class ThreadLocalMap {

    private ThreadLocalMap() {
    }

    /**
     * ThreadLocal的静态方法withInitial()会返回一个SuppliedThreadLocal对象
     * 而SuppliedThreadLocal<T> extends ThreadLocal<T>
     * 我们存进去的Map会作为的返回值:
     * protected T initialValue() {
     *    return supplier.get();
     * }
     * 
     * 所以也相当于重写了initialValue()
     * 
     */
    private final static ThreadLocal<Map<String, Object>> THREAD_CONTEXT = ThreadLocal.withInitial(
            () -> new HashMap<>(8)
    );

    /**
     * 根据key获取value
     * 比如key为USER_INFO,则返回"{'name':'bravo', 'age':18}"
     * {
     * ...THREAD_CONTEXT: {
     * ........."USER_INFO":"{'name':'bravo', 'age':18}",
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     *
     * @param key
     * @return
     */
    public static Object get(String key) {
        // getContextMap()表示要先获取THREAD_CONTEXT的value,也就是Map<String, Object>。然后再从Map<String, Object>中根据key获取
        return THREAD_CONTEXT.get().get(key);
    }

    /**
     * put操作,原理同上
     *
     * @param key
     * @param value
     */
    public static void put(String key, Object value) {
        THREAD_CONTEXT.get().put(key, value);
    }

    /**
     * 清除map里的某个值
     * 比如把
     * {
     * ...THREAD_CONTEXT: {
     * ........."USER_INFO":"{'name':'bravo', 'age':18}",
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     * 变成
     * {
     * ...THREAD_CONTEXT: {
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     *
     * @param key
     * @return
     */
    public static Object remove(String key) {
        return THREAD_CONTEXT.get().remove(key);
    }

    /**
     * 清除整个Map<String, Object>
     * 比如把
     * {
     * ...THREAD_CONTEXT: {
     * ........."USER_INFO":"{'name':'bravo', 'age':18}",
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     * 变成
     * {
     * ...THREAD_CONTEXT: {}
     * }
     */
    public static void clear() {
        THREAD_CONTEXT.get().clear();
    }

    /**
     * 从ThreadLocalMap中清除当前ThreadLocal存储的内容
     * 比如把
     * {
     * ...THREAD_CONTEXT: {
     * ........."USER_INFO":"{'name':'bravo', 'age':18}",
     * ........."SCORE":"{'Math':99, 'English': 97}"
     * ...}
     * }
     * 变成
     * {
     * }
     */
    public static void clearAll() {
        THREAD_CONTEXT.remove();
    }

}

Spring对ThreadLocal的封装

比如编写AOP日志时,经常会用到的RequestContextHolder,其实内部也维护了ThreadLocal。

那么Spring是如何做到remove的呢?使用过滤器(我们使用了拦截器)。

对于ThreadLocal的应用还有很多很多,这里就举这么一个例子叭~

补充:线程重用导致用户信息错乱的Bug

虽然小册两篇ThreadLocal相关的文章都反复强调用完之后最好及时remove(),但似乎都没有给出特别具有说服力的案例。最近在看极客时间朱晔老师的《Java业务开发常见错误100例》时,发现一个很不错的案例,这里特别拿来补充。

贴一段里面的代码:


private static final ThreadLocal<Integer> currentUser = ThreadLocal.withInitial(() -> null);

@GetMapping("wrong")
public Map wrong(@RequestParam("userId") Integer userId) {
    //设置用户信息之前先查询一次ThreadLocal中的用户信息
    String before  = Thread.currentThread().getName() + ":" + currentUser.get();
    //设置用户信息到ThreadLocal
    currentUser.set(userId);
    //设置用户信息之后再查询一次ThreadLocal中的用户信息
    String after  = Thread.currentThread().getName() + ":" + currentUser.get();
    //汇总输出两次查询结果
    Map result = new HashMap();
    result.put("before", before);
    result.put("after", after);
    return result;
}

为了更明显地看到这个BUG,可以将Tomcat线程池的最大连接数设置为1:

server.tomcat.max-threads=1

分别请求两次:

也就是说,由于Tomcat连接池的线程数有限(比如极端情况下max-thread=1),所以必然存在线程复用。如果两个请求复用一个Thread且ThreadLocal没有及时remove,那么上一个请求设置在Thread.ThreadLocalMap中的值就会污染本次请求。

所以应该保证每次使用后及时remove():


@GetMapping("right")
public Map right(@RequestParam("userId") Integer userId) {
    String before  = Thread.currentThread().getName() + ":" + currentUser.get();
    currentUser.set(userId);
    try {
        String after = Thread.currentThread().getName() + ":" + currentUser.get();
        Map result = new HashMap();
        result.put("before", before);
        result.put("after", after);
        return result;
    } finally {
        //在finally代码块中删除ThreadLocal中的数据,确保数据不串
        currentUser.remove();
    }
}

或者像Spring一样放在filter或者interceptor中remove()。

说一个小插曲:

之前有一次面试时,我提到项目中使用了BaseController,里面封装了ThreadLocal,可以获取Interceptor中存入的用户信息,然后面试官问我是否了解分布式场景下ThreadLocal导致的用户信息混乱的问题。我当时有点懵逼,关注点全在分布式场景上是否会产生这种BUG。其实不论是单体应用还是分布式应用,都有可能出现这个BUG。但为什么我们没遇到呢?不是因为项目小、并发低,而是我们根本不会像上面demo那样,上来就获取ThreadLocal里的内容,我们通常是在Interceptor中先设置值,然后Controller/Service中获取值,也就是说每次都是先覆盖、再取值,此时上次的value早就没了。

当然,还是推荐本次请求的值在响应时就remove,不要留到下次请求去覆盖,很容易出错,也容易造成内存泄漏。

作者简介:大家好,我是smart哥,前中兴通讯、美团架构师,现某互联网公司CTO

进群,大家一起学习,一起进步,一起对抗互联网寒冬

猜你喜欢

转载自blog.csdn.net/smart_an/article/details/134822170