Slicing a numpy array along a dynamically specified axis

前端 未结 5 1728
不思量自难忘°
不思量自难忘° 2020-11-29 10:54

I would like to dynamically slice a numpy array along a specific axis. Given this:

axis = 2
start = 5
end = 10

I want to achieve the same r

相关标签:
5条回答
  • 2020-11-29 11:19

    There is an elegant way to access an arbitrary axis n of array x: Use numpy.moveaxis¹ to move the axis of interest to the front.

    x_move = np.moveaxis(x, n, 0)  # move n-th axis to front
    x_move[start:end]              # access n-th axis
    

    The catch is that you likely have to apply moveaxis on other arrays you use with the output of x_move[start:end] to keep axis order consistent. The array x_move is only a view, so every change you make to its front axis corresponds to a change of x in the n-th axis (i.e. you can read/write to x_move).


    1) You could also use swapaxes to not worry about the order of n and 0, contrary to moveaxis(x, n, 0). I prefer moveaxis over swapaxes because it only alters the order concerning n.

    0 讨论(0)
  • 2020-11-29 11:27

    As it was not mentioned clearly enough (and i was looking for it too):

    an equivalent to:

    a = my_array[:, :, :, 8]
    b = my_array[:, :, :, 2:7]
    

    is:

    a = my_array.take(indices=8, axis=3)
    b = my_array.take(indices=range(2, 7), axis=3)
    
    0 讨论(0)
  • 2020-11-29 11:30

    I think one way would be to use slice(None):

    >>> m = np.arange(2*3*5).reshape((2,3,5))
    >>> axis, start, end = 2, 1, 3
    >>> target = m[:, :, 1:3]
    >>> target
    array([[[ 1,  2],
            [ 6,  7],
            [11, 12]],
    
           [[16, 17],
            [21, 22],
            [26, 27]]])
    >>> slc = [slice(None)] * len(m.shape)
    >>> slc[axis] = slice(start, end)
    >>> np.allclose(m[slc], target)
    True
    

    I have a vague feeling I've used a function for this before, but I can't seem to find it now..

    0 讨论(0)
  • 2020-11-29 11:32

    This is very late to the party, but I have an alternate slicing function that performs slightly better than those from the other answers:

    def array_slice(a, axis, start, end, step=1):
        return a[(slice(None),) * (axis % a.ndim) + (slice(start, end, step),)]
    

    Here's a code testing each answer. Each version is labeled with the name of the user who posted the answer:

    import numpy as np
    from timeit import timeit
    
    def answer_dms(a, axis, start, end, step=1):
        slc = [slice(None)] * len(a.shape)
        slc[axis] = slice(start, end, step)
        return a[slc]
    
    def answer_smiglo(a, axis, start, end, step=1):
        return a.take(indices=range(start, end, step), axis=axis)
    
    def answer_eelkespaak(a, axis, start, end, step=1):
        sl = [slice(None)] * m.ndim
        sl[axis] = slice(start, end, step)
        return a[tuple(sl)]
    
    def answer_clemisch(a, axis, start, end, step=1):
        a = np.moveaxis(a, axis, 0)
        a = a[start:end:step]
        return np.moveaxis(a, 0, axis)
    
    def answer_leland(a, axis, start, end, step=1):
        return a[(slice(None),) * (axis % a.ndim) + (slice(start, end, step),)]
    
    if __name__ == '__main__':
        m = np.arange(2*3*5).reshape((2,3,5))
        axis, start, end = 2, 1, 3
        target = m[:, :, 1:3]
        for answer in (answer_dms, answer_smiglo, answer_eelkespaak,
                       answer_clemisch, answer_leland):
            print(answer.__name__)
            m_copy = m.copy()
            m_slice = answer(m_copy, axis, start, end)
            c = np.allclose(target, m_slice)
            print('correct: %s' %c)
            t = timeit('answer(m, axis, start, end)',
                       setup='from __main__ import answer, m, axis, start, end')
            print('time:    %s' %t)
            try:
                m_slice[0,0,0] = 42
            except:
                print('method:  view_only')
            finally:
                if np.allclose(m, m_copy):
                    print('method:  copy')
                else:
                    print('method:  in_place')
            print('')
    

    Here are the results:

    answer_dms
    
    Warning (from warnings module):
      File "C:\Users\leland.hepworth\test_dynamic_slicing.py", line 7
        return a[slc]
    FutureWarning: Using a non-tuple sequence for multidimensional indexing is 
    deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be 
    interpreted as an array index, `arr[np.array(seq)]`, which will result either in an 
    error or a different result.
    correct: True
    time:    2.2048302
    method:  in_place
    
    answer_smiglo
    correct: True
    time:    5.9013344
    method:  copy
    
    answer_eelkespaak
    correct: True
    time:    1.1219435999999998
    method:  in_place
    
    answer_clemisch
    correct: True
    time:    13.707583699999999
    method:  in_place
    
    answer_leland
    correct: True
    time:    0.9781496999999995
    method:  in_place
    
    • DSM's answer includes a few suggestions for improvement in the comments.
    • EelkeSpaak's answer applies those improvements, which avoids the warning and is quicker.
    • Śmigło's answer involving np.take gives worse results, and while it is not view-only, it does create a copy.
    • clemisch's answer involving np.moveaxis takes the longest time to complete, but it surprisingly references back to the previous array's memory location.
    • My answer removes the need for the intermediary slicing list. It also uses a shorter slicing index when the slicing axis is toward the beginning. This gives the quickest results, with additional improvements as axis is closer to 0.

    I also added a step parameter to each version, in case that is something you need.

    0 讨论(0)
  • 2020-11-29 11:43

    This is a bit late to the party, but the default Numpy way to do this is numpy.take. However, that one always copies data (since it supports fancy indexing, it always assumes this is possible). To avoid that (in many cases you will want a view of the data, not a copy), fallback to the slice(None) option already mentioned in the other answer, possibly wrapping it in a nice function:

    def simple_slice(arr, inds, axis):
        # this does the same as np.take() except only supports simple slicing, not
        # advanced indexing, and thus is much faster
        sl = [slice(None)] * arr.ndim
        sl[axis] = inds
        return arr[tuple(sl)]
    
    0 讨论(0)
提交回复
热议问题