Dynamic axis indexing of Numpy ndarray

后端 未结 2 1092
梦谈多话
梦谈多话 2021-01-18 01:29

I want to obtain the 2D slice in a given direction of a 3D array where the direction (or the axis from where the slice is going to be extracted) is given by ano

相关标签:
2条回答
  • 2021-01-18 01:49

    Transpose is cheap (timewise). There are numpy functions that use it to move the operational axis (or axes) to a known location - usually the front or end of the shape list. tensordot is one that comes to mind.

    Other functions construct an indexing tuple. They may start with a list or array for ease of manipulation, and then turn it into a tuple for application. For example

    I = [slice(None)]*A.ndim
    I[axis] = idx
    A[tuple(I)]
    

    np.apply_along_axis does something like that. It's instructive to look at the code for functions like this.

    I imagine the writers of the numpy functions worried most about whether it works robustly, secondly about speed, and lastly whether it looks pretty. You can bury all kinds of ugly code in a function!.


    tensordot ends with

    at = a.transpose(newaxes_a).reshape(newshape_a)
    bt = b.transpose(newaxes_b).reshape(newshape_b)
    res = dot(at, bt)
    return res.reshape(olda + oldb)
    

    where the previous code calculated newaxes_.. and newshape....

    apply_along_axis constructs a (0...,:,0...) index tuple

    i = zeros(nd, 'O')
    i[axis] = slice(None, None)
    i.put(indlist, ind)
    ....arr[tuple(i.tolist())]
    
    0 讨论(0)
  • 2021-01-18 02:07

    To index a dimension dynamically, you can use swapaxes, as shown below:

    a = np.arange(7 * 8 * 9).reshape((7, 8, 9))
    
    axis = 1
    idx = 2
    
    np.swapaxes(a, 0, axis)[idx]
    

    Runtime comparison

    Natural method (non dynamic) :

    %timeit a[:, idx, :]
    300 ns ± 1.58 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
    

    swapaxes:

    %timeit np.swapaxes(a, 0, axis)[idx]
    752 ns ± 4.54 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
    

    Index with list comprehension:

    %timeit a[[idx if i==axis else slice(None) for i in range(a.ndim)]]
    
    0 讨论(0)
提交回复
热议问题