Suppose I have a 3D array:
>>> a
array([[[7, 0],
[3, 6]],
[[2, 4],
[5, 1]]])
I can get its argmax along axis=1 using
>>> m = np.argmax(a, axis=1)
>>> m
array([[0, 1],
[1, 0]])
How can I use m as an index into a, so that the results are equivalent to simply using max?
>>> a.max(axis=1)
array([[7, 6],
[5, 4]])
(This is useful when m is applied to other arrays of the same shape)