Install JAX for ROCm#

This directory provides setup instructions and necessary files to build, test, and run JAX with ROCm support using a PIP or Docker install, suitable for both runtime and CI workflows.

Note
These instructions are for JAX installation on Radeon GPUs.
To install ROCm on Instinct GPUs, refer to ROCm Instinct documentation.

Install JAX#

Follow these instructions to install JAX via PIP or Docker install.

PIP installation#

Follow these instructions to install JAX via PIP.

Important
The packages must be installed in the following order:

  1. Install pjrt wheel.

  2. Install plugin wheel.

  3. Install jaxlib wheel.

  4. Install jax wheel.

Install JAX for Ubuntu 24.04.

  1. Uninstall previous version

    pip3 uninstall -y jax-rocm60-pjrt jax-rocm60-plugin jaxlib jax
    
  2. Install the pjrt wheel.

    pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jax_rocm60_pjrt-0.4.35-py3-none-manylinux_2_28_x86_64.whl
    
  3. Install the plugin wheel.

    pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jax_rocm60_plugin-0.4.35-cp312-cp312-manylinux_2_28_x86_64.whl
    
  4. Install the jaxlib wheel.

    pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jaxlib-0.4.35-cp312-cp312-manylinux_2_28_x86_64.whl
    
  5. Install the jax wheel.

    pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jax-0.4.35-py3-none-any.whl
    

Install JAX for Ubuntu 22.04.

  1. Uninstall previous version

    pip3 uninstall -y jax-rocm60-pjrt jax-rocm60-plugin jaxlib jax
    
  2. Install the pjrt wheel.

    pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jax_rocm60_pjrt-0.4.35-py3-none-manylinux_2_28_x86_64.whl
    
  3. Install the plugin wheel.

    pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jax_rocm60_plugin-0.4.35-cp310-cp310-manylinux_2_28_x86_64.whl
    
  4. Install the jaxlib wheel.

    pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jaxlib-0.4.35-cp310-cp310-manylinux_2_28_x86_64.whl
    
  5. Install the jax wheel.

    pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jax-0.4.35-py3-none-any.whl
    

Install JAX for RHEL 9.6.

  1. Uninstall previous version

    pip3 uninstall -y jax-rocm60-pjrt jax-rocm60-plugin jaxlib jax
    
  2. Install the pjrt wheel.

    pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jax_rocm60_pjrt-0.4.35-py3-none-manylinux_2_28_x86_64.whl
    
  3. Install the plugin wheel.

    pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jax_rocm60_plugin-0.4.35-cp312-cp312-manylinux_2_28_x86_64.whl
    
  4. Install the jaxlib wheel.

    pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jaxlib-0.4.35-cp312-cp312-manylinux_2_28_x86_64.whl
    
  5. Install the jax wheel.

    pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/jax-0.4.35-py3-none-any.whl
    

Docker installation#

The ROCm JAX team provides prebuilt Docker images, which is the simplest way to use JAX on ROCm. These images are available on Docker Hub and come with JAX configured for ROCm.

Install JAX for Ubuntu 24.04.

  1. To pull the latest ROCm JAX Docker image, run:

    docker pull rocm/jax:rocm6.4.1-jax0.4.35-py3.12
    

    Note
    For specific versions of JAX, review the periodically pushed Docker images at ROCm JAX Community on Docker Hub.

    Additional Docker images are available at ROCm JAX on Docker Hub. These contain the latest ROCm version but might use an older version of JAX.

  2. Once the image is downloaded, launch a container using the following command:

    docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G \
    --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir \
    --name rocm_jax rocm/jax-community:latest /bin/bash
    
    docker attach rocm_jax
    

    Note
    The –shm-size parameter allocates shared memory for the container. Adjust it based on your system’s resources if needed. Replace $(pwd) with the absolute path to the directory you want to mount inside the container. If you prefer to use rocm/jax, remember to replace rocm/jax-community with rocm/jax.

Install JAX for Ubuntu 22.04.

  1. To pull the latest ROCm JAX Docker image, run:

    docker pull rocm/jax:rocm6.4.1-jax0.4.35-py3.10
    

    Note
    For specific versions of JAX, review the periodically pushed Docker images at ROCm JAX Community on Docker Hub.

    Additional Docker images are available at ROCm JAX on Docker Hub. These contain the latest ROCm version but might use an older version of JAX.

  2. Once the image is downloaded, launch a container using the following command:

    docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G \
    --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir \
    --name rocm_jax rocm/jax-community:latest /bin/bash
    
    docker attach rocm_jax
    

    Note
    The –shm-size parameter allocates shared memory for the container. Adjust it based on your system’s resources if needed. Replace $(pwd) with the absolute path to the directory you want to mount inside the container. If you prefer to use rocm/jax, remember to replace rocm/jax-community with rocm/jax.

Install JAX for RHEL 9.6.

  1. To pull the latest ROCm JAX Docker image, run:

    docker pull rocm/jax:rocm6.4.1-jax0.4.35-py3.12
    

    Note
    For specific versions of JAX, review the periodically pushed Docker images at ROCm JAX Community on Docker Hub.

    Additional Docker images are available at ROCm JAX on Docker Hub. These contain the latest ROCm version but might use an older version of JAX.

  2. Once the image is downloaded, launch a container using the following command:

    docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G \
    --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir \
    --name rocm_jax rocm/jax-community:latest /bin/bash
    
    docker attach rocm_jax
    

    Note
    The –shm-size parameter allocates shared memory for the container. Adjust it based on your system’s resources if needed. Replace $(pwd) with the absolute path to the directory you want to mount inside the container. If you prefer to use rocm/jax, remember to replace rocm/jax-community with rocm/jax.

Verify installation#

Refer to Testing your JAX installation with ROCm for verification steps.