From e14636adf0d5658dab933cb9140f7ea6d7f13f0d Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Fri, 10 May 2024 19:23:25 +0000 Subject: [PATCH] Update CI to use cuda plugin --- .circleci/triton.sh | 1 + .github/workflows/_build_plugin.yml | 2 +- .github/workflows/_triton.yml | 112 +++++++++++++++++++++++++++ .github/workflows/build_and_test.yml | 13 ++++ WORKSPACE | 1 - bazel/rules_def.bzl | 2 +- 6 files changed, 128 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/_triton.yml diff --git a/.circleci/triton.sh b/.circleci/triton.sh index 9d8f51634b8..d27731d9f7e 100755 --- a/.circleci/triton.sh +++ b/.circleci/triton.sh @@ -20,6 +20,7 @@ fi apply_patches python -c "import fcntl; fcntl.fcntl(1, fcntl.F_SETFL, 0)" +pip install --upgrade "jax[cuda12]" export PATH=$PATH:/usr/local/cuda-12.1/bin export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64 diff --git a/.github/workflows/_build_plugin.yml b/.github/workflows/_build_plugin.yml index e30b88aed1e..09a7c477c8a 100644 --- a/.github/workflows/_build_plugin.yml +++ b/.github/workflows/_build_plugin.yml @@ -39,7 +39,7 @@ jobs: shell: bash run: | cd pytorch/xla/infra/ansible - ansible-playbook playbook.yaml -vvv -e "stage=build_plugin arch=amd64 accelerator=cuda cuda_compute_capabilities=5.2,7.5 src_root=${GITHUB_WORKSPACE} cache_suffix=-ci" --skip-tags=fetch_srcs,install_deps + ansible-playbook playbook.yaml -vvv -e "stage=build_plugin arch=amd64 accelerator=cuda cuda_compute_capabilities=5.2,7.5,8.6 src_root=${GITHUB_WORKSPACE} cache_suffix=-ci" --skip-tags=fetch_srcs,install_deps - name: Upload wheel uses: actions/upload-artifact@v4 with: diff --git a/.github/workflows/_triton.yml b/.github/workflows/_triton.yml new file mode 100644 index 00000000000..1d40181385f --- /dev/null +++ b/.github/workflows/_triton.yml @@ -0,0 +1,112 @@ +name: xla-test +on: + workflow_call: + inputs: + dev-image: + required: true + type: string + description: Base image for builds + runner: + required: false + type: string + description: Runner type for the test + default: linux.12xlarge + timeout-minutes: + required: false + type: number + default: 270 + description: | + Set the maximum (in minutes) how long the workflow should take to finish + timeout-minutes: + install-cuda-plugin: + required: false + type: boolean + default: false + description: Whether to install CUDA plugin package + + secrets: + gcloud-service-key: + required: true + description: Secret to access Bazel build cache +jobs: + test: + runs-on: ${{ inputs.runner }} + container: + image: ${{ inputs.dev-image }} + options: "${{ inputs.install-cuda-plugin && '--gpus all' || '' }} --shm-size 16g" + steps: + # See https://github.com/actions/checkout/issues/1014#issuecomment-1906802802 + - name: Clean up workspace + run: | + ls -la + rm -rvf ${GITHUB_WORKSPACE}/* + - name: Setup gcloud + shell: bash + run: | + echo "${GCLOUD_SERVICE_KEY}" > $GOOGLE_APPLICATION_CREDENTIALS + - name: Fetch wheels + uses: actions/download-artifact@v4 + with: + name: torch-xla-wheels + path: /tmp/wheels/ + - name: Fetch CUDA plugin + uses: actions/download-artifact@v4 + with: + name: cuda-plugin + path: /tmp/wheels/ + if: ${{ inputs.install-cuda-plugin }} + - name: Setup CUDA environment + shell: bash + run: | + # TODO: Make PJRT_DEVICE=CPU work with XLA_REGISTER_INSTALLED_PLUGINS=1 + echo "XLA_REGISTER_INSTALLED_PLUGINS=1" >> $GITHUB_ENV + + echo "PATH=$PATH:/usr/local/cuda-12.1/bin" >> $GITHUB_ENV + echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64" >> $GITHUB_ENV + if: ${{ inputs.install-cuda-plugin }} + - name: Check GPU + run: nvidia-smi + if: ${{ inputs.install-cuda-plugin }} + - name: Install wheels + shell: bash + run: | + pip install /tmp/wheels/*.whl + # TODO: Add these in setup.py + pip install fsspec + pip install rich + + echo "Import check..." + python -c "import torch_xla" + - name: Record PyTorch commit + run: | + # Don't just pipe output in shell because imports may do extra logging + python -c " + import torch_xla.version + with open('$GITHUB_ENV', 'a') as f: + f.write(f'PYTORCH_COMMIT={torch_xla.version.__torch_gitrev__}\n') + " + - name: Checkout PyTorch Repo + uses: actions/checkout@v4 + with: + repository: pytorch/pytorch + path: pytorch + ref: ${{ env.PYTORCH_COMMIT }} + - name: Checkout PyTorch/XLA Repo + uses: actions/checkout@v4 + with: + path: pytorch/xla + - name: Install test dependencies + shell: bash + run: | + # TODO: Add these in setup.py + pip install fsspec + pip install rich + # Jax nightly is needed for Triton tests. + pip install --upgrade "jax[cuda12]" + - name: Run Tests + env: + PJRT_DEVICE: CUDA + TRITON_PTXAS_PATH: /usr/local/cuda/bin/ptxas + run: | + cd pytorch/xla + python test/test_triton.py diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 60a2eda44cd..a46a79b5da3 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -59,6 +59,19 @@ jobs: install-cuda-plugin: true secrets: gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} + + test-triton: + name: "Triton tests" + uses: ./.github/workflows/_triton.yml + needs: [build-torch-xla, build-cuda-plugin] + with: + dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.1 + runner: linux.g5.4xlarge.nvidia.gpu + timeout-minutes: 300 + install-cuda-plugin: true + secrets: + gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} + test-tpu: name: "TPU tests" diff --git a/WORKSPACE b/WORKSPACE index 86c78c57bda..c73231dd402 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -24,7 +24,6 @@ http_archive( url = "https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz", ) - load("@pybind11_bazel//:python_configure.bzl", "python_configure") # This is required for setting up the linkopts for -lpython.q diff --git a/bazel/rules_def.bzl b/bazel/rules_def.bzl index 92a5860cac1..08c7d237f65 100644 --- a/bazel/rules_def.bzl +++ b/bazel/rules_def.bzl @@ -37,4 +37,4 @@ def ptxla_cc_test( "@torch//:libtorch_python", ], **kwargs - ) \ No newline at end of file + )