Java并发编程之ConcurrentHashMap详解

简介

我们知道,HashMap并不是线程安全的,要使用线程安全的HashMap,可以用HashTable或者通过Collections.synchronizedMap方法来获取一个线程安全的Map,但是这两种线程安全的Map都是使用synchronized来保证线程安全,因此,在多线程竞争激烈的情况下,它们的效率非常低。因为当一个线程访问synchronized同步方法时,其他线程访问Map,可能会进入阻塞或轮询状态。因此,就有了ConcurrentHashMap的登场机会。

在JDK 1.7及之前的版本中,ConcurrentHashMap的实现采用了锁分段技术

锁分段技术

HashTable容器在竞争激烈的并发环境下表现出效率低下的原因,是因为所有访问HashTable的线程都必须竞争同一把锁,那假如容器里有多把锁,每一把锁用于锁容器其中一部分数据,那么当多线程访问容器里不同数据段的数据时,线程间就不会存在锁竞争,从而可以有效的提高并发访问效率,这就是ConcurrentHashMap所使用的锁分段技术,首先将数据分成一段一段的存储,然后给每一段数据配一把锁,当一个线程占用锁访问其中一个段数据的时候,其他段的数据也能被其他线程访问。

简单理解就是,ConcurrentHashMap是一个Segment数组,Segment通过继承ReentrantLock来进行加锁,所以每次需要加锁的操作锁住的是一个segment,这样只要保证每个segment是线程安全的,也就实现了全局的线程安全。

而在JDK 1.8中,ConcurrentHashMap的实现摒弃了锁分段技术,而是采用CAS+synchronized来保证线程安全,同时,其底层的数据结构也从Segment数组变为Node数组+链表+红黑树的实现方式。

我们下面先简单介绍一下JDK 1.7中ConcurrentHashMap的实现方式,然后再来仔细分析JDK 1.8的实现。

JDk 1.7版本的ConcurrentHashMap

JDk 1.7版本的ConcurrentHashMap的整体结构图如下:


ConcurrentHashMap中的Segment[]数组默认容量为16,即并发度为16,其实每个Segment内部很像之前介绍的HashMap,但是它只有链表,没有红黑树的实现。

ConcurrentHashMap有如下五个构造方法:

public ConcurrentHashMap(int initialCapacity, float loadFactor, int concurrencyLevel) {
    // 判断初始容量、负载因子和并发度的合法性
    // 初始容量、负载因子和并发度分别默认为16、0.75、16
    if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
        throw new IllegalArgumentException();
    // 判断并发度是否超过了最大值MAX_SEGMENTS,MAX_SEGMENTS = 1 << 16
    if (concurrencyLevel > MAX_SEGMENTS)
        concurrencyLevel = MAX_SEGMENTS;
    // Find power-of-two sizes best matching arguments
    // 为了能够通过按位与的散列算法来定位Segment数组的索引,必须要保证Segment数组的长度是2的整数次幂,即要保证并发度是2的整数次幂
    int sshift = 0;
    int ssize = 1;
    while (ssize < concurrencyLevel) {
        ++sshift;
        ssize <<= 1;
    }

    // sshift是ssize从1向左移位的次数。默认情况下,concurrencyLevel是16,则sshift是4
    // segmentShift用于定位参与散列运算的位数,默认为28
    this.segmentShift = 32 - sshift;
    // segmentMask是散列运算的掩码,segmentMask = 并发度 - 1,默认为15
    this.segmentMask = ssize - 1;
    
    // initialCapacity是整个map的初始大小
    if (initialCapacity > MAXIMUM_CAPACITY)
        initialCapacity = MAXIMUM_CAPACITY;
    // 根据initialCapacity计算Segment数组中每个位置可以分到的大小
    int c = initialCapacity / ssize;
    if (c * ssize < initialCapacity)
        ++c;
    
    // 默认MIN_SEGMENT_TABLE_CAPACITY是2,插入一个元素不至于扩容,插入第二个的时候才会扩容
    int cap = MIN_SEGMENT_TABLE_CAPACITY;
    while (cap < c)
        cap <<= 1;
    // create segments and segments[0]
    // 创建Segment数组以及segments[0]
    Segment<K,V> s0 =
        new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                            (HashEntry<K,V>[])new HashEntry[cap]);
    Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
    // 往数组写入segment[0]
    UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
    this.segments = ss;
}

public ConcurrentHashMap(int initialCapacity, float loadFactor) {
    this(initialCapacity, loadFactor, DEFAULT_CONCURRENCY_LEVEL);
}

public ConcurrentHashMap(int initialCapacity) {
    this(initialCapacity, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
}

public ConcurrentHashMap() {
    this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
}

public ConcurrentHashMap(Map<? extends K, ? extends V> m) {
    this(Math.max((int) (m.size() / DEFAULT_LOAD_FACTOR) + 1,
                    DEFAULT_INITIAL_CAPACITY),
            DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
    putAll(m);
}

从构造方法中可以看到,ConcurrentHashMap初始化的时候会初始化第一个槽segment[0],对于其他槽来说,在插入第一个值的时候进行初始化。初始化其他segment的方法为ensureSegment方法。

ensureSegment初始化其他segment

ensureSegment方法源码如下:

private Segment<K,V> ensureSegment(int k) {
    final Segment<K,V>[] ss = this.segments;
    long u = (k << SSHIFT) + SBASE; // raw offset
    Segment<K,V> seg;
    if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
        // 使用segment[0]当做原型,根据其数组长度和负载因子来初始化其他segment
        Segment<K,V> proto = ss[0]; // use segment 0 as prototype
        int cap = proto.table.length;
        float lf = proto.loadFactor;
        int threshold = (int)(cap * lf);
        // 初始化segment[k]内部的数组
        HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
            == null) { // recheck
            Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
            // 使用while循环,内部用CAS,当前线程成功设值或其他线程成功设值后退出
            while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                    == null) {
                if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                    break;
            }
        }
    }
    return seg;
}

这个方法主要作用就是获取下标为k的segment,若是对应槽位不存在,则进行初始化,方法中也注释了为什么需要先初始化segment[0],因为需要根据segment[0]的数组长度和负载因子来初始化其他槽位。那怎么定位元素对应的segment的下标呢?我们下面来看一下。

定位segment

既然ConcurrentHashMap使用分段锁来保护不同段的数据,那么在插入和获取元素的时候,必须先通过散列算法定位到segment,可以看到ConcurrentHashMap会首先使用Wang/Jenkins hash的变种算法对键值key进行一次再散列:

private int hash(Object k) {
    int h = hashSeed;

    if ((0 != h) && (k instanceof String)) {
        return sun.misc.Hashing.stringHash32((String) k);
    }

    h ^= k.hashCode();

    // Spread bits to regularize both segment and index locations,
    // using variant of single-word Wang/Jenkins hash.
    h += (h <<  15) ^ 0xffffcd7d;
    h ^= (h >>> 10);
    h += (h <<   3);
    h ^= (h >>>  6);
    h += (h <<   2) + (h << 14);
    return h ^ (h >>> 16);
}

之所以要进行再散列,是为了减少散列冲突,使元素能够均匀地分布在不同的segment中,从而提高存取效率。计算出key的散列值之后,然后通过下述代码来计算元素对应的segment下标:

int j = (hash >>> segmentShift) & segmentMask;

我们前面已经讲到,在默认情况下,segmentShift为28,segmentMask为15,hash向右无符号移动28位,是为了能够让高四位参与到散列运算中。

扩容操作

