Java CountDownLatch 实现原理

java.util.concurrent.CountDownLatch

1 CountDownLatch 简介

CountDownLatch 是 JUC 中基于 AQS 实现的倒计时锁存器,通常用于控制一个或多个线程等待其他线程完成某个任务,然后再继续执行。示意图如下:

count-down-latch

2 CountDownLatch 使用示例

接下来我们介绍一种常见的 CountDownLatch 使用场景:从数据库中读取 100 万条数据进行批量处理。在处理过程中,我们会将这一百万条数据分成若干段,然后将每一段提交给线程池进行处理。在这种情况下,我们需要在主线程中进行阻塞等待,确保每一段任务都处理完成后再继续后续的操作,以避免主线程提前退出。实现如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
// 1. 从库中取出 100 万条数据
List<Dto> list = databaseDao.search("xxx");
// 2. 将 list 分段,每段 1000 条数据
List<List<Dto>> segmentedLists = Lists.partition(list, 1000);
// 3. 创建 cdl 计数器,将数量初始化为 list 分段数
CountDownLatch cdl = new CountDownLatch(segmentedLists.size());
// 4. 将每段提交给线程池处理
segmentedLists.forEach(it -> MY_POOL.submit(() -> {
    // 5. 处理业务逻辑
    ...
    // 6. 处理完这批后,将计数器 - 1
    cdl.countDown();
}));
// 7. 主线程等待所有分段处理完成
cdl.await();
// 8. 继续处理后续逻辑

在该示例中,主线程将所有任务提交到线程池后,就会通过 await() 阻塞等待,直到所有任务处理完成再醒来,继续处理后续任务。

3 CountDownLatch 初始化

由以上示例可知,CountDownLatch 在初始化时会设定资源数量:

1
2
3
4
public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    this.sync = new Sync(count);
}

该方法就是将 AQS 中的 state 变量的值设置为 count

4 await() 实现分析

CountDownLatchawait() 是基于 AQS 实现的:

1
2
3
public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

acquireSharedInterruptibly() 是 AQS 中获取共享资源的方法,当线程已被中断时会直接抛出异常,否则调用 tryAcquireShared 削减资源数量:

1
2
3
4
5
6
7
public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    if (tryAcquireShared(arg) < 0)
        doAcquireSharedInterruptibly(arg);
}

tryAcquireShared() 是一个 模板方法,其真正实现在 CountDownLatch 的内部类 Sync 中:

1
2
3
protected int tryAcquireShared(int acquires) {
    return (getState() == 0) ? 1 : -1;
}

如果剩余资源数量为 0 了,表明所有任务都处理完了,则 await() 会直接返回;否则表明仍有任务在处理,进而调用 doAcquireSharedInterruptibly() 将自己阻塞:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
private void doAcquireSharedInterruptibly(int arg)
    throws InterruptedException {
    final Node node = addWaiter(Node.SHARED);// 加到队尾
    boolean failed = true;
    try {
        for (;;) {
            // 获取资源数,如果是 0 则返回
            final Node p = node.predecessor();
            if (p == head) {
                int r = tryAcquireShared(arg);
                if (r >= 0) {
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            // 还有剩余资源,则将线程阻塞,并且响应中断
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt())
                throw new InterruptedException();
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}

阻塞中的主线程在所有任务线程都调用 countDown() 后,会被唤醒,进而继续执行后续的任务。

5 countDown() 实现分析

所谓的 countDown() 就是将 state 值减 1,因此这里直接调用了 AQS 的 releaseShared() 方法:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
public void countDown() {
    sync.releaseShared(1);
}

// AbstractQueuedSynchronizer#releaseShared
public final boolean releaseShared(int arg) {
    // 获取资源数,只有资源数为 0,才进入 if 方法体
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}

// CountDownLatch.Sync#tryReleaseShared
protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    for (;;) {
        int c = getState();
        if (c == 0)
            return false;
        int nextc = c-1;
        if (compareAndSetState(c, nextc))
            return nextc == 0;
    }
}

由上可知,tryReleaseShared() 只有释放资源后,资源剩余数量恰好为 0 的线程会返回 true,进而调用 doReleaseShared() 方法唤醒阻塞在等待队列上的所有线程;否则均返回 false。而且当倒计时归零后再次调用该方法,依旧返回 false。


欢迎关注我的公众号,第一时间获取文章更新:

微信公众号

相关内容