I have an array of the shape (2,10) such as:
arr = jnp.ones(shape=(2,10)) * 2
or
[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]
and another array, for example [2,4].
I want the second array to tell from which index the elements of arr should be masked. Here the result would be:
[[2. 2. -1. -1. -1. -1. -1. -1. -1. -1.]
 [2. 2. 2. 2.  -1. -1. -1. -1. -1. -1.]]
I need to use jax.numpy and the answer to be vectorized and fast if possible, i.e. not using loops.