在添加元素时,通过上述步骤已经计算出了segment的下标,然后在对应segment里进行插入操作。插入操作需要经历两个步骤,第一步是判断是否需要对segment里的HashEntry数组进行扩容,第二部是定位添加元素的位置,然后将其放在HashEntry数组中。

在插入元素前会先判断HashEntry数组是否超过容量,如果是的话,则要进行扩容,segment的扩容操作比HashMap的扩容操作更加恰当,因为HashMap是在插入元素之后判断是否需要扩容,但是很有可能扩容之后没有新元素插入,那这时HashMap就进行了一次无效的扩容。

segment的扩容操作是通过Segment类中的rehash方法来完成的:

private void rehash(HashEntry<K,V> node) {
    HashEntry<K,V>[] oldTable = table;
    int oldCapacity = oldTable.length;
    // 容量扩大为以前的2倍
    int newCapacity = oldCapacity << 1;
    threshold = (int)(newCapacity * loadFactor);
    // 创建新数组
    HashEntry<K,V>[] newTable =
        (HashEntry<K,V>[]) new HashEntry[newCapacity];
    // 新的掩码
    int sizeMask = newCapacity - 1;
    // 遍历原数组,将原数组位置i处的链表拆分到新数组位置为i和i + oldCap的两个位置
    for (int i = 0; i < oldCapacity ; i++) {
        HashEntry<K,V> e = oldTable[i];
        if (e != null) {
            HashEntry<K,V> next = e.next;
            // 计算应该放置在新数组中的位置
            int idx = e.hash & sizeMask;
            // 该位置只有一个元素
            if (next == null)   //  Single node on list
                newTable[idx] = e;
            else { // Reuse consecutive sequence at same slot
                HashEntry<K,V> lastRun = e;
                // idx是当前链表的头结点e的新位置
                int lastIdx = idx;

                // 下面这个for循环会找到一个lastRun节点,这个节点之后的所有元素是将要放到一起的
                for (HashEntry<K,V> last = next;
                        last != null;
                        last = last.next) {
                    int k = last.hash & sizeMask;
                    if (k != lastIdx) {
                        lastIdx = k;
                        lastRun = last;
                    }
                }

                // 将lastRun及其之后的所有节点组成的这个链表放到lastIdx这个位置
                newTable[lastIdx] = lastRun;
                // Clone remaining nodes
                // 下面的操作是处理lastRun之前的节点,
                // 这些节点可能分配在另一个链表中,也可能分配到上面的那个链表中
                for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
                    V v = p.value;
                    int h = p.hash;
                    int k = h & sizeMask;
                    HashEntry<K,V> n = newTable[k];
                    newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                }
            }
        }
    }

    // 将新来的node放到新数组中两个链表之一的头部
    int nodeIndex = node.hash & sizeMask; // add the new node
    node.setNext(newTable[nodeIndex]);
    newTable[nodeIndex] = node;
    table = newTable;
}

可以看到,在扩容的时候,首先会创建一个容量是原来容量两倍的数组,然后将原数组里的元素进行再散列后插入到新的数组里。为了高效,ConcurrentHashMap不会对整个容器进行扩容,而只对某个segment进行扩容。

size操作

如果要统计整个ConcurrentHashMap里元素的数目,就必须统计所有segment里元素的数目,然后求和。Segment类中有一个count变量来记录对应segment里元素的数目,那是不是直接相加所有segment的count就可以了呢?不是的,我们在求和的时候,有可能有segment中元素的数目发生了变化,导致结果不准确。所以,最安全的做法就是在统计时,把所有segment的put、remove和clean方法都锁住,但是这种做法显然非常低效。

因为在累加count的过程中,count发生变化的概率比较小,所以ConcurrentHashMap的做法是先尝试2次通过不加锁的方式来计算各个segment的大小,若统计过程中,容器的count发生了变化,则再采用加锁的方式来统计所有segment的大小。

那ConcurrentHashMap是怎么判断在统计过程中容器是否发生变化呢?答案就是使用modCount变量,在put、remove、clean方法里操作元素前都会将modCount加1,那么在统计size前后比较modCount是否发生变化,就可以知道容器大小是否发生变化。

那我们来看看ConcurrentHashMap的size()方法源码:

public int size() {
    final Segment<K,V>[] segments = this.segments;
    int size;
    boolean overflow; // true if size overflows 32 bits
    // 各个segment的modCount的和
    long sum;         // sum of modCounts
    // 上一次计算得出的结果
    long last = 0L;   // previous sum
    // 重试次数
    int retries = -1; // first iteration isn't retry
    try {
        for (;;) {
            // 重试次数是否达到了RETRIES_BEFORE_LOCK,RETRIES_BEFORE_LOCK = 2
            if (retries++ == RETRIES_BEFORE_LOCK) {
                // 重试次数已经达到了RETRIES_BEFORE_LOCK,则依次获取各个segment上的锁
                for (int j = 0; j < segments.length; ++j)
                    ensureSegment(j).lock(); // force creation
            }
            sum = 0L;
            size = 0;
            overflow = false;
            for (int j = 0; j < segments.length; ++j) {
                Segment<K,V> seg = segmentAt(segments, j);
                if (seg != null) {
                    // 计算modCount的和
                    sum += seg.modCount;
                    int c = seg.count;
                    // 计算count的累加和
                    if (c < 0 || (size += c) < 0)
                        overflow = true;
                }
            }

            // 若计算的结果和上次结果相同,则直接退出循环
            if (sum == last)
                break;
            last = sum;
        }
    } finally {
        // 最终,若在计算size的过程中,获取了所有segment的锁,那在这里要释放所有的锁
        if (retries > RETRIES_BEFORE_LOCK) {
            for (int j = 0; j < segments.length; ++j)
                segmentAt(segments, j).unlock();
        }
    }
    return overflow ? Integer.MAX_VALUE : size;
}

需要注意的是,在计算size的时候,在获取所有segment的锁之前,计算了三次size值,因为第一次是没有历史值与它比较的,而且,后面最多还要2次通过不加锁的方式来计算各个segment的大小,所以在获取所有segment的锁之前,计算了三次size值。

关于JDK 1.7版本ConcurrentHashMap的实现方式就介绍到这里,下面我们来看JDK 1.8版本中实现的ConcurrentHashMap。

JDk 1.8版本的ConcurrentHashMap

JDk 1.8版本的的ConcurrentHashMap整体结构图如下:


JDk 1.8版本的的ConcurrentHashMap整体结构和HashMap几乎一样,不过它要保证线程安全性,所以在源码上确实要复杂一些。

ConcurrentHashMap有如下五个构造方法:

// table默认大小为16
public ConcurrentHashMap() {
}

// 初始化容量为 >= 1.5 * initialCapacity + 1计算出的2的整数次幂
public ConcurrentHashMap(int initialCapacity) {
    if (initialCapacity < 0)
        throw new IllegalArgumentException();
    int cap = ((initialCapacity >= (MAXIMUM_CAPACITY >>> 1)) ?
                MAXIMUM_CAPACITY :
                tableSizeFor(initialCapacity + (initialCapacity >>> 1) + 1));
    this.sizeCtl = cap;
}

// 创建一个和输入参数map映射一样的map
public ConcurrentHashMap(Map<? extends K, ? extends V> m) {
    this.sizeCtl = DEFAULT_CAPACITY;
    putAll(m);
}

public ConcurrentHashMap(int initialCapacity, float loadFactor) {
    this(initialCapacity, loadFactor, 1);
}

