Celery: Rate limit on tasks with the same parameters

半世苍凉 提交于 2021-02-15 05:50:48

问题


I am looking for a way to restrict when a function is called, but only when the input parameters are different, that is:

@app.task(rate_limit="60/s")
def api_call(user):
   do_the_api_call()

for i in range(0,100):
  api_call("antoine")
  api_call("oscar")

So I would like api_call("antoine") to be called 60 times per second and api_call("oscar") 60 times per second as well.

Any help on how can I do that?

--EDIT 27/04/2015 I have tried calling a subtask with rate_limit within a task, but it does not work either: The rate_limit is always applied for all the instantiated subtasks or tasks (which is logical).

@app.task(rate_limit="60/s")
def sub_api_call(user):
   do_the_api_call()

@app.task
def api_call(user):
  sub_api_call(user)

for i in range(0,100):
  api_call("antoine")
  api_call("oscar")

Best!


回答1:


I don't think this is possible to achieve with Celery's built-in task limiter.

Assuming you're using some kind of cache for your API, the best solution might be to create a hash of the task name and args and use that key for a cache-based throttler.

If you're using Redis, you could either set a lock with a 60 second timeout, or use an incremental counter to count the calls per minute.

This post might give you some pointers on distributed throttling of Celery tasks with Redis:

https://callhub.io/blog/2014/02/03/distributed-rate-limiting-with-redis-and-celery/




回答2:


I spent some time on this today and came up with a nice solution. All the other solutions to this have one of the following problems:

  • They require tasks to have infinite retries thereby making celery's retry mechanism useless.
  • They don't throttle based on parameters
  • It fails with multiple workers or queues
  • They're clunky, etc.

Basically, you wrap your task like this:

@app.task(bind=True, max_retries=10)
@throttle_task("2/s", key="domain", jitter=(2, 15))
def scrape_domain(self, domain):
    do_stuff()

And the result is that you'll throttle the task to 2 runs per second per domain parameter, with a random retry jitter between 2-15s. The key parameter is optional, but corresponds to a parameter in your task. If the key parameter is not given, it'll just throttle the task to the rate given. If it is provided, then the throttle will apply to the (task, key) dyad.

The other way to look at this is without the decorator. This gives a little more flexibility, but counts on you to do the retrying yourself. Instead of the above, you could do:

@app.task(bind=True, max_retries=10)
def scrape_domain(self, domain):
    proceed = is_rate_okay(self, "2/s", key=domain)
    if proceed:
        do_stuff()
    else:
        self.request.retries = task.request.retries - 1  # Don't count this as against max_retries.
        return task.retry(countdown=random.uniform(2, 15))

I think that's identical to the first example. A bit longer, and more branchy, but shows how it works a bit more clearly. I'm hoping to always use the decorator, myself.

This all works by keeping a tally in redis. The implementation is very simple. You create a key in redis for the task (and the key parameter, if given), and you expire the redis key according to the schedule provided. If the user sets a rate of 10/m, you make a redis key for 60s, and you increment it every time a task with the correct name is attempted. If your incrementor gets too high, retry the task. Otherwise, run it.

def parse_rate(rate: str) -> Tuple[int, int]:
    """

    Given the request rate string, return a two tuple of:
    <allowed number of requests>, <period of time in seconds>

    (Stolen from Django Rest Framework.)
    """
    num, period = rate.split("/")
    num_requests = int(num)
    if len(period) > 1:
        # It takes the form of a 5d, or 10s, or whatever
        duration_multiplier = int(period[0:-1])
        duration_unit = period[-1]
    else:
        duration_multiplier = 1
        duration_unit = period[-1]
    duration_base = {"s": 1, "m": 60, "h": 3600, "d": 86400}[duration_unit]
    duration = duration_base * duration_multiplier
    return num_requests, duration


def throttle_task(
    rate: str,
    jitter: Tuple[float, float] = (1, 10),
    key: Any = None,
) -> Callable:
    """A decorator for throttling tasks to a given rate.

    :param rate: The maximum rate that you want your task to run. Takes the
    form of '1/m', or '10/2h' or similar.
    :param jitter: A tuple of the range of backoff times you want for throttled
    tasks. If the task is throttled, it will wait a random amount of time
    between these values before being tried again.
    :param key: An argument name whose value should be used as part of the
    throttle key in redis. This allows you to create per-argument throttles by
    simply passing the name of the argument you wish to key on.
    :return: The decorated function
    """

    def decorator_func(func: Callable) -> Callable:
        @functools.wraps(func)
        def wrapper(*args, **kwargs) -> Any:
            # Inspect the decorated function's parameters to get the task
            # itself and the value of the parameter referenced by key.
            sig = inspect.signature(func)
            bound_args = sig.bind(*args, **kwargs)
            task = bound_args.arguments["self"]
            key_value = None
            if key:
                try:
                    key_value = bound_args.arguments[key]
                except KeyError:
                    raise KeyError(
                        f"Unknown parameter '{key}' in throttle_task "
                        f"decorator of function {task.name}. "
                        f"`key` parameter must match a parameter "
                        f"name from function signature: '{sig}'"
                    )
            proceed = is_rate_okay(task, rate, key=key_value)
            if not proceed:
                logger.info(
                    "Throttling task %s (%s) via decorator.",
                    task.name,
                    task.request.id,
                )
                # Decrement the number of times the task has retried. If you
                # fail to do this, it gets auto-incremented, and you'll expend
                # retries during the backoff.
                task.request.retries = task.request.retries - 1
                return task.retry(countdown=random.uniform(*jitter))
            else:
                # All set. Run the task.
                return func(*args, **kwargs)

        return wrapper

    return decorator_func


def is_rate_okay(task: Task, rate: str = "1/s", key=None) -> bool:
    """Keep a global throttle for tasks

    Can be used via the `throttle_task` decorator above.

    This implements the timestamp-based algorithm detailed here:

        https://www.figma.com/blog/an-alternative-approach-to-rate-limiting/

    Basically, you keep track of the number of requests and use the key
    expiration as a reset of the counter.

    So you have a rate of 5/m, and your first task comes in. You create a key:

        celery_throttle:task_name = 1
        celery_throttle:task_name.expires = 60

    Another task comes in a few seconds later:

        celery_throttle:task_name = 2
        Do not update the ttl, it now has 58s remaining

    And so forth, until:

        celery_throttle:task_name = 6
        (10s remaining)

    We're over the threshold. Re-queue the task for later. 10s later:

        Key expires b/c no more ttl.

    Another task comes in:

        celery_throttle:task_name = 1
        celery_throttle:task_name.expires = 60

    And so forth.

    :param task: The task that is being checked
    :param rate: How many times the task can be run during the time period.
    Something like, 1/s, 2/h or similar.
    :param key: If given, add this to the key placed in Redis for the item.
    Typically, this will correspond to the value of an argument passed to the
    throttled task.
    :return: Whether the task should be throttled or not.
    """
    key = f"celery_throttle:{task.name}{':' + str(key) if key else ''}"

    r = make_redis_interface("CACHE")

    num_tasks, duration = parse_rate(rate)

    # Check the count in redis
    count = r.get(key)
    if count is None:
        # No key. Set the value to 1 and set the ttl of the key.
        r.set(key, 1)
        r.expire(key, duration)
        return True
    else:
        # Key found. Check it.
        if int(count) <= num_tasks:
            # We're OK, run it.
            r.incr(key, 1)
            return True
        else:
            return False


来源:https://stackoverflow.com/questions/29854102/celery-rate-limit-on-tasks-with-the-same-parameters

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!