How to sort each row of a 3D numpy array by another 2D array?

后端 未结 1 1854
误落风尘
误落风尘 2021-01-15 00:36

I have a 2D numpy array of 2D points:

np.random.seed(0)   
a = np.random.rand(3, 4, 2) # each value is a 2D point

I would like to sort each r

相关标签:
1条回答
  • 2021-01-15 01:13

    If I understand you correctly, you want this:

    norms = np.linalg.norm(a,axis=2) # shape(3,4)
    indices = np.argsort(norms , axis=1)
    np.take_along_axis(a, indices[:,:,None], axis=1)
    

    output for your example:

    [[[0.4236548  0.64589411]
      [0.60276338 0.54488318]
      [0.5488135  0.71518937]
      [0.43758721 0.891773  ]]
    
     [[0.07103606 0.0871293 ]
      [0.79172504 0.52889492]
      [0.96366276 0.38344152]
      [0.56804456 0.92559664]]
    
     [[0.0202184  0.83261985]
      [0.46147936 0.78052918]
      [0.77815675 0.87001215]
      [0.97861834 0.79915856]]]
    
    0 讨论(0)
提交回复
热议问题