基于Reids实现一个分布式可重入锁

警告
本文最后更新于 2021-01-19,文中内容可能已过时。

思路:基于Redis原子操作、list结构、发布订阅机制、看门狗续命实现一把分布式独占锁。

共有以下几个类

RedisLock

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296

/**
 * 分布式独占锁,分布式CLH队列锁实现类,非公平锁,此锁未实现可重入功能。
 * <p>
 * 同一个业务键获取锁将会互斥,不同的业务键获取锁不互斥。
 */
@Slf4j
@Component
@ConditionalOnBean(RedisTemplate.class)
@ConditionalOnClass(RedisTemplate.class)
public class RedisLock {
    /**
     * 获取锁超时时长。
     */
    public static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(60);
    /**
     * 持有锁超时时长。
     */
    public static final Duration DEFAULT_HOLD_LOCK_TIMEOUT = Duration.ofSeconds(30);
    /**
     * 进程key,用来标识一个进程。
     */
    public static final String PROCESS_KEY = UUID.randomUUID().toString();
    /**
     * 释放锁信号发布的channel。
     */
    public static final String LOCK_CHANNEL = "redis-lock-channel";

    @Resource
    private RedisTemplate<String, Object> redisTemplate;

    /**
     * 业务lockKey -> 锁对象 映射关系。
     */
    private final Map<String, Object> blockerMap = new ConcurrentHashMap<>();

    /**
     * 线程id-线程 映射关系。
     */
    private final Map<Long, WeakReferenceThread> threadMap = new ConcurrentHashMap<>();

    /**
     * 锁续命看门狗。
     */
    @Resource
    private WatchDog watchDog;

    /**
     * 进程-线程信息,用来标识是哪个进程的哪个线程。
     */
    @Getter
    @Setter
    public static class ProcessThreadInfo {
        /**
         * 进程的唯一标识。
         */
        private String processKey;

        /**
         * 线程id。
         */
        private long threadId;

        /**
         * 业务键。
         */
        private String lockKey;

        /**
         * 获取锁超时时长。
         */
        private long tryLockTimeoutMills;

        /**
         * 持有锁超时时长。
         */
        private long holdLockTimeoutMills;
    }

    /**
     * 线程引用,这里使用弱引用,当一个线程被销毁时,在下次触发GC时,此类型的对象会被回收。
     */
    static class WeakReferenceThread extends WeakReference<Thread> {

        public WeakReferenceThread(Thread thread) {
            super(thread);
        }
    }

    /**
     * 尝试获取锁。
     *
     * @param lockKey 业务键(相同的业务键互斥)。
     * @param timeout 超时时长。
     * @return 是否获取到锁。
     */
    public boolean tryLock(String lockKey, Duration timeout) {
        String lockValue = getLockValue(lockKey);
        Boolean success = redisTemplate.opsForValue().setIfAbsent(lockKey, lockValue, timeout);
        boolean holdLock = nonNull(success) && success;
        if (holdLock) {
            if (log.isDebugEnabled()) {
                Thread t = Thread.currentThread();
                log.debug("线程id[{}]名称[{}]获取锁成功[{}:{}]。", t.getId(), t.getName(), lockKey, lockValue);
            }
            wakeUpWatchdog(lockKey, timeout);
        }
        return holdLock;
    }


    /**
     * 唤醒看门狗。
     *
     * @param lockKey 业务键。
     * @param timeout 超时时长。
     */
    private void wakeUpWatchdog(String lockKey, Duration timeout) {
        Thread t = Thread.currentThread();
        threadMap.putIfAbsent(t.getId(), new WeakReferenceThread(t));
        ProcessThreadInfo processThreadInfo = generateProcessThreadInfo(lockKey, timeout);
        watchDog.watch(processThreadInfo);
    }

    /**
     * 尝试获取锁。
     *
     * @param lockKey 业务键(相同的业务键互斥)。
     * @return 是否获取到锁。
     */
    public boolean tryLock(String lockKey) {
        return tryLock(lockKey, DEFAULT_TIMEOUT);
    }

