diff --git a/.circleci/README.md b/.circleci/README.md deleted file mode 100644 index d01e6138317..00000000000 --- a/.circleci/README.md +++ /dev/null @@ -1,19 +0,0 @@ -# CircleCI Overview -PyTorch and PyTorch/XLA use CircleCI to lint, build, and test each PR that is submitted. All CircleCI tests should succeed before the PR is merged into master. PyTorch CircleCI pins PyTorch/XLA to a specific commit. On the other hand, PyTorch/XLA CircleCI pulls PyTorch from master unless a pin is manually provided. This README will go through the reasons of these pins, how to pin a PyTorch/XLA PR to an upstream PyTorch PR, and how to coordinate a merge for breaking PyTorch changes. - -## Why does PyTorch CircleCI pin PyTorch/XLA? -As mentioned above, [PyTorch CircleCI pins PyTorch/XLA](https://github.com/pytorch/pytorch/blob/master/.jenkins/pytorch/common_utils.sh#L119) to a "known good" commit to prevent accidental changes from PyTorch/XLA to break PyTorch CircleCI without warning. PyTorch has hundreds of commits each week, and this pin ensures that PyTorch/XLA as a downstream package does not cause failures in PyTorch CircleCI. - -## Why does PyTorch/XLA CircleCI pull from PyTorch master? -[PyTorch/XLA CircleCI pulls PyTorch from master](https://github.com/pytorch/xla/blob/f3415929683880192b63b285921c72439af55bf0/.circleci/common.sh#L15) unless a PyTorch pin is manually provided. PyTorch/XLA is a downstream package to PyTorch, and pulling from master ensures that PyTorch/XLA will stay up-to-date and works with the latest PyTorch changes. - -## Pinning PyTorch PR in PyTorch/XLA PR -Sometimes a PyTorch/XLA PR needs to be pinned to a specific PyTorch PR to test new featurues, fix breaking changes, etc. Since PyTorch/XLA CircleCI pulls from PyTorch master by default, we need to manually provided a PyTorch pin. In a PyTorch/XLA PR, PyTorch an be manually pinned by creating a `.torch_pin` under `/torch_patches`. The `.torch_pin` should have the corresponding PyTorch PR number prefixed by "#". Take a look at [example here](https://github.com/pytorch/xla/pull/3792/commits/40f41fb98b0f2386d287eeac0bae86e873d4a9d8). Before the PyTorch/XLA PR gets merged, the `.torch_pin` must be deleted. - -## Coodinating merges for breaking PyTorch PRs -When PyTorch PR introduces a breaking change, its PyTorch/XLA CircleCI tests will fail. Steps for fixing and merging such breaking PyTorch change is as following: -1. Create a PyTorch/XLA PR to fix this issue with `.torch_pin` and rebase with master to ensure the PR is up-to-date with the latest commit on PyTorch/XLA. Once this PR is created, it'll create a commit hash that will be used in step 2. If you have multiple commits in the PR, use the last one's hash. **Important note: When you rebase this PR, it'll create a new commit hash and make the old hash obsolete. Be cautious about rebasing, and if you rebase, make sure you inform the PyTorch PR's author.** -2. Rebase (or ask the PR owner to rebase) the PyTorch PR with master. Update the PyTorch PR to pin the PyTorch/XLA to the commit hash created in step 1 by updating `pytorch/.github/ci_commit_pins/xla.txt`. -3. Once CircleCI tests are green on both ends, merge PyTorch PR. -4. Remove the `.torch_pin` in PyTorch/XLA PR and merge. To be noted, `git commit --amend` should be avoided in this step as PyTorch CI will keep using the commit hash created in step 1 until other PRs update that manually or the nightly buildbot updates that automatically. -5. Finally, don't delete your branch until 2 days later. See step 4 for explanations. diff --git a/.circleci/doc_push.sh b/.circleci/doc_push.sh deleted file mode 100755 index 72b4a44f6e7..00000000000 --- a/.circleci/doc_push.sh +++ /dev/null @@ -1,63 +0,0 @@ -#!/bin/bash - -set -ex - -cd /tmp/pytorch/xla - -source ./xla_env -source .circleci/common.sh - -echo "Building docs" -pushd docs -./docs_build.sh -popd - -echo "Pushing to public" -git config --global user.email "pytorchxla@gmail.com" -git config --global user.name "torchxlabot2" -GH_PAGES_BRANCH=gh-pages -GH_PAGES_DIR=gh-pages-tmp -CURRENT_COMMIT=`git rev-parse HEAD` -BRANCH_NAME=`git rev-parse --abbrev-ref HEAD` -if [[ "$BRANCH_NAME" == release/* ]]; then - SUBDIR_NAME=$BRANCH_NAME -else - SUBDIR_NAME="master" -fi -pushd /tmp -git clone --quiet -b "$GH_PAGES_BRANCH" https://github.com/pytorch/xla.git "$GH_PAGES_DIR" -pushd $GH_PAGES_DIR -rm -rf $SUBDIR_NAME -mkdir -p $SUBDIR_NAME -cp -fR /tmp/pytorch/xla/docs/build/* $SUBDIR_NAME -git_status=$(git status --porcelain) -if [[ $git_status ]]; then - echo "Doc is updated... Pushing to public" - echo "${git_status}" - sudo apt-get -qq update - export DEBIAN_FRONTEND=noninteractive - sudo ln -snf /usr/share/zoneinfo/Etc/UTC /etc/localtime - sudo sh -c "echo Etc/UTC > /etc/timezone" - sudo apt-get -qq -y install tzdata - sudo apt-get -qq install expect - git add . - - COMMIT_MSG="Update doc from commit $CURRENT_COMMIT" - git commit -m "$COMMIT_MSG" - set +x -/usr/bin/expect < /dev/null 2>&1 ; then - VER="buster" - else - VER=$(lsb_release -c -s) - fi - echo "$VER" -} - -function install_llvm_clang() { - local DEBVER=$(debian_version) - if ! apt-get install -y -s clang-8 > /dev/null 2>&1 ; then - maybe_append "deb http://apt.llvm.org/${DEBVER}/ llvm-toolchain-${DEBVER}-8 main" /etc/apt/sources.list - maybe_append "deb-src http://apt.llvm.org/${DEBVER}/ llvm-toolchain-${DEBVER}-8 main" /etc/apt/sources.list - wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - - sudo apt-get update - fi - # Build config also sets CC=clang-8, CXX=clang++-8 - sudo apt-get install -y clang-8 clang++-8 - sudo apt-get install -y llvm-8 llvm-8-dev llvm-8-tools - sudo ln -s /usr/bin/clang-8 /usr/bin/clang - sudo ln -s /usr/bin/clang++-8 /usr/bin/clang++ - export CC=clang-8 CXX=clang++-8 -} - -install_llvm_clang diff --git a/.circleci/setup_ci_environment.sh b/.circleci/setup_ci_environment.sh index eba2c373b8a..87a61524e7e 100755 --- a/.circleci/setup_ci_environment.sh +++ b/.circleci/setup_ci_environment.sh @@ -58,7 +58,7 @@ sudo apt-get -y remove linux-image-generic linux-headers-generic linux-generic d # How to figure out what the correct versions of these packages are? # My preferred method is to start a Docker instance of the correct # Ubuntu version (e.g., docker run -it ubuntu:16.04) and then ask -# apt what the packages you need are. Note that the CircleCI image +# apt what the packages you need are. Note that the CI image # comes with Docker. # # Using 'retry' here as belt-and-suspenders even though we are diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 87072a65bce..bfff4ef8422 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -/infra @will-cromar @JackCaoG @yeounoh @mateuszlewko @stgpetrovic +/infra @will-cromar @JackCaoG @lsy323 diff --git a/.github/ci.md b/.github/ci.md new file mode 100644 index 00000000000..9dc7b3d8b15 --- /dev/null +++ b/.github/ci.md @@ -0,0 +1,97 @@ +# CI Overview + +PyTorch and PyTorch/XLA use CI to lint, build, and test each PR that is submitted. All CI tests should succeed before the PR is merged into master. PyTorch CI pins PyTorch/XLA to a specific commit. On the other hand, PyTorch/XLA CI pulls PyTorch from master unless a pin is manually provided. This README will go through the reasons of these pins, how to pin a PyTorch/XLA PR to an upstream PyTorch PR, and how to coordinate a merge for breaking PyTorch changes. + +## Usage + +### Pinning PyTorch PR in PyTorch/XLA PR + +Sometimes a PyTorch/XLA PR needs to be pinned to a specific PyTorch PR to test new featurues, fix breaking changes, etc. Since PyTorch/XLA CI pulls from PyTorch master by default, we need to manually provided a PyTorch pin. In a PyTorch/XLA PR, PyTorch an be manually pinned by creating a `.torch_pin` file at the root of the repository. The `.torch_pin` should have the corresponding PyTorch PR number prefixed by "#". Take a look at [example here](https://github.com/pytorch/xla/pull/3792/commits/40f41fb98b0f2386d287eeac0bae86e873d4a9d8). Before the PyTorch/XLA PR gets merged, the `.torch_pin` must be deleted. + +### Coodinating merges for breaking PyTorch PRs + +When PyTorch PR introduces a breaking change, its PyTorch/XLA CI tests will fail. Steps for fixing and merging such breaking PyTorch change is as following: +1. Create a PyTorch/XLA PR to fix this issue with `.torch_pin` and rebase with master to ensure the PR is up-to-date with the latest commit on PyTorch/XLA. Once this PR is created, it'll create a commit hash that will be used in step 2. If you have multiple commits in the PR, use the last one's hash. **Important note: When you rebase this PR, it'll create a new commit hash and make the old hash obsolete. Be cautious about rebasing, and if you rebase, make sure you inform the PyTorch PR's author.** +2. Rebase (or ask the PR owner to rebase) the PyTorch PR with master. Update the PyTorch PR to pin the PyTorch/XLA to the commit hash created in step 1 by updating `pytorch/.github/ci_commit_pins/xla.txt`. +3. Once CI tests are green on both ends, merge PyTorch PR. +4. Remove the `.torch_pin` in PyTorch/XLA PR and merge. To be noted, `git commit --amend` should be avoided in this step as PyTorch CI will keep using the commit hash created in step 1 until other PRs update that manually or the nightly buildbot updates that automatically. +5. Finally, don't delete your branch until 2 days later. See step 4 for explanations. + +### Running TPU tests on PRs + +By default, we only run TPU tests on a postsubmit basis to save capacity. If you are making a sensitive change, add the `tpuci` label to your PR. Note that the label must be present before `build_and_test.yml` triggers. If it has already run, make a new commit or rebase to trigger the CI again. + +## CI Environment + +Before the CI in this repository runs, we build a the base dev image. These are the same images we recommend in our VSCode `.devcontainer` setup and nightly build to ensure consistency between environments. We produce variants with and without CUDA, configured in `infra/ansible` (build config) and `infra/tpu-pytorch-releases/dev_images.tf` (build triggers). + +The CI runs in two environments: + +1. Organization self-hosted runners for CPU and GPU: used for amost every step of the CI. These runners are managed by PyTorch and have access to the shared ECR repository. +2. TPU self-hosted runners: these are managed by us and are only availabe in the `pytorch/xla` repository. See the [_TPU CI_](#tpu-ci) section for more details. + +## Build and test (`build_and_test.yml`) + +We have two build paths for each CI run: + +- `torch_xla`: we build the main package to support for both TPU and GPU[^1], along with a CPU bild of `torch` from HEAD. This build step exports the `torch-xla-wheels` artifact for downstream use in tests. + - Some CI tests also require `torchvision`. To reduce flakiness, we compile `torchvision` from [`torch`'s CI pin](https://github.com/pytorch/pytorch/blob/main/.github/ci_commit_pins/vision.txt). + - C++ tests are piggybacked onto the same build and uploaded in the `cpp-test-bin` artifact. +- `torch_xla_cuda_plugin`: the XLA CUDA runtime can be built independently of either `torch` or `torch_xla` -- it depends only on our pinned OpenXLA. Thus, this build should be almost entirely cached, unless your PR changes the XLA pin or adds a patch. + +Both the main package build and plugin build are configured with ansible at `infra/ansible`, although they run in separate stages (`stage=build_srcs` vs `stage=build_plugin`). This is the same configuration we use for our nightly and release builds. + +The CPU and GPU test configs are defined in the same file, `_test.yml`. Since some of the tests come from the upstream PyTorch repository, we check out PyTorch at the same git rev as the `build` step (taken from `torch_xla.version.__torch_gitrev__`). The tests are split up into multiple groups that run in parallel; the `matrix` section of `_test.yml` corresponds to in `.github/scripts/run_tests.sh`. + +CPU tests run immediately after then `torch_xla` build completes. This will likely be the first test feedback on your commit. GPU tests will launch when both the `torch_xla` and `torch_xla_cuda_plugin` complete. GPU compilation is much slower due to the number of possible optimizations, and the GPU chips themselves are quite outdated, so these tests will take longer to run than the CPU tests. + +![CPU tests launch when `torch_xla` is complete](../docs/assets/ci_test_dependency.png) + +![GPU tests also depend on CUDA plugin](../docs/assets/ci_test_dependency_gpu.png) + +For the C++ test groups in either case, the test binaries are pre-built during the build phase and packaged in `cpp-test-bin`. This will only be downloaded if necessary. + +[^1]: Note: both GPU and TPU support require their respective plugins to be installed. This package will _not_ work on either out of the box. + +### TPU CI + +The TPU CI runs only a subset of our tests due to capacity constraints, defined in `_tpu_ci.yml` `test/tpu/run_tests.sh`. The runners themselves are containers in GKE managed by [ARC](https://github.com/actions/actions-runner-controller). The container image is also based on our dev images, with some changes for ARC compatibility. The Dockerfile for this image lives in `test/tpu/Dockerfile`. + +The actual ARC cluster is defined in Terraform at `infra/tpu-pytorch/tpu_ci.yml`. + +### Reproducing test failures + +The best way to reproduce failures in the CI is to use the recommended container configuration in `.devcontainer`. These use identical images/environments as the CI. + +If you cannot reproduce the failure or need to inspect the package built in a CI run, you can download the `torch-xla-wheels` artifact for that run, [either locally in your web browser or remotely with the `gh` CLI tool](https://docs.github.com/en/actions/managing-workflow-runs/downloading-workflow-artifacts). C++ tests in particular can be quite slow to build. If you need to re-run these yourself, download the `cpp-test-bin` artifact. You'll have to set some additional environment variables for these to load the correct `torch` and plugin binaries, so you should copy the variables we set in `_test.yml` before runnign them. + +### Generating docs + +Our API documentation is generated automatically from the `torch_xla` package with `sphinx`. The workflow to update our static site is defined in `_docs.yml`. The workflow is roughly the following: + +- Changes to `master` update the docs at `/master` on the `gh-pages` branch. +- Changes to a release brance update the docs under `/releases/rX.Y`. + +By default, we redirect to the latest stable version, defined in [`index.md`](https://github.com/pytorch/xla/blob/gh-pages/index.md). + +We build preview docs for every CI, but only push to `gh-pages` for `master` and release branches. To preview doc changes, download the `github-pages` artifact locally and open `index.html` in your browser. + +Changes to `gh-pages` are pushed by our bot account, `torchxlabot2`. + +### FAQ and Troubleshooting + +#### Why does PyTorch CI pin PyTorch/XLA? + +As mentioned above, [PyTorch CI pins PyTorch/XLA](https://github.com/pytorch/pytorch/blob/master/.jenkins/pytorch/common_utils.sh#L119) to a "known good" commit to prevent accidental changes from PyTorch/XLA to break PyTorch CI without warning. PyTorch has hundreds of commits each week, and this pin ensures that PyTorch/XLA as a downstream package does not cause failures in PyTorch CI. + +#### Why does PyTorch/XLA CI pull from PyTorch master? + +[PyTorch/XLA CI pulls PyTorch from master](https://github.com/pytorch/xla/blob/f3415929683880192b63b285921c72439af55bf0/.circleci/common.sh#L15) unless a PyTorch pin is manually provided. PyTorch/XLA is a downstream package to PyTorch, and pulling from master ensures that PyTorch/XLA will stay up-to-date and works with the latest PyTorch changes. + +## Upstream CI image (`build_upstream_image.yml`) + +We use different build tools than the upstream `torch` repository due to our dependency on XLA, namely `bazel`. To ensure the upstream CI has the correct tools to run XLA, we layer some additional tools and changes on top of our dev image and push the result to upstream's ECR instance. The upstream CI image is defined in `.github/upstream`. + +If you are making a breaking change to the image, bump the image version tag in `build_upstream_image.yml` first and then send a PR to `pytorch/pytorch` to update the tag on their side ([example](https://github.com/pytorch/pytorch/pull/125319)). + +Note: the upstream CI still relies on some legacy scripts in `.circleci` rather than our Ansible config. Don't update these without checking if they break the upstream CI first! TODO: finally delete these. diff --git a/.github/scripts/run_tests.sh b/.github/scripts/run_tests.sh new file mode 100755 index 00000000000..ae59a51490d --- /dev/null +++ b/.github/scripts/run_tests.sh @@ -0,0 +1,108 @@ +set -ex + +function run_torch_xla_python_tests() { + PYTORCH_DIR=$1 + XLA_DIR=$2 + USE_COVERAGE="${3:-0}" + + pushd $XLA_DIR + echo "Running Python Tests" + if [ "$USE_COVERAGE" != "0" ]; then + pip install coverage==6.5.0 --upgrade + pip install coverage-lcov + pip install toml + ./test/run_tests.sh + coverage combine + mkdir lcov && cp .coverage lcov/ + coverage-lcov --data_file_path lcov/.coverage + coverage html + cp lcov.info htmlcov/ + mv htmlcov ~/ + chmod -R 755 ~/htmlcov + else + ./test/run_tests.sh + fi + popd +} + +function run_torch_xla_cpp_tests() { + PYTORCH_DIR=$1 + XLA_DIR=$2 + USE_COVERAGE="${3:-0}" + + TORCH_DIR=$(python -c "import pkgutil; import os; print(os.path.dirname(pkgutil.get_loader('torch').get_filename()))") + export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${TORCH_DIR}/lib + if [ -x "$(command -v nvidia-smi)" ]; then + CUDA_PLUGIN_DIR=$(python -c "import pkgutil; import os; print(os.path.dirname(pkgutil.get_loader('torch_xla_cuda_plugin').get_filename()))") + export PJRT_LIBRARY_PATH=$CUDA_PLUGIN_DIR/lib/pjrt_c_api_gpu_plugin.so + export PJRT_DEVICE=LIBRARY + export PJRT_DYNAMIC_PLUGINS=1 + else + export PJRT_DEVICE=CPU + fi + export XLA_EXPERIMENTAL="nonzero:masked_select:nms" + + test_names1=("test_aten_xla_tensor_1" + "test_aten_xla_tensor_2" + "test_aten_xla_tensor_3" + "test_aten_xla_tensor_4" + "pjrt_computation_client_test" + "ifrt_computation_client_test") + test_names2=("test_aten_xla_tensor_5" + "test_aten_xla_tensor_6" + "test_ir" + "test_lazy" + "test_replication" + "test_tensor" + # disable test_xla_backend_intf since it is flaky on upstream + #"test_xla_backend_intf" + "test_xla_sharding") + if [[ "$RUN_CPP_TESTS1" == "cpp_tests1" ]]; then + test_names=("${test_names1[@]}") + elif [[ "$RUN_CPP_TESTS2" == "cpp_tests2" ]]; then + test_names=("${test_names2[@]}") + else + test_names=("${test_names1[@]}" "${test_names2[@]}") + fi + + for name in "${test_names[@]}"; do + echo "Running $name cpp test..." + /tmp/test/bin/${name} + done +} + +function run_torch_xla_benchmark_tests() { + XLA_DIR=$1 + pushd $XLA_DIR + echo "Running Benchmark Tests" + test/benchmarks/run_tests.sh -L"" +} + +PYTORCH_DIR=$1 +XLA_DIR=$2 +USE_COVERAGE="${3:-0}" +RUN_CPP="${RUN_CPP_TESTS:0}" +RUN_PYTHON="${RUN_PYTHON_TESTS:0}" + +if [ -x "$(command -v nvidia-smi)" ]; then + num_devices=$(nvidia-smi --list-gpus | wc -l) + echo "Found $num_devices GPU devices..." + export GPU_NUM_DEVICES=$num_devices +fi +export PYTORCH_TESTING_DEVICE_ONLY_FOR="xla" +export CXX_ABI=$(python -c "import torch;print(int(torch._C._GLIBCXX_USE_CXX11_ABI))") + +if [[ -z "$RUN_BENCHMARK_TESTS" && -z "$RUN_CPP_TESTS1" && -z "$RUN_CPP_TESTS2" && -z "$RUN_PYTHON_TESTS" ]]; then + run_torch_xla_python_tests $PYTORCH_DIR $XLA_DIR $USE_COVERAGE + run_torch_xla_cpp_tests $PYTORCH_DIR $XLA_DIR $USE_COVERAGE + run_torch_xla_benchmark_tests $XLA_DIR +else + # run tests separately. + if [[ "$RUN_PYTHON_TESTS" == "python_tests" ]]; then + run_torch_xla_python_tests $PYTORCH_DIR $XLA_DIR $USE_COVERAGE + elif [[ "$RUN_BENCHMARK_TESTS" == "benchmark_tests" ]]; then + run_torch_xla_benchmark_tests $XLA_DIR + else + run_torch_xla_cpp_tests $PYTORCH_DIR $XLA_DIR $USE_COVERAGE + fi +fi diff --git a/.circleci/docker/Dockerfile b/.github/upstream/Dockerfile similarity index 98% rename from .circleci/docker/Dockerfile rename to .github/upstream/Dockerfile index f0cd196511c..006460c2477 100644 --- a/.circleci/docker/Dockerfile +++ b/.github/upstream/Dockerfile @@ -1,3 +1,4 @@ +# Dockerfile for image used by upstream CI # This requires cuda & cudnn packages pre-installed in the base image. # Other available cuda images are listed at https://hub.docker.com/r/nvidia/cuda ARG base_image="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_12.1" diff --git a/.circleci/docker/install_conda.sh b/.github/upstream/install_conda.sh similarity index 100% rename from .circleci/docker/install_conda.sh rename to .github/upstream/install_conda.sh diff --git a/.circleci/docker/install_valgrind.sh b/.github/upstream/install_valgrind.sh similarity index 100% rename from .circleci/docker/install_valgrind.sh rename to .github/upstream/install_valgrind.sh diff --git a/.github/workflows/_build.yml b/.github/workflows/_build.yml deleted file mode 100644 index 789d0579272..00000000000 --- a/.github/workflows/_build.yml +++ /dev/null @@ -1,111 +0,0 @@ -name: xla-buld -on: - workflow_call: - inputs: - gcr-docker-image: - required: true - type: string - description: Base image for builds - ecr-docker-image-base: - required: true - type: string - description: Container registry to upload image to - runner: - required: false - type: string - description: Runner type for the test - default: linux.12xlarge - cuda: - required: false - type: string - description: Whether to build XLA with CUDA - default: 1 - - secrets: - gcloud-service-key: - required: true - description: Secret to access Bazel build cache - - outputs: - docker-image: - value: ${{ jobs.build.outputs.docker-image }} - description: The docker image containing the built PyTorch. -jobs: - build: - runs-on: ${{ inputs.runner }} - timeout-minutes: 240 - outputs: - docker-image: ${{ steps.upload-docker-image.outputs.docker-image }} - env: - ECR_DOCKER_IMAGE_BASE: ${{ inputs.ecr-docker-image-base }} - GCR_DOCKER_IMAGE: ${{ inputs.gcr-docker-image }} - WORKDIR: /var/lib/jenkins/workspace - SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 - GCLOUD_SERVICE_KEY: ${{ secrets.gcloud-service-key }} - XLA_CUDA: ${{ inputs.cuda }} - BAZEL_JOBS: 16 - steps: - - name: Setup Linux - uses: pytorch/test-infra/.github/actions/setup-linux@main - - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@main - with: - github-secret: ${{ secrets.GITHUB_TOKEN }} - instructions: | - Build is done inside the container, to start an interactive session run: - docker exec -it $(docker container ps --format '{{.ID}}') bash - - name: Checkout repo - uses: actions/checkout@v3 - - name: Download docker image from GCR - shell: bash - run: docker pull "${GCR_DOCKER_IMAGE}" - - name: Stage image to ECR - shell: bash - run: | - # This is to stage PyTorch/XLA base image for use in the upstream. - # To allow the upstream workflow to access PyTorch/XLA build images, we - # need to have them in the ECR. This is not expensive, and only pushes it - # if image layers are not present in the repo. - # Note: disable the following 2 lines while testing a new image, so we do not - # push to the upstream. - docker tag "${GCR_DOCKER_IMAGE}" "${ECR_DOCKER_IMAGE_BASE}:v1.1-lite" >/dev/null - docker push "${ECR_DOCKER_IMAGE_BASE}:v1.1-lite" >/dev/null - - name: Start the container - shell: bash - run: | - pid=$(docker run --privileged -t -d -w "$WORKDIR" "${GCR_DOCKER_IMAGE}") - docker exec -u jenkins "${pid}" sudo chown -R jenkins "${WORKDIR}" - docker cp "${GITHUB_WORKSPACE}/." "$pid:$WORKDIR" - echo "pid=${pid}" >> "${GITHUB_ENV}" - - - name: Prepare build env - shell: bash - run: | - echo "declare -x SCCACHE_BUCKET=${SCCACHE_BUCKET}" | docker exec -i "${pid}" sh -c "cat >> env" - echo "declare -x XLA_CUDA=${XLA_CUDA}" | docker exec -i "${pid}" sh -c "cat >> xla_env" - echo "declare -x BAZEL_JOBS=${BAZEL_JOBS}" | docker exec -i "${pid}" sh -c "cat >> xla_env" - echo "declare -x BAZEL_REMOTE_CACHE=1" | docker exec -i "${pid}" sh -c "cat >> xla_env" - echo "${GCLOUD_SERVICE_KEY}" | docker exec -i "${pid}" sh -c "cat >> default_credentials.json" - - - name: Build - shell: bash - run: | - docker exec --privileged -u jenkins "${pid}" bash -c ".circleci/build.sh" - - name: Cleanup build env - shell: bash - run: | - docker exec "${pid}" rm default_credentials.json /tmp/pytorch/xla/default_credentials.json - - - name: Push built docker image to ECR - id: upload-docker-image - shell: bash - run: | - export COMMIT_DOCKER_IMAGE="${ECR_DOCKER_IMAGE_BASE}:latest-${GITHUB_SHA}" - time docker commit "${pid}" "${COMMIT_DOCKER_IMAGE}" - time docker push "${COMMIT_DOCKER_IMAGE}" - echo "docker-image=${COMMIT_DOCKER_IMAGE}" >> "${GITHUB_OUTPUT}" - - - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@main - if: always() - diff --git a/.github/workflows/_build_torch_xla.yml b/.github/workflows/_build_torch_xla.yml index 3e85b7c4c98..58a783216e4 100644 --- a/.github/workflows/_build_torch_xla.yml +++ b/.github/workflows/_build_torch_xla.yml @@ -26,7 +26,7 @@ jobs: GOOGLE_APPLICATION_CREDENTIALS: /tmp/default_credentials.json BAZEL_JOBS: 16 BAZEL_REMOTE_CACHE: 1 - # BUILD_CPP_TESTS: 1 + BUILD_CPP_TESTS: 1 steps: - name: Setup gcloud shell: bash @@ -38,7 +38,6 @@ jobs: repository: pytorch/pytorch path: pytorch submodules: recursive - # TODO: correct pin - name: Checkout PyTorch/XLA Repo uses: actions/checkout@v4 with: @@ -47,9 +46,14 @@ jobs: shell: bash run: | cd pytorch/xla/infra/ansible - ansible-playbook playbook.yaml -vvv -e "stage=build arch=amd64 accelerator=tpu src_root=${GITHUB_WORKSPACE} bundle_libtpu=0 cache_suffix=-ci" --skip-tags=fetch_srcs,install_deps + ansible-playbook playbook.yaml -vvv -e "stage=build arch=amd64 accelerator=tpu src_root=${GITHUB_WORKSPACE} bundle_libtpu=0 build_cpp_tests=1 git_versioned_xla_build=1 cache_suffix=-ci" --skip-tags=fetch_srcs,install_deps - name: Upload wheel uses: actions/upload-artifact@v4 with: name: torch-xla-wheels path: /dist/*.whl + - name: Upload CPP test binaries + uses: actions/upload-artifact@v4 + with: + name: cpp-test-bin + path: /tmp/test/bin diff --git a/.github/workflows/_docs.yml b/.github/workflows/_docs.yml index ed9a4ab0ea9..378dec9697a 100644 --- a/.github/workflows/_docs.yml +++ b/.github/workflows/_docs.yml @@ -2,10 +2,10 @@ name: xla-docs-build on: workflow_call: inputs: - docker-image: + dev-image: required: true type: string - description: Image to build docs in + description: Base image for builds runner: required: false type: string @@ -15,35 +15,57 @@ on: torchxla-bot-token: required: true jobs: - push-docs: - runs-on: ${{ inputs.runner }} + build-docs: + runs-on: ubuntu-latest timeout-minutes: 45 + container: + image: ${{ inputs.dev-image }} env: - DOCKER_IMAGE: ${{ inputs.docker-image }} - WORKDIR: /var/lib/jenkins/workspace + BRANCH_NAME: ${{ github.ref_name }} steps: - - name: Setup Linux - uses: pytorch/test-infra/.github/actions/setup-linux@main - - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@main - with: - github-secret: ${{ secrets.GITHUB_TOKEN }} - instructions: | - Doc builds are done inside container. Interactive session can be started by following: - docker exec -it $(docker container ps --format '{{.ID}}') bash - - name: Download and run docker image from GCR - shell: bash - env: - GITHUB_TORCH_XLA_BOT_TOKEN: ${{ secrets. torchxla-bot-token }} - run: | - echo "DOCKER_IMAGE: ${DOCKER_IMAGE}" - docker pull "${DOCKER_IMAGE}" - pid=$(docker run -e GITHUB_TORCH_XLA_BOT_TOKEN -t -d -w "$WORKDIR" "${DOCKER_IMAGE}") - echo "${GCLOUD_SERVICE_KEY}" | docker exec -i "${pid}" sh -c "cat >> /tmp/pytorch/xla/default_credentials.json" - echo "pid=${pid}" >> "${GITHUB_ENV}" - - name: Build & publish docs - shell: bash - run: docker exec -u jenkins "${pid}" bash -c '.circleci/doc_push.sh' - - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@main - if: always() + - name: Fetch wheels + uses: actions/download-artifact@v4 + with: + name: torch-xla-wheels + path: /tmp/wheels/ + - name: Install wheels + shell: bash + run: | + pip install /tmp/wheels/*.whl + - name: Checkout PyTorch/XLA Repo + uses: actions/checkout@v4 + with: + path: pytorch/xla + - name: Build docs + shell: bash + run: | + cd pytorch/xla/docs + pip install -r requirements.txt + sphinx-build -b html source build + - name: Checkout GitHub Pages + uses: actions/checkout@v4 + with: + path: gh-pages + ref: gh-pages + token: ${{ secrets.torchxla-bot-token }} + - name: Merge changes + shell: bash + run: | + subdir=${{ env.BRANCH_NAME == 'master' && 'master' || format('{0}/{1}', 'release', env.BRANCH_NAME) }} + mkdir -p gh-pages/$subdir + cp -fR pytorch/xla/docs/build/* gh-pages/$subdir + - name: Upload preview as artifact + uses: actions/upload-artifact@v4 + with: + name: github-pages + path: pytorch/xla/docs/build/ + - name: Deploy + shell: bash + run: | + cd gh-pages + git config user.email "pytorchxla@gmail.com" + git config user.name "torchxlabot2" + git add . -v + git diff --cached --exit-code || git commit -m "Update doc from commit ${{ github.sha }}" + git push origin gh-pages + if: github.event_name == 'push' diff --git a/.github/workflows/_test_python.yml b/.github/workflows/_test.yml similarity index 84% rename from .github/workflows/_test_python.yml rename to .github/workflows/_test.yml index bd260cdb2d1..8a454cc075b 100644 --- a/.github/workflows/_test_python.yml +++ b/.github/workflows/_test.yml @@ -53,6 +53,10 @@ jobs: run_xla_op_tests3: 'xla_op3' - run_python_tests: 'python_tests' run_torch_mp_op_tests: 'torch_mp_op' + - run_cpp_tests: 'cpp_tests' + run_cpp_tests1: 'cpp_tests1' + - run_cpp_tests: 'cpp_tests' + run_cpp_tests2: 'cpp_tests2' timeout-minutes: ${{ inputs.timeout-minutes }} env: GCLOUD_SERVICE_KEY: ${{ secrets.gcloud-service-key }} @@ -64,9 +68,16 @@ jobs: RUN_XLA_OP_TESTS2: ${{ matrix.run_xla_op_tests2 }} RUN_XLA_OP_TESTS3: ${{ matrix.run_xla_op_tests3 }} RUN_TORCH_MP_OP_TESTS: ${{ matrix.run_torch_mp_op_tests }} + RUN_CPP_TESTS1: ${{ matrix.run_cpp_tests1 }} + RUN_CPP_TESTS2: ${{ matrix.run_cpp_tests2 }} BAZEL_JOBS: 16 BAZEL_REMOTE_CACHE: 1 steps: + # See https://github.com/actions/checkout/issues/1014#issuecomment-1906802802 + - name: Clean up workspace + run: | + ls -la + rm -rvf ${GITHUB_WORKSPACE}/* - name: Setup gcloud shell: bash run: | @@ -76,6 +87,19 @@ jobs: with: name: torch-xla-wheels path: /tmp/wheels/ + - name: Fetch CPP test binaries + uses: actions/download-artifact@v4 + with: + name: cpp-test-bin + path: /tmp/test/bin + if: ${{ matrix.run_cpp_tests }} + # GitHub Actions doesn't preserve executable permissions + # https://github.com/actions/download-artifact?tab=readme-ov-file#permission-loss + - name: Set CPP test permissions + run: | + chmod +x /tmp/test/bin/* + ls -l /tmp/test/bin + if: ${{ matrix.run_cpp_tests }} - name: Fetch CUDA plugin uses: actions/download-artifact@v4 with: @@ -101,8 +125,17 @@ jobs: # TODO: Add these in setup.py pip install fsspec pip install rich + + echo "Import check..." + python -c "import torch_xla" - name: Record PyTorch commit - run: echo "PYTORCH_COMMIT=$(python -c 'import torch_xla.version; print(torch_xla.version.__torch_gitrev__)')" >> $GITHUB_ENV + run: | + # Don't just pipe output in shell because imports may do extra logging + python -c " + import torch_xla.version + with open('$GITHUB_ENV', 'a') as f: + f.write(f'PYTORCH_COMMIT={torch_xla.version.__torch_gitrev__}\n') + " - name: Checkout PyTorch Repo uses: actions/checkout@v4 with: @@ -125,10 +158,7 @@ jobs: fi - name: Test shell: bash - run: | - source pytorch/xla/.circleci/common.sh - - run_torch_xla_tests pytorch/ pytorch/xla/ $USE_COVERAGE + run: pytorch/xla/.github/scripts/run_tests.sh pytorch/ pytorch/xla/ $USE_COVERAGE - name: Upload coverage results if: ${{ inputs.collect-coverage }} shell: bash diff --git a/.github/workflows/_test_cpp.yml b/.github/workflows/_test_cpp.yml deleted file mode 100644 index d0056d34963..00000000000 --- a/.github/workflows/_test_cpp.yml +++ /dev/null @@ -1,150 +0,0 @@ -name: xla-test -on: - workflow_call: - inputs: - docker-image: - required: true - type: string - description: Image to test on - runner: - required: false - type: string - description: Runner type for the test - default: linux.12xlarge - collect-coverage: - required: false - type: boolean - description: Set to true to collect coverage information - default: false - timeout-minutes: - required: false - type: number - default: 270 - description: | - Set the maximum (in minutes) how long the workflow should take to finish - disable-pjrt: - required: false - type: string - default: 0 - description: Whether to disable PJRT tests - test-script: - required: false - type: string - default: test.sh - description: Which test script to run - - secrets: - gcloud-service-key: - required: true - description: Secret to access Bazel build cache -jobs: - test: - runs-on: ${{ inputs.runner }} - strategy: - fail-fast: false - matrix: - include: - # Use readable strings as they define the workflow titles. - - run_cpp_tests1: 'cpp_tests1' - - run_cpp_tests2: 'cpp_tests2' - timeout-minutes: ${{ inputs.timeout-minutes }} - env: - DOCKER_IMAGE: ${{ inputs.docker-image }} - WORKDIR: /var/lib/jenkins/workspace - GCLOUD_SERVICE_KEY: ${{ secrets.gcloud-service-key }} - USE_COVERAGE: ${{ inputs.collect-coverage && '1' || '0' }} - XLA_SKIP_TORCH_OP_TESTS: ${{ inputs.disable-pjrt }} - XLA_SKIP_MP_OP_TESTS: ${{ inputs.disable-pjrt }} - RUN_CPP_TESTS1: ${{ matrix.run_cpp_tests1 }} - RUN_CPP_TESTS2: ${{ matrix.run_cpp_tests2 }} - steps: - - name: Setup Linux - uses: pytorch/test-infra/.github/actions/setup-linux@main - - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@main - with: - github-secret: ${{ secrets.GITHUB_TOKEN }} - instructions: | - Tests are done inside the container, to start an interactive session run: - docker exec -it $(docker container ps --format '{{.ID}}') bash - - name: Install gcloud CLI - if: ${{ inputs.collect-coverage }} - shell: bash - run: | - sudo tee -a /etc/yum.repos.d/google-cloud-sdk.repo << EOM - [google-cloud-cli] - name=Google Cloud CLI - baseurl=https://packages.cloud.google.com/yum/repos/cloud-sdk-el8-x86_64 - enabled=1 - gpgcheck=1 - repo_gpgcheck=0 - gpgkey=https://packages.cloud.google.com/yum/doc/rpm-package-key.gpg - EOM - sudo yum install -y google-cloud-cli - - name: Auth to GCR - if: ${{ inputs.collect-coverage }} - shell: bash - run: | - echo "${GCLOUD_SERVICE_KEY}" | gcloud auth activate-service-account --key-file=- - - name: Download and run docker image from GCR - shell: bash - run: | - echo "DOCKER_IMAGE: ${DOCKER_IMAGE}" - docker pull "${DOCKER_IMAGE}" - pid=$(docker run --shm-size=16g ${GPU_FLAG:-} -e USE_COVERAGE -e XLA_SKIP_TORCH_OP_TESTS -e XLA_SKIP_MP_OP_TESTS -e RUN_BENCHMARK_TESTS -e RUN_CPP_TESTS1 -e RUN_CPP_TESTS2 -e RUN_PYTHON_TESTS -e RUN_XLA_OP_TESTS1 -e RUN_XLA_OP_TESTS2 -e RUN_XLA_OP_TESTS3 -e RUN_TORCH_MP_OP_TESTS -t -d -w "$WORKDIR" "${DOCKER_IMAGE}") - echo "${GCLOUD_SERVICE_KEY}" | docker exec -i "${pid}" sh -c "cat >> /tmp/pytorch/xla/default_credentials.json" - echo "pid=${pid}" >> "${GITHUB_ENV}" - - name: Test - shell: bash - run: | - docker exec --privileged -u jenkins "${pid}" bash -c '.circleci/${{ inputs.test-script }}' - - name: Upload coverage results - if: ${{ inputs.collect-coverage }} - shell: bash - env: - CIRCLE_WORKFLOW_ID: ${{ github.run_id }} - CIRCLE_BUILD_NUM: ${{ github.run_number }} - BENCHMARK_TEST_NAME: ${{ env.RUN_BENCHMARK_TESTS }} - PYTHON_TEST_NAME: ${{ env.RUN_PYTHON_TESTS }}${{ env.RUN_XLA_OP_TESTS1 }}${{ env.RUN_XLA_OP_TESTS2 }}${{ env.RUN_XLA_OP_TESTS3 }}${{ env.RUN_TORCH_MP_OP_TESTS }} - CPP_TEST_NAME: ${{ env.RUN_CPP_TESTS1 }}${{ env.RUN_CPP_TESTS2 }} - run: | - # TODO(yeounoh) collect coverage report as needed. - if [ -n "${BENCHMARK_TEST_NAME}" ]; then - exit 0 - fi - docker cp "${pid}":/home/jenkins/htmlcov "${GITHUB_WORKSPACE}" - if [ -n "${GPU_FLAG:-}" ]; then - if [ -n "${PYTHON_TEST_NAME}" ]; then - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/gpu_python_coverage_${PYTHON_TEST_NAME}.out - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/gpu_python_coverage_${PYTHON_TEST_NAME}.out - fi - if [ -n "${CPP_TEST_NAME}" ]; then - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/cpp_lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/gpu_cpp_coverage_${CPP_TEST_NAME}.out - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/cpp_lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/gpu_cpp_coverage_${CPP_TEST_NAME}.out - fi - else - if [ -n "${PYTHON_TEST_NAME}" ]; then - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_python_coverage_${PYTHON_TEST_NAME}.out - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_python_coverage_${PYTHON_TEST_NAME}.out - fi - - if [ -n "${CPP_TEST_NAME}" ]; then - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/cpp_lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_cpp_coverage_${CPP_TEST_NAME}.out - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/cpp_lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_cpp_coverage_${CPP_TEST_NAME}.out - fi - - if [ "${CPP_TEST_NAME}" == "cpp_tests1" ]; then - ABS_METADATA='{"host": "github", "project": "pytorchxla", "trace_type": "LCOV", "commit_id": '\"${GITHUB_SHA}\"', "ref": "HEAD", "source": "https://github.com/pytorch/xla", "owner": "cloud-tpu-pt-dev", "bug_component": "587012"}' - echo $ABS_METADATA > abs_metadata.json - gsutil cp abs_metadata.json gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/metadata.json - - INC_METADATA='{"host": "github", "project": "pytorchxla", "trace_type": "LCOV", "patchset_num": 1, "change_id": '${CIRCLE_BUILD_NUM}', "owner": "cloud-tpu-pt-dev", "bug_component": "587012"}' - echo $INC_METADATA > inc_metadata.json - gsutil cp inc_metadata.json gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/metadata.json - fi - fi - - - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@main - if: always() - diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index e5738b5a6af..60a2eda44cd 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -19,53 +19,9 @@ concurrency: cancel-in-progress: true jobs: - # Old CI workflow - build: - name: "Build PyTorch/XLA (GPU)" - uses: ./.github/workflows/_build.yml - with: - ecr-docker-image-base: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base - gcr-docker-image: gcr.io/tpu-pytorch/xla_base:dev-3.8_cuda_12.1 - cuda: 1 - secrets: - gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} - - test-cpp-cpu: - name: "CPU C++ tests" - uses: ./.github/workflows/_test_cpp.yml - needs: build - with: - docker-image: ${{ needs.build.outputs.docker-image }} - timeout-minutes: 120 - collect-coverage: false # TODO(yeounoh) separate from CPU coverage metrics - secrets: - gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} - - test-cpp-cuda: - name: "GPU C++ tests" - uses: ./.github/workflows/_test_cpp.yml - needs: build - with: - docker-image: ${{ needs.build.outputs.docker-image }} - runner: linux.8xlarge.nvidia.gpu - timeout-minutes: 300 - collect-coverage: false # TODO(yeounoh) separate from CPU coverage metrics - secrets: - gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} - push-docs: - name: "Build & publish docs" - if: github.event_name == 'push' && (github.event.ref == 'refs/heads/master' || startsWith(github.event.ref, 'refs/tags/r')) - uses: ./.github/workflows/_docs.yml - needs: build - with: - docker-image: ${{ needs.build.outputs.docker-image }} - secrets: - torchxla-bot-token: ${{ secrets.TORCH_XLA_BOT_TOKEN }} - - # New CI workflow build-torch-xla: - name: "Build PyTorch/XLA (TPU)" + name: "Build PyTorch/XLA" uses: ./.github/workflows/_build_torch_xla.yml with: dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_tpuvm @@ -81,8 +37,8 @@ jobs: gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} test-python-cpu: - name: "CPU Python tests" - uses: ./.github/workflows/_test_python.yml + name: "CPU tests" + uses: ./.github/workflows/_test.yml needs: build-torch-xla with: dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_tpuvm @@ -91,9 +47,9 @@ jobs: secrets: gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} - test-python-cuda: - name: "GPU Python tests" - uses: ./.github/workflows/_test_python.yml + test-cuda: + name: "GPU tests" + uses: ./.github/workflows/_test.yml needs: [build-torch-xla, build-cuda-plugin] with: dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.1 @@ -109,5 +65,13 @@ jobs: uses: ./.github/workflows/_tpu_ci.yml needs: build-torch-xla # Only run this for HEAD and releases - if: github.event_name == 'push' + if: github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'tpuci') + push-docs: + name: "Build docs" + uses: ./.github/workflows/_docs.yml + needs: build-torch-xla + with: + dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_tpuvm + secrets: + torchxla-bot-token: ${{ secrets.TORCH_XLA_BOT_TOKEN }} diff --git a/.github/workflows/build_upstream_image.yml b/.github/workflows/build_upstream_image.yml new file mode 100644 index 00000000000..446ad366e54 --- /dev/null +++ b/.github/workflows/build_upstream_image.yml @@ -0,0 +1,40 @@ +name: Build upstream image +on: + push: + branches: + - master + - r[0-9]+.[0-9]+ + paths-ignore: + - 'experimental/torch_xla2/**' + workflow_dispatch: +jobs: + build: + runs-on: linux.12xlarge + timeout-minutes: 30 + env: + ECR_DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base + BAZEL_JOBS: 16 + steps: + # See https://github.com/actions/checkout/issues/1014#issuecomment-1906802802 + - name: Clean up workspace + run: | + ls -la + sudo rm -rvf ${GITHUB_WORKSPACE}/* + - name: Setup Linux + uses: pytorch/test-infra/.github/actions/setup-linux@main + - name: Checkout repo + uses: actions/checkout@v3 + - name: Build Docker image + shell: bash + run: | + docker build -t "${ECR_DOCKER_IMAGE_BASE}:v1.2-lite" .github/upstream + - name: Stage image to ECR + shell: bash + run: | + # This is to stage PyTorch/XLA base image for use in the upstream. + # To allow the upstream workflow to access PyTorch/XLA build images, we + # need to have them in the ECR. This is not expensive, and only pushes it + # if image layers are not present in the repo. + # Note: disable the following line while testing a new image, so we do not + # push to the upstream. + docker push "${ECR_DOCKER_IMAGE_BASE}:v1.2-lite" diff --git a/.github/workflows/lintercheck.yml b/.github/workflows/lintercheck.yml index 6598b98da32..b17c608f883 100644 --- a/.github/workflows/lintercheck.yml +++ b/.github/workflows/lintercheck.yml @@ -24,7 +24,7 @@ jobs: if: github.event_name == 'push' && github.event.ref == 'refs/heads/master' shell: bash run: | - TORCH_PIN=./torch_patches/.torch_pin + TORCH_PIN=./.torch_pin if [[ -f "${TORCH_PIN}" ]]; then echo "Please remove ${TORCH_PIN} before landing." exit 1 diff --git a/.github/workflows/torch_xla2.yml b/.github/workflows/torch_xla2.yml index 7c5a88bf430..441addad422 100644 --- a/.github/workflows/torch_xla2.yml +++ b/.github/workflows/torch_xla2.yml @@ -34,10 +34,8 @@ jobs: shell: bash working-directory: experimental/torch_xla2 run: | - pip install pytest absl-py jax[cpu] flatbuffers tensorflow - pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install -r test_requirements.txt - pip install -e . + pip install -r test-requirements.txt + pip install -e .[cpu] - name: Run tests working-directory: experimental/torch_xla2 shell: bash diff --git a/BUILD b/BUILD index 6949f6dc748..60b601240fc 100644 --- a/BUILD +++ b/BUILD @@ -30,3 +30,23 @@ cc_binary( "@xla//xla/stream_executor:cuda_platform", ]), ) + +test_suite( + name = "cpp_tests", + # testonly = True, + tests = [ + "//test/cpp:test_aten_xla_tensor_1", + "//test/cpp:test_aten_xla_tensor_2", + "//test/cpp:test_aten_xla_tensor_3", + "//test/cpp:test_aten_xla_tensor_4", + "//test/cpp:test_aten_xla_tensor_5", + "//test/cpp:test_aten_xla_tensor_6", + "//test/cpp:test_ir", + "//test/cpp:test_lazy", + "//test/cpp:test_replication", + "//test/cpp:test_tensor", + "//test/cpp:test_xla_sharding", + "//torch_xla/csrc/runtime:pjrt_computation_client_test", + "//torch_xla/csrc/runtime:ifrt_computation_client_test", + ], +) diff --git a/OP_LOWERING_GUIDE.md b/OP_LOWERING_GUIDE.md index b445a1d8998..535d7cf596c 100644 --- a/OP_LOWERING_GUIDE.md +++ b/OP_LOWERING_GUIDE.md @@ -25,7 +25,7 @@ All file mentioned below lives under the `xla/torch_xla/csrc` folder, with the e 7. `ops/` directory contains all `ir::ops` declaration and definition. Smaller nodes can be put in `ops/ops.h/.cpp`. More complicated nodes can be put into a separate file. All ops inherit from `ir::ops::Node` and provide a way to lower input `ir::Value` to a sequence of `XlaOp`. ## Unit Test -Our CircleCI runs PyTorch native python tests for every change and every day. Those tests will use XLA implementation if we provide a lowering. We usually don’t need to add additional python tests for PyTorch/XLA unless we want to verify some xla behaviors(like dynamic shape) or we skipped the pytorch native test for some reason. The python test should be added to `xla/test/test_operations.py` if it is required. We also need to add CPP tests in `xla/test/cpp/test_aten_xla_tensor.cpp`. This test should call PyTorch c++ API and verify our implementation yields the same result as PyTorch native implementation. We also need to verify if the xla implementation is called when the tensor is a XLA tensor by checking the `aten::op` and `xla::op` counters. +Our CI runs PyTorch native python tests for every change and every day. Those tests will use XLA implementation if we provide a lowering. We usually don’t need to add additional python tests for PyTorch/XLA unless we want to verify some xla behaviors(like dynamic shape) or we skipped the pytorch native test for some reason. The python test should be added to `xla/test/test_operations.py` if it is required. We also need to add CPP tests in `xla/test/cpp/test_aten_xla_tensor.cpp`. This test should call PyTorch c++ API and verify our implementation yields the same result as PyTorch native implementation. We also need to verify if the xla implementation is called when the tensor is a XLA tensor by checking the `aten::op` and `xla::op` counters. ## Tips The process of lowering is breaking down the PyTorch operations into a sequence of XlaOp. To provide a good lowering of the PyTorch operation, one needs to have a good grasp of what XLA is capable of. Reading the XlaOp document and looking into how similar ops is lowered is the best way to achieve that. You can find a minimal Op lowering example in [this pr](https://github.com/pytorch/xla/pull/2969). You can also find a slightly more complicated example with backward lowering in [this pr](https://github.com/pytorch/xla/pull/2972). diff --git a/README.md b/README.md index d1653eb7b53..70bdcfd57d9 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ started: To install PyTorch/XLA a new TPU VM: ``` -pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html +pip install torch~=2.3.0 torch_xla[tpu]~=2.3.0 -f https://storage.googleapis.com/libtpu-releases/index.html ``` To update your existing training loop, make the following changes: @@ -132,31 +132,36 @@ Our comprehensive user guides are available at: PyTorch/XLA releases starting with version r2.1 will be available on PyPI. You can now install the main build with `pip install torch_xla`. To also install the -Cloud TPU plugin, install the optional `tpu` dependencies: +Cloud TPU plugin, install the optional `tpu` dependencies after installing the main build with ``` pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html ``` -GPU, XRT (legacy runtime), and nightly builds are available in our public GCS -bucket. +GPU and nightly builds are available in our public GCS bucket. | Version | Cloud TPU/GPU VMs Wheel | | --- | ----------- | -| 2.2 (Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl` | -| 2.2 (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl` | -| 2.2 (CUDA 12.1 + Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl` | -| 2.2 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl` | +| 2.3 (Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.3.0-cp38-cp38-manylinux_2_28_x86_64.whl` | +| 2.3 (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl` | +| 2.3 (CUDA 12.1 + Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp38-cp38-manylinux_2_28_x86_64.whl` | +| 2.3 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl` | | nightly (Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl` | | nightly (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl` | | nightly (CUDA 12.1 + Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-nightly-cp38-cp38-linux_x86_64.whl` | +You can also add `+yyyymmdd` after `torch_xla-nightly` to get the nightly wheel of a specified date. To get the companion pytorch nightly wheel, replace the `torch_xla` with `torch` on above wheel links. +
older versions | Version | Cloud TPU VMs Wheel | |---------|-------------------| +| 2.2 (Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl` | +| 2.2 (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl` | +| 2.2 (CUDA 12.1 + Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl` | +| 2.2 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl` | | 2.1 (XRT + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch_xla-2.1.0%2Bxrt-cp310-cp310-manylinux_2_28_x86_64.whl` | | 2.1 (Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.1.0-cp38-cp38-linux_x86_64.whl` | | 2.0 (Python 3.8) | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl` | @@ -202,25 +207,29 @@ wheels for `torch` and `torch_xla` at | --- | ----------- | | 2.0 | `https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl` | -You can also add `+yyyymmdd` after `torch_xla-nightly` to get the nightly wheel -of a specified date. To get the companion pytorch and torchvision nightly wheel, -replace the `torch_xla` with `torch` or `torchvision` on above wheel links.
### Docker | Version | Cloud TPU VMs Docker | | --- | ----------- | +| 2.3 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_tpuvm` | | 2.2 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_tpuvm` | | 2.1 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_tpuvm` | | 2.0 | `gcr.io/tpu-pytorch/xla:r2.0_3.8_tpuvm` | | 1.13 | `gcr.io/tpu-pytorch/xla:r1.13_3.8_tpuvm` | | nightly python | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm` | +To use the above dockers, please pass `--privileged --net host --shm-size=16G` along. Here is an example: +```bash +docker run --privileged --net host --shm-size=16G -it us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm /bin/bash +``` +
| Version | GPU CUDA 12.1 Docker | | --- | ----------- | +| 2.3 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_cuda_12.1` | | 2.2 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_cuda_12.1` | | 2.1 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_12.1` | | nightly | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1` | diff --git a/WORKSPACE b/WORKSPACE index 07a4b53b910..86c78c57bda 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -51,9 +51,9 @@ http_archive( "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:f16_abi_clang.diff", ], - strip_prefix = "xla-54ca388f9ad9e8bbcb0ef823752d6b47a99d0b5f", + strip_prefix = "xla-80462ef5b22360df177fe24fc13c81b235d3f3a2", urls = [ - "https://github.com/openxla/xla/archive/54ca388f9ad9e8bbcb0ef823752d6b47a99d0b5f.tar.gz", + "https://github.com/openxla/xla/archive/80462ef5b22360df177fe24fc13c81b235d3f3a2.tar.gz", ], ) diff --git a/benchmarks/run_benchmark.sh b/benchmarks/run_benchmark.sh index fd8a055bccc..e4e483947d9 100644 --- a/benchmarks/run_benchmark.sh +++ b/benchmarks/run_benchmark.sh @@ -5,7 +5,7 @@ LOGFILE=/tmp/benchmark_test.log # Note [Keep Going] # -# Set the `CONTINUE_ON_ERROR` flag to `true` to make the CircleCI tests continue on error. +# Set the `CONTINUE_ON_ERROR` flag to `true` to make the CI tests continue on error. # This will allow you to see all the failures on your PR, not stopping with the first # test failure like the default behavior. CONTINUE_ON_ERROR="${CONTINUE_ON_ERROR:-0}" diff --git a/build_util.py b/build_util.py index 78e4bd5e453..487f5116323 100644 --- a/build_util.py +++ b/build_util.py @@ -36,10 +36,6 @@ def bazel_options_from_env() -> Iterable[str]: bazel_flags.append('--remote_default_exec_properties=cache-silo-key=%s' % cache_silo_name) - if check_env_flag('BUILD_CPP_TESTS', default='0'): - bazel_flags.append('//test/cpp:all') - bazel_flags.append('//torch_xla/csrc/runtime:all') - bazel_jobs = os.getenv('BAZEL_JOBS', default='') if bazel_jobs: bazel_flags.append('--jobs=%s' % bazel_jobs) diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index 199025dc7e1..de5500a0c5b 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -361,6 +361,7 @@ supported: - zero_ - _native_batch_norm_legit - _native_batch_norm_legit.no_stats + - _embedding_bag_forward_only # Note: [functionalization and CompositeExplicitAutograd] # Below are all operators that are "composite" in core, # but require us to explicitly re-enable functionalization in order to use them. diff --git a/docs/README.md b/docs/README.md index 33a0ce5bc36..88ab7f44f03 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,25 +1,25 @@ ## Publish documentation for a new release. -CircleCI job `pytorch_xla_linux_debian11_and_push_doc` is specified to run on `release/*` branches, but it was not +CI job `pytorch_xla_linux_debian11_and_push_doc` is specified to run on `release/*` branches, but it was not run on release branches due to "Only build pull requests" setting. Turning off "Only build pull requests" will result in much larger volumes in jobs which is often unnecessary. We're waiting for [this feature request](https://ideas.circleci.com/ideas/CCI-I-215) to be implemented so that we could override this setting on some branches. Before the feature is available on CircleCi side, we'll use a manual process to publish documentation for release. -[Documentation for master branch](http://pytorch.org/xla/master/) is still updated automatically by the CircleCI job. +[Documentation for master branch](http://pytorch.org/xla/master/) is still updated automatically by the CI job. But we'll need to manually commit the new versioned doc and point http://pytorch.org/xla to the documentation of new stable release. -Take 1.5 release as example: +Take 2.3 release as example: ``` -# Build pytorch/pytorch:release/1.5 and pytorch/xla:release/1.5 respectively. +# Build pytorch/pytorch:release/2.3 and pytorch/xla:release/2.3 respectively. # In pytorch/xla/docs ./docs_build.sh git clone -b gh-pages https://github.com/pytorch/xla.git /tmp/xla -cp -r build/* /tmp/xla/release/1.5 +cp -r build/* /tmp/xla/release/2.3 cd /tmp/xla # Update `redirect_url` in index.md git add . -git commit -m "Publish 1.5 documentation." +git commit -m "Publish 2.3 documentation." git push origin gh-pages -``` \ No newline at end of file +``` diff --git a/docs/assets/ci_test_dependency.png b/docs/assets/ci_test_dependency.png new file mode 100644 index 00000000000..e4b2c397ba0 Binary files /dev/null and b/docs/assets/ci_test_dependency.png differ diff --git a/docs/assets/ci_test_dependency_gpu.png b/docs/assets/ci_test_dependency_gpu.png new file mode 100644 index 00000000000..68cd77ec90c Binary files /dev/null and b/docs/assets/ci_test_dependency_gpu.png differ diff --git a/docs/gpu.md b/docs/gpu.md index c2678164f4e..de1cf807361 100644 --- a/docs/gpu.md +++ b/docs/gpu.md @@ -71,9 +71,12 @@ source ~/.bashrc ### Wheel ``` -pip3 install torch==2.2.0 -pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl +pip3 install torch==2.3.0 +# GPU whl for python 3.10 + cuda 12.1 +pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl ``` +Wheels for other Python version and CUDA version can be found [here](https://github.com/pytorch/xla?tab=readme-ov-file#available-docker-images-and-wheels). + ## Run a simple model In order to run below examples, you need to clone the pytorch/xla repo to access the imagenet example(We already clone it in our docker). diff --git a/docs/pallas.md b/docs/pallas.md new file mode 100644 index 00000000000..46c80b79f2e --- /dev/null +++ b/docs/pallas.md @@ -0,0 +1,57 @@ +# Custom Kernels via Pallas + +With the rise of OpenAI [triton](https://openai.com/research/triton), custom kernels become more and more popular in the GPU community, for instance, the introduction of [FlashAttention](https://github.com/Dao-AILab/flash-attention) and [PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html). In order to provide the feature parity in the TPU world, Google has introduced [Pallas](http://go/jax-pallas) and [Mosaic](http://go/mosaic-tpu). For PyTorch/XLA to continue pushing the performance in TPU, we have to support custom kernels, and the best way is through Pallas and Mosaic. The design doc is [TBA](). + +Let's assume you have a Pallas kernel defined as follow: +```python3 +import jax +from jax.experimental import pallas as pl +import jax.numpy as jnp + +def add_vectors_kernel(x_ref, y_ref, o_ref): + x, y = x_ref[...], y_ref[...] + o_ref[...] = x + y + +@jax.jit +def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: + return pl.pallas_call(add_vectors_kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) + )(x, y) +``` + +## Adopt the above kernel to be compatible with PyTorch/XLA + +Example usage: +```python3 +q = torch.randn(3, 2, 128, 4).to("xla") +k = torch.randn(3, 2, 128, 4).to("xla") +v = torch.randn(3, 2, 128, 4).to("xla") + +# Adopts any Pallas kernel +from torch_xla.experimental.custom_kernel import make_kernel_from_pallas +pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)]) +output = pt_kernel(q, k) +``` +For simple kernels, the adoption is just as simple as one liner. For more complicated kernels, you can refer to our Flash Attention implementation for details. + +## Use built-in kernels + +Besides manually wrapping external Pallas kernels, there are built-in kernels where the adoptions are done by PyTorch/XLA already. + +Example usage: +```python3 +# Use built-in kernels +from torch_xla.experimental.custom_kernel import flash_attention +output = flash_attention(q, k, v) +``` + +You can just use it like any other torch.ops. + +## HuggingFace Llama 3 Example +We have a fork of HF Llama 3 to demonstrate a potential integration [here](https://github.com/pytorch-tpu/transformers/tree/alanwaketan/flash_attention). + +## Dependencies +The Pallas integration depends on JAX to function. However, not every JAX version is compatible with your installed PyTorch/XLA. To install the proper JAX: +```bash +pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html +``` diff --git a/docs/requirements.txt b/docs/requirements.txt index 411e6642ff7..0d0f871b154 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,5 @@ mistune==0.8.4 -sphinx==5.0.0 +sphinx==5.3.0 docutils==0.16 Jinja2==3.1.3 m2r diff --git a/experimental/torch_xla2/README.md b/experimental/torch_xla2/README.md index f30be7ff1da..594d5380882 100644 --- a/experimental/torch_xla2/README.md +++ b/experimental/torch_xla2/README.md @@ -4,7 +4,8 @@ Currently this is only source-installable. Requires Python version >= 3.10. -### NOTE: +### NOTE: + Please don't install torch-xla from instructions in https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md . In particular, the following are not needed: @@ -18,71 +19,58 @@ TorchXLA2 and torch-xla have different installation instructions, please follow the instructions below from scratch (fresh venv / conda environment.) -### 1. Install dependencies - -#### 1.0 (optional) Make a virtualenv / conda env, and activate it. - -```bash -conda create --name python=3.10 -conda activate -``` -Or, -```bash -python -m venv create my_venv -source my_venv/bin/activate -``` - -#### 1.1 Install torch CPU, even if your device has GPU or TPU: +### 1. Installing `torch_xla2` -```bash -pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu -``` +#### 1.0 (recommended) Make a virtualenv / conda env -Or, follow official instructions in [pytorch.org](https://pytorch.org/get-started/locally/) to install for your OS. +If you are using VSCode, then [you can create a new environment from +UI](https://code.visualstudio.com/docs/python/environments). Select the +`dev-requirements.txt` when asked to install project dependencies. -#### 1.2 Install Jax for either GPU or TPU +Otherwise create a new environment from the command line. -If you are using Google Cloud TPU, then ```bash -pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -``` +# Option 1: venv +python -m venv create my_venv +source my_venv/bin/activate -If you are using a machine with NVidia GPU: +# Option 2: conda +conda create --name python=3.10 +conda activate -```bash -pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +# Either way, install the dev requirements. +pip install -r dev-requirements.txt ``` -If you are using a CPU-only machine: -```bash -pip install --upgrade "jax[cpu]" -``` +Note: `dev-requirements.txt` will install the CPU-only version of PyTorch. -Or, follow the official instructions in https://jax.readthedocs.io/en/latest/installation.html to install for your OS or Device. +#### 1.1 Install this package -#### 1.3 Install this package +Install `torch_xla2` from source for your platform: ```bash -pip install -e . +pip install -e .[cpu] +pip install -e .[cuda] +pip install -e .[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html ``` -#### 1.4 (optional) verify installation by running tests +#### 1.2 (optional) verify installation by running tests ```bash -pip install -r test_requirements.txt +pip install -r test-requirements.txt pytest test ``` - ## Run a model Now let's execute a model under torch_xla2. We'll start with a simple 2-layer model it can be in theory any instance of `torch.nn.Module`. ```python +import torch +import torch.nn as nn +import torch.nn.functional as F -import torch_xla2 -from torch import nn class MyModel(nn.Module): def __init__(self): @@ -101,8 +89,8 @@ class MyModel(nn.Module): m = MyModel() # Execute this model using torch -inputs = (torch.randn(3, 3, 28, 28), ) -print(m(*inputs)) +inputs = torch.randn(3, 3, 28, 28) +print(m(inputs)) ``` This model `m` contains 2 parts: the weights that is stored inside of the model @@ -114,6 +102,7 @@ to `XLA` devices. This can be accomplished with `torch_xla2.tensor.move_to_devic We need move both the weights and the input to xla devices: ```python +import torch_xla2 from torch.utils import _pytree as pytree from torch_xla2.tensor import move_to_device @@ -121,7 +110,7 @@ inputs = move_to_device(inputs) new_state_dict = pytree.tree_map_only(torch.Tensor, move_to_device, m.state_dict()) m.load_state_dict(new_state_dict, assign=True) -res = m(*inputs) +res = m(inputs) print(type(res)) # outputs XLATensor2 ``` @@ -164,5 +153,3 @@ from torch_xla2.extra import jax_jit model_func_jitted = jax_jit(model_func) print(model_func_jitted(new_state_dict, inputs)) ``` - - diff --git a/experimental/torch_xla2/dev-requirements.txt b/experimental/torch_xla2/dev-requirements.txt index 4a32310fbda..208f70d5fef 100644 --- a/experimental/torch_xla2/dev-requirements.txt +++ b/experimental/torch_xla2/dev-requirements.txt @@ -1,9 +1,3 @@ -absl-py==2.0.0 -flatbuffers==23.5.26 -jax==0.4.23 -jaxlib==0.4.23 -pytest -tensorflow -torch==2.2.1+cpu -immutabledict -sentencepiece \ No newline at end of file +-f https://download.pytorch.org/whl/torch +torch==2.3.0+cpu +ruff~=0.3.5 diff --git a/experimental/torch_xla2/docs/fixing_op_info_test.md b/experimental/torch_xla2/docs/fixing_op_info_test.md new file mode 100644 index 00000000000..03624f9487e --- /dev/null +++ b/experimental/torch_xla2/docs/fixing_op_info_test.md @@ -0,0 +1,211 @@ +# How to fix an op info test. + +## What is OpInfo test + +PyTorch created a list of python objects (OpInfo) to keep +track how to test each op. This is useful to us because it +ensures that the ops we implement produces the same results +pytorch would produce. + +Context: +* https://dev-discuss.pytorch.org/t/opinfos-in-pytorch-1-10/253 +* https://github.com/pytorch/pytorch/issues/54261 + + +## How to fix one + +### Remove one op from skiplist + +Open [test/test_ops.py](../test/test_ops.py) with your +favorite text editor. +Remove one line from the `skiplist` set. + +i.e. + +```bash +(base) hanq-macbookpro:torch_xla2 hanq$ git diff +diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py +index 72a39ae85..2a156cbce 100644 +--- a/experimental/torch_xla2/test/test_ops.py ++++ b/experimental/torch_xla2/test/test_ops.py +@@ -15,7 +15,6 @@ skiplist = { + "_native_batch_norm_legit", + "_segment_reduce", + "_upsample_bilinear2d_aa", +- "addbmm", + "addmm", + "addmv", + "addr", +``` + +### Run test to see what failure + +Error gotten: + +``` +E RuntimeError: ('No lowering found for\n\nTo execute this test, run the following from the base repo dir:\n python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64\n\nThis message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0', 'aten::addbmm') +``` + +From here we have 2 strategies for fixing this test: + +1. Add an implementation to `aten::addbmm` operator using Jax ops. Or, +2. Add an implementation `aten::addbmm` operator using torch ops (this commonly known as "decompositions"). + +Either way works for torch_xla2. For ops that are not "Core Aten" sometimes we implement in torch ops with the goal of +upstreaming this decomposition to [pytorch decompositon](https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py) +so other projects can benefit from it. + +For illustration purposes, let's implement this op in Jax. + +(NOTE: this doesn't stop us from upstreaming a decomposition later if we want) + +### First Impl + +To implement this op using jax ops, we first find what +is the exact semantics in this page: +https://pytorch.org/docs/stable/generated/torch.addbmm.html + +From it's math formula: we can implement it as follows. + +``` ++@op(torch.ops.aten.addbmm.default) ++def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): ++ ++ mm = jnp.einsum('bxy, byz -> xz', batch1, batch2) ++ return beta * input + alpha * mm +``` + +Now running test again: + +``` +python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64 +``` + +(NOTE: the exact test command is printed out when we run +`pytest test/test_ops.py` so we can only run the failed test instead of running all tests.) + +We now see this error: + +``` +FAIL: test_reference_eager_addbmm_cpu_int64 (__main__.TestOpInfoCPU) [torch_xla2_diff:0.001] +---------------------------------------------------------------------- +Traceback (most recent call last): + File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 654, in run_export_and_compare + diff_output( + File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 617, in diff_output + testcase.assertTrue( +AssertionError: False is not true +``` + +This is telling me that our implementation did not produce +the same result as the ops in PyTorch. + +To debug this, let's figure out what exact input caused this. +We can achieve this by setting a break point [here](https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/test/test_ops.py#L644), right before the diff. Here we can +inspect values of `res` and `res2`, as well as the `sample_input`. + +The sample input we get is +``` +SampleInput(input=tensor([[-3, -3, 9, 8, -8, -3, -4, 2, 2, 2], + [-5, 1, -9, 9, 1, -5, 6, 1, -4, -5], + [-2, -1, 5, -2, -3, 0, 5, -4, 9, -6], + [-1, -7, 6, 3, 8, 3, 8, 9, -5, 7], + [-3, -4, -9, 9, 7, -3, -8, 2, 5, -3]]), args=(tensor([[[-2, 4, -2, 5, 8], + [-6, -2, 5, 7, 7], + [-8, -3, 2, 5, -3], + [-4, 7, 0, -9, 8], + [ 3, 9, -9, -2, 0]], + + [[-7, 1, -3, 7, -4], + [ 3, 5, 4, 6, 5], + [-2, 8, 3, 5, 7], + [ 8, -2, -8, 2, 0], + [ 6, 1, -8, 8, 0]], + + [[ 2, -1, -5, -8, -9], + [ 5, 0, -4, -1, -6], + [-6, 2, -5, -2, -5], + [-5, -3, -5, -4, 9], + [-3, 4, -9, -9, 7]], + + [[ 2, 5, -7, -3, 8], + [-5, -7, -8, -4, 4], + [-4, -6, -3, 0, 6], + [ 8, 0, -3, -8, 2], + [-4, 3, -9, -6, 7]], + + [[ 2, 1, -6, 2, 8], + [ 2, 6, 4, 1, 8], + [-9, 9, -5, 8, 3], + [-5, 0, -2, 4, 0], + [ 5, 8, -4, 9, 7]]]), tensor([[[-1, -8, 3, 5, -8, 2, -5, 0, -9, -5], + [-4, -7, 2, 2, 1, -9, 2, 7, -1, -1], + [ 1, 8, -6, -4, -6, -8, -7, -9, 7, 4], + [-4, 1, -9, 3, 4, 6, 0, -2, -2, -7], + [ 5, 5, 0, 8, -3, 7, -7, 8, 3, 5]], + + [[ 8, -4, -9, 9, 5, 0, 5, 0, -5, 5], + [-5, -3, -2, 8, 1, -2, 4, -7, 5, 3], + [-4, 4, 1, -4, -8, 2, -5, 2, 9, -7], + [ 9, 6, -8, -3, 3, 1, 4, 6, -5, -4], + [-2, 1, 5, 5, 2, 6, 7, -3, -7, 3]], + + [[ 9, -8, 5, -3, -1, 2, -9, -5, -1, -3], + [-3, 3, -9, -7, -9, -8, 1, -3, 7, -2], + [ 8, -1, 8, -8, -7, 4, 8, 8, 5, -7], + [-1, 6, -8, 7, -1, -5, -8, 6, -2, 8], + [-5, -5, 8, 6, 0, 1, 3, -2, -3, -9]], + + [[ 7, -2, 6, -8, -5, 3, 2, -1, -5, 8], + [-6, -4, 3, 9, -9, -8, -7, 3, 9, 0], + [ 1, 3, 4, 4, -5, -2, -4, -2, 3, -7], + [-6, 9, 5, -1, 7, 7, 8, -3, -8, 0], + [-1, -6, -3, 3, 3, -8, -4, 9, -5, 7]], + + [[-5, -3, -9, 6, -1, -7, 9, -8, 1, -8], + [-8, -8, -2, -5, -7, -8, 1, 0, 0, -6], + [ 7, -5, 2, 2, 0, -9, -5, -7, 1, 8], + [-4, 0, 9, 6, -1, -6, 6, -6, -2, -1], + [ 7, 3, 0, 1, 1, -9, 5, -8, -1, -7]]])), kwargs={'beta': 0.6, 'alpha': 0.2}, broadcasts_input=False, name='') +``` + +And the `res` from torch is + +``` +tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) +``` + +So few observation is: +1. Input tensor are of type int64 +2. alpha and beta are both floats. + +So one can suspect that it has to do with rounding. +Reading the doc more carefully, we can find this sentence + + For inputs of type FloatTensor or DoubleTensor, arguments beta and alpha must be real numbers, otherwise they should be integers. + +So likely torch first casted the float alpha and beta to integer, which yields 0, then used them in math to get a matrix with all zeros. + +### Second Impl + +```python ++@op(torch.ops.aten.addbmm.default) ++def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): ++ alpha = jnp.array(alpha).astype(batch1.dtype) ++ beta = jnp.array(beta).astype(batch1.dtype) ++ mm = jnp.einsum('bxy, byz -> xz', batch1, batch2) ++ return jax.lax.cond(beta == 0, ++ lambda: alpha * mm, ++ lambda: beta*input + alpha*mm) ++ +``` + +Adding type casts makes the tests passes. + +### Submit +Now, let's remove the pdb and prints we added, and submit the fix as a PR: https://github.com/pytorch/xla/pull/6993 + diff --git a/experimental/torch_xla2/docs/ops_registry.md b/experimental/torch_xla2/docs/ops_registry.md new file mode 100644 index 00000000000..c0e68f42fc4 --- /dev/null +++ b/experimental/torch_xla2/docs/ops_registry.md @@ -0,0 +1,40 @@ +# Ops Registry + +## Background + +In the [How it works](how_it_works.md) doc, we mentioned 2 important pieces: + +1. A mechanism to route `ATen` ops to implementation written in + Jax or in PyTorch, and + +2. The ops themselves. + + +Ops Registry is there to help us to organize the ops themselves. + +An op implementation can written in terms of Jax, or in other PyTorch ops. +The latter is also known as "decompositions". For decompositions, +one need to be careful of not introducing circular dependencies. + +Here we simply store the operator implementations in a dictionary, +which key the torch / Aten callable that we wish to override, and +value an instance of `Operator` class. + +`Operator` class has this schema: + +```python +@dataclasses.dataclass +class Operator: + torch_op: TorchCallable + func: Union[TorchCallable, JaxCallable] + is_jax_function: bool + is_user_defined: bool + needs_env: bool +``` + +The `torch_op` is the corresponding torch callable, and `func` the implementation. `is_jax_function` is True if `func` is implemented using Jax, False if `func` is implemented using other torch ops. We can use this information to decide how to call it. + +If `needs_env` is true, `func` will recieve an extra kwarg with name `env`. +This will be the "Environment" in which this op operate on. In particular, +the environment will contain the Jax random number generator key, that might be useful for ops like `aten::rand`. + diff --git a/experimental/torch_xla2/examples/basic_training.py b/experimental/torch_xla2/examples/basic_training.py index 5d3f5a734c5..29e55700a32 100644 --- a/experimental/torch_xla2/examples/basic_training.py +++ b/experimental/torch_xla2/examples/basic_training.py @@ -10,7 +10,11 @@ from torch.utils import _pytree as pytree import torchvision import torchvision.transforms as transforms -import torch_xla2 +import torch_xla2.tensor + + +xla_env = torch_xla2.tensor.Environment(0) +mode = xla_env.mode() # PyTorch TensorBoard support from torch.utils.tensorboard import SummaryWriter @@ -80,6 +84,7 @@ def forward(self, x): model = GarmentClassifier() +model = xla_env.to_xla(model) loss_fn = torch.nn.CrossEntropyLoss() @@ -96,13 +101,6 @@ def forward(self, x): print('Total loss for this batch: {}'.format(loss.item())) # Optimizers specified in the torch.optim package - -# NEW: Move model to XLA device -state_dict = model.state_dict() -state_dict = pytree.tree_map_only(torch.Tensor, - torch_xla2.tensor.move_to_device, state_dict) -model.load_state_dict(state_dict, strict=False, assign=True) - optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) def train_one_epoch(epoch_index, tb_writer): @@ -115,14 +113,14 @@ def train_one_epoch(epoch_index, tb_writer): for i, data in enumerate(training_loader): # Every data instance is an input + label pair # NEW: Move model to XLA device - data = pytree.tree_map_only(torch.Tensor, - torch_xla2.tensor.move_to_device, data) + data = xla_env.to_xla(data) inputs, labels = data # Zero your gradients for every batch! optimizer.zero_grad() # Make predictions for this batch + outputs = model(inputs) # Compute the loss and its gradients @@ -169,14 +167,11 @@ def train_one_epoch(epoch_index, tb_writer): # Disable gradient computation and reduce memory consumption. with torch.no_grad(): for i, vdata in enumerate(validation_loader): - # NOTE: move to XLA device - vinputs, vlabels = pytree.tree_map_only( - torch.Tensor, - torch_xla2.tensor.move_to_device, - vdata) - voutputs = model(vinputs) # call model's forward - vloss = loss_fn(voutputs, vlabels) - running_vloss += vloss + # NOTE: move to XLA device + vinputs, vlabels = xla_env.to_xla(vdata) + voutputs = model(vinputs) # call model's forward + vloss = loss_fn(voutputs, vlabels) + running_vloss += vloss avg_vloss = running_vloss / (i + 1) print('LOSS train {} valid {}'.format(avg_loss, avg_vloss)) diff --git a/experimental/torch_xla2/examples/basic_training_jax.py b/experimental/torch_xla2/examples/basic_training_jax.py index 3941fcdf8fe..ae6efdf4856 100644 --- a/experimental/torch_xla2/examples/basic_training_jax.py +++ b/experimental/torch_xla2/examples/basic_training_jax.py @@ -8,7 +8,7 @@ import torchvision import torchvision.transforms as transforms import torch_xla2 -import torch_xla2.extra +import torch_xla2.interop import jax import optax import numpy as np @@ -91,7 +91,7 @@ def forward(self, x): def jax_loss(weights, data, label): pred = jax_func(weights, data) - loss = torch_xla2.extra.call_torch(loss_fn, pred, label) + loss = torch_xla2.interop.call_torch(loss_fn, pred, label) return loss grad_fn = jax.jit(jax.value_and_grad(jax_loss)) @@ -155,12 +155,6 @@ def train_one_epoch(jax_weights, opt_state, epoch_index, tb_writer): # Make sure gradient tracking is on, and do a pass over the data model.train(True) - # NEW: Move model to XLA device - state_dict = model.state_dict() - state_dict = pytree.tree_map_only(torch.Tensor, - torch_xla2.tensor.move_to_device, state_dict) - model.load_state_dict(state_dict, strict=False, assign=True) - avg_loss, opt_state = train_one_epoch(jax_weights, opt_state, epoch_number, writer) running_vloss = 0.0 @@ -174,7 +168,7 @@ def train_one_epoch(jax_weights, opt_state, epoch_index, tb_writer): vinputs, vlabels = pytree.tree_map_only(torch.Tensor, torch_xla2.tensor.t2j, vdata) voutputs = jax_func(jax_weights, (vinputs, )) # call model's forward - vloss = torch_xla2.extra.call_torch(loss_fn, voutputs, vlabels) + vloss = torch_xla2.interop.call_torch(loss_fn, voutputs, vlabels) running_vloss += vloss avg_vloss = running_vloss / (i + 1) diff --git a/experimental/torch_xla2/examples/eager_mode.py b/experimental/torch_xla2/examples/eager_mode.py index 358ee6256c6..755f24b0d2b 100644 --- a/experimental/torch_xla2/examples/eager_mode.py +++ b/experimental/torch_xla2/examples/eager_mode.py @@ -1,10 +1,9 @@ - -from torch_xla2.tensor import move_to_device import torch_xla2 from torch import nn from torch.nn import functional as F import torch -from torch.utils import _pytree as pytree + +xla_env = torch_xla2.default_env() class MyModel(nn.Module): @@ -22,21 +21,21 @@ def forward(self, x): return x m = MyModel() +m = xla_env.to_xla(m) # Execute this model using torch inputs = (torch.randn(3, 3, 28, 28), ) +inputs = xla_env.to_xla(inputs) -inputs, state_dict = pytree.tree_map_only(torch.Tensor, move_to_device, (inputs, m.state_dict())) -m.load_state_dict(state_dict, strict=False, assign=True) print(m(*inputs)) print('---=====') -from torch_xla2.extra import jax_jit +from torch_xla2.interop import jax_jit @jax_jit def model_func(param, inputs): return torch.func.functional_call(m, param, inputs) -print(model_func(state_dict, inputs)) +print(model_func(m.state_dict(), inputs)) diff --git a/experimental/torch_xla2/pyproject.toml b/experimental/torch_xla2/pyproject.toml index d0d2a42dec8..0c2101dbcb9 100644 --- a/experimental/torch_xla2/pyproject.toml +++ b/experimental/torch_xla2/pyproject.toml @@ -2,29 +2,30 @@ requires = ["hatchling"] build-backend = "hatchling.build" - [project] version = "0.0.1" name = "torch_xla2" dependencies = [ "absl-py", - "flatbuffers", + "immutabledict", + "jax>=0.4.24", "pytest", - "tensorflow", - - # Note: Exclude these because otherwise on pip install . - # pip will install libs from pypi which is the GPU version - # of these libs. - # We most likely need CPU version of torch and TPU version of - # jax. So it's best for users to install them by hand - # See more at README.md - # "jax>=0.4.24", - # "jaxlib>=0.4.24", - # "torch", + "tensorflow-cpu", + # Developers should install `dev-requirements.txt` first + "torch>=2.2.1", ] - requires-python = ">=3.10" license = {file = "LICENSE"} +[project.optional-dependencies] +cpu = ["jax[cpu]"] +# Add libtpu index `-f https://storage.googleapis.com/libtpu-releases/index.html` +tpu = ["jax[tpu]"] +cuda = ["jax[cuda12]"] + [tool.pytest.ini_options] addopts="-n auto" + +[tool.ruff] +line-length = 80 +indent-width = 2 diff --git a/experimental/torch_xla2/test-requirements.txt b/experimental/torch_xla2/test-requirements.txt new file mode 100644 index 00000000000..1deead455a1 --- /dev/null +++ b/experimental/torch_xla2/test-requirements.txt @@ -0,0 +1,5 @@ +-r dev-requirements.txt +pytest +pytest-xdist +sentencepiece +expecttest diff --git a/experimental/torch_xla2/test/gemma/test_gemma.py b/experimental/torch_xla2/test/gemma/test_gemma.py index bd0bb21dbb1..4d91bc6f9b0 100644 --- a/experimental/torch_xla2/test/gemma/test_gemma.py +++ b/experimental/torch_xla2/test/gemma/test_gemma.py @@ -74,7 +74,7 @@ def test_gemma(self): weights, jax_func = torch_xla2.extract_jax(model) inputs_jax = pytree.tree_map_only( - torch.Tensor, torch_xla2.tensor.move_to_device, inputs) + torch.Tensor, torch_xla2.tensor.t2j, inputs) import jax print(jax.jit(jax_func)(weights, inputs_jax)) diff --git a/experimental/torch_xla2/test/llama/test_llama.py b/experimental/torch_xla2/test/llama/test_llama.py index dae7bf0cc5c..083116ab89e 100644 --- a/experimental/torch_xla2/test/llama/test_llama.py +++ b/experimental/torch_xla2/test/llama/test_llama.py @@ -1,8 +1,5 @@ -import unittest -import jax import torch -from torch._functorch.make_functional import make_functional_with_buffers -from torch_xla2 import tensor, ops # pylint: disable=unused-import +from torch_xla2 import tensor # pylint: disable=unused-import import torch_xla2 from .. import test_base diff --git a/experimental/torch_xla2/test/test_context.py b/experimental/torch_xla2/test/test_context.py index 1a75a7d23d0..a6bcda5113a 100644 --- a/experimental/torch_xla2/test/test_context.py +++ b/experimental/torch_xla2/test/test_context.py @@ -1,20 +1,22 @@ import unittest import torch -import torch_xla2 from torch_xla2 import tensor +xla_env = tensor.Environment(0) + class TestContext(unittest.TestCase): + def test_mode_context_manager(self): - with torch_xla2.mode(): + with xla_env: x = torch.full((3, 3), -1) self.assertIsInstance(x, tensor.XLATensor2) y = x.abs() self.assertIsInstance(y, tensor.XLATensor2) @staticmethod - @torch_xla2.mode() + @xla_env def _test_mode_decorator(): x = torch.full((3, 3), -1) y = x.abs() diff --git a/experimental/torch_xla2/test/test_core_aten_ops.py b/experimental/torch_xla2/test/test_core_aten_ops.py index 357e41c9101..6a1cef306be 100644 --- a/experimental/torch_xla2/test/test_core_aten_ops.py +++ b/experimental/torch_xla2/test/test_core_aten_ops.py @@ -1,7 +1,6 @@ import unittest import torch -from torch_xla2 import ops_registry from torch_xla2 import tensor from . import test_base @@ -34,12 +33,13 @@ def run_export_and_compare(testcase, rtol=1e-5, equal_nan=True, ignore_indices=False): + with testcase.subTest("torch_eval"): res = func(*args, **kwargs) with testcase.subTest("torch_xla2_eval"): - args2, kwargs2 = pytree.tree_map_only(torch.Tensor, tensor.move_to_device, - (args, kwargs)) - res2 = func(*args2, **kwargs2) + args2, kwargs2 = testcase.env.to_xla((args, kwargs)) + with testcase.env: + res2 = func(*args2, **kwargs2) res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2) # import pdb; pdb.set_trace() with testcase.subTest("torch_xla2_diff:" + str(atol)): @@ -61,11 +61,11 @@ class TestCoreAtenOps(unittest.TestCase): @classmethod def setUpClass(cls): super().setUpClass() - ops_registry.print_missing_ops() def setUp(self): super().setUp() torch.manual_seed(0) + self.env = tensor.Environment(0) def test_aten_abs_0(self): args = (torch.randn((10, 10)).to(torch.float32),) @@ -2109,7 +2109,7 @@ def test_aten_logit_0(self): def test_aten_logit_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() - run_export_and_compare(self, torch.ops.aten.logit, args, kwargs) + run_export_and_compare(self, torch.ops.aten.logit, args, kwargs, atol=0.01,) def test_aten_logit_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) @@ -3639,8 +3639,9 @@ def test_aten__softmax_1(self): def _compare_sorted_result(self, args): res = torch.ops.aten.sort(*args) with self.subTest("torch_xla2_eval"): - args2 = pytree.tree_map_only(torch.Tensor, tensor.move_to_device, args) - res2 = torch.ops.aten.sort(*args2) + args2 = self.env.to_xla(args) + with self.env: + res2 = torch.ops.aten.sort(*args2) # The second argument is the sorted index. These might not be # identical from torch vs. jax; but both can be correct diff --git a/experimental/torch_xla2/test/test_extra.py b/experimental/torch_xla2/test/test_extra.py deleted file mode 100644 index 768488d6a99..00000000000 --- a/experimental/torch_xla2/test/test_extra.py +++ /dev/null @@ -1,64 +0,0 @@ -import unittest -import torch -import torch.nn.functional as F -import jax -import jax.numpy as jnp -import torch_xla2 -from torch_xla2 import tensor, extra - - -class ExtraTest(unittest.TestCase): - - def setUp(self): - torch.manual_seed(0) - - def test_standard_callable(self): - def f(a, b): - return torch.add(a, b) - - a = jnp.ones((10, )) - b = jnp.ones((10, )) - - c = extra.jax_view(f)(a, b) - self.assertTrue(jnp.allclose(c, a + b)) - - def f2(a, b): - return jnp.add(a, b) - - a = tensor.move_to_device(torch.ones((10, ))) - b = tensor.move_to_device(torch.ones((10, ))) - c2 = extra.torch_view(f2)(a, b) - - self.assertTrue(jnp.allclose(c2._elem, c)) - - - - def test_fori_loop(self): - a = tensor.move_to_device(torch.ones((10, 10))) - - def body(i, c): - return c + a[i] - - init_val = tensor.move_to_device(torch.zeros(10)) - res = extra.fori_loop(0, 10, body, init_val) - expect = torch.ones(10) * 10 - self.assertTrue(torch.allclose(tensor.j2t(res._elem), expect)) - - def test_jax_jit(self): - - # functions that acts on torch tensor - def f(a, b): - return torch.sin(a) + torch.cos(b) - - fjitted = extra.jax_jit(f) - a = torch.rand((10, 10)) - b = torch.rand((10, 10)) - aj = tensor.move_to_device(a) - bj = tensor.move_to_device(b) - res = f(a, b) - res2 = fjitted(aj, bj) - self.assertTrue(torch.allclose(res, tensor.j2t(res2._elem))) - - -if __name__ == '__main__': - unittest.main() diff --git a/experimental/torch_xla2/test/test_functions.py b/experimental/torch_xla2/test/test_functions.py index 76e842d6fdd..2d624b25b5b 100644 --- a/experimental/torch_xla2/test/test_functions.py +++ b/experimental/torch_xla2/test/test_functions.py @@ -3,12 +3,14 @@ from absl.testing import parameterized import torch import torch_xla2 -import torch_xla2.functions import torch_xla2.tensor class TestTorchFunctions(parameterized.TestCase): + def setUp(self): + self.env = torch_xla2.tensor.Environment(0) + @parameterized.named_parameters( ('tensor_2d', lambda: torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])), ('tensor_1d', lambda: torch.tensor([0, 1],)), @@ -32,7 +34,7 @@ class TestTorchFunctions(parameterized.TestCase): def test_tensor_constructor(self, func: Callable[[], torch.Tensor]): expected = func() - with torch_xla2.functions.XLAFunctionMode(): + with self.env: actual = func() self.assertIsInstance(actual, torch_xla2.tensor.XLATensor2) diff --git a/experimental/torch_xla2/test/test_mutations.py b/experimental/torch_xla2/test/test_mutations.py index 2f9ddca975b..50d78aa0fae 100644 --- a/experimental/torch_xla2/test/test_mutations.py +++ b/experimental/torch_xla2/test/test_mutations.py @@ -6,46 +6,43 @@ class TestMutations(TestCase): - def test_add(self): - x = torch.tensor([1, 2, 3], dtype=torch.int32) - y = torch.tensor([4, 5, 6], dtype=torch.int32) + def setUp(self): + self.env = torch_xla2.tensor.Environment(0) - x = torch_xla2.tensor.move_to_device(x) - y = torch_xla2.tensor.move_to_device(y) - x.add_(y) - xt = torch_xla2.tensor.j2t(x._elem) - self.assertEqual(xt, torch.tensor([5, 7, 9], dtype=torch.int32)) + def test_add(self): + with self.env: + x = torch.tensor([1, 2, 3], dtype=torch.int32) + y = torch.tensor([4, 5, 6], dtype=torch.int32) + x.add_(y) + xt = torch_xla2.tensor.j2t(x._elem) + self.assertEqual(xt, torch.tensor([5, 7, 9], dtype=torch.int32)) def test_sub(self): - x = torch.tensor([1, 2, 3], dtype=torch.int32) - y = torch.tensor([4, 5, 6], dtype=torch.int32) - - x = torch_xla2.tensor.move_to_device(x) - y = torch_xla2.tensor.move_to_device(y) - x.sub_(y) - xt = torch_xla2.tensor.j2t(x._elem) - self.assertEqual(xt, torch.tensor([-3, -3, -3], dtype=torch.int32)) + with self.env: + x = torch.tensor([1, 2, 3], dtype=torch.int32) + y = torch.tensor([4, 5, 6], dtype=torch.int32) + x.sub_(y) + xt = torch_xla2.tensor.j2t(x._elem) + self.assertEqual(xt, torch.tensor([-3, -3, -3], dtype=torch.int32)) def test_mul(self): - x = torch.tensor([1, 2, 3], dtype=torch.int32) - y = torch.tensor([4, 5, 6], dtype=torch.int32) + with self.env: + x = torch.tensor([1, 2, 3], dtype=torch.int32) + y = torch.tensor([4, 5, 6], dtype=torch.int32) - x = torch_xla2.tensor.move_to_device(x) - y = torch_xla2.tensor.move_to_device(y) - x.mul_(y) - xt = torch_xla2.tensor.j2t(x._elem) - self.assertEqual(xt, torch.tensor([4, 10, 18], dtype=torch.int32)) + x.mul_(y) + xt = torch_xla2.tensor.j2t(x._elem) + self.assertEqual(xt, torch.tensor([4, 10, 18], dtype=torch.int32)) def test_div(self): - x = torch.tensor([1, 2, 3], dtype=torch.int32) - y = torch.tensor([4, 5, 6], dtype=torch.int32) - - x = torch_xla2.tensor.move_to_device(x) - y = torch_xla2.tensor.move_to_device(y) - x.div_(y) - xt = torch_xla2.tensor.j2t(x._elem) - self.assertEqual(xt, - torch.tensor([1. / 4, 2. / 5, 3. / 6], dtype=torch.float)) + with self.env: + x = torch.tensor([1, 2, 3], dtype=torch.int32) + y = torch.tensor([4, 5, 6], dtype=torch.int32) + + x.div_(y) + xt = torch_xla2.tensor.j2t(x._elem) + self.assertEqual(xt, + torch.tensor([1. / 4, 2. / 5, 3. / 6], dtype=torch.float)) if __name__ == '__main__': diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index ed14e636e5c..20686f2fe6c 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -8,6 +8,7 @@ from torch.utils import _pytree as pytree from torch_xla2 import tensor + skiplist = { "__getitem__", "__rmatmul__", @@ -15,19 +16,6 @@ "_native_batch_norm_legit", "_segment_reduce", "_upsample_bilinear2d_aa", - "addbmm", - "addmm", - "addmv", - "addr", - "all", - "allclose", - "amax", - "amin", - "aminmax", - "angle", - "any", - "argmax", - "argmin", "argsort", "as_strided", "as_strided_scatter", @@ -570,6 +558,7 @@ "special.xlog1py", "split", "split_with_sizes", + "split_with_sizes_copy", "sqrt", "square", "stack", @@ -636,10 +625,10 @@ def run_export_and_compare(testcase, with testcase.subTest("torch_eval"): res = func(sample_input.input, *sample_input.args, **sample_input.kwargs) with testcase.subTest("torch_xla2_eval"): - input2, args2, kwargs2 = pytree.tree_map_only( - torch.Tensor, tensor.move_to_device, - (sample_input.input, sample_input.args, sample_input.kwargs)) - res2 = func(input2, *args2, **kwargs2) + input2, args2, kwargs2 = testcase.env.to_xla(( + sample_input.input, sample_input.args, sample_input.kwargs)) + with testcase.env: + res2 = func(input2, *args2, **kwargs2) res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2) with testcase.subTest("torch_xla2_diff:" + str(atol)): if ignore_indices and isinstance(res, tuple) and len(res) == 2: @@ -664,6 +653,9 @@ class TestOpInfo(TestCase): def setUpClass(cls): print('op_db size: ', len(op_db), 'testing: ', len(ops_to_test)) + def setUp(self): + self.env = tensor.Environment(0) + @ops(ops_to_test, allowed_dtypes=(torch.float32, torch.long)) def test_reference_eager(self, device, dtype, op): sample_inputs = op.sample_inputs(device, dtype) diff --git a/experimental/torch_xla2/test_requirements.txt b/experimental/torch_xla2/test_requirements.txt deleted file mode 100644 index c8596327236..00000000000 --- a/experimental/torch_xla2/test_requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -pytest -immutabledict -sentencepiece -pytest-xdist -expecttest \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/__init__.py b/experimental/torch_xla2/torch_xla2/__init__.py index b0bb20712d4..bd0e00fa6ca 100644 --- a/experimental/torch_xla2/torch_xla2/__init__.py +++ b/experimental/torch_xla2/torch_xla2/__init__.py @@ -1,31 +1,34 @@ -import contextlib import jax import torch from torch._functorch import make_functional from torch.utils import _pytree as pytree -from torch_xla2 import export, _ops, ops_registry, tensor, tf_integration, functions +from torch_xla2 import export, tensor, tf_integration jax.config.update('jax_enable_x64', True) +env = None +def default_env(): + global env + if env is None: + env = tensor.Environment(0) + return env -@contextlib.contextmanager -def mode(): - with tensor.XLADispatchMode(), functions.XLAFunctionMode(): - yield -def extract_jax(mod: torch.nn.Module): +def extract_jax(mod: torch.nn.Module, env=None): """Returns a pytree of jax.ndarray and a jax callable.""" + if env is None: + env = default_env() func, weights, buffer = make_functional.make_functional_with_buffers(mod) - states = (weights, buffer) + states = mod.state_dict() + states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states) #@jax.jit def jax_func(states, inputs): - (states, inputs) = tensor.wrap((states, inputs)) - weights, buffer = states - with tensor.XLADispatchMode(): - res = func(weights, buffer, *inputs) - return tensor.unwrap(res) + (states, inputs) = env.j2t_iso((states, inputs)) + with env: + res = torch.func.functional_call(mod, states, inputs) + return env.t2j_iso(res) return states, jax_func diff --git a/experimental/torch_xla2/torch_xla2/_ops.py b/experimental/torch_xla2/torch_xla2/_ops.py deleted file mode 100644 index fe0f97a0f01..00000000000 --- a/experimental/torch_xla2/torch_xla2/_ops.py +++ /dev/null @@ -1,1745 +0,0 @@ -# pylint: disable -"""Torch ops implemented using jax.""" -import sys - -import jax -from jax import numpy as jnp -import numpy as np -import torch -from torch_xla2 import ops_registry -from torch_xla2 import tensor - - -class TorchFunctionLowering: - - def __init__(self, func, is_jax_func, should_jit=False): - if is_jax_func and should_jit: - func = jax.jit(func) - self.func = func - self.is_jax_func = is_jax_func - - def __call__(self, *args, **kwargs): - if self.is_jax_func: - (args, kwargs) = tensor.unwrap((args, kwargs)) - res = self.func(*args, **kwargs) - if self.is_jax_func: - res = tensor.wrap(res) - return res - - -def op(aten_op, is_jax_func=True): - """if is_jax_func is true, then the function it will register - - should takes jax array as input and returns jax array. - - Which means we need to wrap it - """ - - def inner(func): - ops_registry.lowerings.register(aten_op, - TorchFunctionLowering(func, is_jax_func)) - return func - - return inner - - -@op(torch.ops.aten.view_copy) -@op(torch.ops.aten.view) -@op(torch.ops.aten._unsafe_view) -@op(torch.ops.aten.reshape) -def _aten_unsafe_view(x, shape): - return jnp.reshape(x, shape) - - -@op(torch.ops.aten.add) -def _aten_add(x, y, *, alpha=1): - """if isinstance(x, jnp.ndarray) and isinstance(y, jnp.ndarray): - - assert x.dtype == y.dtype, (x.dtype, y.dtype) - """ - return x + y * alpha - - -@op(torch.ops.aten.copy_, is_jax_func=False) -def _aten_copy(x, y, memory_format=None): - if isinstance(x, tensor.XLATensor2): - x._elem = y._elem - elif isinstance(x, tensor.SliceView): - x.mutate(y) - return x - - -@op(torch.ops.aten.clone) -def _aten_clone(x, memory_format=None): - return jnp.copy(x) - - -@op(torch.ops.aten.full) -def _aten_full(size, value, **kwargs): - return jnp.full(size, value) - - -@op(torch.ops.aten.index_copy) -def _aten_index_copy(x, dim, indexes, source): - # return jax.lax.scatter(x, index, dim) - dims = [] - for i in range(len(x.shape)): - if i == dim: - dims.append(indexes) - else: - dims.append(slice(None, None, None)) - return x.at[dim].set(source) - - -@op(torch.ops.aten.select) -@op(torch.ops.aten.index_select) -@op(torch.ops.aten.select_copy) -def _aten_index_select(x, dim, indexes): - dims = [] - for i in range(len(x.shape)): - if i == dim: - dims.append(indexes) - else: - dims.append(slice(None, None, None)) - return x[tuple(dims)] - - -@op(torch.ops.aten.mean) -def _aten_mean(x, dim=None, keepdim=False): - return jnp.mean(x, dim, keepdims=keepdim) - - -def _torch_binary_scalar_type(scalar, tensor): - if "float" in str(tensor.dtype): - return tensor.dtype - - if isinstance(scalar, int): - if "int" in str(tensor.dtype): - return tensor.dtype - - return jnp.float32 - - -@op(torch.ops.aten.sub) -def _aten_sub(x, y): - if isinstance(x, float): - dtype = _torch_binary_scalar_type(x, y) - x = jnp.array(x, dtype=dtype) - if isinstance(y, float): - dtype = _torch_binary_scalar_type(y, x) - y = jnp.array(y, dtype=dtype) - return x - y - - -@op(torch.ops.aten.mm) -def _aten_mm(x, y): - res = x @ y - return res - - -@op(torch.ops.aten.mul) -def _aten_mul(x, y): - return x * y - - -@op(torch.ops.aten.silu) -def _aten_silu(x): - return jax.nn.silu(x) - - -@op(torch.ops.aten.t) -def _aten_t(x): - return jnp.transpose(x) - - -@op(torch.ops.aten.transpose) -@op(torch.ops.aten.transpose_copy) -def _aten_transpose(x, dim0, dim1): - shape = list(range(len(x.shape))) - shape[dim0], shape[dim1] = shape[dim1], shape[dim0] - return jnp.transpose(x, shape) - - -@op(torch.ops.aten.triu) -def _aten_triu(m, k): - return jnp.triu(m, k) - - -@op(torch.ops.aten.slice) -@op(torch.ops.aten.slice_copy) -def _aten_slice(self, dim=0, start=None, end=None, step=1): - if end == sys.maxsize: - end = self.shape[dim] - sl = slice(start, end, step) - dims = [] - for i in range(len(self.shape)): - if i == dim: - dims.append(sl) - else: - dims.append(slice(None, None, None)) - return self[tuple(dims)] - - -@op(torch.ops.aten.detach) -def _aten_detach(self): - return self - - -@op(torch.ops.aten.view_as_real) -def _aten_view_as_real(x): - real = jnp.real(x) - im = jnp.imag(x) - res = jnp.stack([real, im], -1) - return res - - -@op(torch.ops.aten.stack) -def _aten_stack(tensors, dim=0): - return jnp.stack(tensors, dim) - - -@op(torch.ops.aten._softmax) -def _aten_softmax(x, dim, halftofloat): - return jax.nn.softmax(x, dim) - - -@op(torch.ops.aten.pow) -def _aten_pow(x, y): - if isinstance(y, int): - y = float(y) - return jnp.power(x, y) - - -@op(torch.ops.aten.view_as_complex) -def _aten_view_as_complex(input): - if input.dtype == jnp.bfloat16: - input = input.astype(jnp.float32) - x, y = input[..., 0], input[..., 1] - return jax.lax.complex(x, y) - - -@op(torch.ops.aten.div) -def _aten_div(x, y, rounding_mode=""): - res = x / y - if rounding_mode == "trunc": - res = jnp.trunc(res) - return res - - -@op(torch.ops.aten.div_, is_jax_func=False) -def _aten_div_(x, y, rounding_mode=""): - x._elem = _aten_div(x._elem, y._elem, rounding_mode) - return x - - -@op(torch.ops.aten.true_divide) -def _aten_true_divide(x, y): - return x / y - - -@op(torch.ops.aten.bmm) -def _aten_bmm(x, y): - res = x @ y - return res - # return jnp.einsum('bnm,bmk->bnk', x, y) - - -@op(torch.ops.aten.embedding) -# embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -def _aten_embedding(a, w, padding_idx=-1): - return jnp.take(a, w, axis=0) - - -@op(torch.ops.aten.rsqrt) -def _aten_rsqrt(x): - if isinstance(x, int): - x = float(x) - if x.dtype == jnp.int32: - x = x.astype(jnp.float32) - return jax.lax.rsqrt(x) - - -@op(torch.ops.aten.expand) -@op(torch.ops.aten.expand_copy) -def _aten_expand(x, dims): - - def fix_dims(d, xs): - if d == -1: - return xs - return d - - dims = [fix_dims(p, s) for p, s in zip(dims, x.shape)] - return jnp.broadcast_to(x, dims) - - -@op(torch.ops.aten.dot) -def _aten_dot(x, y): - return jnp.dot(x, y) - - -@op(torch.ops.aten._to_copy) -def _aten__to_copy(self, **kwargs): - dtype = tensor.t2j_dtype(kwargs["dtype"]) - if dtype != self.dtype: - return self.astype(dtype) - return jnp.copy(self) - - -@op(torch.ops.aten.empty) -def _aten_empty(sizes, **kwargs): - return jnp.zeros(sizes) - - -@op(torch.ops.aten.index_put_) -@op(torch.ops.aten.index_put) -def _aten_index_put(self, indexes, values, accumulate=False): - indexes = [slice(None, None, None) if i is None else i for i in indexes] - indexes = tuple(indexes) - if accumulate: - return self.at[indexes].add(values) - else: - return self.at[indexes].set(values) - - -@op(torch.ops.aten.index) -@op(torch.ops.aten._unsafe_index) -@op(torch.ops.aten.index.Tensor) -def _aten_index(self, indexes): - indexes = [slice(None, None, None) if i is None else i for i in indexes] - indexes = tuple(indexes) - return self[indexes] - - -@op(torch.ops.aten.split) -@op(torch.ops.aten.split_copy) -@op(torch.ops.aten.split_with_sizes) -def split_with_sizes(x, sizes, dim=0): - """Splits an array `x` into sub-arrays based on static sizes `sizes`. - - Args: - x: The input array to split. - sizes: A 1D array of integer sizes for each sub-array. - - Returns: - A list of sub-arrays. - """ - if isinstance(sizes, int): - # split equal size - new_sizes = [sizes] * (x.shape[dim] // sizes) - sizes = new_sizes - rank = x.ndim - splits = np.cumsum(sizes) # Cumulative sum for split points - - def make_range(rank, dim, start, end): - res = [slice(None, None, None)] * rank - res[dim] = slice(start, end) - return tuple(res) - - return [ - x[make_range(rank, dim, start, end)] - for start, end in zip([0] + list(splits[:-1]), splits) - ] - - -@op(torch.ops.aten.permute) -@op(torch.ops.aten.permute_copy) -def permute(t, dims): - return jnp.transpose(t, dims) - - -@op(torch.ops.aten.unsqueeze) -@op(torch.ops.aten.unsqueeze_copy) -@op(torch.ops.aten.unsqueeze.default) -def _aten_unsqueeze(self, dim): - if dim < 0: - dim += self.ndim + 1 - return jnp.expand_dims(self, dim) - - -@op(torch.ops.aten.ne) -def _aten_ne(x, y): - return jnp.not_equal(x, y) - - -@op(torch.ops.aten.cumsum) -def _aten_cumsum(x, y, dtype=None): - dtype = tensor.t2j_dtype(dtype) - res = jnp.cumsum(x, y, dtype) - return res - - -@op(torch.ops.aten.native_layer_norm) -def _aten_native_layer_norm(input, - normalized_shape, - weight=None, - bias=None, - eps=1e-5): - """Implements layer normalization in Jax as defined by `aten::native_layer_norm`. - - Args: - input: The input tensor. - normalized_shape: A list of integer dimensions to be normalized over. - weight: Optional weight tensor for the affine transformation. - bias: Optional bias tensor for the affine transformation. - eps: A small epsilon value for numerical stability. - - Returns: - output: The normalized tensor. - mean: The calculated mean tensor. - std: The calculated standard deviation tensor. - """ - if isinstance(normalized_shape, int): - normalized_shape = [normalized_shape] - axis = [i for i, d in enumerate(input.shape) if d in normalized_shape] - - # Calculate mean and standard deviation - mean = jnp.mean(input, axis=axis, keepdims=True) - var = jnp.var(input, axis=axis, keepdims=True) - rstd = jax.lax.rsqrt(var + eps) - - # Normalize the input - norm_x = (input - mean) * rstd - - # Apply affine transformation (if provided) - if weight is not None: - norm_x *= weight - if bias is not None: - norm_x += bias - return norm_x, mean, rstd - - -# - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor -@op(torch.ops.aten.addmm) -def _aten_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0): - self *= beta - self += alpha * jnp.matmul(mat1, mat2) - return self - - -@op(torch.ops.aten.gelu) -def _aten_gelu(self, *, approximate="none"): - approx = approximate == "tanh" - return jax.nn.gelu(self, approx) - - -@op(torch.ops.aten.squeeze) -@op(torch.ops.aten.squeeze_copy) -def _aten_squeeze_dim(self, dim): - """Squeezes a Jax tensor by removing a single dimension of size 1. - - Args: - self: The input tensor. - dim: The dimension to squeeze. - - Returns: - The squeezed tensor with the specified dimension removed if it is 1, - otherwise the original tensor is returned. - """ - - # Validate input arguments - if not isinstance(self, jnp.ndarray): - raise TypeError(f"Expected a Jax tensor, got {type(self)}.") - if isinstance(dim, int): - dim = [dim] - - # Check if the specified dimension has size 1 - if all([self.shape[d] != 1 for d in dim]): - return self - - # Use slicing to remove the dimension if it is 1 - new_shape = list(self.shape) - - def fix_dim(p): - if p < 0: - return p + len(self.shape) - return p - - dim = [fix_dim(d) for d in dim] - new_shape = [p for i, p in enumerate(self.shape) if i not in dim or p != 1] - return self.reshape(new_shape) - - -@op(torch.ops.aten.convolution) -def _aten_convolution( - input, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, -): - if transposed: - raise NotImplementedError("Transposed convolution is not implemented.") - - def make_padding(padding): - return ((p, p) for p in padding) - - def create_default_conv_dimension_numbers(num_spatial_dims): - # Ref: https://github.com/openxla/xla/blob/main/xla/client/xla_builder.cc#L4211 - # (batch dimension, feature dimension, spatial dimensions...) - lhs_spec = [0, 1] - # (out feature dimension, in feature dimension, spatial dimensions...) - rhs_spec = [0, 1] - # (batch dimension, feature dimension, spatial dimensions...) - out_spec = [0, 1] - for i in range(0, num_spatial_dims): - lhs_spec.append(i + 2) - rhs_spec.append(i + 2) - out_spec.append(i + 2) - return jax.lax.ConvDimensionNumbers( - *map(tuple, (lhs_spec, rhs_spec, out_spec))) - - res = jax.lax.conv_general_dilated( - input, - weight, - stride, - make_padding(padding), - lhs_dilation=(1,) * len(stride), - rhs_dilation=dilation, - dimension_numbers=create_default_conv_dimension_numbers(len(stride)), - feature_group_count=groups, - batch_group_count=1, - ) - - if bias is not None: - # TODO(qihqi): bias always on channel? - if len(bias.shape) == 1: - shape = [1] * len(res.shape) - shape[1] = bias.shape[0] - bias = bias.reshape(tuple(shape)) - res = res + bias - return res - - -# _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -@op(torch.ops.aten._native_batch_norm_legit) -def _aten__native_batch_norm_legit(input, weight, bias, running_mean, - running_var, training, momentum, eps): - return _aten__native_batch_norm_legit_no_training(input, weight, bias, - running_mean, running_var, - momentum, eps) - - -@op(torch.ops.aten._native_batch_norm_legit_no_training) -def _aten__native_batch_norm_legit_no_training(input, weight, bias, - running_mean, running_var, - momentum, eps): - if weight is None: - weight = jnp.ones_like(running_mean) - if bias is None: - bias = jnp.zeros_like(running_mean) - - def broadcast(t): - return jax.lax.broadcast_in_dim(t, input.shape, broadcast_dimensions=(1,)) - - a = input - broadcast(running_mean) - b = broadcast(jnp.sqrt(running_var + eps)) - return ( - a / b * broadcast(weight) + broadcast(bias), - jnp.array([]), - jnp.array([]), - ) - - -@op(torch.ops.aten.relu) -def _aten_relu(self): - return jax.nn.relu(self) - - -@op(torch.ops.aten.cat) -def _aten_cat(tensors, dims=0): - return jnp.concatenate(tensors, dims) - - -@op(torch.ops.aten.max_pool2d_with_indices) -@op(torch.ops.aten.max_pool3d_with_indices) -def _aten_max_pool2d_with_indices(inputs, - kernel_size, - strides, - padding=0, - dilation=1, - ceil_mode=False): - num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 - kernel_size = tuple(kernel_size) - strides = tuple(strides) - if isinstance(padding, int): - padding = tuple((padding, padding) for _ in range(len(kernel_size))) - elif isinstance(padding, list): - padding = tuple((p, p) for p in padding) - - window_shape = kernel_size - num_batch_dims = inputs.ndim - (len(window_shape) + 1) - strides = strides or (1,) * len(window_shape) - assert len(window_shape) == len( - strides), f'len({window_shape}) must equal len({strides})' - strides = (1,) * (1 + num_batch_dims) + strides - dims = (1,) * (1 + num_batch_dims) + window_shape - - is_single_input = False - if num_batch_dims == 0: - # add singleton batch dimension because lax.reduce_window always - # needs a batch dimension. - inputs = inputs[None] - strides = (1,) + strides - dims = (1,) + dims - is_single_input = True - - assert inputs.ndim == len(dims), f'len({inputs.shape}) != len({dims})' - if not isinstance(padding, str): - padding = tuple(map(tuple, padding)) - assert len(padding) == len(window_shape), ( - f'padding {padding} must specify pads for same number of dims as ' - f'window_shape {window_shape}') - assert all([len(x) == 2 for x in padding - ]), f'each entry in padding {padding} must be length 2' - padding = ((0, 0), (0, 0)) + padding - - indices = jnp.arange(np.prod(inputs.shape)).reshape(inputs.shape) - - def reduce_fn(a, b): - ai, av = a - bi, bv = b - which = av > bv - return jnp.where(which, ai, bi), jnp.where(which, av, bv) - - init_val = -jnp.inf - if inputs.dtype in (jnp.int32, jnp.int64): - init_val = -(1 << 31) - init_val = jnp.array(init_val).astype(inputs.dtype) - - indices, y = jax.lax.reduce_window((indices, inputs), (0, init_val), - reduce_fn, dims, strides, padding) - if is_single_input: - indices = jnp.squeeze(indices, axis=0) - y = jnp.squeeze(y, axis=0) - return y, indices - - batch_result = pool(inputs, -jnp.inf, jax.lax.max, kernel_size, strides, - padding) - indices = pool(inputs, 0, jnp.argmax, kernel_size, strides, padding) - return batch_result, indices - - -# TODO add more ops - - -@op(torch.ops.aten.min) -def _aten_min(x, axis=None): - return jnp.min(x, axis=axis), jnp.argmin(x, axis=axis).astype(jnp.int64) - - -@op(torch.ops.aten.amin) -def _aten_amin(x, axis=None): - return jnp.min(x, axis=axis) - - -@op(torch.ops.aten.argmin) -def _aten_amin(x, axis=None): - return jnp.argmin(x, axis=axis) - - -@op(torch.ops.aten.sin) -def _aten_sin(x): - return jnp.sin(x) - - -@op(torch.ops.aten.sym_size) -def _aten_sym_size(x, dim): - return x.shape[dim] - - -@op(torch.ops.aten.var) -@op(torch.ops.prims.var) -def _aten_var(x, dim=None, *, correction=1, keepdim=False, out=None): - return jnp.var(x, axis=dim, ddof=correction, keepdims=keepdim) - - -@op(torch.ops.prims.broadcast_in_dim) -def _prims_broadcast_in_dim(t, shape, broadcast_dimensions): - return jax.lax.broadcast_in_dim( - t, shape, broadcast_dimensions=broadcast_dimensions) - - -# aten.native_group_norm -- should use decomp table -# func: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) - - -@op(torch.ops.aten.native_group_norm) -def _aten_native_group_norm(input, weight, bias, N, C, HxW, group, eps=1e-5): - """Group Normalization implementation in JAX. - - Args: - input: Input tensor. Expected shape (batch_size, channels, ... spatial dims - ...) - weight: Optional scaling (gamma) parameter. Shape (channels,) - bias: Optional shifting (beta) parameter. Shape (channels,) - N: Batch size. - C: Number of channels. - HxW: Product of spatial dimensions (number of elements per channel after - flattening). - group: Number of groups for Group Normalization. - eps: Small value added for numerical stability. - - Returns: - A tuple of (normalized_output, mean, rstd) - """ - - input_shape = input.shape - - # Reshape for group-wise normalization - reshaped_input = jnp.reshape(input, (1, N * group, -1)) - - # **Core Group Normalization** - def group_norm_body(x): # Function to apply within each group - mean = jnp.mean(x, axis=-1, keepdims=True) - var = jnp.var(x, axis=-1, keepdims=True) - rstd = jax.lax.rsqrt(var + eps) # Reciprocal of std with epsilon - normalized = (x - mean) * rstd - return normalized, mean, rstd - - normalized, group_mean, group_rstd = jax.lax.map(group_norm_body, - reshaped_input) - - # Reshape back to original input shape - output = jnp.reshape(normalized, input_shape) - - # **Affine transformation** - affine_shape = [-1 if i == 1 else 1 for i in range(input.ndim) - ] # Shape for broadcasting - if weight is not None and bias is not None: - output = bias.reshape(affine_shape) + output * weight.reshape(affine_shape) - elif weight is not None: - output = output * weight.reshape(affine_shape) - elif bias is not None: - output = output + bias.reshape(affine_shape) - - # Reshape mean and rstd - mean = jnp.reshape(group_mean, (N, group)) - rstd = jnp.reshape(group_rstd, (N, group)) - - return output, mean, rstd - - -@op(torch.ops.aten.linalg_vector_norm) -def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): - """Calculates the vector norm along specified dimensions. - - Args: - self: The input tensor. - ord: The order of the norm. Can be a float or 'inf', '-inf', 'fro'. - Default is 2 (Euclidean norm). - dim: Dimensions along which to calculate the norm. If None, the norm is - calculated over all dimensions. - keepdim: Whether to keep the reduced dimensions. - dtype: Optional data type for the output. - - Returns: - The tensor containing the calculated vector norms. - """ - - if ord not in {2, float("inf"), float("-inf"), "fro"}: - raise ValueError( - f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and" - " 'fro'.") - - # Special cases (for efficiency and clarity) - if ord == 2: # Euclidean norm - result = jnp.sqrt(jnp.sum(jnp.abs(self)**2, axis=dim, keepdims=keepdim)) - - elif ord == float("inf"): - result = jnp.max(jnp.abs(self), axis=dim, keepdims=keepdim) - - elif ord == float("-inf"): - result = jnp.min(jnp.abs(self), axis=dim, keepdims=keepdim) - - elif ord == "fro": # Frobenius norm - result = jnp.sqrt(jnp.sum(jnp.abs(self)**2, axis=dim, keepdims=keepdim)) - - else: # General case (e.g., ord = 1, ord = 3) - result = jnp.sum( - jnp.abs(self)**ord, axis=dim, keepdims=keepdim)**(1.0 / ord) - - # (Optional) dtype conversion - if dtype is not None: - result = result.astype(dtype) - - return result - - -# aten.reflection_pad1d -@op(torch.ops.aten.reflection_pad1d) -def _aten_reflection_pad1d(input, padding): - rank = len(input.shape) - pad_size = [(0, 0)] * rank - pad_size[-1] = padding - return jnp.pad(input, pad_size, mode="reflect") - - -# aten.alias -@op(torch.ops.aten.alias) -def _aten_alias(self, *args): - return self - - -# aten.sinh -@op(torch.ops.aten.sinh) -def _aten_sinh(self): - return jnp.sinh(self) - - -# aten.native_layer_norm_backward -@op(torch.ops.aten.native_layer_norm_backward) -def _aten_native_layer_norm_backward(grad_out, - input, - normalized_shape, - weight, - bias, - eps=1e-5): - """Implements the backward pass of layer normalization in Jax as defined by `aten::native_layer_norm_backward`. - - Args: - grad_out: The gradient of the output tensor. - input: The input tensor. - normalized_shape: A list of integer dimensions to be normalized over. - weight: Optional weight tensor for the affine transformation. - bias: Optional bias tensor for the affine transformation. - eps: A small epsilon value for numerical stability. - - Returns: - A tuple of (grad_input, grad_weight, grad_bias). - """ - return jax.lax.native_layer_norm_backward(grad_out, input, normalized_shape, - weight, bias, eps) - - -# aten.reflection_pad3d_backward -# aten.reflection_pad2d - - -# aten.atanh -@op(torch.ops.aten.atanh) -def _aten_atanh(self): - return jnp.arctanh(self) - - -# aten.bitwise_not -@op(torch.ops.aten.bitwise_not) -def _aten_bitwise_not(self): - return ~self - - -# aten.embedding_dense_backward - - -# aten.sum -@op(torch.ops.aten.sum) -def _aten_sum(self, dim=None, keepdim=False, dtype=None): - return jnp.sum(self, axis=dim, keepdims=keepdim, dtype=dtype) - - -# aten.sqrt -@op(torch.ops.aten.sqrt) -def _aten_sqrt(self): - return jnp.sqrt(self) - - -@op(torch.ops.aten.tan) -def _aten_tanh(self): - return jnp.tan(self) - - -# aten.tanh -@op(torch.ops.aten.tanh) -def _aten_tanh(self): - return jnp.tanh(self) - - -# aten.ceil -@op(torch.ops.aten.ceil) -def _aten_ceil(self): - return jnp.ceil(self) - - -# aten.asin -@op(torch.ops.aten.asin) -def _aten_asin(self): - return jnp.arcsin(self) - - -# aten.minimum -@op(torch.ops.aten.minimum) -def _aten_minimum(self, other): - return jnp.minimum(self, other) - - -# aten.max_pool2d_backward - - -def _scatter_index(dim, index): - """Returns a tuple of indexes; - - The first is to select in input (to modify), - the second is to select from the values. - """ - index_shape = list(index.shape) - input_indexes = [] - source_indexes = [] - for i in range(len(index_shape)): - source_indexes.append(slice(0, index_shape[i])) - if i == dim: - input_indexes.append(index) - else: - target_shape = [1] * len(index_shape) - target_shape[i] = index_shape[i] - input_indexes.append( - jnp.broadcast_to( - jnp.arange(index_shape[i]).reshape(target_shape), index_shape)) - return tuple(input_indexes), tuple(source_indexes) - - -# aten.scatter_add -@op(torch.ops.aten.scatter_add) -def _aten_scatter_add(input, dim, index, src): - """JAX implementation of scatter, mimicking torch.scatter behavior""" - - input_indexes, source_indexes = _scatter_index(dim, index) - return input.at[input_indexes].add(src[source_indexes]) - - -# aten.logical_not - - -# aten.sign -@op(torch.ops.aten.sign) -def _aten_sign(x): - return jnp.sign(x) - - -# aten.sigmoid -@op(torch.ops.aten.sigmoid) -def _aten_sigmoid(x): - if x.dtype in (jnp.int32, jnp.int64): - x = x.astype(jnp.float32) - return jax.nn.sigmoid(x) - - -# implement aten.asinh in jax -@op(torch.ops.aten.asinh) -def _aten_asinh(self): - return jnp.arcsinh(self) - - -# aten.atan -@op(torch.ops.aten.atan) -def _aten_atan(self): - return jnp.arctan(self) - - -# aten.scatter_reduce -@op(torch.ops.aten.scatter_reduce) -def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True): - input_indexes, source_indexes = _scatter_index(dim, index) - if reduce == "sum": - return input.at[input_indexes].add(src[source_indexes]) - elif reduce == "prod": - return input.at[input_indexes].multiply(src[source_indexes]) - elif reduce == "mean": - return input.at[input_indexes].add(src[source_indexes]) - elif reduce == "amax": - return input.at[input_indexes].max(src[source_indexes]) - elif reduce == "amin": - return input.at[input_indexes].min(src[source_indexes]) - else: - raise RuntimeError('Unknow reduction type: ', reduce) - - -# aten.acos -@op(torch.ops.aten.acos) -def _aten_acos(self): - return jnp.arccos(self) - - -# aten.sym_storage_offset -# aten.native_layer_norm_backward -# aten.max_pool3d_with_indices - - -# aten.gt -@op(torch.ops.aten.gt) -def _aten_gt(self, other): - return self > other - - -# aten.pixel_shuffle -@op(torch.ops.aten.pixel_shuffle) -def _aten_pixel_shuffle(x, upscale_factor): - """PixelShuffle implementation in JAX. - - Args: - x: Input tensor. Typically a feature map. - upscale_factor: Integer by which to upscale the spatial dimensions. - - Returns: - Tensor after PixelShuffle operation. - """ - - batch_size, channels, height, width = x.shape - - if channels % (upscale_factor**2) != 0: - raise ValueError( - 'Number of channels must be divisible by the square of the upscale factor.' - ) - - new_channels = channels // (upscale_factor**2) - new_height = height * upscale_factor - new_width = width * upscale_factor - - x = x.reshape(batch_size, new_channels, upscale_factor, upscale_factor, - height, width) - x = jnp.transpose(x, - (0, 1, 2, 4, 3, 5)) # Move channels to spatial dimensions - x = x.reshape(batch_size, new_channels, new_height, new_width) - - return x - - -# aten.sym_stride -# aten.lt -@op(torch.ops.aten.lt) -def _aten_lt(self, other): - return self < other - - -def pool(inputs, init, reduce_fn, window_shape, strides, padding): - """Helper function to define pooling functions. - - Pooling functions are implemented using the ReduceWindow XLA op. - NOTE: Be aware that pooling is not generally differentiable. - That means providing a reduce_fn that is differentiable does not imply that - pool is differentiable. - - Args: - inputs: input data with dimensions (batch, window dims..., features). - init: the initial value for the reduction - reduce_fn: a reduce function of the form ``(T, T) -> T``. - window_shape: a shape tuple defining the window to reduce over. - strides: a sequence of ``n`` integers, representing the inter-window - strides (default: ``(1, ..., 1)``). - padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence - of ``n`` ``(low, high)`` integer pairs that give the padding to apply before - and after each spatial dimension. - Returns: - The output of the reduction for each window slice. - """ - num_batch_dims = inputs.ndim - (len(window_shape) + 1) - strides = strides or (1,) * len(window_shape) - assert len(window_shape) == len( - strides), f'len({window_shape}) must equal len({strides})' - strides = (1,) * (1 + num_batch_dims) + strides - dims = (1,) * (1 + num_batch_dims) + window_shape - - is_single_input = False - if num_batch_dims == 0: - # add singleton batch dimension because lax.reduce_window always - # needs a batch dimension. - inputs = inputs[None] - strides = (1,) + strides - dims = (1,) + dims - is_single_input = True - - assert inputs.ndim == len(dims), f'len({inputs.shape}) != len({dims})' - if not isinstance(padding, str): - padding = tuple(map(tuple, padding)) - assert len(padding) == len(window_shape), ( - f'padding {padding} must specify pads for same number of dims as ' - f'window_shape {window_shape}') - assert all([len(x) == 2 for x in padding - ]), f'each entry in padding {padding} must be length 2' - padding = ((0, 0), (0, 0)) + padding - y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding) - if is_single_input: - y = jnp.squeeze(y, axis=0) - return y - - -@op(torch.ops.aten._adaptive_avg_pool3d) -def _aten_adaptive_avg_pool3d(x, output_shape): - return _aten_adaptive_avg_pool(x, output_shape, 3) - - -@op(torch.ops.aten._adaptive_avg_pool2d) -def _aten_adaptive_avg_pool3d(x, output_shape): - return _aten_adaptive_avg_pool(x, output_shape, 2) - - -def _aten_adaptive_avg_pool(x, output_shape, pool_dim): - - def adaptive_kernel_size(input_shape, output_shape): - sizes = [1, 1] - spatial_dim_off = len(input_shape) - pool_dim - for spatial_dim in range(pool_dim): - sizes.append(input_shape[spatial_dim_off + spatial_dim] // - output_shape[spatial_dim]) - return tuple(sizes) - - kernel_sizes = adaptive_kernel_size(x.shape, output_shape) - y = pool(x, 0.0, jax.lax.add, kernel_sizes, kernel_sizes, padding='VALID') - - div_shape = list(x.shape) - num_batch_dims = len(x.shape) - pool_dim - 1 - div_shape[num_batch_dims] = 1 - div_shape = tuple(div_shape) - if len(div_shape) - 2 == len(kernel_sizes): - div_shape = (1,) + div_shape[1:] - y = y / pool( - jnp.ones(div_shape), 0.0, jax.lax.add, kernel_sizes, kernel_sizes, - 'VALID') - return y - - -# aten.avg_pool2d -@op(torch.ops.aten.avg_pool2d) -@op(torch.ops.aten.avg_pool3d) -def _aten_avg_pool(inputs, - kernel_size, - strides=None, - padding=0, - ceil_mode=False, - count_include_pad=True, - divisor_override=None): - - num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 - kernel_size = tuple(kernel_size) - strides = tuple(strides) - if isinstance(padding, int): - padding = tuple((padding, padding) for _ in range(len(kernel_size))) - elif isinstance(padding, list): - padding = tuple((p, p) for p in padding) - - y = pool(inputs, 0.0, jax.lax.add, kernel_size, strides, padding) - if count_include_pad: - y = y / np.prod(kernel_size) - else: - div_shape = list(inputs.shape) - div_shape[num_batch_dims] = 1 - div_shape = tuple(div_shape) - if len(div_shape) - 2 == len(kernel_size): - div_shape = (1,) + div_shape[1:] - y = y / pool( - jnp.ones(div_shape), 0.0, jax.lax.add, kernel_size, strides, padding) - return y - - -# aten.sym_numel -# aten.reciprocal -@op(torch.ops.aten.reciprocal) -def _aten_reciprocal(a): - return 1 / a - - -# aten.scatter -@op(torch.ops.aten.select_scatter) -def _aten_select_scatter(input, src, dim, index): - input_indexes = [] - for x in range(len(input.shape)): - if x == dim: - input_indexes.append(index) - else: - input_indexes.append(slice(None, None, None)) - return input.at[tuple(input_indexes)].set(src) - - -@op(torch.ops.aten.scatter.src) -def _aten_scatter_src(input, dim, index, src, reduce=None): - input_index, source_indexes = _scatter_index(dim, index) - return input.at[input_index].set(src[source_indexes]) - - -@op(torch.ops.aten.scatter.value) -def _aten_scatter(input, dim, index, src, reduce=None): - input_index, source_indexes = _scatter_index(dim, index) - return input.at[input_index].set(src) - - -# aten.acosh -@op(torch.ops.aten.acosh) -def _aten_acosh(self): - return jnp.arccosh(self) - - -# aten.avg_pool2d_backward -# aten.col2im -# aten.avg_pool3d -# aten.round -@op(torch.ops.aten.round) -def _aten_round(input, decimals=0): - return jnp.round(input, decimals) - - -# aten.max -@op(torch.ops.aten.max) -def _aten_max(self, dim=None, keepdim=False): - return jnp.max( - self, axis=dim, keepdims=keepdim), jnp.argmax( - self, axis=dim, keepdims=keepdim) - - -# aten.maximum -@op(torch.ops.aten.maximum) -def _aten_maximum(self, other): - return jnp.maximum(self, other) - - -# aten.abs -@op(torch.ops.aten.abs) -def _aten_abs(self): - return jnp.abs(self) - - -# generate aten.amax only -@op(torch.ops.aten.amax) -def _aten_amax(self, dim=None, keepdim=False): - return jnp.amax(self, axis=dim, keepdims=keepdim) - - -# aten.any -@op(torch.ops.aten.any) -def _aten_any(self, dim=None, keepdim=False): - return jnp.any(self, axis=dim, keepdims=keepdim) - - -# aten.arange -@op(torch.ops.aten.arange) -def _aten_arange(start, - end=None, - step=1, - *, - dtype=None, - layout=None, - requires_grad=False, - device=None, - pin_memory=False): - if end is None: - end = start - start = 0 - dtype = tensor.t2j_dtype(dtype) - return jnp.arange( - start, - end, - step, - dtype=dtype, - ) - - -# aten.argmax -@op(torch.ops.aten.argmax) -def _aten_argmax(self, dim=None, keepdim=False): - return jnp.argmax(self, axis=dim, keepdims=keepdim) - - -# aten.as_strided -@op(torch.ops.aten.as_strided) -@op(torch.ops.aten.as_strided_copy) -def _aten_as_strided(x, sizes, strides, storage_offset=None): - ind = jnp.zeros(sizes, dtype=jnp.int32) - - for i, (size, stride) in enumerate(zip(sizes, strides)): - result_shape = (1,) * i + (size,) + (1,) * (len(sizes) - i - 1) - indexes = (jnp.arange(size) * stride).reshape(result_shape) - ind += indexes - - return jnp.ravel(x)[ind] - - -# aten.atan2 -@op(torch.ops.aten.atan2) -def _aten_atan2(self, other): - return jnp.arctan2(self, other) - - -# aten.bitwise_and -@op(torch.ops.aten.bitwise_and) -def _aten_bitwise_and(self, other): - return self & other - - -# aten.bitwise_or -@op(torch.ops.aten.bitwise_or) -def _aten_bitwise_or(self, other): - return self | other - - -# aten.bitwise_xor -@op(torch.ops.aten.bitwise_xor) -def _aten_bitwise_xor(self, other): - return self ^ other - - -# aten.clamp -@op(torch.ops.aten.clamp) -def _aten_clamp(self, min=None, max=None): - return jnp.clip(self, min, max) - - -# aten.constant_pad_nd -@op(torch.ops.aten.constant_pad_nd) -def _aten_constant_pad_nd(input, padding, value=0): - # NOTE: Torch padding is flat and reversed: (1, 1, 2, 2) - # means last dim get padded 1 in front and 1 in back; - # and second last dim get padded 2 in front and 2 in back. - # Jax padding tuple of 2-tuple: the same padding is - # [(0, 0), ..., (2,2), (1,1)] - m = len(padding) - rev_padding = [(padding[i - 1], padding[i]) for i in range(m - 1, 0, -2)] - pad_dim = tuple(([(0, 0)] * (len(input.shape) - m // 2)) + rev_padding) - return jnp.pad(input, pad_dim, mode="constant", constant_values=value) - - -# aten.convolution_backward -@op(torch.ops.aten.copy) -@op(torch.ops.aten.lift_fresh_copy) -def _aten_copy(x): - return jnp.copy(x) - - -@op(torch.ops.aten._cdist_forward) -def _aten_cdist_forward(x1, x2, p, compute_mode=''): - # x1 is B x P x M - # x2 is B x Q x M - # res is B x P x Q - x1 = jnp.expand_dims(x1, len(x1.shape) - 1) - x2 = jnp.expand_dims(x2, len(x2.shape) - 2) - return jnp.linalg.norm(x1 - x2, ord=p, axis=-1) - - -@op(torch.ops.aten._pdist_forward) -def _aten__pdist_forward(x, p): - pairwise_dists = _aten_cdist_forward(x, x, p) - condensed_dists = pairwise_dists[jnp.triu_indices( - pairwise_dists.shape[0], k=1)] - return condensed_dists - - -# aten.cos -@op(torch.ops.aten.cos) -def _aten_cos(input): - return jnp.cos(input) - - -# aten.cosh -@op(torch.ops.aten.cosh) -def _aten_cosh(input): - return jnp.cosh(input) - - -# aten.diagonal -@op(torch.ops.aten.diagonal) -def _aten_diagonal(input, offset=0, dim1=0, dim2=1): - return jnp.diagonal(input, offset, dim1, dim2) - - -# aten.empty_strided -# aten.eq -@op(torch.ops.aten.eq) -def _aten_eq(input1, input2): - return input1 == input2 - - -# aten.erf -@op(torch.ops.aten.erf) -def _aten_erf(x): - if x.dtype in (jnp.int32, jnp.int64): - x = x.astype(jnp.float32) - return jax.lax.erf(x) - - -# aten.exp -@op(torch.ops.aten.exp) -def _aten_exp(input): - return jnp.exp(input) - - -# aten.expm1 -@op(torch.ops.aten.expm1) -def _aten_expm1(input): - return jnp.expm1(input) - - -# aten.fill -@op(torch.ops.aten.fill) -@op(torch.ops.aten.full_like) -def _aten_fill(x, value, dtype=None, pin_memory=None, memory_format=None): - if dtype is None: - dtype = x.dtype - else: - dtype = tensor.t2j_dtype(dtype) - return jnp.full(x.shape, value, dtype) - - -# aten.flip -@op(torch.ops.aten.flip) -def _aten_flip(input, dims): - if dims is not None: - return jnp.flip(input, tuple(dims)) - else: - return jnp.flip(input) - - -# aten.floor -@op(torch.ops.aten.floor) -def _aten_floor(input): - return jnp.floor(input) - - -# aten.fmod -@op(torch.ops.aten.fmod) -def _aten_fmod(input, other): - return input - other * _aten_div(input, other, 'trunc') - - -# aten.gather -@op(torch.ops.aten.gather) -def _aten_gather(input, dim, index): - input_indexes, source_indexes = _scatter_index(dim, index) - return input[input_indexes] - - -# aten.ge -@op(torch.ops.aten.ge) -def _aten_ge(self, other): - return self >= other - - -@op(torch.ops.aten.glu) -@op(torch.ops.aten.glu.default) -def _aten_glu(x, dim=-1): - return jax.nn.glu(x, dim) - - -# aten.hardtanh -@op(torch.ops.aten.hardtanh) -def _aten_hardtanh(input, min_val=-1., max_val=1., inplace=False): - return jnp.clip(input, min_val, max_val) - - -# aten.isinf -@op(torch.ops.aten.isinf) -def _aten_isinf(input): - return jnp.isinf(input) - - -# aten.isnan -@op(torch.ops.aten.isnan) -def _aten_isnan(input): - return jnp.isnan(input) - - -@op(torch.ops.aten.le) -def _aten_le(self, other): - return self <= other - - -# aten.leaky_relu -@op(torch.ops.aten.leaky_relu) -def _aten_leaky_relu(x, negative_slope): - return jax.nn.leaky_relu(x, negative_slope) - - -# aten.log -@op(torch.ops.aten.log) -def _aten_log(x): - return jnp.log(x) - - -# aten.log10 -@op(torch.ops.aten.log10) -def _aten_log10(x): - return jnp.log10(x) - - -# aten.log1p -@op(torch.ops.aten.log1p) -def _aten_log1p(x): - return jnp.log1p(x) - - -# aten.log2 -@op(torch.ops.aten.log2) -def _aten_log2(x): - return jnp.log2(x) - - -# aten.logical_and -@op(torch.ops.aten.logical_and) -def _aten_logical_and(self, other): - return jnp.logical_and(self, other) - - -# aten.logical_or -@op(torch.ops.aten.logical_or) -def _aten_logical_or(self, other): - return jnp.logical_or(self, other) - - -# aten.logical_not -@op(torch.ops.aten.logical_not) -def _aten_logical_not(self): - return jnp.logical_not(self) - - -# aten.log_softmax -@op(torch.ops.aten._log_softmax) -def _aten_log_softmax(self, axis=-1, half_to_float=False): - return jax.nn.log_softmax(self, axis) - - -# aten.max_pool3d_backward -# aten.logical_xor -@op(torch.ops.aten.logical_xor) -def _aten_logical_xor(self, other): - return jnp.logical_xor(self, other) - - -# aten.max_pool2d_with_indices_backward -# aten.native_dropout -# aten.native_group_norm_backward -# aten.neg -@op(torch.ops.aten.neg) -def _aten_neg(x): - return -1 * x - - -# aten.nonzero -@op(torch.ops.aten.nonzero) -def _aten_nonzero(x): - index_tuple = jnp.nonzero(x) - index_tuple = [jnp.expand_dims(p, -1) for p in index_tuple] - return jnp.concatenate(index_tuple, axis=-1) - - -# aten.prod - - -@op(torch.ops.aten.prod) -def _aten_prod(self, dim=None, keepdim=False): - return jnp.prod(self, axis=dim, keepdims=keepdim) - - -# aten.rand -# aten.randn -# aten.randperm -# aten.reflection_pad3d -# aten.remainder -@op(torch.ops.aten.remainder) -def _aten_remainder(inputs, other): - return inputs % other - - -# aten.repeat -@op(torch.ops.aten.repeat) -def _aten_repeat(x, reps): - return jnp.tile(x, reps) - - -# aten.replication_pad2d -# aten.replication_pad3d -# aten.roll -@op(torch.ops.aten.roll) -def _aten_roll(input, shifts, dims=None): - return jnp.roll(input, shifts, dims) - - -# aten.scalar_tensor -# aten.slice_scatter -@op(torch.ops.aten.slice_scatter) -def _aten_slice_scatter(input, src, dim=0, start=None, end=None, step=1): - input_index = [] - for x in range(len(input.shape)): - if x == dim: - input_index.append(slice(start, end, step)) - else: - input_index.append(slice(None, None, None)) - return input.at[tuple(input_index)].set(src) - - -# aten.sort -# torch.sort(input, dim=-1, descending=False, stable=False, *, out=None) -@op(torch.ops.aten.sort) -def _aten_sort(a, dim=-1, descending=False, stable=False): - return ( - jnp.sort(a, axis=dim, stable=stable, descending=descending), - jnp.argsort(a, axis=dim, stable=stable, descending=descending), - ) - - -# aten.sym_size - - -# aten.topk -@op(torch.ops.aten.topk) -def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None): - """JAX top-k implementation using jax.lax.top_k for improved efficiency. - - Args: - input: The input JAX array. - k: The number of top elements to return. - dim: The dimension along which to find the top-k. If None, operates on the - flattened array. - largest: If True, returns the largest k elements. Otherwise, smallest k. - sorted: If True, returns the elements in sorted order. - - Returns: - A tuple (values, indices) containing: - - values: The top k values. - - indices: The indices of the top k values in the original array. - """ - if dim is None: - input = input.flatten() - dim = 0 - - if not largest: - input = -input # Find top-k of negated input if we want the smallest - - transpose_shape = None - if dim != -1 and dim != len(input.shape) - 1: - transpose_shape = list(range(len(input.shape))) - transpose_shape[dim], transpose_shape[-1] = (transpose_shape[-1], - transpose_shape[dim]) - input = jnp.transpose(input, transpose_shape) - - values, indices = jax.lax.top_k(input, k) - - if sorted: - values = jnp.sort(values, descending=True) - indices = jnp.take_along_axis( - indices, jnp.argsort(values, axis=-1, descending=True), axis=-1) - - if not largest: - values = -values # Negate values back if we found smallest - - if transpose_shape is not None: - values = jnp.transpose(values, transpose_shape) - indices = jnp.transpose(indices, transpose_shape) - - return values, indices - - -# aten.trunc -@op(torch.ops.aten.trunc) -def _aten_trunc(a): - return jnp.trunc(a) - - -@op(torch.ops.aten.unbind) -@op(torch.ops.aten.unbind_copy) -def _aten_unbind(a, dim=0): - return tuple( - _aten_squeeze_dim(jax.lax.index_in_dim(a, i, axis=dim), dim) - for i in range(a.shape[dim])) - - -# NOTE: skip aten.upsample_nearest2d and aten.upsample_bilinear2d -# despite those being core aten ops, they also have decompositions. -# here we are using torch decompositions. - - -# aten.where -@op(torch.ops.aten.where) -def _aten_where(condition, x, y): - return jnp.where(condition, x, y) - - -# aten.to.dtype -#Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None -@op(torch.ops.aten.to.dtype) -def _aten_to_dtype(a, - dtype, - non_blocking=False, - copy=False, - memory_format=None): - jaxdtype = tensor.t2j_dtype(dtype) - return a.astype(jaxdtype) - - -# aten.to.device - - -#Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False -@op(torch.ops.aten.var_mean.correction) -def _aten_var_mean_correction(self, dim=None, correction=None, keepdim=False): - return (jnp.var(self, axis=dim, ddof=correction, - keepdims=keepdim), jnp.mean(self, dim, keepdims=keepdim)) - - -@op(torch.ops.aten.scalar_tensor) -def _aten_scalar_tensor(s, - dtype=None, - layout=None, - device=None, - pin_memory=None): - if dtype is not None: - dtype = tensor.t2j_dtype(dtype) - return jnp.array(s, dtype=dtype) - return jnp.array(s) - - -@op(torch.ops.aten.to.device) -def _aten_to_device(x,device, dtype): - return x - - -@op(torch.ops.aten.max_pool2d_with_indices_backward) -def max_pool2d_with_indices_backward_custom(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices): - - """ - Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward. - - Args: - grad_output: The gradient tensor from the preceding layer. - self: The input tensor on which the original max pooling was performed. - kernel_size: The size of the pooling window. - stride: The stride of the pooling window. - padding: The padding applied during max pooling. - dilation: The dilation factor for the pooling operation. - ceil_mode: Whether to use ceil or floor when calculating output shapes. - indices: The indices of the maximum values, as produced by max_pool2d_with_indices. - - Returns: - The calculated gradient with respect to the input (grad_input). - """ - - kH, kW = kernel_size - dH, dW = stride - padH, padW = padding - dilH, dilW = dilation - - # Calculate output shape (may need adjustment based on ceil_mode) - out_shape = jnp.array(self.shape) - grad_input = jnp.zeros_like(self) - - # Iterate over the flattened input and output tensors - for i, idx in enumerate(indices.flatten()): - # Calculate input coordinates corresponding to the maximum value - out_y, out_x = i // grad_output.shape[3], i % grad_output.shape[3] - in_y = out_y * dH - padH + out_y * (dilH - 1) - in_x = out_x * dW - padW + out_x * (dilW - 1) - - # Scatter the gradient to the appropriate input locations (handling potential overlaps) - for y in range(in_y, in_y + kH): - for x in range(in_x, in_x + kW): - if 0 <= y < grad_input.shape[2] and 0 <= x < grad_input.shape[3]: - grad_input = grad_input.at[y, x].add(grad_output.flatten()[i]) - - return grad_input - - -@op(torch.ops.aten._local_scalar_dense) -def _aten_local_scalar_dense(x): - return x.item() - -@op(torch.ops.aten.tensor_split.sections) -def _aten_tensor_split(ary, indices_or_sections, axis=0): - return jnp.array_split(ary, indices_or_sections, axis) \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/decompositions.py b/experimental/torch_xla2/torch_xla2/decompositions.py index e85e49e13ee..81b48bb5da8 100644 --- a/experimental/torch_xla2/torch_xla2/decompositions.py +++ b/experimental/torch_xla2/torch_xla2/decompositions.py @@ -90,4 +90,21 @@ def _reflection_or_replication_pad( return result _try_register(aten.replication_pad1d, _replication_pad) -_try_register(aten.replication_pad3d, _replication_pad) \ No newline at end of file +_try_register(aten.replication_pad3d, _replication_pad) + +EXTRA_DECOMP = decomp.get_decompositions([ + torch.ops.aten.upsample_nearest2d, + torch.ops.aten._native_batch_norm_legit.no_stats, + torch.ops.aten._adaptive_avg_pool2d, + torch.ops.aten._adaptive_avg_pool3d, + torch.ops.aten.grid_sampler_2d, + torch.ops.aten.native_dropout, + torch.ops.aten.reflection_pad1d, + torch.ops.aten.reflection_pad2d, + torch.ops.aten.reflection_pad3d, + torch.ops.aten.replication_pad1d, + torch.ops.aten.replication_pad2d, + torch.ops.aten.replication_pad3d, +]) + +EXTRA_DECOMP[torch.ops.aten.uniform] = torch.ops.aten.rand \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/environment.py b/experimental/torch_xla2/torch_xla2/environment.py index 6a71c7d51c0..139597f9cb0 100644 --- a/experimental/torch_xla2/torch_xla2/environment.py +++ b/experimental/torch_xla2/torch_xla2/environment.py @@ -1,26 +1,2 @@ -import jax - - -class Environment: - """This class holds a set of configurations and "globals" needed - - for executing torch program using jax. - Things included so far: - - op registry - PRNGKey - Configs - - Also helper functions to manipulate those. - """ - - _prng_key: jax.random.PRNGKey - - - def __init__(self, random_seed): - self._prng_key = jax.random.PRNGKey(random_seed) - - def get_and_rotate_prng_key(self): - self._prng_key, key = jax.random.split(self._prng_key) diff --git a/experimental/torch_xla2/torch_xla2/export.py b/experimental/torch_xla2/torch_xla2/export.py index 64a3f9d175c..78430a6d537 100644 --- a/experimental/torch_xla2/torch_xla2/export.py +++ b/experimental/torch_xla2/torch_xla2/export.py @@ -2,146 +2,12 @@ """Utilities for exporting a torch program to jax/stablehlo.""" import copy from typing import Any, Dict, Tuple -import jax import torch -from torch.fx import _pytree as fx_pytree -from torch_xla2 import ops_registry, tensor +from torch_xla2.ops import ops_registry +from torch_xla2 import tensor from torch.utils import _pytree as pytree -class JaxProgram: - - def _wrap_inputs(self, xs, allow_torch_tensor=False): - - def convert(t): - if isinstance(t, tensor.XLATensor2): - return t - if isinstance(t, torch.Tensor): - if allow_torch_tensor: - return tensor.move_to_device(t) - else: - raise ValueError('Regular torch.Tensor is not allowed.') - if isinstance(t, jax.Array): - return tensor.XLATensor2(t) - return t - - return jax.tree_util.tree_map(convert, xs) - - def _unwrap_outputs(self, xs): - - def convert(t): - if isinstance(t, tensor.XLATensor2): - return t.jax() - if isinstance(t, torch.Tensor): - raise ValueError('Regular torch.Tensor is not allowed.') - return t - - return jax.tree_util.tree_map(convert, xs) - - def __init__( - self, - exported_program, - param_buffer_values, - ordered_tensor_constants, - ): - - self.param_buffer_values = self._wrap_inputs( - param_buffer_values, allow_torch_tensor=True) - self.ordered_tensor_constants = self._wrap_inputs( - ordered_tensor_constants, allow_torch_tensor=True) - self.exported_program = exported_program - - def __hash__(self): - return hash(self.exported_program) - - @property - def example_inputs(self): - args, kwargs = self.exported_program.example_inputs - args = pytree.tree_map(tensor.t2j, args) - kwargs = pytree.tree_map(tensor.t2j, kwargs) - return args, kwargs - - def flatten_inputs(self, args, kwargs): - if args is None: - args = tuple() - if kwargs is None: - kwargs = {} - - if (in_spec := self.exported_program.call_spec.in_spec) is not None: - if (in_spec.type == tuple and len(in_spec.children_specs) == 2 and - in_spec.children_specs[0].type == tuple and - in_spec.children_specs[1].type == dict): - # NOTE: this is the case where in_spec is for both args and kwargs - return fx_pytree.tree_flatten_spec((args, kwargs), in_spec) - return fx_pytree.tree_flatten_spec(args, in_spec) - return copy.deepcopy(args) - - def unflatten_outputs(self, res): - return pytree.tree_unflatten(res, self.exported_program.call_spec.out_spec) - - def __call__(self, *args, **kwargs): - - inputs = self.flatten_inputs(args, kwargs) - res = self.flatten_callable(*inputs) - res = self.unflatten_outputs(res) - - return res - - @property - def flatten_callable(self): - - def func(*inputs: jax.Array): - nonlocal self - inputs = self._wrap_inputs(inputs) - num_mutations = len( - self.exported_program.graph_signature.buffers_to_mutate) - res = torch.fx.Interpreter(self.exported_program.graph_module).run( - *self.param_buffer_values, - *inputs, - *self.ordered_tensor_constants, - enable_io_processing=False, - ) - res = res[num_mutations:] - res = self._unwrap_outputs(res) - return res - - return func - - def jit(self, *args, **kwargs): - """Returns `jax.jit(self, *args, **kwargs)`.""" - return jax.jit(self, *args, **kwargs) - - def jit_lower(self, *args, **kwargs): - """Returns `jax.jit(self, *args, **kwargs).lower(...)` with example_inputs used in export.""" - example_args, example_kwargs = self.example_inputs - return self.jit(*args, **kwargs).lower(*example_args, **example_kwargs) - - -def exported_program_to_jax_program(ep): - """exported_program_to_jax_program. - - Args: - ep: torch.export.ExportedProgram - - Returns: - JaxProgram - - """ - if torch.__version__ >= '2.2': - ep = ep.run_decompositions() - - param_buffer_keys = ep.graph_signature.parameters + ep.graph_signature.buffers - param_buffer_values = tuple(ep.state_dict[key] for key in param_buffer_keys) - - if hasattr(ep.graph_signature, 'lifted_tensor_constants'): - ordered_tensor_constants = tuple( - ep.tensor_constants[name] - for name in ep.graph_signature.lifted_tensor_constants) - else: - ordered_tensor_constants = tuple() - - return JaxProgram(ep, param_buffer_values, ordered_tensor_constants) - DEBUG = False @@ -149,6 +15,11 @@ def exported_program_to_jax_program(ep): class JaxInterpreter(torch.fx.Interpreter): """Experimental.""" + def __init__(self, graph_module): + super().__init__(graph_module) + import torch_xla2.ops.jaten + import torch_xla2.ops.jtorch + def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: if not isinstance(target, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)): @@ -157,7 +28,9 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: if DEBUG: print('Running ', target.name(), '--------') - op = ops_registry.lowerings.lookup(target) + op = ops_registry.all_aten_ops.get(target) + if op is None: + op = ops_registry.all_aten_ops.get(target.overloadpacket) if op is None: print(target.name(), target.tags) raise RuntimeError('No lowering found for', target.name()) diff --git a/experimental/torch_xla2/torch_xla2/extra.py b/experimental/torch_xla2/torch_xla2/extra.py deleted file mode 100644 index ebfdb96b1db..00000000000 --- a/experimental/torch_xla2/torch_xla2/extra.py +++ /dev/null @@ -1,62 +0,0 @@ -import jax -import jax.numpy as jnp -import functools -import torch -from torch.utils import _pytree as pytree -from torch_xla2 import tensor - -def torch_view(t): - # t is an object from jax land - # view it as-if it's a torch land object - if isinstance(t, jax.Array): - return tensor.XLATensor2(t) - if isinstance(t, type(jnp.int32)): - return tensor.t2j_type(t) - if callable(t): - def new_t(*args, **kwargs): - # args, kwargs are torch-land - args, kwargs = pytree.tree_map(jax_view, (args, kwargs)) - # now they are objs in jax-land - res = t(*args, **kwargs) # t is jax callable - # res is jax-land obj - return pytree.tree_map(torch_view, res) - return new_t - # regular types are not changed - return t - - -def jax_view(t): - # t is an object from torch land - # view it as-if it's a jax land object - if isinstance(t, torch.Tensor): - assert isinstance(t, tensor.XLATensor2) - return t.jax() - if isinstance(t, type(torch.int32)): - return tensor.j2t_dtype(t) - if callable(t): - def new_t(*args, **kwargs): - # args, kwargs are jax-land - args, kwargs = pytree.tree_map(torch_view, (args, kwargs)) - # now they are objs in torch-land - res = t(*args, **kwargs) - # res is torch-land obj - return pytree.tree_map(jax_view, res) - return new_t - # regular types are not changed - return t - -def call_jax(jax_func, *args, **kwargs): - return torch_view(jax_func)(*args, **kwargs) - - -def call_torch(torch_func, *args, **kwargs): - return jax_view(torch_func)(*args, **kwargs) - - -fori_loop = torch_view(jax.lax.fori_loop) - -def jax_jit(torch_function, kwargs_for_jax_jit=None): - kwargs_for_jax_jit = kwargs_for_jax_jit or {} - jax_func = jax_view(torch_function) - jitted = jax.jit(jax_func, **kwargs_for_jax_jit) - return torch_view(jitted) diff --git a/experimental/torch_xla2/torch_xla2/functions.py b/experimental/torch_xla2/torch_xla2/functions.py deleted file mode 100644 index 9fcd5653a86..00000000000 --- a/experimental/torch_xla2/torch_xla2/functions.py +++ /dev/null @@ -1,109 +0,0 @@ -"""Tensor constructor overrides""" -import functools -import logging -from typing import Callable, Optional, ParamSpec, Sequence - -import jax -import torch -import jax.numpy as jnp -from torch_xla2 import tensor - -registry = {} - -P = ParamSpec('P') - - -def register_function(torch_func: Callable[P, torch.Tensor]): - """Registers a function as the JAX implementation of a torch function.""" - - def decorator(jax_impl: Callable[P, jax.Array]): - registry[torch_func] = jax_impl - return jax_impl - - return decorator - - -def convert_dtype(use_default_dtype: bool = True): - """Converts `dtype` kwarg of function from torch to JAX. - - Args: - use_default_dtype: Whether to use torch default dtype if none is provided. - - Returns: - A decorator that wraps a JAX implementation of a torch function. - """ - - def decorator(func: Callable[P, torch.Tensor]): - - @functools.wraps(func) - def wrapper(*args: P.args, - dtype: Optional[torch.dtype] = None, - **kwargs: P.kwargs): - if not dtype and use_default_dtype: - dtype = torch.get_default_dtype() - jax_dtype = tensor.t2j_dtype(dtype) - - return func(*args, dtype=jax_dtype, **kwargs) - - return wrapper - - return decorator - - -@register_function(torch.tensor) -@convert_dtype(use_default_dtype=False) # Attempt to infer type from elements -def _tensor(data, *, dtype=None, **kwargs): - python_types_to_torch_types = { - bool: jnp.bool, - int: jnp.int64, - float: jnp.float32, - complex: jnp.complex64, - } - if not dtype: - leaves = jax.tree_util.tree_leaves(data) - if len(leaves) > 0: - dtype = python_types_to_torch_types.get(type(leaves[0])) - - return jnp.array( - data, dtype=dtype or tensor.t2j_dtype(torch.get_default_dtype())) - - -@register_function(torch.ones) -@convert_dtype() -def _ones(*size: int, dtype=None, **kwargs): - return jnp.ones(size, dtype) - - -@register_function(torch.zeros) -@convert_dtype() -def _zeros(*size: int, dtype=None, **kwargs): - return jnp.zeros(size, dtype) - - -@register_function(torch.eye) -@convert_dtype() -def _eye(n: int, m: Optional[int] = None, *, dtype=None, **kwargs): - return jnp.eye(n, m, dtype=dtype) - - -@register_function(torch.full) -@convert_dtype() -def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs): - # TODO: handle torch.Size - return jnp.full(size, fill_value, dtype=dtype) - - -class XLAFunctionMode(torch.overrides.TorchFunctionMode): - """Context manager that dispatches torch function calls to JAX.""" - - def __torch_function__(self, - func, - types, - args=(), - kwargs=None) -> torch.Tensor: - jax_func = registry.get(func) - if not jax_func: - return func(*args, **(kwargs or {})) - - # TODO: unwrap args here or in implementations? - return tensor.wrap(jax_func(*args, **(kwargs or {}))) diff --git a/experimental/torch_xla2/torch_xla2/interop.py b/experimental/torch_xla2/torch_xla2/interop.py new file mode 100644 index 00000000000..fbcd47922e1 --- /dev/null +++ b/experimental/torch_xla2/torch_xla2/interop.py @@ -0,0 +1,65 @@ +import functools +import torch +import jax +import jax.numpy as jnp +from jax import tree_util as pytree +from torch_xla2 import tensor +import torch_xla2 + +from torch_xla2.types import JaxValue, TorchValue, JaxCallable, TorchCallable + + + + +def torch_view(t: JaxValue) -> TorchValue: + # t is an object from jax land + # view it as-if it's a torch land object + if isinstance(t, jax.Array): + # TODO + return tensor.XLATensor2(t, torch_xla2.default_env()) + if isinstance(t, type(jnp.int32)): + return tensor.t2j_type(t) + if callable(t): # t is a JaxCallable + return functools.partial(call_jax, t) + # regular types are not changed + return t + + +def jax_view(t: TorchValue) -> JaxValue: + # t is an object from torch land + # view it as-if it's a jax land object + if isinstance(t, torch.Tensor): + assert isinstance(t, tensor.XLATensor2) + return t.jax() + if isinstance(t, type(torch.int32)): + return tensor.j2t_dtype(t) + + # torch.nn.Module needs special handling + if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable + return functools.partial(call_torch, t) + # regular types are not changed + return t + + +def call_jax(jax_func: JaxCallable, + *args: TorchValue, + **kwargs: TorchValue) -> TorchValue: + args, kwargs = pytree.tree_map(jax_view, (args, kwargs)) + res: JaxValue = jax_func(*args, **kwargs) + return torch_view(res) + + +def call_torch(torch_func: TorchCallable, *args: JaxValue, **kwargs: JaxValue) -> JaxValue: + args, kwargs = pytree.tree_map(torch_view, (args, kwargs)) + with torch_xla2.default_env(): + res: TorchValue = torch_func(*args, **kwargs) + return jax_view(res) + + +fori_loop = torch_view(jax.lax.fori_loop) + +def jax_jit(torch_function, kwargs_for_jax_jit=None): + kwargs_for_jax_jit = kwargs_for_jax_jit or {} + jax_func = jax_view(torch_function) + jitted = jax.jit(jax_func, **kwargs_for_jax_jit) + return torch_view(jitted) diff --git a/experimental/torch_xla2/torch_xla2/ops/__init__.py b/experimental/torch_xla2/torch_xla2/ops/__init__.py index e69de29bb2d..abefc8344b1 100644 --- a/experimental/torch_xla2/torch_xla2/ops/__init__.py +++ b/experimental/torch_xla2/torch_xla2/ops/__init__.py @@ -0,0 +1,9 @@ +def all_aten_jax_ops(): + # to load the ops + import torch_xla2.jaten # type: ignore + import torch_xla2.ops_registry # type: ignore + return { + key: val.func + for key, val in torch_xla2.ops_registry.all_aten_ops + if val.is_jax_function + } \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index a30fae82de8..f6adc702a14 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1,5 +1,14 @@ -"""This module contains implementation of ATen ops.""" +"""Torch ops implemented using jax.""" + +import sys + +import jax +from jax import numpy as jnp +import numpy as np import torch +from torch_xla2.ops import ops_registry +from torch_xla2 import tensor +from torch_xla2.ops import op_base # Keys are OpOverload, value is a callable that takes # XLATensor2 @@ -9,29 +18,1933 @@ # and need to be implemented in jax mutation_ops_to_functional = { - torch.ops.aten.add_: torch.ops.aten.add, - torch.ops.aten.sub_: torch.ops.aten.sub, - torch.ops.aten.mul_: torch.ops.aten.mul, - torch.ops.aten.div_: torch.ops.aten.div, - torch.ops.aten.pow_: torch.ops.aten.pow, - torch.ops.aten.lt_: torch.ops.aten.lt, - torch.ops.aten.le_: torch.ops.aten.le, - torch.ops.aten.gt_: torch.ops.aten.gt, - torch.ops.aten.ge_: torch.ops.aten.ge, - torch.ops.aten.eq_: torch.ops.aten.eq, - torch.ops.aten.ne_: torch.ops.aten.ne, + torch.ops.aten.add_: torch.ops.aten.add, + torch.ops.aten.sub_: torch.ops.aten.sub, + torch.ops.aten.mul_: torch.ops.aten.mul, + torch.ops.aten.div_: torch.ops.aten.div, + torch.ops.aten.pow_: torch.ops.aten.pow, + torch.ops.aten.lt_: torch.ops.aten.lt, + torch.ops.aten.le_: torch.ops.aten.le, + torch.ops.aten.gt_: torch.ops.aten.gt, + torch.ops.aten.ge_: torch.ops.aten.ge, + torch.ops.aten.eq_: torch.ops.aten.eq, + torch.ops.aten.ne_: torch.ops.aten.ne, + torch.ops.aten.uniform_: torch.ops.aten.uniform, } def make_mutation(op): + return op_base.InplaceOp(mutation_ops_to_functional[op], position_to_mutate=0) - def f(*args, **kwargs): - res = mutation_ops_to_functional[op](*args, **kwargs) - args[0].copy_(res) - return args[0] - return f +for op in mutation_ops_to_functional.keys(): + ops_registry.register_torch_dispatch_op( + op, make_mutation(op), is_jax_function=False + ) -for op in mutation_ops_to_functional.keys(): - all_ops[op] = make_mutation(op) +def op(*aten, **kwargs): + def inner(func): + for a in aten: + ops_registry.register_torch_dispatch_op(a, func, **kwargs) + return func + + return inner + + +@op( + torch.ops.aten.view_copy, + torch.ops.aten.view, + torch.ops.aten._unsafe_view, + torch.ops.aten.reshape, +) +def _aten_unsafe_view(x, shape): + return jnp.reshape(x, shape) + + +@op(torch.ops.aten.add.Tensor) +@op(torch.ops.aten.add.Scalar) +def _aten_add(x, y, *, alpha=1): + """if isinstance(x, jnp.ndarray) and isinstance(y, jnp.ndarray): + + assert x.dtype == y.dtype, (x.dtype, y.dtype) + """ + return x + y * alpha + + +@op(torch.ops.aten.copy_, torch.ops.aten.copy_.default, is_jax_function=False) +def _aten_copy(x, y, memory_format=None): + if isinstance(x, tensor.XLATensor2): + x._elem = y._elem + elif isinstance(x, tensor.SliceView): + x.mutate(y) + return x + + +@op(torch.ops.aten.clone) +@op(torch.ops.aten.clone.default) +def _aten_clone(x, memory_format=None): + return jnp.copy(x) + + +@op(torch.ops.aten.full) +def _aten_full(size, value, **kwargs): + return jnp.full(size, value) + + +@op(torch.ops.aten.index_copy) +def _aten_index_copy(x, dim, indexes, source): + # return jax.lax.scatter(x, index, dim) + dims = [] + for i in range(len(x.shape)): + if i == dim: + dims.append(indexes) + else: + dims.append(slice(None, None, None)) + return x.at[dim].set(source) + + +@op(torch.ops.aten.select) +@op(torch.ops.aten.index_select) +@op(torch.ops.aten.select_copy) +def _aten_index_select(x, dim, indexes): + dims = [] + for i in range(len(x.shape)): + if i == dim: + dims.append(indexes) + else: + dims.append(slice(None, None, None)) + return x[tuple(dims)] + + +@op(torch.ops.aten.mean) +def _aten_mean(x, dim=None, keepdim=False): + return jnp.mean(x, dim, keepdims=keepdim) + + +def _torch_binary_scalar_type(scalar, tensor): + if "float" in str(tensor.dtype): + return tensor.dtype + + if isinstance(scalar, int): + if "int" in str(tensor.dtype): + return tensor.dtype + + return jnp.float32 + + +@op(torch.ops.aten.sub.Tensor) +@op(torch.ops.aten.sub.Scalar) +def _aten_sub(x, y): + if isinstance(x, float): + dtype = _torch_binary_scalar_type(x, y) + x = jnp.array(x, dtype=dtype) + if isinstance(y, float): + dtype = _torch_binary_scalar_type(y, x) + y = jnp.array(y, dtype=dtype) + return x - y + + +@op(torch.ops.aten.mm) +def _aten_mm(x, y): + res = x @ y + return res + + +@op(torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar) +def _aten_mul(x, y): + return x * y + + +@op(torch.ops.aten.silu) +def _aten_silu(x): + return jax.nn.silu(x) + + +@op(torch.ops.aten.t) +def _aten_t(x): + return jnp.transpose(x) + + +@op(torch.ops.aten.transpose) +@op(torch.ops.aten.transpose_copy) +def _aten_transpose(x, dim0, dim1): + shape = list(range(len(x.shape))) + shape[dim0], shape[dim1] = shape[dim1], shape[dim0] + return jnp.transpose(x, shape) + + +@op(torch.ops.aten.triu) +def _aten_triu(m, k): + return jnp.triu(m, k) + + +@op(torch.ops.aten.slice) +@op(torch.ops.aten.slice_copy) +def _aten_slice(self, dim=0, start=None, end=None, step=1): + if end == sys.maxsize: + end = self.shape[dim] + sl = slice(start, end, step) + dims = [] + for i in range(len(self.shape)): + if i == dim: + dims.append(sl) + else: + dims.append(slice(None, None, None)) + return self[tuple(dims)] + + +@op(torch.ops.aten.detach) +def _aten_detach(self): + return self + + +@op(torch.ops.aten.view_as_real) +def _aten_view_as_real(x): + real = jnp.real(x) + im = jnp.imag(x) + res = jnp.stack([real, im], -1) + return res + + +@op(torch.ops.aten.stack) +def _aten_stack(tensors, dim=0): + return jnp.stack(tensors, dim) + + +@op(torch.ops.aten._softmax) +def _aten_softmax(x, dim, halftofloat): + return jax.nn.softmax(x, dim) + + +@op(torch.ops.aten.pow) +def _aten_pow(x, y): + if isinstance(y, int): + y = float(y) + return jnp.power(x, y) + + +@op(torch.ops.aten.view_as_complex) +def _aten_view_as_complex(input): + if input.dtype == jnp.bfloat16: + input = input.astype(jnp.float32) + x, y = input[..., 0], input[..., 1] + return jax.lax.complex(x, y) + + +@op(torch.ops.aten.div) +def _aten_div(x, y, rounding_mode=""): + res = x / y + if rounding_mode == "trunc": + res = jnp.trunc(res) + return res + + +@op(torch.ops.aten.div_, is_jax_function=False) +def _aten_div_(x, y, rounding_mode=""): + x._elem = _aten_div(x._elem, y._elem, rounding_mode) + return x + + +@op(torch.ops.aten.true_divide) +def _aten_true_divide(x, y): + return x / y + + +@op(torch.ops.aten.bmm) +def _aten_bmm(x, y): + res = x @ y + return res + # return jnp.einsum('bnm,bmk->bnk', x, y) + + +@op(torch.ops.aten.embedding) +# embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) +def _aten_embedding(a, w, padding_idx=-1): + return jnp.take(a, w, axis=0) + + +@op(torch.ops.aten.rsqrt) +def _aten_rsqrt(x): + if isinstance(x, int): + x = float(x) + if x.dtype == jnp.int32: + x = x.astype(jnp.float32) + return jax.lax.rsqrt(x) + + +@op(torch.ops.aten.expand) +@op(torch.ops.aten.expand_copy) +def _aten_expand(x, dims): + def fix_dims(d, xs): + if d == -1: + return xs + return d + + dims = [fix_dims(p, s) for p, s in zip(dims, x.shape)] + return jnp.broadcast_to(x, dims) + + +@op(torch.ops.aten.dot) +def _aten_dot(x, y): + return jnp.dot(x, y) + + +@op(torch.ops.aten._to_copy) +def _aten__to_copy(self, **kwargs): + dtype = tensor.t2j_dtype(kwargs["dtype"]) + if dtype != self.dtype: + return self.astype(dtype) + return jnp.copy(self) + + +@op(torch.ops.aten.empty) +def _aten_empty(sizes, **kwargs): + return jnp.zeros(sizes) + + +@op(torch.ops.aten.index_put_) +@op(torch.ops.aten.index_put) +def _aten_index_put(self, indexes, values, accumulate=False): + indexes = [slice(None, None, None) if i is None else i for i in indexes] + indexes = tuple(indexes) + if accumulate: + return self.at[indexes].add(values) + else: + return self.at[indexes].set(values) + + +@op(torch.ops.aten.index) +@op(torch.ops.aten._unsafe_index) +@op(torch.ops.aten.index.Tensor) +def _aten_index(self, indexes): + print(indexes) + indexes = [slice(None, None, None) if i is None else i for i in indexes] + indexes = tuple(indexes) + return self[indexes] + + +@op(torch.ops.aten.split) +@op(torch.ops.aten.split_copy) +@op(torch.ops.aten.split_with_sizes) +def split_with_sizes(x, sizes, dim=0): + """Splits an array `x` into sub-arrays based on static sizes `sizes`. + + Args: + x: The input array to split. + sizes: A 1D array of integer sizes for each sub-array. + + Returns: + A list of sub-arrays. + """ + if isinstance(sizes, int): + # split equal size + new_sizes = [sizes] * (x.shape[dim] // sizes) + sizes = new_sizes + rank = x.ndim + splits = np.cumsum(sizes) # Cumulative sum for split points + + def make_range(rank, dim, start, end): + res = [slice(None, None, None)] * rank + res[dim] = slice(start, end) + return tuple(res) + + return [ + x[make_range(rank, dim, start, end)] + for start, end in zip([0] + list(splits[:-1]), splits) + ] + + +@op(torch.ops.aten.permute) +@op(torch.ops.aten.permute_copy) +def permute(t, dims): + return jnp.transpose(t, dims) + + +@op(torch.ops.aten.unsqueeze) +@op(torch.ops.aten.unsqueeze_copy) +@op(torch.ops.aten.unsqueeze.default) +def _aten_unsqueeze(self, dim): + if dim < 0: + dim += self.ndim + 1 + return jnp.expand_dims(self, dim) + + +@op(torch.ops.aten.ne) +def _aten_ne(x, y): + return jnp.not_equal(x, y) + + +@op(torch.ops.aten.cumsum) +def _aten_cumsum(x, y, dtype=None): + if dtype: + dtype = tensor.t2j_dtype(dtype) + res = jnp.cumsum(x, y, dtype) + return res + + +@op(torch.ops.aten.native_layer_norm) +def _aten_native_layer_norm( + input, normalized_shape, weight=None, bias=None, eps=1e-5 +): + """Implements layer normalization in Jax as defined by `aten::native_layer_norm`. + + Args: + input: The input tensor. + normalized_shape: A list of integer dimensions to be normalized over. + weight: Optional weight tensor for the affine transformation. + bias: Optional bias tensor for the affine transformation. + eps: A small epsilon value for numerical stability. + + Returns: + output: The normalized tensor. + mean: The calculated mean tensor. + std: The calculated standard deviation tensor. + """ + if isinstance(normalized_shape, int): + normalized_shape = [normalized_shape] + axis = [i for i, d in enumerate(input.shape) if d in normalized_shape] + + # Calculate mean and standard deviation + mean = jnp.mean(input, axis=axis, keepdims=True) + var = jnp.var(input, axis=axis, keepdims=True) + rstd = jax.lax.rsqrt(var + eps) + + # Normalize the input + norm_x = (input - mean) * rstd + + # Apply affine transformation (if provided) + if weight is not None: + norm_x *= weight + if bias is not None: + norm_x += bias + return norm_x, mean, rstd + + +# - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +@op(torch.ops.aten.addmm) +@op(torch.ops.aten.addmv) +def _aten_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0): + alpha = jnp.array(alpha).astype(mat1.dtype) + beta = jnp.array(beta).astype(mat1.dtype) + self *= beta + self += alpha * jnp.matmul(mat1, mat2) + return self + + +@op(torch.ops.aten.addbmm.default) +def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): + alpha = jnp.array(alpha).astype(batch1.dtype) + beta = jnp.array(beta).astype(batch1.dtype) + mm = jnp.einsum("bxy, byz -> xz", batch1, batch2) + return jax.lax.cond( + beta == 0, lambda: alpha * mm, lambda: beta * input + alpha * mm + ) + + +@op(torch.ops.aten.gelu) +def _aten_gelu(self, *, approximate="none"): + approx = approximate == "tanh" + return jax.nn.gelu(self, approx) + + +@op(torch.ops.aten.squeeze) +@op(torch.ops.aten.squeeze_copy) +def _aten_squeeze_dim(self, dim): + """Squeezes a Jax tensor by removing a single dimension of size 1. + + Args: + self: The input tensor. + dim: The dimension to squeeze. + + Returns: + The squeezed tensor with the specified dimension removed if it is 1, + otherwise the original tensor is returned. + """ + + # Validate input arguments + if not isinstance(self, jnp.ndarray): + raise TypeError(f"Expected a Jax tensor, got {type(self)}.") + if isinstance(dim, int): + dim = [dim] + + # Check if the specified dimension has size 1 + if all([self.shape[d] != 1 for d in dim]): + return self + + # Use slicing to remove the dimension if it is 1 + new_shape = list(self.shape) + + def fix_dim(p): + if p < 0: + return p + len(self.shape) + return p + + dim = [fix_dim(d) for d in dim] + new_shape = [p for i, p in enumerate(self.shape) if i not in dim or p != 1] + return self.reshape(new_shape) + + +@op(torch.ops.aten.convolution) +def _aten_convolution( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, +): + if transposed: + raise NotImplementedError("Transposed convolution is not implemented.") + + def make_padding(padding): + return ((p, p) for p in padding) + + def create_default_conv_dimension_numbers(num_spatial_dims): + # Ref: https://github.com/openxla/xla/blob/main/xla/client/xla_builder.cc#L4211 + # (batch dimension, feature dimension, spatial dimensions...) + lhs_spec = [0, 1] + # (out feature dimension, in feature dimension, spatial dimensions...) + rhs_spec = [0, 1] + # (batch dimension, feature dimension, spatial dimensions...) + out_spec = [0, 1] + for i in range(0, num_spatial_dims): + lhs_spec.append(i + 2) + rhs_spec.append(i + 2) + out_spec.append(i + 2) + return jax.lax.ConvDimensionNumbers( + *map(tuple, (lhs_spec, rhs_spec, out_spec)) + ) + + res = jax.lax.conv_general_dilated( + input, + weight, + stride, + make_padding(padding), + lhs_dilation=(1,) * len(stride), + rhs_dilation=dilation, + dimension_numbers=create_default_conv_dimension_numbers(len(stride)), + feature_group_count=groups, + batch_group_count=1, + ) + + if bias is not None: + # TODO(qihqi): bias always on channel? + if len(bias.shape) == 1: + shape = [1] * len(res.shape) + shape[1] = bias.shape[0] + bias = bias.reshape(tuple(shape)) + res = res + bias + return res + + +# _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) +@op(torch.ops.aten._native_batch_norm_legit) +def _aten__native_batch_norm_legit( + input, weight, bias, running_mean, running_var, training, momentum, eps +): + return _aten__native_batch_norm_legit_no_training( + input, weight, bias, running_mean, running_var, momentum, eps + ) + + +@op(torch.ops.aten._native_batch_norm_legit_no_training) +def _aten__native_batch_norm_legit_no_training( + input, weight, bias, running_mean, running_var, momentum, eps +): + if weight is None: + weight = jnp.ones_like(running_mean) + if bias is None: + bias = jnp.zeros_like(running_mean) + + def broadcast(t): + return jax.lax.broadcast_in_dim(t, input.shape, broadcast_dimensions=(1,)) + + if running_mean is not None: + a = input - broadcast(running_mean) + else: + a = input + if running_var is not None: + b = broadcast(jnp.sqrt(running_var + eps)) + else: + b = broadcast(jnp.sqrt(eps)) + return ( + a / b * broadcast(weight) + broadcast(bias), + jnp.array([]), + jnp.array([]), + ) + + +@op(torch.ops.aten.relu) +def _aten_relu(self): + return jax.nn.relu(self) + + +@op(torch.ops.aten.cat) +def _aten_cat(tensors, dims=0): + return jnp.concatenate(tensors, dims) + + +@op(torch.ops.aten.max_pool2d_with_indices) +@op(torch.ops.aten.max_pool3d_with_indices) +def _aten_max_pool2d_with_indices( + inputs, kernel_size, strides, padding=0, dilation=1, ceil_mode=False +): + num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 + kernel_size = tuple(kernel_size) + strides = tuple(strides) + if isinstance(padding, int): + padding = tuple((padding, padding) for _ in range(len(kernel_size))) + elif isinstance(padding, list): + padding = tuple((p, p) for p in padding) + + window_shape = kernel_size + num_batch_dims = inputs.ndim - (len(window_shape) + 1) + strides = strides or (1,) * len(window_shape) + assert len(window_shape) == len( + strides + ), f"len({window_shape}) must equal len({strides})" + strides = (1,) * (1 + num_batch_dims) + strides + dims = (1,) * (1 + num_batch_dims) + window_shape + + is_single_input = False + if num_batch_dims == 0: + # add singleton batch dimension because lax.reduce_window always + # needs a batch dimension. + inputs = inputs[None] + strides = (1,) + strides + dims = (1,) + dims + is_single_input = True + + assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})" + if not isinstance(padding, str): + padding = tuple(map(tuple, padding)) + assert len(padding) == len(window_shape), ( + f"padding {padding} must specify pads for same number of dims as " + f"window_shape {window_shape}" + ) + assert all( + [len(x) == 2 for x in padding] + ), f"each entry in padding {padding} must be length 2" + padding = ((0, 0), (0, 0)) + padding + + indices = jnp.arange(np.prod(inputs.shape)).reshape(inputs.shape) + + def reduce_fn(a, b): + ai, av = a + bi, bv = b + which = av > bv + return jnp.where(which, ai, bi), jnp.where(which, av, bv) + + init_val = -jnp.inf + if inputs.dtype in (jnp.int32, jnp.int64): + init_val = -(1 << 31) + init_val = jnp.array(init_val).astype(inputs.dtype) + + indices, y = jax.lax.reduce_window( + (indices, inputs), (0, init_val), reduce_fn, dims, strides, padding + ) + if is_single_input: + indices = jnp.squeeze(indices, axis=0) + y = jnp.squeeze(y, axis=0) + return y, indices + + batch_result = pool( + inputs, -jnp.inf, jax.lax.max, kernel_size, strides, padding + ) + indices = pool(inputs, 0, jnp.argmax, kernel_size, strides, padding) + return batch_result, indices + + +# TODO add more ops + + +@op(torch.ops.aten.min) +def _aten_min(x, axis=None): + return jnp.min(x, axis=axis), jnp.argmin(x, axis=axis).astype(jnp.int64) + + +@op(torch.ops.aten.amin) +def _aten_amin(x, dim=None, keepdim=False): + return _with_reduction_scalar(jnp.amin, x, dim, keepdim) + + +@op(torch.ops.aten.argmin) +def _aten_argmin(self, dim=None, keepdim=False): + return _with_reduction_scalar(jnp.argmin, self, dim, keepdim) + + +@op(torch.ops.aten.sin) +def _aten_sin(x): + return jnp.sin(x) + + +@op(torch.ops.aten.sym_size) +def _aten_sym_size(x, dim): + return x.shape[dim] + + +@op(torch.ops.aten.var.correction) +@op(torch.ops.prims.var) +def _aten_var(x, dim=None, *, correction=1, keepdim=False, out=None): + return jnp.var(x, axis=dim, ddof=correction, keepdims=keepdim) + + +@op(torch.ops.prims.broadcast_in_dim) +def _prims_broadcast_in_dim(t, shape, broadcast_dimensions): + return jax.lax.broadcast_in_dim( + t, shape, broadcast_dimensions=broadcast_dimensions + ) + + +# aten.native_group_norm -- should use decomp table +# func: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) + + +@op(torch.ops.aten.native_group_norm) +def _aten_native_group_norm(input, weight, bias, N, C, HxW, group, eps=1e-5): + """Group Normalization implementation in JAX. + + Args: + input: Input tensor. Expected shape (batch_size, channels, ... spatial dims + ...) + weight: Optional scaling (gamma) parameter. Shape (channels,) + bias: Optional shifting (beta) parameter. Shape (channels,) + N: Batch size. + C: Number of channels. + HxW: Product of spatial dimensions (number of elements per channel after + flattening). + group: Number of groups for Group Normalization. + eps: Small value added for numerical stability. + + Returns: + A tuple of (normalized_output, mean, rstd) + """ + + input_shape = input.shape + + # Reshape for group-wise normalization + reshaped_input = jnp.reshape(input, (1, N * group, -1)) + + # **Core Group Normalization** + def group_norm_body(x): # Function to apply within each group + mean = jnp.mean(x, axis=-1, keepdims=True) + var = jnp.var(x, axis=-1, keepdims=True) + rstd = jax.lax.rsqrt(var + eps) # Reciprocal of std with epsilon + normalized = (x - mean) * rstd + return normalized, mean, rstd + + normalized, group_mean, group_rstd = jax.lax.map( + group_norm_body, reshaped_input + ) + + # Reshape back to original input shape + output = jnp.reshape(normalized, input_shape) + + # **Affine transformation** + affine_shape = [ + -1 if i == 1 else 1 for i in range(input.ndim) + ] # Shape for broadcasting + if weight is not None and bias is not None: + output = bias.reshape(affine_shape) + output * weight.reshape(affine_shape) + elif weight is not None: + output = output * weight.reshape(affine_shape) + elif bias is not None: + output = output + bias.reshape(affine_shape) + + # Reshape mean and rstd + mean = jnp.reshape(group_mean, (N, group)) + rstd = jnp.reshape(group_rstd, (N, group)) + + return output, mean, rstd + + +@op(torch.ops.aten.linalg_vector_norm) +def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): + """Calculates the vector norm along specified dimensions. + + Args: + self: The input tensor. + ord: The order of the norm. Can be a float or 'inf', '-inf', 'fro'. + Default is 2 (Euclidean norm). + dim: Dimensions along which to calculate the norm. If None, the norm is + calculated over all dimensions. + keepdim: Whether to keep the reduced dimensions. + dtype: Optional data type for the output. + + Returns: + The tensor containing the calculated vector norms. + """ + + if ord not in {2, float("inf"), float("-inf"), "fro"}: + raise ValueError( + f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and" + " 'fro'." + ) + + # Special cases (for efficiency and clarity) + if ord == 2: # Euclidean norm + result = jnp.sqrt(jnp.sum(jnp.abs(self) ** 2, axis=dim, keepdims=keepdim)) + + elif ord == float("inf"): + result = jnp.max(jnp.abs(self), axis=dim, keepdims=keepdim) + + elif ord == float("-inf"): + result = jnp.min(jnp.abs(self), axis=dim, keepdims=keepdim) + + elif ord == "fro": # Frobenius norm + result = jnp.sqrt(jnp.sum(jnp.abs(self) ** 2, axis=dim, keepdims=keepdim)) + + else: # General case (e.g., ord = 1, ord = 3) + result = jnp.sum(jnp.abs(self) ** ord, axis=dim, keepdims=keepdim) ** ( + 1.0 / ord + ) + + # (Optional) dtype conversion + if dtype is not None: + result = result.astype(dtype) + + return result + + +# aten.reflection_pad1d +@op(torch.ops.aten.reflection_pad1d) +def _aten_reflection_pad1d(input, padding): + rank = len(input.shape) + pad_size = [(0, 0)] * rank + pad_size[-1] = padding + return jnp.pad(input, pad_size, mode="reflect") + + +# aten.alias +@op(torch.ops.aten.alias) +def _aten_alias(self, *args): + return self + + +# aten.sinh +@op(torch.ops.aten.sinh) +def _aten_sinh(self): + return jnp.sinh(self) + + +# aten.native_layer_norm_backward +@op(torch.ops.aten.native_layer_norm_backward) +def _aten_native_layer_norm_backward( + grad_out, input, normalized_shape, weight, bias, eps=1e-5 +): + """Implements the backward pass of layer normalization in Jax as defined by `aten::native_layer_norm_backward`. + + Args: + grad_out: The gradient of the output tensor. + input: The input tensor. + normalized_shape: A list of integer dimensions to be normalized over. + weight: Optional weight tensor for the affine transformation. + bias: Optional bias tensor for the affine transformation. + eps: A small epsilon value for numerical stability. + + Returns: + A tuple of (grad_input, grad_weight, grad_bias). + """ + return jax.lax.native_layer_norm_backward( + grad_out, input, normalized_shape, weight, bias, eps + ) + + +# aten.reflection_pad3d_backward +# aten.reflection_pad2d + + +# aten.atanh +@op(torch.ops.aten.atanh) +def _aten_atanh(self): + return jnp.arctanh(self) + + +# aten.bitwise_not +@op(torch.ops.aten.bitwise_not) +def _aten_bitwise_not(self): + return ~self + + +# aten.embedding_dense_backward + + +# aten.sum +@op(torch.ops.aten.sum) +def _aten_sum(self, dim=None, keepdim=False, dtype=None): + if not dim: + dim = None + return jnp.sum(self, axis=dim, keepdims=keepdim, dtype=dtype) + + +# aten.sqrt +@op(torch.ops.aten.sqrt) +def _aten_sqrt(self): + return jnp.sqrt(self) + + +@op(torch.ops.aten.tan) +def _aten_tanh(self): + return jnp.tan(self) + + +# aten.tanh +@op(torch.ops.aten.tanh) +def _aten_tanh(self): + return jnp.tanh(self) + + +# aten.ceil +@op(torch.ops.aten.ceil) +def _aten_ceil(self): + return jnp.ceil(self) + + +# aten.asin +@op(torch.ops.aten.asin) +def _aten_asin(self): + return jnp.arcsin(self) + + +# aten.minimum +@op(torch.ops.aten.minimum) +def _aten_minimum(self, other): + return jnp.minimum(self, other) + + +# aten.max_pool2d_backward + + +def _scatter_index(dim, index): + """Returns a tuple of indexes; + + The first is to select in input (to modify), + the second is to select from the values. + """ + index_shape = list(index.shape) + input_indexes = [] + source_indexes = [] + for i in range(len(index_shape)): + source_indexes.append(slice(0, index_shape[i])) + if i == dim: + input_indexes.append(index) + else: + target_shape = [1] * len(index_shape) + target_shape[i] = index_shape[i] + input_indexes.append( + jnp.broadcast_to( + jnp.arange(index_shape[i]).reshape(target_shape), index_shape + ) + ) + return tuple(input_indexes), tuple(source_indexes) + + +# aten.scatter_add +@op(torch.ops.aten.scatter_add) +def _aten_scatter_add(input, dim, index, src): + """JAX implementation of scatter, mimicking torch.scatter behavior""" + + input_indexes, source_indexes = _scatter_index(dim, index) + return input.at[input_indexes].add(src[source_indexes]) + + +# aten.logical_not + + +# aten.sign +@op(torch.ops.aten.sign) +def _aten_sign(x): + return jnp.sign(x) + + +# aten.sigmoid +@op(torch.ops.aten.sigmoid) +def _aten_sigmoid(x): + if x.dtype in (jnp.int32, jnp.int64): + x = x.astype(jnp.float32) + return jax.nn.sigmoid(x) + + +# implement aten.asinh in jax +@op(torch.ops.aten.asinh) +def _aten_asinh(self): + return jnp.arcsinh(self) + + +# aten.atan +@op(torch.ops.aten.atan) +def _aten_atan(self): + return jnp.arctan(self) + + +# aten.scatter_reduce +@op(torch.ops.aten.scatter_reduce) +def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True): + input_indexes, source_indexes = _scatter_index(dim, index) + if reduce == "sum": + return input.at[input_indexes].add(src[source_indexes]) + elif reduce == "prod": + return input.at[input_indexes].multiply(src[source_indexes]) + elif reduce == "mean": + return input.at[input_indexes].add(src[source_indexes]) + elif reduce == "amax": + return input.at[input_indexes].max(src[source_indexes]) + elif reduce == "amin": + return input.at[input_indexes].min(src[source_indexes]) + else: + raise RuntimeError("Unknow reduction type: ", reduce) + + +# aten.acos +@op(torch.ops.aten.acos) +def _aten_acos(self): + return jnp.arccos(self) + + +# aten.sym_storage_offset +# aten.native_layer_norm_backward +# aten.max_pool3d_with_indices + + +# aten.gt +@op(torch.ops.aten.gt) +def _aten_gt(self, other): + return self > other + + +# aten.pixel_shuffle +@op(torch.ops.aten.pixel_shuffle) +def _aten_pixel_shuffle(x, upscale_factor): + """PixelShuffle implementation in JAX. + + Args: + x: Input tensor. Typically a feature map. + upscale_factor: Integer by which to upscale the spatial dimensions. + + Returns: + Tensor after PixelShuffle operation. + """ + + batch_size, channels, height, width = x.shape + + if channels % (upscale_factor**2) != 0: + raise ValueError( + "Number of channels must be divisible by the square of the upscale factor." + ) + + new_channels = channels // (upscale_factor**2) + new_height = height * upscale_factor + new_width = width * upscale_factor + + x = x.reshape( + batch_size, new_channels, upscale_factor, upscale_factor, height, width + ) + x = jnp.transpose( + x, (0, 1, 2, 4, 3, 5) + ) # Move channels to spatial dimensions + x = x.reshape(batch_size, new_channels, new_height, new_width) + + return x + + +# aten.sym_stride +# aten.lt +@op(torch.ops.aten.lt) +def _aten_lt(self, other): + return self < other + + +def pool(inputs, init, reduce_fn, window_shape, strides, padding): + """Helper function to define pooling functions. + + Pooling functions are implemented using the ReduceWindow XLA op. + NOTE: Be aware that pooling is not generally differentiable. + That means providing a reduce_fn that is differentiable does not imply that + pool is differentiable. + + Args: + inputs: input data with dimensions (batch, window dims..., features). + init: the initial value for the reduction + reduce_fn: a reduce function of the form ``(T, T) -> T``. + window_shape: a shape tuple defining the window to reduce over. + strides: a sequence of ``n`` integers, representing the inter-window + strides (default: ``(1, ..., 1)``). + padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence + of ``n`` ``(low, high)`` integer pairs that give the padding to apply before + and after each spatial dimension. + Returns: + The output of the reduction for each window slice. + """ + num_batch_dims = inputs.ndim - (len(window_shape) + 1) + strides = strides or (1,) * len(window_shape) + assert len(window_shape) == len( + strides + ), f"len({window_shape}) must equal len({strides})" + strides = (1,) * (1 + num_batch_dims) + strides + dims = (1,) * (1 + num_batch_dims) + window_shape + + is_single_input = False + if num_batch_dims == 0: + # add singleton batch dimension because lax.reduce_window always + # needs a batch dimension. + inputs = inputs[None] + strides = (1,) + strides + dims = (1,) + dims + is_single_input = True + + assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})" + if not isinstance(padding, str): + padding = tuple(map(tuple, padding)) + assert len(padding) == len(window_shape), ( + f"padding {padding} must specify pads for same number of dims as " + f"window_shape {window_shape}" + ) + assert all( + [len(x) == 2 for x in padding] + ), f"each entry in padding {padding} must be length 2" + padding = ((0, 0), (0, 0)) + padding + y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding) + if is_single_input: + y = jnp.squeeze(y, axis=0) + return y + + +@op(torch.ops.aten._adaptive_avg_pool3d) +def _aten_adaptive_avg_pool3d(x, output_shape): + return _aten_adaptive_avg_pool(x, output_shape, 3) + + +@op(torch.ops.aten._adaptive_avg_pool2d) +def _aten_adaptive_avg_pool3d(x, output_shape): + return _aten_adaptive_avg_pool(x, output_shape, 2) + + +def _aten_adaptive_avg_pool(x, output_shape, pool_dim): + def adaptive_kernel_size(input_shape, output_shape): + sizes = [1, 1] + spatial_dim_off = len(input_shape) - pool_dim + for spatial_dim in range(pool_dim): + sizes.append( + input_shape[spatial_dim_off + spatial_dim] // output_shape[spatial_dim] + ) + return tuple(sizes) + + kernel_sizes = adaptive_kernel_size(x.shape, output_shape) + y = pool(x, 0.0, jax.lax.add, kernel_sizes, kernel_sizes, padding="VALID") + + div_shape = list(x.shape) + num_batch_dims = len(x.shape) - pool_dim - 1 + div_shape[num_batch_dims] = 1 + div_shape = tuple(div_shape) + if len(div_shape) - 2 == len(kernel_sizes): + div_shape = (1,) + div_shape[1:] + y = y / pool( + jnp.ones(div_shape), 0.0, jax.lax.add, kernel_sizes, kernel_sizes, "VALID" + ) + return y + + +# aten.avg_pool2d +@op(torch.ops.aten.avg_pool2d) +@op(torch.ops.aten.avg_pool3d) +def _aten_avg_pool( + inputs, + kernel_size, + strides=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 + kernel_size = tuple(kernel_size) + strides = tuple(strides) + if isinstance(padding, int): + padding = tuple((padding, padding) for _ in range(len(kernel_size))) + elif isinstance(padding, list): + padding = tuple((p, p) for p in padding) + + y = pool(inputs, 0.0, jax.lax.add, kernel_size, strides, padding) + if count_include_pad: + y = y / np.prod(kernel_size) + else: + div_shape = list(inputs.shape) + div_shape[num_batch_dims] = 1 + div_shape = tuple(div_shape) + if len(div_shape) - 2 == len(kernel_size): + div_shape = (1,) + div_shape[1:] + y = y / pool( + jnp.ones(div_shape), 0.0, jax.lax.add, kernel_size, strides, padding + ) + return y + + +# aten.sym_numel +# aten.reciprocal +@op(torch.ops.aten.reciprocal) +def _aten_reciprocal(a): + return 1 / a + + +# aten.scatter +@op(torch.ops.aten.select_scatter) +def _aten_select_scatter(input, src, dim, index): + input_indexes = [] + for x in range(len(input.shape)): + if x == dim: + input_indexes.append(index) + else: + input_indexes.append(slice(None, None, None)) + return input.at[tuple(input_indexes)].set(src) + + +@op(torch.ops.aten.scatter.src) +def _aten_scatter_src(input, dim, index, src, reduce=None): + input_index, source_indexes = _scatter_index(dim, index) + return input.at[input_index].set(src[source_indexes]) + + +@op(torch.ops.aten.scatter.value) +def _aten_scatter(input, dim, index, src, reduce=None): + input_index, source_indexes = _scatter_index(dim, index) + return input.at[input_index].set(src) + + +# aten.acosh +@op(torch.ops.aten.acosh) +def _aten_acosh(self): + return jnp.arccosh(self) + + +# aten.avg_pool2d_backward +# aten.col2im +# aten.avg_pool3d +# aten.round +@op(torch.ops.aten.round) +def _aten_round(input, decimals=0): + return jnp.round(input, decimals) + + +# aten.max +@op(torch.ops.aten.max) +def _aten_max(self, dim=None, keepdim=False): + return jnp.max(self, axis=dim, keepdims=keepdim), jnp.argmax( + self, axis=dim, keepdims=keepdim + ) + + +# aten.maximum +@op(torch.ops.aten.maximum) +def _aten_maximum(self, other): + return jnp.maximum(self, other) + + +# aten.abs +@op(torch.ops.aten.abs) +def _aten_abs(self): + return jnp.abs(self) + + +# generate aten.amax only +@op(torch.ops.aten.amax) +def _aten_amax(self, dim=None, keepdim=False): + return _with_reduction_scalar(jnp.amax, self, dim, keepdim) + + +def _with_reduction_scalar(jax_func, self, dim, keepdim): + expanded = False + if self.ndim == 0: + # for self of rank 0: + # torch.any(x, 0), torch.any(x, -1) works; + # torch.any(x, 1) throws out of bounds, so it's + # behavior is the same as a jnp array of rank 1 + expanded = True + self = jnp.expand_dims(self, 0) + res = jax_func(self, axis=dim, keepdims=keepdim) + if expanded: + res = res.squeeze() + return res + + +# aten.any +@op(torch.ops.aten.any) +def _aten_any(self, dim=None, keepdim=False): + return _with_reduction_scalar(jnp.any, self, dim, keepdim) + + +# aten.arange +@op(torch.ops.aten.arange.start_step) +@op(torch.ops.aten.arange.start) +@op(torch.ops.aten.arange.default) +def _aten_arange( + start, + end=None, + step=1, + *, + dtype=None, + layout=None, + requires_grad=False, + device=None, + pin_memory=False, +): + if end is None: + end = start + start = 0 + if dtype: + dtype = tensor.t2j_dtype(dtype) + return jnp.arange( + start, + end, + step, + dtype=dtype, + ) + + +# aten.argmax +@op(torch.ops.aten.argmax) +def _aten_argmax(self, dim=None, keepdim=False): + return _with_reduction_scalar(jnp.argmax, self, dim, keepdim) + + +# aten.as_strided +@op(torch.ops.aten.as_strided) +@op(torch.ops.aten.as_strided_copy) +def _aten_as_strided(x, sizes, strides, storage_offset=None): + ind = jnp.zeros(sizes, dtype=jnp.int32) + + for i, (size, stride) in enumerate(zip(sizes, strides)): + result_shape = (1,) * i + (size,) + (1,) * (len(sizes) - i - 1) + indexes = (jnp.arange(size) * stride).reshape(result_shape) + ind += indexes + + return jnp.ravel(x)[ind] + + +# aten.atan2 +@op(torch.ops.aten.atan2) +def _aten_atan2(self, other): + return jnp.arctan2(self, other) + + +# aten.bitwise_and +@op(torch.ops.aten.bitwise_and) +def _aten_bitwise_and(self, other): + return self & other + + +# aten.bitwise_or +@op(torch.ops.aten.bitwise_or) +def _aten_bitwise_or(self, other): + return self | other + + +# aten.bitwise_xor +@op(torch.ops.aten.bitwise_xor) +def _aten_bitwise_xor(self, other): + return self ^ other + + +# aten.clamp +@op(torch.ops.aten.clamp.default) +@op(torch.ops.aten.clamp.Tensor) +def _aten_clamp(self, min=None, max=None): + return jnp.clip(self, min, max) + + +# aten.constant_pad_nd +@op(torch.ops.aten.constant_pad_nd) +def _aten_constant_pad_nd(input, padding, value=0): + # NOTE: Torch padding is flat and reversed: (1, 1, 2, 2) + # means last dim get padded 1 in front and 1 in back; + # and second last dim get padded 2 in front and 2 in back. + # Jax padding tuple of 2-tuple: the same padding is + # [(0, 0), ..., (2,2), (1,1)] + m = len(padding) + rev_padding = [(padding[i - 1], padding[i]) for i in range(m - 1, 0, -2)] + pad_dim = tuple(([(0, 0)] * (len(input.shape) - m // 2)) + rev_padding) + return jnp.pad(input, pad_dim, mode="constant", constant_values=value) + + +# aten.convolution_backward +@op(torch.ops.aten.copy) +@op(torch.ops.aten.lift_fresh_copy) +def _aten_copy(x): + return jnp.copy(x) + + +@op(torch.ops.aten._cdist_forward) +def _aten_cdist_forward(x1, x2, p, compute_mode=""): + # x1 is B x P x M + # x2 is B x Q x M + # res is B x P x Q + x1 = jnp.expand_dims(x1, len(x1.shape) - 1) + x2 = jnp.expand_dims(x2, len(x2.shape) - 2) + return jnp.linalg.norm(x1 - x2, ord=p, axis=-1) + + +@op(torch.ops.aten._pdist_forward) +def _aten__pdist_forward(x, p): + pairwise_dists = _aten_cdist_forward(x, x, p) + condensed_dists = pairwise_dists[ + jnp.triu_indices(pairwise_dists.shape[0], k=1) + ] + return condensed_dists + + +# aten.cos +@op(torch.ops.aten.cos) +def _aten_cos(input): + return jnp.cos(input) + + +# aten.cosh +@op(torch.ops.aten.cosh) +def _aten_cosh(input): + return jnp.cosh(input) + + +# aten.diagonal +@op(torch.ops.aten.diagonal) +def _aten_diagonal(input, offset=0, dim1=0, dim2=1): + return jnp.diagonal(input, offset, dim1, dim2) + + +# aten.empty_strided +# aten.eq +@op(torch.ops.aten.eq) +def _aten_eq(input1, input2): + return input1 == input2 + + +# aten.erf +@op(torch.ops.aten.erf) +def _aten_erf(x): + if x.dtype in (jnp.int32, jnp.int64): + x = x.astype(jnp.float32) + return jax.lax.erf(x) + + +# aten.exp +@op(torch.ops.aten.exp) +def _aten_exp(input): + return jnp.exp(input) + + +# aten.expm1 +@op(torch.ops.aten.expm1) +def _aten_expm1(input): + return jnp.expm1(input) + + +# aten.fill +@op(torch.ops.aten.fill) +@op(torch.ops.aten.full_like) +def _aten_fill(x, value, dtype=None, pin_memory=None, memory_format=None): + if dtype is None: + dtype = x.dtype + else: + dtype = tensor.t2j_dtype(dtype) + return jnp.full(x.shape, value, dtype) + + +# aten.flip +@op(torch.ops.aten.flip) +def _aten_flip(input, dims): + if dims is not None: + return jnp.flip(input, tuple(dims)) + else: + return jnp.flip(input) + + +# aten.floor +@op(torch.ops.aten.floor) +def _aten_floor(input): + return jnp.floor(input) + + +# aten.fmod +@op(torch.ops.aten.fmod) +def _aten_fmod(input, other): + return input - other * _aten_div(input, other, "trunc") + + +# aten.gather +@op(torch.ops.aten.gather) +def _aten_gather(input, dim, index): + input_indexes, source_indexes = _scatter_index(dim, index) + return input[input_indexes] + + +# aten.ge +@op(torch.ops.aten.ge) +def _aten_ge(self, other): + return self >= other + + +@op(torch.ops.aten.glu) +@op(torch.ops.aten.glu.default) +def _aten_glu(x, dim=-1): + return jax.nn.glu(x, dim) + + +# aten.hardtanh +@op(torch.ops.aten.hardtanh) +def _aten_hardtanh(input, min_val=-1.0, max_val=1.0, inplace=False): + return jnp.clip(input, min_val, max_val) + + +# aten.isinf +@op(torch.ops.aten.isinf) +def _aten_isinf(input): + return jnp.isinf(input) + + +# aten.isnan +@op(torch.ops.aten.isnan) +def _aten_isnan(input): + return jnp.isnan(input) + + +@op(torch.ops.aten.le) +def _aten_le(self, other): + return self <= other + + +# aten.leaky_relu +@op(torch.ops.aten.leaky_relu) +def _aten_leaky_relu(x, negative_slope): + return jax.nn.leaky_relu(x, negative_slope) + + +# aten.log +@op(torch.ops.aten.log) +def _aten_log(x): + return jnp.log(x) + + +# aten.log10 +@op(torch.ops.aten.log10) +def _aten_log10(x): + return jnp.log10(x) + + +# aten.log1p +@op(torch.ops.aten.log1p) +def _aten_log1p(x): + return jnp.log1p(x) + + +# aten.log2 +@op(torch.ops.aten.log2) +def _aten_log2(x): + return jnp.log2(x) + + +# aten.logical_and +@op(torch.ops.aten.logical_and) +def _aten_logical_and(self, other): + return jnp.logical_and(self, other) + + +# aten.logical_or +@op(torch.ops.aten.logical_or) +def _aten_logical_or(self, other): + return jnp.logical_or(self, other) + + +# aten.logical_not +@op(torch.ops.aten.logical_not) +def _aten_logical_not(self): + return jnp.logical_not(self) + + +# aten.log_softmax +@op(torch.ops.aten._log_softmax) +def _aten_log_softmax(self, axis=-1, half_to_float=False): + return jax.nn.log_softmax(self, axis) + + +# aten.max_pool3d_backward +# aten.logical_xor +@op(torch.ops.aten.logical_xor) +def _aten_logical_xor(self, other): + return jnp.logical_xor(self, other) + + +# aten.max_pool2d_with_indices_backward +# aten.native_dropout +# aten.native_group_norm_backward +# aten.neg +@op(torch.ops.aten.neg) +def _aten_neg(x): + return -1 * x + + +# aten.nonzero +@op(torch.ops.aten.nonzero) +def _aten_nonzero(x): + index_tuple = jnp.nonzero(x) + index_tuple = [jnp.expand_dims(p, -1) for p in index_tuple] + return jnp.concatenate(index_tuple, axis=-1) + + +# aten.prod + + +@op(torch.ops.aten.prod) +def _aten_prod(self, dim=None, keepdim=False): + return jnp.prod(self, axis=dim, keepdims=keepdim) + + +# aten.randperm + + +# aten.reflection_pad3d + + +# aten.remainder +@op(torch.ops.aten.remainder) +def _aten_remainder(inputs, other): + return inputs % other + + +# aten.repeat +@op(torch.ops.aten.repeat) +def _aten_repeat(x, reps): + return jnp.tile(x, reps) + + +# aten.replication_pad2d +# aten.replication_pad3d +# aten.roll +@op(torch.ops.aten.roll) +def _aten_roll(input, shifts, dims=None): + return jnp.roll(input, shifts, dims) + + +# aten.scalar_tensor +# aten.slice_scatter +@op(torch.ops.aten.slice_scatter) +def _aten_slice_scatter(input, src, dim=0, start=None, end=None, step=1): + input_index = [] + for x in range(len(input.shape)): + if x == dim: + input_index.append(slice(start, end, step)) + else: + input_index.append(slice(None, None, None)) + return input.at[tuple(input_index)].set(src) + + +# aten.sort +# torch.sort(input, dim=-1, descending=False, stable=False, *, out=None) +@op(torch.ops.aten.sort) +def _aten_sort(a, dim=-1, descending=False, stable=False): + return ( + jnp.sort(a, axis=dim, stable=stable, descending=descending), + jnp.argsort(a, axis=dim, stable=stable, descending=descending), + ) + + +# aten.sym_size + + +# aten.topk +@op(torch.ops.aten.topk) +def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None): + """JAX top-k implementation using jax.lax.top_k for improved efficiency. + + Args: + input: The input JAX array. + k: The number of top elements to return. + dim: The dimension along which to find the top-k. If None, operates on the + flattened array. + largest: If True, returns the largest k elements. Otherwise, smallest k. + sorted: If True, returns the elements in sorted order. + + Returns: + A tuple (values, indices) containing: + - values: The top k values. + - indices: The indices of the top k values in the original array. + """ + if dim is None: + input = input.flatten() + dim = 0 + + if not largest: + input = -input # Find top-k of negated input if we want the smallest + + transpose_shape = None + if dim != -1 and dim != len(input.shape) - 1: + transpose_shape = list(range(len(input.shape))) + transpose_shape[dim], transpose_shape[-1] = ( + transpose_shape[-1], + transpose_shape[dim], + ) + input = jnp.transpose(input, transpose_shape) + + values, indices = jax.lax.top_k(input, k) + + if sorted: + values = jnp.sort(values, descending=True) + indices = jnp.take_along_axis( + indices, jnp.argsort(values, axis=-1, descending=True), axis=-1 + ) + + if not largest: + values = -values # Negate values back if we found smallest + + if transpose_shape is not None: + values = jnp.transpose(values, transpose_shape) + indices = jnp.transpose(indices, transpose_shape) + + return values, indices + + +# aten.trunc +@op(torch.ops.aten.trunc) +def _aten_trunc(a): + return jnp.trunc(a) + + +@op(torch.ops.aten.unbind) +@op(torch.ops.aten.unbind_copy) +def _aten_unbind(a, dim=0): + return tuple( + _aten_squeeze_dim(jax.lax.index_in_dim(a, i, axis=dim), dim) + for i in range(a.shape[dim]) + ) + + +# NOTE: skip aten.upsample_nearest2d and aten.upsample_bilinear2d +# despite those being core aten ops, they also have decompositions. +# here we are using torch decompositions. + + +# aten.where +@op(torch.ops.aten.where.self) +@op(torch.ops.aten.where.ScalarSelf) +@op(torch.ops.aten.where.ScalarOther) +def _aten_where(condition, x, y): + return jnp.where(condition, x, y) + + +# aten.to.dtype +# Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None +@op(torch.ops.aten.to.dtype) +def _aten_to_dtype( + a, dtype, non_blocking=False, copy=False, memory_format=None +): + if dtype: + jaxdtype = tensor.t2j_dtype(dtype) + return a.astype(jaxdtype) + + +# aten.to.device + + +# Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False +@op(torch.ops.aten.var_mean.correction) +def _aten_var_mean_correction(self, dim=None, correction=None, keepdim=False): + return ( + jnp.var(self, axis=dim, ddof=correction, keepdims=keepdim), + jnp.mean(self, dim, keepdims=keepdim), + ) + + +@op(torch.ops.aten.scalar_tensor) +def _aten_scalar_tensor( + s, dtype=None, layout=None, device=None, pin_memory=None +): + if dtype is not None: + dtype = tensor.t2j_dtype(dtype) + return jnp.array(s, dtype=dtype) + return jnp.array(s) + + +@op(torch.ops.aten.to.device) +def _aten_to_device(x, device, dtype): + return x + + +@op(torch.ops.aten.max_pool2d_with_indices_backward) +def max_pool2d_with_indices_backward_custom( + grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices +): + """ + Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward. + + Args: + grad_output: The gradient tensor from the preceding layer. + self: The input tensor on which the original max pooling was performed. + kernel_size: The size of the pooling window. + stride: The stride of the pooling window. + padding: The padding applied during max pooling. + dilation: The dilation factor for the pooling operation. + ceil_mode: Whether to use ceil or floor when calculating output shapes. + indices: The indices of the maximum values, as produced by max_pool2d_with_indices. + + Returns: + The calculated gradient with respect to the input (grad_input). + """ + + kH, kW = kernel_size + dH, dW = stride + padH, padW = padding + dilH, dilW = dilation + + # Calculate output shape (may need adjustment based on ceil_mode) + out_shape = jnp.array(self.shape) + grad_input = jnp.zeros_like(self) + + # Iterate over the flattened input and output tensors + for i, idx in enumerate(indices.flatten()): + # Calculate input coordinates corresponding to the maximum value + out_y, out_x = i // grad_output.shape[3], i % grad_output.shape[3] + in_y = out_y * dH - padH + out_y * (dilH - 1) + in_x = out_x * dW - padW + out_x * (dilW - 1) + + # Scatter the gradient to the appropriate input locations (handling potential overlaps) + for y in range(in_y, in_y + kH): + for x in range(in_x, in_x + kW): + if 0 <= y < grad_input.shape[2] and 0 <= x < grad_input.shape[3]: + grad_input = grad_input.at[y, x].add(grad_output.flatten()[i]) + + return grad_input + + +@op(torch.ops.aten._local_scalar_dense) +def _aten_local_scalar_dense(x): + return x.item() + + +@op(torch.ops.aten.tensor_split.sections) +def _aten_tensor_split(ary, indices_or_sections, axis=0): + return jnp.array_split(ary, indices_or_sections, axis) + + +@op(torch.ops.aten.randn, needs_env=True) +def _randn( + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + env=None, +): + shape = size + if len(shape) == 1 and isinstance(shape[0], (list, tuple)): + shape = shape[0] + key = env.get_and_rotate_prng_key() + res = jax.random.normal(key, shape) + if dtype is not None: + dtype = tensor.t2j_dtype(dtype) + res = res.astype(dtype) + return res + + +@op(torch.ops.aten.rand, needs_env=True) +def _rand( + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + env=None, +): + shape = size + if len(shape) == 1 and isinstance(shape[0], (list, tuple)): + shape = shape[0] + key = env.get_and_rotate_prng_key() + res = jax.random.uniform(key, shape) + if dtype is not None: + dtype = tensor.t2j_dtype(dtype) + res = res.astype(dtype) + return res + + +@op(torch.ops.aten.scalar_tensor.default) +def _aten_scalar_tensor(val, **kwargs): + p = torch.ops.aten.scalar_tensor(val) + return tensor.t2j(p) + + +@op(torch.ops.aten.to.device) +def _aten_to_device(x, device, dtype): + return x + + +@op(torch.ops.aten.max_pool2d_with_indices_backward) +def max_pool2d_with_indices_backward_custom( + grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices +): + """ + Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward. + + Args: + grad_output: The gradient tensor from the preceding layer. + self: The input tensor on which the original max pooling was performed. + kernel_size: The size of the pooling window. + stride: The stride of the pooling window. + padding: The padding applied during max pooling. + dilation: The dilation factor for the pooling operation. + ceil_mode: Whether to use ceil or floor when calculating output shapes. + indices: The indices of the maximum values, as produced by max_pool2d_with_indices. + + Returns: + The calculated gradient with respect to the input (grad_input). + """ + + kH, kW = kernel_size + dH, dW = stride + padH, padW = padding + dilH, dilW = dilation + + # Calculate output shape (may need adjustment based on ceil_mode) + out_shape = jnp.array(self.shape) + grad_input = jnp.zeros_like(self) + + # Iterate over the flattened input and output tensors + for i, idx in enumerate(indices.flatten()): + # Calculate input coordinates corresponding to the maximum value + out_y, out_x = i // grad_output.shape[3], i % grad_output.shape[3] + in_y = out_y * dH - padH + out_y * (dilH - 1) + in_x = out_x * dW - padW + out_x * (dilW - 1) + + # Scatter the gradient to the appropriate input locations (handling potential overlaps) + for y in range(in_y, in_y + kH): + for x in range(in_x, in_x + kW): + if 0 <= y < grad_input.shape[2] and 0 <= x < grad_input.shape[3]: + grad_input = grad_input.at[y, x].add(grad_output.flatten()[i]) + + return grad_input + + +@op(torch.ops.aten._local_scalar_dense) +def _aten_local_scalar_dense(x): + return x.item() + + +@op(torch.ops.aten.tensor_split.sections) +def _aten_tensor_split(ary, indices_or_sections, axis=0): + return jnp.array_split(ary, indices_or_sections, axis) + + +@op(torch.ops.aten.outer) +def _aten_outer(a, b): + return jnp.outer(a, b) + + +@op(torch.ops.aten.allclose) +def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): + return jnp.allclose(input, other, rtol, atol, equal_nan) + diff --git a/experimental/torch_xla2/torch_xla2/ops/jtorch.py b/experimental/torch_xla2/torch_xla2/ops/jtorch.py index e69de29bb2d..ddc04fa4b1b 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jtorch.py +++ b/experimental/torch_xla2/torch_xla2/ops/jtorch.py @@ -0,0 +1,116 @@ +"""Tensor constructor overrides""" +import functools +from typing import Callable, Optional, ParamSpec, Sequence + +import jax +import torch +import jax.numpy as jnp +from torch_xla2 import tensor +from torch_xla2.ops.ops_registry import register_torch_function_op + +def register_function(torch_func, **kwargs): + return functools.partial(register_torch_function_op, torch_func, **kwargs) + + +P = ParamSpec('P') + + +def convert_dtype(use_default_dtype: bool = True): + """Converts `dtype` kwarg of function from torch to JAX. + + Args: + use_default_dtype: Whether to use torch default dtype if none is provided. + + Returns: + A decorator that wraps a JAX implementation of a torch function. + """ + + def decorator(func: Callable[P, torch.Tensor]): + + @functools.wraps(func) + def wrapper(*args: P.args, + dtype: Optional[torch.dtype] = None, + **kwargs: P.kwargs): + if not dtype and use_default_dtype: + dtype = torch.get_default_dtype() + jax_dtype = tensor.t2j_dtype(dtype) + + return func(*args, dtype=jax_dtype, **kwargs) + + return wrapper + + return decorator + + +@register_function(torch.tensor) +@convert_dtype(use_default_dtype=False) # Attempt to infer type from elements +def _tensor(data, *, dtype=None, **kwargs): + python_types_to_torch_types = { + bool: jnp.bool, + int: jnp.int64, + float: jnp.float32, + complex: jnp.complex64, + } + if not dtype: + leaves = jax.tree_util.tree_leaves(data) + if len(leaves) > 0: + dtype = python_types_to_torch_types.get(type(leaves[0])) + + return jnp.array( + data, dtype=dtype or tensor.t2j_dtype(torch.get_default_dtype())) + + +@register_function(torch.ones) +@convert_dtype() +def _ones(*size: int, dtype=None, **kwargs): + return jnp.ones(size, dtype) + + +@register_function(torch.zeros) +@convert_dtype() +def _zeros(*size: int, dtype=None, **kwargs): + return jnp.zeros(size, dtype) + + +@register_function(torch.eye) +@convert_dtype() +def _eye(n: int, m: Optional[int] = None, *, dtype=None, **kwargs): + return jnp.eye(n, m, dtype=dtype) + + +@register_function(torch.full) +@convert_dtype() +def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs): + # TODO: handle torch.Size + return jnp.full(size, fill_value, dtype=dtype) + + +@register_function(torch.allclose) +def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): + return jnp.allclose(input, other, rtol, atol, equal_nan) + +@register_function(torch.angle) +def _torch_angle(input): + return jnp.angle(input) + + +@register_function(torch.argsort) +def _torch_argsort(input, dim=-1, descending=False, stable=False): + expanded = False + if input == 0: + # for self of rank 0: + # torch.any(x, 0), torch.any(x, -1) works; + # torch.any(x, 1) throws out of bounds, so it's + # behavior is the same as a jnp array of rank 1 + expanded = True + input = jnp.expand_dims(input, 0) + res = jnp.argsort(input, axis=dim, descending=descending, + stable=stable) + if expanded: + res = res.squeeze() + return res + +@register_function(torch.einsum) +def _einsum(equation, *operands): + assert isinstance(equation, str), 'Only accept str equation' + return jnp.einsum(equation, *operands) \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/ops/op_base.py b/experimental/torch_xla2/torch_xla2/ops/op_base.py index 62df160edc9..983d20fb660 100644 --- a/experimental/torch_xla2/torch_xla2/ops/op_base.py +++ b/experimental/torch_xla2/torch_xla2/ops/op_base.py @@ -1,22 +1,11 @@ import torch -from torch_xla2 import extra - -class JaxOperator: - """This is a aten op backed by jax function.""" - - def __init__(self, jax_callable): - self.jax = jax_callable - - def __call__(self, *args, **kwargs): - # args are torch.Tensor - res = call_jax(self.jax, args, kwargs) - return res +from torch_xla2 import interop class BinaryOpWithPromotion: - def __init__(self, jax_callable): - self.jax = jax_callable + def __init__(self, inner): + self.inner = inner def _get_dtype(self, obj): if isinstance(obj, torch.Tensor): @@ -31,7 +20,7 @@ def _get_dtype(self, obj): def __call__(self, *args, **kwargs): # args are torch.Tensor - res = extra.torch_view(self.jax)(*args, **kwargs) + res = interop.torch_view(self.jax)(*args, **kwargs) dtype = torch.promote_types( self._get_dtype(args[0]), @@ -41,15 +30,6 @@ def __call__(self, *args, **kwargs): return res -class TorchLowering: - - def __init__(self, lowering): - self.lowering = lowering - - def __call__(self, *args, **kwargs): - return self.lowering(*args, **kwargs) - - class InplaceOp: def __init__(self, functional_op, position_to_mutate=0): @@ -58,7 +38,7 @@ def __init__(self, functional_op, position_to_mutate=0): def __call__(self, *args, **kwargs): to_mutate = args[0] - to_mutate._elem = self.functional(*args, **kwargs)._elem + to_mutate.copy_(self.functional(*args, **kwargs)) return to_mutate diff --git a/experimental/torch_xla2/torch_xla2/ops/ops_registry.py b/experimental/torch_xla2/torch_xla2/ops/ops_registry.py new file mode 100644 index 00000000000..e75d1549456 --- /dev/null +++ b/experimental/torch_xla2/torch_xla2/ops/ops_registry.py @@ -0,0 +1,47 @@ +import dataclasses +from torch_xla2.types import JaxCallable, TorchCallable + +from typing import Union, Dict + + +@dataclasses.dataclass +class Operator: + torch_op: TorchCallable + func: Union[TorchCallable, JaxCallable] + is_jax_function: bool + is_user_defined: bool + needs_env: bool + + +all_aten_ops: Dict[TorchCallable, Operator] = {} +all_torch_functions: Dict[TorchCallable, Operator] = {} + + +def register_torch_dispatch_op( + aten_op, impl_callable, + is_jax_function=True, + is_user_defined=False, + needs_env=False, +): + op = Operator( + aten_op, impl_callable, + is_jax_function=is_jax_function, + is_user_defined=is_user_defined, + needs_env=needs_env) + all_aten_ops[aten_op] = op + return impl_callable + + +def register_torch_function_op( + torch_func, impl_callable, + is_jax_function=True, + is_user_defined=False, + needs_env=False, +): + op = Operator( + torch_func, impl_callable, + is_jax_function=is_jax_function, + is_user_defined=is_user_defined, + needs_env=needs_env) + all_torch_functions[torch_func] = op + return impl_callable \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/ops_registry.py b/experimental/torch_xla2/torch_xla2/ops_registry.py deleted file mode 100644 index f1d115864d3..00000000000 --- a/experimental/torch_xla2/torch_xla2/ops_registry.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -import torch._decomp as decomp -import torch_xla2.decompositions - -class LoweringRegistry: - - def __init__(self): - self.registered_ops = {} - self.decomps = {} - - def lookup(self, op_or_name): - candidate = self._lookup(op_or_name) - if candidate is None: - if isinstance(op_or_name, torch._ops.OpOverloadPacket): - candidate = self._lookup(op_or_name.default) - if isinstance(op_or_name, torch._ops.OpOverload): - candidate = self._lookup(op_or_name.overloadpacket) - return candidate - - def _lookup(self, op): - candidate = self.registered_ops.get(op) - if candidate is None: - candidate = self.decomp.get(op) - return candidate - - def register(self, op, lowering): - if isinstance(op, torch._ops.OpOverloadPacket): - if hasattr(op, 'default'): - self.registered_ops[op.default] = lowering - self.registered_ops[op] = lowering - - -lowerings = LoweringRegistry() -EXTRA_DECOMP = decomp.get_decompositions([ - torch.ops.aten.upsample_nearest2d, - torch.ops.aten._native_batch_norm_legit.no_stats, - torch.ops.aten._adaptive_avg_pool2d, - torch.ops.aten._adaptive_avg_pool3d, - torch.ops.aten.grid_sampler_2d, - torch.ops.aten.native_dropout, - torch.ops.aten.reflection_pad1d, - torch.ops.aten.reflection_pad2d, - torch.ops.aten.reflection_pad3d, - torch.ops.aten.replication_pad1d, - torch.ops.aten.replication_pad2d, - torch.ops.aten.replication_pad3d, -]) -CORE_ATEN_DECOMP = decomp.core_aten_decompositions() -CORE_ATEN_DECOMP.update(EXTRA_DECOMP) -lowerings.decomp = CORE_ATEN_DECOMP - - -def _all_core_ops(): - """Yields all core ops.""" - import torch._ops - - for k, v in torch.ops.aten.__dict__.items(): - if k.startswith('__'): - continue - if k.startswith('_'): - continue - if isinstance(v, torch._ops.OpOverloadPacket): - for overload in v.overloads(): - op = getattr(v, overload) - if torch.Tag.core in op.tags: - yield v - break - - -def print_missing_ops(): - core_aten = set(_all_core_ops()) - existing = set(lowerings.registered_ops.keys()) - for v in core_aten - existing: - print(v) diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index 98953a8b04c..262bc95f566 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -1,53 +1,16 @@ -import functools +import contextlib import jax from jax import dlpack as jaxdl import jax.numpy as jnp import numpy import torch import torch.func -import torch._decomp.decompositions -from torch_xla2 import ops_registry import torch.utils._python_dispatch as torch_dispatch import torch.utils._pytree as torch_pytree import torch.utils.dlpack as torchdl -from torch_xla2.ops import jaten -from torch._subclasses.fake_tensor import FakeTensorMode -fake_mode = FakeTensorMode() - - -class XLADispatchMode(torch_dispatch.TorchDispatchMode): - - def __torch_dispatch__(self, fn, types, args=(), kwargs=None): - if fn in constructors: - args, kwargs = unwrap((args, kwargs)) - res = constructors[fn](*args, **kwargs) - return wrap(res) - - return fn(*args, **kwargs) - - -def _aten_arange(start, - end, - *, - dtype=None, - layout=None, - requires_grad=False, - device=None, - pin_memory=False): - return jnp.arange(start, end, 1) - - -def _aten_scalar_tensor(val, **kwargs): - p = torch.ops.aten.scalar_tensor(val) - return wrap(t2j(p)) - - -constructors = { - torch.ops.aten.scalar_tensor.default: _aten_scalar_tensor, - torch.ops.aten.arange.default: functools.partial(_aten_arange, 0), - torch.ops.aten.arange.start: _aten_arange, -} +class OperatorNotFound(Exception): + pass def wrap(jaxarray): @@ -61,7 +24,9 @@ def unwrap(torchtensors): def t2j(t): if isinstance(t, XLATensor2): return t._elem + is_bool = False if t.dtype == torch.bool: + is_bool = True t = t.to(torch.int8) if not t.is_contiguous(): @@ -82,7 +47,7 @@ def t2j(t): if t.dtype == torch.bfloat16: res = res.astype(jnp.bfloat16) - if t.dtype == torch.bool: + if is_bool: res = res.astype(jnp.bool_) return res @@ -97,48 +62,41 @@ def j2t(x): res = res.to(torch.bool) return res +TORCH_DTYPE_TO_JAX = { + torch.float16: jnp.dtype('float16'), + torch.bfloat16: jnp.dtype('bfloat16'), + torch.half: jnp.dtype('float16'), + torch.float32: jnp.dtype('float32'), + torch.double: jnp.dtype('double'), + torch.long: jnp.dtype('int64'), + torch.int32: jnp.dtype('int32'), + torch.int16: jnp.dtype('int16'), + torch.int8: jnp.dtype('int8'), + torch.uint8: jnp.dtype('uint8'), + torch.bool: jnp.dtype('bool_'), + torch.complex64: jnp.dtype('complex64'), + torch.complex128: jnp.dtype('complex128'), + None: None, +} + +JAX_DTYPE_TO_TORCH = {value: key for key, value in TORCH_DTYPE_TO_JAX.items()} def t2j_dtype(dtype): - return { - torch.float16: jnp.float16, - torch.bfloat16: jnp.bfloat16, - torch.half: jnp.float16, - torch.float32: jnp.float32, - torch.double: jnp.double, - torch.long: jnp.int64, - torch.int32: jnp.int32, - torch.int16: jnp.int16, - torch.int8: jnp.int8, - torch.uint8: jnp.uint8, - torch.bool: jnp.bool_, - torch.complex64: jnp.complex64, - torch.complex128: jnp.complex128, - }.get(dtype) + if dtype not in TORCH_DTYPE_TO_JAX: + raise RuntimeError(f'Attempting to convert unknown type: {dtype} to torch type,') + return TORCH_DTYPE_TO_JAX[dtype] def j2t_dtype(dtype): - return { - jnp.float16: torch.float16, - jnp.bfloat16: torch.bfloat16, - jnp.double: torch.double, - jnp.float32: torch.float32, - jnp.float16: torch.half, - jnp.int64: torch.long, - jnp.int32: torch.int32, - jnp.int16: torch.int16, - jnp.bool_: torch.bool, - jnp.complex64: torch.complex64, - }.get(dtype) - - -def move_to_device(t): - return XLATensor2(t2j(t)) + if dtype not in JAX_DTYPE_TO_TORCH: + raise RuntimeError(f'Attempting to convert unknown type: {dtype} to torch type,') + return JAX_DTYPE_TO_TORCH[dtype] class XLATensor2(torch.Tensor): @staticmethod - def __new__(cls, elem): + def __new__(cls, elem, env): dtype = j2t_dtype(elem.dtype) shape = list(elem.shape) for i, s in enumerate(shape): @@ -154,9 +112,10 @@ def __new__(cls, elem): requires_grad=False, ) - def __init__(self, elem: jax.Array): + def __init__(self, elem: jax.Array, env: 'Environment'): super().__init__() self._elem = elem + self._env = env def __str__(self): return "XLATensor2({} {})".format(str(type(self._elem)), str(self._elem)) @@ -178,7 +137,7 @@ def flatten(self, start_dim=0, end_dim=-1): new_shape = ( self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim:]) new_elem = jnp.reshape(self._elem, new_shape) - return XLATensor2(new_elem) + return XLATensor2(new_elem, self._env) # return torch.reshape(self, new_shape) def __setitem__(self, key, val): @@ -193,32 +152,17 @@ def type_as(self, other): @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - kwargs = kwargs or {} - with jax.named_scope(func.name()): + env = None + for arg in torch_pytree.arg_tree_leaves(*args, **kwargs): + if isinstance(arg, XLATensor2): + env = arg._env + break - if isinstance(func, torch._ops.OpOverloadPacket): - return func(*args, **kwargs) - - if func.name() == 'aten::copy_': - x, y = args - x._elem = y._elem - return - - if func.overloadpacket in jaten.all_ops: - return jaten.all_ops[func.overloadpacket](*args, **kwargs) - - lowering = ops_registry.lowerings.lookup(func) - - if lowering is None: - raise RuntimeError("No lowering found for", func.name()) - - with XLADispatchMode(): - res = lowering(*args, **kwargs) - debug_accuracy(func, args, kwargs, res) - return res + with env: + return func(*args, **(kwargs or {})) def detach(self): - return XLATensor2(jax.lax.stop_gradient(self.jax())) + return XLATensor2(jax.lax.stop_gradient(self.jax()), self._env) def numpy(self) -> numpy.ndarray: import numpy as np @@ -231,6 +175,20 @@ def jax(self) -> jax.Array: def torch(self) -> torch.Tensor: return j2t(self.jax()) + def to(self, *args, **kwargs): + if len(args) == 1: + if isinstance(args[0], torch.dtype): + return XLATensor2(self._elem.astype(t2j_dtype(args[0])), self._env) + if 'dtype' in kwargs: + dtype = kwargs['dtype'] + return XLATensor2(self._elem.astype(t2j_dtype(dtype)), self._env) + return self + + @property + def dtype(self): + return j2t_dtype(self._elem.dtype) + + # TODO: slice of slice should also be another slice class SliceView(XLATensor2): @@ -281,3 +239,159 @@ def debug_accuracy(func, args, kwargs, current_output): pdb.set_trace() return True + + +class XLAFunctionMode(torch.overrides.TorchFunctionMode): + """Context manager that dispatches torch function calls to JAX.""" + + def __init__(self, env): + self.env = env + + def __torch_function__(self, + func, + types, + args=(), + kwargs=None) -> torch.Tensor: + try: + return self.env.dispatch(func, types, args, kwargs) + except OperatorNotFound: + return func(*args, **(kwargs or {})) + + +class XLADispatchMode(torch_dispatch.TorchDispatchMode): + + def __init__(self, env): + self.env = env + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if isinstance(func, torch._ops.OpOverloadPacket): + with self: + return func(*args, **kwargs) + if func.namespace != 'aten': + return func(*args, **kwargs) + return self.env.dispatch(func, types, args, kwargs) + +def _name_of_func(func): + if hasattr(func, 'name'): + return func.name() + return func.__name__ + + +class Environment(contextlib.ContextDecorator): + """This class holds a set of configurations and "globals" needed + + for executing torch program using jax. + Things included so far: + + op registry + PRNGKey + Configs + + Also helper functions to manipulate those. + """ + + _prng_key: jax.random.PRNGKey + + + def __init__(self, random_seed): + self._prng_key = jax.random.PRNGKey(random_seed) + self._function_mode = XLAFunctionMode(self) + self._dispatch_mode = XLADispatchMode(self) + + # name is torch callable + self._ops = {} + self.load_ops() + + def load_ops(self): + from torch_xla2.ops import jaten, jtorch, ops_registry + self._ops.update(ops_registry.all_aten_ops) + self._ops.update(ops_registry.all_torch_functions) + + decomps = torch._decomp.core_aten_decompositions() + from torch_xla2.decompositions import EXTRA_DECOMP + decomps.update(EXTRA_DECOMP) + for k, v in decomps.items(): + if k not in self._ops: + self._ops[k] = ops_registry.Operator( + k, + v, + is_jax_function=False, + is_user_defined=False, + needs_env=False + ) + + def get_and_rotate_prng_key(self): + self._prng_key, key = jax.random.split(self._prng_key) + return key + + def dispatch(self, func, types, args, kwargs): + with jax.named_scope(_name_of_func(func)): + kwargs = kwargs or {} + op = self._ops.get(func) + + if op is None and isinstance(func, torch._ops.OpOverloadPacket): + op = self._ops.get(func.default) + + if op is None and isinstance(func, torch._ops.OpOverload): + op = self._ops.get(func.overloadpacket) + + if op is None: + raise OperatorNotFound( + f'Operator with name {_name_of_func(func)} has no lowering') + + if op.is_jax_function: + args, kwargs = self.t2j_iso((args, kwargs)) + + if op.needs_env: + kwargs['env'] = self + + with self: + res = op.func(*args, **kwargs) + + if op.is_jax_function: + res = self.j2t_iso(res) + + #if self.config.debug_accuracy_for_each_op: + # debug_accuracy(func, args, kwargs, res) + return res + + def __enter__(self): + self._dispatch_mode.__enter__() + self._function_mode.__enter__() + return self + + def __exit__(self, *exc): + self._function_mode.__exit__(*exc) + self._dispatch_mode.__exit__(*exc) + + def _move_one_value(self, val): + if isinstance(val, torch.nn.Module): + state_dict = self.to_xla(val.state_dict()) + val.load_state_dict(state_dict, assign=True) + return val + if isinstance(val, XLATensor2): + return val + if isinstance(val, torch.Tensor): + return XLATensor2(t2j(val), self) + return val + + def to_xla(self, torchvalues): + # tensors are torch.Tensors (not XLATensor) + res = torch_pytree.tree_map( + self._move_one_value, + torchvalues) + return res + + def t2j_iso(self, torchtensors): + return torch_pytree.tree_map_only( + XLATensor2, lambda x: x.jax(), torchtensors) + + def j2t_iso(self, jaxarray): + return torch_pytree.tree_map_only( + jnp.ndarray, lambda x: XLATensor2(x, self), jaxarray) + + def j2t_copy(self, args): + pass + + def j2t_copy(self, args): + pass diff --git a/experimental/torch_xla2/torch_xla2/types.py b/experimental/torch_xla2/torch_xla2/types.py new file mode 100644 index 00000000000..f39d530c18d --- /dev/null +++ b/experimental/torch_xla2/torch_xla2/types.py @@ -0,0 +1,12 @@ +from typing import TypeAlias, Callable, ParamSpec, Any, Union +import torch +import jax +import jax.numpy as jnp + + +P = ParamSpec('P') + +TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, 'TorchCallable', Any] +TorchCallable: TypeAlias = Callable[P, TorchValue] +JaxValue: TypeAlias = Union[jax.Array, jnp.dtype, 'JaxCallable', Any] +JaxCallable: TypeAlias = Callable[P, JaxValue] \ No newline at end of file diff --git a/infra/ansible/config/env.yaml b/infra/ansible/config/env.yaml index d324729ce11..9e2fe7270cc 100644 --- a/infra/ansible/config/env.yaml +++ b/infra/ansible/config/env.yaml @@ -22,7 +22,7 @@ build_env: common: LD_LIBRARY_PATH: "$LD_LIBRARY_PATH:/usr/local/lib" # Set explicitly to 0 as setup.py defaults this flag to true if unset. - BUILD_CPP_TESTS: 0 + BUILD_CPP_TESTS: "{{ build_cpp_tests }}" # Force GCC because clang/bazel has issues. CC: gcc-10 CXX: g++-10 @@ -33,7 +33,7 @@ build_env: BAZEL_REMOTE_CACHE: 1 SILO_NAME: "cache-silo-{{ arch }}-{{ accelerator }}-{{ clang_version }}{{ cache_suffix }}" _GLIBCXX_USE_CXX11_ABI: 0 - GIT_VERSIONED_XLA_BUILD: "{{ nightly_release }}" + GIT_VERSIONED_XLA_BUILD: "{{ nightly_release or git_versioned_xla_build }}" amd64: ARCH: amd64 diff --git a/infra/ansible/config/vars.yaml b/infra/ansible/config/vars.yaml index c1ca7a93d27..e5851d0cc77 100644 --- a/infra/ansible/config/vars.yaml +++ b/infra/ansible/config/vars.yaml @@ -14,3 +14,7 @@ nightly_release: false bundle_libtpu: 1 # Suffix for bazel remote cache key cache_suffix: "" +# Whether to build C++ tests with `torch_xla` wheel +build_cpp_tests: 0 +# Whether to tag wheels with git hash, e.g. X.Y.Z+git123abc +git_versioned_xla_build: false diff --git a/infra/ansible/roles/build_srcs/tasks/main.yaml b/infra/ansible/roles/build_srcs/tasks/main.yaml index d945f150d38..da09a695453 100644 --- a/infra/ansible/roles/build_srcs/tasks/main.yaml +++ b/infra/ansible/roles/build_srcs/tasks/main.yaml @@ -1,3 +1,27 @@ +- name: Read PyTorch pin + ansible.builtin.command: cat {{ (src_root, 'pytorch/xla/.torch_pin') | path_join }} + register: torch_pin + # Pin may not exist + ignore_errors: true + +- name: Checkout PyTorch pin + # ansible.builtin.git wants to fetch the entire history, so check out the pin manually + ansible.builtin.shell: + cmd: | + set -xe + PIN="{{ torch_pin.stdout }}" + if [[ $PIN = \#* ]]; then + PRNUM="${PIN//[!0-9]/}" + git fetch origin "pull/$PRNUM/head" + else + git fetch origin {{ torch_pin.stdout }} + fi + git checkout --recurse-submodules FETCH_HEAD + chdir: "{{ (src_root, 'pytorch') | path_join }}" + args: + executable: /bin/bash + when: torch_pin is succeeded + - name: Build PyTorch ansible.builtin.command: cmd: python setup.py bdist_wheel @@ -77,6 +101,22 @@ state: absent mode: '0755' +- name: Create temp directory for C++ tests + ansible.builtin.file: + path: /tmp/test/bin + state: directory + mode: '0755' + when: build_cpp_tests + +- name: Collect C++ test files + ansible.builtin.shell: | + cd pytorch/xla/build/temp* + bazel query 'kind(".*_test", tests(//:cpp_tests))' --output=label | xargs -n 1 bazel cquery --output=files | xargs cp -t /tmp/test/bin + args: + executable: bash + chdir: "{{ src_root }}" + when: build_cpp_tests + - name: Read Torchvision pin ansible.builtin.command: cat {{ (src_root, 'pytorch') | path_join }}/.github/ci_commit_pins/vision.txt register: torchvision_pin diff --git a/infra/tpu-pytorch-releases/artifacts.auto.tfvars b/infra/tpu-pytorch-releases/artifacts.auto.tfvars index 0229a79c190..16902f663fd 100644 --- a/infra/tpu-pytorch-releases/artifacts.auto.tfvars +++ b/infra/tpu-pytorch-releases/artifacts.auto.tfvars @@ -35,58 +35,58 @@ nightly_builds = [ versioned_builds = [ # Remove libtpu from PyPI builds { - git_tag = "v2.3.0-rc12" - package_version = "2.3.0-rc12" - pytorch_git_rev = "v2.3.0-rc12" + git_tag = "v2.3.0" + package_version = "2.3.0" + pytorch_git_rev = "v2.3.0" accelerator = "tpu" python_version = "3.8" bundle_libtpu = "0" }, { - git_tag = "v2.3.0-rc12" - package_version = "2.3.0-rc12" - pytorch_git_rev = "v2.3.0-rc12" + git_tag = "v2.3.0" + package_version = "2.3.0" + pytorch_git_rev = "v2.3.0" accelerator = "tpu" python_version = "3.9" bundle_libtpu = "0" }, { - git_tag = "v2.3.0-rc12" - package_version = "2.3.0-rc12" - pytorch_git_rev = "v2.3.0-rc12" + git_tag = "v2.3.0" + package_version = "2.3.0" + pytorch_git_rev = "v2.3.0" accelerator = "tpu" python_version = "3.10" bundle_libtpu = "0" }, { - git_tag = "v2.3.0-rc12" - package_version = "2.3.0-rc12" - pytorch_git_rev = "v2.3.0-rc12" + git_tag = "v2.3.0" + package_version = "2.3.0" + pytorch_git_rev = "v2.3.0" accelerator = "tpu" python_version = "3.11" bundle_libtpu = "0" }, # Bundle libtpu for Kaggle { - git_tag = "v2.3.0-rc12" - package_version = "2.3.0-rc12+libtpu" - pytorch_git_rev = "v2.3.0-rc12" + git_tag = "v2.3.0" + package_version = "2.3.0+libtpu" + pytorch_git_rev = "v2.3.0" accelerator = "tpu" python_version = "3.10" bundle_libtpu = "1" }, { - git_tag = "v2.3.0-rc12" - pytorch_git_rev = "v2.3.0-rc12" - package_version = "2.3.0-rc12" + git_tag = "v2.3.0" + pytorch_git_rev = "v2.3.0" + package_version = "2.3.0" accelerator = "cuda" cuda_version = "12.1" python_version = "3.8" }, { - git_tag = "v2.3.0-rc12" - pytorch_git_rev = "v2.3.0-rc12" - package_version = "2.3.0-rc12" + git_tag = "v2.3.0" + pytorch_git_rev = "v2.3.0" + package_version = "2.3.0" accelerator = "cuda" cuda_version = "12.1" python_version = "3.10" diff --git a/infra/tpu-pytorch-releases/dev_images.tf b/infra/tpu-pytorch-releases/dev_images.tf index 023ac8b870a..54c340809ef 100644 --- a/infra/tpu-pytorch-releases/dev_images.tf +++ b/infra/tpu-pytorch-releases/dev_images.tf @@ -36,8 +36,6 @@ module "dev_images" { image_name = "development" image_tags = concat(each.value.extra_tags, [ each.key, - # Append _YYYYMMDD suffix to the dev image name. - "${each.key}_$(date +%Y%m%d)", ]) dockerfile = "development.Dockerfile" diff --git a/scripts/apply_patches.sh b/scripts/apply_patches.sh index 923b68c79d4..7ba0a3ef8e3 100755 --- a/scripts/apply_patches.sh +++ b/scripts/apply_patches.sh @@ -7,7 +7,7 @@ XDIR=$CDIR/.. PTDIR=$XDIR/.. OPENXLADIR=$XDIR/third_party/xla -TORCH_PIN="$XDIR/torch_patches/.torch_pin" +TORCH_PIN="$XDIR/.torch_pin" if [ -f "$TORCH_PIN" ]; then CID=$(cat "$TORCH_PIN") # If starts with # and it's not merged into master, fetch from origin diff --git a/setup.py b/setup.py index dbe47007aff..df54224ed98 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,7 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_date = '20240418' +_date = '20240502' _libtpu_version = f'0.1.dev{_date}' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' _jax_version = f'0.4.27.dev{_date}' @@ -223,6 +223,10 @@ def bazel_build(self, ext): f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}" ] + build_cpp_tests = build_util.check_env_flag('BUILD_CPP_TESTS', default='0') + if build_cpp_tests: + bazel_argv.append('//:cpp_tests') + import torch cxx_abi = os.getenv('CXX_ABI') or getattr(torch._C, '_GLIBCXX_USE_CXX11_ABI', None) diff --git a/test/benchmarks/run_tests.sh b/test/benchmarks/run_tests.sh index 3832b21ed22..fce6140a4fe 100755 --- a/test/benchmarks/run_tests.sh +++ b/test/benchmarks/run_tests.sh @@ -9,7 +9,7 @@ export PYTHONPATH=$PYTHONPATH:$CDIR/../../benchmarks/ # Note [Keep Going] # -# Set the `CONTINUE_ON_ERROR` flag to `true` to make the CircleCI tests continue on error. +# Set the `CONTINUE_ON_ERROR` flag to `true` to make the CI tests continue on error. # This will allow you to see all the failures on your PR, not stopping with the first # test failure like the default behavior. CONTINUE_ON_ERROR="${CONTINUE_ON_ERROR:-0}" diff --git a/test/cpp/test_aten_xla_tensor_5.cpp b/test/cpp/test_aten_xla_tensor_5.cpp index 4070779529f..07e4c2dae86 100644 --- a/test/cpp/test_aten_xla_tensor_5.cpp +++ b/test/cpp/test_aten_xla_tensor_5.cpp @@ -267,6 +267,27 @@ TEST_F(AtenXlaTensorTest, TestEmbedding) { }); } +TEST_F(AtenXlaTensorTest, TestEmbeddingBag) { + torch::Tensor weight = + torch::rand({32, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor indices = + torch::randint(0, 31, {10}, torch::TensorOptions(torch::kLong)); + torch::Tensor offsets = torch::arange(0, 10, 3); + auto out = torch::embedding_bag(weight, indices, offsets); + torch::Tensor result = std::get<0>(out); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_weight = CopyToDevice(weight, device); + torch::Tensor xla_indices = CopyToDevice(indices, device); + torch::Tensor xla_offsets = CopyToDevice(offsets, device); + auto xla_out = torch::embedding_bag(xla_weight, xla_indices, xla_offsets); + torch::Tensor xla_result = std::get<0>(xla_out); + AllClose(result, xla_result); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_embedding_bag_forward_only", + cpp_test::GetIgnoredCounters()); + }); +} + TEST_F(AtenXlaTensorTest, TestOneHot) { int num_classes = 5; torch::Tensor input = diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index c3dfe6bbed1..3a3eb3d43f1 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -489,13 +489,13 @@ def test_resnet18(self): # Graph 1: forward # Graph 2: backward # Graph 3: sync input for backward - self.assertEqual(met.metric_data('CompileTime')[0], 3) + self.assertLessEqual(met.metric_data('CompileTime')[0], 3) # We execute 3 graphs per step. - self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count * 3) + self.assertLessEqual(met.metric_data('ExecuteTime')[0], sample_count * 3) # one for each forward and one for each backward - self.assertEqual( + self.assertLessEqual( met.metric_data('RunCachedGraphInputData')[0], sample_count * 2) - self.assertEqual( + self.assertLessEqual( met.metric_data('RunCachedGraphOutputData')[0], sample_count * 2) @@ -641,10 +641,7 @@ def test_all_cpu_tensor(self): # there should be 18 paramters + 1 input self.assertGreater(len(w), 15) self.assertIn('Found tensor with shape torch.Size', str(w[0].message)) - # no XLA operation should happens except a empty mark_step. Partitioner should offload all CPU - # ops to CPU. - self.assertEqual(len(met.counter_names()), 1) - self.assertIn('MarkStep', met.counter_names()) + self.assertLessEqual(len(met.counter_names()), 1) class DynamoOperationsTests(test_utils.XlaTestCase): @@ -671,6 +668,21 @@ def foo(x): self.assertEqual(expected.dtype, actual.dtype) self.assertEqual(expected.device, actual.device) + def test_return_expand(self): + + def foo(x): + return x.expand(2, -1) + + optfoo = torch.compile(backend="openxla")(foo) + + t = torch.arange(10) + Xt = t.to(xm.xla_device()) + + expected = foo(t) + actual = optfoo(Xt) + + self.assertEqual(expected, actual.cpu()) + if __name__ == '__main__': test = unittest.main() diff --git a/test/pjrt/test_runtime_tpu.py b/test/pjrt/test_runtime_tpu.py index 0def33ae275..744039a4f58 100644 --- a/test/pjrt/test_runtime_tpu.py +++ b/test/pjrt/test_runtime_tpu.py @@ -206,7 +206,8 @@ def _runtime_device_attributes(): def test_runtime_device_attributes(self): result = pjrt.run_multiprocess(self._runtime_device_attributes) for device in result.values(): - self.assertCountEqual(['coords', 'core_on_chip'], list(device.keys())) + self.assertCountEqual(['coords', 'core_on_chip', 'num_cores'], + list(device.keys())) self.assertIsInstance(device['coords'], list) self.assertIsInstance(device['core_on_chip'], int) @@ -218,7 +219,7 @@ def test_global_runtime_device_attributes(self): results = pjrt.run_multiprocess(self._global_runtime_device_attributes) for result in results.values(): for device in result: - self.assertCountEqual(['coords', 'core_on_chip', 'name'], + self.assertCountEqual(['coords', 'core_on_chip', 'name', 'num_cores'], list(device.keys())) self.assertIsInstance(device['coords'], list) self.assertIsInstance(device['core_on_chip'], int) diff --git a/test/pytorch_test_base.py b/test/pytorch_test_base.py index 88ad0f6bc3d..3a6dcdd96c6 100644 --- a/test/pytorch_test_base.py +++ b/test/pytorch_test_base.py @@ -70,6 +70,7 @@ 'test_pdist_norm_backward_xla', # pdist_single 'test_pdist_norm_forward_xla', # pdist_single 'test_nuclear_norm_axes_small_brute_force', + 'test_nondeterministic_alert_EmbeddingBag_max_xla', # FIXME: implement embedding_bag_backward 'test_mul_intertype_scalar', 'test_masked_select_discontiguous', # FIXME: wrong result 'test_memory_format_type', diff --git a/test/run_tests.sh b/test/run_tests.sh index 8926318dc38..4d4d3bfa99e 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -8,7 +8,7 @@ VERBOSITY=2 # Note [Keep Going] # -# Set the `CONTINUE_ON_ERROR` flag to `true` to make the CircleCI tests continue on error. +# Set the `CONTINUE_ON_ERROR` flag to `true` to make the CI tests continue on error. # This will allow you to see all the failures on your PR, not stopping with the first # test failure like the default behavior. CONTINUE_ON_ERROR="${CONTINUE_ON_ERROR:-0}" @@ -162,7 +162,6 @@ function run_xla_op_tests1 { run_dynamic "$CDIR/ds/test_dynamic_shapes.py" run_dynamic "$CDIR/ds/test_dynamic_shape_models.py" "$@" --verbosity=$VERBOSITY run_eager_debug "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY - run_test "$CDIR/test_grad_checkpoint.py" run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_pt_xla_debug "$CDIR/debug_tool/test_pt_xla_debug.py" @@ -210,8 +209,9 @@ function run_xla_op_tests3 { run_test "$CDIR/stablehlo/test_exports.py" run_test "$CDIR/stablehlo/test_export_fx_passes.py" run_test "$CDIR/stablehlo/test_implicit_broadcasting.py" - run_test "$CDIR/stablehlo/test_mark_pattern.py" + run_test "$CDIR/stablehlo/test_composite.py" run_test "$CDIR/stablehlo/test_pt2e_qdq.py" + run_test "$CDIR/stablehlo/test_stablehlo_custom_call.py" run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_inference.py" run_test "$CDIR/stablehlo/test_stablehlo_compile.py" run_test "$CDIR/stablehlo/test_unbounded_dynamism.py" diff --git a/test/stablehlo/test_mark_pattern.py b/test/stablehlo/test_composite.py similarity index 100% rename from test/stablehlo/test_mark_pattern.py rename to test/stablehlo/test_composite.py diff --git a/test/stablehlo/test_export_fx_passes.py b/test/stablehlo/test_export_fx_passes.py index d1e731abd6e..82650997316 100644 --- a/test/stablehlo/test_export_fx_passes.py +++ b/test/stablehlo/test_export_fx_passes.py @@ -18,7 +18,7 @@ class ExportFxPassTest(unittest.TestCase): def test_decompose_dynamic_shape_select(self): args = (torch.rand((10, 197, 768)), 1, 0) - dynamic_shapes = ([{0: Dim("bs")}, None, None],) + dynamic_shapes = (({0: Dim("bs")}, None, None),) m = wrap_func_as_nn_module(torch.ops.aten.select.int) ep = export(m, args, dynamic_shapes=dynamic_shapes) out1 = ep.module()(*args) @@ -55,7 +55,7 @@ def forward(self, x): def test_embedding_indices_flatten(self): args = (torch.rand((20, 768)), torch.randint(0, 15, (3, 10)).to(torch.int64)) - dynamic_shapes = ([None, {0: Dim("bs")}],) + dynamic_shapes = ((None, {0: Dim("bs")}),) m = wrap_func_as_nn_module(torch.ops.aten.embedding.default) ep = export(m, args, dynamic_shapes=dynamic_shapes) print(ep) diff --git a/test/stablehlo/test_stablehlo_custom_call.py b/test/stablehlo/test_stablehlo_custom_call.py new file mode 100644 index 00000000000..7291608e506 --- /dev/null +++ b/test/stablehlo/test_stablehlo_custom_call.py @@ -0,0 +1,121 @@ +import re +import unittest + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.experimental.stablehlo_custom_call +from torch.library import Library, impl, impl_abstract +from torch_xla.experimental.stablehlo_custom_call import stablehlo_custom_call +from torch_xla.stablehlo import (StableHLOExportOptions, + exported_program_to_stablehlo) + +m = Library("my_custom_library", "DEF") + + +class StableHLOCustomCallExportTest(unittest.TestCase): + + def test_single_output(self): + + m.define("custom_op(Tensor input) -> Tensor") + + @impl(m, "custom_op", "Meta") + def custom_op_meta(x): + return torch.empty_like(x) + + class M(torch.nn.Module): + + def forward(self, x): + x = torch.sin(x) + x = torch.ops.my_custom_library.custom_op(x) + x = torch.cos(x) + x = torch.ops.my_custom_library.custom_op(x) + x = torch.sin(x) + return x + + options = StableHLOExportOptions() + options.custom_ops_allowed_in_graph.add("my_custom_library") + ep = torch.export.export(M(), (torch.randn(3, 3),)) + shlo_module = exported_program_to_stablehlo(ep, options) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search( + r"stablehlo.custom_call.*@my_custom_library\.custom_op\.default", + shlo_text) is not None) + self.assertTrue( + re.search(r"tensor<3x3xf32>.*->.*tensor<3x3xf32>", shlo_text) + is not None) + self.assertTrue(shlo_text.count("@my_custom_library.custom_op.default", 2)) + + def test_multiple_input_output(self): + + m.define("custom_op2(Tensor input, Tensor input) -> (Tensor, Tensor)") + + @impl(m, "custom_op2", "Meta") + def custom_op2_meta(x, y): + return torch.empty_like(x), torch.empty(y.shape[1:], device='meta') + + class M(torch.nn.Module): + + def forward(self, x, y): + x = torch.sin(x) + x, y = torch.ops.my_custom_library.custom_op2(x, y) + x = torch.cos(x) + x, y = torch.ops.my_custom_library.custom_op2(x, y) + y = torch.sin(y) + return x, y + + options = StableHLOExportOptions() + options.custom_ops_allowed_in_graph.add("my_custom_library") + ep = torch.export.export(M(), (torch.randn(3, 3), torch.randn(5, 5))) + shlo_module = exported_program_to_stablehlo(ep, options) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search( + r"stablehlo.custom_call.*@my_custom_library\.custom_op2\.default", + shlo_text) is not None) + self.assertTrue( + re.search( + r"tensor<3x3xf32>.*tensor<5x5xf32>.*->.*tuple, tensor<5xf32>>", + shlo_text) is not None) + self.assertTrue(shlo_text.count("@my_custom_library.custom_op2.default", 2)) + + def test_stable_custom_call_api(self): + + m.define("custom_op3(Tensor input) -> Tensor") + + @impl(m, "custom_op3", "Meta") + def custom_op3_meta(x): + return torch.empty(x.shape[1:], device='meta') + + @impl(m, "custom_op3", "XLA") + def custom_op3_xla(x): + res = stablehlo_custom_call((x,), "custom_op3", [x.shape[1:]], + [torch.int8], True, "backend_config", 1) + return res + + class M(torch.nn.Module): + + def forward(self, x): + x = torch.sin(x) + x = torch.ops.my_custom_library.custom_op3(x) + x = torch.cos(x) + return x + + ep = torch.export.export(M(), (torch.randn(3, 3),)) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search(r"stablehlo.custom_call.*@custom_op3", shlo_text) is not None) + self.assertTrue( + re.search(r"tensor<3x3xf32>.*->.*tensor<3xi8>", shlo_text) is not None) + self.assertTrue("backend_config = \"backend_config\"" in shlo_text) + self.assertTrue("has_side_effect = true" in shlo_text) + # TODO: api version lost during conversion, or not shown in txt format. + # self.assertTrue("api_version = 1" in shlo_text) + + +if __name__ == "__main__": + + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/stablehlo/test_unbounded_dynamism.py b/test/stablehlo/test_unbounded_dynamism.py index e185a47007e..3cd17a7fe34 100644 --- a/test/stablehlo/test_unbounded_dynamism.py +++ b/test/stablehlo/test_unbounded_dynamism.py @@ -27,7 +27,7 @@ class UnboundedDynamismExportTest(unittest.TestCase): def test_add(self): args = (torch.rand((10, 197, 768)), torch.rand((10, 197, 768))) - dynamic_shapes = ([{0: Dim("dim")}, {0: Dim("dim")}],) + dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) m = wrap_func_as_nn_module(torch.ops.aten.add.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -45,7 +45,7 @@ def test_add(self): def test_add_scalar(self): args = (torch.rand((10, 197, 768)), 0.345) - dynamic_shapes = ([{0: Dim("dim")}, None],) + dynamic_shapes = (({0: Dim("dim")}, None),) m = wrap_func_as_nn_module(torch.ops.aten.add.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -62,7 +62,7 @@ def test_add_scalar(self): def test_addmm(self): args = (torch.rand((5)), torch.rand((10, 5)), torch.rand((5, 5))) - dynamic_shapes = ([None, {0: Dim("dim")}, None],) + dynamic_shapes = ((None, {0: Dim("dim")}, None),) m = wrap_func_as_nn_module(torch.ops.aten.addmm.default) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -83,7 +83,7 @@ def test_bmm(self): torch.rand((24, 197, 64)), torch.rand((24, 64, 197)), ) - dynamic_shapes = ([{0: Dim("dim")}, {0: Dim("dim")}],) + dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) m = wrap_func_as_nn_module(torch.ops.aten.bmm.default) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -104,7 +104,7 @@ def test_bmm_dynamic_out_dim(self): torch.rand((8, 128, 256)), torch.rand((8, 256, 3)), ) - dynamic_shapes = ([None, {2: Dim("dim")}],) + dynamic_shapes = ((None, {2: Dim("dim")}),) m = wrap_func_as_nn_module(torch.ops.aten.bmm.default) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -125,7 +125,7 @@ def test_bmm_dynamic_reduction_dim(self): torch.rand((8, 128, 3)), torch.rand((8, 3, 256)), ) - dynamic_shapes = ([{2: Dim("dim")}, {1: Dim("dim")}],) + dynamic_shapes = (({2: Dim("dim")}, {1: Dim("dim")}),) m = wrap_func_as_nn_module(torch.ops.aten.bmm.default) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -143,7 +143,7 @@ def test_bmm_dynamic_reduction_dim(self): def test_cat(self): args = (torch.rand((10, 1, 768)), torch.rand((10, 196, 768))) - dynamic_shapes = ([{0: Dim("dim")}, {0: Dim("dim")}],) + dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) m = wrap_func_as_nn_module( lambda x, y: torch.ops.aten.cat.default([x, y], 1)) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) @@ -166,7 +166,7 @@ def test_conv(self): torch.rand((5, 3, 16, 16)), torch.rand((5)), ) - dynamic_shapes = ([{0: Dim("dim")}, None, None],) + dynamic_shapes = (({0: Dim("dim")}, None, None),) m = wrap_func_as_nn_module( lambda x, y, z: torch.ops.aten.convolution.default( x, @@ -197,7 +197,7 @@ def test_conv1d(self): torch.rand((3, 1, 800)), torch.rand((512, 1, 10)), ) - dynamic_shapes = ([{0: Dim("dim")}, None],) + dynamic_shapes = (({0: Dim("dim")}, None),) # dynamic_shapes = None m = wrap_func_as_nn_module(lambda x, y: torch.ops.aten.convolution.default( x, @@ -225,7 +225,7 @@ def test_conv1d(self): def test_cumsum(self): args = (torch.rand((10, 5)), 1) - dynamic_shapes = ([{0: Dim("dim")}, None],) + dynamic_shapes = (({0: Dim("dim")}, None),) m = wrap_func_as_nn_module(torch.ops.aten.cumsum.default) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -242,7 +242,7 @@ def test_cumsum(self): def test_div(self): args = (torch.rand((10, 12, 197)), torch.rand((10, 12, 197))) - dynamic_shapes = ([{0: Dim("dim")}, {0: Dim("dim")}],) + dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) m = wrap_func_as_nn_module(torch.ops.aten.div.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -260,7 +260,7 @@ def test_div(self): def test_div_scalar(self): args = (torch.rand((10, 12, 197)), 8.0) - dynamic_shapes = ([{0: Dim("dim")}, None],) + dynamic_shapes = (({0: Dim("dim")}, None),) m = wrap_func_as_nn_module(torch.ops.aten.div.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -277,7 +277,7 @@ def test_div_scalar(self): def test_gelu(self): args = (torch.rand((3, 5)),) - dynamic_shapes = ([{0: Dim("dim")}],) + dynamic_shapes = (({0: Dim("dim")},),) m = wrap_func_as_nn_module(torch.ops.aten.gelu) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -342,7 +342,7 @@ def forward(self, x): def test_mul(self): args = (torch.rand((10, 2, 768)), torch.rand((10, 2, 768))) - dynamic_shapes = ([{0: Dim("dim")}, {0: Dim("dim")}],) + dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) m = wrap_func_as_nn_module(torch.ops.aten.mul.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -360,7 +360,7 @@ def test_mul(self): def test_mul_scalar(self): args = (torch.rand((10, 2, 768)), 0.125) - dynamic_shapes = ([{0: Dim("dim")}, None],) + dynamic_shapes = (({0: Dim("dim")}, None),) m = wrap_func_as_nn_module(torch.ops.aten.mul.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -483,7 +483,7 @@ def forward(self, x, weight, bias): def test_permute(self): args = (torch.rand((10, 197, 12, 64)),) - dynamic_shapes = ([{0: Dim("dim")}],) + dynamic_shapes = (({0: Dim("dim")},),) m = wrap_func_as_nn_module( lambda x: torch.ops.aten.permute.default(x, [0, 2, 1, 3])) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) @@ -502,7 +502,7 @@ def test_permute(self): def test_select(self): args = (torch.rand((10, 197, 768)), 1, 0) - dynamic_shapes = ([{0: Dim("dim")}, None, None],) + dynamic_shapes = (({0: Dim("dim")}, None, None),) m = wrap_func_as_nn_module(torch.ops.aten.select.int) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -519,7 +519,7 @@ def test_select(self): def test_slice(self): args = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807) - dynamic_shapes = ([{0: Dim("dim")}, None, None, None],) + dynamic_shapes = (({0: Dim("dim")}, None, None, None),) m = wrap_func_as_nn_module(torch.ops.aten.slice.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -537,7 +537,7 @@ def test_slice(self): def test_slice_2(self): args = (torch.rand((10, 3, 224, 224)), 1, 0, 2) - dynamic_shapes = ([{0: Dim("dim")}, None, None, None],) + dynamic_shapes = (({0: Dim("dim")}, None, None, None),) m = wrap_func_as_nn_module(torch.ops.aten.slice.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -555,7 +555,7 @@ def test_slice_2(self): def test_softmax(self): args = (torch.rand((10, 12, 197, 197)), -1, False) - dynamic_shapes = ([{0: Dim("dim")}, None, None],) + dynamic_shapes = (({0: Dim("dim")}, None, None),) m = wrap_func_as_nn_module(torch.ops.aten._softmax.default) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -573,7 +573,7 @@ def test_softmax(self): def test_sub(self): args = (torch.rand((10, 1, 1, 10)), torch.rand((10, 1, 1, 10))) - dynamic_shapes = ([{0: Dim("dim")}, {0: Dim("dim")}],) + dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) m = wrap_func_as_nn_module(torch.ops.aten.sub.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -591,7 +591,7 @@ def test_sub(self): def test_softmax_reduce_on_dynamic_dim(self): args = (torch.rand((1, 8, 128, 3)), -1, False) - dynamic_shapes = ([{3: Dim("dim")}, None, None],) + dynamic_shapes = (({3: Dim("dim")}, None, None),) m = wrap_func_as_nn_module(torch.ops.aten._softmax.default) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -609,7 +609,7 @@ def test_softmax_reduce_on_dynamic_dim(self): @unittest.skip("Converted StableHLO contains i1 dtype, not expected.") def test_index(self): args = (torch.rand((2, 10)), torch.arange(5)) - dynamic_shapes = ([None, {0: Dim("dim")}],) + dynamic_shapes = ((None, {0: Dim("dim")}),) m = wrap_func_as_nn_module( lambda x, y: torch.ops.aten.index.Tensor(x, [None, y])) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) @@ -628,7 +628,7 @@ def test_index(self): def test_sub_scalar(self): args = (1.0, torch.rand((10, 1, 1, 10))) - dynamic_shapes = ([None, {0: Dim("dim")}],) + dynamic_shapes = ((None, {0: Dim("dim")}),) m = wrap_func_as_nn_module(torch.ops.aten.sub.Tensor) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -670,7 +670,7 @@ def forward(self, x): def test_transpose_on_dynamic_dim(self): args = (torch.rand((1, 8, 3, 256)),) - dynamic_shapes = ([{2: Dim("dim")}],) + dynamic_shapes = (({2: Dim("dim")},),) m = wrap_func_as_nn_module( lambda x: torch.ops.aten.transpose.int(x, -2, -1)) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) @@ -688,7 +688,7 @@ def test_transpose_on_dynamic_dim(self): def test_unsqueeze_1(self): args = (torch.rand((3, 10)),) - dynamic_shapes = ([{0: Dim("dim")}],) + dynamic_shapes = (({0: Dim("dim")},),) m = wrap_func_as_nn_module(lambda x: torch.ops.aten.unsqueeze.default(x, 1)) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -705,7 +705,7 @@ def test_unsqueeze_1(self): def test_unsqueeze_2(self): args = (torch.rand((1, 1, 3, 256)),) - dynamic_shapes = ([{2: Dim("dim")}],) + dynamic_shapes = (({2: Dim("dim")},),) m = wrap_func_as_nn_module(lambda x: torch.ops.aten.unsqueeze.default(x, 2)) ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) diff --git a/test/test_input_output_aliases.py b/test/test_input_output_aliases.py index c7c04f781c3..b2c5fc50b21 100644 --- a/test/test_input_output_aliases.py +++ b/test/test_input_output_aliases.py @@ -38,6 +38,50 @@ def test_aliasing_with_cloned(self): torch.allclose(t1 - 1, t1_cloned) self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0) + def test_aliasing_across_mark_step(self): + xla_device = xm.xla_device() + met.clear_all() + t1 = torch.randn(4, 5).to(xla_device) + t1 += 1 + xm.mark_step() + self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0) + t1 *= 100 + xm.mark_step() + self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0) + + def test_aliasing_with_multiple_inplace_update(self): + BATCH_SIZE = 1 + SEQ_LEN = 128 + NUM_KV_HEADS = 16 + HEAD_SIZE = 256 + BLOCK_SIZE = 16 + DTYPE = torch.bfloat16 + num_blocks = 1024 + device = xm.xla_device() + key = torch.randn( + BATCH_SIZE * SEQ_LEN, + NUM_KV_HEADS, + HEAD_SIZE, + device=device, + dtype=DTYPE) + k_cache = torch.randn( + num_blocks * BLOCK_SIZE, + NUM_KV_HEADS, + HEAD_SIZE, + device=device, + dtype=DTYPE) + slot_mapping = torch.randint( + 0, num_blocks, (BATCH_SIZE, SEQ_LEN), device=device, dtype=torch.int64) + # materalize k_cache to device data + xm.mark_step() + met.clear_all() + for _ in range(10): + k_cache.index_copy_(0, slot_mapping.flatten(), key) + xm.mark_step() + xm.wait_device_ops() + self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0) + torch.allclose(k_cache[slot_mapping[0][0]].cpu(), key[0].cpu()) + if __name__ == '__main__': test = unittest.main() diff --git a/test/test_megablox.py b/test/test_megablox.py new file mode 100644 index 00000000000..af4ecf76b31 --- /dev/null +++ b/test/test_megablox.py @@ -0,0 +1,161 @@ +"""Grouped matrix multiplication kernels for TPU written in Pallas.""" + +import logging +import unittest + +from typing import Optional, Union, Callable + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.experimental.megablox as megablox +from torch_xla import runtime as xr +from torch_xla._internal import tpu + +import numpy as np + +if xr.device_type() == 'TPU': + from torch_xla.experimental.custom_kernel import jax_import_guard + jax_import_guard() + import jax + import jax.numpy as jnp + from jax.experimental import pallas as pl + + +class MegabloxTest(unittest.TestCase): + + def _reference_gmm( + self, + lhs: np.array, + rhs: np.array, + group_sizes: np.array, + preferred_element_type: np.dtype = np.float32, + ) -> np.array: + + start = 0 + out = [] + for i, size in enumerate(group_sizes): + result = np.dot(lhs[start:start + size, :], rhs[i, :, :]) + + result = result.astype(preferred_element_type) + out.append(result) + start += group_sizes[i] + return np.array(np.concatenate(out, axis=0)) + + def _group_sizes_strategy(self, m: int, num_groups: int) -> torch.Tensor: + # Randomly sample the ends of the groups in the m-dimension. Let the fuzzer + # sample with replacement so that it's possible to get zero-sized groups. Get + # 'num_groups - 1' run ends. The final group will end at 'm'. + ends_no_final = np.sort( + np.array( + [np.random.randint(low=0, high=m) for _ in range(num_groups - 1)], + dtype=np.int32, + ),) + ends = np.concatenate([ends_no_final, np.array([m], dtype=np.int32)]) + + # Calculate the run starts by shifting ends 1 to the right. The first run + # starts at zero. + starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final]) + return torch.from_numpy(ends - starts).to(torch.int32) + + def _tolerances(self, lhs_dtype: torch.dtype, rhs_dtype: torch.dtype, + out_dtype: torch.dtype) -> tuple[float, float]: + if (lhs_dtype == torch.bfloat16 or rhs_dtype == torch.bfloat16 or + out_dtype == torch.bfloat16): + return 1e-3, 1e-2 # atol, rtol + return 1e-4, 1e-2 # atol, rtol + + LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]] + + def _init_test_cases(self): + self.tests_cases = [] + self.tests_cases.append({ + 'dtype': torch.float32, + 'm': 128, + 'k': 128, + 'n': 128, + 'num_groups': 1 + }) + self.tests_cases.append({ + 'dtype': torch.float32, + 'm': 256, + 'k': 128, + 'n': 128, + 'num_groups': 1 + }) + self.tests_cases.append({ + 'dtype': torch.float32, + 'm': 128, + 'k': 256, + 'n': 128, + 'num_groups': 8 + }) + self.tests_cases.append({ + 'dtype': torch.float32, + 'm': 512, + 'k': 128, + 'n': 256, + 'num_groups': 2 + }) + self.tests_cases.append({ + 'dtype': torch.bfloat16, + 'm': 128, + 'k': 128, + 'n': 128, + 'num_groups': 1 + }) + self.tests_cases.append({ + 'dtype': torch.bfloat16, + 'm': 256, + 'k': 128, + 'n': 128, + 'num_groups': 1 + }) + self.tests_cases.append({ + 'dtype': torch.bfloat16, + 'm': 128, + 'k': 256, + 'n': 128, + 'num_groups': 8 + }) + self.tests_cases.append({ + 'dtype': torch.bfloat16, + 'm': 512, + 'k': 128, + 'n': 256, + 'num_groups': 2 + }) + + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + def test_gmm(self): + self._init_test_cases() + for test_case in self.tests_cases: + num_groups = test_case['num_groups'] + k = test_case['k'] + m = test_case['m'] + n = test_case['n'] + lhs_dtype = rhs_dtype = test_case['dtype'] + out_dtype = torch.float32 + + lhs = torch.rand(m, k, dtype=lhs_dtype).to('xla') + rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype).to('xla') + group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) + out = megablox.gmm(lhs, rhs, group_sizes) + + ref_out = self._reference_gmm(lhs.cpu().float().numpy(), + rhs.cpu().float().numpy(), + group_sizes.numpy()) + + atol, rtol = self._tolerances(lhs_dtype, rhs_dtype, out_dtype) + np.testing.assert_allclose( + ref_out, np.array(out[0].cpu()), rtol=rtol, atol=atol) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + torch.set_default_dtype(torch.float32) + torch.manual_seed(42) + torch_xla._XLAC._xla_set_use_full_mat_mul_precision( + use_full_mat_mul_precision=True) + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_operations.py b/test/test_operations.py index ff32c268927..ed8f5a88151 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2422,6 +2422,31 @@ def test_aten_move_scalar_cuda_to_xla(self): # Has a different execution path than other tensors. self._test_move_tensor_cuda_to_xla(torch.tensor(42)) + def test_unsafe_buffer_pointer(self): + xla_device = xm.xla_device() + xla_tensor_0 = torch.tensor(42).to(xla_device) + # `mark_step` ensures xtensor->CurrentDataHandle() != nullptr + xm.mark_step() + buf_ptr_0 = torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor_0) + self.assertGreaterEqual(buf_ptr_0, 0) + + # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr + xla_tensor_1 = torch.tensor(42, device=xm.xla_device()) + buf_ptr_1 = torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor_1) + self.assertGreaterEqual(buf_ptr_1, 0) + + # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr + xla_tensor_2 = torch.ones((5, 5)).to(xla_device) + buf_ptr_2 = torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor_2) + self.assertGreaterEqual(buf_ptr_2, 0) + + xla_tensor_3 = torch.arange(5, device=xm.xla_device()) + xm.mark_step() + # Without the `wait_device_ops()`, the pjrt buffer (pjrt_data->buffer) at https://github.com/pytorch/xla/blob/e3fc03314dab5f44e3ed9ccbba6c15fbca3285cd/torch_xla/csrc/runtime/pjrt_computation_client.cc#L467 will be nullptr. + xm.wait_device_ops() + buf_ptr_3 = torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor_3) + self.assertGreaterEqual(buf_ptr_3, 0) + class SimpleModelWithDropout(torch.nn.Module): diff --git a/test/test_ops.py b/test/test_ops.py index 12b874593bd..3b098e85f93 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -225,6 +225,7 @@ def __new__(cls, name, variant_test_name=""): AllowedOpInfoEntry('norm', 'fro'), AllowedOpInfoEntry('special.erfcx'), AllowedOpInfoEntry('_native_batch_norm_legit'), + AllowedOpInfoEntry('full'), # Duplicate Redundant entries for this test. # AllowedOpInfoEntry('polygamma', 'polygamma_n_1'), @@ -393,7 +394,7 @@ def _cpu(t): return tuple(map(to_cpu, x)) elif isinstance(x, dict): return {k: to_cpu(v) for k, v in x.items()} - elif isinstance(x, (numbers.Number, bool, str)): + elif isinstance(x, (numbers.Number, bool, str, torch.dtype)): return x # Passthrough None because some functions wrapped with type promotion @@ -426,5 +427,4 @@ def test_reference_eager(self, device, dtype, op): instantiate_device_type_tests(TestOpInfo, globals()) if __name__ == '__main__': - #run_tests() unittest.main() diff --git a/test/test_pallas.py b/test/test_pallas.py index f8480782094..f686816034f 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -10,6 +10,8 @@ from torch_xla import runtime as xr from torch_xla._internal import tpu +import numpy as np + if xr.device_type() == 'TPU': from torch_xla.experimental.custom_kernel import jax_import_guard jax_import_guard() @@ -20,12 +22,52 @@ class PallasTest(unittest.TestCase): - def _attention(self, q, k, v): + # This is to create a diagonal mask where only elements within the same segment + # can attend to each other. Since the mask is to mask out the unrelevant parts, + # therefore we use != instead of ==. + def _make_attention_mask_from_segment_ids(self, q_segment_ids, + kv_segment_ids): + return q_segment_ids.view(q_segment_ids.shape[0], 1, + q_segment_ids.shape[1], 1) != kv_segment_ids.view( + kv_segment_ids.shape[0], 1, 1, + kv_segment_ids.shape[1]) + + def _attention(self, q, k, v, *, attn_mask=None): attn_weight = q @ k.transpose(-2, -1) + if attn_mask is not None: + # Masked out the unrelevant parts. + attn_weight = attn_weight.masked_fill(attn_mask, + torch.finfo(attn_weight.dtype).min) attn_weight = nn.functional.softmax(attn_weight, dim=-1) attn_output = attn_weight @ v return attn_output + # The following helper functions prefixed with _pagedattention are used for PagedAttention unit tests + # Reference: https://github.com/google/jax/blob/main/tests/pallas/paged_attention_kernel_test.py + def _pagedattention_generate_qkv( + self, + seq_lens, + page_size, + max_seq_len, + num_kv_heads, + num_heads, + head_dim, + dtype=torch.float32, + ): + assert max_seq_len % page_size == 0 + pages_per_sequence = max_seq_len // page_size + batch_size = len(seq_lens) + total_pages = batch_size * pages_per_sequence + k_pages = torch.randn( + num_kv_heads, total_pages, page_size, head_dim, dtype=dtype) + v_pages = torch.randn( + num_kv_heads, total_pages, page_size, head_dim, dtype=dtype) + page_indices = torch.randperm( + batch_size * pages_per_sequence, dtype=torch.int32) + page_indices = page_indices.reshape(batch_size, pages_per_sequence) + q = torch.randn(batch_size, num_heads, head_dim, dtype=dtype) + return q, k_pages, v_pages, page_indices + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") def test_tpu_custom_call_pallas_add(self): # This payload is generated by the following Pallas code: @@ -454,6 +496,303 @@ def test_flash_attention_backward(self): self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, + "This test only works on TPUv4+.") + def test_paged_attention_wrapper(self): + from torch_xla.experimental.custom_kernel import paged_attention + from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention + + max_kv_len = 2048 + block_size = 512 + page_size = 64 + num_kv_heads = 8 + q_kv_head_ratio = 8 + head_dim = 256 + dtype = torch.float32 + seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32) + + q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv( + seq_lens, + page_size, + max_kv_len, + num_kv_heads, + num_kv_heads * q_kv_head_ratio, + head_dim, + ) + + q_xla = q.to("xla") + k_pages_xla = k_pages.to("xla") + v_pages_xla = v_pages.to("xla") + seq_lens_xla = seq_lens.to("xla") + page_indices_xla = page_indices.to("xla") + + output = paged_attention( + q_xla, + k_pages_xla, + v_pages_xla, + seq_lens_xla, + page_indices_xla, + pages_per_compute_block=block_size // page_size, + ) + + q_jax = jnp.array(q.numpy(), dtype=jnp.float32) + k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32) + v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32) + seq_lens_jax = jnp.array(seq_lens.numpy(), dtype=jnp.int32) + page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32) + expected_output = torch.from_numpy( + np.array( + jax_paged_attention( + q_jax, + k_pages_jax, + v_pages_jax, + seq_lens_jax, + page_indices_jax, + pages_per_compute_block=block_size // page_size, + ))) + + self.assertTrue( + torch.allclose( + output.cpu()[seq_lens > 0], + expected_output.cpu()[seq_lens > 0], + atol=1e-5, + rtol=1e-5)) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, + "This test only works on TPUv4+.") + def test_paged_attention_wrapper_with_dynamo(self): + from torch_xla.experimental.custom_kernel import paged_attention + from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention + + max_kv_len = 2048 + block_size = 512 + page_size = 64 + num_kv_heads = 8 + q_kv_head_ratio = 8 + head_dim = 256 + dtype = torch.float32 + seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32) + + q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv( + seq_lens, + page_size, + max_kv_len, + num_kv_heads, + num_kv_heads * q_kv_head_ratio, + head_dim, + ) + + q_xla = q.to("xla") + k_pages_xla = k_pages.to("xla") + v_pages_xla = v_pages.to("xla") + seq_lens_xla = seq_lens.to("xla") + page_indices_xla = page_indices.to("xla") + + def paged_attention_wrapper(q, k, v, seq_lens, page_indices, + pages_per_compute_block): + return torch.ops.xla.paged_attention( + q, + k, + v, + seq_lens, + page_indices, + pages_per_compute_block=pages_per_compute_block, + ) + + compiled_paged_attention = torch.compile( + paged_attention_wrapper, backend="openxla") + + output = compiled_paged_attention( + q_xla, + k_pages_xla, + v_pages_xla, + seq_lens_xla, + page_indices_xla, + pages_per_compute_block=block_size // page_size, + ) + + q_jax = jnp.array(q.numpy(), dtype=jnp.float32) + k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32) + v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32) + seq_lens_jax = jnp.array(seq_lens.numpy(), dtype=jnp.int32) + page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32) + expected_output = torch.from_numpy( + np.array( + jax_paged_attention( + q_jax, + k_pages_jax, + v_pages_jax, + seq_lens_jax, + page_indices_jax, + pages_per_compute_block=block_size // page_size, + ))) + + self.assertTrue( + torch.allclose( + output.cpu()[seq_lens > 0], + expected_output.cpu()[seq_lens > 0], + atol=1e-5, + rtol=1e-5)) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_wrapper_segment_ids_1(self): + from torch_xla.experimental.custom_kernel import flash_attention + from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention as jax_flash_attention, SegmentIds + + q = torch.randn(3, 2, 128, 4) + k = torch.randn(3, 2, 128, 4) + v = torch.randn(3, 2, 128, 4) + zeros = torch.zeros(3, 32) + segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) + o = flash_attention( + q.to("xla"), k.to("xla"), v.to("xla"), False, segment_ids.to("xla"), + segment_ids.to("xla")) + + jax_q = jnp.array(q.numpy(), dtype=jnp.float32) + jax_k = jnp.array(k.numpy(), dtype=jnp.float32) + jax_v = jnp.array(v.numpy(), dtype=jnp.float32) + jax_segment_ids = jnp.array(segment_ids.numpy(), dtype=jnp.float32) + expected_o = torch.from_numpy( + np.array( + jax_flash_attention( + jax_q, + jax_k, + jax_v, + segment_ids=SegmentIds(jax_segment_ids, jax_segment_ids), + ))) + + self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_wrapper_segment_ids_2(self): + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) + from torch_xla.experimental.custom_kernel import flash_attention + + q = torch.randn(3, 2, 128, 4).to("xla") + k = torch.randn(3, 2, 128, 4).to("xla") + v = torch.randn(3, 2, 128, 4).to("xla") + zeros = torch.zeros(3, 32).to("xla") + segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) + o = flash_attention(q, k, v, False, segment_ids, segment_ids) + + expected_o = self._attention( + q, + k, + v, + attn_mask=self._make_attention_mask_from_segment_ids( + segment_ids, segment_ids)) + self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_backward_segment_ids(self): + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) + from torch_xla.experimental.custom_kernel import flash_attention + + torch.manual_seed(42) + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + zeros = torch.zeros(4, 32).to("xla") + segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = flash_attention(q, k, v, False, segment_ids, segment_ids) + loss = o.sum() + loss.backward() + xm.mark_step() + + q_grad = q.grad + k_grad = k.grad + v_grad = v.grad + + torch.manual_seed(42) + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + zeros = torch.zeros(4, 32).to("xla") + segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = self._attention( + q, + k, + v, + attn_mask=self._make_attention_mask_from_segment_ids( + segment_ids, segment_ids)) + loss = o.sum() + loss.backward() + xm.mark_step() + + for i in [(q, q_grad), (k, k_grad), (v, v_grad)]: + self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_wrapper_sm_scale(self): + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) + from torch_xla.experimental.custom_kernel import flash_attention + + q = torch.randn(3, 2, 128, 4).to("xla") + k = torch.randn(3, 2, 128, 4).to("xla") + v = torch.randn(3, 2, 128, 4).to("xla") + sm_scale = 0.7 + o = flash_attention(q, k, v, False, None, None, sm_scale) + + expected_o = self._attention(q * sm_scale, k, v) + self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_sm_scale_backward(self): + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) + from torch_xla.experimental.custom_kernel import flash_attention + + torch.manual_seed(42) + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + sm_scale = 0.7 + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = flash_attention(q, k, v, False, None, None, sm_scale) + loss = o.sum() + loss.backward() + xm.mark_step() + + q_grad = q.grad + k_grad = k.grad + v_grad = v.grad + + torch.manual_seed(42) + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = self._attention(q * sm_scale, k, v) + loss = o.sum() + loss.backward() + xm.mark_step() + + # Hmm, the gradients are the same even the autograd graph seems different. + for i in [(q, q_grad), (k, k_grad), (v, v_grad)]: + self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index dc2f4e96dba..2653d0ed8c2 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -14,12 +14,15 @@ python3 test/spmd/test_xla_auto_sharding.py XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shape_models.py -v XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.py -v python3 test/test_autocast.py +python3 test/test_grad_checkpoint.py python3 test/dynamo/test_dynamo.py python3 test/spmd/test_spmd_debugging.py python3 test/pjrt/test_dtypes.py python3 test/pjrt/test_dynamic_plugin_tpu.py python3 test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py python3 test/test_pallas.py +python3 test/test_input_output_aliases.py +python3 test/test_megablox.py python3 test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py python3 test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py python3 test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py diff --git a/torch_patches/README.md b/torch_patches/README.md deleted file mode 100644 index f6476f64ca5..00000000000 --- a/torch_patches/README.md +++ /dev/null @@ -1,32 +0,0 @@ -# Guidelines For Patch File Names - -Files with extension '.diff' are consider as git patches by apply script. - -A file for PyTorch PR _N_ needs to be named 'N.diff'. - -Patch files which are not related to PyTorch PRs, should begin with an 'X' character, -followed by a two digit number, followed by a dash ('-'), a name, and '.diff'. -Example: - -``` -X10-optimizer.diff -``` - -Patch file are alphabetically ordered, so PyTorch PR patches are always applied -before the non PyTorch ones. - - -There's a special file `torch_patches/.torch_pin`, which is used to coordinate landing PRs in -`pytorch/pytorch` and `pytorch/xla`. - -To test a `pytorch/xla` PR against a `pytorch/pytorch` PR or branch, -put the PR number or branch name in this file. -Example: - -``` -#32451 -# or -my_awesome_branch # (must live in `pytorch/pytorch`) -``` - -In the case where the pytorch/pytorch PR also depends on the pytorch/xla PR, you will also need to update the https://github.com/pytorch/pytorch/blob/main/.github/ci_commit_pins/xla.txt to match the latest hash of your pytorch/xla PR. To be noted, the hash from a PR produced by a fork won't work in this case. Then you need to find someone from the pytorch/xla team to produe a branch PR for you. diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index ebc0af6c7ad..2841a65c885 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -6,6 +6,7 @@ import torch import _XLAC from ._internal import tpu +from .version import __version__ logging.basicConfig() logger = logging.getLogger(__name__) @@ -76,6 +77,8 @@ def _setup_default_env(): os.environ.setdefault('ALLOW_MULTIPLE_LIBTPU_LOAD', '1') os.environ.setdefault('TPU_ML_PLATFORM', 'PyTorch/XLA') + # This is used for ML Framework Telemetry. + os.environ.setdefault('TPU_ML_PLATFORM_VERSION', __version__) if tpu.version() == 4: os.environ.setdefault('TPU_MEGACORE', 'megacore_dense') @@ -149,7 +152,6 @@ def _setup_tpu_vm_library_path() -> bool: import atexit from ._patched_functions import _apply_patches -from .version import __version__ _found_libtpu = _setup_tpu_vm_library_path() @@ -186,6 +188,27 @@ def _init_xla_lazy_backend(): # TODO @wonjoo come up with a long term fix in Dynamo. torch._dynamo.config.automatic_dynamic_shapes = False +# Activate view-replay on AOTAutograd. +# See: https://github.com/pytorch/pytorch/pull/124488 +import torch._functorch.config + +torch._functorch.config.view_replay_for_aliased_outputs = True + +import importlib.metadata +import warnings + +try: + # TensorFlow TPU distribution has the same package name as GPU, but not CPU + dist = importlib.metadata.distribution('tensorflow') + warnings.warn( + "`tensorflow` can conflict with `torch-xla`. Prefer `tensorflow-cpu` when" + " using PyTorch/XLA. To silence this warning, `pip uninstall -y " + "tensorflow && pip install tensorflow-cpu`. If you are in a notebook " + "environment such as Colab or Kaggle, restart your notebook runtime " + "afterwards.") +except importlib.metadata.PackageNotFoundError: + pass + from .stablehlo import save_as_stablehlo, save_torch_model_as_stablehlo from .experimental import plugins diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index a7ae1c47964..12a49a91ad9 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -29,6 +29,7 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ops/as_strided.h" #include "torch_xla/csrc/ops/as_strided_view_update.h" +#include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/diagonal_view_update.h" #include "torch_xla/csrc/ops/einsum_utilities.h" #include "torch_xla/csrc/ops/index_ops.h" @@ -1290,6 +1291,38 @@ at::Tensor XLANativeFunctions::embedding_dense_backward( num_weights, padding_idx, scale_grad_by_freq)); } +std::tuple +XLANativeFunctions::_embedding_bag_forward_only( + const at::Tensor& weight, const at::Tensor& indices, + const at::Tensor& offsets, bool scale_grad_by_freq, int64_t mode, + bool sparse, const c10::optional& per_sample_weights, + bool include_last_offset, int64_t padding_idx) { + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + if (mode == 1 || scale_grad_by_freq || sparse || padding_idx != -1) { + return at::native::call_fallback_fn< + &xla_cpu_fallback, + ATEN_OP(_embedding_bag_forward_only)>::call(weight, indices, offsets, + scale_grad_by_freq, mode, + sparse, per_sample_weights, + include_last_offset, + padding_idx); + } + auto indices_tensor = bridge::GetXlaTensor(indices); + auto sample_weights = + per_sample_weights.has_value() && per_sample_weights.value().defined() + ? bridge::GetXlaTensor(per_sample_weights.value()) + : tensor_methods::full_like(indices_tensor, 1.0, + *torch_xla::bridge::GetXlaDevice(weight), + at::ScalarType::Float); + auto result = tensor_methods::embedding_bag( + bridge::GetXlaTensor(weight), indices_tensor, + bridge::GetXlaTensor(offsets), mode, sample_weights, include_last_offset); + return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(result)), + bridge::AtenFromXlaTensor(std::get<1>(result)), + bridge::AtenFromXlaTensor(std::get<2>(result)), + bridge::AtenFromXlaTensor(std::get<3>(result))); +} + at::Tensor XLANativeFunctions::empty_symint( at::SymIntArrayRef sym_size, c10::optional dtype, c10::optional layout, c10::optional device, @@ -1438,9 +1471,18 @@ at::Tensor XLANativeFunctions::full(at::IntArrayRef size, return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP(full)>::call( size, fill_value, dtype, layout, device, pin_memory); } - return bridge::AtenFromXlaTensor(tensor_methods::full( - absl::Span(size), fill_value, - GetXlaDeviceOrCurrent(device), at::dtype_or_default(dtype))); + at::ScalarType intend_dtype; + if (dtype || fill_value.isFloatingPoint()) { + // Respect the dtype if it is being explictlly passed in. + // All python scalar will be passed in as float64 to the backend, but the + // default behavior for pytorch is to return a float32 tensor in this case. + intend_dtype = at::dtype_or_default(dtype); + } else { + intend_dtype = fill_value.type(); + } + return bridge::AtenFromXlaTensor( + tensor_methods::full(absl::Span(size), fill_value, + GetXlaDeviceOrCurrent(device), intend_dtype)); } at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim, @@ -2497,7 +2539,38 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input, // 1) Aid XLA's InputOutputAlias. auto input_tensor = bridge::GetXlaTensor(input); auto output_tensor = bridge::GetXlaTensor(output); - output_tensor->data()->alias_id = input_tensor->GetUniqueId(); + if (input_tensor->CurrentDataHandle() != nullptr || + (input_tensor->CurrentIrValue().node != nullptr && + torch_xla::DeviceData::Cast( + input_tensor->CurrentIrValue().node.get()))) { + /* + if input has a XLAData or holds a devicedata node, set alias_id to + tensor_id. Consider the case. + + // x.tensor_id = 1, x.alias_id = 1 + x = torch.randn(5,5).to(xla_device()) + // x.tensor_id = 2, x.alias_id should be 1 + x += 1 + xm.mark_step() + // x.tensor_id =3, x.alias_id should be 2 since input tensor id will be 2 + // for this graph + x *= 1 of 1 + */ + output_tensor->data()->alias_id = input_tensor->GetUniqueId(); + } else { + /* + Consider the case + + // x.tensor_id = 1, x.alias_id = 1 + x = torch.randn(5,5).to(xla_device()) + // x.tensor_id = 2, x.alias_id should be 1 + x += 1 + // x.tensor_id = 3, x.alias_id should still be 1 + x * = 2 + xm.mark_step() + */ + output_tensor->data()->alias_id = input_tensor->data()->alias_id; + } // 2) Aid SPMD. XLATensor::ShardingSpecPtr sharding = input_tensor->sharding_spec(); @@ -3709,6 +3782,7 @@ at::Tensor XLANativeFunctions::embedding_symint(const at::Tensor& weight, scale_grad_by_freq, sparse); } + // TODO: We need to make use of the TPU embedding core here eventually. TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::embedding( bridge::GetXlaTensor(weight), bridge::GetXlaTensor(indices))); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 9834a25780e..3943e3f451b 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2323,6 +2323,24 @@ void InitXlaModuleBindings(py::module m) { [](at::Tensor& self, const at::Tensor& source) -> at::Tensor& { return XLANativeFunctions::set_(self, source); }); + m.def("_xla_custom_call", + [](const std::vector& inputs, const std::string& target, + const std::vector>& output_shapes, + const std::vector& output_dtypes, bool has_side_effect, + const std::string& backend_config, + const int api_version) -> std::vector { + std::vector dtypes; + dtypes.reserve(output_dtypes.size()); + for (auto& dtype : output_dtypes) { + dtypes.push_back( + reinterpret_cast(dtype.ptr())->scalar_type); + } + + auto xtensors = tensor_methods::custom_call( + bridge::GetXlaTensors(inputs), target, output_shapes, dtypes, + has_side_effect, backend_config, api_version); + return bridge::AtenFromXlaTensors(std::move(xtensors)); + }); m.def("_xla_tpu_custom_call", [](const std::vector& inputs, const std::string& payload, const std::vector>& output_shapes, @@ -2488,6 +2506,31 @@ void InitXlaModuleBindings(py::module m) { return false; }); + m.def("_unsafe_buffer_pointer", + [](const at::Tensor& input) -> std::uintptr_t { + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLA_CHECK(xtensor) << "The input is not an XLA tensor."; + if (xtensor->CurrentDataHandle() != nullptr) { + std::shared_ptr data = + std::dynamic_pointer_cast( + xtensor->CurrentDataHandle()); + return runtime::GetComputationClient()->UnsafeBufferPointer(data); + } else if (xtensor->CurrentIrValue().node != nullptr) { + DeviceData* device_data = + DeviceData::Cast(xtensor->CurrentIrValue().node.get()); + if (device_data != nullptr) { + torch::lazy::BackendDataPtr data = device_data->data(); + return runtime::GetComputationClient()->UnsafeBufferPointer( + UnwrapXlaData(data)); + } else { + XLA_ERROR() << "Could not get the buffer pointer for XLATensor " + "with IR that's not DeviceData"; + } + } + XLA_ERROR() << "Could not get the buffer pointer for XLATensor " + "without a data handle or an IR."; + }); + // -------------Dynamo Integration API Start------------------------- /* * Return tensor ids and at::tensors for all DeviceData nodes that is needed diff --git a/torch_xla/csrc/ops/custom_call.cpp b/torch_xla/csrc/ops/custom_call.cpp new file mode 100644 index 00000000000..00347e0c975 --- /dev/null +++ b/torch_xla/csrc/ops/custom_call.cpp @@ -0,0 +1,70 @@ +#include "torch_xla/csrc/ops/custom_call.h" + +#include + +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/shape_helper.h" + +namespace torch_xla { + +CustomCall::CustomCall(torch::lazy::OpList inputs, + const std::string& call_target, xla::Shape output_shape, + bool has_side_effect, const std::string& backend_config, + const int api_version) + : XlaNode(xla_custom_call, inputs, std::move(output_shape), + /*num_outputs=*/output_shape.tuple_shapes_size(), + torch::lazy::MHash(call_target)), + call_target_(call_target), + has_side_effect_(has_side_effect), + backend_config_(backend_config), + api_version_(api_version) {} + +torch::lazy::NodePtr CustomCall::Clone(torch::lazy::OpList operands) const { + return torch::lazy::MakeNode(operands, call_target_, + this->xla_shape(), has_side_effect_, + backend_config_, api_version_); +} + +XlaOpVector CustomCall::Lower(LoweringContext* loctx) const { + std::vector inputs; + inputs.reserve(this->operands().size()); + for (auto& operand : operands()) { + inputs.push_back(loctx->GetOutputOp(operand)); + } + xla::Shape output_shape = this->xla_shape(); + const int n_outputs = output_shape.tuple_shapes_size(); + if (n_outputs == 1) { + output_shape = output_shape.tuple_shapes(0); + } + XLA_CHECK(api_version_ >= 0 && api_version_ < 5); + xla::XlaOp output = xla::CustomCall( + inputs[0].builder(), call_target_, inputs, output_shape, + /*opaque=*/backend_config_, + /*has_side_effect=*/has_side_effect_, + /*output_operand_aliasing=*/{}, + /*literal=*/nullptr, + /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/static_cast(api_version_)); + std::vector result; + if (n_outputs == 1) { + result = {output}; + } else { + result.reserve(n_outputs); + for (int i = 0; i < n_outputs; ++i) { + result.push_back(xla::GetTupleElement(output, i)); + } + } + return ReturnOps(result, loctx); +} + +std::string CustomCall::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString() << ", call_target=" << call_target_ + << ", has_side_effect=" << has_side_effect_ + << ", backend_config=" << backend_config_ + << ", api_version=" << api_version_; + return ss.str(); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/custom_call.h b/torch_xla/csrc/ops/custom_call.h new file mode 100644 index 00000000000..69bb613d4b6 --- /dev/null +++ b/torch_xla/csrc/ops/custom_call.h @@ -0,0 +1,29 @@ +#ifndef XLA_TORCH_XLA_CSRC_OPS_CUSTOM_CALL_H_ +#define XLA_TORCH_XLA_CSRC_OPS_CUSTOM_CALL_H_ + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { + +class CustomCall : public XlaNode { + public: + CustomCall(torch::lazy::OpList inputs, const std::string& call_target, + xla::Shape output_shape, bool has_side_effect, + const std::string& backend_config, const int api_version); + + std::string ToString() const override; + + torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + private: + std::string call_target_; + bool has_side_effect_; + std::string backend_config_; + int api_version_; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_OPS_CUSTOM_CALL_H_ diff --git a/torch_xla/csrc/ops/embedding_bag.cpp b/torch_xla/csrc/ops/embedding_bag.cpp new file mode 100644 index 00000000000..d2bb034a005 --- /dev/null +++ b/torch_xla/csrc/ops/embedding_bag.cpp @@ -0,0 +1,192 @@ +#include "torch_xla/csrc/ops/embedding_bag.h" + +#include "torch_xla/csrc/helpers.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" +#include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/shape_helper.h" +#include "torch_xla/csrc/xla_lower_util.h" +#include "tsl/platform/stacktrace.h" +#include "xla/client/lib/constants.h" +#include "xla/client/lib/loops.h" +#include "xla/client/lib/slicing.h" +#include "xla/shape_util.h" + +namespace torch_xla { +namespace { +const int MODE_SUM = 0; +const int MODE_MEAN = 1; +const int MODE_MAX = 2; +std::vector BuildEmbeddingBag(xla::XlaOp weight, xla::XlaOp indices, + xla::XlaOp offsets, + xla::XlaOp per_sample_weights, + bool include_last_offset, int mode) { + xla::Shape offset_shape = ShapeHelper::ShapeOfXlaOp(offsets); + int64_t n = offset_shape.dimensions(0); + xla::Shape weight_shape = ShapeHelper::ShapeOfXlaOp(weight); + int64_t weight_dim = weight_shape.dimensions(1); + xla::Shape indices_shape = ShapeHelper::ShapeOfXlaOp(indices); + int64_t num_embeddings = indices_shape.dimensions(0); + XLA_CHECK(indices_shape.rank() == 1 || indices_shape.rank() == 2) + << "input has to be a 1D or 2D Tensor, but got Tensor of dimension " + << indices_shape.rank(); + if (indices_shape.rank() == 1) { + XLA_CHECK(offset_shape.rank() == 1) + << "offsets has to be a 1D Tensor, but got Tensor of dimension " + << offset_shape.rank(); + } + XLA_CHECK(weight_shape.rank() == 2) + << "weight has to be a 2D Tensor, but got Tensor of dimension " + << weight_shape.rank(); + + xla::XlaOp output2 = xla::ZerosLike(indices); + xla::XlaOp output3 = xla::ZerosLike(offsets); + std::vector sizes = {n, weight_dim}; + xla::XlaOp output4 = + xla::Zeros(offsets.builder(), + xla::ShapeUtil::MakeShape(offset_shape.element_type(), sizes)); + + xla::XlaOp embeddings = xla::TorchIndexSelect(weight, indices, 0); + xla::XlaOp embeddings_weighted = xla::Mul( + embeddings, xla::ConvertElementType( + xla::BroadcastInDim(per_sample_weights, + {num_embeddings, weight_dim}, {0}), + weight_shape.element_type())); + + std::vector shape_elements = { + xla::ShapeUtil::MakeShape(offset_shape.element_type(), {}), + xla::ShapeUtil::MakeShape(offset_shape.element_type(), {}), + xla::ShapeUtil::MakeShape(weight_shape.element_type(), + {num_embeddings, weight_dim}), + xla::ShapeUtil::MakeShape(weight_shape.element_type(), {1, weight_dim})}; + xla::Shape result_shape = xla::ShapeUtil::MakeTupleShape(shape_elements); + + xla::XlaComputation condition; + { + xla::XlaBuilder builder("condition"); + auto prev = xla::Parameter(&builder, 0, result_shape, "prev"); + auto index = xla::GetTupleElement(prev, 0); + auto final_value = xla::GetTupleElement(prev, 1); + xla::Lt(index, final_value); + condition = builder.Build().value(); + } + + xla::XlaComputation body; + { + xla::XlaBuilder builder("body"); + auto prev = xla::Parameter(&builder, 0, result_shape, "prev"); + auto index = xla::GetTupleElement(prev, 0); + auto emb = xla::GetTupleElement(prev, 2); + auto w = xla::GetTupleElement(prev, 3); + + xla::XlaOp slice = xla::DynamicSlice( + emb, + {index, xla::ConvertElementType(xla::ConstantR0(&builder, 0), + offset_shape.element_type())}, + {1, weight_dim}); + xla::XlaOp result = + mode == MODE_SUM ? xla::Add(w, slice) : xla::Max(w, slice); + + xla::Tuple(&builder, + { + xla::Add(index, xla::ConvertElementType( + xla::ConstantR0(&builder, 1), + offset_shape.element_type())), + xla::GetTupleElement(prev, 1), + xla::GetTupleElement(prev, 2), + result, + }); + body = builder.Build().value(); + } + + xla::Array initial_vector({1, weight_dim}, 0.f); + std::vector results; + for (int64_t i = 0; i < n; i++) { + xla::XlaOp start = xla::DynamicSlice( + offsets, {xla::ConstantR0(offsets.builder(), i)}, {1}); + if (i == n - 1 && include_last_offset) continue; + xla::XlaOp end = + i == n - 1 && !include_last_offset + ? xla::ConvertElementType(xla::ConstantR1( + offsets.builder(), 1, num_embeddings), + offset_shape.element_type()) + : xla::DynamicSlice( + offsets, {xla::ConstantR0(offsets.builder(), i + 1)}, + {1}); + // Create a While node with computations for the condition and the body. + auto init_tuple = xla::Tuple( + offsets.builder(), + {xla::Reshape(start, {0}, {}), xla::Reshape(end, {0}, {}), + embeddings_weighted, + xla::ConvertElementType( + xla::ConstantFromArray(offsets.builder(), initial_vector), + weight_shape.element_type())}); + auto result = xla::While(condition, body, init_tuple); + results.push_back(xla::GetTupleElement(result, 3)); + }; + xla::XlaOp output1 = xla::ConcatInDim(offsets.builder(), results, 0); + return {output1, output2, output3, output4}; +} + +xla::Shape NodeOutputShapes(const torch::lazy::Value& weight, + const torch::lazy::Value& indices, + const torch::lazy::Value& offsets, + const torch::lazy::Value& per_sample_weights, + bool include_last_offset, bool mode) { + auto lower_for_shapes_fn = + [&](absl::Span operands) -> xla::XlaOp { + return xla::Tuple( + operands[0].builder(), + BuildEmbeddingBag(operands[0], operands[1], operands[2], operands[3], + include_last_offset, mode)); + }; + + std::vector input_shapes = { + GetXlaShape(weight), GetXlaShape(indices), GetXlaShape(offsets), + GetXlaShape(per_sample_weights)}; + + return InferOutputShape(absl::MakeSpan(input_shapes), lower_for_shapes_fn); +} +} // namespace + +std::string EmbeddingBag::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString(); + return ss.str(); +} + +EmbeddingBag::EmbeddingBag(const torch::lazy::Value& weight, + const torch::lazy::Value& indices, + const torch::lazy::Value& offsets, int64_t mode, + const torch::lazy::Value& per_sample_weights, + bool include_last_offset) + : XlaNode( + torch::lazy::OpKind(at::aten::embedding_bag), + {weight, indices, offsets, per_sample_weights}, + [&]() { + return NodeOutputShapes(weight, indices, offsets, + per_sample_weights, include_last_offset, + mode); + }, + /*num_outputs=*/4, torch::lazy::MHash(mode, include_last_offset)), + mode_(mode), + include_last_offset_(include_last_offset) {} + +torch::lazy::NodePtr EmbeddingBag::Clone(torch::lazy::OpList operands) const { + return torch::lazy::MakeNode(operands.at(0), operands.at(1), + operands.at(2), mode_, + operands.at(3), false); +} + +XlaOpVector EmbeddingBag::Lower(LoweringContext* loctx) const { + xla::XlaOp weight = loctx->GetOutputOp(operand(0)); + xla::XlaOp indices = loctx->GetOutputOp(operand(1)); + xla::XlaOp offsets = loctx->GetOutputOp(operand(2)); + xla::XlaOp per_sample_weights = loctx->GetOutputOp(operand(3)); + std::vector ops = + BuildEmbeddingBag(weight, indices, offsets, per_sample_weights, + include_last_offset_, mode_); + return ReturnOps(absl::MakeSpan(ops), loctx); +} + +} // namespace torch_xla \ No newline at end of file diff --git a/torch_xla/csrc/ops/embedding_bag.h b/torch_xla/csrc/ops/embedding_bag.h new file mode 100644 index 00000000000..4d9b0a6eecb --- /dev/null +++ b/torch_xla/csrc/ops/embedding_bag.h @@ -0,0 +1,31 @@ +#ifndef XLA_TORCH_XLA_CSRC_OPS_EMBEDDING_BAG_H_ +#define XLA_TORCH_XLA_CSRC_OPS_EMBEDDING_BAG_H_ + +#include + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { + +class EmbeddingBag : public XlaNode { + public: + EmbeddingBag(const torch::lazy::Value& weight, + const torch::lazy::Value& indices, + const torch::lazy::Value& offsets, int64_t mode, + const torch::lazy::Value& per_sample_weights, + bool include_last_offset); + + std::string ToString() const override; + + torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + private: + int64_t mode_; + bool include_last_offset_; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_OPS_EMBEDDING_BAG_H_ \ No newline at end of file diff --git a/torch_xla/csrc/ops/xla_ops.cpp b/torch_xla/csrc/ops/xla_ops.cpp index fa7adbc5ce6..d61c3cbc839 100644 --- a/torch_xla/csrc/ops/xla_ops.cpp +++ b/torch_xla/csrc/ops/xla_ops.cpp @@ -9,6 +9,7 @@ const OpKindWrapper xla_as_strided_view_update("xla::as_strided_view_update"); const OpKindWrapper xla_cast("xla::cast"); const OpKindWrapper xla_collective_permute("xla::collective_permute"); const OpKindWrapper xla_cross_replica_sum("xla::cross_replica_sum"); +const OpKindWrapper xla_custom_call("xla::custom_call"); const OpKindWrapper xla_device_data("xla::device_data"); const OpKindWrapper xla_dequantize_tensor("xla::dequantize_tensor"); const OpKindWrapper xla_diagonal_view_update("xla::diagonal_view_update"); diff --git a/torch_xla/csrc/ops/xla_ops.h b/torch_xla/csrc/ops/xla_ops.h index ac8e54eb407..8d8d7874364 100644 --- a/torch_xla/csrc/ops/xla_ops.h +++ b/torch_xla/csrc/ops/xla_ops.h @@ -35,6 +35,7 @@ extern const OpKindWrapper xla_as_strided_view_update; extern const OpKindWrapper xla_cast; extern const OpKindWrapper xla_collective_permute; extern const OpKindWrapper xla_cross_replica_sum; +extern const OpKindWrapper xla_custom_call; extern const OpKindWrapper xla_device_data; extern const OpKindWrapper xla_dequantize_tensor; extern const OpKindWrapper xla_diagonal_view_update; diff --git a/torch_xla/csrc/runtime/cache.h b/torch_xla/csrc/runtime/cache.h index bef5b099ec6..9557b2353b7 100644 --- a/torch_xla/csrc/runtime/cache.h +++ b/torch_xla/csrc/runtime/cache.h @@ -173,6 +173,7 @@ class PersistentCache : public AbstractCache { TORCH_LAZY_COUNTER("PersistentCacheMiss", 1); return nullptr; } + TORCH_LAZY_TIMED("PersistentCacheLoad"); std::stringstream ss; std::ifstream in(path, std::ios::binary); ss << in.rdbuf(); diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 33b48255baf..cc58736e8dc 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -302,6 +302,8 @@ class ComputationClient { virtual std::vector TransferFromDevice( absl::Span handles) = 0; + virtual std::uintptr_t UnsafeBufferPointer(const DataPtr handle) = 0; + // Compiles a set of computations. virtual std::vector Compile( std::vector instances) = 0; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 20ee9b0bfa6..842398126d0 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -96,7 +96,7 @@ std::string IfrtComputationClient::IfrtDeviceToString( xla::ifrt::Device* const device) const { std::string platform = absl::AsciiStrToUpper(device->client()->platform_name()); - int ordinal = global_ordinals_.at(device->id()); + int ordinal = global_ordinals_.at(device->Id().value()); std::string str = absl::StrFormat("%s:%d", platform, ordinal); return str; } @@ -124,11 +124,12 @@ IfrtComputationClient::IfrtComputationClient() { // a device's global ordinal separately from its device ID. Order the // devices by increasing ID to assign global ordinals. std::vector ordered_devices(client_->device_count()); - std::partial_sort_copy(client_->devices().begin(), client_->devices().end(), - ordered_devices.begin(), ordered_devices.end(), - [](auto& a, auto& b) { return a->id() < b->id(); }); + std::partial_sort_copy( + client_->devices().begin(), client_->devices().end(), + ordered_devices.begin(), ordered_devices.end(), + [](auto& a, auto& b) { return a->Id().value() < b->Id().value(); }); for (auto* device : ordered_devices) { - global_ordinals_[device->id()] = global_ordinals_.size(); + global_ordinals_[device->Id().value()] = global_ordinals_.size(); std::string device_str = IfrtDeviceToString(device); string_to_device_.emplace(device_str, device); } @@ -396,6 +397,11 @@ tsl::RCReference IfrtComputationClient::ReplicateShardedData( return *replicated_output; } +std::uintptr_t IfrtComputationClient::UnsafeBufferPointer( + const DataPtr handle) { + XLA_ERROR() << __FUNCTION__ << " not implemented"; +} + std::vector IfrtComputationClient::TransferFromDevice( absl::Span handles) { metrics::TimedSection timed(TransferFromDeviceMetric()); @@ -615,7 +621,7 @@ std::vector IfrtComputationClient::GetAllDevices() const { int IfrtComputationClient::GetNumProcesses() const { int max_process_index = client_->process_index(); for (auto* device : client_->devices()) { - max_process_index = std::max(max_process_index, device->process_index()); + max_process_index = std::max(max_process_index, device->ProcessIndex()); } return max_process_index + 1; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index d6d914ad8da..4c10be9d1ca 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -54,6 +54,8 @@ class IfrtComputationClient : public ComputationClient { std::vector TransferFromDevice( absl::Span handles) override; + std::uintptr_t UnsafeBufferPointer(const DataPtr handle) override; + DataPtr TransferShardsToDevice( absl::Span> tensor_shards, std::string device, xla::Shape shape, xla::OpSharding sharding) override; @@ -134,7 +136,7 @@ class IfrtComputationClient : public ComputationClient { // global_ordinals_ tracks a map from PjRtDeviceId to the device's // dense global ordinal. std::unordered_map global_ordinals_; - std::unordered_map string_to_device_; + std::unordered_map string_to_device_; std::shared_ptr> replication_devices_; OperationManager operation_manager_; tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool( diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index a129a476a2e..3d5fbaf1f8e 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -458,6 +458,17 @@ std::vector PjRtComputationClient::ReshardData( return resharded_results; } +std::uintptr_t PjRtComputationClient::UnsafeBufferPointer( + const DataPtr handle) { + std::shared_ptr pjrt_data = + std::dynamic_pointer_cast(handle); + XLA_CHECK(pjrt_data) << "handle must be PjRtData, got " << handle->ToString(); + xla::StatusOr ptr = + client_->UnsafeBufferPointer(pjrt_data->buffer.get()); + XLA_CHECK(ptr.ok()); + return ptr.value(); +} + std::vector PjRtComputationClient::TransferFromDevice( absl::Span handles) { metrics::TimedSection timed(TransferFromDeviceMetric()); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 9a911c0139b..350b1193ef7 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -55,6 +55,8 @@ class PjRtComputationClient : public ComputationClient { std::vector TransferFromDevice( absl::Span handles) override; + std::uintptr_t UnsafeBufferPointer(const DataPtr handle) override; + DataPtr TransferShardsToDevice( absl::Span> tensor_shards, std::string device, xla::Shape shape, xla::OpSharding sharding) override; diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index 99e23f4b555..52b06d89cb4 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -21,8 +21,24 @@ namespace runtime { namespace { +// Placeholder plugin for testing only. Does not implement multiprocessing or +// configuration. Very likely will not work from Python code. +class LibraryPlugin : public PjRtPlugin { + public: + std::string library_path() const override { + return sys_util::GetEnvString("PJRT_LIBRARY_PATH", ""); + } + + const std::unordered_map + client_create_options() const override { + return {}; + } + + bool requires_xla_coordinator() const override { return false; } +}; + std::unordered_map> - pjrt_plugins_; + pjrt_plugins_ = {{"LIBRARY", std::make_shared()}}; xla::GpuAllocatorConfig GetGpuAllocatorConfig() { auto allocator_config = xla::GpuAllocatorConfig{}; diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 5d7937c0e48..7bd954844ec 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -40,6 +40,7 @@ #include "torch_xla/csrc/ops/count_nonzero.h" #include "torch_xla/csrc/ops/cumprod.h" #include "torch_xla/csrc/ops/cumsum.h" +#include "torch_xla/csrc/ops/custom_call.h" #include "torch_xla/csrc/ops/dequant_tensor.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/diagonal.h" @@ -48,6 +49,7 @@ #include "torch_xla/csrc/ops/dynamic_view.h" #include "torch_xla/csrc/ops/einsum.h" #include "torch_xla/csrc/ops/einsum_backward.h" +#include "torch_xla/csrc/ops/embedding_bag.h" #include "torch_xla/csrc/ops/expand.h" #include "torch_xla/csrc/ops/expand_symint.h" #include "torch_xla/csrc/ops/exponential.h" @@ -520,6 +522,41 @@ std::pair collective_permute( torch::lazy::Value(node, 1)}; } +std::vector custom_call( + const std::vector& inputs, const std::string& target, + const std::vector>& output_shapes, + const std::vector& output_dtypes, bool has_side_effect, + const std::string& backend_config, const int api_version) { + XLA_CHECK(inputs.size() > 0) << "inputs are empty"; + + std::vector values; + values.reserve(inputs.size()); + for (const auto& input : inputs) { + values.push_back(input->GetIrValue()); + } + + XLA_CHECK_EQ(output_shapes.size(), output_dtypes.size()); + std::vector output_xla_shapes; + output_xla_shapes.reserve(output_shapes.size()); + for (size_t i = 0; i < output_shapes.size(); ++i) { + output_xla_shapes.push_back(xla::ShapeUtil::MakeShape( + MakeXlaPrimitiveType(output_dtypes[i], &(inputs[0]->GetDevice())), + output_shapes[i])); + } + + auto node = torch::lazy::MakeNode( + values, target, xla::ShapeUtil::MakeTupleShape(output_xla_shapes), + has_side_effect, backend_config, api_version); + + std::vector outputs; + outputs.reserve(output_shapes.size()); + for (size_t i = 0; i < output_shapes.size(); ++i) { + outputs.push_back( + inputs[0]->CreateFrom(torch::lazy::Value(node, i), output_dtypes[i])); + } + return outputs; +} + void custom_sharding_( const XLATensorPtr& input, const std::shared_ptr& sharding_spec, @@ -1326,6 +1363,20 @@ XLATensorPtr embedding(const XLATensorPtr& weight, return tensor_ops::Embedding(weight, indices); } +std::tuple +embedding_bag(const XLATensorPtr& weight, const XLATensorPtr& indices, + const XLATensorPtr& offsets, int64_t mode, + const XLATensorPtr& per_sample_weights, + bool include_last_offset) { + torch::lazy::NodePtr node = torch::lazy::MakeNode( + weight->GetIrValue(), indices->GetIrValue(), offsets->GetIrValue(), mode, + per_sample_weights->GetIrValue(), include_last_offset); + return std::make_tuple(weight->CreateFrom(torch::lazy::Value(node, 0)), + weight->CreateFrom(torch::lazy::Value(node, 1)), + weight->CreateFrom(torch::lazy::Value(node, 2)), + weight->CreateFrom(torch::lazy::Value(node, 3))); +} + XLATensorPtr exp(const XLATensorPtr& input) { return input->CreateFrom(Exp(input->GetIrValue())); } diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 14163c49cda..11df2c6eb74 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -80,6 +80,12 @@ std::pair collective_permute( const XLATensorPtr& input, const torch::lazy::Value& token, std::vector> source_target_pairs); +std::vector custom_call( + const std::vector& inputs, const std::string& target, + const std::vector>& output_shapes, + const std::vector& output_dtypes, bool has_side_effect, + const std::string& backend_config, const int api_version); + void custom_sharding_( const XLATensorPtr& input, const std::shared_ptr& spec, @@ -386,6 +392,11 @@ XLATensorPtr embedding_dense_backward(const XLATensorPtr& grad_output, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq); +std::tuple +embedding_bag(const XLATensorPtr& weight, const XLATensorPtr& indices, + const XLATensorPtr& offsets, int64_t mode, + const XLATensorPtr& per_sample_weights, bool include_last_offset); + XLATensorPtr embedding(const XLATensorPtr& weight, const XLATensorPtr& indices); XLATensorPtr eq(const XLATensorPtr& input, const at::Scalar& other); diff --git a/torch_xla/csrc/unwrap_data.h b/torch_xla/csrc/unwrap_data.h index 7d5080e84bf..6bf1cc60e0a 100644 --- a/torch_xla/csrc/unwrap_data.h +++ b/torch_xla/csrc/unwrap_data.h @@ -11,6 +11,9 @@ namespace torch_xla { +runtime::ComputationClient::DataPtr UnwrapXlaData( + const torch::lazy::BackendDataPtr& data); + std::vector UnwrapXlaData( absl::Span datas); diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index bb4ce0c4e23..7e4f2b4c3b3 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -18,7 +18,7 @@ def _extract_backend_config( module: "jaxlib.mlir._mlir_libs._mlir.ir.Module") -> str | None: """ This algorithm intends to extract the backend config from the compiler IR like the following, - and it is designed to traverse any generic MLIR module. + and it is not designed to traverse any generic MLIR module. module @jit_add_vectors attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<8xi32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<8xi32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) { @@ -55,17 +55,12 @@ def jax_import_guard(): torch_xla._XLAC._init_computation_client() -def trace_pallas(kernel: Callable, - *args, - static_argnums=None, - static_argnames=None, - **kwargs): +def to_jax_shape_dtype_struct(tensor: torch.Tensor) -> "jax.ShapeDtypeStruct": # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. jax_import_guard() import jax import jax.numpy as jnp - import jax._src.pallas.mosaic.pallas_call_registration def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype: if dtype == torch.float32: @@ -93,14 +88,28 @@ def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype: else: raise ValueError(f"Unsupported dtype: {dtype}") + return jax.ShapeDtypeStruct(tensor.shape, + convert_torch_dtype_to_jax(tensor.dtype)) + + +def trace_pallas(kernel: Callable, + *args, + static_argnums=None, + static_argnames=None, + **kwargs): + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + import jax + import jax._src.pallas.mosaic.pallas_call_registration + jax_args = [] # for tracing tensor_args = [] # for execution for i, arg in enumerate(args): # TODO: Could the args be a tuple of tensors or a list of tensors? Flattern them? if torch.is_tensor(arg): # ShapeDtypeStruct doesn't have any storage and thus is very suitable for generating the payload. - jax_meta_tensor = jax.ShapeDtypeStruct( - arg.shape, convert_torch_dtype_to_jax(arg.dtype)) + jax_meta_tensor = to_jax_shape_dtype_struct(arg) jax_args.append(jax_meta_tensor) tensor_args.append(arg) else: @@ -167,15 +176,38 @@ class FlashAttention(torch.autograd.Function): "block_k_dq": 256, "block_k_major_dq": 512, } + NUM_LANES = 128 + NUM_SUBLANES = 8 @staticmethod - def forward(ctx, q, k, v, causal=False, partition_spec=None, mesh=None): + def prepare_segment_ids(q_segment_ids, kv_segment_ids): + from jax.experimental.pallas.ops.tpu.flash_attention import SegmentIds + if q_segment_ids is None or kv_segment_ids is None: + return None, None, None + + assert q_segment_ids is not None and kv_segment_ids is not None, "Both q_segment_ids and kv_segment_ids should be provided." + segment_ids = SegmentIds( + to_jax_shape_dtype_struct(q_segment_ids), + to_jax_shape_dtype_struct(kv_segment_ids)) + q_segment_ids = q_segment_ids.unsqueeze(-1).expand( + [-1 for _ in q_segment_ids.shape] + [FlashAttention.NUM_LANES]) + kv_segment_ids = kv_segment_ids.unsqueeze(1).expand([ + kv_segment_ids.shape[0], FlashAttention.NUM_SUBLANES, + kv_segment_ids.shape[1] + ]) + return segment_ids, q_segment_ids, kv_segment_ids + + @staticmethod + def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, + partition_spec, mesh): # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. jax_import_guard() + import jax from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_impl ctx.causal = causal + ctx.sm_scale = sm_scale ctx.partition_spec = partition_spec ctx.mesh = mesh ctx.full_shape = None @@ -192,37 +224,49 @@ def forward(ctx, q, k, v, causal=False, partition_spec=None, mesh=None): k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor - # It returns the shape and type of o, l, m. - def shape_dtype(q, *arg): - if not save_residuals: - return [(q.shape, q.dtype)] + # It computes the shape and type of o, l, m. + shapes = [q.shape] + dtypes = [q.dtype] + if save_residuals: res_shape = list(q.shape) res_shape[-1] = FlashAttention.MIN_BLOCK_SIZE - return [(q.shape, q.dtype), (res_shape, torch.float32), - (res_shape, torch.float32)] - - # We can't directly use flash_attention as we need to override the save_residuals flag which returns - # l and m that is needed for the backward. Then we lose all the shape checks. - # TODO: replicate the shape checks on flash_attention. - _flash_attention_impl = make_kernel_from_pallas(_flash_attention_impl, - shape_dtype) + for _ in range(2): + shapes.append(res_shape) + dtypes.append(torch.float32) + with torch.no_grad(): - o = _flash_attention_impl( + segment_ids, q_segment_ids, kv_segment_ids = FlashAttention.prepare_segment_ids( + q_segment_ids, kv_segment_ids) + ctx.segment_ids = segment_ids + + # We can't directly use flash_attention as we need to override the save_residuals flag which returns + # l and m that is needed for the backward. Then we lose all the shape checks. + # TODO: replicate the shape checks on flash_attention. + # Here we seperate the tracing and execution part just to support SegmentIds. + payload, _ = trace_pallas( + _flash_attention_impl, q, k, v, None, - None, + segment_ids, save_residuals, causal, - 1.0, + sm_scale, min(FlashAttention.DEFAULT_BLOCK_SIZES["block_b"], q.shape[0]), min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q"], q.shape[2]), min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]), min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]), False, static_argnums=range(5, 13)) + + args = [q, k, v] + if segment_ids is not None: + args += [q_segment_ids, kv_segment_ids] + o = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes) + if not save_residuals: + o = o[0] # SPMD integration if partition_spec is not None: o = xs.disable_manual_sharding( @@ -240,18 +284,21 @@ def shape_dtype(q, *arg): m = xs.disable_manual_sharding( m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor - ctx.save_for_backward(full_q, full_k, full_v, o, l, m) + ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids, + kv_segment_ids) return o @staticmethod def backward(ctx, grad_output): from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv - q, k, v, o, l, m = ctx.saved_tensors + q, k, v, o, l, m, q_segment_ids, kv_segment_ids = ctx.saved_tensors causal = ctx.causal + sm_scale = ctx.sm_scale partition_spec = ctx.partition_spec mesh = ctx.mesh full_shape = ctx.full_shape + segment_ids = ctx.segment_ids grad_q = grad_k = grad_v = None grad_i = torch.sum( @@ -286,7 +333,7 @@ def backward(ctx, grad_output): k, v, None, - None, + segment_ids, l, m, grad_output, @@ -298,7 +345,7 @@ def backward(ctx, grad_output): k.shape[2]), block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dq"], k.shape[2]), - sm_scale=1.0, + sm_scale=sm_scale, causal=causal, mask_value=FlashAttention.DEFAULT_MASK_VALUE, debug=False, @@ -306,9 +353,13 @@ def backward(ctx, grad_output): "block_q_major", "block_k_major", "block_k", "sm_scale", "causal", "mask_value", "debug" ]) - grad_q = torch_xla._XLAC._xla_tpu_custom_call( - [q, k, v, expanded_l, expanded_m, grad_output, expanded_grad_i], - payload, [q.shape], [q.dtype])[0] + + args = [q, k, v] + if segment_ids is not None: + args += [q_segment_ids, kv_segment_ids] + args += [expanded_l, expanded_m, grad_output, expanded_grad_i] + grad_q = torch_xla._XLAC._xla_tpu_custom_call(args, payload, [q.shape], + [q.dtype])[0] if ctx.needs_input_grad[1] or ctx.needs_input_grad[2]: payload, _ = trace_pallas( @@ -317,7 +368,7 @@ def backward(ctx, grad_output): k, v, None, - None, + segment_ids, l, m, grad_output, @@ -332,7 +383,7 @@ def backward(ctx, grad_output): k.shape[2]), block_q=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dkv"], q.shape[2]), - sm_scale=1.0, + sm_scale=sm_scale, causal=causal, mask_value=FlashAttention.DEFAULT_MASK_VALUE, debug=False, @@ -340,9 +391,14 @@ def backward(ctx, grad_output): "block_q_major", "block_k_major", "block_k", "block_q", "sm_scale", "causal", "mask_value", "debug" ]) - grads = torch_xla._XLAC._xla_tpu_custom_call( - [q, k, v, expanded_l, expanded_m, grad_output, expanded_grad_i], - payload, [k.shape, v.shape], [k.dtype, v.dtype]) + + args = [q, k, v] + if segment_ids is not None: + args += [q_segment_ids, kv_segment_ids] + args += [expanded_l, expanded_m, grad_output, expanded_grad_i] + grads = torch_xla._XLAC._xla_tpu_custom_call(args, payload, + [k.shape, v.shape], + [k.dtype, v.dtype]) if ctx.needs_input_grad[1]: grad_k = grads[0] if ctx.needs_input_grad[2]: @@ -357,7 +413,7 @@ def backward(ctx, grad_output): grad_v = xs.disable_manual_sharding( grad_v, partition_spec, full_shape, mesh=mesh).global_tensor - return grad_q, grad_k, grad_v, None, None, None + return grad_q, grad_k, grad_v, None, None, None, None, None, None def flash_attention( @@ -365,10 +421,73 @@ def flash_attention( k, # [batch_size, num_heads, kv_seq_len, d_model] v, # [batch_size, num_heads, kv_seq_len, d_model] causal=False, + q_segment_ids=None, # [batch_size, q_seq_len] + kv_segment_ids=None, # [batch_size, kv_seq_len] + sm_scale=1.0, *, partition_spec=None, mesh=None): - return FlashAttention.apply(q, k, v, causal, partition_spec, mesh) + # TODO: support SPMD and Dynamo with segment_ids. + return FlashAttention.apply(q, k, v, causal, q_segment_ids, kv_segment_ids, + sm_scale, partition_spec, mesh) + + +def paged_attention(q, k_pages, v_pages, lengths, page_indices, + pages_per_compute_block): + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention + + payload, tensor_args = trace_pallas( + paged_attention, + q, + k_pages, + v_pages, + lengths, + page_indices, + pages_per_compute_block=pages_per_compute_block, + static_argnames=["pages_per_compute_block"], + ) + + batch_size, num_heads, head_dim = q.shape + num_kv_heads, _, page_size, head_dim_k = k_pages.shape + batch_size_paged_indices, pages_per_sequence = page_indices.shape + q_dtype_for_kernel_launch = q.dtype + if (num_heads // num_kv_heads) % 8 != 0: + q = q.reshape(batch_size, num_heads, 1, head_dim) + q_dtype_for_kernel_launch = torch.float32 + + page_indices_reshaped = page_indices.reshape(-1) + buffer_index = torch.zeros((1,), dtype=torch.int32).to("xla") + step = torch.zeros((1,), dtype=torch.int32).to("xla") + output_shape = torch.Size(list(q.shape[:-1]) + [1]) + + output, _, _ = torch_xla._XLAC._xla_tpu_custom_call( + [ + lengths, + page_indices_reshaped, + buffer_index, + step, + q.to(q_dtype_for_kernel_launch), + k_pages, + v_pages, + ], payload, [q.shape, output_shape, output_shape], + [q_dtype_for_kernel_launch, torch.float32, torch.float32]) + + return output.reshape(batch_size, num_heads, head_dim).to(q.dtype) + + +def non_xla_attetion(q, k, v, attention_type): + # This will be called when dynamo use fake tensor to construct the fake output. + # We need to make sure output tensor's shape is correct. + if k.device != torch.device("meta"): + warnings.warn( + f'XLA {attention_type} attention should only be applied to tensors on XLA device' + ) + + # Return orignal shape of q. + return torch.empty_like(q) XLA_LIB.define( @@ -389,14 +508,26 @@ def flash_attention_non_xla(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False): - # This will be called when dynamo use fake tensor to construct the fake output. - # We need to make sure output tensor's shape is correct. - if k.device != torch.device("meta"): - warnings.warn( - 'XLA flash attention should only be applied to tensors on XLA device') + return non_xla_attetion(q, k, v, "flash") + + +XLA_LIB.define( + "paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block) -> Tensor", +) + + +@impl(XLA_LIB, "paged_attention", "XLA") +def paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor, + v_pages: torch.Tensor, lengths: torch.Tensor, + page_indices: torch.Tensor, + pages_per_compute_block: int): + return paged_attention(q, k_pages, v_pages, lengths, page_indices, + pages_per_compute_block) + - # perform a regular attention if input tensors are not on XLA device. - attn_weight = q @ k.transpose(-2, -1) - attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1) - attn_output = attn_weight @ v - return attn_output +@impl(XLA_LIB, "paged_attention", "CompositeExplicitAutograd") +def paged_attention_non_xla(q: torch.Tensor, k_pages: torch.Tensor, + v_pages: torch.Tensor, lengths: torch.Tensor, + page_indices: torch.Tensor, + pages_per_compute_block: int): + return non_xla_attetion(q, k_pages, v_pages, "paged") diff --git a/torch_xla/experimental/megablox/__init__.py b/torch_xla/experimental/megablox/__init__.py new file mode 100644 index 00000000000..63f60f808cd --- /dev/null +++ b/torch_xla/experimental/megablox/__init__.py @@ -0,0 +1 @@ +from .gmm import gmm diff --git a/torch_xla/experimental/megablox/common.py b/torch_xla/experimental/megablox/common.py new file mode 100644 index 00000000000..19555ab208a --- /dev/null +++ b/torch_xla/experimental/megablox/common.py @@ -0,0 +1,22 @@ +"""Common utilities for Pallas kernels.""" + +from typing import Union +import torch +from torch_xla._internal import tpu + + +def assert_is_supported_dtype(dtype: torch.dtype) -> None: + if dtype != torch.bfloat16 and dtype != torch.float32: + raise ValueError(f"Expected bfloat16 or float32 array but got {dtype}.") + + +def select_input_dtype(lhs: torch.Tensor, rhs: torch.Tensor) -> torch.dtype: + """A type to which both input should be adapted to before dot product.""" + # bf16xbf16 matmul is only supported since TPU v4 generation. In + # case of mixed input precision, we need to convert bf16 argument to fp32 + # beforehand. + if (tpu.version() >= 4 and lhs.dtype == torch.bfloat16 and + rhs.dtype == torch.bfloat16): + return torch.bfloat16 + else: + return torch.float32 diff --git a/torch_xla/experimental/megablox/gmm.py b/torch_xla/experimental/megablox/gmm.py new file mode 100644 index 00000000000..518553d474e --- /dev/null +++ b/torch_xla/experimental/megablox/gmm.py @@ -0,0 +1,395 @@ +"""Grouped matrix multiplication kernels for TPU written in Pallas.""" + +from typing import Any, Callable, Optional, Union +from torch_xla.experimental.megablox import common +from torch_xla.experimental.custom_kernel import jax_import_guard +import torch +import torch_xla +import numpy as np + + +def _validate_args( + *, + lhs: torch.Tensor, + rhs: torch.Tensor, + group_sizes: torch.Tensor, + expected_rhs_dims: int = 3, +) -> 'tuple[jnp.ndarray, jnp.ndarray, jnp.dtype]': + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + import jax + import jax.numpy as jnp + """Validates the arguments for the gmm function.""" + # Validate 'lhs'. + if lhs.dim() != 2: + raise ValueError(f"Expected 2-tensor for 'lhs' but got {lhs.dim()}-tensor.") + common.assert_is_supported_dtype(lhs.dtype) + + # Validate 'rhs'. + if rhs.dim() != expected_rhs_dims: + raise ValueError(f"Expected {expected_rhs_dims}-tensor for 'rhs' but got" + f" {rhs.dim()}-tensor.") + common.assert_is_supported_dtype(rhs.dtype) + + # Validate 'group_sizes'. + if group_sizes.dtype != torch.int32: + raise ValueError( + f"Expected 32-bit integer 'group_sizes' but got {group_sizes.dtype}.") + + return lhs, group_sizes, common.select_input_dtype(lhs, rhs) + + +def _calculate_num_tiles(x: int, tx: int) -> int: + tiles, rem = divmod(x, tx) + if rem: + raise ValueError(f"{x} must be divisible by x-dimension tile size ({tx}).") + return tiles + + +def _calculate_irregular_num_tiles(x: int, tx: int) -> tuple[int, int]: + tiles, rem = divmod(x, tx) + if rem: + tiles += 1 + return tiles, rem + + +GroupMetadata = Any # TODO(enriqueps): Clean this up and use a namedtuple + + +def _make_group_metadata( + *, + group_sizes: 'jnp.ndarray', + m: int, + tm: int, + start_group: 'jnp.ndarray', + num_nonzero_groups: int, + visit_empty_groups: bool = True, +) -> GroupMetadata: + """Create the metadata needed for grouped matmul computation. + + Args: + group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype. + m: The number of rows in lhs. + tm: The m-dimension tile size being used. + start_group: The group in group sizes to start computing from. This is + particularly useful for when rhs num_groups is sharded. + num_nonzero_groups: Number of groups in group sizes to compute on. Useful in + combination with group_offset. + visit_empty_groups: If True, do not squeeze tiles for empty groups out of + the metadata. This is necessary for tgmm, where we at least need to zero + the output for each group. + + Returns: + tuple of: + group_offsets: A 1d, jnp.ndarray with shape [num_groups+1] and jnp.int32 + dtype. group_offsets[i] indicates the row at which group [i] starts in + the lhs matrix and group_offsets[i-1] = m. + group_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and + jnp.int32 dtype. group_ids[i] indicates which group grid index 'i' will + work on. + m_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and + jnp.int32. m_tile_ids[i] indicates which m-dimension tile grid index 'i' + will work on. + num_tiles: The number of m-dimension tiles to execute. + """ + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + import jax + import jax.numpy as jnp + + num_groups = group_sizes.shape[0] + end_group = start_group + num_nonzero_groups - 1 + + # Calculate the offset of each group, starting at zero. This metadata is + # similar to row offsets in a CSR matrix. The following properties hold: + # + # group_offsets.shape = [num_groups + 1] + # group_offsets[0] = 0 + # group_offsets[num_groups] = m + # + # The row at which group 'i' starts is group_offsets[i]. + group_ends = jnp.cumsum(group_sizes) + group_offsets = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), group_ends]) + + # Assign a group id to each grid index. + # + # If a group starts somewhere other than the start of a tile or ends somewhere + # other than the end of a tile we need to compute that full tile. Calculate + # the number of tiles for each group by rounding their end up to the nearest + # 'tm' and their start down to the nearest 'tm'. + + # (1) Round the group_ends up to the nearest multiple of 'tm'. + # + # NOTE: This does not change group_offsets[num_groups], which is m + # (because we enforce m is divisible by tm). + rounded_group_ends = ((group_ends + tm - 1) // tm * tm).astype(jnp.int32) + + # (2) Round the group_starts down to the nearest multiple of 'tm'. + group_starts = jnp.concatenate( + [jnp.zeros(1, dtype=jnp.int32), group_ends[:-1]]) + rounded_group_starts = group_starts // tm * tm + + # (3) Calculate the number of rows in each group. + # + # NOTE: Handle zero-sized groups as a special case. If the start for a + # zero-sized group is not divisible by 'tm' its start will be rounded down and + # its end will be rounded up such that its size will become 1 tile here. + rounded_group_sizes = rounded_group_ends - rounded_group_starts + rounded_group_sizes = jnp.where(group_sizes == 0, 0, rounded_group_sizes) + + # (4) Convert the group sizes from units of rows to unit of 'tm' sized tiles. + # + # An m-dimension tile is 'owned' by group 'i' if the first row of the tile + # belongs to group 'i'. In addition to owned tiles, each group can have 0 or 1 + # initial partial tiles if it's first row does not occur in the first row of a + # tile. The '0-th' group never has a partial tile because it always starts at + # the 0-th row. + # + # If no group has a partial tile, the total number of tiles is equal to + # 'm // tm'. If every group has a partial except the 0-th group, the total + # number of tiles is equal to 'm // tm + num_groups - 1'. Thus we know that + # + # tiles_m <= group_tiles.sum() <= tiles_m + num_groups - 1 + # + # Where tiles_m = m // tm. + # + # NOTE: All group sizes are divisible by 'tm' because of the rounding in steps + # (1) and (2) so this division is exact. + group_tiles = rounded_group_sizes // tm + + if visit_empty_groups: + # Insert one tile for empty groups. + group_tiles = jnp.where(group_sizes == 0, 1, group_tiles) + + # Create the group ids for each grid index based on the tile counts for each + # group. + # + # NOTE: This repeat(...) will pad group_ids with the final group id if + # group_tiles.sum() < tiles_m + num_groups - 1. The kernel grid will be sized + # such that we only execute the necessary number of tiles. + tiles_m = _calculate_num_tiles(m, tm) + group_ids = jnp.repeat( + jnp.arange(num_groups, dtype=jnp.int32), + group_tiles, + total_repeat_length=tiles_m + num_groups - 1, + ) + + # Assign an m-dimension tile id to each grid index. + # + # NOTE: Output tiles can only be re-visited consecutively. The following + # procedure guarantees that m-dimension tile indices respect this. + + # (1) Calculate how many times each m-dimension tile will be visited. + # + # Each tile is guaranteed to be visited once by the group that owns the tile. + # The remaining possible visits occur when a group starts inside of a tile at + # a position other than the first row. We can calculate which m-dimension tile + # each group starts in by floor-dividing its offset with `tm` and then count + # tile visits with a histogram. + # + # To avoid double counting tile visits from the group that owns the tile, + # filter these out by assigning their tile id to `tile_m` (one beyond the max) + # such that they're ignored by the subsequent histogram. Also filter out any + # group which is empty. + # + # TODO(tgale): Invert the 'partial_tile_mask' predicates to be more clear. + partial_tile_mask = jnp.logical_or((group_offsets[:-1] % tm) == 0, + group_sizes == 0) + + # Explicitly enable tiles for zero sized groups, if specified. This covers + # zero sized groups that start on a tile-aligned row and those that do not. + if visit_empty_groups: + partial_tile_mask = jnp.where(group_sizes == 0, 0, partial_tile_mask) + + partial_tile_ids = jnp.where(partial_tile_mask, tiles_m, + group_offsets[:-1] // tm) + + tile_visits = ( + jnp.histogram(partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1))[0] + + 1) + + # Create the m-dimension tile ids for each grid index based on the visit + # counts for each tile. + m_tile_ids = jnp.repeat( + jnp.arange(tiles_m, dtype=jnp.int32), + tile_visits.astype(jnp.int32), + total_repeat_length=tiles_m + num_groups - 1, + ) + + # Account for sharding. + # + # Find the start of the groups owned by our shard and shift the group_ids and + # m_tile_ids s.t. the metadata for our tiles are at the front of the arrays. + # + # TODO(tgale): Move this offset into the kernel to avoid these rolls. + first_tile_in_shard = (group_ids < start_group).sum() + group_ids = jnp.roll(group_ids, shift=-first_tile_in_shard, axis=0) + m_tile_ids = jnp.roll(m_tile_ids, shift=-first_tile_in_shard, axis=0) + + # Calculate the number of tiles we need to compute for our shard. + # + # Remove tile visits that belong to a group not in our shard. + iota = jnp.arange(num_groups, dtype=jnp.int32) + active_group_mask = jnp.logical_and(iota <= end_group, iota >= start_group) + group_tiles = jnp.where(active_group_mask, group_tiles, 0) + num_tiles = group_tiles.sum() + return (group_offsets, group_ids, m_tile_ids), num_tiles + + +def _zero_uninitialized_memory( + out: 'jnp.ndarray', + *, + start_group: 'jnp.ndarray', + num_nonzero_groups: int, + group_metadata: GroupMetadata, +) -> torch.Tensor: + """Zero out uninitialized memory from output.""" + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + import jax + import jax.numpy as jnp + + group_offsets = group_metadata[0] + group_start = group_offsets[start_group] + group_end = group_offsets[start_group + num_nonzero_groups] + valid_mask = jax.lax.broadcasted_iota(jnp.int32, (out.shape[0],), 0) + valid_mask = (valid_mask >= group_start) & (valid_mask < group_end) + return torch.from_numpy(np.array(jnp.where(valid_mask[:, None], out, + 0))).to('xla') + + +LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]] + + +def _gmm( + lhs: torch.Tensor, + rhs: torch.Tensor, + group_sizes: torch.Tensor, + payload: str, + preferred_element_type: torch.dtype = torch.float32, + tiling: Optional[Union[tuple[int, int, int], LutFn]] = (128, 128, 128), + group_offset: Optional[torch.Tensor] = None, + existing_out: Optional[torch.Tensor] = None, + transpose_rhs: bool = False, + interpret: bool = False, +) -> torch.Tensor: + """Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'. + + Args: + lhs: A 2d, torch.Tensor with shape [m, k]. + rhs: A 3d, torch.Tensor with shape [num_groups, k, n]. + group_sizes: A 1d, torch.Tensor with shape [num_groups] and torch.int32 dtype. + payload: pallas payload extracted from the pallas code on JAX. + preferred_element_type: torch.dtype, the element type for the output matrix. + tiling: 3-tuple of ints. The m, k and n-dimension tile sizes. + group_offset: The group in group sizes to start computing from. This is + particularly useful for when rhs num_groups is sharded. + existing_out: Existing output to write to. + transpose_rhs: True if the rhs needs to be transposed. + interpret: Whether or not to run the kernel in interpret mode, helpful for + testing and debugging. + + Returns: + A 2d, torch.Tensor with shape [m, n]. + """ + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + import jax + import jax.numpy as jnp + + if existing_out is not None: + assert isinstance(existing_out, jax.Array) + expected_dtype = existing_out.dtype + if expected_dtype != preferred_element_type: + raise ValueError( + "Existing output dtype must match preferred_element_type.") + if group_offset is None: + group_offset = jnp.array([0], dtype=jnp.int32) + else: + group_offset = jnp.ndarray(group_offset.numpy()) + if group_offset.shape: + raise ValueError( + f"group_offset must be a ()-shaped array. Got: {group_offset.shape}.") + group_offset = group_offset[None] + num_current_groups = rhs.shape[0] + num_total_groups = group_sizes.shape[0] + lhs, group_sizes, input_dtype = _validate_args( + lhs=lhs, rhs=rhs, group_sizes=group_sizes) + + # Gather shape information. + m, k, n = (lhs.shape[0], lhs.shape[1], rhs.shape[2]) + if transpose_rhs: + n = rhs.shape[1] + + # If tiling is callable, look up the problem dimensions in the LUT. If no tuned + # tile dimensions are available throw an error. + if callable(tiling): + tiling = tiling(m, k, n) + + if tiling is None: + raise ValueError(f"No tuned tiling found for (m, k, n) = ({m}, {k}, {n})") + + tm, tk, tn = tiling + tiles_k, k_rem = _calculate_irregular_num_tiles(k, tk) + tiles_n, n_rem = _calculate_irregular_num_tiles(n, tn) + del n_rem + + # Create the metadata we need for computation. + group_sizes = jnp.asarray(group_sizes.numpy()) + group_metadata, num_active_tiles = _make_group_metadata( # pylint: disable=unbalanced-tuple-unpacking + group_sizes=group_sizes, + m=m, + tm=tm, + start_group=group_offset[0], + num_nonzero_groups=rhs.shape[0], + visit_empty_groups=False, + ) + group_metadata0 = torch.from_numpy(np.array(group_metadata[0])).to( + torch.int32).to("xla") + group_metadata1 = torch.from_numpy(np.array(group_metadata[1])).to("xla") + group_metadata2 = torch.from_numpy(np.array(group_metadata[2])).to("xla") + num_active_tiles = torch.tensor(np.array(num_active_tiles)).to("xla") + group_offset_torch = torch.from_numpy(np.array(group_offset)).to("xla") + output_shape = torch.Size([m, n]) + out = torch_xla._XLAC._xla_tpu_custom_call([ + num_active_tiles, group_metadata0, group_metadata1, group_metadata2, + group_offset_torch, lhs, rhs + ], payload, [output_shape], [preferred_element_type]) + + if existing_out is None and num_current_groups < num_total_groups: + out = jnp.asarray(out.cpu().float().numpy()) + out = _zero_uninitialized_memory( + out, + start_group=group_offset[0], + num_nonzero_groups=rhs.shape[0], + group_metadata=group_metadata, + ) + return out + + +def gmm( + lhs: torch.Tensor, + rhs: torch.Tensor, + group_sizes: torch.Tensor, + preferred_element_type: torch.dtype = torch.float32, + tiling: Optional[Union[tuple[int, int, int], LutFn]] = (128, 128, 128), + group_offset: Optional[torch.Tensor] = None, + existing_out: Optional[torch.Tensor] = None, + transpose_rhs: bool = False, + interpret: bool = False, +): + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + import jax + from jax.experimental.pallas.ops.tpu.megablox import gmm + from torch_xla.experimental.custom_kernel import trace_pallas + + payload, _ = trace_pallas(gmm, lhs, rhs, group_sizes) + out = _gmm(lhs, rhs, group_sizes, payload, preferred_element_type, tiling, + group_offset, existing_out, transpose_rhs, interpret) + return out diff --git a/torch_xla/experimental/stablehlo_custom_call.py b/torch_xla/experimental/stablehlo_custom_call.py new file mode 100644 index 00000000000..e729d0b7791 --- /dev/null +++ b/torch_xla/experimental/stablehlo_custom_call.py @@ -0,0 +1,31 @@ +import torch +import torch_xla + + +# TODO(lsy323): Register as a torch op, cannot do that because parameter +# `ScalarType[] output_dtypes` in the op schema has some problem. +def stablehlo_custom_call(args, + call_target, + output_shapes, + output_dtypes, + has_side_effect=False, + backend_config="", + api_version=0): + res = torch_xla._XLAC._xla_custom_call(args, call_target, output_shapes, + output_dtypes, has_side_effect, + backend_config, api_version) + if len(output_shapes) == 1: + return res[0] + return res + + +def extract_custom_call_outputs_shape_dtype(n: torch.fx.Node): + assert 'val' in n.meta + if isinstance(n.meta['val'], torch.Tensor): + return [n.meta['val'].shape], [n.meta['val'].dtype] + output_shape_dtype = [(t.shape, + t.dtype) if isinstance(t, torch.Tensor) else None + for t in n.meta['val']] + assert None not in output_shape_dtype + output_shape, output_dtype = zip(*output_shape_dtype) + return output_shape, output_dtype diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index 3642354ab91..103cb7161be 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -1,24 +1,26 @@ import copy -from dataclasses import dataclass +import dataclasses import enum import json import os -from typing import List, Tuple, Optional, Mapping, Any, Dict -import dataclasses +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Tuple import numpy as np import torch -from torch.fx import _pytree as fx_pytree import torch_xla -from torch_xla.core import xla_model as xm -from torch_xla.core import dynamo_bridge -from torch_xla.debug import metrics import torch_xla.experimental.quantized -from torch_xla.experimental.unbounded_dynamism_export import exported_program_has_symbolic_input_shape, process_exported_program_with_symbolic_input -from torch.utils import _pytree as pytree from torch._decomp import get_decompositions - -from typing import Tuple +from torch.fx import _pytree as fx_pytree +from torch.utils import _pytree as pytree +from torch_xla.core import dynamo_bridge +from torch_xla.core import xla_model as xm +from torch_xla.debug import metrics +from torch_xla.experimental.stablehlo_custom_call import ( + extract_custom_call_outputs_shape_dtype, stablehlo_custom_call) +from torch_xla.experimental.unbounded_dynamism_export import ( + exported_program_has_symbolic_input_shape, + process_exported_program_with_symbolic_input) def _get_numpy_dtype(dtype): @@ -59,6 +61,9 @@ class StableHLOExportOptions: # Whether to export the weights export_weights: bool = True + # Ops that will be mapped to stablehlo.custom_call in the + # exported StableHLO graph. + custom_ops_allowed_in_graph: Set[str] = field(default_factory=set) class StableHLOGraphModule: @@ -214,10 +219,11 @@ class StableHLOModelBundle: class XLAExportInterpreter(torch.fx.Interpreter): - def __init__(self, module, device): + def __init__(self, module, device, custom_ops_allowed_in_graph): self._device = device super().__init__(module) self.tensor_id_to_dynamic_dims = {} + self.custom_ops_allowed_in_graph = custom_ops_allowed_in_graph def _mark_dynamic(self, tensor, dynamic_dims): tid = torch_xla._XLAC._xla_get_tensor_id(tensor) @@ -262,6 +268,14 @@ def run_node(self, n) -> Any: ] self._mark_dynamic(res, dynamic_dims) return res + if n.op == 'call_function': + if hasattr(n.target, 'namespace' + ) and n.target.namespace in self.custom_ops_allowed_in_graph: + output_shapes, output_dtypes = extract_custom_call_outputs_shape_dtype( + n) + call_name = str(n.target) + n.target = stablehlo_custom_call + n.args = (n.args, call_name, output_shapes, output_dtypes) return super().run_node(n) @@ -320,7 +334,8 @@ def _exported_program_to_stablehlo_bundle(exported_model, if options.inline_all_constant: # Inline all constants. torch_xla._XLAC._set_xla_all_numbers_special_scalars(True) - xla_interpreter = XLAExportInterpreter(exported_model.graph_module, device) + xla_interpreter = XLAExportInterpreter(exported_model.graph_module, device, + options.custom_ops_allowed_in_graph) with torch.no_grad(): res = xla_interpreter.run(*_flat_input_args, enable_io_processing=False) res = res[num_mutations:]