Understanding NumPy's einsum

后端 未结 6 658
死守一世寂寞
死守一世寂寞 2020-11-22 14:36

I\'m struggling to understand exactly how einsum works. I\'ve looked at the documentation and a few examples, but it\'s not seeming to stick.

Here\'s an

6条回答
  •  隐瞒了意图╮
    2020-11-22 15:17

    When reading einsum equations, I've found it the most helpful to just be able to mentally boil them down to their imperative versions.

    Let's start with the following (imposing) statement:

    C = np.einsum('bhwi,bhwj->bij', A, B)
    

    Working through the punctuation first we see that we have two 4-letter comma-separated blobs - bhwi and bhwj, before the arrow, and a single 3-letter blob bij after it. Therefore, the equation produces a rank-3 tensor result from two rank-4 tensor inputs.

    Now, let each letter in each blob be the name of a range variable. The position at which the letter appears in the blob is the index of the axis that it ranges over in that tensor. The imperative summation that produces each element of C, therefore, has to start with three nested for loops, one for each index of C.

    for b in range(...):
        for i in range(...):
            for j in range(...):
                # the variables b, i and j index C in the order of their appearance in the equation
                C[b, i, j] = ...
    

    So, essentially, you have a for loop for every output index of C. We'll leave the ranges undetermined for now.

    Next we look at the left-hand side - are there any range variables there that don't appear on the right-hand side? In our case - yes, h and w. Add an inner nested for loop for every such variable:

    for b in range(...):
        for i in range(...):
            for j in range(...):
                C[b, i, j] = 0
                for h in range(...):
                    for w in range(...):
                        ...
    

    Inside the innermost loop we now have all indices defined, so we can write the actual summation and the translation is complete:

    # three nested for-loops that index the elements of C
    for b in range(...):
        for i in range(...):
            for j in range(...):
    
                # prepare to sum
                C[b, i, j] = 0
    
                # two nested for-loops for the two indexes that don't appear on the right-hand side
                for h in range(...):
                    for w in range(...):
                        # Sum! Compare the statement below with the original einsum formula
                        # 'bhwi,bhwj->bij'
    
                        C[b, i, j] += A[b, h, w, i] * B[b, h, w, j]
    

    If you've been able to follow the code thus far, then congratulations! This is all you need to be able to read einsum equations. Notice in particular how the original einsum formula maps to the final summation statement in the snippet above. The for-loops and range bounds are just fluff and that final statement is all you really need to understand what's going on.

    For the sake of completeness, let's see how to determine the ranges for each range variable. Well, the range of each variable is simply the length of the dimension(s) which it indexes. Obviously, if a variable indexes more than one dimension in one or more tensors, then the lengths of each of those dimensions have to be equal. Here's the code above with the complete ranges:

    # C's shape is determined by the shapes of the inputs
    # b indexes both A and B, so its range can come from either A.shape or B.shape
    # i indexes only A, so its range can only come from A.shape, the same is true for j and B
    assert A.shape[0] == B.shape[0]
    assert A.shape[1] == B.shape[1]
    assert A.shape[2] == B.shape[2]
    C = np.zeros((A.shape[0], A.shape[3], B.shape[3]))
    for b in range(A.shape[0]): # b indexes both A and B, or B.shape[0], which must be the same
        for i in range(A.shape[3]):
            for j in range(B.shape[3]):
                # h and w can come from either A or B
                for h in range(A.shape[1]):
                    for w in range(A.shape[2]):
                        C[b, i, j] += A[b, h, w, i] * B[b, h, w, j]
    

提交回复
热议问题