In numpy, how to efficiently list all fixed-size submatrices?

前端 未结 1 1279
旧时难觅i
旧时难觅i 2021-01-12 16:12

I have an arbitrary NxM matrix, for example:

1 2 3 4 5 6
7 8 9 0 1 2
3 4 5 6 7 8
9 0 1 2 3 4

I want to get a list of all 3x3 submatrices in

1条回答
  •  醉梦人生
    2021-01-12 16:27

    You want a windowed view:

    from numpy.lib.stride_tricks import as_strided
    
    arr = np.arange(1, 25).reshape(4, 6) % 10
    sub_shape = (3, 3)
    view_shape = tuple(np.subtract(arr.shape, sub_shape) + 1) + sub_shape
    arr_view = as_strided(arr, view_shape, arr.strides * 2
    arr_view = arr_view.reshape((-1,) + sub_shape)
    
    >>> arr_view
    array([[[[1, 2, 3],
             [7, 8, 9],
             [3, 4, 5]],
    
            [[2, 3, 4],
             [8, 9, 0],
             [4, 5, 6]],
    
            ...
    
            [[9, 0, 1],
             [5, 6, 7],
             [1, 2, 3]],
    
            [[0, 1, 2],
             [6, 7, 8],
             [2, 3, 4]]]])
    

    The good part of doing it like this is that you are not copying any data, you are simply accessing the data of your original array in a different way. For large arrays this can result in tremendous memory savings.

    0 讨论(0)
提交回复
热议问题