numerically stable way to multiply log probability matrices in numpy

后端 未结 4 1190
逝去的感伤
逝去的感伤 2021-02-02 06:55

I need to take the matrix product of two NumPy matrices (or other 2d arrays) containing log probabilities. The naive way np.log(np.dot(np.exp(a), np.exp(b))) is not

4条回答
  •  无人及你
    2021-02-02 07:42

    Suppose A.shape==(n,r) and B.shape==(r,m). In computing the matrix product C=A*B, there are actually n*m summations. To have stable results when you're working in log-space, You need the logsumexp trick in each of these summations. Fortunately, using numpy broadcasting that's quite easy to control stability of rows and columns of A and B separately.

    Here is the code:

    def logdotexp(A, B):
        max_A = np.max(A,1,keepdims=True)
        max_B = np.max(B,0,keepdims=True)
        C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
        np.log(C, out=C)
        C += max_A + max_B
        return C
    

    Note:

    The reasoning behind this is similar to the FredFoo's answer, but he used a single maximum value for each matrix. Since he did not consider every n*m summations, some elements of the final matrix might still be unstable as mentioned in one of the comments.

    Comparing with the currently accepted answer using @identity-m counter example:

    def logdotexp_less_stable(A, B):
        max_A = np.max(A)
        max_B = np.max(B)
        C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
        np.log(C, out=C)
        C += max_A + max_B
        return C
    
    print('old method:')
    print(logdotexp_less_stable([[0,0],[0,0]], [[-1000,0], [-1000,0]]))
    print('new method:')
    print(logdotexp([[0,0],[0,0]], [[-1000,0], [-1000,0]]))
    

    which prints

    old method:
    [[      -inf 0.69314718]
     [      -inf 0.69314718]]
    new method:
    [[-9.99306853e+02  6.93147181e-01]
     [-9.99306853e+02  6.93147181e-01]]
    

提交回复
热议问题