public ConcurrentHashMap(int initialCapacity,
                            float loadFactor, int concurrencyLevel) {
    if (!(loadFactor > 0.0f) || initialCapacity < 0 || concurrencyLevel <= 0)
        throw new IllegalArgumentException();
    if (initialCapacity < concurrencyLevel)   // Use at least as many bins
        initialCapacity = concurrencyLevel;   // as estimated threads
    // 根据initialCapacity和loadFactor来计算size
    long size = (long)(1.0 + (long)initialCapacity / loadFactor);
    // 初始化容量应该为不小于size的2的整数次幂
    int cap = (size >= (long)MAXIMUM_CAPACITY) ?
        MAXIMUM_CAPACITY : tableSizeFor((int)size);
    this.sizeCtl = cap;
}

在ConcurrentHashMap的构造方法中,并没有初始化table(除了第三个构造方法,调用了putAll来初始化),table的初始化发生在第一次插入操作,默认大小为16的数组,在ConcurrentHashMap中,元素都被封装为了Node对象:

static class Node<K,V> implements Map.Entry<K,V> {
    // 节点的哈希值
    final int hash;
    // 键
    final K key;
    // 值
    volatile V val;
    // 下一节点
    volatile Node<K,V> next;

    Node(int hash, K key, V val, Node<K,V> next) {
        this.hash = hash;
        this.key = key;
        this.val = val;
        this.next = next;
    }

    public final K getKey()       { return key; }
    public final V getValue()     { return val; }
    public final int hashCode()   { return key.hashCode() ^ val.hashCode(); }
    public final String toString(){ return key + "=" + val; }

    public final V setValue(V value) {
        throw new UnsupportedOperationException();
    }

    public final boolean equals(Object o) {
        Object k, v, u; Map.Entry<?,?> e;
        return ((o instanceof Map.Entry) &&
                (k = (e = (Map.Entry<?,?>)o).getKey()) != null &&
                (v = e.getValue()) != null &&
                (k == key || k.equals(key)) &&
                (v == (u = val) || v.equals(u)));
    }

    Node<K,V> find(int h, Object k) {
        Node<K,V> e = this;
        if (k != null) {
            do {
                K ek;
                if (e.hash == h &&
                    ((ek = e.key) == k || (ek != null && k.equals(ek))))
                    return e;
            } while ((e = e.next) != null);
        }
        return null;
    }
}

我们前面介绍过,ConcurrentHashMap中除了链表,还有红黑树,红黑树节点为TreeNode类型:

static final class TreeNode<K,V> extends Node<K,V> {
    // 父节点
    TreeNode<K,V> parent;  // red-black tree links
    // 左子节点
    TreeNode<K,V> left;
    // 右子节点
    TreeNode<K,V> right;
    // 删除节点时,需要断开链接,这个节点记录了删除节点的前一个节点
    TreeNode<K,V> prev;    // needed to unlink next upon deletion
    boolean red;

    TreeNode(int hash, K key, V val, Node<K,V> next,
                TreeNode<K,V> parent) {
        super(hash, key, val, next);
        this.parent = parent;
    }

    Node<K,V> find(int h, Object k) {
        return findTreeNode(h, k, null);
    }

    final TreeNode<K,V> findTreeNode(int h, Object k, Class<?> kc) {
        if (k != null) {
            TreeNode<K,V> p = this;
            do  {
                int ph, dir; K pk; TreeNode<K,V> q;
                TreeNode<K,V> pl = p.left, pr = p.right;
                if ((ph = p.hash) > h)
                    p = pl;
                else if (ph < h)
                    p = pr;
                else if ((pk = p.key) == k || (pk != null && k.equals(pk)))
                    return p;
                else if (pl == null)
                    p = pr;
                else if (pr == null)
                    p = pl;
                else if ((kc != null ||
                            (kc = comparableClassFor(k)) != null) &&
                            (dir = compareComparables(kc, k, pk)) != 0)
                    p = (dir < 0) ? pl : pr;
                else if ((q = pr.findTreeNode(h, k, kc)) != null)
                    return q;
                else
                    p = pl;
            } while (p != null);
        }
        return null;
    }
}

下面,我们来看ConcurrentHashMap中的一些关键静态属性:

// 最大容量(必须是2的幂且小于2的30次方,传入容量过大将被这个值替换)
private static final int MAXIMUM_CAPACITY = 1 << 30;
// 默认容量16
private static final int DEFAULT_CAPACITY = 16;
// 数组最大长度,在toArray等相关方法中用到
static final int MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8;
// 默认并发级别,已经不再使用的字段,之所以还存在只是为了和之前的版本兼容
private static final int DEFAULT_CONCURRENCY_LEVEL = 16;
// 此表的加载因子。在构造函数中重写此值只会影响初始表的容量,而不会使用实际的浮点值
private static final float LOAD_FACTOR = 0.75f;
// 添加当前元素,bin中元素个数若为8,则将链表转为红黑树
static final int TREEIFY_THRESHOLD = 8;
// bin中元素个数若为6个,则将红黑树转为链表
static final int UNTREEIFY_THRESHOLD = 6;
// table转为红黑树的阈值,此值最小为4*TREEIFY_THRESHOLD,默认为64
static final int MIN_TREEIFY_CAPACITY = 64;
// table扩容时,bin转移个数,最小为默认的DEFAULT_CAPACITY=16
// 因为扩容时,可以多个线程同时操作,所以16个bin会被分配给多个的线程进行转移
private static final int MIN_TRANSFER_STRIDE = 16;
// 用来控制扩容时,单线程进入的变量
private static int RESIZE_STAMP_BITS = 16;
// resize时的线程最大个数
private static final int MAX_RESIZERS = (1 << (32 - RESIZE_STAMP_BITS)) - 1;
// 用来控制扩容,单线程进入的变量
private static final int RESIZE_STAMP_SHIFT = 32 - RESIZE_STAMP_BITS;
// 节点hash域的编码
static final int MOVED     = -1; // hash for forwarding nodes
static final int TREEBIN   = -2; // hash for roots of trees
static final int RESERVED  = -3; // hash for transient reservations
static final int HASH_BITS = 0x7fffffff; // usable bits of normal node hash
// 当前可用cpu数量
static final int NCPU = Runtime.getRuntime().availableProcessors();

很多静态变量都与HashMap中的变量相似。同时,ConcurrentHashMap还有如下几个成员变量:

// Node数组,该变量只有在第一次插入元素时才会初始化
transient volatile Node<K,V>[] table;
// resize时用到的临时table,只有在resize时才不为null
private transient volatile Node<K,V>[] nextTable;
// 基本计数器值,主要用于没有争用时,也可作为表初始化期间的后备,通过CAS更新。
private transient volatile long baseCount;
/**
* 用于控制table初始化和resize的一个变量
* 值为负数:table正在初始化or正在resize
*     sizeCtl = -1:正在初始化;
*     sizeCtl = -(1 + n):当前有n个线程正在进行resize;
* 当table未初始化时,保存创建时使用的初始表大小,或默认为0
* 初始化后,保存下一个要调整table大小的元素计数值
*/
private transient volatile int sizeCtl;
// resize时,next table的索引+1,用于分割
private transient volatile int transferIndex;
//在调整大小和/或创建CounterCells时使用的自旋锁(通过CAS锁定)
private transient volatile int cellsBusy;
//这是一个计数器数组,用于保存table中每一下标对应的节点个数
private transient volatile CounterCell[] counterCells;

ConcurrentHashMap中的有些变量我们都不知道有什么作用,不过没关系,我们继续来看ConcurrentHashMap的源码,在源码分析过程中来解释部分变量的作用。

我们下面来看ConcurrentHashMap的几个关键方法:put方法、get方法和remove方法。

put方法

该方法很简单:

public V put(K key, V value) {
    return putVal(key, value, false);
}

put方法直接调用了putVal方法:

