Couple of approaches with np.argpartition and np.argsort for ndarrays -
def k_largest_index_argpartition_v1(a, k):
idx = np.argpartition(-a.ravel(),k)[:k]
return np.column_stack(np.unravel_index(idx, a.shape))
def k_largest_index_argpartition_v2(a, k):
idx = np.argpartition(a.ravel(),a.size-k)[-k:]
return np.column_stack(np.unravel_index(idx, a.shape))
def k_largest_index_argsort(a, k):
idx = np.argsort(a.ravel())[:-k-1:-1]
return np.column_stack(np.unravel_index(idx, a.shape))
Discussion on two versions with argpartition
Difference between k_largest_index_argpartition_v1 and k_largest_index_argpartition_v2 is how we are using argparition. In the first version, we are negating input array and then using argpartition to get the indices for the smallest k indices, thus effectively getting the largest k indices, whereas in the second version, we are getting the first a.size-k smallest indices and then we are choosing the leftover largest k indices.
Also, its worth mentioning here that with argpartition, we are not getting the indices in their sorted order. If the sorted order is needed, we need to feed in range array to np.argpartition, as mentioned in this post.
Sample runs -
1) 2D case :
In [42]: a # 2D array
Out[42]:
array([[38, 14, 81, 50],
[17, 65, 60, 24],
[64, 73, 25, 95]])
In [43]: k_largest_index_argsort(a, k=2)
Out[43]:
array([[2, 3],
[0, 2]])
In [44]: k_largest_index_argsort(a, k=4)
Out[44]:
array([[2, 3],
[0, 2],
[2, 1],
[1, 1]])
In [66]: k_largest_index_argpartition_v1(a, k=4)
Out[66]:
array([[2, 1], # Notice the order is different
[2, 3],
[0, 2],
[1, 1]])
2) 3D case :
In [46]: a # 3D array
Out[46]:
array([[[20, 98, 27, 73],
[33, 78, 48, 59],
[28, 91, 64, 70]],
[[47, 34, 51, 19],
[73, 38, 63, 94],
[95, 25, 93, 64]]])
In [47]: k_largest_index_argsort(a, k=2)
Out[47]:
array([[0, 0, 1],
[1, 2, 0]])
Runtime test -
In [56]: a = np.random.randint(0,99999999999999,(3000,4000))
In [57]: %timeit k_largest_index_argsort(a, k=10)
1 loops, best of 3: 2.18 s per loop
In [58]: %timeit k_largest_index_argpartition_v1(a, k=10)
10 loops, best of 3: 178 ms per loop
In [59]: %timeit k_largest_index_argpartition_v2(a, k=10)
10 loops, best of 3: 128 ms per loop