频率(限流)

本秂侑毒 提交于 2020-02-17 23:19:28

问题:控制访问频率,在访问的时候加上一定的次数限制

基本实现

views.py

class VisitThrottle(object):
    def allow_request(self, request, view):
        return True       # 可以继续访问
        # return False    # 访问频率太高, 被限制

    def wait(self):
        pass

可以进一步的升级,限制 10s 内只能访问3次

import time
VISIT_RECORD = {}

class VisitThrottle(object):
    '''
    10s内只能访问3次
    '''
    def allow_request(self, request, view):
        # 1. 获取用户IP
        remote_addr = request.META.get('REMOTE_ADDR')
        ctime = time.time()
        if remote_addr not in VISIT_RECORD:
            VISIT_RECORD[remote_addr] = [ctime, ]
            return True
        history = VISIT_RECORD.get(remote_addr)

        while history and history[-1] < ctime - 10:
            history.pop()

        if len(history) < 3:
            history.insert(0, ctime)
            return True
            # return True       # 可以继续访问
            # return False      # 访问频率太高, 被限制

    def wait(self):
    '''
    还需要等待的时间
    '''
    ctime = time.time()
    return 60 - (ctime - self.history[-1])

源码流程

和前面一样,也是从 dispatch 开始,到 initial

def initial(self, request, *args, **kwargs):
    
    # Ensure that the incoming request is permitted
    self.perform_authentication(request)
    self.check_permissions(request)
    # 控制访问频率
    self.check_throttles(request)
def check_throttles(self, request):
    # get_throttles 里面是一个列表生成式
    for throttle in self.get_throttles():
        if not throttle.allow_request(request, self):
            self.throttled(request, throttle.wait())
def get_throttles(self):
    """
    Instantiates and returns the list of throttles that this view uses.
    """
    return [throttle() for throttle in self.throttle_classes]

throttle_classes 默认使用配置文件

class APIView(View):
    ...
    throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
    ...

可以添加到全局使用,首先在 utils 下新建 throttle.py,将视图文件中的类移至 throttle.py,这里修改了 60s内能访问3次

# throttle.py

import time
VISIT_RECORD = {}

class VisitThrottle(object):
    '''
    60s内只能访问3次
    '''
    def __init__(self):
        self.history = None

    def allow_request(self, request, view):
        # 1. 获取用户IP
        remote_addr = request.META.get('REMOTE_ADDR')
        ctime = time.time()
        if remote_addr not in VISIT_RECORD:
            VISIT_RECORD[remote_addr] = [ctime, ]
            return True
        history = VISIT_RECORD.get(remote_addr)
        self.history = history

        while history and history[-1] < ctime - 60:
            history.pop()

        if len(history) < 3:
            history.insert(0, ctime)
            return True
            # return True       # 可以继续访问
            # return False      # 访问频率太高, 被限制

    def wait(self):
        '''
        还需要等待的时间
        '''
        ctime = time.time()
        return 60 - (ctime - self.history[-1])

然后在配置文件 settings.py 中添加路径

REST_FRAMEWORK = {
    ...
    'DEFAULT_THROTTLE_CLASSES': ['api.utils.throttle.VisitThrottle']
}

最后将视图中的局部配置删除即可。

回到 check_throttles

def check_throttles(self, request):
    
    for throttle in self.get_throttles():
        # throttle.allow_request 为 False,走下一步,throttled 抛出异常,表示访问频率过多
        if not throttle.allow_request(request, self):
            self.throttled(request, throttle.wait())
def throttled(self, request, wait):
    """
    If request is throttled, determine what kind of exception to raise.
    """
    raise exceptions.Throttled(wait)

频率的内置类

在自定义频率的时候,为了更加规范,需要继承,并且父类有获取 IP 的方法(可以在 BaseThrottle 中查看),因此这里直接调用父类的方法即可

from rest_framework.throttling import BaseThrottle

import time
VISIT_RECORD = {}

class VisitThrottle(BaseThrottle):
    '''
    60s内只能访问3次
    '''
    def __init__(self):
        self.history = None

    def allow_request(self, request, view):
        # 1. 获取用户IP,调用父类的方法
        remote_addr = self.get_ident(request)
        
        ctime = time.time()
        if remote_addr not in VISIT_RECORD:
            VISIT_RECORD[remote_addr] = [ctime, ]
            return True
        history = VISIT_RECORD.get(remote_addr)
        self.history = history

        while history and history[-1] < ctime - 60:
            history.pop()

        if len(history) < 3:
            history.insert(0, ctime)
            return True
            # return True       # 可以继续访问
            # return False      # 访问频率太高, 被限制

    def wait(self):
        '''
        还需要等待的时间
        '''
        ctime = time.time()
        return 60 - (ctime - self.history[-1])

