I have a function compute(x) where x is a jnp.ndarray. Now, I want to use vmap to transform it into a function that takes a batch of arrays x[i], and then jit to speed it up. compute(x) is something like:
def compute(x):
# ... some code
y = very_expensive_function(x)
return y
However, each array x[i] has a different length. I can easily work around this problem by padding arrays with trailing zeros such that they all have the same length N and vmap(compute) can be applied on batches with shape (batch_size, N).
Doing so, however, leads to very_expensive_function() to be called also on the trailing zeros of each array x[i]. Is there a way to modify compute() such that very_expensive_function() is called only on a slice of x, without interfering with vmap and jit?