numerically stable way to multiply log probability matrices in numpy

后端 未结 4 1171
逝去的感伤 2021-02-02 06:55

I need to take the matrix product of two NumPy matrices (or other 2d arrays) containing log probabilities. The naive way np.log(, np.exp(b))) is not

  •  有刺的猬
    2021-02-02 07:53

    logsumexp works by evaluating the right-hand side of the equation

    log(∑ exp[a]) = max(a) + log(∑ exp[a - max(a)])

    I.e., it pulls out the max before starting to sum, to prevent overflow in exp. The same can be applied before doing vector dot products:

    log(exp[a] ⋅ exp[b])
     = log(∑ exp[a] × exp[b])
     = log(∑ exp[a + b])
     = max(a + b) + log(∑ exp[a + b - max(a + b)])     { this is logsumexp(a + b) }

    but by taking a different turn in the derivation, we obtain

    log(∑ exp[a] × exp[b])
     = max(a) + max(b) + log(∑ exp[a - max(a)] × exp[b - max(b)])
     = max(a) + max(b) + log(exp[a - max(a)] ⋅ exp[b - max(b)])

    The final form has a vector dot product in its innards. It also extends readily to matrix multiplication, so we get the algorithm

    def logdotexp(A, B):
        max_A = np.max(A)
        max_B = np.max(B)
        C = - max_A), np.exp(B - max_B))
        np.log(C, out=C)
        C += max_A + max_B
        return C

    This creates two A-sized temporaries and two B-sized ones, but one of each can be eliminated by

    exp_A = A - max_A
    np.exp(exp_A, out=exp_A)

    and similarly for B. (If the input matrices may be modified by the function, all the temporaries can be eliminated.)
