How to use multiprocessing pool.map with multiple arguments?

前端 未结 20 3441
-上瘾入骨i
-上瘾入骨i 2020-11-21 11:24

In the Python multiprocessing library, is there a variant of pool.map which supports multiple arguments?

text = "test"
def         


        
20条回答
  •  甜味超标
    2020-11-21 11:41

    This might be another option. The trick is in the wrapper function that returns another function which is passed in to pool.map. The code below reads an input array and for each (unique) element in it, returns how many times (ie counts) that element appears in the array, For example if the input is

    np.eye(3) = [ [1. 0. 0.]
                  [0. 1. 0.]
                  [0. 0. 1.]]
    

    then zero appears 6 times and one 3 times

    import numpy as np
    from multiprocessing.dummy import Pool as ThreadPool
    from multiprocessing import cpu_count
    
    
    def extract_counts(label_array):
        labels = np.unique(label_array)
        out = extract_counts_helper([label_array], labels)
        return out
    
    def extract_counts_helper(args, labels):
        n = max(1, cpu_count() - 1)
        pool = ThreadPool(n)
        results = {}
        pool.map(wrapper(args, results), labels)
        pool.close()
        pool.join()
        return results
    
    def wrapper(argsin, results):
        def inner_fun(label):
            label_array = argsin[0]
            counts = get_label_counts(label_array, label)
            results[label] = counts
        return inner_fun
    
    def get_label_counts(label_array, label):
        return sum(label_array.flatten() == label)
    
    if __name__ == "__main__":
        img = np.ones([2,2])
        out = extract_counts(img)
        print('input array: \n', img)
        print('label counts: ', out)
        print("========")
               
        img = np.eye(3)
        out = extract_counts(img)
        print('input array: \n', img)
        print('label counts: ', out)
        print("========")
        
        img = np.random.randint(5, size=(3, 3))
        out = extract_counts(img)
        print('input array: \n', img)
        print('label counts: ', out)
        print("========")
    

    You should get:

    input array: 
     [[1. 1.]
     [1. 1.]]
    label counts:  {1.0: 4}
    ========
    input array: 
     [[1. 0. 0.]
     [0. 1. 0.]
     [0. 0. 1.]]
    label counts:  {0.0: 6, 1.0: 3}
    ========
    input array: 
     [[4 4 0]
     [2 4 3]
     [2 3 1]]
    label counts:  {0: 1, 1: 1, 2: 2, 3: 2, 4: 3}
    ========
    

提交回复
热议问题