How to understand numpy strides for layman?

◇◆丶佛笑我妖孽 提交于 2019-11-28 23:37:41

The actual data of a numpy array is stored in a homogeneous and contiguous block of memory called data buffer. For more information see NumPy internals. Using the (default) row-major order, a 2D array looks like this:

To map the indices i,j,k,... of a multidimensional array to the positions in the data buffer (the offset, in bytes), NumPy uses the notion of strides. Strides are the number of bytes to jump-over in the memory in order to get from one item to the next item along each direction/dimension of the array. In other words, it's the byte-separation between consecutive items for each dimension.

For example:

>>> a = np.arange(1,10).reshape(3,3)
>>> a
array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]])

This 2D array has two directions, axes-0 (running vertically downwards across rows), and axis-1 (running horizontally across columns), with each item having size:

>>> a.itemsize  # in bytes
4  

So to go from a[0, 0] -> a[0, 1] (moving horizontally along the 0th row, from the 0th column to the 1st column) the byte-step in the data buffer is 4. Same for a[0, 1] -> a[0, 2], a[1, 0] -> a[1, 1] etc. This means that the number of strides for the horizontal direction (axis-1) is 4 bytes.

However, to go from a[0, 0] -> a[1, 0] (moving vertically along the 0th column, from the 0th row to the 1st row), you need first to traverse all the remaining items on the 0th row to get to the 1st row, and then move through the 1st row to get to the item a[1, 0], i.e. a[0, 0] -> a[0, 1] -> a[0, 2] -> a[1, 0]. Therefore the number of strides for the vertical direction (axis-0) is 3*4 = 12 bytes. Note that going from a[0, 2] -> a[1, 0], and in general from the last item of the i-th row to the first item of the (i+1)-th row, is also 4 bytes because the array a is stored in the row-major order.

That's why

>>> a.strides  # (strides[0], strides[1])
(12, 4)  

Here's another example showing that the strides in the horizontal direction (axis-1), strides[1], of a 2D array is not necessary equal to the item size (e.g. an array with column-major order):

>>> b = np.array([[1, 4, 7],
                  [2, 5, 8],
                  [3, 6, 9]]).T
>>> b
array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]])

>>> b.strides
(4, 12)

Here strides[1] is a multiple of the item-size. Although the array b looks identical to the array a, it's a different array: internally b is stored as |1|4|7|2|5|8|3|6|9| (because transposing doesn't affect the data buffer but only swaps the strides and the shape), whereas a as |1|2|3|4|5|6|7|8|9|. What makes them look alike is the different strides. That is, the byte-step for b[0, 0] -> b[0, 1] is 3*4=12 bytes and for b[0, 0] -> b[1, 0] is 4 bytes, whereas for a[0, 0] -> a[0, 1] is 4 bytes and for a[0, 0] -> a[1, 0] is 12 bytes.

Last but not least, NumPy allows to create views of existing arrays with the option of modifying the strides and the shape, see stride tricks. For example:

>>> np.lib.stride_tricks.as_strided(a, shape=a.shape[::-1], strides=a.strides[::-1])
array([[1, 4, 7],
       [2, 5, 8],
       [3, 6, 9]])

which is equivalent to transposing the array a.

Let me just add, but without going into much detail, that one can even define strides that are not multiples of the item size. Here's an example:

>>> a = np.lib.stride_tricks.as_strided(np.array([1, 512, 0, 3], dtype=np.int16), 
                                        shape=(3,), strides=(3,))
>>> a
array([1, 2, 3], dtype=int16)

>>> a.strides[0]
3

>>> a.itemsize
2

Just to add to great answer by @AndyK, I learnt about numpy strides from Numpy MedKit. There they show the use with a problem as follows:

Given input:

x = np.arange(20).reshape([4, 5])
>>> x
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19]])

Expected Output:

array([[[  0,  1,  2,  3,  4],
        [  5,  6,  7,  8,  9]],

       [[  5,  6,  7,  8,  9],
        [ 10, 11, 12, 13, 14]],

       [[ 10, 11, 12, 13, 14],
        [ 15, 16, 17, 18, 19]]])

To do this, we need to know the following terms:

shape - The dimensions of the array along each axis.

strides - The number of bytes of memory that must be skipped to progress to the next item along a certain dimension.

>>> x.strides
(20, 4)

>>> np.int32().itemsize
4

Now, if we look at the Expected Output:

array([[[  0,  1,  2,  3,  4],
        [  5,  6,  7,  8,  9]],

       [[  5,  6,  7,  8,  9],
        [ 10, 11, 12, 13, 14]],

       [[ 10, 11, 12, 13, 14],
        [ 15, 16, 17, 18, 19]]])

We need to manipulate the array shape and strides. The output shape must be (3, 2, 5), i.e. 3 items, each containing two rows (m == 2) and each row having 5 elements.

The strides need to change from (20, 4) to (20, 20, 4). Each item in the new output array starts at a new row, that each row consists of 20 bytes (5 elements of 4 bytes each), and each element occupies 4 bytes (int32).

So:

>>> from numpy.lib import stride_tricks
>>> stride_tricks.as_strided(x, shape=(3, 2, 5),
                                strides=(20, 20, 4))
...
array([[[  0,  1,  2,  3,  4],
        [  5,  6,  7,  8,  9]],

       [[  5,  6,  7,  8,  9],
        [ 10, 11, 12, 13, 14]],

       [[ 10, 11, 12, 13, 14],
        [ 15, 16, 17, 18, 19]]])

An alternative would be:

>>> d = dict(x.__array_interface__)
>>> d['shape'] = (3, 2, 5)
>>> s['strides'] = (20, 20, 4)

>>> class Arr:
...     __array_interface__ = d
...     base = x

>>> np.array(Arr())
array([[[  0,  1,  2,  3,  4],
        [  5,  6,  7,  8,  9]],

       [[  5,  6,  7,  8,  9],
        [ 10, 11, 12, 13, 14]],

       [[ 10, 11, 12, 13, 14],
        [ 15, 16, 17, 18, 19]]])

I use this method very often instead of numpy.hstack or numpy.vstack and trust me, computationally it is much faster.

Note:

When using very large arrays with this trick, calculating the exact strides is not so trivial. I usually make a numpy.zeroes array of the desired shape and get the strides using array.strides and use this in the function stride_tricks.as_strided.

Hope it helps!

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!