Skip to content

Commit

Permalink
Update CI to use cuda plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
bhavya01 committed May 10, 2024
1 parent a7b94c6 commit e14636a
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 3 deletions.
1 change: 1 addition & 0 deletions .circleci/triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ fi
apply_patches

python -c "import fcntl; fcntl.fcntl(1, fcntl.F_SETFL, 0)"
pip install --upgrade "jax[cuda12]"

export PATH=$PATH:/usr/local/cuda-12.1/bin
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/_build_plugin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
shell: bash
run: |
cd pytorch/xla/infra/ansible
ansible-playbook playbook.yaml -vvv -e "stage=build_plugin arch=amd64 accelerator=cuda cuda_compute_capabilities=5.2,7.5 src_root=${GITHUB_WORKSPACE} cache_suffix=-ci" --skip-tags=fetch_srcs,install_deps
ansible-playbook playbook.yaml -vvv -e "stage=build_plugin arch=amd64 accelerator=cuda cuda_compute_capabilities=5.2,7.5,8.6 src_root=${GITHUB_WORKSPACE} cache_suffix=-ci" --skip-tags=fetch_srcs,install_deps
- name: Upload wheel
uses: actions/upload-artifact@v4
with:
Expand Down
112 changes: 112 additions & 0 deletions .github/workflows/_triton.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
name: xla-test
on:
workflow_call:
inputs:
dev-image:
required: true
type: string
description: Base image for builds
runner:
required: false
type: string
description: Runner type for the test
default: linux.12xlarge
timeout-minutes:
required: false
type: number
default: 270
description: |
Set the maximum (in minutes) how long the workflow should take to finish
timeout-minutes:
install-cuda-plugin:
required: false
type: boolean
default: false
description: Whether to install CUDA plugin package

secrets:
gcloud-service-key:
required: true
description: Secret to access Bazel build cache
jobs:
test:
runs-on: ${{ inputs.runner }}
container:
image: ${{ inputs.dev-image }}
options: "${{ inputs.install-cuda-plugin && '--gpus all' || '' }} --shm-size 16g"
steps:
# See https://github.com/actions/checkout/issues/1014#issuecomment-1906802802
- name: Clean up workspace
run: |
ls -la
rm -rvf ${GITHUB_WORKSPACE}/*
- name: Setup gcloud
shell: bash
run: |
echo "${GCLOUD_SERVICE_KEY}" > $GOOGLE_APPLICATION_CREDENTIALS
- name: Fetch wheels
uses: actions/download-artifact@v4
with:
name: torch-xla-wheels
path: /tmp/wheels/
- name: Fetch CUDA plugin
uses: actions/download-artifact@v4
with:
name: cuda-plugin
path: /tmp/wheels/
if: ${{ inputs.install-cuda-plugin }}
- name: Setup CUDA environment
shell: bash
run: |
# TODO: Make PJRT_DEVICE=CPU work with XLA_REGISTER_INSTALLED_PLUGINS=1
echo "XLA_REGISTER_INSTALLED_PLUGINS=1" >> $GITHUB_ENV
echo "PATH=$PATH:/usr/local/cuda-12.1/bin" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64" >> $GITHUB_ENV
if: ${{ inputs.install-cuda-plugin }}
- name: Check GPU
run: nvidia-smi
if: ${{ inputs.install-cuda-plugin }}
- name: Install wheels
shell: bash
run: |
pip install /tmp/wheels/*.whl
# TODO: Add these in setup.py
pip install fsspec
pip install rich
echo "Import check..."
python -c "import torch_xla"
- name: Record PyTorch commit
run: |
# Don't just pipe output in shell because imports may do extra logging
python -c "
import torch_xla.version
with open('$GITHUB_ENV', 'a') as f:
f.write(f'PYTORCH_COMMIT={torch_xla.version.__torch_gitrev__}\n')
"
- name: Checkout PyTorch Repo
uses: actions/checkout@v4
with:
repository: pytorch/pytorch
path: pytorch
ref: ${{ env.PYTORCH_COMMIT }}
- name: Checkout PyTorch/XLA Repo
uses: actions/checkout@v4
with:
path: pytorch/xla
- name: Install test dependencies
shell: bash
run: |
# TODO: Add these in setup.py
pip install fsspec
pip install rich
# Jax nightly is needed for Triton tests.
pip install --upgrade "jax[cuda12]"
- name: Run Tests
env:
PJRT_DEVICE: CUDA
TRITON_PTXAS_PATH: /usr/local/cuda/bin/ptxas
run: |
cd pytorch/xla
python test/test_triton.py
13 changes: 13 additions & 0 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ jobs:
install-cuda-plugin: true
secrets:
gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}

test-triton:
name: "Triton tests"
uses: ./.github/workflows/_triton.yml
needs: [build-torch-xla, build-cuda-plugin]
with:
dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.1
runner: linux.g5.4xlarge.nvidia.gpu
timeout-minutes: 300
install-cuda-plugin: true
secrets:
gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}


test-tpu:
name: "TPU tests"
Expand Down
1 change: 0 additions & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ http_archive(
url = "https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz",
)


load("@pybind11_bazel//:python_configure.bzl", "python_configure")

# This is required for setting up the linkopts for -lpython.q
Expand Down
2 changes: 1 addition & 1 deletion bazel/rules_def.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ def ptxla_cc_test(
"@torch//:libtorch_python",
],
**kwargs
)
)

0 comments on commit e14636a

Please sign in to comment.