In the Python multiprocessing
library, is there a variant of pool.map
which supports multiple arguments?
text = "test"
def
This might be another option. The trick is in the wrapper
function that returns another function which is passed in to pool.map
. The code below reads an input array and for each (unique) element in it, returns how many times (ie counts) that element appears in the array, For example if the input is
np.eye(3) = [ [1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]]
then zero appears 6 times and one 3 times
import numpy as np
from multiprocessing.dummy import Pool as ThreadPool
from multiprocessing import cpu_count
def extract_counts(label_array):
labels = np.unique(label_array)
out = extract_counts_helper([label_array], labels)
return out
def extract_counts_helper(args, labels):
n = max(1, cpu_count() - 1)
pool = ThreadPool(n)
results = {}
pool.map(wrapper(args, results), labels)
pool.close()
pool.join()
return results
def wrapper(argsin, results):
def inner_fun(label):
label_array = argsin[0]
counts = get_label_counts(label_array, label)
results[label] = counts
return inner_fun
def get_label_counts(label_array, label):
return sum(label_array.flatten() == label)
if __name__ == "__main__":
img = np.ones([2,2])
out = extract_counts(img)
print('input array: \n', img)
print('label counts: ', out)
print("========")
img = np.eye(3)
out = extract_counts(img)
print('input array: \n', img)
print('label counts: ', out)
print("========")
img = np.random.randint(5, size=(3, 3))
out = extract_counts(img)
print('input array: \n', img)
print('label counts: ', out)
print("========")
You should get:
input array:
[[1. 1.]
[1. 1.]]
label counts: {1.0: 4}
========
input array:
[[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]]
label counts: {0.0: 6, 1.0: 3}
========
input array:
[[4 4 0]
[2 4 3]
[2 3 1]]
label counts: {0: 1, 1: 1, 2: 2, 3: 2, 4: 3}
========