问题
When using map()
from multiprocessing.Pool()
on a list of instances from a numpy.ndarray
-subclass, the new attributes of the own class are dropped.
The following minimal example based on the numpy docs subclassing example reproduces the problem:
from multiprocessing import Pool
import numpy as np
class MyArray(np.ndarray):
def __new__(cls, input_array, info=None):
obj = np.asarray(input_array).view(cls)
obj.info = info
return obj
def __array_finalize__(self, obj):
if obj is None: return
self.info = getattr(obj, 'info', None)
def sum_worker(x):
return sum(x) , x.info
if __name__ == '__main__':
arr_list = [MyArray(np.random.rand(3), info=f'foo_{i}') for i in range(10)]
with Pool() as p:
p.map(sum_worker, arr_list)
The attribute info
is dropped
AttributeError: 'MyArray' object has no attribute 'info'
Using the builtin map()
works fine
arr_list = [MyArray(np.random.rand(3), info=f'foo_{i}') for i in range(10)]
list(map(sum_worker, arr_list2))
The purpose of the method __array_finalize__()
is that the object keeps the attribute after slicing
arr = MyArray([1,2,3], info='foo')
subarr = arr[:2]
print(subarr.info)
But for Pool.map()
this method is somehow not working...
回答1:
Because multiprocessing uses pickle
to serialize data to/from separate processes, this is essentially a duplicate of this question.
Adapting the accepted solution from that question, your example becomes:
from multiprocessing import Pool
import numpy as np
class MyArray(np.ndarray):
def __new__(cls, input_array, info=None):
obj = np.asarray(input_array).view(cls)
obj.info = info
return obj
def __array_finalize__(self, obj):
if obj is None: return
self.info = getattr(obj, 'info', None)
def __reduce__(self):
pickled_state = super(MyArray, self).__reduce__()
new_state = pickled_state[2] + (self.info,)
return (pickled_state[0], pickled_state[1], new_state)
def __setstate__(self, state):
self.info = state[-1]
super(MyArray, self).__setstate__(state[0:-1])
def sum_worker(x):
return sum(x) , x.info
if __name__ == '__main__':
arr_list = [MyArray(np.random.rand(3), info=f'foo_{i}') for i in range(10)]
with Pool() as p:
p.map(sum_worker, arr_list)
Note, the second answer suggests you might be able to use pathos.multiprocessing
with your unadapted original code since pathos uses dill
instead of pickle
. This did not work when I tested it however.
来源:https://stackoverflow.com/questions/46813375/multiprocessing-pool-map-drops-attribute-of-subclassed-ndarray