How to use numpy as_strided (from np.stride_tricks) correctly?

后端 未结 3 1812
青春惊慌失措
青春惊慌失措 2021-01-20 10:32

I\'m trying to reshape a numpy array using numpy.strided_tricks. This is the guide I\'m following: https://stackoverflow.com/a/2487551/4909087

My use c

相关标签:
3条回答
  • 2021-01-20 11:12

    The accepted answer (and discussion) is good, but for the benefit of readers who don't want to run their own test case, I'll try to illustrate what's going on:

    In [374]: a = np.arange(1,10)
    In [375]: as_strided = np.lib.stride_tricks.as_strided
    
    In [376]: a.shape
    Out[376]: (9,)
    In [377]: a.strides 
    Out[377]: (4,)
    

    For a contiguous 1d array, strides is the size of the element, here 4 bytes, an int32. To go from one element to the next it steps forward 4 bytes.

    What the OP tried:

    In [380]: as_strided(a, shape=(7,3), strides=(3,3))
    Out[380]: 
    array([[        1,       512,    196608],
           [      512,    196608,  67108864],
           [   196608,  67108864,         4],
           [ 67108864,         4,      1280],
           [        4,      1280,    393216],
           [     1280,    393216, 117440512],
           [   393216, 117440512,         7]])
    

    This is stepping by 3 bytes, crossing int32 boundaries, and giving mostly unintelligable numbers. If might make more sense if the dtype had been bytes or uint8.

    Instead using a.strides*2 (tuple replication), or (4,4) we get the desired array:

    In [381]: as_strided(a, shape=(7,3), strides=(4,4))
    Out[381]: 
    array([[1, 2, 3],
           [2, 3, 4],
           [3, 4, 5],
           [4, 5, 6],
           [5, 6, 7],
           [6, 7, 8],
           [7, 8, 9]])
    

    Columns and rows both step one element, resulting in a 1 step moving window. We could have also set shape=(3,7), 3 windows 7 elements long.

    In [382]: _.strides
    Out[382]: (4, 4)
    

    Changing strides to (8,4) steps 2 elements for each window.

    In [383]: as_strided(a, shape=(7,3), strides=(8,4))
    Out[383]: 
    array([[          1,           2,           3],
           [          3,           4,           5],
           [          5,           6,           7],
           [          7,           8,           9],
           [          9,          25, -1316948568],
           [-1316948568,   184787224, -1420192452],
           [-1420192452,           0,           0]])
    

    But shape is off, showing us bytes off the end of the original databuffer. That could be dangerous (we don't know if those bytes belong to some other object or array). With this size of array we don't get a full set of 2 step windows.

    Now step 3 elements for each row (3*4, 4):

    In [384]: as_strided(a, shape=(3,3), strides=(12,4))
    Out[384]: 
    array([[1, 2, 3],
           [4, 5, 6],
           [7, 8, 9]])
    In [385]: a.reshape(3,3).strides
    Out[385]: (12, 4)
    

    This is the same shape and strides as a 3x3 reshape.

    We can set negative stride values and 0 values. In fact, negative-step slicing along a dimension with a positive stride will give a negative stride, and broadcasting works by setting 0 strides:

    In [399]: np.broadcast_to(a, (2,9))
    Out[399]: 
    array([[1, 2, 3, 4, 5, 6, 7, 8, 9],
           [1, 2, 3, 4, 5, 6, 7, 8, 9]])
    In [400]: _.strides
    Out[400]: (0, 4)
    
    In [401]: a.reshape(3,3)[::-1,:]
    Out[401]: 
    array([[7, 8, 9],
           [4, 5, 6],
           [1, 2, 3]])
    In [402]: _.strides
    Out[402]: (-12, 4)
    

    However, negative strides require adjusting which element of the original array is the first element of the view, and as_strided has no parameter for that.

    0 讨论(0)
  • 2021-01-20 11:17

    I have no idea why you think you need strides of 3. You need strides the distance in bytes between one element of a and the next, which you can get using a.strides:

    as_strided(a, (len(a) - 2, 3), a.strides*2)
    
    0 讨论(0)
  • I was trying to do a similar operation and run into the same problem.

    In your case, as stated in this comment, the problems were:

    1. You were not taking into account the size of your element when stored in memory (int32 = 4, which can be checked using a.dtype.itemsize).
    2. You didn't specify appropriately the number of strides you had to skip, which in your case were also 4, as you were skipping only one element.

    I made myself a function based on this answer, in which I compute the segmentation of a given array, using a window of n-elements and specifying the number of elements to overlap (given by window - number_of_elements_to_skip).

    I share it here in case someone else needs it, since it took me a while to figure out how stride_tricks work:

    def window_signal(signal, window, overlap):
        """ 
        Windowing function for data segmentation.
    
        Parameters:
        ------------
        signal: ndarray
                The signal to segment.
        window: int
                Window length, in samples.
        overlap: int
                 Number of samples to overlap
    
        Returns: 
        --------
        nd-array 
                A copy of the signal array with shape (rows, window),
                where row = (N-window)//(window-overlap) + 1
        """
        N = signal.reshape(-1).shape[0] 
        if (window == overlap):
            rows = N//window
            overlap = 0
        else:
            rows = (N-window)//(window-overlap) + 1
            miss = (N-window)%(window-overlap)
            if(miss != 0):
                print('Windowing led to the loss of ', miss, ' samples.')
        item_size = signal.dtype.itemsize 
        strides = (window - overlap) * item_size
        return np.lib.stride_tricks.as_strided(signal, shape=(rows, window),
                                               strides=(strides, item_size))
    

    The solution for this case is, according to your code: as_strided(a, (len(a) - 2, 3), (4, 4))

    Alternatively, using the function window_signal:

    window_signal(a, 3, 2)

    Both return as output the following array:

    array([[1, 2, 3],
       [2, 3, 4],
       [3, 4, 5],
       [4, 5, 6],
       [5, 6, 7],
       [6, 7, 8],
       [7, 8, 9]])
    
    0 讨论(0)
提交回复
热议问题