final V putVal(K key, V value, boolean onlyIfAbsent) {
    // 若key或value为null,则直接抛出NullPointerException异常
    if (key == null || value == null) throw new NullPointerException();
    // 计算节点的hash值
    int hash = spread(key.hashCode());
    // 用来记录相应链表的长度
    int binCount = 0;

    for (Node<K,V>[] tab = table;;) {
        Node<K,V> f; int n, i, fh;
        // 如果table数组为null或长度为0,则对数组进行初始化
        if (tab == null || (n = tab.length) == 0)
            tab = initTable();
        // 否则,按照hash值对应的数组下标,得到第一个节点f
        // 若f为null,则通过CAS将该节点设置为对应下标的首节点
        else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
            if (casTabAt(tab, i, null,
                            new Node<K,V>(hash, key, value, null)))
                break;                   // no lock when adding to empty bin
        }
        // 若果f节点的hash值为MOVED,此时表示数组在扩容,则帮助数据迁移
        else if ((fh = f.hash) == MOVED)
            tab = helpTransfer(tab, f);
        else {
            V oldVal = null;
            // 获取数组该位置的头结点的监视器锁
            synchronized (f) {
                if (tabAt(tab, i) == f) {
                    // 头结点的hash值大于0,说明是链表,因为红黑树的根节点hash值是TREEBIN(-2)
                    if (fh >= 0) {
                        // 用于累加,记录链表的长度
                        binCount = 1;
                        // 遍历链表
                        for (Node<K,V> e = f;; ++binCount) {
                            K ek;
                            // 如果发现了"相等"的key,判断是否要进行值覆盖
                            if (e.hash == hash &&
                                ((ek = e.key) == key ||
                                    (ek != null && key.equals(ek)))) {
                                oldVal = e.val;
                                // put方法传入的onlyIfAbsent默认为false,即可以覆盖
                                if (!onlyIfAbsent)
                                    e.val = value;
                                break;
                            }
                            // 到了链表的末尾,将新节点放到链表的最后面
                            Node<K,V> pred = e;
                            if ((e = e.next) == null) {
                                pred.next = new Node<K,V>(hash, key,
                                                            value, null);
                                break;
                            }
                        }
                    }
                    // 否则,头结点为红黑树节点
                    else if (f instanceof TreeBin) {
                        Node<K,V> p;
                        binCount = 2;
                        // 调用红黑树的putTreeVal方法插入新节点
                        if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key,
                                                        value)) != null) {
                            oldVal = p.val;
                            if (!onlyIfAbsent)
                                p.val = value;
                        }
                    }
                }
            }
            // 插入节点后,检查数组对应下标的节点个数是否 >= TREEIFY_THRESHOLD,如果是,则由链表转为红黑树
            if (binCount != 0) {
                if (binCount >= TREEIFY_THRESHOLD)
                    treeifyBin(tab, i);
                if (oldVal != null)
                    return oldVal;
                break;
            }
        }
    }

    // 更新元素数目
    addCount(1L, binCount);
    return null;
}

put方法的主流程看完了,但是putVal方法中还调用到了一些其他方法,我们来看一下。

首先是计算节点hash值的spread(int)方法

spread(int)方法

static final int spread(int h) {
    // HASH_BITS = 0x7fffffff
    return (h ^ (h >>> 16)) & HASH_BITS;
}

为什么采用h ^ (h >>> 16)的方式来计算hash值,前面我们介绍HashMap时,已经做了解释,这里不再叙述。计算出来的hash值还要和HASH_BITS进行与运算才是最终结果,为什么要进行与运算呢?HASH_BITS的值是0x7fffffff,一个整形数字与HASH_BITS进行与运算,其实就是将数字二进制表示的第一位设置为0,它这样做的目的是消除符号位的影响,因为在table中,有些节点的hash值是特定的负数,比如前面介绍到的节点的hash域编码:

static final int MOVED     = -1; // hash for forwarding nodes
static final int TREEBIN   = -2; // hash for roots of trees
static final int RESERVED  = -3; // hash for transient reservations

红黑树根节点的hash值是TREEBIN。

initTable()

initTable()是用来对table进行初始化的,我们来看源代码:

private final Node<K,V>[] initTable() {
    Node<K,V>[] tab; int sc;
    // 若table为null或者table的长度为0,则进行初始化操作
    while ((tab = table) == null || tab.length == 0) {
        // 若我们设置了初始容量,则在构造方法会设置sizeCtl的值,否则,sizeCtl为0
        // 若sizeCtl小于0,则表明已经有其他线程在初始化
        if ((sc = sizeCtl) < 0)
            Thread.yield(); // lost initialization race; just spin
        // 通过CAS操作将sizeCtl设置为-1,代表本线程来初始化,其他线程就不要初始化了
        else if (U.compareAndSwapInt(this, SIZECTL, sc, -1)) {
            try {
                if ((tab = table) == null || tab.length == 0) {
                    // 获取table的初始容量
                    int n = (sc > 0) ? sc : DEFAULT_CAPACITY;
                    @SuppressWarnings("unchecked")
                    // 初始化数组
                    Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n];
                    table = tab = nt;
                    // 计算sc的值,其实sc = 0.75 * n
                    sc = n - (n >>> 2);
                }
            } finally {
                // 设置sizeCtl为sc
                sizeCtl = sc;
            }
            break;
        }
    }
    return tab;
}

helpTransfer方法我们后面再介绍,下面先看treeifyBin方法。

treeifyBin方法

private final void treeifyBin(Node<K,V>[] tab, int index) {
    Node<K,V> b; int n, sc;
    if (tab != null) {
        // MIN_TREEIFY_CAPACITY = 64
        // 如果数组长度小于64,则会进行数组扩容
        if ((n = tab.length) < MIN_TREEIFY_CAPACITY)
            // 扩容
            tryPresize(n << 1);
        // 否则,b为index对应链表的首节点
        else if ((b = tabAt(tab, index)) != null && b.hash >= 0) {
            // 获取b的监视器锁
            synchronized (b) {
                if (tabAt(tab, index) == b) {
                    TreeNode<K,V> hd = null, tl = null;
                    // 下面就是遍历链表,建立一颗红黑树
                    for (Node<K,V> e = b; e != null; e = e.next) {
                        TreeNode<K,V> p =
                            new TreeNode<K,V>(e.hash, e.key, e.val,
                                                null, null);
                        if ((p.prev = tl) == null)
                            hd = p;
                        else
                            tl.next = p;
                        tl = p;
                    }
                    // 将红黑树设置到数组相应位置中
                    setTabAt(tab, index, new TreeBin<K,V>(hd));
                }
            }
        }
    }
}

从源码中我们知道,treeifyBin方法不一定就会进行红黑树转换,也可能是仅仅做数组扩容。扩容是通过tryPresize(int)方法来完成的,int参数就是扩容或的值,我们下面来看扩容操作。

tryPresize方法

