Skip to content

Commit

Permalink
Fix for jax build including cuda plugin (#987)
Browse files Browse the repository at this point in the history
Recently Google migrated (and completely abandoned) from jaxlib[cuda]
(for instance, include CUDA backend in jaxlib) to CUDA plugin, which
needs to be enabled separately along with --enable_cuda such as
```
python build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=<CUDA_VERSION>.
```

It would build three wheels: 
1. jaxlib without CUDA
2. jax-cuda-plugin (for CUDA support)
3. jax-cuda-pjrt. 

They all need to be installed and added as build requirements.
  • Loading branch information
DwarKapex authored Aug 9, 2024
1 parent 4696e4d commit 955177d
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 11 deletions.
6 changes: 4 additions & 2 deletions .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# syntax=docker/dockerfile:1-labs
ARG BASE_IMAGE=ghcr.io/nvidia/jax:base
ARG BUILD_PATH_JAXLIB=/opt/jaxlib
ARG BUILD_PATH_JAXLIB=/opt/jax/dist
ARG URLREF_JAX=https://github.com/google/jax.git#main
ARG URLREF_XLA=https://github.com/openxla/xla.git#main
ARG URLREF_FLAX=https://github.com/google/flax.git#main
Expand Down Expand Up @@ -89,7 +89,9 @@ RUN mkdir -p /opt/pip-tools.d

## Editable installations of jax and jaxlib
RUN <<"EOF" bash -ex
echo "-e file://${BUILD_PATH_JAXLIB}" > /opt/pip-tools.d/requirements-jax.in
for component in $(ls ${BUILD_PATH_JAXLIB}); do
echo "-e file://${BUILD_PATH_JAXLIB}/${component}" >> /opt/pip-tools.d/requirements-jax.in;
done
echo "-e file://${SRC_PATH_JAX}" >> /opt/pip-tools.d/requirements-jax.in
echo "numpy<2.0.0" >> /opt/pip-tools.d/requirements-jax.in
EOF
Expand Down
30 changes: 21 additions & 9 deletions .github/container/build-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ usage() {

# Set defaults
BAZEL_CACHE=""
BUILD_PATH_JAXLIB="/opt/jaxlib"
BUILD_PATH_JAXLIB="/opt/jax/dist"
BUILD_PARAM=""
CLEAN=0
CLEANONLY=0
Expand Down Expand Up @@ -133,7 +133,7 @@ while [ : ]; do
;;
--)
shift;
break
break
;;
*)
echo "UNKNOWN OPTION $1"
Expand Down Expand Up @@ -164,6 +164,7 @@ export TF_NEED_TENSORRT=0
export TF_CUDA_PATHS=/usr,/usr/local/cuda
export TF_CUDNN_PATHS=/usr/lib/$(uname -p)-linux-gnu
export TF_CUDA_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*.*.* | cut -d . -f 3-4)
export TF_CUDA_MAJOR_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*.*.* | cut -d . -f 3)
export TF_CUBLAS_VERSION=$(ls /usr/local/cuda/lib64/libcublas.so.*.*.* | cut -d . -f 3)
export TF_CUDNN_VERSION=$(echo "${NV_CUDNN_VERSION}" | cut -d . -f 1)
export TF_NCCL_VERSION=$(echo "${NCCL_VERSION}" | cut -d . -f 1)
Expand Down Expand Up @@ -217,6 +218,7 @@ print_var SRC_PATH_JAX
print_var SRC_PATH_XLA

print_var TF_CUDA_VERSION
print_var TF_CUDA_MAJOR_VERSION
print_var TF_CUDA_COMPUTE_CAPABILITIES
print_var TF_CUBLAS_VERSION
print_var TF_CUDNN_VERSION
Expand Down Expand Up @@ -258,6 +260,8 @@ time python "${SRC_PATH_JAX}/build/build.py" \
--editable \
--use_clang \
--enable_cuda \
--build_gpu_plugin \
--gpu_plugin_cuda_version=$TF_CUDA_MAJOR_VERSION \
--cuda_path=$TF_CUDA_PATHS \
--cudnn_path=$TF_CUDNN_PATHS \
--cuda_version=$TF_CUDA_VERSION \
Expand All @@ -271,32 +275,40 @@ time python "${SRC_PATH_JAX}/build/build.py" \

# Make sure that JAX depends on the local jaxlib installation
# https://jax.readthedocs.io/en/latest/developer.html#specifying-dependencies-on-local-wheels
line="jaxlib @ file://${BUILD_PATH_JAXLIB}"
line="jaxlib @ file://${BUILD_PATH_JAXLIB}/jaxlib"
if ! grep -xF "${line}" "${SRC_PATH_JAX}/build/requirements.in"; then
pushd "${SRC_PATH_JAX}"
echo "${line}" >> build/requirements.in
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt @ file://${BUILD_PATH_JAXLIB}/jax_gpu_pjrt" >> build/requirements.in
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin @ file://${BUILD_PATH_JAXLIB}/jax_gpu_plugin" >> build/requirements.in
PYTHON_VERSION=$(python -c 'import sys; print("{}.{}".format(*sys.version_info[:2]))')
bazel run //build:requirements_dev.update --repo_env=HERMETIC_PYTHON_VERSION="${PYTHON_VERSION}"
#bazel run //build:requirements_dev.update --repo_env=HERMETIC_PYTHON_VERSION="${PYTHON_VERSION}"
python build/build.py --requirements_update --python_version=${PYTHON_VERSION}
popd
fi

## Install the built packages

# Uninstall jaxlib in case this script was used before.
if [[ "$JAXLIB_ONLY" == "0" ]]; then
pip uninstall -y jax jaxlib
pip uninstall -y jax jaxlib jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin
else
pip uninstall -y jaxlib
pip uninstall -y jaxlib jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin
fi

# install jaxlib
pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}

pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax_gpu_pjrt -e ${BUILD_PATH_JAXLIB}/jax_gpu_plugin
# install jax
if [[ "$JAXLIB_ONLY" == "0" ]]; then
pip --disable-pip-version-check install -e "${SRC_PATH_JAX}"
fi

# after installation (example)
# jax 0.4.32.dev20240808+9c2caedab /opt/jax
# jax-cuda12-pjrt 0.4.32.dev20240808 /opt/jax/dist/jax_gpu_pjrt
# jax-cuda12-plugin 0.4.32.dev20240808 /opt/jax/dist/jax_gpu_plugin
# jaxlib 0.4.32.dev20240808 /opt/jax/dist/jaxlib
pip list | grep jax

## Cleanup

pushd $SRC_PATH_JAX
Expand Down
1 change: 1 addition & 0 deletions .github/container/test-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ for t in $*; do
done

COMMON_FLAGS=$(cat << EOF
--@local_config_cuda//:enable_cuda
--cache_test_results=${CACHE_TEST_RESULTS}
--test_timeout=600
--test_tag_filters=-multiaccelerator
Expand Down

0 comments on commit 955177d

Please sign in to comment.