The existing answers are correct, but I wanted to expand on them to provide a self-contained function that behaves exactly like torch.topk with pure numpy.
Here's the function (I've included the instructions inline):
def topk(array, k, axis=-1, sorted=True):
    # Use np.argpartition is faster than np.argsort, but do not return the values in order
    # We use array.take because you can specify the axis
    partitioned_ind = (
        np.argpartition(array, -k, axis=axis)
        .take(indices=range(-k, 0), axis=axis)
    )
    # We use the newly selected indices to find the score of the top-k values
    partitioned_scores = np.take_along_axis(array, partitioned_ind, axis=axis)
    
    if sorted:
        # Since our top-k indices are not correctly ordered, we can sort them with argsort
        # only if sorted=True (otherwise we keep it in an arbitrary order)
        sorted_trunc_ind = np.flip(
            np.argsort(partitioned_scores, axis=axis), axis=axis
        )
        
        # We again use np.take_along_axis as we have an array of indices that we use to
        # decide which values to select
        ind = np.take_along_axis(partitioned_ind, sorted_trunc_ind, axis=axis)
        scores = np.take_along_axis(partitioned_scores, sorted_trunc_ind, axis=axis)
    else:
        ind = partitioned_ind
        scores = partitioned_scores
    
    return scores, ind
To verify the correctness, you can test it against torch:
import torch
import numpy as np
x = np.random.randn(50, 50, 10, 10)
axis = 2  # Change this to any axis and it'll be fine
val_np, ind_np = topk(x, k=10, axis=axis)
val_pt, ind_pt = torch.topk(torch.tensor(x), k=10, dim=axis)
print("Values are same:", np.all(val_np == val_pt.numpy()))
print("Indices are same:", np.all(ind_np == ind_pt.numpy()))
- To be clear, np.take_along_axisis recommended to be used withnp.argpartitionfor accessing the original value in the higher-dimension.
- np.argpartitionis faster than- np.argsortbecause it does not sort the entire array. This answer claims it takes- O(n)instead of `O(n log