问题
I am very amateur when it comes to scipy. I am trying to use scipy's fmin function on a multidimensional variable system. For the sake of simplicity I am using list of list of list's. My data is 12 dimensional, when I enter np.shape(DATA)
it returns (3,2,2)
, I am not even sure if scipy can handle that many dimensions, if not no problem I can reduce them, the point is that the optimize.fmin()
function doesn't accept list based arrays as x0
initial parameters, so I need help either rewriting the x0
array into numpy compatible one or the entire DATA array into a 12 dimensional matrix or something like that.
Here is a simpler example illustrating the issue:
from scipy import optimize
import numpy as np
def f(x): return(x[0][0]*1.5-x[0][1]*2.0+x[1][0]*2.5-x[1][1]*3.0)
result = optimize.fmin(f,[[0.1,0.1],[0.1,0.1]])
print(result)
It will give an error saying invalid index to scalar variable
which probably comes from not understanding the [[],[]]
list of list structure, so it probably only understands numpy array formats.
So how to rewrite this to make it work, and also for my (3,2,2) shaped list of list as well!?
回答1:
scipy.optimize.fmin
needs the initial guess for the function parameters to be a 1D array with a number of elements that suits the function to optimize. In your case, maybe you can use flatten
and reshape
if you just need the output to be in the same shape as your input parameters. An example based on your illustration code:
from scipy import optimize
import numpy as np
def f(x):
return x[0]*1.5-x[1]*2.0+x[2]*2.5-x[3]*3.0
guess = np.array([[0.1, 0.1],
[0.1, 0.1]]) # guess.shape is (2,2)
out = optimize.fmin(f, guess.flatten()) # flatten upon input
# out.shape is (4,)
# reshape output according to guess
out = out.reshape(guess.shape) # out.shape is (2,2) again
or out = optimize.fmin(f, guess.flatten()).reshape(guess.shape)
in one line. Note that this also works for a 3-dimensional array as you propose:
guess = np.arange(12).reshape(3,2,2)
# array([[[ 0, 1],
# [ 2, 3]],
# [[ 4, 5],
# [ 6, 7]],
# [[ 8, 9],
# [10, 11]]])
guess = guess.flatten()
# array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
guess = guess.reshape(3,2,2)
# array([[[ 0, 1],
# [ 2, 3]],
# [[ 4, 5],
# [ 6, 7]],
# [[ 8, 9],
# [10, 11]]])
来源:https://stackoverflow.com/questions/59912426/how-to-organize-list-of-list-of-lists-to-be-compatible-with-scipy-optimize-fmin