【防止重复提交】Redis + AOP + 注解的方式实现分布式锁

工作原理

分布式环境下,可能会遇到用户对某个接口被重复点击的场景,为了防止接口重复提交造成的问题,可用 Redis 实现一个简单的分布式锁来解决问题。

在 Redis 中, SETNX 命令是可以帮助我们实现互斥。SETNX 即 SET if Not eXists (对应 Java 中的 setIfAbsent 方法),如果 key 不存在的话,才会设置 key 的值。如果 key 已经存在, SETNX 啥也不做。

需求实现

  1. 自定义一个防止重复提交的注解,注解中可以携带到期时间和一个参数的key
  2. 为需要防止重复提交的接口添加注解
  3. 注解AOP会拦截加了此注解的请求,进行加解锁处理并且添加注解上设置的key超时时间
  4. Redis 中的 key = token + "-" + path + "-" + param_value; (例如:17800000001 + /api/subscribe/ + zhangsan)
  5. 如果重复调用某个加了注解的接口且key还未到期,就会返回重复提交的Result。

1)自定义防重复提交注解

自定义防止重复提交注解,注解中可设置 超时时间 + 要扫描的参数(请求中的某个参数,最终拼接后成为Redis中的key)

package com.lihw.lihwtestboot.noRepeatSubmit;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
 * 防重复提交注解
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface NoRepeatSubmit {
    
    

    /**
     * 锁过期的时间
     */
    int seconds() default 5;

    /**
     * 要扫描的参数
     */
    String scanParam() default "";
}

2)定义防重复提交AOP切面

@Pointcut("@annotation(noRepeatSubmit)") 表示切点表达式,它使用了注解匹配的方式来选择被注解 @NoRepeatSubmit 标记的方法。

package com.lihw.lihwtestboot.noRepeatSubmit;

import com.alibaba.fastjson.JSONObject;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.io.BufferedReader;
import java.io.IOException;
import java.util.UUID;
/**
 * 重复提交aop
 */
@Aspect
@Component
public class RepeatSubmitAspect {
    
    

    private static final Logger LOGGER = LoggerFactory.getLogger(RepeatSubmitAspect.class);

    @Autowired
    private RedisLock redisLock;

    @Pointcut("@annotation(noRepeatSubmit)")
    public void pointCut(NoRepeatSubmit noRepeatSubmit) {
    
    
    }

    @Around("pointCut(noRepeatSubmit)")
    public Object around(ProceedingJoinPoint pjp, NoRepeatSubmit noRepeatSubmit) throws Throwable {
    
    

        //获取基本信息
        ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        HttpServletRequest request = attributes.getRequest();
        Assert.notNull(request, "request can not null");
        int lockSeconds = noRepeatSubmit.seconds();//过期时间
        String threadName = Thread.currentThread().getName();// 获取当前线程名称
        String param = noRepeatSubmit.scanParam();//请求参数
        String path = request.getServletPath();
        String type = request.getMethod();
        String param_value = "";

        if (type.equals("POST")){
    
    
            param_value = JSONObject.parseObject(new BodyReaderHttpServletRequestWrapper(request).getBodyString()).getString(param);
        }else if (type.equals("GET")){
    
    
            param_value = request.getParameter(param);
        }

        String token = request.getHeader("uid");
        LOGGER.info("线程:{}, 接口:{},重复提交验证",threadName,path);
        String key;
        if (!"".equals(param) && param != null){
    
    
            key = token + "-" + path + "-" + param_value;//生成key

        }else {
    
    
            key = token + "-" + path;//生成key
        }

        String clientId = getClientId();// 调接口时生成临时value(UUID)

        // 用于添加锁,如果添加成功返回true,失败返回false 
        boolean isSuccess = redisLock.tryLock(key, clientId, lockSeconds);
      
        ApiResult result = new ApiResult();
        if (isSuccess) {
    
    
            LOGGER.info("加锁成功:接口 = {}, key = {}", path, key);
            // 获取锁成功
            Object obj;
            try {
    
    
                // 执行进程
                obj = pjp.proceed();// aop代理链执行的方法
            } finally {
    
    
                // 据key从redis中获取value
                if (clientId.equals(redisLock.get(key))) {
    
    
                    // 解锁
                    redisLock.releaseLock(key, clientId);
                    LOGGER.info("解锁成功:接口={}, key = {},",path, key);
                }
            }
            return obj;
        } else {
    
    
            // 添加锁失败,认为是重复提交的请求
            LOGGER.info("重复请求:接口 = {}, key = {}",path, key);
            result.setData("重复提交");
            return result;
        }
    }


    private String getClientId() {
    
    
        return UUID.randomUUID().toString();
    }

    public static String getRequestBodyData(HttpServletRequest request) throws IOException{
    
    
        BufferedReader bufferReader = new BufferedReader(request.getReader());
        StringBuilder sb = new StringBuilder();
        String line = null;
        while ((line = bufferReader.readLine()) != null) {
    
    
            sb.append(line);
        }
        return sb.toString();
    }
}

3)RedisLock 工具类

package com.lihw.lihwtestboot.noRepeatSubmit;

import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
import java.util.concurrent.TimeUnit;


@Service
public class RedisLock {
    
    

    private static final Logger logger = LoggerFactory.getLogger(RedisLock.class);

    /**  不设置过期时长 */
    public final static long NOT_EXPIRE = -1;

    @Autowired
    private StringRedisTemplate redisTemplate;

    /**
     * @param lockKey   加锁键
     * @param clientId  加锁客户端唯一标识(采用UUID)
     * @param seconds   锁过期时间
     * @return
     */
    public boolean tryLock(String lockKey, String clientId, long seconds) {
    
    
        if (redisTemplate.opsForValue().setIfAbsent(lockKey, clientId,seconds, TimeUnit.SECONDS)) {
    
    
            return true;//得到锁
        }else{
    
    
            return false;
        }
    }