// 首先要说明的是,方法参数size传进来的时候就已经翻倍了
private final void tryPresize(int size) {
    // 尝试将table大小设定为1.5 * size + 1,以容纳元素
    int c = (size >= (MAXIMUM_CAPACITY >>> 1)) ? MAXIMUM_CAPACITY :
        tableSizeFor(size + (size >>> 1) + 1);
    int sc;
    // 若sizeCtl < 0,则表明已经有其他线程在扩容
    // 若sizeCtl >= 0,则本线程进行扩容
    while ((sc = sizeCtl) >= 0) {
        Node<K,V>[] tab = table; int n;
        // 若table为null或者table的长度为0,则进行初始化操作
        if (tab == null || (n = tab.length) == 0) {
            n = (sc > c) ? sc : c;
             // 通过CAS操作将sizeCtl设置为-1,代表本线程来初始化,其他线程就不要初始化了
            if (U.compareAndSwapInt(this, SIZECTL, sc, -1)) {
                try {
                    if (table == tab) {
                        @SuppressWarnings("unchecked")
                        // 创建新数组
                        Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n];
                        table = nt;
                        // 计算sc的值,其实sc = 0.75 * n
                        sc = n - (n >>> 2);
                    }
                } finally {
                    // 设置sizeCtl为sc
                    sizeCtl = sc;
                }
            }
        }
        // 若扩容值小于原阀值,或现有容量 >= 最大值,则直接退出
        else if (c <= sc || n >= MAXIMUM_CAPACITY)
            break;
        // table不为空,且在此期间,其他线程没有修改table
        else if (tab == table) {
            // 返回table的扩容标记位
            int rs = resizeStamp(n);
            // 已经有线程在进行扩容工作
            if (sc < 0) {
                Node<K,V>[] nt;
                // 条件1检查原容量为n的情况下进行扩容,保证sizeCtl与n是一块修改好的,
                // 条件2与条件3在当前RESIZE_STAMP_BITS情况下应该不会成功。
                // 条件4与条件5确保tranfer()中的nextTable相关初始化逻辑已走完。
                if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
                    sc == rs + MAX_RESIZERS || (nt = nextTable) == null ||
                    transferIndex <= 0)
                    break;
                // 有新线程参与扩容则sizeCtl加1
                if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1))
                    transfer(tab, nt);
            }
            // 修改sizeCtl的值,开始扩容
            else if (U.compareAndSwapInt(this, SIZECTL, sc,
                                            (rs << RESIZE_STAMP_SHIFT) + 2))
                transfer(tab, null);
        }
    }
}

我们看到tryPresize方法中调用了一个叫做resizeStamp的方法,我们看一看这个方法做了什么事情:

static final int resizeStamp(int n) {
    return Integer.numberOfLeadingZeros(n) | (1 << (RESIZE_STAMP_BITS - 1));
}

这个方法会返回一个与table容量n大小有关的扩容标记,它的实现很简单,Integer.numberOfLeadingZeros(int n)方法是计算在n的二进制表示中,前面一共有多少个连续的0,然后将其与1 << (RESIZE_STAMP_BITS - 1)进行或运算。它这么做的意义是什么呢?我写了一段代码来验证该方法的作用:

private static int RESIZE_STAMP_BITS = 16;
private static final int RESIZE_STAMP_SHIFT = 32 - RESIZE_STAMP_BITS;

public static void main(String[] args) throws UnsupportedEncodingException {
    int size = 1 << 4;
    int rs = resizeStamp(size);

    formatPrintNum(rs);
    
    int rs2 = (rs << RESIZE_STAMP_SHIFT) + 2;
    formatPrintNum(rs2);
    formatPrintNum(rs2 + 1);
}

static final int resizeStamp(int n) {
    return Integer.numberOfLeadingZeros(n) | (1 << (RESIZE_STAMP_BITS - 1));
}

static void formatPrintNum(int n) {
    String s = Integer.toBinaryString(n);
    
    while (s.length() < 32) {
        s = "0" + s;
    }
    
    s = s.substring(0, 16) + " | " + s.substring(16, 32);
    System.out.println(s);
}

运行结果:

0000000000000000 | 1000000000011011
1000000000011011 | 0000000000000010
1000000000011011 | 0000000000000011

我把结果的输出分为了高16位和低16位。

假设table的容量为16,则通过resizeStamp方法计算出的扩容标记位是“1000000000011011”(只看低16位),(rs << RESIZE_STAMP_SHIFT) + 2(后面称为rs2)的值是“1000000000011011 | 0000000000000010”,rs2+1的值是“1000000000011011 | 0000000000000011”。

这个程序的目的是什么呢?

我们先来看rs2,rs2的值一直为负数,因为resizeStamp方法中(1 << (RESIZE_STAMP_BITS - 1))计算出来的值二进制表示为“1000000000000000”(低16位),通过或运算后得出的扩容标记位rs的二进制表示中,第17位一定为1,而rs2是通过rs左移RESIZE_STAMP_SHIFT计算得到的,则rs2的二进制表示中,最高位一定为1,即rs2的值一直为负数。

前面我们介绍了sizeCtl的作用,若rs2就是sizeCtl,那么sizeCtl表示什么呢?rs2不等于-1,那么sizeCtl的取值就只有下面一种情况了:

sizeCtl = -(1 + n):当前有n个线程正在进行resize;

rs2转成十进制表示是

-2145714174

难道我们有2145714173个线程在做resize操作吗?肯定不是的。其实-(1 + n)中的(1 + n)只是sizeCtl的低16位了,rs2的低16位表示为十进制是2,即表示当前有1个线程在做resize操作,若有其他线程参与进来则,sizeCtl的值加1,。

通过上面的分析,我们知道,在扩容时sizeCtl的意义如下图所示:

高RESIZE_STAMP_BITS位 低RESIZE_STAMP_SHIFT位
扩容标记 并行扩容线程数

看懂上面的分析之后,那tryPresize方法后面的程序大部分都可以看懂了,不过我一直不清楚的是tryPresize方法中为什么还会有sc < 0的情况,外层不是通过while ((sc = sizeCtl) >= 0)循环进入的吗?既然能进入循环,那这个sc < 0是干嘛的。。。??

该方法最后会调用transfer()来进行真正的扩容处理。

transfer方法

