Skip to content

Commit

Permalink
Install jaxlib while setting up triton tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bhavya01 committed May 10, 2024
1 parent e14636a commit 256d819
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
4 changes: 3 additions & 1 deletion .circleci/triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ fi
apply_patches

python -c "import fcntl; fcntl.fcntl(1, fcntl.F_SETFL, 0)"
pip install --upgrade "jax[cuda12]"
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -U --pre jax-cuda12-pjrt jax-cuda12-plugin -f https://storage.googleapis.com/jax-releases/jax_cuda_plugin_nightly_releases.html
pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html

export PATH=$PATH:/usr/local/cuda-12.1/bin
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64
Expand Down
12 changes: 10 additions & 2 deletions .github/workflows/_triton.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: xla-test
name: triton-test
on:
workflow_call:
inputs:
Expand Down Expand Up @@ -34,6 +34,12 @@ jobs:
container:
image: ${{ inputs.dev-image }}
options: "${{ inputs.install-cuda-plugin && '--gpus all' || '' }} --shm-size 16g"
timeout-minutes: ${{ inputs.timeout-minutes }}
env:
GCLOUD_SERVICE_KEY: ${{ secrets.gcloud-service-key }}
GOOGLE_APPLICATION_CREDENTIALS: /tmp/default_credentials.json
BAZEL_JOBS: 16
BAZEL_REMOTE_CACHE: 1
steps:
# See https://github.com/actions/checkout/issues/1014#issuecomment-1906802802
- name: Clean up workspace
Expand Down Expand Up @@ -102,7 +108,9 @@ jobs:
pip install fsspec
pip install rich
# Jax nightly is needed for Triton tests.
pip install --upgrade "jax[cuda12]"
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -U --pre jax-cuda12-pjrt jax-cuda12-plugin -f https://storage.googleapis.com/jax-releases/jax_cuda_plugin_nightly_releases.html
pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
- name: Run Tests
env:
PJRT_DEVICE: CUDA
Expand Down
2 changes: 1 addition & 1 deletion infra/tpu-pytorch/infra_triggers.tf
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module "terraform_apply" {
source = "../terraform_modules/apply_terraform_trigger"

included_files = ["infra/**"]
branch = "main"
branch = "master"
config_directory = "infra/tpu-pytorch"

worker_pool_id = module.worker_pool.id
Expand Down

0 comments on commit 256d819

Please sign in to comment.