I am working with JAX through numpyro. Specially, I want to use a B-spline function (e.g. implemented in scipy.interpolate.BSpline) to transform different points into a spline where the input depends on some of the parameters in the model. Thus, I need to be able to differentiate the B-spline in JAX (only in the input argument and not in the knots or the integer order (of course!)).
I can easily use jax.custom_vjp but not when JIT is used as it is in numpyro. I looked at the following:
- https://github.com/google/jax/issues/1142
 - https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
 
and it seems like the best hope is to use a callback. Though, I cannot figure out entirely how that would work. At https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html#using-call-to-call-a-jax-function-on-another-device-with-reverse-mode-autodiff-support
the TensorFlow example with reverse mode autodiff seem not to use JIT.
The example
Here is Python code that works without JIT (see the b_spline_basis() function):
from scipy.interpolate import BSpline
import numpy as np
from numpy import typing as npt
from functools import partial
import jax
doubleArray = npt.NDArray[np.double]
# see
#   https://stackoverflow.com/q/74699053/5861244
#   https://en.wikipedia.org/wiki/B-spline#Derivative_expressions
def _b_spline_deriv_inner(spline: BSpline, deriv_basis: doubleArray) -> doubleArray:  # type: ignore[no-any-unimported]
    out = np.zeros((deriv_basis.shape[0], deriv_basis.shape[1] - 1))
    for col_index in range(out.shape[1] - 1):
        scale = spline.t[col_index + spline.k + 1] - spline.t[col_index + 1]
        if scale != 0:
            out[:, col_index] = -deriv_basis[:, col_index + 1] / scale
    for col_index in range(1, out.shape[1]):
        scale = spline.t[col_index + spline.k] - spline.t[col_index]
        if scale != 0:
            out[:, col_index] += deriv_basis[:, col_index] / scale
    return float(spline.k) * out
def _b_spline_eval(spline: BSpline, x: doubleArray, deriv: int) -> doubleArray:  # type: ignore[no-any-unimported]
    if deriv == 0:
        return spline.design_matrix(x=x, t=spline.t, k=spline.k).todense()
    elif spline.k <= 0:
        return np.zeros((x.shape[0], spline.t.shape[0] - spline.k - 1))
    return _b_spline_deriv_inner(
        spline=spline,
        deriv_basis=_b_spline_eval(
            BSpline(t=spline.t, k=spline.k - 1, c=np.zeros(spline.c.shape[0] + 1)), x=x, deriv=deriv - 1
        ),
    )
@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2))
def b_spline_basis(knots: doubleArray, order: int, deriv: int, x: doubleArray) -> doubleArray:
    return _b_spline_eval(spline=BSpline(t=knots, k=order, c=np.zeros((order + knots.shape[0] - 1))), x=x, deriv=deriv)[
        :, 1:
    ]
def b_spline_basis_fwd(knots: doubleArray, order: int, deriv: int, x: doubleArray) -> tuple[doubleArray, doubleArray]:
    spline = BSpline(t=knots, k=order, c=np.zeros(order + knots.shape[0] - 1))
    return (
        _b_spline_eval(spline=spline, x=x, deriv=deriv)[:, 1:],
        _b_spline_eval(spline=spline, x=x, deriv=deriv + 1)[:, 1:],
    )
def b_spline_basis_bwd(
    knots: doubleArray, order: int, deriv: int, partials: doubleArray, grad: doubleArray
) -> tuple[doubleArray]:
    return (jax.numpy.sum(partials * grad, axis=1),)
b_spline_basis.defvjp(b_spline_basis_fwd, b_spline_basis_bwd)
if __name__ == "__main__":
    # tests
    knots = np.array([0, 0, 0, 0, 0.25, 1, 1, 1, 1])
    x = np.array([0.1, 0.5, 0.9])
    order = 3
    def test_jax(basis: doubleArray, partials: doubleArray, deriv: int) -> None:
        weights = jax.numpy.arange(1, basis.shape[1] + 1)
        def test_func(x: doubleArray) -> doubleArray:
            return jax.numpy.sum(jax.numpy.dot(b_spline_basis(knots=knots, order=order, deriv=deriv, x=x), weights))  # type: ignore[no-any-return]
        assert np.allclose(test_func(x), np.sum(np.dot(basis, weights)))
        assert np.allclose(jax.grad(test_func)(x), np.dot(partials, weights))
    deriv0 = np.transpose(
        np.array(
            [
                0.684,
                0.166666666666667,
                0.00133333333333333,
                0.096,
                0.444444444444444,
                0.0355555555555555,
                0.004,
                0.351851851851852,
                0.312148148148148,
                0,
                0.037037037037037,
                0.650962962962963,
            ]
        ).reshape(-1, 3)
    )
    deriv1 = np.transpose(
        np.array(
            [
                2.52,
                -1,
                -0.04,
                1.68,
                -0.666666666666667,
                -0.666666666666667,
                0.12,
                1.22222222222222,
                -2.29777777777778,
                0,
                0.444444444444444,
                3.00444444444444,
            ]
        ).reshape(-1, 3)
    )
    test_jax(deriv0, deriv1, deriv=0)
    deriv2 = np.transpose(
        np.array(
            [
                -69.6,
                4,
                0.8,
                9.6,
                -5.33333333333333,
                5.33333333333333,
                2.4,
                -2.22222222222222,
                -15.3777777777778,
                0,
                3.55555555555556,
                9.24444444444445,
            ]
        ).reshape(-1, 3)
    )
    test_jax(deriv1, deriv2, deriv=1)
    deriv3 = np.transpose(
        np.array(
            [
                504,
                -8,
                -8,
                -144,
                26.6666666666667,
                26.6666666666667,
                24,
                -32.8888888888889,
                -32.8888888888889,
                0,
                14.2222222222222,
                14.2222222222222,
            ]
        ).reshape(-1, 3)
    )
    test_jax(deriv2, deriv3, deriv=2)