How do I get indices of N maximum values in a NumPy array?

后端 未结 18 1225
长情又很酷
长情又很酷 2020-11-22 04:25

NumPy proposes a way to get the index of the maximum value of an array via np.argmax.

I would like a similar thing, but returning the indexes of the

18条回答
  •  隐瞒了意图╮
    2020-11-22 04:57

    Three Answers Compared For Coding Ease And Speed

    Speed was important for my needs, so I tested three answers to this question.

    Code from those three answers was modified as needed for my specific case.

    I then compared the speed of each method.

    Coding wise:

    1. NPE's answer was the next most elegant and adequately fast for my needs.
    2. Fred Foos answer required the most refactoring for my needs but was the fastest. I went with this answer, because even though it took more work, it was not too bad and had significant speed advantages.
    3. off99555's answer was the most elegant, but it is the slowest.

    Complete Code for Test and Comparisons

    import numpy as np
    import time
    import random
    import sys
    from operator import itemgetter
    from heapq import nlargest
    
    ''' Fake Data Setup '''
    a1 = list(range(1000000))
    random.shuffle(a1)
    a1 = np.array(a1)
    
    ''' ################################################ '''
    ''' NPE's Answer Modified A Bit For My Case '''
    t0 = time.time()
    indices = np.flip(np.argsort(a1))[:5]
    results = []
    for index in indices:
        results.append((index, a1[index]))
    t1 = time.time()
    print("NPE's Answer:")
    print(results)
    print(t1 - t0)
    print()
    
    ''' Fred Foos Answer Modified A Bit For My Case'''
    t0 = time.time()
    indices = np.argpartition(a1, -6)[-5:]
    results = []
    for index in indices:
        results.append((a1[index], index))
    results.sort(reverse=True)
    results = [(b, a) for a, b in results]
    t1 = time.time()
    print("Fred Foo's Answer:")
    print(results)
    print(t1 - t0)
    print()
    
    ''' off99555's Answer - No Modification Needed For My Needs '''
    t0 = time.time()
    result = nlargest(5, enumerate(a1), itemgetter(1))
    t1 = time.time()
    print("off99555's Answer:")
    print(result)
    print(t1 - t0)
    

    Output with Speed Reports

    NPE's Answer:
    [(631934, 999999), (788104, 999998), (413003, 999997), (536514, 999996), (81029, 999995)]
    0.1349949836730957
    
    Fred Foo's Answer:
    [(631934, 999999), (788104, 999998), (413003, 999997), (536514, 999996), (81029, 999995)]
    0.011161565780639648
    
    off99555's Answer:
    [(631934, 999999), (788104, 999998), (413003, 999997), (536514, 999996), (81029, 999995)]
    0.439760684967041
    

提交回复
热议问题