private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {
    int n = tab.length, stride;

    // stride在单核下直接等于n,多核模式下为(n >>> 3) / NCPU
    // stride可以理解为“步长”,表示每个线程处理桶的最小数目,可以看出核数越高步长越小,最小值是最小分割并行段数MIN_TRANSFER_STRIDE(16)
    if ((stride = (NCPU > 1) ? (n >>> 3) / NCPU : n) < MIN_TRANSFER_STRIDE)
        stride = MIN_TRANSFER_STRIDE; // subdivide range
    
    // 如果新数组nextTab为null,先进行一次初始化,长度为旧table的2倍
    if (nextTab == null) {            // initiating
        try {
            @SuppressWarnings("unchecked")
            Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n << 1];
            nextTab = nt;
        } catch (Throwable ex) {      // try to cope with OOME
            sizeCtl = Integer.MAX_VALUE;
            return;
        }
        nextTable = nextTab;
        transferIndex = n;
    }

    // 新table的长度
    int nextn = nextTab.length;
    // ForwardingNode就是正在被迁移的Node,ForwardingNode的hash值被设置成为了MOVED(-1),这个Node的key、value和next都为null
    // 后面我们会看到,原数组中位置i处的节点完成迁移工作后
    // 就会将位置i处的节点设置为这个ForwardingNode,用来告诉其他线程该位置已经处理过了
    // 所以它其实相当于是一个标志
    ForwardingNode<K,V> fwd = new ForwardingNode<K,V>(nextTab);

    // 并发扩容的关键属性,如果advance等于true,说明这个节点已经处理过,可以处理下一个位置的节点了
    boolean advance = true;
    boolean finishing = false; // to ensure sweep before committing nextTab
    
    // i是位置索引,bound是边界,注意是从后往前
    for (int i = 0, bound = 0;;) {
        Node<K,V> f; int fh;
        // 这个while循环体的作用就是在控制i--
        // 通过i--可以依次遍历原hash表中的节点
        // 可以简单理解为:i指向了transferIndex,bound指向了transferIndex - stride
        while (advance) {
            int nextIndex, nextBound;
            if (--i >= bound || finishing)
                advance = false;
            else if ((nextIndex = transferIndex) <= 0) {
                i = -1;
                advance = false;
            }
            else if (U.compareAndSwapInt
                        (this, TRANSFERINDEX, nextIndex,
                        nextBound = (nextIndex > stride ?
                                    nextIndex - stride : 0))) {
                
                // 确定当前线程每次分配的待迁移桶的范围[bound, nextIndex)
                bound = nextBound;
                i = nextIndex - 1;
                advance = false;
            }
        }
        if (i < 0 || i >= n || i + n >= nextn) {
            int sc;
            // 如果所有的节点都已经完成复制工作,就把nextTable赋值给table
            if (finishing) {
                nextTable = null;
                table = nextTab;
                // 重新计算sizeCtl,n是原数组长度,所以计算得出的值将是新数组长度的0.75倍
                sizeCtl = (n << 1) - (n >>> 1);
                return;
            }

            // 之前我们说过,sizeCtl在迁移前会设置为(rs << RESIZE_STAMP_SHIFT) + 2
            // 然后,每有一个线程参与迁移就会将sizeCtl加1
            // 这里使用CAS操作对sizeCtl减1,代表该线程做完了属于自己的任务
            if (U.compareAndSwapInt(this, SIZECTL, sc = sizeCtl, sc - 1)) {
                // 任务结束,方法退出
                if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT)
                    return;
                // (sc - 2) == resizeStamp(n) << RESIZE_STAMP_SHIFT
                // 这表明所有的线程都完成了迁移工作,设置finishing为true,下次循环就会运行上面的if(finishing){}分支了
                finishing = advance = true;
                i = n; // recheck before commit
            }
        }
        // 如果位置i处是空的,没有任何节点,那么放入刚刚初始化的ForwardingNode节点
        else if ((f = tabAt(tab, i)) == null)
            advance = casTabAt(tab, i, null, fwd);
        // 该位置处是ForwardingNode节点,代表该位置已经迁移过了
        else if ((fh = f.hash) == MOVED)
            advance = true; // already processed
        else {
            // 对数组该位置处的结点加锁,开始处理数组该位置处的迁移工作
            synchronized (f) {
                if (tabAt(tab, i) == f) {
                    Node<K,V> ln, hn;
                    // 头结点的hash值大于0,说明是链表节点
                    if (fh >= 0) {
                        // 下面这一段代码和JDK 1.7中的ConcurrentHashMap迁移差不多
                        int runBit = fh & n;
                        Node<K,V> lastRun = f;
                        for (Node<K,V> p = f.next; p != null; p = p.next) {
                            int b = p.hash & n;
                            if (b != runBit) {
                                runBit = b;
                                lastRun = p;
                            }
                        }
                        if (runBit == 0) {
                            ln = lastRun;
                            hn = null;
                        }
                        else {
                            hn = lastRun;
                            ln = null;
                        }
                        for (Node<K,V> p = f; p != lastRun; p = p.next) {
                            int ph = p.hash; K pk = p.key; V pv = p.val;
                            if ((ph & n) == 0)
                                ln = new Node<K,V>(ph, pk, pv, ln);
                            else
                                hn = new Node<K,V>(ph, pk, pv, hn);
                        }
                        // 将其中的一个链表放在新数组的位置i
                        setTabAt(nextTab, i, ln);
                        // 将另一个链表放在新数组的位置i + n
                        setTabAt(nextTab, i + n, hn);
                        // 将原数组该位置处设置为fwd,代表该位置已经处理完毕
                        // 其他线程一旦看到该位置的hash值为 MOVED,就不会进行迁移了
                        setTabAt(tab, i, fwd);
                        // advance设置为true,代表该位置已经迁移完毕
                        advance = true;
                    }
                    // 若头结点是红黑树节点,则进行红黑树的迁移
                    else if (f instanceof TreeBin) {
                        TreeBin<K,V> t = (TreeBin<K,V>)f;
                        TreeNode<K,V> lo = null, loTail = null;
                        TreeNode<K,V> hi = null, hiTail = null;
                        int lc = 0, hc = 0;
                        for (Node<K,V> e = t.first; e != null; e = e.next) {
                            int h = e.hash;
                            TreeNode<K,V> p = new TreeNode<K,V>
                                (h, e.key, e.val, null, null);
                            if ((h & n) == 0) {
                                if ((p.prev = loTail) == null)
                                    lo = p;
                                else
                                    loTail.next = p;
                                loTail = p;
                                ++lc;
                            }
                            else {
                                if ((p.prev = hiTail) == null)
                                    hi = p;
                                else
                                    hiTail.next = p;
                                hiTail = p;
                                ++hc;
                            }
                        }

                        // 如果一分为二后,节点数少于UNTREEIFY_THRESHOLD,那么将红黑树转换回链表
                        ln = (lc <= UNTREEIFY_THRESHOLD) ? untreeify(lo) :
                            (hc != 0) ? new TreeBin<K,V>(lo) : t;
                        hn = (hc <= UNTREEIFY_THRESHOLD) ? untreeify(hi) :
                            (lc != 0) ? new TreeBin<K,V>(hi) : t;
                        // 将ln放置在新数组的位置i
                        setTabAt(nextTab, i, ln);
                        // 将hn放置在新数组的位置i + n
                        setTabAt(nextTab, i + n, hn);
                        // 将原数组该位置处设置为fwd,代表该位置已经处理完毕
                        // 其他线程一旦看到该位置的hash值为 MOVED,就不会进行迁移了
                        setTabAt(tab, i, fwd);
                        // advance设置为true,代表该位置已经迁移完毕
                        advance = true;
                    }
                }
            }
        }
    }
}

我们这里介绍一下ForwardingNode的作用,它主要有两个:

  • 标明此节点已完成迁移
  • 为方便扩容期间的元素查找需求,里面有find()方法可以从nextTable查找元素

下面来看helpTransfer方法。

helpTransfer方法

helpTransfer方法是用来帮助其他线程进行数据迁移的,源代码如下:

final Node<K,V>[] helpTransfer(Node<K,V>[] tab, Node<K,V> f) {
    Node<K,V>[] nextTab; int sc;
    // 若tab不为null且首节点f是ForwardingNode节点,且f的nextTable不为null,即已经有其他线程在进行resize操作
    if (tab != null && (f instanceof ForwardingNode) &&
        (nextTab = ((ForwardingNode<K,V>)f).nextTable) != null) {
        // 计算扩容标记位
        int rs = resizeStamp(tab.length);
        // 扩容还没有完成
        while (nextTab == nextTable && table == tab &&
                (sc = sizeCtl) < 0) {
            // 扩容结束
            if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
                sc == rs + MAX_RESIZERS || transferIndex <= 0)
                break;
            // 帮助数据迁移,因为多了一个迁移线程,所以要将sizeCtl加1
            if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1)) {
                transfer(tab, nextTab);
                break;
            }
        }
        return nextTab;
    }
    return table;
}

如果已经了解了前面介绍的方法,那这个方法就比较简单了。

扩容操作总结

  • 单线程新建nextTable,扩容为原table容量的两倍。
  • 每个线程想增/删元素时,如果访问的桶是ForwardingNode节点,则表明当前正处于扩容状态,协助一起扩容完成后再完成相应的数据更改操作。
  • 扩容时将原table的所有桶倒序分配,每个线程每次最小分16个桶进行处理,防止资源竞争导致的效率下降, 每个桶的迁移是单线程的,但桶范围处理分配可以多线程,在没有迁移完成所有桶之前每个线程需要重复获取迁移桶范围,直至所有桶迁移完成。
  • 一个旧桶内的数据迁移完成但迁移工作没有全部完成时,查询数据委托给ForwardingNode结点查询nextTable完成(这个后面看find()分析)。
  • 迁移过程中sizeCtl用于记录参与扩容线程的数量,全部迁移完成后sizeCtl更新为新table的扩容阈值。

