python pass pandas dataframe, parameters, and functions to scipy.optimize.minimize

后端 未结 1 1370
醉话见心
醉话见心 2021-01-13 01:20

I am trying to use SciPy\'s scipy.optimize.minimize function to minimize a function I have created. However, the function I am trying to optimize over is itself constructed

相关标签:
1条回答
  • 2021-01-13 01:55

    It was not an issue with the data format but you called loglik_total in the wrong manner. Here is the modified version, with the correct order of arguments (params has to go first; then you pass the additional arguments in the same order as in args of your minimize call):

    def loglik_total(params, data, id_list):
    
        # Extract parameters.
        delta_params = list(params[0:len(id_list)])
        sigma_param = params[-1]
    
        # Calculate the negative log-likelihood for every row in data and sum the values.
        lt = -np.sum( data.apply(lambda row: loglik_row(row, delta_params, sigma_param, id_list), axis=1) )
    
        return lt
    

    If you then call

    res = minimize(fun=loglik_total, x0=init_params,
                args=(data, id_list), method='nelder-mead')
    

    it runs through nicely (note that the order is x, data, id_list, the same as you pass to loglik_total) and res looks as follows:

    final_simplex: (array([[  2.55758096e+05,   6.99890451e+04,  -1.41860117e+05,
              3.88586258e+05,   3.19488400e+05,   4.90209168e+04,
              6.43380010e+04,  -1.85436851e+09],
           [  2.55758096e+05,   6.99890451e+04,  -1.41860117e+05,
              3.88586258e+05,   3.19488400e+05,   4.90209168e+04,
              6.43380010e+04,  -1.85436851e+09],
           [  2.55758096e+05,   6.99890451e+04,  -1.41860117e+05,
              3.88586258e+05,   3.19488400e+05,   4.90209168e+04,
              6.43380010e+04,  -1.85436851e+09],
           [  2.55758096e+05,   6.99890451e+04,  -1.41860117e+05,
              3.88586258e+05,   3.19488400e+05,   4.90209168e+04,
              6.43380010e+04,  -1.85436851e+09],
           [  2.55758096e+05,   6.99890451e+04,  -1.41860117e+05,
              3.88586258e+05,   3.19488400e+05,   4.90209168e+04,
              6.43380010e+04,  -1.85436851e+09],
           [  2.55758096e+05,   6.99890451e+04,  -1.41860117e+05,
              3.88586258e+05,   3.19488400e+05,   4.90209168e+04,
              6.43380010e+04,  -1.85436851e+09],
           [  2.55758096e+05,   6.99890451e+04,  -1.41860117e+05,
              3.88586258e+05,   3.19488400e+05,   4.90209168e+04,
              6.43380010e+04,  -1.85436851e+09],
           [  2.55758096e+05,   6.99890451e+04,  -1.41860117e+05,
              3.88586258e+05,   3.19488400e+05,   4.90209168e+04,
              6.43380010e+04,  -1.85436851e+09],
           [  2.55758096e+05,   6.99890451e+04,  -1.41860117e+05,
              3.88586258e+05,   3.19488400e+05,   4.90209168e+04,
              6.43380010e+04,  -1.85436851e+09]]), array([-0., -0., -0., -0., -0., -0., -0., -0., -0.]))
               fun: -0.0
           message: 'Optimization terminated successfully.'
              nfev: 930
               nit: 377
            status: 0
           success: True
                 x: array([  2.55758096e+05,   6.99890451e+04,  -1.41860117e+05,
             3.88586258e+05,   3.19488400e+05,   4.90209168e+04,
             6.43380010e+04,  -1.85436851e+09])
    

    Whether this output makes any sense, I cannot judge though :)

    0 讨论(0)
提交回复
热议问题