    /**
     * 释放锁。
     *
     * @param lockKey 锁业务键。
     */
    public void unLock(String lockKey) {
        // 移除当前线程的的映射关系。
        Thread t = Thread.currentThread();
        threadMap.remove(t.getId());
        String lockValue = getLockValue(lockKey);
        // 如果持有锁的线程是当前线程,则释放锁。
        Object redisLockValue = redisTemplate.opsForValue().get(lockKey);
        if (Objects.equals(lockValue, redisLockValue)) {
            if (log.isDebugEnabled()) {
                log.debug("线程id[{}]名称[{}]释放锁成功[{}:{}]。", t.getId(), t.getName(), lockKey, lockValue);
            }
            redisTemplate.delete(lockKey);
        } else {
            if (log.isDebugEnabled()) {
                log.debug("线程id[{}]名称[{}]释放锁跳过,lockKey:{},lockValue:{},redisLockValue:{}。",
                        t.getId(), t.getName(), lockKey, lockValue, redisLockValue);
                ;
            }
        }
        // 无论是否正常释放了锁,均唤醒队列的下一个节点线程。
        String lockQueueKey = getLockQueueKey(lockKey);

        ProcessThreadInfo processThreadInfo = (ProcessThreadInfo) redisTemplate.opsForList().leftPop(lockQueueKey);
        if (processThreadInfo != null) {
            redisTemplate.convertAndSend(getLockChannel(), processThreadInfo);
        }
    }

    /**
     * 尝试获取锁,直到获取到锁或获取锁超时抛出异常。
     *
     * @param lockKey 锁业务键。
     * @throws RedisLockWaitTimeoutException 当超过超时时间还未获取到锁时,抛出此异常。
     */
    public void lock(String lockKey) throws RedisLockWaitTimeoutException {
        lock(lockKey, DEFAULT_TIMEOUT);
    }

    /**
     * 尝试获取锁,直到获取到锁或获取锁超时抛出异常。
     *
     * @param lockKey 锁业务键。
     * @param timeout 超时时长。
     * @throws RedisLockWaitTimeoutException 当超过超时时间还未获取到锁时,抛出此异常。
     */
    public void lock(String lockKey, Duration timeout) throws RedisLockWaitTimeoutException {
        // 尝试获取锁
        boolean success = tryLock(lockKey, timeout);
        if (!success) {
            // 加入等待队列,阻塞。
            addQueueAndAwait(lockKey, timeout);
        }
    }

    /**
     * 未获取到锁,加入队列等待。以下情况将会终止等待。
     * 1.当在队列最前端且锁被释放时,将会获取到锁。
     * 2.等待超时,将会抛出异常。
     *
     * @param lockKey 锁业务键。
     * @param timeout 超时时长。
     * @throws RedisLockWaitTimeoutException 当超过超时时间还未获取到锁时,抛出此异常。
     */
    private void addQueueAndAwait(String lockKey, Duration timeout) throws RedisLockWaitTimeoutException {
        Thread t = Thread.currentThread();
        threadMap.putIfAbsent(t.getId(), new WeakReferenceThread(t));

        if (log.isDebugEnabled()) {
            log.debug("线程id[{}]名称[{}]获取锁失败[{}:{}],加入等待队列。", t.getId(), t.getName(), lockKey, getLockValue(lockKey));
        }

        // 将会在此时间戳之后的时间获取锁超时。
        long timeoutMills = System.currentTimeMillis() + timeout.toMillis();
        ProcessThreadInfo processThreadInfo = generateProcessThreadInfo(lockKey, timeout);

        // 进程锁对象。
        blockerMap.putIfAbsent(lockKey, new Object());
        Object blocker = blockerMap.get(lockKey);
        // 分布式CLH队列锁。
        String lockQueueKey = getLockQueueKey(lockKey);
        redisTemplate.opsForList().rightPush(lockQueueKey, processThreadInfo);
        redisTemplate.expire(lockQueueKey, timeout);

        LockSupport.parkNanos(blocker, timeout.toNanos());
        // 唤醒之后,如果还没有超时,则继续尝试获取锁。
        long currentTimeMills = System.currentTimeMillis();
        if (currentTimeMills < timeoutMills) {
            lock(lockKey, Duration.ofMillis(timeoutMills - currentTimeMills));
        } else {
            log.warn("线程id[{}]名称[{}]获取锁超时![{}:{}]", t.getId(), t.getName(), lockKey, getLockValue(lockKey));
            throw new RedisLockWaitTimeoutException();
        }
    }

