I am trying to use JAX on another SO question to evaluate JAX applicability and performance on the code (There are useful information on that about what the code does). For this purpose, I have modified the code by jax.numpy (jnp) equivalent methods (Substituting NumPy related codes with their equivalent jnp codes were not as easy as I thought due to my little experience by JAX, and may be it could be written better). Finally, I checked the results with the ex-code (optimized algorithm) and the results were the same, but it takes 7.5 seconds by JAX, which took 0.10 seconds by the ex-one for a sample case (using Colab). I think this long runtime may be related to for loop in the code, which might be substituted by JAX related modules e.g. fori-loop or vectorization and …; but I don’t know what changes, and how, must be done to make this code satisfying in terms of performance and speed (using JAX).
import numpy as np
from scipy.spatial import cKDTree, distance
import jax
from jax import numpy as jnp
jax.config.update("jax_enable_x64", True)
# ---------------------------- input data ----------------------------
""" For testing by prepared files:
radii = np.load('a.npy')
poss = np.load('b.npy')
"""
rnd = np.random.RandomState(70)
data_volume = 1000
radii = rnd.uniform(0.0005, 0.122, data_volume)
dia_max = 2 * radii.max()
x = rnd.uniform(-1.02, 1.02, (data_volume, 1))
y = rnd.uniform(-3.52, 3.52, (data_volume, 1))
z = rnd.uniform(-1.02, -0.575, (data_volume, 1))
poss = np.hstack((x, y, z))
# --------------------------------------------------------------------
# @jax.jit
def ends_gap(poss, dia_max):
    particle_corsp_overlaps = jnp.array([], dtype=np.float64)
    # kdtree = cKDTree(poss)                                                                                              # Using SciPy
    for particle_idx in range(len(poss)):
        cur_point = poss[particle_idx]
        # nears_i_ind = jnp.array(kdtree.query_ball_point(cur_point, r=dia_max, return_sorted=True), dtype=np.int64)      # Using SciPy
        
        # Using NumPy
        unshared_idx = jnp.delete(jnp.arange(len(poss)), particle_idx)
        poss_without = poss[unshared_idx]
        dist_max = radii[particle_idx] + radii.max()
        lx_limit_idx = poss_without[:, 0] <= poss[particle_idx][0] + dist_max
        ux_limit_idx = poss_without[:, 0] >= poss[particle_idx][0] - dist_max
        ly_limit_idx = poss_without[:, 1] <= poss[particle_idx][1] + dist_max
        uy_limit_idx = poss_without[:, 1] >= poss[particle_idx][1] - dist_max
        lz_limit_idx = poss_without[:, 2] <= poss[particle_idx][2] + dist_max
        uz_limit_idx = poss_without[:, 2] >= poss[particle_idx][2] - dist_max
        nears_i_ind = jnp.where(lx_limit_idx & ux_limit_idx & ly_limit_idx & uy_limit_idx & lz_limit_idx & uz_limit_idx)[0]
        # assert len(nears_i_ind) > 0
        # if len(nears_i_ind) <= 1:
        #     continue
        nears_i_ind = nears_i_ind[nears_i_ind != particle_idx]
        # dist_i = distance.cdist(poss[tuple(nears_i_ind[None, :])], cur_point[None, :]).squeeze()                        # Using SciPy
        dist_i = jnp.linalg.norm(poss[tuple(nears_i_ind[None, :])] - cur_point[None, :], axis=-1)                     # Using NumPy
        contact_check = dist_i - (radii[tuple(nears_i_ind[None, :])] + radii[particle_idx])
        connected = contact_check[contact_check <= 0]
        particle_corsp_overlaps = jnp.concatenate((particle_corsp_overlaps, connected))
        contacts_ind = jnp.where(contact_check <= 0)[0]
        contacts_sec_ind = jnp.array(nears_i_ind)[contacts_ind]
        sphere_olps_ind = jnp.sort(contacts_sec_ind)
        ends_ind_mod_temp = jnp.array([jnp.repeat(particle_idx, len(sphere_olps_ind)), sphere_olps_ind], dtype=np.int64).T
        if particle_idx > 0:   # ---> these 4-lines perhaps be better to be substituted by just one-line list appending as "ends_ind.append(ends_ind_mod_temp)"
            ends_ind = jnp.concatenate((ends_ind, ends_ind_mod_temp))
        else:
            ends_ind = jnp.array(ends_ind_mod_temp, dtype=np.int64)
    ends_ind_org = ends_ind
    ends_ind, ends_ind_idx = jnp.unique(jnp.sort(ends_ind_org), axis=0, return_index=True)
    gap = jnp.array(particle_corsp_overlaps)[ends_ind_idx]
    return gap, ends_ind, ends_ind_idx, ends_ind_org
