Install JAX for ROCm#
This directory provides setup instructions and necessary files to build, test, and run JAX with ROCm support using a PIP or Docker install, suitable for both runtime and CI workflows.
Note
These instructions are for JAX installation on Radeon GPUs.
To install ROCm on Instinct GPUs, refer to ROCm Instinct documentation.
Install JAX#
Follow these instructions to install JAX via PIP or Docker install.
PIP installation#
Follow these instructions to install JAX via PIP.
Important
The packages must be installed in the following order:
Install
pjrt
wheel.Install
plugin
wheel.Install
jaxlib
wheel.Install
jax
wheel.
Install JAX for Ubuntu 24.04.
Uninstall previous version
pip3 uninstall -y jax-rocm60-pjrt jax-rocm60-plugin jaxlib jax
Install the
pjrt
wheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jax_rocm60_pjrt-0.4.35-py3-none-manylinux_2_28_x86_64.whl
Install the
plugin
wheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jax_rocm60_plugin-0.4.35-cp312-cp312-manylinux_2_28_x86_64.whl
Install the
jaxlib
wheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jaxlib-0.4.35-cp312-cp312-manylinux_2_28_x86_64.whl
Install the
jax
wheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jax-0.4.35-py3-none-any.whl
Install JAX for Ubuntu 22.04.
Uninstall previous version
pip3 uninstall -y jax-rocm60-pjrt jax-rocm60-plugin jaxlib jax
Install the
pjrt
wheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jax_rocm60_pjrt-0.4.35-py3-none-manylinux_2_28_x86_64.whl
Install the
plugin
wheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jax_rocm60_plugin-0.4.35-cp310-cp310-manylinux_2_28_x86_64.whl
Install the
jaxlib
wheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jaxlib-0.4.35-cp310-cp310-manylinux_2_28_x86_64.whl
Install the
jax
wheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jax-0.4.35-py3-none-any.whl
Install JAX for RHEL 9.6.
Uninstall previous version
pip3 uninstall -y jax-rocm60-pjrt jax-rocm60-plugin jaxlib jax
Install the
pjrt
wheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jax_rocm60_pjrt-0.4.35-py3-none-manylinux_2_28_x86_64.whl
Install the
plugin
wheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jax_rocm60_plugin-0.4.35-cp312-cp312-manylinux_2_28_x86_64.whl
Install the
jaxlib
wheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jaxlib-0.4.35-cp312-cp312-manylinux_2_28_x86_64.whl
Install the
jax
wheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jax-0.4.35-py3-none-any.whl
Docker installation#
The ROCm JAX team provides prebuilt Docker images, which is the simplest way to use JAX on ROCm. These images are available on Docker Hub and come with JAX configured for ROCm.
Install JAX for Ubuntu 24.04.
To pull the latest ROCm JAX Docker image, run:
docker pull rocm/jax:rocm6.4.1-jax0.4.35-py3.12
Note
For specific versions of JAX, review the periodically pushed Docker images at ROCm JAX Community on Docker Hub.Additional Docker images are available at ROCm JAX on Docker Hub. These contain the latest ROCm version but might use an older version of JAX.
Once the image is downloaded, launch a container using the following command:
docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G \ --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir \ --name rocm_jax rocm/jax-community:latest /bin/bash docker attach rocm_jax
Note
The –shm-size parameter allocates shared memory for the container. Adjust it based on your system’s resources if needed. Replace $(pwd) with the absolute path to the directory you want to mount inside the container. If you prefer to use rocm/jax, remember to replace rocm/jax-community with rocm/jax.
Install JAX for Ubuntu 22.04.
To pull the latest ROCm JAX Docker image, run:
docker pull rocm/jax:rocm6.4.1-jax0.4.35-py3.10
Note
For specific versions of JAX, review the periodically pushed Docker images at ROCm JAX Community on Docker Hub.Additional Docker images are available at ROCm JAX on Docker Hub. These contain the latest ROCm version but might use an older version of JAX.
Once the image is downloaded, launch a container using the following command:
docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G \ --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir \ --name rocm_jax rocm/jax-community:latest /bin/bash docker attach rocm_jax
Note
The –shm-size parameter allocates shared memory for the container. Adjust it based on your system’s resources if needed. Replace $(pwd) with the absolute path to the directory you want to mount inside the container. If you prefer to use rocm/jax, remember to replace rocm/jax-community with rocm/jax.
Install JAX for RHEL 9.6.
To pull the latest ROCm JAX Docker image, run:
docker pull rocm/jax:rocm6.4.1-jax0.4.35-py3.12
Note
For specific versions of JAX, review the periodically pushed Docker images at ROCm JAX Community on Docker Hub.Additional Docker images are available at ROCm JAX on Docker Hub. These contain the latest ROCm version but might use an older version of JAX.
Once the image is downloaded, launch a container using the following command:
docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G \ --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir \ --name rocm_jax rocm/jax-community:latest /bin/bash docker attach rocm_jax
Note
The –shm-size parameter allocates shared memory for the container. Adjust it based on your system’s resources if needed. Replace $(pwd) with the absolute path to the directory you want to mount inside the container. If you prefer to use rocm/jax, remember to replace rocm/jax-community with rocm/jax.
Verify installation#
Refer to Testing your JAX installation with ROCm for verification steps.