Recently, all my notebooks have been updated to python 3.10, and I can't run my old code. I was using jax 0.2.17 and numpyro 0.7.1, but these no longer work with my version, so I tried changing to python 3.9. I used the following code:
!sudo apt-get update -y
!sudo apt-get install python3.9
!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1
!sudo update-alternatives --config python3
!apt-get install python3-pip
However, after downloading this and using !pip install jax==0.2.17, when I check the jax version using jax.__version__, I get 0.4.8, the version of jax already installed in Google Colab. Are there any ways to fix this? I want to be able to use the old versions of jax and numpyro, and pip installs them, but when I import I appear to be importing from the 3.10 version instead. Thanks in advance!