I have tried to use @jax.jit on this code, but it shows errors: TracerArrayConversionError or ConcretizationTypeError on COLAB TPU:
Using SciPy:
TracerArrayConversionError: The numpy.ndarray conversion method array() was called on the JAX Tracer object Traced<ShapedArray(float64[1000,3])>with<DynamicJaxprTrace(level=0/1)> While tracing the function ends_gap at :1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'poss'. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
Using NumPy:
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)> The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations. While tracing the function ends_gap at :1 for jit, this concrete value was not available in Python because it depends on the values of the arguments 'poss' and 'dia_max'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
I would be appreciated for any help to speed up this code by passing these problems using JAX (and jax.jit if possible). How to utilize JAX to have the best performances on both CPU and GPU or TPU?
Prepared sample test data:
a.npy = Radii data
b.npy = Poss data
Updates
The main aim of this issue is how to modify the code for gaining the best performance of that using JAX library
I have commented the SciPy related lines on the code based on jakevdp answer and uncomment the equivalent NumPy related sections.
For getting better answer, I'm numbering some important subjects:
- Is scikit-learn 
BallTreerelated methods compatible with JAX?? This methods can be a good alternative for SciPycKDTreein terms of memory usage (for probable vectorizations). - How to best handle the loop section in the code, using 
fori_loopor by putting code lines of the loop inside a function and then vectorizing, jitting or …?? 
- I had problem preparing the code for using 
fori_loop. What has been done for usingfori_loopcan be understood from the following code line, whereparticle_corsp_overlapswas the input of the defined function (this function just contains the loop section). It will be useful to show how to do that if usingfori_loopis recommended. 
particle_corsp_overlaps, ends_ind = jax.lax.fori_loop(0, len(poss), jax_loop, particle_corsp_overlaps)
- I put the NumPy section in a function for jitting by 
@jax.jitto check its capability to improve performance (I don't know how much it can help). It got an error ConcretizationTypeError (--> Shape depends on Traced Value) relating toposs. So, I tried to use@partial(jax.jit, static_argnums=0)decorator by importingpartialfromfunctools, but now I am getting the following error; how to solve it if this way is recommended e.g. for: 
@partial(jax.jit, static_argnums=0)
def ends_gap(poss):
    for particle_idx in range(len(poss)):
        cur_point = poss[particle_idx]
        unshared_idx = jnp.delete(jnp.arange(len(poss)), particle_idx)
        poss_without = poss[unshared_idx]
        dist_max = radii[particle_idx] + radii.max()
        lx_limit_idx = poss_without[:, 0] <= poss[particle_idx][0] + dist_max
        ux_limit_idx = poss_without[:, 0] >= poss[particle_idx][0] - dist_max
        ly_limit_idx = poss_without[:, 1] <= poss[particle_idx][1] + dist_max
        uy_limit_idx = poss_without[:, 1] >= poss[particle_idx][1] - dist_max
        lz_limit_idx = poss_without[:, 2] <= poss[particle_idx][2] + dist_max
        uz_limit_idx = poss_without[:, 2] >= poss[particle_idx][2] - dist_max
        nears_i_ind = jnp.where(lx_limit_idx & ux_limit_idx & ly_limit_idx & uy_limit_idx & lz_limit_idx & uz_limit_idx)[0]
        nears_i_ind = nears_i_ind[nears_i_ind != particle_idx]
        dist_i = jnp.linalg.norm(poss[tuple(nears_i_ind[None, :])] - cur_point[None, :], axis=-1) 
ValueError: Non-hashable static arguments are not supported. An error occured during a call to 'nearest_neighbors_jax' while trying to hash an object of type <class 'jaxlib.xla_extension.DeviceArray'>, [[ 8.42519143e-01 1.37693422e+00 -7.97775882e-01] [-3.31436445e-01 -1.67346250e+00 -8.61069684e-01] [-1.57500126e-01 -1.17502591e+00 -7.48879998e-01]]. The error was: TypeError: unhashable type: 'DeviceArray'
I did not put the total loop body into the function due to stuck in this short defined function. Creating a function with all the loop body, which can be jitted or …, is of interest if possible.
- Can 4-lines 
ends_indrelatedif-elsestatement be written in just one line using jax methods to avoid probable problems withifduring jitting or …?