进入 BaseThrottle ,发现在其下方有个 SimpleRateThrottle ,也是继承 BaseThrottle 。首先看 SimpleRateThrottle__init__ 方法

class SimpleRateThrottle(BaseThrottle):
    ... # 省略的内容
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

    def __init__(self):
        if not getattr(self, 'rate', None):
            # 这里执行了 get_rate 方法
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)
def get_rate(self):
    """
    Determine the string representation of the allowed request rate.
    """
    if not getattr(self, 'scope', None):
        msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
               self.__class__.__name__)
        raise ImproperlyConfigured(msg)

    try:
        # scope实际上是一个字典的 key,这里在 THROTTLE_RATES 中取值
        # 在上面的代码中看到 THROTTLE_RATES 是一个配置项,获取用户自定义的配置
        return self.THROTTLE_RATES[self.scope]
    except KeyError:
        msg = "No default throttle rate set for '%s' scope" % self.scope
        raise ImproperlyConfigured(msg)

至此,就可以在配置文件中写一个 60s内能访问3次 的程序,让它自动完成,无需自定义写

throttle.py

class VisitThrottle(SimpleRateThrottle):
    scope = "xi"    # scope作为key使用

settings.py

REST_FRAMEWORK = {
    ... # 省略
    'DEFAULT_THROTTLE_CLASSES': ['api.utils.throttle.VisitThrottle'],
    'DEFAULT_THROTTLE_RATES' : {
        'xi': '3/m'     # m是分钟,每分钟访问3次
    }
}

这时,配置了访问次数,就会在 return self.THROTTLE_RATES[self.scope] 中获取到,返回给 get_rate 方法,然后 __init__ 中的 rate 拿到的就是 3/m

class SimpleRateThrottle(BaseThrottle):
    ... # 省略的内容
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

    def __init__(self):
        if not getattr(self, 'rate', None):
            # '3/m'
            self.rate = self.get_rate()
        # 将字符串 '3/m' 当做参数传递给 parse_rate
        # 走完 parse_rate,num_requests代表3次,duration代表60s
        self.num_requests, self.duration = self.parse_rate(self.rate)
        
    .... # 省略
    
    def parse_rate(self, rate):
    """
    Given the request rate string, return a two tuple of:
    <allowed number of requests>, <period of time in seconds>
    """
    
    # rate就是 '3/m'
    if rate is None:
        return (None, None)
    num, period = rate.split('/')
    num_requests = int(num)
    duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
    return (num_requests, duration)

此时,构造函数走完,接着查看 allow_request

def allow_request(self, request, view):
    
    if self.rate is None:
        return True
    
    # 内置提供的访问记录放在了缓存中,通过 get_cache_key 实现
    self.key = self.get_cache_key(request, view)
    if self.key is None:
        return True
    
    self.history = self.cache.get(self.key, [])
    self.now = self.timer()
# 来到 get_cache_key,源码上并没有写什么,这表示是让我们自己写的

def get_cache_key(self, request, view):
    
    raise NotImplementedError('.get_cache_key() must be overridden')
# get_cache_key 实际上是表示能够唯一标识的方法,所以返回值可以是获取IP,用来表示谁的访问记录
# throttle.py

class VisitThrottle(SimpleRateThrottle):
    scope = "xi"

    def get_cache_key(self, request, view):
        return self.get_ident(request)  # 获取IP

回到 allow_request

def allow_request(self, request, view):
    
    if self.rate is None:
        return True
    
    # 内置提供的访问记录放在了缓存中,通过 get_cache_key 实现
    self.key = self.get_cache_key(request, view)
    if self.key is None:
        return True
    
    # 去缓存中取出所有记录
    # cache = default_cache,是django内置的缓存
    self.history = self.cache.get(self.key, [])
    self.now = self.timer()  # timer() = time.time(),获取当前时间
    
    # Drop any requests from the history which have now passed the
    # throttle duration
    # 这里与上面自定义的相同
    while self.history and self.history[-1] <= self.now - self.duration:
        self.history.pop()
    if len(self.history) >= self.num_requests:
        return self.throttle_failure()
    return self.throttle_success()


