单个用户单个接口请求次数限制

通过注解的形式去实现,需要用到的jar是

# gradle
implementation("net.jodah:expiringmap:0.5.8")
或者
# maven
<dependency>
   <groupId>net.jodah</groupId>
   <artifactId>expiringmap</artifactId>
   <version>0.5.8</version>
</dependency>

新建注解类

package com.yulisao.common;

import java.lang.annotation.*;

/**
 * 请求次数限制
 * author yulisao
 * createDate 2023/5/5
 */
@Documented
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface LimitRequest {
    
    
    long time() default 60*1000; // 单位时间内 ,默认一分钟
    int count() default 10; // 单位时间内限制请求次数, 默认10次
}

新建一个切面类

package com.yulisao.common;

import net.jodah.expiringmap.ExpirationPolicy;
import net.jodah.expiringmap.ExpiringMap;
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * 请求次数限制切面
 * author yulisao
 * createDate 2023/5/5
 */
@Aspect
@Component
public class LimitRequestAspect {
    
    
    private Logger log = LoggerFactory.getLogger(this.getClass());
    private static ConcurrentHashMap<String, ExpiringMap<String, Integer>> book = new ConcurrentHashMap<>();

    /**
     * url上带参数的请求,如需限制请求次数,需在这里配置url前缀(也可以改成读取配置表)
     * 比如下载文件的url是 '/file/upload?id=3' 或者 '/file/upload/3' ,应当配置成 ’/file/upload‘
     * 因为每次上传,由于拼接了参数,其完整的url都不一样,后面是根据url+ip来累计请求次数的
     */
    private List<String> spcUrlList = Arrays.asList(          
            "/file/upload",
            "/file/dowm",
            "/user/update"
    );

    // 定义切点 让所有有@LimitRequest注解的方法都执行切面方法
    @Pointcut("@annotation(limitRequest)")
    public void excudeService(LimitRequest limitRequest) {
    
    
    }

    @Around("excudeService(limitRequest)")
    public Object doAround(ProceedingJoinPoint pjp, LimitRequest limitRequest) throws Throwable {
    
    

        RequestAttributes ra = RequestContextHolder.getRequestAttributes();
        ServletRequestAttributes sra = (ServletRequestAttributes) ra;
        HttpServletRequest request = sra.getRequest();
        String ip = getIpAddr(request);
        String url = request.getServletPath();
        log.info("request url is " + url);
        log.info("request ip is " + ip);

        // 带参数的url,取前面固定不变的部分作为url存map的key
        String prefix = getPathPrefix(url);
        if (StringUtils.isNotBlank(prefix)) {
    
    
            url = prefix;
        }

        // 根据请求的url+用户真实ip作为key,记录单位时间内请求次数
        ExpiringMap<String, Integer> uc = book.getOrDefault(url, ExpiringMap.builder().variableExpiration().build());
        Integer uCount = uc.getOrDefault(ip, 0);
        log.info("request uCount is " + uCount);

        if (uCount >= limitRequest.count()) {
    
     // 超过次数,不执行目标方法
            throw new Exception("请求频繁,请稍后在试!");
        } else if (uCount == 0){
    
     // 第一次请求时,设置有效时间
            uc.put(ip, uCount + 1, ExpirationPolicy.CREATED, limitRequest.time(), TimeUnit.MILLISECONDS);
        } else {
    
     // 未超过次数, 记录加一
            uc.put(ip, uCount + 1);
        }
        book.put(url, uc);

        // result的值就是被拦截方法的返回值
        return pjp.proceed();
    }

    /**
     * 获取请求IP
     * @param request
     * @return
     */
    public static String getIpAddr(HttpServletRequest request) {
    
    
        String ipAddress = request.getHeader("x-forwarded-for");
        if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
    
    
            ipAddress = request.getHeader("Proxy-Client-IP");
        }
        if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
    
    
            ipAddress = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
    
    
            ipAddress = request.getRemoteAddr();
            if (ipAddress.equals("127.0.0.1") || ipAddress.equals("0:0:0:0:0:0:0:1")) {
    
    
                InetAddress inet = null; //根据网卡取本机配置的IP
                try {
    
    
                    inet = InetAddress.getLocalHost();
                } catch (UnknownHostException e) {
    
    
                    e.printStackTrace();
                }
                ipAddress = inet.getHostAddress();
            }
        }
        //对于通过多个代理的情况,第一个IP为客户端真实IP,多个IP按照逗号分割
        if (ipAddress != null && ipAddress.length() > 15) {
    
     // ***.***.***.***
            if (ipAddress.indexOf(",") > 0) {
    
    
                ipAddress = ipAddress.substring(0, ipAddress.indexOf(",")); // 截取第一个IP
            }
        }
        return ipAddress;
    }

    private String getPathPrefix(String url) {
    
    
        for(int i=0; i < spcUrlList.size(); i++){
    
    
            Pattern pattern = Pattern.compile(spcUrlList.get(i));
            Matcher matcher = pattern.matcher(url);
            if(matcher.find()){
    
      //matcher.find()-为模糊查询   matcher.matches()-为精确查询
                return spcUrlList.get(i);
            }
        }

        return null;
    }
}

给需要限制请求次数的接口添加自定义注解

@ApiOperation("上传图片")
@GetMapping("/file/upload/{id}")
@LimitRequest(time = 60*1000, count = 10) // 两个参数,这里也可以重新赋值
public void workUpLoadPic(@PathVariable("id") String id){
    
    
	// dosomething...
}
  • 对于url固定不变的,给接口上直接加上LimitRequest注解即可。可以为不同的接口给不同的限制策略,比如获取验证码接口一分钟一次, 实名认证一天三次等等
  • 而url上带参数的,需要获取url前面固定的前缀作为url的唯一标识,这样后面每请求一次才会被累加一次记录下来。不然每次请求都是一个新的key存入map,其val都是1,永远不会超限(除非是重复请求参数不变)。获取url前缀我用的是matcher.find,当然也可以换成indexof,startWith等思路,灵活应用就好。

除了注解的实现方式, 也可以通过拦截器+redis缓存之类的去实现。

猜你喜欢

转载自blog.csdn.net/qq_29539827/article/details/130533204