I am trying to understand numpy\'s argpartition function. I have made the documentation\'s example as basic as possible.
import numpy as np
x = np.array([3, 4,
Similar to @Imtinan, I struggled with this. I found it useful to break up the function into the arg and the partition.
Take the following array:
array = np.array([9, 2, 7, 4, 6, 3, 8, 1, 5])
the corresponding indices are: [0,1,2,3,4,5,6,7,8] where 8th index = 5 and 0th = 9
if we do np.partition(array, k=5)
, the code is going to take the 5th element (not index) and then place it into a new array. It is then going to put those elements < 5th element before it and that > 5th element after, like this:
pseudo output: [lower value elements, 5th element, higher value elements]
if we compute this we get:
array([3, 5, 1, 4, 2, 6, 8, 7, 9])
This makes sense as the 5th element in the original array = 6, [1,2,3,4,5] are all lower than 6 and [7,8,9] are higher than 6. Note that the elements are not ordered.
The arg part of the np.argpartition()
then goes one step further and swaps the elements out for their respective indices in the original array. So if we did:
np.argpartition(array, 5)
we will get:
array([5, 8, 7, 3, 1, 4, 6, 2, 0])
from above, the original array had this structure [index=value] [0=9, 1=2, 2=7, 3=4, 4=6, 5=3, 6=8, 7=1, 8=5]
you can map the value of the index to the output and you with satisfy the condition:
argpartition() = partition()
, like this:
[index form] array([5, 8, 7, 3, 1, 4, 6, 2, 0]) becomes
[3, 5, 1, 4, 2, 6, 8, 7, 9]
which is the same as the output of np.partition(array)
,
array([3, 5, 1, 4, 2, 6, 8, 7, 9])
Hopefully, this makes sense, it was the only way I could get my head around the arg part of the function.