I\'m struggling to understand exactly how einsum
works. I\'ve looked at the documentation and a few examples, but it\'s not seeming to stick.
Here\'s an
When reading einsum equations, I've found it the most helpful to just be able to mentally boil them down to their imperative versions.
Let's start with the following (imposing) statement:
C = np.einsum('bhwi,bhwj->bij', A, B)
Working through the punctuation first we see that we have two 4-letter comma-separated blobs - bhwi
and bhwj
, before the arrow,
and a single 3-letter blob bij
after it. Therefore, the equation produces a rank-3 tensor result from two rank-4 tensor inputs.
Now, let each letter in each blob be the name of a range variable. The position at which the letter appears in the blob is the index of the axis that it ranges over in that tensor. The imperative summation that produces each element of C, therefore, has to start with three nested for loops, one for each index of C.
for b in range(...):
for i in range(...):
for j in range(...):
# the variables b, i and j index C in the order of their appearance in the equation
C[b, i, j] = ...
So, essentially, you have a for
loop for every output index of C. We'll leave the ranges undetermined for now.
Next we look at the left-hand side - are there any range variables there that don't appear on the right-hand side? In our case - yes, h
and w
.
Add an inner nested for
loop for every such variable:
for b in range(...):
for i in range(...):
for j in range(...):
C[b, i, j] = 0
for h in range(...):
for w in range(...):
...
Inside the innermost loop we now have all indices defined, so we can write the actual summation and the translation is complete:
# three nested for-loops that index the elements of C
for b in range(...):
for i in range(...):
for j in range(...):
# prepare to sum
C[b, i, j] = 0
# two nested for-loops for the two indexes that don't appear on the right-hand side
for h in range(...):
for w in range(...):
# Sum! Compare the statement below with the original einsum formula
# 'bhwi,bhwj->bij'
C[b, i, j] += A[b, h, w, i] * B[b, h, w, j]
If you've been able to follow the code thus far, then congratulations! This is all you need to be able to read einsum equations. Notice in particular how the original einsum formula maps to the final summation statement in the snippet above. The for-loops and range bounds are just fluff and that final statement is all you really need to understand what's going on.
For the sake of completeness, let's see how to determine the ranges for each range variable. Well, the range of each variable is simply the length of the dimension(s) which it indexes. Obviously, if a variable indexes more than one dimension in one or more tensors, then the lengths of each of those dimensions have to be equal. Here's the code above with the complete ranges:
# C's shape is determined by the shapes of the inputs
# b indexes both A and B, so its range can come from either A.shape or B.shape
# i indexes only A, so its range can only come from A.shape, the same is true for j and B
assert A.shape[0] == B.shape[0]
assert A.shape[1] == B.shape[1]
assert A.shape[2] == B.shape[2]
C = np.zeros((A.shape[0], A.shape[3], B.shape[3]))
for b in range(A.shape[0]): # b indexes both A and B, or B.shape[0], which must be the same
for i in range(A.shape[3]):
for j in range(B.shape[3]):
# h and w can come from either A or B
for h in range(A.shape[1]):
for w in range(A.shape[2]):
C[b, i, j] += A[b, h, w, i] * B[b, h, w, j]