passing kwargs with multiprocessing.pool.map

后端 未结 2 1082
天涯浪人
天涯浪人 2020-12-24 08:24

I would like to pass keyword arguments to my worker-function with Pool.map(). I can\'t find a clear example of this when searching forums.

Example Code:

<         


        
相关标签:
2条回答
  • 2020-12-24 08:41

    If you want to iterate over the other arguments, use @ArcturusB's answer.

    If you just want to pass them, having the same value for each iteration, then you can do this:

    from functools import partial
    pool.map(partial(worker, **kwargs), jobs)
    

    Partial 'binds' arguments to a function. Old versions of Python cannot serialize partial objects though.

    0 讨论(0)
  • 2020-12-24 09:00

    The multiprocessing.pool.Pool.map doc states:

    A parallel equivalent of the map() built-in function (it supports only one iterable argument though). It blocks until the result is ready.

    We can only pass one iterable argument. End of the story. But we can luckilly think of a workaround: define worker_wrapper function that takes a single argument, unpacks it to args and kwargs, and passes them to worker:

    def worker_wrapper(arg):
        args, kwargs = arg
        return worker(*args, **kwargs)
    

    In your wrapper_process, you need to construct this single argument from jobs (or even directly when constructing jobs) and call worker_wrapper:

    arg = [(j, kwargs) for j in jobs]
    pool.map(worker_wrapper, arg)
    

    Here is a working implementation, kept as close as possible to your original code:

    import multiprocessing as mp
    
    def worker_wrapper(arg):
        args, kwargs = arg
        return worker(*args, **kwargs)
    
    def worker(x, y, **kwargs):
        kwarg_test = kwargs.get('kwarg_test', False)
        # print("kwarg_test = {}".format(kwarg_test))     
        if kwarg_test:
            print("Success")
        else:
            print("Fail")
        return x*y
    
    def wrapper_process(**kwargs):
        jobs = []
        pool=mp.Pool(4)
        for i, n in enumerate(range(4)):
            jobs.append((n,i))
        arg = [(j, kwargs) for j in jobs]
        pool.map(worker_wrapper, arg)
    
    def main(**kwargs):
        print("=> calling `worker`")
        worker(1, 2,kwarg_test=True) #accepts kwargs
        print("=> no kwargs")
        wrapper_process() # no kwargs
        print("=> with `kwar_test=True`")
        wrapper_process(kwarg_test=True)
    
    if __name__ == "__main__":    
        main()
    

    Which passes the test:

    => calling `worker`
    Success
    => no kwargs
    Fail
    Fail
    Fail
    Fail
    => with `kwar_test=True`
    Success
    Success
    Success
    Success
    
    0 讨论(0)
提交回复
热议问题