Sharing numpy arrays in python multiprocessing pool

后端 未结 2 1533
礼貌的吻别
礼貌的吻别 2021-01-12 13:25

I\'m working on some code that does some fairly heavy numerical work on a large (tens to hundreds of thousands of numerical integrations) set of problems. Fortunately, thes

相关标签:
2条回答
  • 2021-01-12 13:49

    To make your last idea work, I think you can simply make X, param_1, and param_2 global variables by using the global keyword before modifying them inside the if statement. So add the following:

    global X
    global param_1
    global param_2
    

    directly after the if __name__ == '__main__'.

    0 讨论(0)
  • 2021-01-12 14:02

    I had a similar problem. If you just want to read my solution skip some lines :) I had to:

    • share a numpy.array between threads operating on different part of it and...
    • pass Pool.map a function with more then one argument.

    I noticed that:

    • the data of the numpy.array was correctly read but...
    • changes on the numpy.array where not made permanent
    • Pool.map had problems handling lambda functions, or so it appeared to me (if this point is not clear to you, just ignore it)

    My solution was to:

    • make the target function only argument a list
    • make the target function return the modified data instead of directly trying to write on the numpy.array

    I understand that your do_work function already return the computed data, so you would just have to modify to_work to accept a list (containing X,param_1,param_2 and arg) as argument and to pack the input to the target function in this format before passing it to Pool.map.

    Here is a sample implementation:

    def do_work2(args):
        X,param_1,param_2,arg = args
        return heavy_computation(X, param_1, param_2, arg)
    

    Now you have to pack the input to the do_work function before calling it. Your main become:

    if __name__=='__main__':
       filename = raw_input("Filename> ")
       param_1 = float(raw_input("Parameter 1: "))
       param_2 = float(raw_input("Parameter 2: "))
       X = parse_numpy_array(filename)
       # now you pack the input arguments
       arglist = [[X,param1,param2,n] for n in linspace(0.0,1.0,100)]
       # consider that you're not making 100 copies of X here. You're just passing a reference to it
       results = Pool.map(do_work2,arglist)
       #save results in a .npy file for analysis
       save("Results", [X,results])
    
    0 讨论(0)
提交回复
热议问题