    private ProcessThreadInfo generateProcessThreadInfo(String lockKey, Duration timeout) {
        ProcessThreadInfo processThreadInfo = new ProcessThreadInfo();
        processThreadInfo.setProcessKey(PROCESS_KEY);
        processThreadInfo.setThreadId(Thread.currentThread().getId());
        processThreadInfo.setLockKey(lockKey);
        processThreadInfo.setTryLockTimeoutMills(timeout.toMillis());
        processThreadInfo.setHoldLockTimeoutMills(DEFAULT_HOLD_LOCK_TIMEOUT.toMillis());
        return processThreadInfo;
    }

    /**
     * 获取等待锁队列键。
     *
     * @param lockKey 业务键。
     * @return 等待锁队列键。
     */
    private String getLockQueueKey(String lockKey) {
        return lockKey + "-queue";
    }

    /**
     * 获取释放锁后消息发送的channel。
     *
     * @return 释放锁后消息发送的channel。
     */
    public static String getLockChannel() {
        return LOCK_CHANNEL;
    }

    /**
     * 获取当前线程锁应存储的值。
     * 如果后面此锁支持可重入,则这个值是有用的。
     *
     * @param lockKey 业务键。
     * @return 锁存储值。
     */
    private String getLockValue(String lockKey) {
        return getLockValue(lockKey, Thread.currentThread());
    }

    /**
     * 获取锁存储值。
     * 如果后面此锁支持可重入,则这个值是有用的。
     *
     * @param lockKey 业务键。
     * @param thread  线程。
     * @return 锁存储值。
     */
    String getLockValue(String lockKey, Thread thread) {
        return PROCESS_KEY + ":" + thread.getId();
    }

    /**
     * 根绝线程id获取等待锁的线程。
     *
     * @param threadId 线程id。
     * @return 正在等待获取锁的线程。
     */
    Thread getThread(long threadId) {
        WeakReferenceThread weakReferenceThread = threadMap.get(threadId);
        return weakReferenceThread == null ? null : weakReferenceThread.get();
    }
}

RedisLockChannelListener

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
/**
 * redis分布式锁消息监听器。
 */
@Slf4j
@Component
@ConditionalOnBean(RedisTemplate.class)
@ConditionalOnClass(RedisTemplate.class)
public class RedisLockChannelListener implements MessageListener {
    @Resource
    private RedisLock redisLock;

    @Override
    public void onMessage(Message message, byte[] pattern) {
        byte[] body = message.getBody();
        // 解析消息,如果解析结果为null、或者需要唤醒的进程中的线程,不在当前进程中,则直接返回,不做处理。
        RedisLock.ProcessThreadInfo processThreadInfo = JSON.parseObject(new String(body), RedisLock.ProcessThreadInfo.class);
        if (!Objects.equals(processThreadInfo.getProcessKey(), RedisLock.PROCESS_KEY)) {
            return;
        }
        // 获取消息中的线程id,从线程缓存映射关系中,获取到线程。
        long threadId = processThreadInfo.getThreadId();
        Thread thread = redisLock.getThread(threadId);
        if (thread == null) {
            if (log.isDebugEnabled()) {
                log.debug("下一个节点[{}]线程未存活,将继续向下寻找是否有等待锁的线程", threadId);
            }
            //如果线程为空,获取队列中等待锁的下一个进程及线程信息。如果存在,尝试唤醒对应的进程中的线程。
            redisLock.unLock(processThreadInfo.getLockKey());
        } else {
            if (log.isDebugEnabled()) {
                log.debug("线程id[{}]名称[{}]正在阻塞,唤醒此线程。", thread.getId(), thread.getName());
            }
            // 如果线程不为空,唤醒线程。
            LockSupport.unpark(thread);
            Thread.yield();
        }
    }
}

RedisLockConfig

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

@Configuration
@ConditionalOnBean(RedisTemplate.class)
@ConditionalOnClass(RedisTemplate.class)
public class RedisLockConfig {

    @Autowired(required = false)
    private RedisMessageListenerContainer redisMessageListenerContainer;
    @Resource
    private RedisLockChannelListener redisLockChannelListener;
    @Resource
    private RedisConnectionFactory connectionFactory;

    @PostConstruct
    public void init() {
        if (redisMessageListenerContainer != null) {
            ChannelTopic channelTopic = new ChannelTopic(RedisLock.getLockChannel());
            redisMessageListenerContainer.addMessageListener(redisLockChannelListener, channelTopic);
        }
    }