def throttle_success(self):
    """
    Inserts the current request's timestamp along with the key
    into the cache.
    """
    # 如果成功,加到历史记录中
    self.history.insert(0, self.now)
    self.cache.set(self.key, self.history, self.duration)
    return True

def throttle_failure(self):
    """
    Called when a request to the API has failed due to throttling.
    """
    return False

def wait(self):
    """
    Returns the recommended next request time in seconds.
    """
    if self.history:
        remaining_duration = self.duration - (self.now - self.history[-1])
    else:
        remaining_duration = self.duration

    available_requests = self.num_requests - len(self.history) + 1
    if available_requests <= 0:
        return None

    return remaining_duration / float(available_requests)

照样是前三次可以访问,后面再访问需要等一分钟,这是对匿名用户的控制

也可以对登录的用户进行控制,但在全局的设置中,不能既有匿名的,还有登录的。这时,就可以将登录用户的访问控制设为全局,匿名用户使用局部的设置。

settings.py

REST_FRAMEWORK = {
    'DEFAULT_AUTHENTICATION_CLASSES': ['api.utils.auth.FirstAuthentication', 'api.utils.auth.Authentication'],
    # 'DEFAULT_AUTHENTICATION_CLASSES': ['api.utils.auth.FirstAuthentication', ],
    'UNAUTHENTICATED_USER': None,
    'UNAUTHENTICATED_TOKEN': None,
    'DEFAULT_PERMISSION_CLASSES': ['api.utils.permission.SVIPPermission'],
    'DEFAULT_THROTTLE_CLASSES': ['api.utils.throttle.UserThrottle'],    # 登录用户
    'DEFAULT_THROTTLE_RATES' : {
        'xi': '3/m',
        'xiUser': '10/m'
    }
}

throttle.py

# 匿名用户
class VisitThrottle(SimpleRateThrottle):
    scope = "xi"

    def get_cache_key(self, request, view):
        return self.get_ident(request)

# 登录用户
class UserThrottle(SimpleRateThrottle):
    scope = "xiUser"

    def get_cache_key(self, request, view):
        return request.user.username

views.py

from django.shortcuts import render, HttpResponse
from django.http import JsonResponse
from rest_framework.views import APIView
from api import models
from api.utils.permission import SVIPPermission, MyPermission
from api.utils.throttle import VisitThrottle

ORDER_DICT = {
    1: {
        'name': 'qiu',
        'age': 18,
        'gender': '男',
        'content': '...'
    },

    2: {
        'name': 'xi',
        'age': 19,
        'gender': '男',
        'content': '.....'
    }
}

def md5(user):
    import hashlib
    import time

    ctime = str(time.time())

    m = hashlib.md5(bytes(user, encoding='utf-8'))
    m.update(bytes(ctime, encoding='utf-8'))
    return m.hexdigest()


class AuthView(APIView):
    authentication_classes = []
    permission_classes = []
    throttle_classes = [VisitThrottle]  # 为匿名用户设置频率控制

    def post(self, request, *args, **kwargs):
        ret = {'code': 1000, 'msg': None}
        try:
            user = request._request.POST.get('username')
            pwd = request._request.POST.get('password')
            obj = models.UerInfo.objects.filter(username=user, password=pwd).first()
            if not obj:
                ret['code'] = 1001
                ret['msg'] = '用户名或密码错误'
            # 为登录用户创建token
            else:
                token = md5(user)
                # 存在就更新, 不存在就创建
                models.UserToken.objects.update_or_create(user=obj, defaults={'token': token})
                ret['token'] = token
        except Exception as e:
            ret['code'] = 1002
            ret['msg'] = '请求异常'

        return JsonResponse(ret)

class OrderView(APIView):
    '''
    订单相关业务(只有SVIP用户有权限)
    '''
    def get(self, request, *args, **kwargs):
        ret = {'code': 1000, 'msg': None, 'data': None}
        try:
            ret['data'] = ORDER_DICT
        except Exception as e:
            pass
        return JsonResponse(ret)

class UserInfoView(APIView):
    '''
    用户中心(普通用户、VIP有权限)
    '''
    permission_classes = [MyPermission]
    def get(self, request, *args, **kwargs):
        return HttpResponse('用户信息')

总结

使用

  • 类,继承 BaseThrottle ,实现 allow_requestwait

  • 类,继承 SimpleRateThrottle ,实现 get_cache_keyscope = "xi"(配置文件中的key)

  • 局部:throttle_classes = [VisitThrottle]

  • 全局:配置 settings.py

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!