    /**
     * 与 tryLock 相对应,用作释放锁
     *
     * @param lockKey
     * @param clientId
     * @return
     */
    public boolean releaseLock(String lockKey, String clientId) {
    
    
        String currentValue = redisTemplate.opsForValue().get(lockKey);
        try {
    
    
            if (!StringUtils.isEmpty(currentValue) && currentValue.equals(clientId)) {
    
    
                redisTemplate.opsForValue().getOperations().delete(lockKey);
                return true;
            }else {
    
    
                return false;
            }
        } catch (Exception e) {
    
    
            logger.error("解锁异常,,{}" , e);
            return false;
        }
    }

    /**
     * 获取
     * @param key
     * @return
     */
    public String get(String key) {
    
    
        return get(key, NOT_EXPIRE);
    }

    public String get(String key, long expire) {
    
    
        String value = redisTemplate.opsForValue().get(key);
        if(expire != NOT_EXPIRE){
    
    
            redisTemplate.expire(key, expire, TimeUnit.SECONDS);
        }
        return value;
    }

    /**
     * 删除
     * @param key
     */
    public void delete(String key) {
    
    
        redisTemplate.delete(key);
    }
}

4)过滤器 + 请求工具类

Filter类

package com.lihw.lihwtestboot.noRepeatSubmit;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.web.servlet.ServletComponentScan;
import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;


@ServletComponentScan
@WebFilter(urlPatterns = "/*",filterName = "channelFilter")
public class ChannelFilter implements Filter {
    
    

    private final Logger logger = LoggerFactory.getLogger(this.getClass());

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
    
    
    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
    
    
        logger.info("-----------------------Execute filter start---------------------");
        // 防止流读取一次后就没有了, 所以需要将流继续写出去
        HttpServletRequest httpServletRequest = (HttpServletRequest) servletRequest;
        ServletRequest requestWrapper = new BodyReaderHttpServletRequestWrapper(httpServletRequest);
        filterChain.doFilter(requestWrapper, servletResponse);
    }

}

BodyReaderHttpServletRequestWrapper

对GET和POST请求的获取参数方法进行了封装

package com.lihw.lihwtestboot.noRepeatSubmit;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;

public class BodyReaderHttpServletRequestWrapper extends HttpServletRequestWrapper{
    
    

    /**
     * Request请求参数获取处理类
     */
    private final byte[] body;

    public BodyReaderHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
    
    
        super(request);
        String sessionStream = getBodyString(request);
        body = sessionStream.getBytes(StandardCharsets.UTF_8);
    }

    /**
     * 获取请求Body
     *
     * @param request
     * @return
     */
    private String getBodyString(final ServletRequest request) {
    
    
        StringBuilder sb = new StringBuilder();
        InputStream inputStream = null;
        BufferedReader reader = null;
        try {
    
    
            inputStream = cloneInputStream(request.getInputStream());
            reader = new BufferedReader(new InputStreamReader(inputStream, Charset.forName("UTF-8")));
            String line = "";
            while ((line = reader.readLine()) != null) {
    
    
                sb.append(line);
            }
        } catch (IOException e) {
    
    
            e.printStackTrace();
        } finally {
    
    
            if (inputStream != null) {
    
    
                try {
    
    
                    inputStream.close();
                } catch (IOException e) {
    
    
                    e.printStackTrace();
                }
            }
            if (reader != null) {
    
    
                try {
    
    
                    reader.close();
                } catch (IOException e) {
    
    
                    e.printStackTrace();
                }
            }
        }
        return sb.toString();
    }
    public String getBodyString() {
    
    
        return new String(body, StandardCharsets.UTF_8);
    }
    /**
     * Description: 复制输入流
     *
     * @param inputStream
     * @return
     */
    public InputStream cloneInputStream(ServletInputStream inputStream) {
    
    
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        byte[] buffer = new byte[1024];
        int len;
        try {
    
    
            while ((len = inputStream.read(buffer)) > -1) {
    
    
                byteArrayOutputStream.write(buffer, 0, len);
            }
            byteArrayOutputStream.flush();
        } catch (IOException e) {
    
    
            e.printStackTrace();
        }
        InputStream byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray());
        return byteArrayInputStream;
    }

    @Override
    public BufferedReader getReader() throws IOException {
    
    
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
    
    
        final ByteArrayInputStream bais = new ByteArrayInputStream(body);

        return new ServletInputStream() {
    
    

            @Override
            public int read() throws IOException {
    
    
                return bais.read();
            }

            @Override
            public boolean isFinished() {
    
    
                return false;
            }

            @Override
            public boolean isReady() {
    
    
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {
    
    
            }
        };
    }
}

5)测试Controller

package com.lihw.lihwtestboot.noRepeatSubmit;

import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import javax.validation.constraints.NotEmpty;

@RestController
@RequestMapping("/api")
@Validated
public class noRepeatSubmitController {
    
    

    @GetMapping("/subscribe/{channel}")
    @NoRepeatSubmit(seconds = 10,scanParam = "username")
    public ApiResult subscribe(@RequestHeader(name = "uid") String phone,@RequestHeader(name = "username") String username,@PathVariable("channel") @NotEmpty(message = "channel不能为空") String channel) {
    
    

        System.out.println("phone=" + phone);
        System.out.println("username=" + username);
        System.out.println("channel=" + channel);

        try {
    
    
            Thread.sleep(5000);//模拟耗时
        } catch (InterruptedException e) {
    
    
            e.printStackTrace();
        }

        return new ApiResult("success","data");
    }
}

6)测试结果

重复点击

猜你喜欢

转载自blog.csdn.net/weixin_44783506/article/details/136041290