api接口限流 防止恶意刷接口

元气小坏坏 提交于 2019-12-26 20:42:03

api限流的场景

限流的需求出现在许多常见的场景中

1.秒杀活动,有人使用软件恶意刷单抢货,需要限流防止机器参与活动
2.某api被各式各样系统广泛调用,严重消耗网络、内存等资源,需要合理限流
3.淘宝获取ip所在城市接口、微信公众号识别微信用户等开发接口,免费提供给用户时需要限流,更具有实时性和准确性的接口需要付费。

api限流实战

首先我们编写注解类AccessLimit,使用注解方式在方法上限流更优雅更方便!三个参数分别代表有效时间、最大访问次数、是否需要登录,可以理解为 seconds 内最多访问 maxCount 次。

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface AccessLimit {
    int seconds();
    int maxCount();
    boolean needLogin() default true;
}


限流的思路

1.通过路径:ip的作为key,访问次数为value的方式对某一用户的某一请求进行唯一标识
2.每次访问的时候判断key是否存在,是否count超过了限制的访问次数
3.若访问超出限制,则应response返回msg:请求过于频繁给前端予以展示
使用spring AOP进行注解拦截

import com.learn.springbootredis.annotation.AccessLimit;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.concurrent.TimeUnit;

@Aspect
@Component
public class AccessLimitAop {

    @Autowired
    private RedisTemplate<String,Integer> redisTemplate;

    @Pointcut(value = "@annotation(com.learn.springbootredis.annotation.AccessLimit)")
    public void cutLimit(){}

    @Around("cutLimit()")
    public Object recordSysLog(ProceedingJoinPoint point)throws Throwable{
        MethodSignature ms = (MethodSignature) point.getSignature();
        Method method = ms.getMethod();
        AccessLimit accessLimit = method.getAnnotation(AccessLimit.class);
        ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        HttpServletRequest request = servletRequestAttributes.getRequest();
        if (null == accessLimit) {
            return point.proceed();
        }
        int seconds = accessLimit.seconds();
        int maxCount = accessLimit.maxCount();
        boolean needLogin = accessLimit.needLogin();

        if (needLogin) {
            //判断是否登录
        }
        String key = ms.getName() + ":" + getIpAddress( request) ;

        Integer count = redisTemplate.opsForValue().get(key);
        Long expire = redisTemplate.getExpire(key);
        if (null == count || -1 == count) {
            redisTemplate.opsForValue().set(key, 1,seconds, TimeUnit.SECONDS);
            return point.proceed();
        }

        if (count < maxCount) {
            redisTemplate.opsForValue().increment(key);
            return point.proceed();
        }

        if (count >= maxCount) {
// response 返回 json 请求过于频繁请稍后再试
            String str = "{\n" +
                    "\t\"success\": 1,\n" +
                    "\t\"message\": \"频繁访问\"\n" +
                    "}";
            return str;
        }
        return point.proceed();
    }

    /**
     * 获取用户真实IP地址,不使用request.getRemoteAddr();的原因是有可能用户使用了代理软件方式避免真实IP地址。
     * 可是,如果通过了多级反向代理的话,X-Forwarded-For的值并不止一个,而是一串IP值,究竟哪个才是真正的用户端的真实IP呢?
     * 答案是取X-Forwarded-For中第一个非unknown的有效IP字符串
     * @param request
     * @return
     */
    private static String getIpAddress(HttpServletRequest request) {
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_CLIENT_IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_X_FORWARDED_FOR");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
            if("127.0.0.1".equals(ip)||"0:0:0:0:0:0:0:1".equals(ip)){
                //根据网卡取本机配置的IP
                InetAddress inet=null;
                try {
                    inet = InetAddress.getLocalHost();
                } catch (UnknownHostException e) {
                    e.printStackTrace();
                }
                ip= inet.getHostAddress();
            }
        }
        return ip;
    }
}


在controller层方法上直接使用注解@AccessLimit

import com.learn.springbootredis.annotation.AccessLimit;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

import java.util.HashMap;
import java.util.Map;

@RestController
public class TestController {

    @RequestMapping("questTest")
    @AccessLimit(seconds = 10,maxCount = 1)
    public Object questTest(){
        Map map = new HashMap();
        map.put("success",0);
        map.put("message","成功");
        return map;
    }
}

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!