From c07ffe4d95af591da2e8df067ebeba407c253416 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Tue, 26 Nov 2024 01:17:31 -0800 Subject: [PATCH] New JAX build from Google (#1172) --- .github/container/build-jax.sh | 36 +++++++++++++++++----------------- .github/container/test-jax.sh | 4 ++-- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/.github/container/build-jax.sh b/.github/container/build-jax.sh index db01bccff..95cf5246b 100755 --- a/.github/container/build-jax.sh +++ b/.github/container/build-jax.sh @@ -274,25 +274,25 @@ if [[ ! -e "/usr/local/cuda/lib" ]]; then fi if ! grep 'try-import %workspace%/.local_cuda.bazelrc' "${SRC_PATH_JAX}/.bazelrc"; then - echo -e '\ntry-import %workspace%/.local_cuda.bazelrc' >> "${SRC_PATH_JAX}/.bazelrc" + echo -e '\ntry-import %workspace%/.local_cuda.bazelrc' >> "${SRC_PATH_JAX}/.bazelrc" fi cat > "${SRC_PATH_JAX}/.local_cuda.bazelrc" << EOF -build:cuda --repo_env=LOCAL_CUDA_PATH="/usr/local/cuda" -build:cuda --repo_env=LOCAL_CUDNN_PATH="/opt/nvidia/cudnn" -build:cuda --repo_env=LOCAL_NCCL_PATH="/opt/nvidia/nccl" +build --repo_env=LOCAL_CUDA_PATH="/usr/local/cuda" +build --repo_env=LOCAL_CUDNN_PATH="/opt/nvidia/cudnn" +build --repo_env=LOCAL_NCCL_PATH="/opt/nvidia/nccl" EOF -time python "${SRC_PATH_JAX}/build/build.py" \ + +pushd ${SRC_PATH_JAX} +time python "${SRC_PATH_JAX}/build/build.py" build \ --editable \ --use_clang \ - --enable_cuda \ - --build_gpu_plugin \ - --gpu_plugin_cuda_version=$TF_CUDA_MAJOR_VERSION \ + --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt \ --cuda_compute_capabilities=$TF_CUDA_COMPUTE_CAPABILITIES \ - --enable_nccl=true \ --bazel_options=--linkopt=-fuse-ld=lld \ - --bazel_options=--override_repository=xla=$SRC_PATH_XLA \ + --local_xla_path=$SRC_PATH_XLA \ --output_path=${BUILD_PATH_JAXLIB} \ $BUILD_PARAM +popd # Make sure that JAX depends on the local jaxlib installation # https://jax.readthedocs.io/en/latest/developer.html#specifying-dependencies-on-local-wheels @@ -300,8 +300,8 @@ line="jaxlib @ file://${BUILD_PATH_JAXLIB}/jaxlib" if ! grep -xF "${line}" "${SRC_PATH_JAX}/build/requirements.in"; then pushd "${SRC_PATH_JAX}" echo "${line}" >> build/requirements.in - echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt @ file://${BUILD_PATH_JAXLIB}/jax_gpu_pjrt" >> build/requirements.in - echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin @ file://${BUILD_PATH_JAXLIB}/jax_gpu_plugin" >> build/requirements.in + echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt @ file://${BUILD_PATH_JAXLIB}/jax-cuda-pjrt" >> build/requirements.in + echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin @ file://${BUILD_PATH_JAXLIB}/jax-cuda-plugin" >> build/requirements.in PYTHON_VERSION=$(python -c 'import sys; print("{}.{}".format(*sys.version_info[:2]))') bazel run --verbose_failures=true //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION="${PYTHON_VERSION}" popd @@ -316,13 +316,13 @@ else fi # install jax and jaxlib -pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax_gpu_pjrt -e ${BUILD_PATH_JAXLIB}/jax_gpu_plugin -e "${SRC_PATH_JAX}" +pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax-cuda-pjrt -e ${BUILD_PATH_JAXLIB}/jax-cuda-plugin -e "${SRC_PATH_JAX}" -# after installation (example) -# jax 0.4.32.dev20240808+9c2caedab /opt/jax -# jax-cuda12-pjrt 0.4.32.dev20240808 /opt/jaxlibs/jax_gpu_pjrt -# jax-cuda12-plugin 0.4.32.dev20240808 /opt/jaxlibs/jax_gpu_plugin -# jaxlib 0.4.32.dev20240808 /opt/jaxlibs/jaxlib +## after installation (example) +# jax 0.4.36.dev20241125+f828f2d7d /opt/jax +# jax-cuda12-pjrt 0.4.36.dev20241125 /opt/jaxlibs/jax-cuda-pjrt +# jax-cuda12-plugin 0.4.36.dev20241125 /opt/jaxlibs/jax-cuda-plugin +# jaxlib 0.4.36.dev20241125 /opt/jaxlibs/jaxlib pip list | grep jax # Ensure directories are readable by all for non-root users diff --git a/.github/container/test-jax.sh b/.github/container/test-jax.sh index a5fea5365..3a14c7a72 100755 --- a/.github/container/test-jax.sh +++ b/.github/container/test-jax.sh @@ -26,7 +26,7 @@ jax_source_dir() { query_tests() { cd `jax_source_dir` - python build/build.py --configure_only + python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --configure_only bazel query tests/... 2>&1 | grep -F '//tests:' exit } @@ -191,5 +191,5 @@ pip install matplotlib ## Run tests cd `jax_source_dir` -python build/build.py --configure_only +python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --configure_only bazel test ${BAZEL_TARGET} ${TEST_TAG_FILTERS} ${COMMON_FLAGS} ${EXTRA_FLAGS}