How to use multiprocessing pool.map with multiple arguments?

前端 未结 20 3341
-上瘾入骨i
-上瘾入骨i 2020-11-21 11:24

In the Python multiprocessing library, is there a variant of pool.map which supports multiple arguments?

text = "test"
def         


        
相关标签:
20条回答
  • 2020-11-21 11:39

    Here is another way to do it that IMHO is more simple and elegant than any of the other answers provided.

    This program has a function that takes two parameters, prints them out and also prints the sum:

    import multiprocessing
    
    def main():
    
        with multiprocessing.Pool(10) as pool:
            params = [ (2, 2), (3, 3), (4, 4) ]
            pool.starmap(printSum, params)
        # end with
    
    # end function
    
    def printSum(num1, num2):
        mySum = num1 + num2
        print('num1 = ' + str(num1) + ', num2 = ' + str(num2) + ', sum = ' + str(mySum))
    # end function
    
    if __name__ == '__main__':
        main()
    

    output is:

    num1 = 2, num2 = 2, sum = 4
    num1 = 3, num2 = 3, sum = 6
    num1 = 4, num2 = 4, sum = 8
    

    See the python docs for more info:

    https://docs.python.org/3/library/multiprocessing.html#module-multiprocessing.pool

    In particular be sure to check out the starmap function.

    I'm using Python 3.6, I'm not sure if this will work with older Python versions

    Why there is not a very straight-forward example like this in the docs, I'm not sure.

    0 讨论(0)
  • 2020-11-21 11:40

    In the official documentation states that it supports only one iterable argument. I like to use apply_async in such cases. In your case I would do:

    from multiprocessing import Process, Pool, Manager
    
    text = "test"
    def harvester(text, case, q = None):
     X = case[0]
     res = text+ str(X)
     if q:
      q.put(res)
     return res
    
    
    def block_until(q, results_queue, until_counter=0):
     i = 0
     while i < until_counter:
      results_queue.put(q.get())
      i+=1
    
    if __name__ == '__main__':
     pool = multiprocessing.Pool(processes=6)
     case = RAW_DATASET
     m = Manager()
     q = m.Queue()
     results_queue = m.Queue() # when it completes results will reside in this queue
     blocking_process = Process(block_until, (q, results_queue, len(case)))
     blocking_process.start()
     for c in case:
      try:
       res = pool.apply_async(harvester, (text, case, q = None))
       res.get(timeout=0.1)
      except:
       pass
     blocking_process.join()
    
    0 讨论(0)
  • 2020-11-21 11:41

    This might be another option. The trick is in the wrapper function that returns another function which is passed in to pool.map. The code below reads an input array and for each (unique) element in it, returns how many times (ie counts) that element appears in the array, For example if the input is

    np.eye(3) = [ [1. 0. 0.]
                  [0. 1. 0.]
                  [0. 0. 1.]]
    

    then zero appears 6 times and one 3 times

    import numpy as np
    from multiprocessing.dummy import Pool as ThreadPool
    from multiprocessing import cpu_count
    
    
    def extract_counts(label_array):
        labels = np.unique(label_array)
        out = extract_counts_helper([label_array], labels)
        return out
    
    def extract_counts_helper(args, labels):
        n = max(1, cpu_count() - 1)
        pool = ThreadPool(n)
        results = {}
        pool.map(wrapper(args, results), labels)
        pool.close()
        pool.join()
        return results
    
    def wrapper(argsin, results):
        def inner_fun(label):
            label_array = argsin[0]
            counts = get_label_counts(label_array, label)
            results[label] = counts
        return inner_fun
    
    def get_label_counts(label_array, label):
        return sum(label_array.flatten() == label)
    
    if __name__ == "__main__":
        img = np.ones([2,2])
        out = extract_counts(img)
        print('input array: \n', img)
        print('label counts: ', out)
        print("========")
               
        img = np.eye(3)
        out = extract_counts(img)
        print('input array: \n', img)
        print('label counts: ', out)
        print("========")
        
        img = np.random.randint(5, size=(3, 3))
        out = extract_counts(img)
        print('input array: \n', img)
        print('label counts: ', out)
        print("========")
    

    You should get:

    input array: 
     [[1. 1.]
     [1. 1.]]
    label counts:  {1.0: 4}
    ========
    input array: 
     [[1. 0. 0.]
     [0. 1. 0.]
     [0. 0. 1.]]
    label counts:  {0.0: 6, 1.0: 3}
    ========
    input array: 
     [[4 4 0]
     [2 4 3]
     [2 3 1]]
    label counts:  {0: 1, 1: 1, 2: 2, 3: 2, 4: 3}
    ========
    
    0 讨论(0)
  • 2020-11-21 11:42

    Having learnt about itertools in J.F. Sebastian answer I decided to take it a step further and write a parmap package that takes care about parallelization, offering map and starmap functions on python-2.7 and python-3.2 (and later also) that can take any number of positional arguments.

    Installation

    pip install parmap
    

    How to parallelize:

    import parmap
    # If you want to do:
    y = [myfunction(x, argument1, argument2) for x in mylist]
    # In parallel:
    y = parmap.map(myfunction, mylist, argument1, argument2)
    
    # If you want to do:
    z = [myfunction(x, y, argument1, argument2) for (x,y) in mylist]
    # In parallel:
    z = parmap.starmap(myfunction, mylist, argument1, argument2)
    
    # If you want to do:
    listx = [1, 2, 3, 4, 5, 6]
    listy = [2, 3, 4, 5, 6, 7]
    param = 3.14
    param2 = 42
    listz = []
    for (x, y) in zip(listx, listy):
            listz.append(myfunction(x, y, param1, param2))
    # In parallel:
    listz = parmap.starmap(myfunction, zip(listx, listy), param1, param2)
    

    I have uploaded parmap to PyPI and to a github repository.

    As an example, the question can be answered as follows:

    import parmap
    
    def harvester(case, text):
        X = case[0]
        text+ str(X)
    
    if __name__ == "__main__":
        case = RAW_DATASET  # assuming this is an iterable
        parmap.map(harvester, case, "test", chunksize=1)
    
    0 讨论(0)
  • 2020-11-21 11:46

    A better way is using decorator instead of writing wrapper function by hand. Especially when you have a lot of functions to map, decorator will save your time by avoiding writing wrapper for every function. Usually a decorated function is not picklable, however we may use functools to get around it. More disscusions can be found here.

    Here the example

    def unpack_args(func):
        from functools import wraps
        @wraps(func)
        def wrapper(args):
            if isinstance(args, dict):
                return func(**args)
            else:
                return func(*args)
        return wrapper
    
    @unpack_args
    def func(x, y):
        return x + y
    

    Then you may map it with zipped arguments

    np, xlist, ylist = 2, range(10), range(10)
    pool = Pool(np)
    res = pool.map(func, zip(xlist, ylist))
    pool.close()
    pool.join()
    

    Of course, you may always use Pool.starmap in Python 3 (>=3.3) as mentioned in other answers.

    0 讨论(0)
  • 2020-11-21 11:47

    I think the below will be better

    def multi_run_wrapper(args):
       return add(*args)
    def add(x,y):
        return x+y
    if __name__ == "__main__":
        from multiprocessing import Pool
        pool = Pool(4)
        results = pool.map(multi_run_wrapper,[(1,2),(2,3),(3,4)])
        print results
    

    output

    [3, 5, 7]
    
    0 讨论(0)
提交回复
热议问题