numerically stable way to multiply log probability matrices in numpy

后端 未结 4 1180
逝去的感伤
逝去的感伤 2021-02-02 06:55

I need to take the matrix product of two NumPy matrices (or other 2d arrays) containing log probabilities. The naive way np.log(np.dot(np.exp(a), np.exp(b))) is not

相关标签:
4条回答
  • 2021-02-02 07:42

    Suppose A.shape==(n,r) and B.shape==(r,m). In computing the matrix product C=A*B, there are actually n*m summations. To have stable results when you're working in log-space, You need the logsumexp trick in each of these summations. Fortunately, using numpy broadcasting that's quite easy to control stability of rows and columns of A and B separately.

    Here is the code:

    def logdotexp(A, B):
        max_A = np.max(A,1,keepdims=True)
        max_B = np.max(B,0,keepdims=True)
        C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
        np.log(C, out=C)
        C += max_A + max_B
        return C
    

    Note:

    The reasoning behind this is similar to the FredFoo's answer, but he used a single maximum value for each matrix. Since he did not consider every n*m summations, some elements of the final matrix might still be unstable as mentioned in one of the comments.

    Comparing with the currently accepted answer using @identity-m counter example:

    def logdotexp_less_stable(A, B):
        max_A = np.max(A)
        max_B = np.max(B)
        C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
        np.log(C, out=C)
        C += max_A + max_B
        return C
    
    print('old method:')
    print(logdotexp_less_stable([[0,0],[0,0]], [[-1000,0], [-1000,0]]))
    print('new method:')
    print(logdotexp([[0,0],[0,0]], [[-1000,0], [-1000,0]]))
    

    which prints

    old method:
    [[      -inf 0.69314718]
     [      -inf 0.69314718]]
    new method:
    [[-9.99306853e+02  6.93147181e-01]
     [-9.99306853e+02  6.93147181e-01]]
    
    0 讨论(0)
  • 2021-02-02 07:53

    logsumexp works by evaluating the right-hand side of the equation

    log(∑ exp[a]) = max(a) + log(∑ exp[a - max(a)])
    

    I.e., it pulls out the max before starting to sum, to prevent overflow in exp. The same can be applied before doing vector dot products:

    log(exp[a] ⋅ exp[b])
     = log(∑ exp[a] × exp[b])
     = log(∑ exp[a + b])
     = max(a + b) + log(∑ exp[a + b - max(a + b)])     { this is logsumexp(a + b) }
    

    but by taking a different turn in the derivation, we obtain

    log(∑ exp[a] × exp[b])
     = max(a) + max(b) + log(∑ exp[a - max(a)] × exp[b - max(b)])
     = max(a) + max(b) + log(exp[a - max(a)] ⋅ exp[b - max(b)])
    

    The final form has a vector dot product in its innards. It also extends readily to matrix multiplication, so we get the algorithm

    def logdotexp(A, B):
        max_A = np.max(A)
        max_B = np.max(B)
        C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
        np.log(C, out=C)
        C += max_A + max_B
        return C
    

    This creates two A-sized temporaries and two B-sized ones, but one of each can be eliminated by

    exp_A = A - max_A
    np.exp(exp_A, out=exp_A)
    

    and similarly for B. (If the input matrices may be modified by the function, all the temporaries can be eliminated.)

    0 讨论(0)
  • 2021-02-02 07:55

    The currently accepted answer by Fred Foo, as well as Hassan's answer, are numerically unstable (Hassan's answer is better). An example of an input on which Hassan's answer fails will be provided later. My implementation is as follows:

    import numpy as np
    from scipy.special import logsumexp
    
    def logmatmulexp(log_A: np.ndarray, log_B: np.ndarray) -> np.ndarray:
        """Given matrix log_A of shape ϴ×R and matrix log_B of shape R×I, calculates                                                                                                                                                             
        (log_A.exp() @ log_B.exp()).log() in a numerically stable way.                                                                                                                                                                           
        Has O(ϴRI) time complexity and space complexity."""
        ϴ, R = log_A.shape
        I = log_B.shape[1]
        assert log_B.shape == (R, I)
        log_A_expanded = np.broadcast_to(np.expand_dims(log_A, 2), (ϴ, R, I))
        log_B_expanded = np.broadcast_to(np.expand_dims(log_B, 0), (ϴ, R, I))
        log_pairwise_products = log_A_expanded + log_B_expanded  # shape: (ϴ, R, I)                                                                                                                                                              
        return logsumexp(log_pairwise_products, axis=1)
    

    Just like Hassan's answer and Fred Foo's answer, my answer has time complexity O(ϴRI). Their answers have space complexity O(ϴR+RI) (I am not actually sure about this), while mine unfortunately has space complexity O(ϴRI) - this is because numpy can multiply a ϴ×R matrix by a R×I matrix without allocating an additional array of size ϴ×R×I. Having O(ϴRI) space complexity is not an immanent property of my method - I think if you write it out using cycles, you can avoid this space complexity, but unfortunately I don't think you can do that using stock numpy functions.

    I have checked how much actual time my code runs, it's 20 times slower than regular matrix multiplication.

    Here's how you can know that my answer is numerically stable:

    1. Clearly, all lines other than the return line are numerically stable.
    2. The logsumexp function is known to be numerically stable.
    3. Therefor, my logmatmulexp function is numerically stable.

    My implementation has another nice property. If instead of using numpy you write the same code in pytorch or using another library with automatic differentiation, you will get a numerically stable backward pass automatically. Here's how we can know the backward pass will be numerically stable:

    1. All functions in my code are differentiable everywhere (unlike np.max)
    2. Clearly, back propagating through all lines except the return line is numerically stable, because absolutely nothing weird is happening there.
    3. Usually the developers of pytorch know what they're doing. So it's enough to trust them that they implemented backward pass of logsumexp in a numerically stable way.
    4. Actually the gradient of logsumexp is the softmax function (for reference google "softmax is gradient of logsumexp" or see https://arxiv.org/abs/1704.00805 proposition 1). It's known that softmax can be calculated in a numerically stable way. So the pytorch devs probably just use softmax there (I haven't actually checked).

    Below is the same code in pytorch (in case you need backpropagation). Due to how pytorch backpropagation works, during forward pass it will save the log_pairwise_products tensor for the backward pass. This tensor is large, and you probably don't want it to be saved - you can just recalculate it once again during backward pass. In such case I suggest you use checkpointing - it's really easy - see the second function below.

    import torch
    from torch.utils.checkpoint import checkpoint
    
    def logmatmulexp(log_A: torch.Tensor, log_B: torch.Tensor) -> torch.Tensor:
        """Given matrix log_A of shape ϴ×R and matrix log_B of shape R×I, calculates                                                                                                                                                             
        (log_A.exp() @ log_B.exp()).log() and its backward in a numerically stable way."""
        ϴ, R = log_A.shape
        I = log_B.shape[1]
        assert log_B.shape == (R, I)
        log_A_expanded = log_A.unsqueeze(2).expand((ϴ, R, I))
        log_B_expanded = log_B.unsqueeze(0).expand((ϴ, R, I))
        log_pairwise_products = log_A_expanded + log_B_expanded  # shape: (ϴ, R, I)                                                                                                                                                              
        return torch.logsumexp(log_pairwise_products, dim=1)
    
    
    def logmatmulexp_lowmem(log_A: torch.Tensor, log_B: torch.Tensor) -> torch.Tensor:
        """Same as logmatmulexp, but doesn't save a (ϴ, R, I)-shaped tensor for backward pass.                                                                                                                                                   
    
        Given matrix log_A of shape ϴ×R and matrix log_B of shape R×I, calculates                                                                                                                                                                
        (log_A.exp() @ log_B.exp()).log() and its backward in a numerically stable way."""
        return checkpoint(logmatmulexp, log_A, log_B)
    

    Here's an input on which Hassan's implementation fails but my implementation gives the correct output:

    def logmatmulexp_hassan(A, B):
        max_A = np.max(A,1,keepdims=True)
        max_B = np.max(B,0,keepdims=True)
        C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
        np.log(C, out=C)
        C += max_A + max_B
        return C
    
    log_A = np.array([[-500., 900.]], dtype=np.float64)
    log_B = np.array([[900.], [-500.]], dtype=np.float64)
    print(logmatmulexp_hassan(log_A, log_B)) # prints -inf, while the correct answer is approximately 400.69.
    
    0 讨论(0)
  • 2021-02-02 07:58

    You are accessing columns of res and b, which has poor locality of reference. One thing to try is to store these in column-major order.

    0 讨论(0)
提交回复
热议问题