将元素添加到table中之后,put方法最后还要更新元素的个数,我们来看一下addCount方法。

addCount方法

// 增加节点个数,如果table太小而没有resize,则检查是否需要resize。如果已经调整大小,则可以帮助复制转移节点
// 转移后重新检查占用情况,以确定是否还需要调整大小,因为resize总是比put操作滞后
private final void addCount(long x, int check) {
    CounterCell[] as; long b, s;
    // 通过CAS操作更新baseCount
    if ((as = counterCells) != null ||
        !U.compareAndSwapLong(this, BASECOUNT, b = baseCount, s = b + x)) {
        // 若counterCells不为null或者更新baseCount失败
        CounterCell a; long v; int m;
        boolean uncontended = true;
        if (as == null || (m = as.length - 1) < 0 ||
            (a = as[ThreadLocalRandom.getProbe() & m]) == null ||
            !(uncontended =
                U.compareAndSwapLong(a, CELLVALUE, v = a.value, v + x))) {
            // 调用fullAddCount方法进行初始化
            fullAddCount(x, uncontended);
            return;
        }
        if (check <= 1)
            return;
        s = sumCount();
    }

    // check就是binCount,有新元素加入成功才检查是否要扩容
    if (check >= 0) {
        Node<K,V>[] tab, nt; int n, sc;
        // 元素数目大于当前扩容阈值并且小于最大扩容值才扩容,如果table还未初始化则等待初始化完成
        while (s >= (long)(sc = sizeCtl) && (tab = table) != null &&
                (n = tab.length) < MAXIMUM_CAPACITY) {
            // 返回table的扩容标记位
            int rs = resizeStamp(n);
            // 如果已经有其他线程在进行扩容
            if (sc < 0) {
                if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
                    sc == rs + MAX_RESIZERS || (nt = nextTable) == null ||
                    transferIndex <= 0)
                    break;
                // 该线程参与扩容,则将sizeCtl的值加1
                if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1))
                    transfer(tab, nt);
            }

            // 没有线程在进行扩容,则该线程开始扩容,设置sizeCtl的值
            else if (U.compareAndSwapInt(this, SIZECTL, sc,
                                            (rs << RESIZE_STAMP_SHIFT) + 2))
                transfer(tab, null);
            s = sumCount();
        }
    }
}

看完这个方法,其实还是不是很懂baseCount和counterCells的含义。。

看注释中写道,baseCount是在没有竞争时使用的变量,所以,我感觉在计算元素数目时,如果没有产生竞争,则用baseCount来记录,否则用counterCells记录了。欢迎指正~

那这样的话,JDK 1.8实现的ConcurrentHashMap再求所有元素数目时,就比较简单了:

public int size() {
    long n = sumCount();
    return ((n < 0L) ? 0 :
            (n > (long)Integer.MAX_VALUE) ? Integer.MAX_VALUE :
            (int)n);
}

final long sumCount() {
    CounterCell[] as = counterCells; CounterCell a;
    long sum = baseCount;
    if (as != null) {
        for (int i = 0; i < as.length; ++i) {
            if ((a = as[i]) != null)
                sum += a.value;
        }
    }
    return sum;
}

通过累加baseCount和CounterCell数组中的值,即可得到元素的总数目。

好了,put方法到这就算是讲完了吧。。下面来看get方法。

get方法

public V get(Object key) {
    Node<K,V>[] tab; Node<K,V> e, p; int n, eh; K ek;
    // 计算key对应节点的hash值
    int h = spread(key.hashCode());
    // 如果table不为null,table的长度不为0且对应下标存在节点
    if ((tab = table) != null && (n = tab.length) > 0 &&
        (e = tabAt(tab, (n - 1) & h)) != null) {
        // 判断头结点是否就是我们查找的节点
        if ((eh = e.hash) == h) {
            if ((ek = e.key) == key || (ek != null && key.equals(ek)))
                return e.val;
        }
        // 如果头结点的hash值小于0,说明table正在扩容,或者e是红黑树节点
        else if (eh < 0)
            return (p = e.find(h, key)) != null ? p.val : null;
        // 否则,遍历链表寻找匹配的节点
        while ((e = e.next) != null) {
            if (e.hash == h &&
                ((ek = e.key) == key || (ek != null && key.equals(ek))))
                return e.val;
        }
    }
    return null;
}

get方法整体还是比较简单的,如果头结点的hash值小于0,说明table正在扩容,或者e是红黑树节点,那我们来看一下,若table正在扩容时查找节点的代码:

Node<K,V> find(int h, Object k) {
    // loop to avoid arbitrarily deep recursion on forwarding nodes
    outer: for (Node<K,V>[] tab = nextTable;;) {
        Node<K,V> e; int n;
        if (k == null || tab == null || (n = tab.length) == 0 ||
            (e = tabAt(tab, (n - 1) & h)) == null)
            return null;
        for (;;) {
            int eh; K ek;
            // 如果找到了匹配节点,则返回
            if ((eh = e.hash) == h &&
                ((ek = e.key) == k || (ek != null && k.equals(ek))))
                return e;
            // 若节点的hash值小于0
            if (eh < 0) {
                // 若节点是ForwardingNode节点
                if (e instanceof ForwardingNode) {
                    // 将新创建的table赋值给tab,在新table中查找
                    tab = ((ForwardingNode<K,V>)e).nextTable;
                    continue outer;
                }
                // 否则,节点是红黑树节点
                else
                    return e.find(h, k);
            }
            if ((e = e.next) == null)
                return null;
        }
    }
}

remove方法

public V remove(Object key) {
    return replaceNode(key, null, null);
}

final V replaceNode(Object key, V value, Object cv) {
    // 计算key对应节点的hash值
    int hash = spread(key.hashCode());
    for (Node<K,V>[] tab = table;;) {
        Node<K,V> f; int n, i, fh;
        if (tab == null || (n = tab.length) == 0 ||
            (f = tabAt(tab, i = (n - 1) & hash)) == null)
            break;
        // 若首节点对应hash值是MOVED,则表明该节点是ForwardingNode节点,帮助table扩容
        else if ((fh = f.hash) == MOVED)
            tab = helpTransfer(tab, f);
        else {
            V oldVal = null;
            boolean validated = false;
            // 获取首节点的监视器锁
            synchronized (f) {
                if (tabAt(tab, i) == f) {
                    // 节点的hash值 >= 0,则表明节点是链表节点
                    if (fh >= 0) {
                        validated = true;
                        // 查找与key匹配的节点
                        for (Node<K,V> e = f, pred = null;;) {
                            K ek;
                            // 找到了匹配节点
                            if (e.hash == hash &&
                                ((ek = e.key) == key ||
                                    (ek != null && key.equals(ek)))) {
                                V ev = e.val;

                                // 如果cv为null或者cv与key对应的旧值ev“相等”
                                if (cv == null || cv == ev ||
                                    (ev != null && cv.equals(ev))) {
                                    oldVal = ev;

                                    // 若value不等于null,则更新key节点对应的val
                                    if (value != null)
                                        e.val = value;
                                    // 否则,将该节点删除
                                    else if (pred != null)
                                        pred.next = e.next;
                                    else
                                        setTabAt(tab, i, e.next);
                                }
                                break;
                            }
                            pred = e;
                            if ((e = e.next) == null)
                                break;
                        }
                    }
                    // 节点是红黑树节点
                    else if (f instanceof TreeBin) {
                        validated = true;
                        TreeBin<K,V> t = (TreeBin<K,V>)f;
                        TreeNode<K,V> r, p;
                        if ((r = t.root) != null &&
                            (p = r.findTreeNode(hash, key, null)) != null) {
                            V pv = p.val;
                            if (cv == null || cv == pv ||
                                (pv != null && cv.equals(pv))) {
                                oldVal = pv;
                                if (value != null)
                                    p.val = value;
                                else if (t.removeTreeNode(p))
                                    setTabAt(tab, i, untreeify(t.first));
                            }
                        }
                    }
                }
            }
            if (validated) {
                if (oldVal != null) {
                    // 若删除了一个节点,则更新元素数目
                    if (value == null)
                        addCount(-1L, -1);
                    return oldVal;
                }
                break;
            }
        }
    }
    return null;
}