    @Bean
    @ConditionalOnMissingBean(RedisMessageListenerContainer.class)
    public RedisMessageListenerContainer redisMessageListenerContainer() {
        RedisMessageListenerContainer container = new RedisMessageListenerContainer();
        container.setConnectionFactory(connectionFactory);
        ChannelTopic channelTopic = new ChannelTopic(RedisLock.getLockChannel());
        container.addMessageListener(redisLockChannelListener, channelTopic);
        return container;
    }
}

RedisLockWaitTimeoutException

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
/**
 * 当尝试获取分布式独占锁超时的时候,抛出此异常。
 *
 * @see RedisLock
 */
public class RedisLockWaitTimeoutException extends BusinessException {
    public RedisLockWaitTimeoutException() {
        super(ResponseCode.REDIS_LOCK_AWAIT_TIME_OUT);
    }
}

WatchDog

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

@Slf4j
@Component
@ConditionalOnBean(RedisLock.class)
@ConditionalOnClass(RedisTemplate.class)
public class WatchDog {
    /**
     * 线程id -> lockKey 映射关系。
     */
    private final Map<Long, ProcessThreadInfo> threadLockKeyMap = new ConcurrentHashMap<>();

    /**
     * 看门狗线程。守护线程。
     */
    private Thread watchDogThread;
    /**
     * 看门狗下次睡眠时间。
     */
    private volatile long nextSleepTimeMills;

    @Resource
    private RedisTemplate<String, Object> redisTemplate;
    @Resource
    private RedisLock redisLock;

    public void watch(ProcessThreadInfo processThreadInfo) {
        Thread t = Thread.currentThread();
        threadLockKeyMap.putIfAbsent(t.getId(), processThreadInfo);
        if (nextSleepTimeMills > processThreadInfo.getTryLockTimeoutMills() - 100) {
            if (log.isDebugEnabled()) {
                log.debug("线程id[{}]名称[{}]抢到锁了,唤醒看门狗", t.getId(), t.getName());
            }
            LockSupport.unpark(watchDogThread);
        }
    }

    @PostConstruct
    public void init() {
        watchDogThread = new Thread(() -> {
            while (true) {
                Set<Map.Entry<Long, ProcessThreadInfo>> entries = threadLockKeyMap.entrySet();
                List<Long> removeKeys = new ArrayList<>();
                long minTimeoutMills = Integer.MAX_VALUE;
                for (Map.Entry<Long, ProcessThreadInfo> entry : entries) {
                    ProcessThreadInfo processThreadInfo = entry.getValue();
                    String lockKey = processThreadInfo.getLockKey();
                    long holdLockTimeoutMills = processThreadInfo.getHoldLockTimeoutMills();
                    Thread thread = redisLock.getThread(processThreadInfo.getThreadId());
                    minTimeoutMills = Math.min(minTimeoutMills, holdLockTimeoutMills);
                    Object value = redisTemplate.opsForValue().get(lockKey);

                    // 如果key依旧存在,并且使用锁的线程依然在运行,则为锁续命。
                    if (Objects.nonNull(value) && thread != null
                            && Objects.equals(redisLock.getLockValue(lockKey, thread), value)) {
                        if (log.isDebugEnabled()) {
                            log.debug("看门狗开始为线程id[{}]占有的锁续命:{}", processThreadInfo.getThreadId(), JSON.toJSONString(processThreadInfo));
                        }
                        redisTemplate.expire(lockKey, holdLockTimeoutMills, TimeUnit.MILLISECONDS);
                    } else {
                        // 如果key不存在,说明锁已经释放了。此时应该移除监测。
                        removeKeys.add(entry.getKey());
                    }
                }
                if (log.isDebugEnabled() && !removeKeys.isEmpty()) {
                    log.debug("看门狗将不再监测以下key[{}]", JSON.toJSONString(removeKeys));
                }
                for (Long lockKey : removeKeys) {
                    threadLockKeyMap.remove(lockKey);
                }
                this.nextSleepTimeMills = Math.max(minTimeoutMills / 3, 1000);
                if (log.isDebugEnabled()) {
                    log.debug("看门狗将睡眠{}ms", nextSleepTimeMills);
                }
                LockSupport.parkNanos(this, Duration.ofMillis(nextSleepTimeMills).toNanos());
            }
        });
        watchDogThread.setName("redis-lock-watch-dog");
        watchDogThread.setDaemon(true);
        watchDogThread.start();
    }
}