集合运算
retainAll
最近写代码时,遇到对集合进行操作:交集、并集、差集。
对于并集,最开始写法如下:
@Test
public void should_get_union_with_removeAll_and_addAll() {
List<Long> result = Lists.newArrayList();
List<Long> s1 = Lists.newArrayList(1L, 2L, 3L);
List<Long> s2 = Lists.newArrayList(2L, 3L, 3L, 7L);
result.addAll(s1);
result.removeAll(s2);
result.addAll(s2);
assertThat(result).isSubsetOf(1L, 2L, 3L, 7L);
}
这里利用List 的removeAll()和 addAll()方法,先从第一个集合中去掉两个集合的共同元素,再加上第二个集合。点开removeAll()函数源码如下:
public boolean removeAll(Collection<?> c) {
Objects.requireNonNull(c);
return batchRemove(c, false);
}
// ...
public boolean retainAll(Collection<?> c) {
Objects.requireNonNull(c);
return batchRemove(c, true);
}
private boolean batchRemove(Collection<?> c, boolean complement) {
final Object[] elementData = this.elementData;
int r = 0, w = 0;
boolean modified = false;
try {
for (; r < size; r++)
if (c.contains(elementData[r]) == complement)
elementData[w++] = elementData[r];
} finally {
// Preserve behavioral compatibility with AbstractCollection,
// even if c.contains() throws.
if (r != size) {
System.arraycopy(elementData, r,
elementData, w,
size - r);
w += size - r;
}
if (w != size) {
// clear to let GC do its work
for (int i = w; i < size; i++)
elementData[i] = null;
modCount += size - w;
size = w;
modified = true;
}
}
return modified;
}
也就是说 removeAll() 和 retainAll() 实现类似,都是线性复杂度,不过containes()函数也用到了一次循环,因此复杂度都是n^2。
同样的点开HashSet的源码(在AbstractCollection.java中):
// AbstractCollection 中
public boolean retainAll(Collection<?> c) {
Objects.requireNonNull(c);
boolean modified = false;
Iterator<E> it = iterator();
while (it.hasNext()) {
if (!c.contains(it.next())) {
it.remove();
modified = true;
}
}
return modified;
}
因此,很容易想到用这几个方法获取集合的交集和差集。代码如下:
// 并集
public Set<Integer> getUnion(Set<Integer> set1, Set<Integer> set2) {
Set<Integer> result = new HashSet<>();
result.addAll(set1);
result.removeAll(set2);
result.addAll(set2);
return result;
}
// 交集
public Set<Integer> getIntersection(Set<Integer> set1, Set<Integer> set2) {
Set<Integer> result = new HashSet<>();
result.addAll(set1);
// 保留所有set2
result.retainAll(set2);
return result;
}
// 差集
public Set<Integer> getSubtraction(Set<Integer> set1, Set<Integer> set2) {
Set<Integer> result = new HashSet<>();
result.addAll(set1);
result.removeAll(set2);
return result;
}
这里解释下:
s1.retainAll(s2):对集合s1,保留出现在s2中的元素。
s1.removeAll(s2):对集合s1,删除出现在s2中的元素。
这俩方法在集合类中都存在。不过使用时还是有点问题。
@Test
public void should_get_union_with_removeAll_and_addAll() {
List<Long> result = Lists.newArrayList();
List<Long> s1 = Lists.newArrayList(1L, 2L, 3L);
List<Long> s2 = Lists.newArrayList(2L, 3L, 3L, 7L);
result.addAll(s1);
result.removeAll(s2);
result.addAll(s2);
assertThat(result).hasSize(4);// 断言失败,集合中出现两个 3
}
比如这个例子中,原始集合中出现重复的元素会被保留。
Apache
因此开始找工具类,Apache的common-collection包提供了集合操作方法。添加依赖如下:
<!-- https://mvnrepository.com/artifact/org.apache.commons/commons-collections4 -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-collections4</artifactId>
<version>4.2</version>
</dependency>
代码如下:
@Test
public void should_get_union_with_apache_union() {
List<Long> s1 = Lists.newArrayList(1L, 2L, 3L);
List<Long> s2 = Lists.newArrayList(2L, 3L, 3L, 7L);
List<Long> result = (List<Long>) CollectionUtils.union(s1, s2);
assertThat(result).hasSize(4);// 断言失败,集合中出现两个 3
}
测试依旧失败,点开源码:
public static <O> Collection<O> union(final Iterable<? extends O> a, final Iterable<? extends O> b) {
final SetOperationCardinalityHelper<O> helper = new SetOperationCardinalityHelper<>(a, b);
for (final O obj : helper) {
helper.setCardinality(obj, helper.max(obj));
}
return helper.list();
}
private static class SetOperationCardinalityHelper<O> extends CardinalityHelper<O> implements Iterable<O> {
// ...
/**
* Add the object {@code count} times to the result collection.
* @param obj the object to add
* @param count the count
*/
public void setCardinality(final O obj, final int count) {
for (int i = 0; i < count; i++) {
newList.add(obj);
}
}
// ...
}
复杂度也是n^2,依旧没有解决去重。
不过这不是问题,我们知道List和Set是可以相互转化的,利用Set集合元素的唯一性就可以解决:
@Test
public void should_get_union_with_apache_union() {
List<Long> s1 = Lists.newArrayList(1L, 2L, 3L);
List<Long> s2 = Lists.newArrayList(2L, 3L, 3L, 7L);
Set set1 = new HashSet(s1);
Set set2 = new HashSet(s2);
List<Long> result = (List<Long>) CollectionUtils.union(set1, set2);
assertThat(result).hasSize(4);
}
此时,相当于使用了之前两倍的对象(两个输入set,一个输出set)。
了解后,自己实现一个:
@Test
public void should_get_union_with_self_union() {
List<Long> s1 = Lists.newArrayList(1L, 2L, 3L);
List<Long> s2 = Lists.newArrayList(2L, 3L, 3L, 7L);
List<Long> result = getSelfUnion(s1, s2);
assertThat(result).hasSize(4);
}
public List<Long> getSelfUnion(List<Long> s1, List<Long> s2) {
Set<Long> result = new HashSet(s1.size() + s2.size());
for (Long s : s1) {
result.add(s);
}
for (Long m : s2) {
result.add(m);
}
return (List<Long>) Lists.newArrayList(result);
}
此时能满足之前的需要,复杂度是线性。
性能测试
上边一共三种方式实现集合取并集。利用下边代码测试。:
@Test
public void testTime() {
long cost = 0;
for (int i = 0; i < 20; i++) {
long start = System.currentTimeMillis();
// 方式一
List<Long> result = getSelfUnion(longs1, longs2);
// 方式二
// List<Long> result2 = (List<Long>) CollectionUtils.union(longs1, longs2);
// 方式三
// List<Long> result3 = getUnion(longs1, longs2);// removeAll and addAll
long end = System.currentTimeMillis();
cost = cost + (end - start);
}
System.out.println("longs1: " + longs1.size() + ", longs2: " + longs2.size() + ", average cost: " + cost / 10 + "ms");
}
测试结果如下:
方式一:longs1: 36900, longs2: 16035, average cost: 7ms
方式二:longs1: 36900, longs2: 16035, average cost: 17ms
方式三:longs1: 36900, longs2: 16035, average cost: 3779ms
因此,当数据量比较大时,还是不慎用 removeAll 和 retainAll 的方式进行集合运算。