remove方法也比较简单,不再叙述。

其他相关方法

要判断一个key在ConcurrentHashMap中是否存在,可以用get(Object)方法来判断的,因为ConcurrentHashMap中节点的key和value都不允许为null,而且,我们可以用containsKey(Object)方法来判断:

public boolean containsKey(Object key) {
    return get(key) != null;
}

可以看到,containsKey(Object)方法就是通过get(Object)方法来判断的。

clear()方法

public void clear() {
    long delta = 0L; // negative number of deletions
    int i = 0;
    Node<K,V>[] tab = table;
    // 遍历table数组
    while (tab != null && i < tab.length) {
        int fh;
        // 获取table数组中下标为i的首节点
        Node<K,V> f = tabAt(tab, i);
        // 首节点为null,则更新下标i
        if (f == null)
            ++i;
        // 首节点hash值是MOVED,则帮助迁移
        else if ((fh = f.hash) == MOVED) {
            tab = helpTransfer(tab, f);
            // 迁移完成之后,将i置为0,重新开始清除table
            i = 0; // restart
        }
        else {
            // 获取首节点的监视器锁
            synchronized (f) {
                // 根据Node节点的next属性,删除对应链表或者红黑树
                if (tabAt(tab, i) == f) {
                    Node<K,V> p = (fh >= 0 ? f :
                                    (f instanceof TreeBin) ?
                                    ((TreeBin<K,V>)f).first : null);
                    while (p != null) {
                        --delta;
                        p = p.next;
                    }
                    setTabAt(tab, i++, null);
                }
            }
        }
    }
    // 更新元素数目
    if (delta != 0L)
        addCount(delta, -1);
}

相关问题

1、JDK 1.8为什么要放弃Segment?

锁的粒度

首先锁的粒度并没有变粗,甚至变得更细了。每当扩容一次,ConcurrentHashMap的并发度就扩大一倍。

Hash冲突

JDK1.7中,ConcurrentHashMap从过二次hash的方式(Segment -> HashEntry)能够快速的找到查找的元素。在1.8中通过链表加红黑树的形式弥补了put、get时的性能差距。

扩容

JDK1.8中,在ConcurrentHashmap进行扩容时,其他线程可以通过检测数组中的节点决定是否对这条链表(红黑树)进行扩容,减小了扩容的粒度,提高了扩容的效率。

2、JDK 1.8为什么要使用synchronized而不是可重入锁?

减少内存开销

假设使用可重入锁来获得同步支持,那么每个节点都需要通过继承AQS来获得同步支持。但并不是每个节点都需要获得同步支持的,只有链表的头节点(红黑树的根节点)需要同步,这无疑带来了巨大内存浪费。 

获得JVM的支持

可重入锁毕竟是API这个级别的,后续的性能优化空间很小。 synchronized则是JVM直接支持的,JVM能够在运行时作出相应的优化措施:锁粗化、锁消除、锁自旋等等。这就使得synchronized能够随着JDK版本的升级而不改动代码的前提下获得性能上的提升。

3、ConcurrentHashMap能完全替代HashTable吗?

hash table虽然性能上不如ConcurrentHashMap,但并不能完全被取代,两者的迭代器的一致性不同的,hash table的迭代器是强一致性的,而ConcurrentHashMap是弱一致的。 ConcurrentHashMap的迭代器方法都是弱一致性的。关于弱一致性的解释可以看这篇博客

在JDK1.8的ConcurrentHashMap实现中,它的迭代器有KeySetView、ValuesView和EntrySetView这三种,我们来看获取KeySetView迭代器方的法:

/**
* <p>The view's iterators and spliterators are
* <a href="package-summary.html#Weakly"><i>weakly consistent</i></a>.
*
* @return the set view
*/
public KeySetView<K,V> keySet() {
    KeySetView<K,V> ks;
    return (ks = keySet) != null ? ks : (keySet = new KeySetView<K,V>(this, null));
}

可以看到,注释中也说明该迭代器是弱一致性的,我们来看一下KeySetView类的iterator方法:

public Iterator<K> iterator() {
    Node<K,V>[] t;
    ConcurrentHashMap<K,V> m = map;
    int f = (t = m.table) == null ? 0 : t.length;
    return new KeyIterator<K,V>(t, f, 0, f, m);
}

最终是返回了一个KeyIterator类对象,在KeyIterator上调用next方法时,最终实际调用到了Traverser.advance()方法,我们来看一下Traverser的构造方法以及advance()方法:

Traverser(Node<K,V>[] tab, int size, int index, int limit) {
    this.tab = tab;
    this.baseSize = size;
    this.baseIndex = this.index = index;
    this.baseLimit = limit;
    this.next = null;
}
final Node<K,V> advance() {
    Node<K,V> e;
    if ((e = next) != null)
        e = e.next;
    for (;;) {
        Node<K,V>[] t; int i, n;  // must use locals in checks
        if (e != null)
            return next = e;
        if (baseIndex >= baseLimit || (t = tab) == null ||
            (n = t.length) <= (i = index) || i < 0)
            return next = null;
        if ((e = tabAt(t, i)) != null && e.hash < 0) {
            if (e instanceof ForwardingNode) {
                tab = ((ForwardingNode<K,V>)e).nextTable;
                e = null;
                pushState(t, i, n);
                continue;
            }
            else if (e instanceof TreeBin)
                e = ((TreeBin<K,V>)e).first;
            else
                e = null;
        }
        if (stack != null)
            recoverState(n);
        else if ((index = i + baseSize) >= n)
            index = ++baseIndex; // visit upper slots if present
    }
}

这个方法在遍历底层数组。在遍历过程中,如果已经遍历的数组上的内容变化了,迭代器不会抛出ConcurrentModificationException异常。如果未遍历的数组上的内容发生了变化,则有可能反映到迭代过程中。这就是ConcurrentHashMap迭代器弱一致的表现。

但是Hashtable的任何操作都会把整个表锁住,是阻塞的。好处是总能获取最实时的更新,比如说线程A调用putAll写入大量数据,期间线程B调用迭代器方法,线程B就会被阻塞,直到线程A完成putAll,因此线程B肯定能获取到线程A写入的完整数据。坏处是所有调用都要排队,效率较低。 

所以,ConcurrentHashMap并不能完全替代HashTable。

相关博客

Java集合之HashMap详解

Java集合之HashTable详解

Java并发编程之synchronized详解

参考资料

Java7/8 中的 HashMap 和 ConcurrentHashMap 全解析

ConcurrentHashMap(JDK1.8)为什么要放弃Segment

为什么ConcurrentHashMap是弱一致的

猜你喜欢

转载自blog.csdn.net/qq_38293564/article/details/80781266