diff --git a/.circleci/README.md b/.circleci/README.md deleted file mode 100644 index d01e6138317..00000000000 --- a/.circleci/README.md +++ /dev/null @@ -1,19 +0,0 @@ -# CircleCI Overview -PyTorch and PyTorch/XLA use CircleCI to lint, build, and test each PR that is submitted. All CircleCI tests should succeed before the PR is merged into master. PyTorch CircleCI pins PyTorch/XLA to a specific commit. On the other hand, PyTorch/XLA CircleCI pulls PyTorch from master unless a pin is manually provided. This README will go through the reasons of these pins, how to pin a PyTorch/XLA PR to an upstream PyTorch PR, and how to coordinate a merge for breaking PyTorch changes. - -## Why does PyTorch CircleCI pin PyTorch/XLA? -As mentioned above, [PyTorch CircleCI pins PyTorch/XLA](https://github.com/pytorch/pytorch/blob/master/.jenkins/pytorch/common_utils.sh#L119) to a "known good" commit to prevent accidental changes from PyTorch/XLA to break PyTorch CircleCI without warning. PyTorch has hundreds of commits each week, and this pin ensures that PyTorch/XLA as a downstream package does not cause failures in PyTorch CircleCI. - -## Why does PyTorch/XLA CircleCI pull from PyTorch master? -[PyTorch/XLA CircleCI pulls PyTorch from master](https://github.com/pytorch/xla/blob/f3415929683880192b63b285921c72439af55bf0/.circleci/common.sh#L15) unless a PyTorch pin is manually provided. PyTorch/XLA is a downstream package to PyTorch, and pulling from master ensures that PyTorch/XLA will stay up-to-date and works with the latest PyTorch changes. - -## Pinning PyTorch PR in PyTorch/XLA PR -Sometimes a PyTorch/XLA PR needs to be pinned to a specific PyTorch PR to test new featurues, fix breaking changes, etc. Since PyTorch/XLA CircleCI pulls from PyTorch master by default, we need to manually provided a PyTorch pin. In a PyTorch/XLA PR, PyTorch an be manually pinned by creating a `.torch_pin` under `/torch_patches`. The `.torch_pin` should have the corresponding PyTorch PR number prefixed by "#". Take a look at [example here](https://github.com/pytorch/xla/pull/3792/commits/40f41fb98b0f2386d287eeac0bae86e873d4a9d8). Before the PyTorch/XLA PR gets merged, the `.torch_pin` must be deleted. - -## Coodinating merges for breaking PyTorch PRs -When PyTorch PR introduces a breaking change, its PyTorch/XLA CircleCI tests will fail. Steps for fixing and merging such breaking PyTorch change is as following: -1. Create a PyTorch/XLA PR to fix this issue with `.torch_pin` and rebase with master to ensure the PR is up-to-date with the latest commit on PyTorch/XLA. Once this PR is created, it'll create a commit hash that will be used in step 2. If you have multiple commits in the PR, use the last one's hash. **Important note: When you rebase this PR, it'll create a new commit hash and make the old hash obsolete. Be cautious about rebasing, and if you rebase, make sure you inform the PyTorch PR's author.** -2. Rebase (or ask the PR owner to rebase) the PyTorch PR with master. Update the PyTorch PR to pin the PyTorch/XLA to the commit hash created in step 1 by updating `pytorch/.github/ci_commit_pins/xla.txt`. -3. Once CircleCI tests are green on both ends, merge PyTorch PR. -4. Remove the `.torch_pin` in PyTorch/XLA PR and merge. To be noted, `git commit --amend` should be avoided in this step as PyTorch CI will keep using the commit hash created in step 1 until other PRs update that manually or the nightly buildbot updates that automatically. -5. Finally, don't delete your branch until 2 days later. See step 4 for explanations. diff --git a/.circleci/doc_push.sh b/.circleci/doc_push.sh deleted file mode 100755 index 72b4a44f6e7..00000000000 --- a/.circleci/doc_push.sh +++ /dev/null @@ -1,63 +0,0 @@ -#!/bin/bash - -set -ex - -cd /tmp/pytorch/xla - -source ./xla_env -source .circleci/common.sh - -echo "Building docs" -pushd docs -./docs_build.sh -popd - -echo "Pushing to public" -git config --global user.email "pytorchxla@gmail.com" -git config --global user.name "torchxlabot2" -GH_PAGES_BRANCH=gh-pages -GH_PAGES_DIR=gh-pages-tmp -CURRENT_COMMIT=`git rev-parse HEAD` -BRANCH_NAME=`git rev-parse --abbrev-ref HEAD` -if [[ "$BRANCH_NAME" == release/* ]]; then - SUBDIR_NAME=$BRANCH_NAME -else - SUBDIR_NAME="master" -fi -pushd /tmp -git clone --quiet -b "$GH_PAGES_BRANCH" https://github.com/pytorch/xla.git "$GH_PAGES_DIR" -pushd $GH_PAGES_DIR -rm -rf $SUBDIR_NAME -mkdir -p $SUBDIR_NAME -cp -fR /tmp/pytorch/xla/docs/build/* $SUBDIR_NAME -git_status=$(git status --porcelain) -if [[ $git_status ]]; then - echo "Doc is updated... Pushing to public" - echo "${git_status}" - sudo apt-get -qq update - export DEBIAN_FRONTEND=noninteractive - sudo ln -snf /usr/share/zoneinfo/Etc/UTC /etc/localtime - sudo sh -c "echo Etc/UTC > /etc/timezone" - sudo apt-get -qq -y install tzdata - sudo apt-get -qq install expect - git add . - - COMMIT_MSG="Update doc from commit $CURRENT_COMMIT" - git commit -m "$COMMIT_MSG" - set +x -/usr/bin/expect < /dev/null 2>&1 ; then - VER="buster" - else - VER=$(lsb_release -c -s) - fi - echo "$VER" -} - -function install_llvm_clang() { - local DEBVER=$(debian_version) - if ! apt-get install -y -s clang-8 > /dev/null 2>&1 ; then - maybe_append "deb http://apt.llvm.org/${DEBVER}/ llvm-toolchain-${DEBVER}-8 main" /etc/apt/sources.list - maybe_append "deb-src http://apt.llvm.org/${DEBVER}/ llvm-toolchain-${DEBVER}-8 main" /etc/apt/sources.list - wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - - sudo apt-get update - fi - # Build config also sets CC=clang-8, CXX=clang++-8 - sudo apt-get install -y clang-8 clang++-8 - sudo apt-get install -y llvm-8 llvm-8-dev llvm-8-tools - sudo ln -s /usr/bin/clang-8 /usr/bin/clang - sudo ln -s /usr/bin/clang++-8 /usr/bin/clang++ - export CC=clang-8 CXX=clang++-8 -} - -install_llvm_clang diff --git a/.circleci/setup_ci_environment.sh b/.circleci/setup_ci_environment.sh index eba2c373b8a..87a61524e7e 100755 --- a/.circleci/setup_ci_environment.sh +++ b/.circleci/setup_ci_environment.sh @@ -58,7 +58,7 @@ sudo apt-get -y remove linux-image-generic linux-headers-generic linux-generic d # How to figure out what the correct versions of these packages are? # My preferred method is to start a Docker instance of the correct # Ubuntu version (e.g., docker run -it ubuntu:16.04) and then ask -# apt what the packages you need are. Note that the CircleCI image +# apt what the packages you need are. Note that the CI image # comes with Docker. # # Using 'retry' here as belt-and-suspenders even though we are diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 87072a65bce..bfff4ef8422 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -/infra @will-cromar @JackCaoG @yeounoh @mateuszlewko @stgpetrovic +/infra @will-cromar @JackCaoG @lsy323 diff --git a/.github/ci.md b/.github/ci.md new file mode 100644 index 00000000000..c2f4d37426c --- /dev/null +++ b/.github/ci.md @@ -0,0 +1,19 @@ +# CI Overview +PyTorch and PyTorch/XLA use CI to lint, build, and test each PR that is submitted. All CI tests should succeed before the PR is merged into master. PyTorch CI pins PyTorch/XLA to a specific commit. On the other hand, PyTorch/XLA CI pulls PyTorch from master unless a pin is manually provided. This README will go through the reasons of these pins, how to pin a PyTorch/XLA PR to an upstream PyTorch PR, and how to coordinate a merge for breaking PyTorch changes. + +## Why does PyTorch CI pin PyTorch/XLA? +As mentioned above, [PyTorch CI pins PyTorch/XLA](https://github.com/pytorch/pytorch/blob/master/.jenkins/pytorch/common_utils.sh#L119) to a "known good" commit to prevent accidental changes from PyTorch/XLA to break PyTorch CI without warning. PyTorch has hundreds of commits each week, and this pin ensures that PyTorch/XLA as a downstream package does not cause failures in PyTorch CI. + +## Why does PyTorch/XLA CI pull from PyTorch master? +[PyTorch/XLA CI pulls PyTorch from master](https://github.com/pytorch/xla/blob/f3415929683880192b63b285921c72439af55bf0/.circleci/common.sh#L15) unless a PyTorch pin is manually provided. PyTorch/XLA is a downstream package to PyTorch, and pulling from master ensures that PyTorch/XLA will stay up-to-date and works with the latest PyTorch changes. + +## Pinning PyTorch PR in PyTorch/XLA PR +Sometimes a PyTorch/XLA PR needs to be pinned to a specific PyTorch PR to test new featurues, fix breaking changes, etc. Since PyTorch/XLA CI pulls from PyTorch master by default, we need to manually provided a PyTorch pin. In a PyTorch/XLA PR, PyTorch an be manually pinned by creating a `.torch_pin` file at the root of the repository. The `.torch_pin` should have the corresponding PyTorch PR number prefixed by "#". Take a look at [example here](https://github.com/pytorch/xla/pull/3792/commits/40f41fb98b0f2386d287eeac0bae86e873d4a9d8). Before the PyTorch/XLA PR gets merged, the `.torch_pin` must be deleted. + +## Coodinating merges for breaking PyTorch PRs +When PyTorch PR introduces a breaking change, its PyTorch/XLA CI tests will fail. Steps for fixing and merging such breaking PyTorch change is as following: +1. Create a PyTorch/XLA PR to fix this issue with `.torch_pin` and rebase with master to ensure the PR is up-to-date with the latest commit on PyTorch/XLA. Once this PR is created, it'll create a commit hash that will be used in step 2. If you have multiple commits in the PR, use the last one's hash. **Important note: When you rebase this PR, it'll create a new commit hash and make the old hash obsolete. Be cautious about rebasing, and if you rebase, make sure you inform the PyTorch PR's author.** +2. Rebase (or ask the PR owner to rebase) the PyTorch PR with master. Update the PyTorch PR to pin the PyTorch/XLA to the commit hash created in step 1 by updating `pytorch/.github/ci_commit_pins/xla.txt`. +3. Once CI tests are green on both ends, merge PyTorch PR. +4. Remove the `.torch_pin` in PyTorch/XLA PR and merge. To be noted, `git commit --amend` should be avoided in this step as PyTorch CI will keep using the commit hash created in step 1 until other PRs update that manually or the nightly buildbot updates that automatically. +5. Finally, don't delete your branch until 2 days later. See step 4 for explanations. diff --git a/.github/scripts/run_tests.sh b/.github/scripts/run_tests.sh new file mode 100755 index 00000000000..ae59a51490d --- /dev/null +++ b/.github/scripts/run_tests.sh @@ -0,0 +1,108 @@ +set -ex + +function run_torch_xla_python_tests() { + PYTORCH_DIR=$1 + XLA_DIR=$2 + USE_COVERAGE="${3:-0}" + + pushd $XLA_DIR + echo "Running Python Tests" + if [ "$USE_COVERAGE" != "0" ]; then + pip install coverage==6.5.0 --upgrade + pip install coverage-lcov + pip install toml + ./test/run_tests.sh + coverage combine + mkdir lcov && cp .coverage lcov/ + coverage-lcov --data_file_path lcov/.coverage + coverage html + cp lcov.info htmlcov/ + mv htmlcov ~/ + chmod -R 755 ~/htmlcov + else + ./test/run_tests.sh + fi + popd +} + +function run_torch_xla_cpp_tests() { + PYTORCH_DIR=$1 + XLA_DIR=$2 + USE_COVERAGE="${3:-0}" + + TORCH_DIR=$(python -c "import pkgutil; import os; print(os.path.dirname(pkgutil.get_loader('torch').get_filename()))") + export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${TORCH_DIR}/lib + if [ -x "$(command -v nvidia-smi)" ]; then + CUDA_PLUGIN_DIR=$(python -c "import pkgutil; import os; print(os.path.dirname(pkgutil.get_loader('torch_xla_cuda_plugin').get_filename()))") + export PJRT_LIBRARY_PATH=$CUDA_PLUGIN_DIR/lib/pjrt_c_api_gpu_plugin.so + export PJRT_DEVICE=LIBRARY + export PJRT_DYNAMIC_PLUGINS=1 + else + export PJRT_DEVICE=CPU + fi + export XLA_EXPERIMENTAL="nonzero:masked_select:nms" + + test_names1=("test_aten_xla_tensor_1" + "test_aten_xla_tensor_2" + "test_aten_xla_tensor_3" + "test_aten_xla_tensor_4" + "pjrt_computation_client_test" + "ifrt_computation_client_test") + test_names2=("test_aten_xla_tensor_5" + "test_aten_xla_tensor_6" + "test_ir" + "test_lazy" + "test_replication" + "test_tensor" + # disable test_xla_backend_intf since it is flaky on upstream + #"test_xla_backend_intf" + "test_xla_sharding") + if [[ "$RUN_CPP_TESTS1" == "cpp_tests1" ]]; then + test_names=("${test_names1[@]}") + elif [[ "$RUN_CPP_TESTS2" == "cpp_tests2" ]]; then + test_names=("${test_names2[@]}") + else + test_names=("${test_names1[@]}" "${test_names2[@]}") + fi + + for name in "${test_names[@]}"; do + echo "Running $name cpp test..." + /tmp/test/bin/${name} + done +} + +function run_torch_xla_benchmark_tests() { + XLA_DIR=$1 + pushd $XLA_DIR + echo "Running Benchmark Tests" + test/benchmarks/run_tests.sh -L"" +} + +PYTORCH_DIR=$1 +XLA_DIR=$2 +USE_COVERAGE="${3:-0}" +RUN_CPP="${RUN_CPP_TESTS:0}" +RUN_PYTHON="${RUN_PYTHON_TESTS:0}" + +if [ -x "$(command -v nvidia-smi)" ]; then + num_devices=$(nvidia-smi --list-gpus | wc -l) + echo "Found $num_devices GPU devices..." + export GPU_NUM_DEVICES=$num_devices +fi +export PYTORCH_TESTING_DEVICE_ONLY_FOR="xla" +export CXX_ABI=$(python -c "import torch;print(int(torch._C._GLIBCXX_USE_CXX11_ABI))") + +if [[ -z "$RUN_BENCHMARK_TESTS" && -z "$RUN_CPP_TESTS1" && -z "$RUN_CPP_TESTS2" && -z "$RUN_PYTHON_TESTS" ]]; then + run_torch_xla_python_tests $PYTORCH_DIR $XLA_DIR $USE_COVERAGE + run_torch_xla_cpp_tests $PYTORCH_DIR $XLA_DIR $USE_COVERAGE + run_torch_xla_benchmark_tests $XLA_DIR +else + # run tests separately. + if [[ "$RUN_PYTHON_TESTS" == "python_tests" ]]; then + run_torch_xla_python_tests $PYTORCH_DIR $XLA_DIR $USE_COVERAGE + elif [[ "$RUN_BENCHMARK_TESTS" == "benchmark_tests" ]]; then + run_torch_xla_benchmark_tests $XLA_DIR + else + run_torch_xla_cpp_tests $PYTORCH_DIR $XLA_DIR $USE_COVERAGE + fi +fi diff --git a/.circleci/docker/Dockerfile b/.github/upstream/Dockerfile similarity index 98% rename from .circleci/docker/Dockerfile rename to .github/upstream/Dockerfile index f0cd196511c..006460c2477 100644 --- a/.circleci/docker/Dockerfile +++ b/.github/upstream/Dockerfile @@ -1,3 +1,4 @@ +# Dockerfile for image used by upstream CI # This requires cuda & cudnn packages pre-installed in the base image. # Other available cuda images are listed at https://hub.docker.com/r/nvidia/cuda ARG base_image="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_12.1" diff --git a/.circleci/docker/install_conda.sh b/.github/upstream/install_conda.sh similarity index 100% rename from .circleci/docker/install_conda.sh rename to .github/upstream/install_conda.sh diff --git a/.circleci/docker/install_valgrind.sh b/.github/upstream/install_valgrind.sh similarity index 100% rename from .circleci/docker/install_valgrind.sh rename to .github/upstream/install_valgrind.sh diff --git a/.github/workflows/_build.yml b/.github/workflows/_build.yml deleted file mode 100644 index 789d0579272..00000000000 --- a/.github/workflows/_build.yml +++ /dev/null @@ -1,111 +0,0 @@ -name: xla-buld -on: - workflow_call: - inputs: - gcr-docker-image: - required: true - type: string - description: Base image for builds - ecr-docker-image-base: - required: true - type: string - description: Container registry to upload image to - runner: - required: false - type: string - description: Runner type for the test - default: linux.12xlarge - cuda: - required: false - type: string - description: Whether to build XLA with CUDA - default: 1 - - secrets: - gcloud-service-key: - required: true - description: Secret to access Bazel build cache - - outputs: - docker-image: - value: ${{ jobs.build.outputs.docker-image }} - description: The docker image containing the built PyTorch. -jobs: - build: - runs-on: ${{ inputs.runner }} - timeout-minutes: 240 - outputs: - docker-image: ${{ steps.upload-docker-image.outputs.docker-image }} - env: - ECR_DOCKER_IMAGE_BASE: ${{ inputs.ecr-docker-image-base }} - GCR_DOCKER_IMAGE: ${{ inputs.gcr-docker-image }} - WORKDIR: /var/lib/jenkins/workspace - SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 - GCLOUD_SERVICE_KEY: ${{ secrets.gcloud-service-key }} - XLA_CUDA: ${{ inputs.cuda }} - BAZEL_JOBS: 16 - steps: - - name: Setup Linux - uses: pytorch/test-infra/.github/actions/setup-linux@main - - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@main - with: - github-secret: ${{ secrets.GITHUB_TOKEN }} - instructions: | - Build is done inside the container, to start an interactive session run: - docker exec -it $(docker container ps --format '{{.ID}}') bash - - name: Checkout repo - uses: actions/checkout@v3 - - name: Download docker image from GCR - shell: bash - run: docker pull "${GCR_DOCKER_IMAGE}" - - name: Stage image to ECR - shell: bash - run: | - # This is to stage PyTorch/XLA base image for use in the upstream. - # To allow the upstream workflow to access PyTorch/XLA build images, we - # need to have them in the ECR. This is not expensive, and only pushes it - # if image layers are not present in the repo. - # Note: disable the following 2 lines while testing a new image, so we do not - # push to the upstream. - docker tag "${GCR_DOCKER_IMAGE}" "${ECR_DOCKER_IMAGE_BASE}:v1.1-lite" >/dev/null - docker push "${ECR_DOCKER_IMAGE_BASE}:v1.1-lite" >/dev/null - - name: Start the container - shell: bash - run: | - pid=$(docker run --privileged -t -d -w "$WORKDIR" "${GCR_DOCKER_IMAGE}") - docker exec -u jenkins "${pid}" sudo chown -R jenkins "${WORKDIR}" - docker cp "${GITHUB_WORKSPACE}/." "$pid:$WORKDIR" - echo "pid=${pid}" >> "${GITHUB_ENV}" - - - name: Prepare build env - shell: bash - run: | - echo "declare -x SCCACHE_BUCKET=${SCCACHE_BUCKET}" | docker exec -i "${pid}" sh -c "cat >> env" - echo "declare -x XLA_CUDA=${XLA_CUDA}" | docker exec -i "${pid}" sh -c "cat >> xla_env" - echo "declare -x BAZEL_JOBS=${BAZEL_JOBS}" | docker exec -i "${pid}" sh -c "cat >> xla_env" - echo "declare -x BAZEL_REMOTE_CACHE=1" | docker exec -i "${pid}" sh -c "cat >> xla_env" - echo "${GCLOUD_SERVICE_KEY}" | docker exec -i "${pid}" sh -c "cat >> default_credentials.json" - - - name: Build - shell: bash - run: | - docker exec --privileged -u jenkins "${pid}" bash -c ".circleci/build.sh" - - name: Cleanup build env - shell: bash - run: | - docker exec "${pid}" rm default_credentials.json /tmp/pytorch/xla/default_credentials.json - - - name: Push built docker image to ECR - id: upload-docker-image - shell: bash - run: | - export COMMIT_DOCKER_IMAGE="${ECR_DOCKER_IMAGE_BASE}:latest-${GITHUB_SHA}" - time docker commit "${pid}" "${COMMIT_DOCKER_IMAGE}" - time docker push "${COMMIT_DOCKER_IMAGE}" - echo "docker-image=${COMMIT_DOCKER_IMAGE}" >> "${GITHUB_OUTPUT}" - - - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@main - if: always() - diff --git a/.github/workflows/_build_plugin.yml b/.github/workflows/_build_plugin.yml index 5f773971430..e30b88aed1e 100644 --- a/.github/workflows/_build_plugin.yml +++ b/.github/workflows/_build_plugin.yml @@ -39,7 +39,7 @@ jobs: shell: bash run: | cd pytorch/xla/infra/ansible - ansible-playbook playbook.yaml -vvv -e "stage=build_plugin arch=amd64 accelerator=cuda src_root=${GITHUB_WORKSPACE}" --skip-tags=fetch_srcs,install_deps + ansible-playbook playbook.yaml -vvv -e "stage=build_plugin arch=amd64 accelerator=cuda cuda_compute_capabilities=5.2,7.5 src_root=${GITHUB_WORKSPACE} cache_suffix=-ci" --skip-tags=fetch_srcs,install_deps - name: Upload wheel uses: actions/upload-artifact@v4 with: diff --git a/.github/workflows/_build_torch_xla.yml b/.github/workflows/_build_torch_xla.yml index 969fb3b5dc9..58a783216e4 100644 --- a/.github/workflows/_build_torch_xla.yml +++ b/.github/workflows/_build_torch_xla.yml @@ -26,6 +26,7 @@ jobs: GOOGLE_APPLICATION_CREDENTIALS: /tmp/default_credentials.json BAZEL_JOBS: 16 BAZEL_REMOTE_CACHE: 1 + BUILD_CPP_TESTS: 1 steps: - name: Setup gcloud shell: bash @@ -36,7 +37,7 @@ jobs: with: repository: pytorch/pytorch path: pytorch - # TODO: correct pin + submodules: recursive - name: Checkout PyTorch/XLA Repo uses: actions/checkout@v4 with: @@ -45,9 +46,14 @@ jobs: shell: bash run: | cd pytorch/xla/infra/ansible - ansible-playbook playbook.yaml -e "stage=build arch=amd64 accelerator=tpu src_root=${GITHUB_WORKSPACE} bundle_libtpu=0" --skip-tags=fetch_srcs,install_deps + ansible-playbook playbook.yaml -vvv -e "stage=build arch=amd64 accelerator=tpu src_root=${GITHUB_WORKSPACE} bundle_libtpu=0 build_cpp_tests=1 git_versioned_xla_build=1 cache_suffix=-ci" --skip-tags=fetch_srcs,install_deps - name: Upload wheel uses: actions/upload-artifact@v4 with: name: torch-xla-wheels path: /dist/*.whl + - name: Upload CPP test binaries + uses: actions/upload-artifact@v4 + with: + name: cpp-test-bin + path: /tmp/test/bin diff --git a/.github/workflows/_build_upstream_image.yml b/.github/workflows/_build_upstream_image.yml new file mode 100644 index 00000000000..ef0975b6abf --- /dev/null +++ b/.github/workflows/_build_upstream_image.yml @@ -0,0 +1,44 @@ +name: xla-buld +on: + workflow_call: + inputs: + ecr-docker-image-base: + required: true + type: string + description: Container registry to upload image to + runner: + required: false + type: string + description: Runner type for the test + default: linux.12xlarge +jobs: + build: + runs-on: ${{ inputs.runner }} + timeout-minutes: 240 + env: + ECR_DOCKER_IMAGE_BASE: ${{ inputs.ecr-docker-image-base }} + BAZEL_JOBS: 16 + steps: + # See https://github.com/actions/checkout/issues/1014#issuecomment-1906802802 + - name: Clean up workspace + run: | + ls -la + sudo rm -rvf ${GITHUB_WORKSPACE}/* + - name: Setup Linux + uses: pytorch/test-infra/.github/actions/setup-linux@main + - name: Checkout repo + uses: actions/checkout@v3 + - name: Download docker image from GCR + shell: bash + run: | + docker build -t "${ECR_DOCKER_IMAGE_BASE}:v1.2-lite" .github/upstream + - name: Stage image to ECR + shell: bash + run: | + # This is to stage PyTorch/XLA base image for use in the upstream. + # To allow the upstream workflow to access PyTorch/XLA build images, we + # need to have them in the ECR. This is not expensive, and only pushes it + # if image layers are not present in the repo. + # Note: disable the following line while testing a new image, so we do not + # push to the upstream. + docker push "${ECR_DOCKER_IMAGE_BASE}:v1.2-lite" diff --git a/.github/workflows/_docs.yml b/.github/workflows/_docs.yml index ed9a4ab0ea9..378dec9697a 100644 --- a/.github/workflows/_docs.yml +++ b/.github/workflows/_docs.yml @@ -2,10 +2,10 @@ name: xla-docs-build on: workflow_call: inputs: - docker-image: + dev-image: required: true type: string - description: Image to build docs in + description: Base image for builds runner: required: false type: string @@ -15,35 +15,57 @@ on: torchxla-bot-token: required: true jobs: - push-docs: - runs-on: ${{ inputs.runner }} + build-docs: + runs-on: ubuntu-latest timeout-minutes: 45 + container: + image: ${{ inputs.dev-image }} env: - DOCKER_IMAGE: ${{ inputs.docker-image }} - WORKDIR: /var/lib/jenkins/workspace + BRANCH_NAME: ${{ github.ref_name }} steps: - - name: Setup Linux - uses: pytorch/test-infra/.github/actions/setup-linux@main - - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@main - with: - github-secret: ${{ secrets.GITHUB_TOKEN }} - instructions: | - Doc builds are done inside container. Interactive session can be started by following: - docker exec -it $(docker container ps --format '{{.ID}}') bash - - name: Download and run docker image from GCR - shell: bash - env: - GITHUB_TORCH_XLA_BOT_TOKEN: ${{ secrets. torchxla-bot-token }} - run: | - echo "DOCKER_IMAGE: ${DOCKER_IMAGE}" - docker pull "${DOCKER_IMAGE}" - pid=$(docker run -e GITHUB_TORCH_XLA_BOT_TOKEN -t -d -w "$WORKDIR" "${DOCKER_IMAGE}") - echo "${GCLOUD_SERVICE_KEY}" | docker exec -i "${pid}" sh -c "cat >> /tmp/pytorch/xla/default_credentials.json" - echo "pid=${pid}" >> "${GITHUB_ENV}" - - name: Build & publish docs - shell: bash - run: docker exec -u jenkins "${pid}" bash -c '.circleci/doc_push.sh' - - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@main - if: always() + - name: Fetch wheels + uses: actions/download-artifact@v4 + with: + name: torch-xla-wheels + path: /tmp/wheels/ + - name: Install wheels + shell: bash + run: | + pip install /tmp/wheels/*.whl + - name: Checkout PyTorch/XLA Repo + uses: actions/checkout@v4 + with: + path: pytorch/xla + - name: Build docs + shell: bash + run: | + cd pytorch/xla/docs + pip install -r requirements.txt + sphinx-build -b html source build + - name: Checkout GitHub Pages + uses: actions/checkout@v4 + with: + path: gh-pages + ref: gh-pages + token: ${{ secrets.torchxla-bot-token }} + - name: Merge changes + shell: bash + run: | + subdir=${{ env.BRANCH_NAME == 'master' && 'master' || format('{0}/{1}', 'release', env.BRANCH_NAME) }} + mkdir -p gh-pages/$subdir + cp -fR pytorch/xla/docs/build/* gh-pages/$subdir + - name: Upload preview as artifact + uses: actions/upload-artifact@v4 + with: + name: github-pages + path: pytorch/xla/docs/build/ + - name: Deploy + shell: bash + run: | + cd gh-pages + git config user.email "pytorchxla@gmail.com" + git config user.name "torchxlabot2" + git add . -v + git diff --cached --exit-code || git commit -m "Update doc from commit ${{ github.sha }}" + git push origin gh-pages + if: github.event_name == 'push' diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml index 0f9e96e31e5..8a454cc075b 100644 --- a/.github/workflows/_test.yml +++ b/.github/workflows/_test.yml @@ -2,10 +2,10 @@ name: xla-test on: workflow_call: inputs: - docker-image: + dev-image: required: true type: string - description: Image to test on + description: Base image for builds runner: required: false type: string @@ -22,16 +22,12 @@ on: default: 270 description: | Set the maximum (in minutes) how long the workflow should take to finish - disable-pjrt: + timeout-minutes: + install-cuda-plugin: required: false - type: string - default: 0 - description: Whether to disable PJRT tests - test-script: - required: false - type: string - default: test.sh - description: Which test script to run + type: boolean + default: false + description: Whether to install CUDA plugin package secrets: gcloud-service-key: @@ -40,14 +36,15 @@ on: jobs: test: runs-on: ${{ inputs.runner }} + container: + image: ${{ inputs.dev-image }} + options: "${{ inputs.install-cuda-plugin && '--gpus all' || '' }} --shm-size 16g" strategy: fail-fast: false matrix: include: # Use readable strings as they define the workflow titles. - run_benchmark_tests: 'benchmark_tests' - - run_cpp_tests1: 'cpp_tests1' - - run_cpp_tests2: 'cpp_tests2' - run_python_tests: 'python_tests' run_xla_op_tests1: 'xla_op1' - run_python_tests: 'python_tests' @@ -56,63 +53,112 @@ jobs: run_xla_op_tests3: 'xla_op3' - run_python_tests: 'python_tests' run_torch_mp_op_tests: 'torch_mp_op' + - run_cpp_tests: 'cpp_tests' + run_cpp_tests1: 'cpp_tests1' + - run_cpp_tests: 'cpp_tests' + run_cpp_tests2: 'cpp_tests2' timeout-minutes: ${{ inputs.timeout-minutes }} env: - DOCKER_IMAGE: ${{ inputs.docker-image }} - WORKDIR: /var/lib/jenkins/workspace GCLOUD_SERVICE_KEY: ${{ secrets.gcloud-service-key }} + GOOGLE_APPLICATION_CREDENTIALS: /tmp/default_credentials.json USE_COVERAGE: ${{ inputs.collect-coverage && '1' || '0' }} - XLA_SKIP_TORCH_OP_TESTS: ${{ inputs.disable-pjrt }} - XLA_SKIP_MP_OP_TESTS: ${{ inputs.disable-pjrt }} RUN_BENCHMARK_TESTS: ${{ matrix.run_benchmark_tests }} - RUN_CPP_TESTS1: ${{ matrix.run_cpp_tests1 }} - RUN_CPP_TESTS2: ${{ matrix.run_cpp_tests2 }} RUN_PYTHON_TESTS: ${{ matrix.run_python_tests }} RUN_XLA_OP_TESTS1: ${{ matrix.run_xla_op_tests1 }} RUN_XLA_OP_TESTS2: ${{ matrix.run_xla_op_tests2 }} RUN_XLA_OP_TESTS3: ${{ matrix.run_xla_op_tests3 }} RUN_TORCH_MP_OP_TESTS: ${{ matrix.run_torch_mp_op_tests }} + RUN_CPP_TESTS1: ${{ matrix.run_cpp_tests1 }} + RUN_CPP_TESTS2: ${{ matrix.run_cpp_tests2 }} + BAZEL_JOBS: 16 + BAZEL_REMOTE_CACHE: 1 steps: - - name: Setup Linux - uses: pytorch/test-infra/.github/actions/setup-linux@main - - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@main + # See https://github.com/actions/checkout/issues/1014#issuecomment-1906802802 + - name: Clean up workspace + run: | + ls -la + rm -rvf ${GITHUB_WORKSPACE}/* + - name: Setup gcloud + shell: bash + run: | + echo "${GCLOUD_SERVICE_KEY}" > $GOOGLE_APPLICATION_CREDENTIALS + - name: Fetch wheels + uses: actions/download-artifact@v4 with: - github-secret: ${{ secrets.GITHUB_TOKEN }} - instructions: | - Tests are done inside the container, to start an interactive session run: - docker exec -it $(docker container ps --format '{{.ID}}') bash - - name: Install gcloud CLI - if: ${{ inputs.collect-coverage }} + name: torch-xla-wheels + path: /tmp/wheels/ + - name: Fetch CPP test binaries + uses: actions/download-artifact@v4 + with: + name: cpp-test-bin + path: /tmp/test/bin + if: ${{ matrix.run_cpp_tests }} + # GitHub Actions doesn't preserve executable permissions + # https://github.com/actions/download-artifact?tab=readme-ov-file#permission-loss + - name: Set CPP test permissions + run: | + chmod +x /tmp/test/bin/* + ls -l /tmp/test/bin + if: ${{ matrix.run_cpp_tests }} + - name: Fetch CUDA plugin + uses: actions/download-artifact@v4 + with: + name: cuda-plugin + path: /tmp/wheels/ + if: ${{ inputs.install-cuda-plugin }} + - name: Setup CUDA environment shell: bash run: | - sudo tee -a /etc/yum.repos.d/google-cloud-sdk.repo << EOM - [google-cloud-cli] - name=Google Cloud CLI - baseurl=https://packages.cloud.google.com/yum/repos/cloud-sdk-el8-x86_64 - enabled=1 - gpgcheck=1 - repo_gpgcheck=0 - gpgkey=https://packages.cloud.google.com/yum/doc/rpm-package-key.gpg - EOM - sudo yum install -y google-cloud-cli - - name: Auth to GCR - if: ${{ inputs.collect-coverage }} + # TODO: Make PJRT_DEVICE=CPU work with XLA_REGISTER_INSTALLED_PLUGINS=1 + echo "XLA_REGISTER_INSTALLED_PLUGINS=1" >> $GITHUB_ENV + + echo "PATH=$PATH:/usr/local/cuda-12.1/bin" >> $GITHUB_ENV + echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64" >> $GITHUB_ENV + if: ${{ inputs.install-cuda-plugin }} + - name: Check GPU + run: nvidia-smi + if: ${{ inputs.install-cuda-plugin }} + - name: Install wheels shell: bash run: | - echo "${GCLOUD_SERVICE_KEY}" | gcloud auth activate-service-account --key-file=- - - name: Download and run docker image from GCR + pip install /tmp/wheels/*.whl + # TODO: Add these in setup.py + pip install fsspec + pip install rich + + echo "Import check..." + python -c "import torch_xla" + - name: Record PyTorch commit + run: | + # Don't just pipe output in shell because imports may do extra logging + python -c " + import torch_xla.version + with open('$GITHUB_ENV', 'a') as f: + f.write(f'PYTORCH_COMMIT={torch_xla.version.__torch_gitrev__}\n') + " + - name: Checkout PyTorch Repo + uses: actions/checkout@v4 + with: + repository: pytorch/pytorch + path: pytorch + ref: ${{ env.PYTORCH_COMMIT }} + - name: Checkout PyTorch/XLA Repo + uses: actions/checkout@v4 + with: + path: pytorch/xla + - name: Extra CI deps shell: bash run: | - echo "DOCKER_IMAGE: ${DOCKER_IMAGE}" - docker pull "${DOCKER_IMAGE}" - pid=$(docker run --shm-size=16g ${GPU_FLAG:-} -e USE_COVERAGE -e XLA_SKIP_TORCH_OP_TESTS -e XLA_SKIP_MP_OP_TESTS -e RUN_BENCHMARK_TESTS -e RUN_CPP_TESTS1 -e RUN_CPP_TESTS2 -e RUN_PYTHON_TESTS -e RUN_XLA_OP_TESTS1 -e RUN_XLA_OP_TESTS2 -e RUN_XLA_OP_TESTS3 -e RUN_TORCH_MP_OP_TESTS -t -d -w "$WORKDIR" "${DOCKER_IMAGE}") - echo "${GCLOUD_SERVICE_KEY}" | docker exec -i "${pid}" sh -c "cat >> /tmp/pytorch/xla/default_credentials.json" - echo "pid=${pid}" >> "${GITHUB_ENV}" + set -x + + pip install expecttest unittest-xml-reporting + + if [[ ! -z "$RUN_BENCHMARK_TESTS" ]]; then + pip install -r pytorch/xla/benchmarks/requirements.txt + fi - name: Test shell: bash - run: | - docker exec --privileged -u jenkins "${pid}" bash -c '.circleci/${{ inputs.test-script }}' + run: pytorch/xla/.github/scripts/run_tests.sh pytorch/ pytorch/xla/ $USE_COVERAGE - name: Upload coverage results if: ${{ inputs.collect-coverage }} shell: bash @@ -158,8 +204,3 @@ jobs: gsutil cp inc_metadata.json gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/metadata.json fi fi - - - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@main - if: always() - diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 38203f57580..e040884b5ef 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -19,18 +19,15 @@ concurrency: cancel-in-progress: true jobs: - build: - name: "Build PyTorch/XLA (GPU)" - uses: ./.github/workflows/_build.yml + build-upstream-image: + name: "Build upstream Docker image" + uses: ./.github/workflows/_build_upstream_image.yml with: ecr-docker-image-base: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base - gcr-docker-image: gcr.io/tpu-pytorch/xla_base:dev-3.8_cuda_12.1 - cuda: 1 - secrets: - gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} + if: github.event_name == 'push' && github.event.ref == 'refs/heads/master' build-torch-xla: - name: "Build PyTorch/XLA (TPU)" + name: "Build PyTorch/XLA" uses: ./.github/workflows/_build_torch_xla.yml with: dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_tpuvm @@ -41,16 +38,16 @@ jobs: name: "Build XLA CUDA plugin" uses: ./.github/workflows/_build_plugin.yml with: - dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_12.1 + dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.1 secrets: gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} - test-cpu: + test-python-cpu: name: "CPU tests" uses: ./.github/workflows/_test.yml - needs: build + needs: build-torch-xla with: - docker-image: ${{ needs.build.outputs.docker-image }} + dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_tpuvm timeout-minutes: 120 collect-coverage: false secrets: @@ -59,12 +56,13 @@ jobs: test-cuda: name: "GPU tests" uses: ./.github/workflows/_test.yml - needs: build + needs: [build-torch-xla, build-cuda-plugin] with: - docker-image: ${{ needs.build.outputs.docker-image }} + dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.1 runner: linux.8xlarge.nvidia.gpu timeout-minutes: 300 - collect-coverage: false # TODO(yeounoh) separate from CPU coverage metrics + collect-coverage: false + install-cuda-plugin: true secrets: gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} @@ -73,14 +71,13 @@ jobs: uses: ./.github/workflows/_tpu_ci.yml needs: build-torch-xla # Only run this for HEAD and releases - if: github.event_name == 'push' + if: github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'tpuci') push-docs: - name: "Build & publish docs" - if: github.event_name == 'push' && (github.event.ref == 'refs/heads/master' || startsWith(github.event.ref, 'refs/tags/r')) + name: "Build docs" uses: ./.github/workflows/_docs.yml - needs: build + needs: build-torch-xla with: - docker-image: ${{ needs.build.outputs.docker-image }} + dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_tpuvm secrets: torchxla-bot-token: ${{ secrets.TORCH_XLA_BOT_TOKEN }} diff --git a/.github/workflows/lintercheck.yml b/.github/workflows/lintercheck.yml index 6598b98da32..b17c608f883 100644 --- a/.github/workflows/lintercheck.yml +++ b/.github/workflows/lintercheck.yml @@ -24,7 +24,7 @@ jobs: if: github.event_name == 'push' && github.event.ref == 'refs/heads/master' shell: bash run: | - TORCH_PIN=./torch_patches/.torch_pin + TORCH_PIN=./.torch_pin if [[ -f "${TORCH_PIN}" ]]; then echo "Please remove ${TORCH_PIN} before landing." exit 1 diff --git a/.github/workflows/torch_xla2.yml b/.github/workflows/torch_xla2.yml index 7c5a88bf430..441addad422 100644 --- a/.github/workflows/torch_xla2.yml +++ b/.github/workflows/torch_xla2.yml @@ -34,10 +34,8 @@ jobs: shell: bash working-directory: experimental/torch_xla2 run: | - pip install pytest absl-py jax[cpu] flatbuffers tensorflow - pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install -r test_requirements.txt - pip install -e . + pip install -r test-requirements.txt + pip install -e .[cpu] - name: Run tests working-directory: experimental/torch_xla2 shell: bash diff --git a/BUILD b/BUILD index 6949f6dc748..60b601240fc 100644 --- a/BUILD +++ b/BUILD @@ -30,3 +30,23 @@ cc_binary( "@xla//xla/stream_executor:cuda_platform", ]), ) + +test_suite( + name = "cpp_tests", + # testonly = True, + tests = [ + "//test/cpp:test_aten_xla_tensor_1", + "//test/cpp:test_aten_xla_tensor_2", + "//test/cpp:test_aten_xla_tensor_3", + "//test/cpp:test_aten_xla_tensor_4", + "//test/cpp:test_aten_xla_tensor_5", + "//test/cpp:test_aten_xla_tensor_6", + "//test/cpp:test_ir", + "//test/cpp:test_lazy", + "//test/cpp:test_replication", + "//test/cpp:test_tensor", + "//test/cpp:test_xla_sharding", + "//torch_xla/csrc/runtime:pjrt_computation_client_test", + "//torch_xla/csrc/runtime:ifrt_computation_client_test", + ], +) diff --git a/OP_LOWERING_GUIDE.md b/OP_LOWERING_GUIDE.md index b445a1d8998..535d7cf596c 100644 --- a/OP_LOWERING_GUIDE.md +++ b/OP_LOWERING_GUIDE.md @@ -25,7 +25,7 @@ All file mentioned below lives under the `xla/torch_xla/csrc` folder, with the e 7. `ops/` directory contains all `ir::ops` declaration and definition. Smaller nodes can be put in `ops/ops.h/.cpp`. More complicated nodes can be put into a separate file. All ops inherit from `ir::ops::Node` and provide a way to lower input `ir::Value` to a sequence of `XlaOp`. ## Unit Test -Our CircleCI runs PyTorch native python tests for every change and every day. Those tests will use XLA implementation if we provide a lowering. We usually don’t need to add additional python tests for PyTorch/XLA unless we want to verify some xla behaviors(like dynamic shape) or we skipped the pytorch native test for some reason. The python test should be added to `xla/test/test_operations.py` if it is required. We also need to add CPP tests in `xla/test/cpp/test_aten_xla_tensor.cpp`. This test should call PyTorch c++ API and verify our implementation yields the same result as PyTorch native implementation. We also need to verify if the xla implementation is called when the tensor is a XLA tensor by checking the `aten::op` and `xla::op` counters. +Our CI runs PyTorch native python tests for every change and every day. Those tests will use XLA implementation if we provide a lowering. We usually don’t need to add additional python tests for PyTorch/XLA unless we want to verify some xla behaviors(like dynamic shape) or we skipped the pytorch native test for some reason. The python test should be added to `xla/test/test_operations.py` if it is required. We also need to add CPP tests in `xla/test/cpp/test_aten_xla_tensor.cpp`. This test should call PyTorch c++ API and verify our implementation yields the same result as PyTorch native implementation. We also need to verify if the xla implementation is called when the tensor is a XLA tensor by checking the `aten::op` and `xla::op` counters. ## Tips The process of lowering is breaking down the PyTorch operations into a sequence of XlaOp. To provide a good lowering of the PyTorch operation, one needs to have a good grasp of what XLA is capable of. Reading the XlaOp document and looking into how similar ops is lowered is the best way to achieve that. You can find a minimal Op lowering example in [this pr](https://github.com/pytorch/xla/pull/2969). You can also find a slightly more complicated example with backward lowering in [this pr](https://github.com/pytorch/xla/pull/2972). diff --git a/README.md b/README.md index d1653eb7b53..70bdcfd57d9 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ started: To install PyTorch/XLA a new TPU VM: ``` -pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html +pip install torch~=2.3.0 torch_xla[tpu]~=2.3.0 -f https://storage.googleapis.com/libtpu-releases/index.html ``` To update your existing training loop, make the following changes: @@ -132,31 +132,36 @@ Our comprehensive user guides are available at: PyTorch/XLA releases starting with version r2.1 will be available on PyPI. You can now install the main build with `pip install torch_xla`. To also install the -Cloud TPU plugin, install the optional `tpu` dependencies: +Cloud TPU plugin, install the optional `tpu` dependencies after installing the main build with ``` pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html ``` -GPU, XRT (legacy runtime), and nightly builds are available in our public GCS -bucket. +GPU and nightly builds are available in our public GCS bucket. | Version | Cloud TPU/GPU VMs Wheel | | --- | ----------- | -| 2.2 (Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl` | -| 2.2 (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl` | -| 2.2 (CUDA 12.1 + Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl` | -| 2.2 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl` | +| 2.3 (Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.3.0-cp38-cp38-manylinux_2_28_x86_64.whl` | +| 2.3 (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl` | +| 2.3 (CUDA 12.1 + Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp38-cp38-manylinux_2_28_x86_64.whl` | +| 2.3 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl` | | nightly (Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl` | | nightly (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl` | | nightly (CUDA 12.1 + Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-nightly-cp38-cp38-linux_x86_64.whl` | +You can also add `+yyyymmdd` after `torch_xla-nightly` to get the nightly wheel of a specified date. To get the companion pytorch nightly wheel, replace the `torch_xla` with `torch` on above wheel links. +
older versions | Version | Cloud TPU VMs Wheel | |---------|-------------------| +| 2.2 (Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl` | +| 2.2 (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl` | +| 2.2 (CUDA 12.1 + Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl` | +| 2.2 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl` | | 2.1 (XRT + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch_xla-2.1.0%2Bxrt-cp310-cp310-manylinux_2_28_x86_64.whl` | | 2.1 (Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.1.0-cp38-cp38-linux_x86_64.whl` | | 2.0 (Python 3.8) | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl` | @@ -202,25 +207,29 @@ wheels for `torch` and `torch_xla` at | --- | ----------- | | 2.0 | `https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl` | -You can also add `+yyyymmdd` after `torch_xla-nightly` to get the nightly wheel -of a specified date. To get the companion pytorch and torchvision nightly wheel, -replace the `torch_xla` with `torch` or `torchvision` on above wheel links.
### Docker | Version | Cloud TPU VMs Docker | | --- | ----------- | +| 2.3 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_tpuvm` | | 2.2 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_tpuvm` | | 2.1 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_tpuvm` | | 2.0 | `gcr.io/tpu-pytorch/xla:r2.0_3.8_tpuvm` | | 1.13 | `gcr.io/tpu-pytorch/xla:r1.13_3.8_tpuvm` | | nightly python | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm` | +To use the above dockers, please pass `--privileged --net host --shm-size=16G` along. Here is an example: +```bash +docker run --privileged --net host --shm-size=16G -it us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm /bin/bash +``` +
| Version | GPU CUDA 12.1 Docker | | --- | ----------- | +| 2.3 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_cuda_12.1` | | 2.2 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_cuda_12.1` | | 2.1 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_12.1` | | nightly | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1` | diff --git a/WORKSPACE b/WORKSPACE index e4d8a73fdc0..9fe770bedff 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -50,9 +50,9 @@ http_archive( "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:f16_abi_clang.diff", ], - strip_prefix = "xla-1acf05ef0d41181caaf0cd691aa9d453ffc41a73", + strip_prefix = "xla-fe08041b23d8baa0d00967913a1d6e8a0c348df3", urls = [ - "https://github.com/openxla/xla/archive/1acf05ef0d41181caaf0cd691aa9d453ffc41a73.tar.gz", + "https://github.com/openxla/xla/archive/fe08041b23d8baa0d00967913a1d6e8a0c348df3.tar.gz", ], ) diff --git a/benchmarks/experiment_runner.py b/benchmarks/experiment_runner.py index 443a4067ac1..da952f3c079 100644 --- a/benchmarks/experiment_runner.py +++ b/benchmarks/experiment_runner.py @@ -319,8 +319,8 @@ def loop(pytorch_profile=None, iter_fn=None): self._args.profile_cuda_cpu or \ self._args.profile_cuda_cpu_individual_ops enable_xla_profiling = self._args.profile_xla - assert not (enable_pytorch_profiling and enable_pytorch_profiling - ), "More than one profiling path enabled." + assert not (enable_pytorch_profiling and + enable_xla_profiling), "More than one profiling path enabled." if enable_xla_profiling: logdir = self._get_results_dir_path(experiment_config, model_config, diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt new file mode 100644 index 00000000000..14e2549fec3 --- /dev/null +++ b/benchmarks/requirements.txt @@ -0,0 +1,3 @@ +tabulate +scipy +pandas diff --git a/benchmarks/run_benchmark.sh b/benchmarks/run_benchmark.sh index fd8a055bccc..e4e483947d9 100644 --- a/benchmarks/run_benchmark.sh +++ b/benchmarks/run_benchmark.sh @@ -5,7 +5,7 @@ LOGFILE=/tmp/benchmark_test.log # Note [Keep Going] # -# Set the `CONTINUE_ON_ERROR` flag to `true` to make the CircleCI tests continue on error. +# Set the `CONTINUE_ON_ERROR` flag to `true` to make the CI tests continue on error. # This will allow you to see all the failures on your PR, not stopping with the first # test failure like the default behavior. CONTINUE_ON_ERROR="${CONTINUE_ON_ERROR:-0}" diff --git a/build_util.py b/build_util.py index 78e4bd5e453..487f5116323 100644 --- a/build_util.py +++ b/build_util.py @@ -36,10 +36,6 @@ def bazel_options_from_env() -> Iterable[str]: bazel_flags.append('--remote_default_exec_properties=cache-silo-key=%s' % cache_silo_name) - if check_env_flag('BUILD_CPP_TESTS', default='0'): - bazel_flags.append('//test/cpp:all') - bazel_flags.append('//torch_xla/csrc/runtime:all') - bazel_jobs = os.getenv('BAZEL_JOBS', default='') if bazel_jobs: bazel_flags.append('--jobs=%s' % bazel_jobs) diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index 199025dc7e1..de5500a0c5b 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -361,6 +361,7 @@ supported: - zero_ - _native_batch_norm_legit - _native_batch_norm_legit.no_stats + - _embedding_bag_forward_only # Note: [functionalization and CompositeExplicitAutograd] # Below are all operators that are "composite" in core, # but require us to explicitly re-enable functionalization in order to use them. diff --git a/docs/README.md b/docs/README.md index 33a0ce5bc36..a405597c798 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,12 +1,12 @@ ## Publish documentation for a new release. -CircleCI job `pytorch_xla_linux_debian11_and_push_doc` is specified to run on `release/*` branches, but it was not +CI job `pytorch_xla_linux_debian11_and_push_doc` is specified to run on `release/*` branches, but it was not run on release branches due to "Only build pull requests" setting. Turning off "Only build pull requests" will result in much larger volumes in jobs which is often unnecessary. We're waiting for [this feature request](https://ideas.circleci.com/ideas/CCI-I-215) to be implemented so that we could override this setting on some branches. Before the feature is available on CircleCi side, we'll use a manual process to publish documentation for release. -[Documentation for master branch](http://pytorch.org/xla/master/) is still updated automatically by the CircleCI job. +[Documentation for master branch](http://pytorch.org/xla/master/) is still updated automatically by the CI job. But we'll need to manually commit the new versioned doc and point http://pytorch.org/xla to the documentation of new stable release. @@ -22,4 +22,4 @@ cd /tmp/xla git add . git commit -m "Publish 1.5 documentation." git push origin gh-pages -``` \ No newline at end of file +``` diff --git a/docs/fori_loop.md b/docs/fori_loop.md new file mode 100644 index 00000000000..0c9f85af399 --- /dev/null +++ b/docs/fori_loop.md @@ -0,0 +1,114 @@ +# Fori_loop +`fori_loop` is a replacement of pure python for loop, PyTorch/XLA would enable `torch_xla.experimental.fori_loop` to keep loop computation graph as rolled during compilation +like [`jax.lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html), not like currently repeat computations by enumerating all execution steps +of each iteration. `fori_loop` might help memory utilization and might help faster compilation. + +User could use `fori_loop` like this: +```python +from torch_xla.experimental.fori_loop import fori_loop +res = fori_loop(upper, lower, /*user defined*/body_fun, init) +``` + +current fori_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-fori_loop) with `fori_loop` on TPU too. + +For detailed implementation: +- for situation that loop range is dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`while_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#while_loop), +like [`jax.lax.while_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html), PyTorch/XLA would support `while_loop` with the +native PyTorch and the XLA backend: XLA::While. Due to `while_loop` didn't support autograd, so it would be used for inference only. + +- for situation that loop range is not dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`scan`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#wipscan), +like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` using XLA::While operator. +This implementation would be very similar like `while_loop`. `scan` support autograd, and it could be used in both training and inference. + +# while_loop +`while_loop` is a replacement of pure python while loop, PyTorch has supported `while_loop` in +[code](https://github.com/pytorch/pytorch/blob/ca6a0e1348ba7dcade1833d983b1b4ca12a5c1e1/torch/_higher_order_ops/while_loop.py#L69). +PyTorch/XLA want to support `while_loop` with the native PyTorch and the XLA backend: XLA::While. + +User could use `while_loop` like this: +```python +import torch_xla.experimental.fori_loop +from torch._higher_order_ops.while_loop import while_loop +res = while_loop(/*user-defined*/cond_fn, /*user-defined*/body_fn, /*tuple or list*/init) +``` +current while_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-while_loop) with `while_loop` on TPU too. + + +# [WIP]scan +like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` for training and inference since it support autograd. +`scan` is WIP. + + +# Simple user guide +User could try these three simple test case to better compare difference between `pure python for loop` and `fori_loop` and `while_loop`, these three test case have similar logic: cumulative plus 1 for ten times: + +### simple example with pure python for loop +```bash +# python +>>> import torch +>>> init = torch.tensor([0], dtype=torch.int32) +>>> one_value = torch.ones(1, dtype=torch.int32) +>>> +>>> for i in range(10): +... init = init + one_value +... +>>> init +tensor([10], dtype=torch.int32) +``` + +### simple example with `while_loop`: +```bash +# PJRT_DEVICE=TPU python +>>> import torch +>>> import torch_xla +>>> import torch_xla.experimental.fori_loop +>>> from torch_xla.experimental.fori_loop import fori_loop +>>> from torch._higher_order_ops.while_loop import while_loop +>>> import torch_xla.core.xla_model as xm +>>> import torch_xla.core.xla_builder as xb +>>> +>>> device = xm.xla_device() +>>> +>>> def cond_fn(init, limit_value): +... return limit_value[0] >= init[0] +... +>>> def body_fn(init, limit_value): +... one_value = torch.ones(1, dtype=torch.int32, device=device) +... return (torch.add(init, one_value), limit_value.clone()) +... +>>> init = torch.tensor([0], dtype=torch.int32, device=device) +>>> limit_value = torch.tensor([10], dtype=torch.int32, device=device) +>>> res_, limit_value_ = while_loop(cond_fn, body_fn, (init, limit_value)) +>>> res_ +FunctionalTensor(lvl=0, value=\ +tensor([11], device='xla:0', dtype=torch.int32)) +``` + +### simple example with `fori_loop`: +```bash +# PJRT_DEVICE=TPU python +>>> import torch +>>> import torch_xla +>>> import torch_xla.experimental.fori_loop +>>> from torch_xla.experimental.fori_loop import fori_loop +>>> from torch._higher_order_ops.while_loop import while_loop +>>> import torch_xla.core.xla_model as xm +>>> import torch_xla.core.xla_builder as xb +>>> +>>> device = xm.xla_device() +>>> +>>> lower = torch.tensor([2], dtype=torch.int32, device=device) +>>> upper = torch.tensor([52], dtype=torch.int32, device=device) +>>> plus_value = torch.tensor([1], dtype=torch.int32, device=device) +>>> init_val = torch.tensor([1], dtype=torch.int32, device=device) +>>> +>>> def body_fun(*argus): +... plus_value, init_val = argus +... return plus_value, torch.add(plus_value, init_val) +... +>>> _, _, _, res_ = fori_loop(upper, lower, body_fun, plus_value, init_val) +>>> res_ +tensor([51], device='xla:0', dtype=torch.int32) +``` + +For more example and detailed user guide, please read [this test file](https://github.com/pytorch/xla/blob/master/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py). PyTorch/XLA would include `while_loop` support in 2.3 for simple test case, complex test case and support for `fori_loop` and `scan` would be added after 2.3 diff --git a/docs/gpu.md b/docs/gpu.md index c2678164f4e..de1cf807361 100644 --- a/docs/gpu.md +++ b/docs/gpu.md @@ -71,9 +71,12 @@ source ~/.bashrc ### Wheel ``` -pip3 install torch==2.2.0 -pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl +pip3 install torch==2.3.0 +# GPU whl for python 3.10 + cuda 12.1 +pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl ``` +Wheels for other Python version and CUDA version can be found [here](https://github.com/pytorch/xla?tab=readme-ov-file#available-docker-images-and-wheels). + ## Run a simple model In order to run below examples, you need to clone the pytorch/xla repo to access the imagenet example(We already clone it in our docker). diff --git a/docs/pallas.md b/docs/pallas.md new file mode 100644 index 00000000000..46c80b79f2e --- /dev/null +++ b/docs/pallas.md @@ -0,0 +1,57 @@ +# Custom Kernels via Pallas + +With the rise of OpenAI [triton](https://openai.com/research/triton), custom kernels become more and more popular in the GPU community, for instance, the introduction of [FlashAttention](https://github.com/Dao-AILab/flash-attention) and [PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html). In order to provide the feature parity in the TPU world, Google has introduced [Pallas](http://go/jax-pallas) and [Mosaic](http://go/mosaic-tpu). For PyTorch/XLA to continue pushing the performance in TPU, we have to support custom kernels, and the best way is through Pallas and Mosaic. The design doc is [TBA](). + +Let's assume you have a Pallas kernel defined as follow: +```python3 +import jax +from jax.experimental import pallas as pl +import jax.numpy as jnp + +def add_vectors_kernel(x_ref, y_ref, o_ref): + x, y = x_ref[...], y_ref[...] + o_ref[...] = x + y + +@jax.jit +def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: + return pl.pallas_call(add_vectors_kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) + )(x, y) +``` + +## Adopt the above kernel to be compatible with PyTorch/XLA + +Example usage: +```python3 +q = torch.randn(3, 2, 128, 4).to("xla") +k = torch.randn(3, 2, 128, 4).to("xla") +v = torch.randn(3, 2, 128, 4).to("xla") + +# Adopts any Pallas kernel +from torch_xla.experimental.custom_kernel import make_kernel_from_pallas +pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)]) +output = pt_kernel(q, k) +``` +For simple kernels, the adoption is just as simple as one liner. For more complicated kernels, you can refer to our Flash Attention implementation for details. + +## Use built-in kernels + +Besides manually wrapping external Pallas kernels, there are built-in kernels where the adoptions are done by PyTorch/XLA already. + +Example usage: +```python3 +# Use built-in kernels +from torch_xla.experimental.custom_kernel import flash_attention +output = flash_attention(q, k, v) +``` + +You can just use it like any other torch.ops. + +## HuggingFace Llama 3 Example +We have a fork of HF Llama 3 to demonstrate a potential integration [here](https://github.com/pytorch-tpu/transformers/tree/alanwaketan/flash_attention). + +## Dependencies +The Pallas integration depends on JAX to function. However, not every JAX version is compatible with your installed PyTorch/XLA. To install the proper JAX: +```bash +pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html +``` diff --git a/docs/requirements.txt b/docs/requirements.txt index 26f491f6c15..411e6642ff7 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,6 @@ mistune==0.8.4 -sphinx==2.4.4 +sphinx==5.0.0 docutils==0.16 -Jinja2<3.1 +Jinja2==3.1.3 m2r -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme diff --git a/experimental/torch_xla2/README.md b/experimental/torch_xla2/README.md index f30be7ff1da..594d5380882 100644 --- a/experimental/torch_xla2/README.md +++ b/experimental/torch_xla2/README.md @@ -4,7 +4,8 @@ Currently this is only source-installable. Requires Python version >= 3.10. -### NOTE: +### NOTE: + Please don't install torch-xla from instructions in https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md . In particular, the following are not needed: @@ -18,71 +19,58 @@ TorchXLA2 and torch-xla have different installation instructions, please follow the instructions below from scratch (fresh venv / conda environment.) -### 1. Install dependencies - -#### 1.0 (optional) Make a virtualenv / conda env, and activate it. - -```bash -conda create --name python=3.10 -conda activate -``` -Or, -```bash -python -m venv create my_venv -source my_venv/bin/activate -``` - -#### 1.1 Install torch CPU, even if your device has GPU or TPU: +### 1. Installing `torch_xla2` -```bash -pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu -``` +#### 1.0 (recommended) Make a virtualenv / conda env -Or, follow official instructions in [pytorch.org](https://pytorch.org/get-started/locally/) to install for your OS. +If you are using VSCode, then [you can create a new environment from +UI](https://code.visualstudio.com/docs/python/environments). Select the +`dev-requirements.txt` when asked to install project dependencies. -#### 1.2 Install Jax for either GPU or TPU +Otherwise create a new environment from the command line. -If you are using Google Cloud TPU, then ```bash -pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -``` +# Option 1: venv +python -m venv create my_venv +source my_venv/bin/activate -If you are using a machine with NVidia GPU: +# Option 2: conda +conda create --name python=3.10 +conda activate -```bash -pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +# Either way, install the dev requirements. +pip install -r dev-requirements.txt ``` -If you are using a CPU-only machine: -```bash -pip install --upgrade "jax[cpu]" -``` +Note: `dev-requirements.txt` will install the CPU-only version of PyTorch. -Or, follow the official instructions in https://jax.readthedocs.io/en/latest/installation.html to install for your OS or Device. +#### 1.1 Install this package -#### 1.3 Install this package +Install `torch_xla2` from source for your platform: ```bash -pip install -e . +pip install -e .[cpu] +pip install -e .[cuda] +pip install -e .[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html ``` -#### 1.4 (optional) verify installation by running tests +#### 1.2 (optional) verify installation by running tests ```bash -pip install -r test_requirements.txt +pip install -r test-requirements.txt pytest test ``` - ## Run a model Now let's execute a model under torch_xla2. We'll start with a simple 2-layer model it can be in theory any instance of `torch.nn.Module`. ```python +import torch +import torch.nn as nn +import torch.nn.functional as F -import torch_xla2 -from torch import nn class MyModel(nn.Module): def __init__(self): @@ -101,8 +89,8 @@ class MyModel(nn.Module): m = MyModel() # Execute this model using torch -inputs = (torch.randn(3, 3, 28, 28), ) -print(m(*inputs)) +inputs = torch.randn(3, 3, 28, 28) +print(m(inputs)) ``` This model `m` contains 2 parts: the weights that is stored inside of the model @@ -114,6 +102,7 @@ to `XLA` devices. This can be accomplished with `torch_xla2.tensor.move_to_devic We need move both the weights and the input to xla devices: ```python +import torch_xla2 from torch.utils import _pytree as pytree from torch_xla2.tensor import move_to_device @@ -121,7 +110,7 @@ inputs = move_to_device(inputs) new_state_dict = pytree.tree_map_only(torch.Tensor, move_to_device, m.state_dict()) m.load_state_dict(new_state_dict, assign=True) -res = m(*inputs) +res = m(inputs) print(type(res)) # outputs XLATensor2 ``` @@ -164,5 +153,3 @@ from torch_xla2.extra import jax_jit model_func_jitted = jax_jit(model_func) print(model_func_jitted(new_state_dict, inputs)) ``` - - diff --git a/experimental/torch_xla2/dev-requirements.txt b/experimental/torch_xla2/dev-requirements.txt index 4a32310fbda..208f70d5fef 100644 --- a/experimental/torch_xla2/dev-requirements.txt +++ b/experimental/torch_xla2/dev-requirements.txt @@ -1,9 +1,3 @@ -absl-py==2.0.0 -flatbuffers==23.5.26 -jax==0.4.23 -jaxlib==0.4.23 -pytest -tensorflow -torch==2.2.1+cpu -immutabledict -sentencepiece \ No newline at end of file +-f https://download.pytorch.org/whl/torch +torch==2.3.0+cpu +ruff~=0.3.5 diff --git a/experimental/torch_xla2/docs/dispatch.png b/experimental/torch_xla2/docs/dispatch.png new file mode 100644 index 00000000000..fcdd5e9e58a Binary files /dev/null and b/experimental/torch_xla2/docs/dispatch.png differ diff --git a/experimental/torch_xla2/docs/fixing_op_info_test.md b/experimental/torch_xla2/docs/fixing_op_info_test.md new file mode 100644 index 00000000000..03624f9487e --- /dev/null +++ b/experimental/torch_xla2/docs/fixing_op_info_test.md @@ -0,0 +1,211 @@ +# How to fix an op info test. + +## What is OpInfo test + +PyTorch created a list of python objects (OpInfo) to keep +track how to test each op. This is useful to us because it +ensures that the ops we implement produces the same results +pytorch would produce. + +Context: +* https://dev-discuss.pytorch.org/t/opinfos-in-pytorch-1-10/253 +* https://github.com/pytorch/pytorch/issues/54261 + + +## How to fix one + +### Remove one op from skiplist + +Open [test/test_ops.py](../test/test_ops.py) with your +favorite text editor. +Remove one line from the `skiplist` set. + +i.e. + +```bash +(base) hanq-macbookpro:torch_xla2 hanq$ git diff +diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py +index 72a39ae85..2a156cbce 100644 +--- a/experimental/torch_xla2/test/test_ops.py ++++ b/experimental/torch_xla2/test/test_ops.py +@@ -15,7 +15,6 @@ skiplist = { + "_native_batch_norm_legit", + "_segment_reduce", + "_upsample_bilinear2d_aa", +- "addbmm", + "addmm", + "addmv", + "addr", +``` + +### Run test to see what failure + +Error gotten: + +``` +E RuntimeError: ('No lowering found for\n\nTo execute this test, run the following from the base repo dir:\n python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64\n\nThis message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0', 'aten::addbmm') +``` + +From here we have 2 strategies for fixing this test: + +1. Add an implementation to `aten::addbmm` operator using Jax ops. Or, +2. Add an implementation `aten::addbmm` operator using torch ops (this commonly known as "decompositions"). + +Either way works for torch_xla2. For ops that are not "Core Aten" sometimes we implement in torch ops with the goal of +upstreaming this decomposition to [pytorch decompositon](https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py) +so other projects can benefit from it. + +For illustration purposes, let's implement this op in Jax. + +(NOTE: this doesn't stop us from upstreaming a decomposition later if we want) + +### First Impl + +To implement this op using jax ops, we first find what +is the exact semantics in this page: +https://pytorch.org/docs/stable/generated/torch.addbmm.html + +From it's math formula: we can implement it as follows. + +``` ++@op(torch.ops.aten.addbmm.default) ++def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): ++ ++ mm = jnp.einsum('bxy, byz -> xz', batch1, batch2) ++ return beta * input + alpha * mm +``` + +Now running test again: + +``` +python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64 +``` + +(NOTE: the exact test command is printed out when we run +`pytest test/test_ops.py` so we can only run the failed test instead of running all tests.) + +We now see this error: + +``` +FAIL: test_reference_eager_addbmm_cpu_int64 (__main__.TestOpInfoCPU) [torch_xla2_diff:0.001] +---------------------------------------------------------------------- +Traceback (most recent call last): + File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 654, in run_export_and_compare + diff_output( + File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 617, in diff_output + testcase.assertTrue( +AssertionError: False is not true +``` + +This is telling me that our implementation did not produce +the same result as the ops in PyTorch. + +To debug this, let's figure out what exact input caused this. +We can achieve this by setting a break point [here](https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/test/test_ops.py#L644), right before the diff. Here we can +inspect values of `res` and `res2`, as well as the `sample_input`. + +The sample input we get is +``` +SampleInput(input=tensor([[-3, -3, 9, 8, -8, -3, -4, 2, 2, 2], + [-5, 1, -9, 9, 1, -5, 6, 1, -4, -5], + [-2, -1, 5, -2, -3, 0, 5, -4, 9, -6], + [-1, -7, 6, 3, 8, 3, 8, 9, -5, 7], + [-3, -4, -9, 9, 7, -3, -8, 2, 5, -3]]), args=(tensor([[[-2, 4, -2, 5, 8], + [-6, -2, 5, 7, 7], + [-8, -3, 2, 5, -3], + [-4, 7, 0, -9, 8], + [ 3, 9, -9, -2, 0]], + + [[-7, 1, -3, 7, -4], + [ 3, 5, 4, 6, 5], + [-2, 8, 3, 5, 7], + [ 8, -2, -8, 2, 0], + [ 6, 1, -8, 8, 0]], + + [[ 2, -1, -5, -8, -9], + [ 5, 0, -4, -1, -6], + [-6, 2, -5, -2, -5], + [-5, -3, -5, -4, 9], + [-3, 4, -9, -9, 7]], + + [[ 2, 5, -7, -3, 8], + [-5, -7, -8, -4, 4], + [-4, -6, -3, 0, 6], + [ 8, 0, -3, -8, 2], + [-4, 3, -9, -6, 7]], + + [[ 2, 1, -6, 2, 8], + [ 2, 6, 4, 1, 8], + [-9, 9, -5, 8, 3], + [-5, 0, -2, 4, 0], + [ 5, 8, -4, 9, 7]]]), tensor([[[-1, -8, 3, 5, -8, 2, -5, 0, -9, -5], + [-4, -7, 2, 2, 1, -9, 2, 7, -1, -1], + [ 1, 8, -6, -4, -6, -8, -7, -9, 7, 4], + [-4, 1, -9, 3, 4, 6, 0, -2, -2, -7], + [ 5, 5, 0, 8, -3, 7, -7, 8, 3, 5]], + + [[ 8, -4, -9, 9, 5, 0, 5, 0, -5, 5], + [-5, -3, -2, 8, 1, -2, 4, -7, 5, 3], + [-4, 4, 1, -4, -8, 2, -5, 2, 9, -7], + [ 9, 6, -8, -3, 3, 1, 4, 6, -5, -4], + [-2, 1, 5, 5, 2, 6, 7, -3, -7, 3]], + + [[ 9, -8, 5, -3, -1, 2, -9, -5, -1, -3], + [-3, 3, -9, -7, -9, -8, 1, -3, 7, -2], + [ 8, -1, 8, -8, -7, 4, 8, 8, 5, -7], + [-1, 6, -8, 7, -1, -5, -8, 6, -2, 8], + [-5, -5, 8, 6, 0, 1, 3, -2, -3, -9]], + + [[ 7, -2, 6, -8, -5, 3, 2, -1, -5, 8], + [-6, -4, 3, 9, -9, -8, -7, 3, 9, 0], + [ 1, 3, 4, 4, -5, -2, -4, -2, 3, -7], + [-6, 9, 5, -1, 7, 7, 8, -3, -8, 0], + [-1, -6, -3, 3, 3, -8, -4, 9, -5, 7]], + + [[-5, -3, -9, 6, -1, -7, 9, -8, 1, -8], + [-8, -8, -2, -5, -7, -8, 1, 0, 0, -6], + [ 7, -5, 2, 2, 0, -9, -5, -7, 1, 8], + [-4, 0, 9, 6, -1, -6, 6, -6, -2, -1], + [ 7, 3, 0, 1, 1, -9, 5, -8, -1, -7]]])), kwargs={'beta': 0.6, 'alpha': 0.2}, broadcasts_input=False, name='') +``` + +And the `res` from torch is + +``` +tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) +``` + +So few observation is: +1. Input tensor are of type int64 +2. alpha and beta are both floats. + +So one can suspect that it has to do with rounding. +Reading the doc more carefully, we can find this sentence + + For inputs of type FloatTensor or DoubleTensor, arguments beta and alpha must be real numbers, otherwise they should be integers. + +So likely torch first casted the float alpha and beta to integer, which yields 0, then used them in math to get a matrix with all zeros. + +### Second Impl + +```python ++@op(torch.ops.aten.addbmm.default) ++def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): ++ alpha = jnp.array(alpha).astype(batch1.dtype) ++ beta = jnp.array(beta).astype(batch1.dtype) ++ mm = jnp.einsum('bxy, byz -> xz', batch1, batch2) ++ return jax.lax.cond(beta == 0, ++ lambda: alpha * mm, ++ lambda: beta*input + alpha*mm) ++ +``` + +Adding type casts makes the tests passes. + +### Submit +Now, let's remove the pdb and prints we added, and submit the fix as a PR: https://github.com/pytorch/xla/pull/6993 + diff --git a/experimental/torch_xla2/docs/how_it_works.md b/experimental/torch_xla2/docs/how_it_works.md new file mode 100644 index 00000000000..e4098ca0096 --- /dev/null +++ b/experimental/torch_xla2/docs/how_it_works.md @@ -0,0 +1,134 @@ +How it works +============ + + +## Tensor subclass and eager mode + +The class `XLATensor2` is a `torch.Tensor` subclass +that overrides `__torch_dispatch__`. + +It roughly looks like this (with some details removed): + +The complete class impl is at [tensor.py](../torch_xla2/tensor.py). + +```python +class XLATensor2(torch.Tensor): + + @staticmethod + def __new__(cls, elem): + return torch.Tensor._make_wrapper_subclass( + cls, + shape, + dtype=dtype, + device='meta', + requires_grad=False, + ) + + def __init__(self, elem: jax.Array): + super().__init__() + self._elem = elem + + __torch_function__ = torch._C._disabled_torch_function_impl + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + # here assumes ALL tensors in args / kwargs are + # instances of XLATensor2 + args, kwargs = unwrap((args, kwargs)) + jax_func = some_registry[func] + res = jax_func(*args, **kwargs) + return wrap(res) + +def wrap(tree): + # wrap jax.Array with XLATensor2 + return pytree.tree_map_only( + jax.Array, XLATensor2, tree) + +def unwrap(tree): + # get jax.Array out ofXLATensor2 + return pytree.tree_map_only( + XLATensor2, lambda x: x._elem, tree) +``` + +In other words, assuming that we have a function +that takes `jax.Array` as input and returns `jax.Array` +but otherwise implement the same semantics +as a `ATen` op; then, using this tensor we would +be able to route the call to this jax function. + +[_ops.py](../torch_xla2/_ops.py) files defines some of those ops. + +Let's take `aten::add` as example: + +```python +@op(torch.ops.aten.add) +def _aten_add(x, y, *, alpha=1): + """if isinstance(x, jnp.ndarray) and isinstance(y, jnp.ndarray): + + assert x.dtype == y.dtype, (x.dtype, y.dtype) + """ + return x + y * alpha +``` + +The `@op` decorator just puts this function into `some_registry` dictionary. + +`_aten_add` has same signature as `torch.ops.aten.add` but takes `jax.Array` as +input. + +![](dispatch.png) + + +## fx Interpreter and dynamo mode + +Now, assuming we have this `some_registry` dict with key core Aten ops, +and value the equivalent python Jax functions. We can also build a `fx.Interpreter` +subclass that executes the jax function given a `fx.GraphModule`. + + +```python +class JaxInterpreter(torch.fx.Interpreter): + + def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: + if not isinstance(target, + (torch._ops.OpOverloadPacket, torch._ops.OpOverload)): + return super().call_function(target, args, kwargs) + + op = some_registry[target] + return op.func(*args, **kwargs) +``` + +There is no wrapping and unwrapping needed because `args` and `kwargs` are +already `jax.Array`'s. + +Using this interpreter we can build a dynamo backend: + +```python +def backend(fxgraph): + + def tojit(*args, *kwargs): + return JaxInterpreter(fxgraph).run(*args, **kwargs) + jitted = jax.jit(to_jit) + + def f(*torchtensor): + jaxarrays = unwrap(torchtensors) + res = jitted(jax_array) + return wrap(res) + + return f +``` + +The inner function `tojit` is a function that takes and returns +`jax.Array`'s. So it's suitable to be jitted with `jax.jit`. + +`f` is returned callable that takes `XLATensor2`; so can interop with +other torch codes. + +## nn.Modules and state management + +See [README.md](../README.md) for using `torch.func.functional_call` to +make `nn.Module`s interact well with `jax.jit`. + +See [Examples](../examples/README.md) for training using torch's optimizers or jax's +optimizers. + +[def]: dispatch.png \ No newline at end of file diff --git a/experimental/torch_xla2/pyproject.toml b/experimental/torch_xla2/pyproject.toml index d0d2a42dec8..0c2101dbcb9 100644 --- a/experimental/torch_xla2/pyproject.toml +++ b/experimental/torch_xla2/pyproject.toml @@ -2,29 +2,30 @@ requires = ["hatchling"] build-backend = "hatchling.build" - [project] version = "0.0.1" name = "torch_xla2" dependencies = [ "absl-py", - "flatbuffers", + "immutabledict", + "jax>=0.4.24", "pytest", - "tensorflow", - - # Note: Exclude these because otherwise on pip install . - # pip will install libs from pypi which is the GPU version - # of these libs. - # We most likely need CPU version of torch and TPU version of - # jax. So it's best for users to install them by hand - # See more at README.md - # "jax>=0.4.24", - # "jaxlib>=0.4.24", - # "torch", + "tensorflow-cpu", + # Developers should install `dev-requirements.txt` first + "torch>=2.2.1", ] - requires-python = ">=3.10" license = {file = "LICENSE"} +[project.optional-dependencies] +cpu = ["jax[cpu]"] +# Add libtpu index `-f https://storage.googleapis.com/libtpu-releases/index.html` +tpu = ["jax[tpu]"] +cuda = ["jax[cuda12]"] + [tool.pytest.ini_options] addopts="-n auto" + +[tool.ruff] +line-length = 80 +indent-width = 2 diff --git a/experimental/torch_xla2/test-requirements.txt b/experimental/torch_xla2/test-requirements.txt new file mode 100644 index 00000000000..1deead455a1 --- /dev/null +++ b/experimental/torch_xla2/test-requirements.txt @@ -0,0 +1,5 @@ +-r dev-requirements.txt +pytest +pytest-xdist +sentencepiece +expecttest diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index ed14e636e5c..5f6fdbbeab2 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -7,6 +7,8 @@ instantiate_device_type_tests, ops) from torch.utils import _pytree as pytree from torch_xla2 import tensor +import torch_xla2 + skiplist = { "__getitem__", @@ -15,19 +17,6 @@ "_native_batch_norm_legit", "_segment_reduce", "_upsample_bilinear2d_aa", - "addbmm", - "addmm", - "addmv", - "addr", - "all", - "allclose", - "amax", - "amin", - "aminmax", - "angle", - "any", - "argmax", - "argmin", "argsort", "as_strided", "as_strided_scatter", @@ -570,6 +559,7 @@ "special.xlog1py", "split", "split_with_sizes", + "split_with_sizes_copy", "sqrt", "square", "stack", @@ -639,7 +629,8 @@ def run_export_and_compare(testcase, input2, args2, kwargs2 = pytree.tree_map_only( torch.Tensor, tensor.move_to_device, (sample_input.input, sample_input.args, sample_input.kwargs)) - res2 = func(input2, *args2, **kwargs2) + with torch_xla2.mode(): + res2 = func(input2, *args2, **kwargs2) res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2) with testcase.subTest("torch_xla2_diff:" + str(atol)): if ignore_indices and isinstance(res, tuple) and len(res) == 2: diff --git a/experimental/torch_xla2/test_requirements.txt b/experimental/torch_xla2/test_requirements.txt deleted file mode 100644 index c8596327236..00000000000 --- a/experimental/torch_xla2/test_requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -pytest -immutabledict -sentencepiece -pytest-xdist -expecttest \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/_ops.py b/experimental/torch_xla2/torch_xla2/_ops.py index fe0f97a0f01..e3650234372 100644 --- a/experimental/torch_xla2/torch_xla2/_ops.py +++ b/experimental/torch_xla2/torch_xla2/_ops.py @@ -410,11 +410,23 @@ def _aten_native_layer_norm(input, # - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor @op(torch.ops.aten.addmm) +@op(torch.ops.aten.addmv) def _aten_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0): + alpha = jnp.array(alpha).astype(mat1.dtype) + beta = jnp.array(beta).astype(mat1.dtype) self *= beta self += alpha * jnp.matmul(mat1, mat2) return self +@op(torch.ops.aten.addbmm.default) +def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): + alpha = jnp.array(alpha).astype(batch1.dtype) + beta = jnp.array(beta).astype(batch1.dtype) + mm = jnp.einsum('bxy, byz -> xz', batch1, batch2) + return jax.lax.cond(beta == 0, + lambda: alpha * mm, + lambda: beta*input + alpha*mm) + @op(torch.ops.aten.gelu) def _aten_gelu(self, *, approximate="none"): @@ -632,13 +644,14 @@ def _aten_min(x, axis=None): @op(torch.ops.aten.amin) -def _aten_amin(x, axis=None): - return jnp.min(x, axis=axis) +def _aten_amin(x, dim=None, keepdim=False): + return _with_reduction_scalar(jnp.amin, x, dim, keepdim) @op(torch.ops.aten.argmin) -def _aten_amin(x, axis=None): - return jnp.argmin(x, axis=axis) +def _aten_argmin(self, dim=None, keepdim=False): + return _with_reduction_scalar( + jnp.argmin, self, dim, keepdim) @op(torch.ops.aten.sin) @@ -1202,13 +1215,27 @@ def _aten_abs(self): # generate aten.amax only @op(torch.ops.aten.amax) def _aten_amax(self, dim=None, keepdim=False): - return jnp.amax(self, axis=dim, keepdims=keepdim) - + return _with_reduction_scalar(jnp.amax, self, dim, keepdim) + + +def _with_reduction_scalar(jax_func, self, dim, keepdim): + expanded = False + if self.ndim == 0: + # for self of rank 0: + # torch.any(x, 0), torch.any(x, -1) works; + # torch.any(x, 1) throws out of bounds, so it's + # behavior is the same as a jnp array of rank 1 + expanded = True + self = jnp.expand_dims(self, 0) + res = jax_func(self, axis=dim, keepdims=keepdim) + if expanded: + res = res.squeeze() + return res # aten.any @op(torch.ops.aten.any) def _aten_any(self, dim=None, keepdim=False): - return jnp.any(self, axis=dim, keepdims=keepdim) + return _with_reduction_scalar(jnp.any, self, dim, keepdim) # aten.arange @@ -1237,7 +1264,8 @@ def _aten_arange(start, # aten.argmax @op(torch.ops.aten.argmax) def _aten_argmax(self, dim=None, keepdim=False): - return jnp.argmax(self, axis=dim, keepdims=keepdim) + return _with_reduction_scalar( + jnp.argmax, self, dim, keepdim) # aten.as_strided @@ -1742,4 +1770,12 @@ def _aten_local_scalar_dense(x): @op(torch.ops.aten.tensor_split.sections) def _aten_tensor_split(ary, indices_or_sections, axis=0): - return jnp.array_split(ary, indices_or_sections, axis) \ No newline at end of file + return jnp.array_split(ary, indices_or_sections, axis) + +@op(torch.ops.aten.outer) +def _aten_outer(a, b): + return jnp.outer(a, b) + +@op(torch.ops.aten.allclose) +def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): + return jnp.allclose(input, other, rtol, atol, equal_nan) \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/functions.py b/experimental/torch_xla2/torch_xla2/functions.py index 9fcd5653a86..94320fd7cb2 100644 --- a/experimental/torch_xla2/torch_xla2/functions.py +++ b/experimental/torch_xla2/torch_xla2/functions.py @@ -92,6 +92,32 @@ def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs): # TODO: handle torch.Size return jnp.full(size, fill_value, dtype=dtype) +@register_function(torch.allclose) +def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): + return jnp.allclose(input, other, rtol, atol, equal_nan) + +@register_function(torch.angle) +def _torch_angle(input): + return jnp.angle(input) + + +@register_function(torch.argsort) +def _torch_argsort(input, dim=-1, descending=False, stable=False): + expanded = False + if input == 0: + # for self of rank 0: + # torch.any(x, 0), torch.any(x, -1) works; + # torch.any(x, 1) throws out of bounds, so it's + # behavior is the same as a jnp array of rank 1 + expanded = True + input = jnp.expand_dims(input, 0) + res = jnp.argsort(input, axis=dim, descending=descending, + stable=stable) + if expanded: + res = res.squeeze() + return res + + class XLAFunctionMode(torch.overrides.TorchFunctionMode): """Context manager that dispatches torch function calls to JAX.""" diff --git a/experimental/torch_xla2/torch_xla2/ops/jtorch.py b/experimental/torch_xla2/torch_xla2/ops/jtorch.py index e69de29bb2d..6628b7e9510 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jtorch.py +++ b/experimental/torch_xla2/torch_xla2/ops/jtorch.py @@ -0,0 +1,7 @@ +import torch + + + +torch_ops_override = { + torch.allclose: torch.ops.aten.allclose +} \ No newline at end of file diff --git a/infra/ansible/config/env.yaml b/infra/ansible/config/env.yaml index 15e8dc79d6c..9e2fe7270cc 100644 --- a/infra/ansible/config/env.yaml +++ b/infra/ansible/config/env.yaml @@ -14,7 +14,7 @@ release_env: TPUVM_MODE: 1 cuda: - TF_CUDA_COMPUTE_CAPABILITIES: 7.0,7.5,8.0,9.0 + TF_CUDA_COMPUTE_CAPABILITIES: "{{ cuda_compute_capabilities }}" XLA_CUDA: 1 # Variables that will be passed to shell environment only for building PyTorch and XLA libs. @@ -22,7 +22,7 @@ build_env: common: LD_LIBRARY_PATH: "$LD_LIBRARY_PATH:/usr/local/lib" # Set explicitly to 0 as setup.py defaults this flag to true if unset. - BUILD_CPP_TESTS: 0 + BUILD_CPP_TESTS: "{{ build_cpp_tests }}" # Force GCC because clang/bazel has issues. CC: gcc-10 CXX: g++-10 @@ -31,9 +31,9 @@ build_env: PYTORCH_BUILD_VERSION: "{{ package_version }}" XLA_SANDBOX_BUILD: 1 BAZEL_REMOTE_CACHE: 1 - SILO_NAME: "cache-silo-{{ arch }}-{{ accelerator }}-{{ clang_version }}" + SILO_NAME: "cache-silo-{{ arch }}-{{ accelerator }}-{{ clang_version }}{{ cache_suffix }}" _GLIBCXX_USE_CXX11_ABI: 0 - GIT_VERSIONED_XLA_BUILD: "{{ nightly_release }}" + GIT_VERSIONED_XLA_BUILD: "{{ nightly_release or git_versioned_xla_build }}" amd64: ARCH: amd64 @@ -41,7 +41,7 @@ build_env: aarch64: cuda: - TF_CUDA_COMPUTE_CAPABILITIES: 7.0,7.5,8.0,9.0 + TF_CUDA_COMPUTE_CAPABILITIES: "{{ cuda_compute_capabilities }}" XLA_CUDA: 1 tpu: diff --git a/infra/ansible/config/vars.yaml b/infra/ansible/config/vars.yaml index 2347d066e84..e5851d0cc77 100644 --- a/infra/ansible/config/vars.yaml +++ b/infra/ansible/config/vars.yaml @@ -1,6 +1,8 @@ # Used for fetching cuda from the right repo, see apt.yaml. cuda_repo: debian11 cuda_version: "11.8" +# Determines supported GPUs. See https://developer.nvidia.com/cuda-gpus +cuda_compute_capabilities: 7.0,7.5,8.0,9.0 # Used for fetching clang from the right repo, see apt.yaml. llvm_debian_repo: bullseye clang_version: 17 @@ -10,3 +12,9 @@ package_version: 2.4.0 nightly_release: false # Whether to preinstall libtpu in the PyTorch/XLA wheel. Ignored for GPU build. bundle_libtpu: 1 +# Suffix for bazel remote cache key +cache_suffix: "" +# Whether to build C++ tests with `torch_xla` wheel +build_cpp_tests: 0 +# Whether to tag wheels with git hash, e.g. X.Y.Z+git123abc +git_versioned_xla_build: false diff --git a/infra/ansible/roles/build_srcs/tasks/main.yaml b/infra/ansible/roles/build_srcs/tasks/main.yaml index d945f150d38..da09a695453 100644 --- a/infra/ansible/roles/build_srcs/tasks/main.yaml +++ b/infra/ansible/roles/build_srcs/tasks/main.yaml @@ -1,3 +1,27 @@ +- name: Read PyTorch pin + ansible.builtin.command: cat {{ (src_root, 'pytorch/xla/.torch_pin') | path_join }} + register: torch_pin + # Pin may not exist + ignore_errors: true + +- name: Checkout PyTorch pin + # ansible.builtin.git wants to fetch the entire history, so check out the pin manually + ansible.builtin.shell: + cmd: | + set -xe + PIN="{{ torch_pin.stdout }}" + if [[ $PIN = \#* ]]; then + PRNUM="${PIN//[!0-9]/}" + git fetch origin "pull/$PRNUM/head" + else + git fetch origin {{ torch_pin.stdout }} + fi + git checkout --recurse-submodules FETCH_HEAD + chdir: "{{ (src_root, 'pytorch') | path_join }}" + args: + executable: /bin/bash + when: torch_pin is succeeded + - name: Build PyTorch ansible.builtin.command: cmd: python setup.py bdist_wheel @@ -77,6 +101,22 @@ state: absent mode: '0755' +- name: Create temp directory for C++ tests + ansible.builtin.file: + path: /tmp/test/bin + state: directory + mode: '0755' + when: build_cpp_tests + +- name: Collect C++ test files + ansible.builtin.shell: | + cd pytorch/xla/build/temp* + bazel query 'kind(".*_test", tests(//:cpp_tests))' --output=label | xargs -n 1 bazel cquery --output=files | xargs cp -t /tmp/test/bin + args: + executable: bash + chdir: "{{ src_root }}" + when: build_cpp_tests + - name: Read Torchvision pin ansible.builtin.command: cat {{ (src_root, 'pytorch') | path_join }}/.github/ci_commit_pins/vision.txt register: torchvision_pin diff --git a/infra/tpu-pytorch-releases/artifacts.auto.tfvars b/infra/tpu-pytorch-releases/artifacts.auto.tfvars index 0229a79c190..16902f663fd 100644 --- a/infra/tpu-pytorch-releases/artifacts.auto.tfvars +++ b/infra/tpu-pytorch-releases/artifacts.auto.tfvars @@ -35,58 +35,58 @@ nightly_builds = [ versioned_builds = [ # Remove libtpu from PyPI builds { - git_tag = "v2.3.0-rc12" - package_version = "2.3.0-rc12" - pytorch_git_rev = "v2.3.0-rc12" + git_tag = "v2.3.0" + package_version = "2.3.0" + pytorch_git_rev = "v2.3.0" accelerator = "tpu" python_version = "3.8" bundle_libtpu = "0" }, { - git_tag = "v2.3.0-rc12" - package_version = "2.3.0-rc12" - pytorch_git_rev = "v2.3.0-rc12" + git_tag = "v2.3.0" + package_version = "2.3.0" + pytorch_git_rev = "v2.3.0" accelerator = "tpu" python_version = "3.9" bundle_libtpu = "0" }, { - git_tag = "v2.3.0-rc12" - package_version = "2.3.0-rc12" - pytorch_git_rev = "v2.3.0-rc12" + git_tag = "v2.3.0" + package_version = "2.3.0" + pytorch_git_rev = "v2.3.0" accelerator = "tpu" python_version = "3.10" bundle_libtpu = "0" }, { - git_tag = "v2.3.0-rc12" - package_version = "2.3.0-rc12" - pytorch_git_rev = "v2.3.0-rc12" + git_tag = "v2.3.0" + package_version = "2.3.0" + pytorch_git_rev = "v2.3.0" accelerator = "tpu" python_version = "3.11" bundle_libtpu = "0" }, # Bundle libtpu for Kaggle { - git_tag = "v2.3.0-rc12" - package_version = "2.3.0-rc12+libtpu" - pytorch_git_rev = "v2.3.0-rc12" + git_tag = "v2.3.0" + package_version = "2.3.0+libtpu" + pytorch_git_rev = "v2.3.0" accelerator = "tpu" python_version = "3.10" bundle_libtpu = "1" }, { - git_tag = "v2.3.0-rc12" - pytorch_git_rev = "v2.3.0-rc12" - package_version = "2.3.0-rc12" + git_tag = "v2.3.0" + pytorch_git_rev = "v2.3.0" + package_version = "2.3.0" accelerator = "cuda" cuda_version = "12.1" python_version = "3.8" }, { - git_tag = "v2.3.0-rc12" - pytorch_git_rev = "v2.3.0-rc12" - package_version = "2.3.0-rc12" + git_tag = "v2.3.0" + pytorch_git_rev = "v2.3.0" + package_version = "2.3.0" accelerator = "cuda" cuda_version = "12.1" python_version = "3.10" diff --git a/plugins/cuda/torch_xla_cuda_plugin/__init__.py b/plugins/cuda/torch_xla_cuda_plugin/__init__.py index 9321d26a1a6..e6863ff711a 100644 --- a/plugins/cuda/torch_xla_cuda_plugin/__init__.py +++ b/plugins/cuda/torch_xla_cuda_plugin/__init__.py @@ -27,6 +27,9 @@ def physical_chip_count(self) -> int: # TODO: default to actual device count return xu.getenv_as('GPU_NUM_DEVICES', int, 1) + def configure_single_process(self): + pass + def client_create_options(self) -> dict: local_process_rank, global_process_rank = self._get_process_rank() local_world_size, global_world_size = self._get_world_size() diff --git a/scripts/apply_patches.sh b/scripts/apply_patches.sh index 923b68c79d4..7ba0a3ef8e3 100755 --- a/scripts/apply_patches.sh +++ b/scripts/apply_patches.sh @@ -7,7 +7,7 @@ XDIR=$CDIR/.. PTDIR=$XDIR/.. OPENXLADIR=$XDIR/third_party/xla -TORCH_PIN="$XDIR/torch_patches/.torch_pin" +TORCH_PIN="$XDIR/.torch_pin" if [ -f "$TORCH_PIN" ]; then CID=$(cat "$TORCH_PIN") # If starts with # and it's not merged into master, fetch from origin diff --git a/setup.py b/setup.py index d45b0b7fc3c..92ccd1004d3 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,7 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_date = '20240409' +_date = '20240425' _libtpu_version = f'0.1.dev{_date}' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' _jax_version = f'0.4.27.dev{_date}' @@ -223,6 +223,10 @@ def bazel_build(self, ext): f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}" ] + build_cpp_tests = build_util.check_env_flag('BUILD_CPP_TESTS', default='0') + if build_cpp_tests: + bazel_argv.append('//:cpp_tests') + import torch cxx_abi = os.getenv('CXX_ABI') or getattr(torch._C, '_GLIBCXX_USE_CXX11_ABI', None) diff --git a/test/benchmarks/run_tests.sh b/test/benchmarks/run_tests.sh index 7d404a7ee7f..fce6140a4fe 100755 --- a/test/benchmarks/run_tests.sh +++ b/test/benchmarks/run_tests.sh @@ -9,7 +9,7 @@ export PYTHONPATH=$PYTHONPATH:$CDIR/../../benchmarks/ # Note [Keep Going] # -# Set the `CONTINUE_ON_ERROR` flag to `true` to make the CircleCI tests continue on error. +# Set the `CONTINUE_ON_ERROR` flag to `true` to make the CI tests continue on error. # This will allow you to see all the failures on your PR, not stopping with the first # test failure like the default behavior. CONTINUE_ON_ERROR="${CONTINUE_ON_ERROR:-0}" @@ -39,10 +39,14 @@ function run_make_tests { } function run_python_tests { - python3 "$CDIR/test_experiment_runner.py" - python3 "$CDIR/test_benchmark_experiment.py" - python3 "$CDIR/test_benchmark_model.py" - python3 "$CDIR/test_result_analyzer.py" + # HACK: don't confuse local `torch_xla` folder with installed package + # Python 3.11 has the permanent fix: https://stackoverflow.com/a/73636559 + pushd $CDIR + python3 "test_experiment_runner.py" + python3 "test_benchmark_experiment.py" + python3 "test_benchmark_model.py" + python3 "test_result_analyzer.py" + popd } function run_tests { diff --git a/test/cpp/run_tests.sh b/test/cpp/run_tests.sh index 74244322840..d6b492dc694 100755 --- a/test/cpp/run_tests.sh +++ b/test/cpp/run_tests.sh @@ -5,7 +5,7 @@ BUILDTYPE="opt" VERB= FILTER= LOGFILE=/tmp/pytorch_cpp_test.log -XLA_EXPERIMENTAL="nonzero:masked_select" +XLA_EXPERIMENTAL="nonzero:masked_select:nms" BAZEL_REMOTE_CACHE="0" BAZEL_VERB="test" diff --git a/test/cpp/test_aten_xla_tensor_5.cpp b/test/cpp/test_aten_xla_tensor_5.cpp index 4070779529f..07e4c2dae86 100644 --- a/test/cpp/test_aten_xla_tensor_5.cpp +++ b/test/cpp/test_aten_xla_tensor_5.cpp @@ -267,6 +267,27 @@ TEST_F(AtenXlaTensorTest, TestEmbedding) { }); } +TEST_F(AtenXlaTensorTest, TestEmbeddingBag) { + torch::Tensor weight = + torch::rand({32, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor indices = + torch::randint(0, 31, {10}, torch::TensorOptions(torch::kLong)); + torch::Tensor offsets = torch::arange(0, 10, 3); + auto out = torch::embedding_bag(weight, indices, offsets); + torch::Tensor result = std::get<0>(out); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_weight = CopyToDevice(weight, device); + torch::Tensor xla_indices = CopyToDevice(indices, device); + torch::Tensor xla_offsets = CopyToDevice(offsets, device); + auto xla_out = torch::embedding_bag(xla_weight, xla_indices, xla_offsets); + torch::Tensor xla_result = std::get<0>(xla_out); + AllClose(result, xla_result); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_embedding_bag_forward_only", + cpp_test::GetIgnoredCounters()); + }); +} + TEST_F(AtenXlaTensorTest, TestOneHot) { int num_classes = 5; torch::Tensor input = diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index e7ac2681d5a..3a3eb3d43f1 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -152,7 +152,7 @@ def test_simple_model(self): # Tests that the dynamo bridge automatically moves tensors to XLA device, # then back to the original device. - @unittest.skipIf(xr.device_type() != "CUDA", + @unittest.skipIf(xr.device_type() != "CUDA" or not torch.cuda.is_available(), f"GPU tests should only run on GPU devices.") def test_simple_model_automoves_tensors(self): x = torch.tensor(100.0).to(device="cuda") @@ -489,13 +489,13 @@ def test_resnet18(self): # Graph 1: forward # Graph 2: backward # Graph 3: sync input for backward - self.assertEqual(met.metric_data('CompileTime')[0], 3) + self.assertLessEqual(met.metric_data('CompileTime')[0], 3) # We execute 3 graphs per step. - self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count * 3) + self.assertLessEqual(met.metric_data('ExecuteTime')[0], sample_count * 3) # one for each forward and one for each backward - self.assertEqual( + self.assertLessEqual( met.metric_data('RunCachedGraphInputData')[0], sample_count * 2) - self.assertEqual( + self.assertLessEqual( met.metric_data('RunCachedGraphOutputData')[0], sample_count * 2) @@ -641,10 +641,7 @@ def test_all_cpu_tensor(self): # there should be 18 paramters + 1 input self.assertGreater(len(w), 15) self.assertIn('Found tensor with shape torch.Size', str(w[0].message)) - # no XLA operation should happens except a empty mark_step. Partitioner should offload all CPU - # ops to CPU. - self.assertEqual(len(met.counter_names()), 1) - self.assertIn('MarkStep', met.counter_names()) + self.assertLessEqual(len(met.counter_names()), 1) class DynamoOperationsTests(test_utils.XlaTestCase): diff --git a/test/pjrt/test_runtime_tpu.py b/test/pjrt/test_runtime_tpu.py index 0def33ae275..744039a4f58 100644 --- a/test/pjrt/test_runtime_tpu.py +++ b/test/pjrt/test_runtime_tpu.py @@ -206,7 +206,8 @@ def _runtime_device_attributes(): def test_runtime_device_attributes(self): result = pjrt.run_multiprocess(self._runtime_device_attributes) for device in result.values(): - self.assertCountEqual(['coords', 'core_on_chip'], list(device.keys())) + self.assertCountEqual(['coords', 'core_on_chip', 'num_cores'], + list(device.keys())) self.assertIsInstance(device['coords'], list) self.assertIsInstance(device['core_on_chip'], int) @@ -218,7 +219,7 @@ def test_global_runtime_device_attributes(self): results = pjrt.run_multiprocess(self._global_runtime_device_attributes) for result in results.values(): for device in result: - self.assertCountEqual(['coords', 'core_on_chip', 'name'], + self.assertCountEqual(['coords', 'core_on_chip', 'name', 'num_cores'], list(device.keys())) self.assertIsInstance(device['coords'], list) self.assertIsInstance(device['core_on_chip'], int) diff --git a/test/pytorch_test_base.py b/test/pytorch_test_base.py index 88ad0f6bc3d..3a6dcdd96c6 100644 --- a/test/pytorch_test_base.py +++ b/test/pytorch_test_base.py @@ -70,6 +70,7 @@ 'test_pdist_norm_backward_xla', # pdist_single 'test_pdist_norm_forward_xla', # pdist_single 'test_nuclear_norm_axes_small_brute_force', + 'test_nondeterministic_alert_EmbeddingBag_max_xla', # FIXME: implement embedding_bag_backward 'test_mul_intertype_scalar', 'test_masked_select_discontiguous', # FIXME: wrong result 'test_memory_format_type', diff --git a/test/run_tests.sh b/test/run_tests.sh index 4d4bd530e27..1c5095baa5a 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -8,7 +8,7 @@ VERBOSITY=2 # Note [Keep Going] # -# Set the `CONTINUE_ON_ERROR` flag to `true` to make the CircleCI tests continue on error. +# Set the `CONTINUE_ON_ERROR` flag to `true` to make the CI tests continue on error. # This will allow you to see all the failures on your PR, not stopping with the first # test failure like the default behavior. CONTINUE_ON_ERROR="${CONTINUE_ON_ERROR:-0}" @@ -104,7 +104,7 @@ function run_xla_hlo_debug { function run_dynamic { echo "Running in DynamicShape mode: $@" - XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter" run_test "$@" + XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter:nms" run_test "$@" } function run_eager_debug { @@ -162,7 +162,6 @@ function run_xla_op_tests1 { run_dynamic "$CDIR/ds/test_dynamic_shapes.py" run_dynamic "$CDIR/ds/test_dynamic_shape_models.py" "$@" --verbosity=$VERBOSITY run_eager_debug "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY - run_test "$CDIR/test_grad_checkpoint.py" run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_pt_xla_debug "$CDIR/debug_tool/test_pt_xla_debug.py" diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index 0595f502da0..d1f6cdc3dce 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -205,6 +205,8 @@ def test_dynamo_spmd_mark_sharding_outside_of_compile(self): dynamo_res = dynamo_linear(xla_x) self.assertEqual(met.metric_data('CompileTime')[0], compile_count) + # https://github.com/pytorch/xla/pull/6921#issuecomment-2062106737 + @unittest.skip("Failing in CI") def test_mark_sharding_inside_compile(self): met.clear_counters() device = xm.xla_device() diff --git a/test/stablehlo/test_export_fx_passes.py b/test/stablehlo/test_export_fx_passes.py index d1e731abd6e..82650997316 100644 --- a/test/stablehlo/test_export_fx_passes.py +++ b/test/stablehlo/test_export_fx_passes.py @@ -18,7 +18,7 @@ class ExportFxPassTest(unittest.TestCase): def test_decompose_dynamic_shape_select(self): args = (torch.rand((10, 197, 768)), 1, 0) - dynamic_shapes = ([{0: Dim("bs")}, None, None],) + dynamic_shapes = (({0: Dim("bs")}, None, None),) m = wrap_func_as_nn_module(torch.ops.aten.select.int) ep = export(m, args, dynamic_shapes=dynamic_shapes) out1 = ep.module()(*args) @@ -55,7 +55,7 @@ def forward(self, x): def test_embedding_indices_flatten(self): args = (torch.rand((20, 768)), torch.randint(0, 15, (3, 10)).to(torch.int64)) - dynamic_shapes = ([None, {0: Dim("bs")}],) + dynamic_shapes = ((None, {0: Dim("bs")}),) m = wrap_func_as_nn_module(torch.ops.aten.embedding.default) ep = export(m, args, dynamic_shapes=dynamic_shapes) print(ep) diff --git a/test/stablehlo/test_exports.py b/test/stablehlo/test_exports.py index a08b65d1ffe..6208ae1ca52 100644 --- a/test/stablehlo/test_exports.py +++ b/test/stablehlo/test_exports.py @@ -45,7 +45,7 @@ def test_interpolate(self): exported = torch.export.export(model, arg) shlo = exported_program_to_stablehlo(exported) ans2 = shlo(*arg).cpu().to(torch.float32) - self.assertTrue(torch.allclose(ans, ans2, atol=1e-5)) + torch.testing.assert_close(ans, ans2, rtol=1e-5, atol=1e-4) def test_constant(self): diff --git a/test/stablehlo/test_unbounded_dynamism.py b/test/stablehlo/test_unbounded_dynamism.py index e185a47007e..3cd17a7fe34 100644 --- a/test/stablehlo/test_unbounded_dynamism.py +++ b/test/stablehlo/test_unbounded_dynamism.py @@ -27,7 +27,7 @@ class UnboundedDynamismExportTest(unittest.TestCase): def test_add(self): args = (torch.rand((10, 197, 768)), torch.rand((10, 197, 768))) - dynamic_shapes = ([{0: Dim("dim")}, {0: Dim("dim")}],) + dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) m = wrap_func_as_nn_module(torch.ops.aten.add.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -45,7 +45,7 @@ def test_add(self): def test_add_scalar(self): args = (torch.rand((10, 197, 768)), 0.345) - dynamic_shapes = ([{0: Dim("dim")}, None],) + dynamic_shapes = (({0: Dim("dim")}, None),) m = wrap_func_as_nn_module(torch.ops.aten.add.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -62,7 +62,7 @@ def test_add_scalar(self): def test_addmm(self): args = (torch.rand((5)), torch.rand((10, 5)), torch.rand((5, 5))) - dynamic_shapes = ([None, {0: Dim("dim")}, None],) + dynamic_shapes = ((None, {0: Dim("dim")}, None),) m = wrap_func_as_nn_module(torch.ops.aten.addmm.default) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -83,7 +83,7 @@ def test_bmm(self): torch.rand((24, 197, 64)), torch.rand((24, 64, 197)), ) - dynamic_shapes = ([{0: Dim("dim")}, {0: Dim("dim")}],) + dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) m = wrap_func_as_nn_module(torch.ops.aten.bmm.default) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -104,7 +104,7 @@ def test_bmm_dynamic_out_dim(self): torch.rand((8, 128, 256)), torch.rand((8, 256, 3)), ) - dynamic_shapes = ([None, {2: Dim("dim")}],) + dynamic_shapes = ((None, {2: Dim("dim")}),) m = wrap_func_as_nn_module(torch.ops.aten.bmm.default) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -125,7 +125,7 @@ def test_bmm_dynamic_reduction_dim(self): torch.rand((8, 128, 3)), torch.rand((8, 3, 256)), ) - dynamic_shapes = ([{2: Dim("dim")}, {1: Dim("dim")}],) + dynamic_shapes = (({2: Dim("dim")}, {1: Dim("dim")}),) m = wrap_func_as_nn_module(torch.ops.aten.bmm.default) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -143,7 +143,7 @@ def test_bmm_dynamic_reduction_dim(self): def test_cat(self): args = (torch.rand((10, 1, 768)), torch.rand((10, 196, 768))) - dynamic_shapes = ([{0: Dim("dim")}, {0: Dim("dim")}],) + dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) m = wrap_func_as_nn_module( lambda x, y: torch.ops.aten.cat.default([x, y], 1)) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) @@ -166,7 +166,7 @@ def test_conv(self): torch.rand((5, 3, 16, 16)), torch.rand((5)), ) - dynamic_shapes = ([{0: Dim("dim")}, None, None],) + dynamic_shapes = (({0: Dim("dim")}, None, None),) m = wrap_func_as_nn_module( lambda x, y, z: torch.ops.aten.convolution.default( x, @@ -197,7 +197,7 @@ def test_conv1d(self): torch.rand((3, 1, 800)), torch.rand((512, 1, 10)), ) - dynamic_shapes = ([{0: Dim("dim")}, None],) + dynamic_shapes = (({0: Dim("dim")}, None),) # dynamic_shapes = None m = wrap_func_as_nn_module(lambda x, y: torch.ops.aten.convolution.default( x, @@ -225,7 +225,7 @@ def test_conv1d(self): def test_cumsum(self): args = (torch.rand((10, 5)), 1) - dynamic_shapes = ([{0: Dim("dim")}, None],) + dynamic_shapes = (({0: Dim("dim")}, None),) m = wrap_func_as_nn_module(torch.ops.aten.cumsum.default) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -242,7 +242,7 @@ def test_cumsum(self): def test_div(self): args = (torch.rand((10, 12, 197)), torch.rand((10, 12, 197))) - dynamic_shapes = ([{0: Dim("dim")}, {0: Dim("dim")}],) + dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) m = wrap_func_as_nn_module(torch.ops.aten.div.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -260,7 +260,7 @@ def test_div(self): def test_div_scalar(self): args = (torch.rand((10, 12, 197)), 8.0) - dynamic_shapes = ([{0: Dim("dim")}, None],) + dynamic_shapes = (({0: Dim("dim")}, None),) m = wrap_func_as_nn_module(torch.ops.aten.div.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -277,7 +277,7 @@ def test_div_scalar(self): def test_gelu(self): args = (torch.rand((3, 5)),) - dynamic_shapes = ([{0: Dim("dim")}],) + dynamic_shapes = (({0: Dim("dim")},),) m = wrap_func_as_nn_module(torch.ops.aten.gelu) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -342,7 +342,7 @@ def forward(self, x): def test_mul(self): args = (torch.rand((10, 2, 768)), torch.rand((10, 2, 768))) - dynamic_shapes = ([{0: Dim("dim")}, {0: Dim("dim")}],) + dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) m = wrap_func_as_nn_module(torch.ops.aten.mul.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -360,7 +360,7 @@ def test_mul(self): def test_mul_scalar(self): args = (torch.rand((10, 2, 768)), 0.125) - dynamic_shapes = ([{0: Dim("dim")}, None],) + dynamic_shapes = (({0: Dim("dim")}, None),) m = wrap_func_as_nn_module(torch.ops.aten.mul.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -483,7 +483,7 @@ def forward(self, x, weight, bias): def test_permute(self): args = (torch.rand((10, 197, 12, 64)),) - dynamic_shapes = ([{0: Dim("dim")}],) + dynamic_shapes = (({0: Dim("dim")},),) m = wrap_func_as_nn_module( lambda x: torch.ops.aten.permute.default(x, [0, 2, 1, 3])) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) @@ -502,7 +502,7 @@ def test_permute(self): def test_select(self): args = (torch.rand((10, 197, 768)), 1, 0) - dynamic_shapes = ([{0: Dim("dim")}, None, None],) + dynamic_shapes = (({0: Dim("dim")}, None, None),) m = wrap_func_as_nn_module(torch.ops.aten.select.int) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -519,7 +519,7 @@ def test_select(self): def test_slice(self): args = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807) - dynamic_shapes = ([{0: Dim("dim")}, None, None, None],) + dynamic_shapes = (({0: Dim("dim")}, None, None, None),) m = wrap_func_as_nn_module(torch.ops.aten.slice.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -537,7 +537,7 @@ def test_slice(self): def test_slice_2(self): args = (torch.rand((10, 3, 224, 224)), 1, 0, 2) - dynamic_shapes = ([{0: Dim("dim")}, None, None, None],) + dynamic_shapes = (({0: Dim("dim")}, None, None, None),) m = wrap_func_as_nn_module(torch.ops.aten.slice.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -555,7 +555,7 @@ def test_slice_2(self): def test_softmax(self): args = (torch.rand((10, 12, 197, 197)), -1, False) - dynamic_shapes = ([{0: Dim("dim")}, None, None],) + dynamic_shapes = (({0: Dim("dim")}, None, None),) m = wrap_func_as_nn_module(torch.ops.aten._softmax.default) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -573,7 +573,7 @@ def test_softmax(self): def test_sub(self): args = (torch.rand((10, 1, 1, 10)), torch.rand((10, 1, 1, 10))) - dynamic_shapes = ([{0: Dim("dim")}, {0: Dim("dim")}],) + dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) m = wrap_func_as_nn_module(torch.ops.aten.sub.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -591,7 +591,7 @@ def test_sub(self): def test_softmax_reduce_on_dynamic_dim(self): args = (torch.rand((1, 8, 128, 3)), -1, False) - dynamic_shapes = ([{3: Dim("dim")}, None, None],) + dynamic_shapes = (({3: Dim("dim")}, None, None),) m = wrap_func_as_nn_module(torch.ops.aten._softmax.default) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -609,7 +609,7 @@ def test_softmax_reduce_on_dynamic_dim(self): @unittest.skip("Converted StableHLO contains i1 dtype, not expected.") def test_index(self): args = (torch.rand((2, 10)), torch.arange(5)) - dynamic_shapes = ([None, {0: Dim("dim")}],) + dynamic_shapes = ((None, {0: Dim("dim")}),) m = wrap_func_as_nn_module( lambda x, y: torch.ops.aten.index.Tensor(x, [None, y])) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) @@ -628,7 +628,7 @@ def test_index(self): def test_sub_scalar(self): args = (1.0, torch.rand((10, 1, 1, 10))) - dynamic_shapes = ([None, {0: Dim("dim")}],) + dynamic_shapes = ((None, {0: Dim("dim")}),) m = wrap_func_as_nn_module(torch.ops.aten.sub.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -670,7 +670,7 @@ def forward(self, x): def test_transpose_on_dynamic_dim(self): args = (torch.rand((1, 8, 3, 256)),) - dynamic_shapes = ([{2: Dim("dim")}],) + dynamic_shapes = (({2: Dim("dim")},),) m = wrap_func_as_nn_module( lambda x: torch.ops.aten.transpose.int(x, -2, -1)) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) @@ -688,7 +688,7 @@ def test_transpose_on_dynamic_dim(self): def test_unsqueeze_1(self): args = (torch.rand((3, 10)),) - dynamic_shapes = ([{0: Dim("dim")}],) + dynamic_shapes = (({0: Dim("dim")},),) m = wrap_func_as_nn_module(lambda x: torch.ops.aten.unsqueeze.default(x, 1)) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -705,7 +705,7 @@ def test_unsqueeze_1(self): def test_unsqueeze_2(self): args = (torch.rand((1, 1, 3, 256)),) - dynamic_shapes = ([{2: Dim("dim")}],) + dynamic_shapes = (({2: Dim("dim")},),) m = wrap_func_as_nn_module(lambda x: torch.ops.aten.unsqueeze.default(x, 2)) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) diff --git a/test/test_operations.py b/test/test_operations.py index 7fb9f5bc3e3..ff32c268927 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -88,6 +88,12 @@ def onlyOnCUDA(fn): return unittest.skipIf(accelerator != "cuda", "PJRT_DEVICE=CUDA required")(fn) +def onlyIfXLAExperimentalContains(feat): + experimental = os.environ.get("XLA_EXPERIMENTAL", "").split(":") + return unittest.skipIf(feat not in experimental, + f"XLA_EXPERIMENTAL={feat} required") + + def _gen_tensor(*args, **kwargs): return torch.randn(*args, **kwargs) @@ -2454,6 +2460,7 @@ def test_dropout(self): # These tests were extracted and adapted from torchvision. # Source: vision/test/test_ops.py +@onlyIfXLAExperimentalContains("nms") class TestNMS(test_utils.XlaTestCase): def _reference_nms(self, boxes, scores, iou_threshold): diff --git a/test/test_pallas.py b/test/test_pallas.py index 2902b5e21ba..7b8755fc71e 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -10,6 +10,8 @@ from torch_xla import runtime as xr from torch_xla._internal import tpu +import numpy as np + if xr.device_type() == 'TPU': from torch_xla.experimental.custom_kernel import jax_import_guard jax_import_guard() @@ -20,12 +22,52 @@ class PallasTest(unittest.TestCase): - def _attention(self, q, k, v): + # This is to create a diagonal mask where only elements within the same segment + # can attend to each other. Since the mask is to mask out the unrelevant parts, + # therefore we use != instead of ==. + def _make_attention_mask_from_segment_ids(self, q_segment_ids, + kv_segment_ids): + return q_segment_ids.view(q_segment_ids.shape[0], 1, + q_segment_ids.shape[1], 1) != kv_segment_ids.view( + kv_segment_ids.shape[0], 1, 1, + kv_segment_ids.shape[1]) + + def _attention(self, q, k, v, *, attn_mask=None): attn_weight = q @ k.transpose(-2, -1) + if attn_mask is not None: + # Masked out the unrelevant parts. + attn_weight = attn_weight.masked_fill(attn_mask, + torch.finfo(attn_weight.dtype).min) attn_weight = nn.functional.softmax(attn_weight, dim=-1) attn_output = attn_weight @ v return attn_output + # The following helper functions prefixed with _pagedattention are used for PagedAttention unit tests + # Reference: https://github.com/google/jax/blob/main/tests/pallas/paged_attention_kernel_test.py + def _pagedattention_generate_qkv( + self, + seq_lens, + page_size, + max_seq_len, + num_kv_heads, + num_heads, + head_dim, + dtype=torch.float32, + ): + assert max_seq_len % page_size == 0 + pages_per_sequence = max_seq_len // page_size + batch_size = len(seq_lens) + total_pages = batch_size * pages_per_sequence + k_pages = torch.randn( + num_kv_heads, total_pages, page_size, head_dim, dtype=dtype) + v_pages = torch.randn( + num_kv_heads, total_pages, page_size, head_dim, dtype=dtype) + page_indices = torch.randperm( + batch_size * pages_per_sequence, dtype=torch.int32) + page_indices = page_indices.reshape(batch_size, pages_per_sequence) + q = torch.randn(batch_size, num_heads, head_dim, dtype=dtype) + return q, k_pages, v_pages, page_indices + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") def test_tpu_custom_call_pallas_add(self): # This payload is generated by the following Pallas code: @@ -417,6 +459,7 @@ def test__flash_attention_bwd_dkv(self): @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, "This test only works on TPUv3+.") def test_flash_attention_backward(self): + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) from torch_xla.experimental.custom_kernel import flash_attention torch.manual_seed(42) @@ -449,9 +492,248 @@ def test_flash_attention_backward(self): loss.backward() xm.mark_step() - mse = torch.nn.MSELoss() for i in [(q, q_grad), (k, k_grad), (v, v_grad)]: - self.assertTrue(mse(i[0].grad.cpu(), i[1].cpu()) < 1e-4) + self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, + "This test only works on TPUv4+.") + def test_paged_attention_wrapper(self): + from torch_xla.experimental.custom_kernel import paged_attention + from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention + + max_kv_len = 2048 + block_size = 512 + page_size = 64 + num_kv_heads = 8 + q_kv_head_ratio = 8 + head_dim = 256 + dtype = torch.float32 + seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32) + + q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv( + seq_lens, + page_size, + max_kv_len, + num_kv_heads, + num_kv_heads * q_kv_head_ratio, + head_dim, + ) + + q_xla = q.to("xla") + k_pages_xla = k_pages.to("xla") + v_pages_xla = v_pages.to("xla") + seq_lens_xla = seq_lens.to("xla") + page_indices_xla = page_indices.to("xla") + + output = paged_attention( + q_xla, + k_pages_xla, + v_pages_xla, + seq_lens_xla, + page_indices_xla, + pages_per_compute_block=block_size // page_size, + ) + + q_jax = jnp.array(q.numpy(), dtype=jnp.float32) + k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32) + v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32) + seq_lens_jax = jnp.array(seq_lens.numpy(), dtype=jnp.int32) + page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32) + expected_output = torch.from_numpy( + np.array( + jax_paged_attention( + q_jax, + k_pages_jax, + v_pages_jax, + seq_lens_jax, + page_indices_jax, + pages_per_compute_block=block_size // page_size, + ))) + + self.assertTrue( + torch.allclose( + output.cpu()[seq_lens > 0], + expected_output.cpu()[seq_lens > 0], + atol=1e-5, + rtol=1e-5)) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, + "This test only works on TPUv4+.") + def test_paged_attention_wrapper_with_dynamo(self): + from torch_xla.experimental.custom_kernel import paged_attention + from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention + + max_kv_len = 2048 + block_size = 512 + page_size = 64 + num_kv_heads = 8 + q_kv_head_ratio = 8 + head_dim = 256 + dtype = torch.float32 + seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32) + + q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv( + seq_lens, + page_size, + max_kv_len, + num_kv_heads, + num_kv_heads * q_kv_head_ratio, + head_dim, + ) + + q_xla = q.to("xla") + k_pages_xla = k_pages.to("xla") + v_pages_xla = v_pages.to("xla") + seq_lens_xla = seq_lens.to("xla") + page_indices_xla = page_indices.to("xla") + + def paged_attention_wrapper(q, k, v, seq_lens, page_indices, + pages_per_compute_block): + return paged_attention( + q_xla, + k_pages_xla, + v_pages_xla, + seq_lens_xla, + page_indices_xla, + pages_per_compute_block=block_size // page_size, + ) + + compiled_paged_attention = torch.compile( + paged_attention_wrapper, backend="openxla") + output = paged_attention_wrapper( + q_xla, + k_pages_xla, + v_pages_xla, + seq_lens_xla, + page_indices_xla, + pages_per_compute_block=block_size // page_size, + ) + + q_jax = jnp.array(q.numpy(), dtype=jnp.float32) + k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32) + v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32) + seq_lens_jax = jnp.array(seq_lens.numpy(), dtype=jnp.int32) + page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32) + expected_output = torch.from_numpy( + np.array( + jax_paged_attention( + q_jax, + k_pages_jax, + v_pages_jax, + seq_lens_jax, + page_indices_jax, + pages_per_compute_block=block_size // page_size, + ))) + + self.assertTrue( + torch.allclose( + output.cpu()[seq_lens > 0], + expected_output.cpu()[seq_lens > 0], + atol=1e-5, + rtol=1e-5)) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_wrapper_segment_ids_1(self): + from torch_xla.experimental.custom_kernel import flash_attention + from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention as jax_flash_attention, SegmentIds + + q = torch.randn(3, 2, 128, 4) + k = torch.randn(3, 2, 128, 4) + v = torch.randn(3, 2, 128, 4) + q_segment_ids = torch.zeros(3, 128) + kv_segment_ids = torch.zeros(3, 128) + o = flash_attention( + q.to("xla"), k.to("xla"), v.to("xla"), False, q_segment_ids.to("xla"), + kv_segment_ids.to("xla")) + + jax_q = jnp.array(q.numpy(), dtype=jnp.float32) + jax_k = jnp.array(k.numpy(), dtype=jnp.float32) + jax_v = jnp.array(v.numpy(), dtype=jnp.float32) + jax_q_segment_ids = jnp.array(q_segment_ids.numpy(), dtype=jnp.float32) + jax_kv_segment_ids = jnp.array(kv_segment_ids.numpy(), dtype=jnp.float32) + expected_o = torch.from_numpy( + np.array( + jax_flash_attention( + jax_q, + jax_k, + jax_v, + segment_ids=SegmentIds(jax_q_segment_ids, jax_kv_segment_ids), + ))) + + self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_wrapper_segment_ids_2(self): + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) + from torch_xla.experimental.custom_kernel import flash_attention + + q = torch.randn(3, 2, 128, 4).to("xla") + k = torch.randn(3, 2, 128, 4).to("xla") + v = torch.randn(3, 2, 128, 4).to("xla") + q_segment_ids = torch.zeros(3, 128).to("xla") + kv_segment_ids = torch.zeros(3, 128).to("xla") + o = flash_attention(q, k, v, False, q_segment_ids, kv_segment_ids) + + expected_o = self._attention( + q, + k, + v, + attn_mask=self._make_attention_mask_from_segment_ids( + q_segment_ids, kv_segment_ids)) + self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_backward_segment_ids(self): + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) + from torch_xla.experimental.custom_kernel import flash_attention + + torch.manual_seed(42) + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + q_segment_ids = torch.zeros(4, 128).to("xla") + kv_segment_ids = torch.zeros(4, 128).to("xla") + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = flash_attention(q, k, v, False, q_segment_ids, kv_segment_ids) + loss = o.sum() + loss.backward() + xm.mark_step() + + q_grad = q.grad + k_grad = k.grad + v_grad = v.grad + + torch.manual_seed(42) + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + q_segment_ids = torch.zeros(4, 128).to("xla") + kv_segment_ids = torch.zeros(4, 128).to("xla") + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = self._attention( + q, + k, + v, + attn_mask=self._make_attention_mask_from_segment_ids( + q_segment_ids, kv_segment_ids)) + loss = o.sum() + loss.backward() + xm.mark_step() + + for i in [(q, q_grad), (k, k_grad), (v, v_grad)]: + self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) if __name__ == '__main__': diff --git a/test/test_pallas_spmd.py b/test/test_pallas_spmd.py new file mode 100644 index 00000000000..33434594191 --- /dev/null +++ b/test/test_pallas_spmd.py @@ -0,0 +1,110 @@ +import logging +import os +import unittest + +import torch +from torch import nn as nn + +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.distributed.spmd as xs +from torch_xla import runtime as xr +from torch_xla._internal import tpu + +if xr.device_type() == 'TPU': + from torch_xla.experimental.custom_kernel import flash_attention + from torch_xla.experimental.custom_kernel import jax_import_guard + jax_import_guard() + import jax + import jax.numpy as jnp + from jax.experimental import pallas as pl + + +class PallasTest(unittest.TestCase): + + def _attention(self, q, k, v): + attn_weight = q @ k.transpose(-2, -1) + attn_weight = nn.functional.softmax(attn_weight, dim=-1) + attn_output = attn_weight @ v + return attn_output + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_spmd_data_parallel(self): + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) + n_devices = xr.global_runtime_device_count() + xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1))) + + q = torch.randn(4, 2, 128, 4).to("xla") + k = torch.randn(4, 2, 128, 4).to("xla") + v = torch.randn(4, 2, 128, 4).to("xla") + + o = flash_attention(q, k, v, partition_spec=range(n_devices)) + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(o), + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + + expected_o = self._attention(q, k, v) + self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_backward_spmd_data_parallel(self): + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) + n_devices = xr.global_runtime_device_count() + xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1))) + + torch.manual_seed(42) + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = flash_attention(q, k, v, partition_spec=range(n_devices)) + loss = o.sum() + loss.backward() + xm.mark_step() + + q_grad = q.grad + k_grad = k.grad + v_grad = v.grad + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(q_grad), + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(k_grad), + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(v_grad), + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + + torch.manual_seed(42) + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = self._attention(q, k, v) + loss = o.sum() + loss.backward() + xm.mark_step() + + for i in [(q, q_grad), (k, k_grad), (v, v_grad)]: + self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + torch.set_default_dtype(torch.float32) + torch.manual_seed(42) + torch_xla._XLAC._xla_set_use_full_mat_mul_precision( + use_full_mat_mul_precision=True) + xr.use_spmd() + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 413951854d6..b2a8fff33dc 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -11,9 +11,10 @@ python3 test/spmd/test_xla_distributed_checkpoint.py python3 test/spmd/test_train_spmd_linear_model.py python3 test/spmd/test_xla_spmd_python_api_interaction.py python3 test/spmd/test_xla_auto_sharding.py -XLA_EXPERIMENTAL=nonzero:masked_select python3 test/ds/test_dynamic_shape_models.py -v -XLA_EXPERIMENTAL=nonzero:masked_select python3 test/ds/test_dynamic_shapes.py -v +XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shape_models.py -v +XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.py -v python3 test/test_autocast.py +python3 test/test_grad_checkpoint.py python3 test/dynamo/test_dynamo.py python3 test/spmd/test_spmd_debugging.py python3 test/pjrt/test_dtypes.py diff --git a/torch_patches/README.md b/torch_patches/README.md deleted file mode 100644 index f6476f64ca5..00000000000 --- a/torch_patches/README.md +++ /dev/null @@ -1,32 +0,0 @@ -# Guidelines For Patch File Names - -Files with extension '.diff' are consider as git patches by apply script. - -A file for PyTorch PR _N_ needs to be named 'N.diff'. - -Patch files which are not related to PyTorch PRs, should begin with an 'X' character, -followed by a two digit number, followed by a dash ('-'), a name, and '.diff'. -Example: - -``` -X10-optimizer.diff -``` - -Patch file are alphabetically ordered, so PyTorch PR patches are always applied -before the non PyTorch ones. - - -There's a special file `torch_patches/.torch_pin`, which is used to coordinate landing PRs in -`pytorch/pytorch` and `pytorch/xla`. - -To test a `pytorch/xla` PR against a `pytorch/pytorch` PR or branch, -put the PR number or branch name in this file. -Example: - -``` -#32451 -# or -my_awesome_branch # (must live in `pytorch/pytorch`) -``` - -In the case where the pytorch/pytorch PR also depends on the pytorch/xla PR, you will also need to update the https://github.com/pytorch/pytorch/blob/main/.github/ci_commit_pins/xla.txt to match the latest hash of your pytorch/xla PR. To be noted, the hash from a PR produced by a fork won't work in this case. Then you need to find someone from the pytorch/xla team to produe a branch PR for you. diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index ebc0af6c7ad..6b83d45e4b4 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -186,6 +186,27 @@ def _init_xla_lazy_backend(): # TODO @wonjoo come up with a long term fix in Dynamo. torch._dynamo.config.automatic_dynamic_shapes = False +# Activate view-replay on AOTAutograd. +# See: https://github.com/pytorch/pytorch/pull/124488 +import torch._functorch.config + +torch._functorch.config.view_replay_for_aliased_outputs = True + +import importlib.metadata +import warnings + +try: + # TensorFlow TPU distribution has the same package name as GPU, but not CPU + dist = importlib.metadata.distribution('tensorflow') + warnings.warn( + "`tensorflow` can conflict with `torch-xla`. Prefer `tensorflow-cpu` when" + " using PyTorch/XLA. To silence this warning, `pip uninstall -y " + "tensorflow && pip install tensorflow-cpu`. If you are in a notebook " + "environment such as Colab or Kaggle, restart your notebook runtime " + "afterwards.") +except importlib.metadata.PackageNotFoundError: + pass + from .stablehlo import save_as_stablehlo, save_torch_model_as_stablehlo from .experimental import plugins diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index a7ae1c47964..56a69ca1e05 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1290,6 +1290,38 @@ at::Tensor XLANativeFunctions::embedding_dense_backward( num_weights, padding_idx, scale_grad_by_freq)); } +std::tuple +XLANativeFunctions::_embedding_bag_forward_only( + const at::Tensor& weight, const at::Tensor& indices, + const at::Tensor& offsets, bool scale_grad_by_freq, int64_t mode, + bool sparse, const c10::optional& per_sample_weights, + bool include_last_offset, int64_t padding_idx) { + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + if (mode == 1 || scale_grad_by_freq || sparse || padding_idx != -1) { + return at::native::call_fallback_fn< + &xla_cpu_fallback, + ATEN_OP(_embedding_bag_forward_only)>::call(weight, indices, offsets, + scale_grad_by_freq, mode, + sparse, per_sample_weights, + include_last_offset, + padding_idx); + } + auto indices_tensor = bridge::GetXlaTensor(indices); + auto sample_weights = + per_sample_weights.has_value() && per_sample_weights.value().defined() + ? bridge::GetXlaTensor(per_sample_weights.value()) + : tensor_methods::full_like(indices_tensor, 1.0, + *torch_xla::bridge::GetXlaDevice(weight), + at::ScalarType::Float); + auto result = tensor_methods::embedding_bag( + bridge::GetXlaTensor(weight), indices_tensor, + bridge::GetXlaTensor(offsets), mode, sample_weights, include_last_offset); + return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(result)), + bridge::AtenFromXlaTensor(std::get<1>(result)), + bridge::AtenFromXlaTensor(std::get<2>(result)), + bridge::AtenFromXlaTensor(std::get<3>(result))); +} + at::Tensor XLANativeFunctions::empty_symint( at::SymIntArrayRef sym_size, c10::optional dtype, c10::optional layout, c10::optional device, @@ -3709,6 +3741,7 @@ at::Tensor XLANativeFunctions::embedding_symint(const at::Tensor& weight, scale_grad_by_freq, sparse); } + // TODO: We need to make use of the TPU embedding core here eventually. TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::embedding( bridge::GetXlaTensor(weight), bridge::GetXlaTensor(indices))); diff --git a/torch_xla/csrc/ops/embedding_bag.cpp b/torch_xla/csrc/ops/embedding_bag.cpp new file mode 100644 index 00000000000..d2bb034a005 --- /dev/null +++ b/torch_xla/csrc/ops/embedding_bag.cpp @@ -0,0 +1,192 @@ +#include "torch_xla/csrc/ops/embedding_bag.h" + +#include "torch_xla/csrc/helpers.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" +#include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/shape_helper.h" +#include "torch_xla/csrc/xla_lower_util.h" +#include "tsl/platform/stacktrace.h" +#include "xla/client/lib/constants.h" +#include "xla/client/lib/loops.h" +#include "xla/client/lib/slicing.h" +#include "xla/shape_util.h" + +namespace torch_xla { +namespace { +const int MODE_SUM = 0; +const int MODE_MEAN = 1; +const int MODE_MAX = 2; +std::vector BuildEmbeddingBag(xla::XlaOp weight, xla::XlaOp indices, + xla::XlaOp offsets, + xla::XlaOp per_sample_weights, + bool include_last_offset, int mode) { + xla::Shape offset_shape = ShapeHelper::ShapeOfXlaOp(offsets); + int64_t n = offset_shape.dimensions(0); + xla::Shape weight_shape = ShapeHelper::ShapeOfXlaOp(weight); + int64_t weight_dim = weight_shape.dimensions(1); + xla::Shape indices_shape = ShapeHelper::ShapeOfXlaOp(indices); + int64_t num_embeddings = indices_shape.dimensions(0); + XLA_CHECK(indices_shape.rank() == 1 || indices_shape.rank() == 2) + << "input has to be a 1D or 2D Tensor, but got Tensor of dimension " + << indices_shape.rank(); + if (indices_shape.rank() == 1) { + XLA_CHECK(offset_shape.rank() == 1) + << "offsets has to be a 1D Tensor, but got Tensor of dimension " + << offset_shape.rank(); + } + XLA_CHECK(weight_shape.rank() == 2) + << "weight has to be a 2D Tensor, but got Tensor of dimension " + << weight_shape.rank(); + + xla::XlaOp output2 = xla::ZerosLike(indices); + xla::XlaOp output3 = xla::ZerosLike(offsets); + std::vector sizes = {n, weight_dim}; + xla::XlaOp output4 = + xla::Zeros(offsets.builder(), + xla::ShapeUtil::MakeShape(offset_shape.element_type(), sizes)); + + xla::XlaOp embeddings = xla::TorchIndexSelect(weight, indices, 0); + xla::XlaOp embeddings_weighted = xla::Mul( + embeddings, xla::ConvertElementType( + xla::BroadcastInDim(per_sample_weights, + {num_embeddings, weight_dim}, {0}), + weight_shape.element_type())); + + std::vector shape_elements = { + xla::ShapeUtil::MakeShape(offset_shape.element_type(), {}), + xla::ShapeUtil::MakeShape(offset_shape.element_type(), {}), + xla::ShapeUtil::MakeShape(weight_shape.element_type(), + {num_embeddings, weight_dim}), + xla::ShapeUtil::MakeShape(weight_shape.element_type(), {1, weight_dim})}; + xla::Shape result_shape = xla::ShapeUtil::MakeTupleShape(shape_elements); + + xla::XlaComputation condition; + { + xla::XlaBuilder builder("condition"); + auto prev = xla::Parameter(&builder, 0, result_shape, "prev"); + auto index = xla::GetTupleElement(prev, 0); + auto final_value = xla::GetTupleElement(prev, 1); + xla::Lt(index, final_value); + condition = builder.Build().value(); + } + + xla::XlaComputation body; + { + xla::XlaBuilder builder("body"); + auto prev = xla::Parameter(&builder, 0, result_shape, "prev"); + auto index = xla::GetTupleElement(prev, 0); + auto emb = xla::GetTupleElement(prev, 2); + auto w = xla::GetTupleElement(prev, 3); + + xla::XlaOp slice = xla::DynamicSlice( + emb, + {index, xla::ConvertElementType(xla::ConstantR0(&builder, 0), + offset_shape.element_type())}, + {1, weight_dim}); + xla::XlaOp result = + mode == MODE_SUM ? xla::Add(w, slice) : xla::Max(w, slice); + + xla::Tuple(&builder, + { + xla::Add(index, xla::ConvertElementType( + xla::ConstantR0(&builder, 1), + offset_shape.element_type())), + xla::GetTupleElement(prev, 1), + xla::GetTupleElement(prev, 2), + result, + }); + body = builder.Build().value(); + } + + xla::Array initial_vector({1, weight_dim}, 0.f); + std::vector results; + for (int64_t i = 0; i < n; i++) { + xla::XlaOp start = xla::DynamicSlice( + offsets, {xla::ConstantR0(offsets.builder(), i)}, {1}); + if (i == n - 1 && include_last_offset) continue; + xla::XlaOp end = + i == n - 1 && !include_last_offset + ? xla::ConvertElementType(xla::ConstantR1( + offsets.builder(), 1, num_embeddings), + offset_shape.element_type()) + : xla::DynamicSlice( + offsets, {xla::ConstantR0(offsets.builder(), i + 1)}, + {1}); + // Create a While node with computations for the condition and the body. + auto init_tuple = xla::Tuple( + offsets.builder(), + {xla::Reshape(start, {0}, {}), xla::Reshape(end, {0}, {}), + embeddings_weighted, + xla::ConvertElementType( + xla::ConstantFromArray(offsets.builder(), initial_vector), + weight_shape.element_type())}); + auto result = xla::While(condition, body, init_tuple); + results.push_back(xla::GetTupleElement(result, 3)); + }; + xla::XlaOp output1 = xla::ConcatInDim(offsets.builder(), results, 0); + return {output1, output2, output3, output4}; +} + +xla::Shape NodeOutputShapes(const torch::lazy::Value& weight, + const torch::lazy::Value& indices, + const torch::lazy::Value& offsets, + const torch::lazy::Value& per_sample_weights, + bool include_last_offset, bool mode) { + auto lower_for_shapes_fn = + [&](absl::Span operands) -> xla::XlaOp { + return xla::Tuple( + operands[0].builder(), + BuildEmbeddingBag(operands[0], operands[1], operands[2], operands[3], + include_last_offset, mode)); + }; + + std::vector input_shapes = { + GetXlaShape(weight), GetXlaShape(indices), GetXlaShape(offsets), + GetXlaShape(per_sample_weights)}; + + return InferOutputShape(absl::MakeSpan(input_shapes), lower_for_shapes_fn); +} +} // namespace + +std::string EmbeddingBag::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString(); + return ss.str(); +} + +EmbeddingBag::EmbeddingBag(const torch::lazy::Value& weight, + const torch::lazy::Value& indices, + const torch::lazy::Value& offsets, int64_t mode, + const torch::lazy::Value& per_sample_weights, + bool include_last_offset) + : XlaNode( + torch::lazy::OpKind(at::aten::embedding_bag), + {weight, indices, offsets, per_sample_weights}, + [&]() { + return NodeOutputShapes(weight, indices, offsets, + per_sample_weights, include_last_offset, + mode); + }, + /*num_outputs=*/4, torch::lazy::MHash(mode, include_last_offset)), + mode_(mode), + include_last_offset_(include_last_offset) {} + +torch::lazy::NodePtr EmbeddingBag::Clone(torch::lazy::OpList operands) const { + return torch::lazy::MakeNode(operands.at(0), operands.at(1), + operands.at(2), mode_, + operands.at(3), false); +} + +XlaOpVector EmbeddingBag::Lower(LoweringContext* loctx) const { + xla::XlaOp weight = loctx->GetOutputOp(operand(0)); + xla::XlaOp indices = loctx->GetOutputOp(operand(1)); + xla::XlaOp offsets = loctx->GetOutputOp(operand(2)); + xla::XlaOp per_sample_weights = loctx->GetOutputOp(operand(3)); + std::vector ops = + BuildEmbeddingBag(weight, indices, offsets, per_sample_weights, + include_last_offset_, mode_); + return ReturnOps(absl::MakeSpan(ops), loctx); +} + +} // namespace torch_xla \ No newline at end of file diff --git a/torch_xla/csrc/ops/embedding_bag.h b/torch_xla/csrc/ops/embedding_bag.h new file mode 100644 index 00000000000..4d9b0a6eecb --- /dev/null +++ b/torch_xla/csrc/ops/embedding_bag.h @@ -0,0 +1,31 @@ +#ifndef XLA_TORCH_XLA_CSRC_OPS_EMBEDDING_BAG_H_ +#define XLA_TORCH_XLA_CSRC_OPS_EMBEDDING_BAG_H_ + +#include + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { + +class EmbeddingBag : public XlaNode { + public: + EmbeddingBag(const torch::lazy::Value& weight, + const torch::lazy::Value& indices, + const torch::lazy::Value& offsets, int64_t mode, + const torch::lazy::Value& per_sample_weights, + bool include_last_offset); + + std::string ToString() const override; + + torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + private: + int64_t mode_; + bool include_last_offset_; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_OPS_EMBEDDING_BAG_H_ \ No newline at end of file diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 6f746972355..582b69d8a50 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -237,7 +237,7 @@ cc_library( deps = [ ":debug_macros", ":sys_util", - "@tsl//tsl/distributed_runtime/preemption:preemption_sync_manager", + "@xla//xla/tsl/distributed_runtime/preemption:preemption_sync_manager", "@xla//xla/pjrt/distributed", ], ) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 029f9268342..c48cf1555ff 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -58,18 +58,6 @@ torch::lazy::hash_t hash_comp_env( xla::ifrt::Client* client, std::vector& ordered_devices) { torch::lazy::hash_t hash = hash::HashXlaEnvVars(); - auto topology_desc = client->GetTopologyForDevices(ordered_devices); - if (topology_desc.ok()) { - // Some backends support a topology description which provides a better - // view of the specific compilation environment. - auto serialized = topology_desc.value()->Serialize(); - if (serialized.ok()) { - return torch::lazy::HashCombine( - hash, - torch::lazy::DataHash(serialized->data(), serialized->length())); - } - // If serialization fails, fallthrough to the manual approach. - } std::string platform_name(client->platform_name()); std::string platform_version(client->platform_version()); hash = torch::lazy::HashCombine( @@ -78,10 +66,26 @@ torch::lazy::hash_t hash_comp_env( hash = torch::lazy::HashCombine( hash, torch::lazy::StringHash(platform_version.c_str())); // Include global devices in the hash, ensuring order is consistent. + xla::ifrt::DeviceList::Devices ifrt_devices; for (auto& device : ordered_devices) { std::string device_str(device->ToString()); hash = torch::lazy::HashCombine( hash, torch::lazy::StringHash(device_str.c_str())); + ifrt_devices.push_back(device); + } + + xla::ifrt::DeviceList device_list(std::move(ifrt_devices)); + auto topology_desc = client->GetTopologyForDevices(device_list); + if (topology_desc.ok()) { + // Some backends support a topology description which provides a better + // view of the specific compilation environment. + auto serialized = topology_desc.value()->Serialize(); + if (serialized.ok()) { + return torch::lazy::HashCombine( + hash, + torch::lazy::DataHash(serialized->data(), serialized->length())); + } + // If serialization fails, fallthrough to the manual approach. } return hash; } @@ -92,7 +96,7 @@ std::string IfrtComputationClient::IfrtDeviceToString( xla::ifrt::Device* const device) const { std::string platform = absl::AsciiStrToUpper(device->client()->platform_name()); - int ordinal = global_ordinals_.at(device->id()); + int ordinal = global_ordinals_.at(device->Id().value()); std::string str = absl::StrFormat("%s:%d", platform, ordinal); return str; } @@ -120,11 +124,12 @@ IfrtComputationClient::IfrtComputationClient() { // a device's global ordinal separately from its device ID. Order the // devices by increasing ID to assign global ordinals. std::vector ordered_devices(client_->device_count()); - std::partial_sort_copy(client_->devices().begin(), client_->devices().end(), - ordered_devices.begin(), ordered_devices.end(), - [](auto& a, auto& b) { return a->id() < b->id(); }); + std::partial_sort_copy( + client_->devices().begin(), client_->devices().end(), + ordered_devices.begin(), ordered_devices.end(), + [](auto& a, auto& b) { return a->Id().value() < b->Id().value(); }); for (auto* device : ordered_devices) { - global_ordinals_[device->id()] = global_ordinals_.size(); + global_ordinals_[device->Id().value()] = global_ordinals_.size(); std::string device_str = IfrtDeviceToString(device); string_to_device_.emplace(device_str, device); } @@ -611,7 +616,7 @@ std::vector IfrtComputationClient::GetAllDevices() const { int IfrtComputationClient::GetNumProcesses() const { int max_process_index = client_->process_index(); for (auto* device : client_->devices()) { - max_process_index = std::max(max_process_index, device->process_index()); + max_process_index = std::max(max_process_index, device->ProcessIndex()); } return max_process_index + 1; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index d6d914ad8da..38d0de97204 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -134,7 +134,7 @@ class IfrtComputationClient : public ComputationClient { // global_ordinals_ tracks a map from PjRtDeviceId to the device's // dense global ordinal. std::unordered_map global_ordinals_; - std::unordered_map string_to_device_; + std::unordered_map string_to_device_; std::shared_ptr> replication_devices_; OperationManager operation_manager_; tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool( diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 188e26f8ac2..a129a476a2e 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -463,7 +463,7 @@ std::vector PjRtComputationClient::TransferFromDevice( metrics::TimedSection timed(TransferFromDeviceMetric()); tsl::profiler::TraceMe activity("PjRtComputationClient::TransferFromDevice", tsl::profiler::TraceMeLevel::kInfo); - std::vector> futures; + std::vector> futures; futures.reserve(handles.size()); std::vector literals; literals.reserve(handles.size()); @@ -679,7 +679,7 @@ PjRtComputationClient::ExecuteComputation( TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device << " Done"; - std::optional> returned_future; + std::optional> returned_future; std::vector> results = pjrt_computation.executable ->ExecuteSharded(buffers, pjrt_device, execute_options, @@ -779,8 +779,8 @@ PjRtComputationClient::ExecuteReplicated( TF_VLOG(5) << "ExecuteReplicated acquiring PJRT device lock for " << spmd_device_str << " Done"; - std::optional>> returned_futures = - std::vector>(); + std::optional>> returned_futures = + std::vector>(); std::vector>> results; { tsl::profiler::TraceMe activity( diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index 648076757be..52b06d89cb4 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -21,8 +21,24 @@ namespace runtime { namespace { +// Placeholder plugin for testing only. Does not implement multiprocessing or +// configuration. Very likely will not work from Python code. +class LibraryPlugin : public PjRtPlugin { + public: + std::string library_path() const override { + return sys_util::GetEnvString("PJRT_LIBRARY_PATH", ""); + } + + const std::unordered_map + client_create_options() const override { + return {}; + } + + bool requires_xla_coordinator() const override { return false; } +}; + std::unordered_map> - pjrt_plugins_; + pjrt_plugins_ = {{"LIBRARY", std::make_shared()}}; xla::GpuAllocatorConfig GetGpuAllocatorConfig() { auto allocator_config = xla::GpuAllocatorConfig{}; @@ -60,7 +76,8 @@ InitializePjRt(const std::string& device_type) { std::unique_ptr client; std::unique_ptr coordinator; - if (sys_util::GetEnvBool(env::kEnvPjrtDynamicPlugins, false)) { + if (sys_util::GetEnvBool(env::kEnvPjrtDynamicPlugins, false) && + device_type != "CPU") { std::shared_ptr plugin = GetPjRtPlugin(device_type); if (plugin) { TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; diff --git a/torch_xla/csrc/runtime/xla_coordinator.h b/torch_xla/csrc/runtime/xla_coordinator.h index ae85c79a941..fb2cfaf99f5 100644 --- a/torch_xla/csrc/runtime/xla_coordinator.h +++ b/torch_xla/csrc/runtime/xla_coordinator.h @@ -3,8 +3,8 @@ #include -#include "tsl/distributed_runtime/preemption/preemption_sync_manager.h" #include "xla/pjrt/distributed/distributed.h" +#include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h" namespace torch_xla { namespace runtime { diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index f27dc786fb5..fbb240f31d3 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -48,6 +48,7 @@ #include "torch_xla/csrc/ops/dynamic_view.h" #include "torch_xla/csrc/ops/einsum.h" #include "torch_xla/csrc/ops/einsum_backward.h" +#include "torch_xla/csrc/ops/embedding_bag.h" #include "torch_xla/csrc/ops/expand.h" #include "torch_xla/csrc/ops/expand_symint.h" #include "torch_xla/csrc/ops/exponential.h" @@ -1292,6 +1293,20 @@ XLATensorPtr embedding(const XLATensorPtr& weight, return tensor_ops::Embedding(weight, indices); } +std::tuple +embedding_bag(const XLATensorPtr& weight, const XLATensorPtr& indices, + const XLATensorPtr& offsets, int64_t mode, + const XLATensorPtr& per_sample_weights, + bool include_last_offset) { + torch::lazy::NodePtr node = torch::lazy::MakeNode( + weight->GetIrValue(), indices->GetIrValue(), offsets->GetIrValue(), mode, + per_sample_weights->GetIrValue(), include_last_offset); + return std::make_tuple(weight->CreateFrom(torch::lazy::Value(node, 0)), + weight->CreateFrom(torch::lazy::Value(node, 1)), + weight->CreateFrom(torch::lazy::Value(node, 2)), + weight->CreateFrom(torch::lazy::Value(node, 3))); +} + XLATensorPtr exp(const XLATensorPtr& input) { return input->CreateFrom(Exp(input->GetIrValue())); } diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index f27465fd67d..6a7005a5f0f 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -381,6 +381,11 @@ XLATensorPtr embedding_dense_backward(const XLATensorPtr& grad_output, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq); +std::tuple +embedding_bag(const XLATensorPtr& weight, const XLATensorPtr& indices, + const XLATensorPtr& offsets, int64_t mode, + const XLATensorPtr& per_sample_weights, bool include_last_offset); + XLATensorPtr embedding(const XLATensorPtr& weight, const XLATensorPtr& indices); XLATensorPtr eq(const XLATensorPtr& input, const at::Scalar& other); diff --git a/torch_xla/csrc/xla_manual_registration.cpp b/torch_xla/csrc/xla_manual_registration.cpp index dc7df436ec7..6020ef6bc04 100644 --- a/torch_xla/csrc/xla_manual_registration.cpp +++ b/torch_xla/csrc/xla_manual_registration.cpp @@ -1,7 +1,9 @@ #include #include +#include "torch_xla/csrc/aten_cpu_fallback.h" #include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/debug_util.h" #include "torch_xla/csrc/ops/nms.h" #include "torch_xla/csrc/ops/ops.h" #include "torch_xla/csrc/tensor_methods.h" @@ -11,10 +13,22 @@ namespace torch_xla { namespace manual { namespace { +struct NmsOp { + using schema = at::Tensor(const at::Tensor&, const at::Tensor&, double); + using ptr_schema = schema*; + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "torchvision::nms") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") +}; + at::Tensor nms_kernel(const at::Tensor& boxes, const at::Tensor& scores, double iou_threshold) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + if (!DebugUtil::ExperimentEnabled("nms")) { + return at::native::call_fallback_fn<&xla_cpu_fallback, NmsOp>::call( + boxes, scores, iou_threshold); + } + XLA_CHECK_EQ(boxes.dim(), 2) << "nms(): boxes should be a 2D tensor."; XLA_CHECK_EQ(boxes.size(1), 4) << "nms(): boxes should be a 2D tensor of shape [N, 4]."; diff --git a/torch_xla/distributed/spmd/__init__.py b/torch_xla/distributed/spmd/__init__.py index abfe1c62ba0..099f25e9fb5 100644 --- a/torch_xla/distributed/spmd/__init__.py +++ b/torch_xla/distributed/spmd/__init__.py @@ -27,4 +27,6 @@ "_mark_manual_sharding", "enable_manual_sharding", "disable_manual_sharding", + "enable_manual_sharding", + "disable_manual_sharding", ] diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index ff4b335058b..9bd050efc29 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -5,6 +5,7 @@ import torch import torch_xla import torch_xla.core.xla_model as xm +import torch_xla.distributed.spmd as xs from typing import List, Callable from torch.library import impl @@ -17,7 +18,7 @@ def _extract_backend_config( module: "jaxlib.mlir._mlir_libs._mlir.ir.Module") -> str | None: """ This algorithm intends to extract the backend config from the compiler IR like the following, - and it is designed to traverse any generic MLIR module. + and it is not designed to traverse any generic MLIR module. module @jit_add_vectors attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<8xi32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<8xi32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) { @@ -54,17 +55,12 @@ def jax_import_guard(): torch_xla._XLAC._init_computation_client() -def trace_pallas(kernel: Callable, - *args, - static_argnums=None, - static_argnames=None, - **kwargs): +def to_jax_shape_dtype_struct(tensor: torch.Tensor) -> "jax.ShapeDtypeStruct": # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. jax_import_guard() import jax import jax.numpy as jnp - import jax._src.pallas.mosaic.pallas_call_registration def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype: if dtype == torch.float32: @@ -92,14 +88,28 @@ def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype: else: raise ValueError(f"Unsupported dtype: {dtype}") + return jax.ShapeDtypeStruct(tensor.shape, + convert_torch_dtype_to_jax(tensor.dtype)) + + +def trace_pallas(kernel: Callable, + *args, + static_argnums=None, + static_argnames=None, + **kwargs): + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + import jax + import jax._src.pallas.mosaic.pallas_call_registration + jax_args = [] # for tracing tensor_args = [] # for execution for i, arg in enumerate(args): # TODO: Could the args be a tuple of tensors or a list of tensors? Flattern them? if torch.is_tensor(arg): # ShapeDtypeStruct doesn't have any storage and thus is very suitable for generating the payload. - jax_meta_tensor = jax.ShapeDtypeStruct( - arg.shape, convert_torch_dtype_to_jax(arg.dtype)) + jax_meta_tensor = to_jax_shape_dtype_struct(arg) jax_args.append(jax_meta_tensor) tensor_args.append(arg) else: @@ -166,38 +176,86 @@ class FlashAttention(torch.autograd.Function): "block_k_dq": 256, "block_k_major_dq": 512, } + NUM_LANES = 128 + NUM_SUBLANES = 8 + + @staticmethod + def prepare_segment_ids(q_segment_ids, kv_segment_ids): + from jax.experimental.pallas.ops.tpu.flash_attention import SegmentIds + if q_segment_ids is None or kv_segment_ids is None: + return None, None, None + + assert q_segment_ids is not None and kv_segment_ids is not None, "Both q_segment_ids and kv_segment_ids should be provided." + segment_ids = SegmentIds( + to_jax_shape_dtype_struct(q_segment_ids), + to_jax_shape_dtype_struct(kv_segment_ids)) + q_segment_ids = q_segment_ids.unsqueeze(-1).expand( + [-1 for _ in q_segment_ids.shape] + [FlashAttention.NUM_LANES]) + kv_segment_ids = kv_segment_ids.unsqueeze(1).expand([ + kv_segment_ids.shape[0], FlashAttention.NUM_SUBLANES, + kv_segment_ids.shape[1] + ]) + return segment_ids, q_segment_ids, kv_segment_ids @staticmethod - def forward(ctx, q, k, v, causal=False): + def forward(ctx, + q, + k, + v, + causal=False, + q_segment_ids=None, + kv_segment_ids=None, + partition_spec=None, + mesh=None): # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. jax_import_guard() + import jax from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_impl ctx.causal = causal + ctx.partition_spec = partition_spec + ctx.mesh = mesh + ctx.full_shape = None save_residuals = q.requires_grad or k.requires_grad or v.requires_grad - # It returns the shape and type of o, l, m. - def shape_dtype(q, *arg): - if not save_residuals: - return [(q.shape, q.dtype)] + # SPMD integration. + # mark_sharding is in-placed, and therefore save the full q, k, v for the backward. + full_q = q + full_k = k + full_v = v + if partition_spec is not None: + ctx.full_shape = q.shape + q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor + k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor + v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor + + # It computes the shape and type of o, l, m. + shapes = [q.shape] + dtypes = [q.dtype] + if save_residuals: res_shape = list(q.shape) res_shape[-1] = FlashAttention.MIN_BLOCK_SIZE - return [(q.shape, q.dtype), (res_shape, torch.float32), - (res_shape, torch.float32)] - - # We can't directly use flash_attention as we need to override the save_residuals flag which returns - # l and m that is needed for the backward. Then we lose all the shape checks. - # TODO: replicate the shape checks on flash_attention. - _flash_attention_impl = make_kernel_from_pallas(_flash_attention_impl, - shape_dtype) + for _ in range(2): + shapes.append(res_shape) + dtypes.append(torch.float32) + with torch.no_grad(): - o = _flash_attention_impl( + segment_ids, q_segment_ids, kv_segment_ids = FlashAttention.prepare_segment_ids( + q_segment_ids, kv_segment_ids) + ctx.segment_ids = segment_ids + + # We can't directly use flash_attention as we need to override the save_residuals flag which returns + # l and m that is needed for the backward. Then we lose all the shape checks. + # TODO: replicate the shape checks on flash_attention. + # Here we seperate the tracing and execution part just to support SegmentIds. + payload, _ = trace_pallas( + _flash_attention_impl, q, k, v, None, - None, + segment_ids, save_residuals, causal, 1.0, @@ -207,20 +265,45 @@ def shape_dtype(q, *arg): min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]), False, static_argnums=range(5, 13)) + + args = [q, k, v] + if segment_ids is not None: + args += [q_segment_ids, kv_segment_ids] + o = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes) + if not save_residuals: + o = o[0] + # SPMD integration + if partition_spec is not None: + o = xs.disable_manual_sharding( + o, partition_spec, ctx.full_shape, mesh=mesh).global_tensor return o o, *aux = o l, m = (v[..., 0] for v in aux[-2:]) - ctx.save_for_backward(q, k, v, o, l, m) + # SPMD integration + if partition_spec is not None: + o = xs.disable_manual_sharding( + o, partition_spec, ctx.full_shape, mesh=mesh).global_tensor + l = xs.disable_manual_sharding( + l, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor + m = xs.disable_manual_sharding( + m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor + + ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids, + kv_segment_ids) return o @staticmethod def backward(ctx, grad_output): from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv - q, k, v, o, l, m = ctx.saved_tensors + q, k, v, o, l, m, q_segment_ids, kv_segment_ids = ctx.saved_tensors causal = ctx.causal + partition_spec = ctx.partition_spec + mesh = ctx.mesh + full_shape = ctx.full_shape + segment_ids = ctx.segment_ids grad_q = grad_k = grad_v = None grad_i = torch.sum( @@ -234,6 +317,20 @@ def backward(ctx, grad_output): expanded_grad_i = grad_i.unsqueeze(-1).expand( [-1 for _ in grad_i.shape] + [FlashAttention.MIN_BLOCK_SIZE]) + # SPMD integration + if partition_spec is not None: + q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor + k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor + v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor + expanded_l = xs.enable_manual_sharding( + expanded_l, partition_spec, mesh=mesh).global_tensor + expanded_m = xs.enable_manual_sharding( + expanded_m, partition_spec, mesh=mesh).global_tensor + grad_output = xs.enable_manual_sharding( + grad_output, partition_spec, mesh=mesh).global_tensor + expanded_grad_i = xs.enable_manual_sharding( + expanded_grad_i, partition_spec, mesh=mesh).global_tensor + if ctx.needs_input_grad[0]: payload, _ = trace_pallas( _flash_attention_bwd_dq, @@ -241,7 +338,7 @@ def backward(ctx, grad_output): k, v, None, - None, + segment_ids, l, m, grad_output, @@ -261,9 +358,13 @@ def backward(ctx, grad_output): "block_q_major", "block_k_major", "block_k", "sm_scale", "causal", "mask_value", "debug" ]) - grad_q = torch_xla._XLAC._xla_tpu_custom_call( - [q, k, v, expanded_l, expanded_m, grad_output, expanded_grad_i], - payload, [q.shape], [q.dtype])[0] + + args = [q, k, v] + if segment_ids is not None: + args += [q_segment_ids, kv_segment_ids] + args += [expanded_l, expanded_m, grad_output, expanded_grad_i] + grad_q = torch_xla._XLAC._xla_tpu_custom_call(args, payload, [q.shape], + [q.dtype])[0] if ctx.needs_input_grad[1] or ctx.needs_input_grad[2]: payload, _ = trace_pallas( @@ -272,7 +373,7 @@ def backward(ctx, grad_output): k, v, None, - None, + segment_ids, l, m, grad_output, @@ -295,15 +396,29 @@ def backward(ctx, grad_output): "block_q_major", "block_k_major", "block_k", "block_q", "sm_scale", "causal", "mask_value", "debug" ]) - grads = torch_xla._XLAC._xla_tpu_custom_call( - [q, k, v, expanded_l, expanded_m, grad_output, expanded_grad_i], - payload, [k.shape, v.shape], [k.dtype, v.dtype]) + + args = [q, k, v] + if segment_ids is not None: + args += [q_segment_ids, kv_segment_ids] + args += [expanded_l, expanded_m, grad_output, expanded_grad_i] + grads = torch_xla._XLAC._xla_tpu_custom_call(args, payload, + [k.shape, v.shape], + [k.dtype, v.dtype]) if ctx.needs_input_grad[1]: grad_k = grads[0] if ctx.needs_input_grad[2]: grad_v = grads[1] - return grad_q, grad_k, grad_v, None + # SPMD integration + if partition_spec is not None: + grad_q = xs.disable_manual_sharding( + grad_q, partition_spec, full_shape, mesh=mesh).global_tensor + grad_k = xs.disable_manual_sharding( + grad_k, partition_spec, full_shape, mesh=mesh).global_tensor + grad_v = xs.disable_manual_sharding( + grad_v, partition_spec, full_shape, mesh=mesh).global_tensor + + return grad_q, grad_k, grad_v, None, None, None, None, None def flash_attention( @@ -311,8 +426,75 @@ def flash_attention( k, # [batch_size, num_heads, kv_seq_len, d_model] v, # [batch_size, num_heads, kv_seq_len, d_model] causal=False, -): - return FlashAttention.apply(q, k, v, causal) + q_segment_ids=None, + kv_segment_ids=None, + *, + partition_spec=None, + mesh=None): + # TODO: support SPMD and Dynamo with segment_ids. + return FlashAttention.apply(q, k, v, causal, q_segment_ids, kv_segment_ids, + partition_spec, mesh) + + +def paged_attention(q, k_pages, v_pages, lengths, page_indices, + pages_per_compute_block): + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention + + payload, tensor_args = trace_pallas( + paged_attention, + q, + k_pages, + v_pages, + lengths, + page_indices, + pages_per_compute_block=pages_per_compute_block, + static_argnames=["pages_per_compute_block"], + ) + + batch_size, num_heads, head_dim = q.shape + num_kv_heads, _, page_size, head_dim_k = k_pages.shape + batch_size_paged_indices, pages_per_sequence = page_indices.shape + q_dtype_for_kernel_launch = q.dtype + if (num_heads // num_kv_heads) % 8 != 0: + q = q.reshape(batch_size, num_heads, 1, head_dim) + q_dtype_for_kernel_launch = torch.float32 + + page_indices_reshaped = page_indices.reshape(-1) + buffer_index = torch.zeros((1,), dtype=torch.int32).to("xla") + step = torch.zeros((1,), dtype=torch.int32).to("xla") + output_shape = torch.Size(list(q.shape[:-1]) + [1]) + + output, _, _ = torch_xla._XLAC._xla_tpu_custom_call( + [ + lengths, + page_indices_reshaped, + buffer_index, + step, + q.to(q_dtype_for_kernel_launch), + k_pages, + v_pages, + ], payload, [q.shape, output_shape, output_shape], + [q_dtype_for_kernel_launch, torch.float32, torch.float32]) + + return output.reshape(batch_size, num_heads, head_dim).to(q.dtype) + + +def non_xla_attetion(q, k, v, attention_type): + # This will be called when dynamo use fake tensor to construct the fake output. + # We need to make sure output tensor's shape is correct. + if k.device != torch.device("meta"): + warnings.warn( + f'XLA {attention_type} attention should only be applied to tensors on XLA device' + ) + + # perform a regular attention if input tensors are not on XLA device. + attn_weight = q @ k.transpose(-2, -1) + attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1) + attn_output = attn_weight @ v + return attn_output XLA_LIB.define( @@ -333,14 +515,26 @@ def flash_attention_non_xla(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False): - # This will be called when dynamo use fake tensor to construct the fake output. - # We need to make sure output tensor's shape is correct. - if k.device != torch.device("meta"): - warnings.warn( - 'XLA flash attention should only be applied to tensors on XLA device') + return non_xla_attetion(q, k, v, "flash") - # perform a regular attention if input tensors are not on XLA device. - attn_weight = q @ k.transpose(-2, -1) - attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1) - attn_output = attn_weight @ v - return attn_output + +XLA_LIB.define( + "paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block) -> Tensor", +) + + +@impl(XLA_LIB, "paged_attention", "XLA") +def paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor, + v_pages: torch.Tensor, lengths: torch.Tensor, + page_indices: torch.Tensor, + pages_per_compute_block: int): + return paged_attention(q, k_pages, v_pages, lengths, page_indices, + pages_per_compute_block) + + +@impl(XLA_LIB, "paged_attention", "CompositeExplicitAutograd") +def paged_attention_non_xla(q: torch.Tensor, k_pages: torch.Tensor, + v_pages: torch.Tensor, lengths: torch.Tensor, + page_indices: torch.Tensor, + pages_per_compute_block: int): + return non_xla_attetion(q, k, v, "paged") diff --git a/torch_xla/experimental/plugins.py b/torch_xla/experimental/plugins.py index 77c2a572de3..620dff7e45c 100644 --- a/torch_xla/experimental/plugins.py +++ b/torch_xla/experimental/plugins.py @@ -76,7 +76,9 @@ def use_dynamic_plugins(): def using_dynamic_plugins(): - return xu.getenv_as(xenv.PJRT_DYNAMIC_PLUGINS, bool, False) + # TODO: dummy plugin for CPU + return xu.getenv_as(xenv.PJRT_DYNAMIC_PLUGINS, bool, + False) and xr.device_type() != "CPU" def default() -> DevicePlugin: