diff --git a/.devcontainer/gpu-internal/devcontainer.json b/.devcontainer/gpu-internal/devcontainer.json
new file mode 100644
index 00000000000..ce06bab9e2e
--- /dev/null
+++ b/.devcontainer/gpu-internal/devcontainer.json
@@ -0,0 +1,30 @@
+{
+ "name": "gpu-internal",
+ "image": "us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_12.1",
+ "runArgs": [
+ "--gpus=all",
+ "--net=host",
+ "--shm-size=16G"
+ ],
+ "containerEnv": {
+ "BAZEL_REMOTE_CACHE": "1",
+ "SILO_NAME": "cache-silo-${localEnv:USER}-gpuvm"
+ },
+ "initializeCommand": "docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_12.1",
+ "customizations": {
+ "vscode": {
+ "extensions": [
+ "llvm-vs-code-extensions.vscode-clangd",
+ "ms-vscode.cpptools-themes",
+ "BazelBuild.vscode-bazel",
+ "DevonDCarew.bazel-code",
+ "StackBuild.bazel-stack-vscode",
+ "StackBuild.bazel-stack-vscode-cc",
+ "xaver.clang-format",
+ "ryanluker.vscode-coverage-gutters",
+ "ms-azuretools.vscode-docker",
+ "ms-python.python"
+ ]
+ }
+ }
+}
\ No newline at end of file
diff --git a/.github/workflows/_build_plugin.yml b/.github/workflows/_build_plugin.yml
new file mode 100644
index 00000000000..e30b88aed1e
--- /dev/null
+++ b/.github/workflows/_build_plugin.yml
@@ -0,0 +1,47 @@
+name: build-cuda-plugin
+on:
+ workflow_call:
+ inputs:
+ dev-image:
+ required: true
+ type: string
+ description: Base image for builds
+ runner:
+ required: false
+ type: string
+ description: Runner type for the test
+ default: linux.12xlarge
+
+ secrets:
+ gcloud-service-key:
+ required: true
+ description: Secret to access Bazel build cache
+jobs:
+ build:
+ runs-on: ${{ inputs.runner }}
+ container:
+ image: ${{ inputs.dev-image }}
+ env:
+ GCLOUD_SERVICE_KEY: ${{ secrets.gcloud-service-key }}
+ GOOGLE_APPLICATION_CREDENTIALS: /tmp/default_credentials.json
+ BAZEL_JOBS: 16
+ BAZEL_REMOTE_CACHE: 1
+ steps:
+ - name: Setup gcloud
+ shell: bash
+ run: |
+ echo "${GCLOUD_SERVICE_KEY}" > $GOOGLE_APPLICATION_CREDENTIALS
+ - name: Checkout repo
+ uses: actions/checkout@v4
+ with:
+ path: pytorch/xla
+ - name: Build
+ shell: bash
+ run: |
+ cd pytorch/xla/infra/ansible
+ ansible-playbook playbook.yaml -vvv -e "stage=build_plugin arch=amd64 accelerator=cuda cuda_compute_capabilities=5.2,7.5 src_root=${GITHUB_WORKSPACE} cache_suffix=-ci" --skip-tags=fetch_srcs,install_deps
+ - name: Upload wheel
+ uses: actions/upload-artifact@v4
+ with:
+ name: cuda-plugin
+ path: /dist/*.whl
diff --git a/.github/workflows/_build_torch_xla.yml b/.github/workflows/_build_torch_xla.yml
new file mode 100644
index 00000000000..3e85b7c4c98
--- /dev/null
+++ b/.github/workflows/_build_torch_xla.yml
@@ -0,0 +1,55 @@
+name: build-cuda-plugin
+on:
+ workflow_call:
+ inputs:
+ dev-image:
+ required: true
+ type: string
+ description: Base image for builds
+ runner:
+ required: false
+ type: string
+ description: Runner type for the test
+ default: linux.12xlarge
+
+ secrets:
+ gcloud-service-key:
+ required: true
+ description: Secret to access Bazel build cache
+jobs:
+ build:
+ runs-on: ${{ inputs.runner }}
+ container:
+ image: ${{ inputs.dev-image }}
+ env:
+ GCLOUD_SERVICE_KEY: ${{ secrets.gcloud-service-key }}
+ GOOGLE_APPLICATION_CREDENTIALS: /tmp/default_credentials.json
+ BAZEL_JOBS: 16
+ BAZEL_REMOTE_CACHE: 1
+ # BUILD_CPP_TESTS: 1
+ steps:
+ - name: Setup gcloud
+ shell: bash
+ run: |
+ echo "${GCLOUD_SERVICE_KEY}" > $GOOGLE_APPLICATION_CREDENTIALS
+ - name: Checkout PyTorch Repo
+ uses: actions/checkout@v4
+ with:
+ repository: pytorch/pytorch
+ path: pytorch
+ submodules: recursive
+ # TODO: correct pin
+ - name: Checkout PyTorch/XLA Repo
+ uses: actions/checkout@v4
+ with:
+ path: pytorch/xla
+ - name: Build
+ shell: bash
+ run: |
+ cd pytorch/xla/infra/ansible
+ ansible-playbook playbook.yaml -vvv -e "stage=build arch=amd64 accelerator=tpu src_root=${GITHUB_WORKSPACE} bundle_libtpu=0 cache_suffix=-ci" --skip-tags=fetch_srcs,install_deps
+ - name: Upload wheel
+ uses: actions/upload-artifact@v4
+ with:
+ name: torch-xla-wheels
+ path: /dist/*.whl
diff --git a/.github/workflows/_test.yml b/.github/workflows/_test_cpp.yml
similarity index 90%
rename from .github/workflows/_test.yml
rename to .github/workflows/_test_cpp.yml
index 0f9e96e31e5..d0056d34963 100644
--- a/.github/workflows/_test.yml
+++ b/.github/workflows/_test_cpp.yml
@@ -45,17 +45,8 @@ jobs:
matrix:
include:
# Use readable strings as they define the workflow titles.
- - run_benchmark_tests: 'benchmark_tests'
- run_cpp_tests1: 'cpp_tests1'
- run_cpp_tests2: 'cpp_tests2'
- - run_python_tests: 'python_tests'
- run_xla_op_tests1: 'xla_op1'
- - run_python_tests: 'python_tests'
- run_xla_op_tests2: 'xla_op2'
- - run_python_tests: 'python_tests'
- run_xla_op_tests3: 'xla_op3'
- - run_python_tests: 'python_tests'
- run_torch_mp_op_tests: 'torch_mp_op'
timeout-minutes: ${{ inputs.timeout-minutes }}
env:
DOCKER_IMAGE: ${{ inputs.docker-image }}
@@ -64,14 +55,8 @@ jobs:
USE_COVERAGE: ${{ inputs.collect-coverage && '1' || '0' }}
XLA_SKIP_TORCH_OP_TESTS: ${{ inputs.disable-pjrt }}
XLA_SKIP_MP_OP_TESTS: ${{ inputs.disable-pjrt }}
- RUN_BENCHMARK_TESTS: ${{ matrix.run_benchmark_tests }}
RUN_CPP_TESTS1: ${{ matrix.run_cpp_tests1 }}
RUN_CPP_TESTS2: ${{ matrix.run_cpp_tests2 }}
- RUN_PYTHON_TESTS: ${{ matrix.run_python_tests }}
- RUN_XLA_OP_TESTS1: ${{ matrix.run_xla_op_tests1 }}
- RUN_XLA_OP_TESTS2: ${{ matrix.run_xla_op_tests2 }}
- RUN_XLA_OP_TESTS3: ${{ matrix.run_xla_op_tests3 }}
- RUN_TORCH_MP_OP_TESTS: ${{ matrix.run_torch_mp_op_tests }}
steps:
- name: Setup Linux
uses: pytorch/test-infra/.github/actions/setup-linux@main
diff --git a/.github/workflows/_test_python.yml b/.github/workflows/_test_python.yml
new file mode 100644
index 00000000000..bd260cdb2d1
--- /dev/null
+++ b/.github/workflows/_test_python.yml
@@ -0,0 +1,176 @@
+name: xla-test
+on:
+ workflow_call:
+ inputs:
+ dev-image:
+ required: true
+ type: string
+ description: Base image for builds
+ runner:
+ required: false
+ type: string
+ description: Runner type for the test
+ default: linux.12xlarge
+ collect-coverage:
+ required: false
+ type: boolean
+ description: Set to true to collect coverage information
+ default: false
+ timeout-minutes:
+ required: false
+ type: number
+ default: 270
+ description: |
+ Set the maximum (in minutes) how long the workflow should take to finish
+ timeout-minutes:
+ install-cuda-plugin:
+ required: false
+ type: boolean
+ default: false
+ description: Whether to install CUDA plugin package
+
+ secrets:
+ gcloud-service-key:
+ required: true
+ description: Secret to access Bazel build cache
+jobs:
+ test:
+ runs-on: ${{ inputs.runner }}
+ container:
+ image: ${{ inputs.dev-image }}
+ options: "${{ inputs.install-cuda-plugin && '--gpus all' || '' }} --shm-size 16g"
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ # Use readable strings as they define the workflow titles.
+ - run_benchmark_tests: 'benchmark_tests'
+ - run_python_tests: 'python_tests'
+ run_xla_op_tests1: 'xla_op1'
+ - run_python_tests: 'python_tests'
+ run_xla_op_tests2: 'xla_op2'
+ - run_python_tests: 'python_tests'
+ run_xla_op_tests3: 'xla_op3'
+ - run_python_tests: 'python_tests'
+ run_torch_mp_op_tests: 'torch_mp_op'
+ timeout-minutes: ${{ inputs.timeout-minutes }}
+ env:
+ GCLOUD_SERVICE_KEY: ${{ secrets.gcloud-service-key }}
+ GOOGLE_APPLICATION_CREDENTIALS: /tmp/default_credentials.json
+ USE_COVERAGE: ${{ inputs.collect-coverage && '1' || '0' }}
+ RUN_BENCHMARK_TESTS: ${{ matrix.run_benchmark_tests }}
+ RUN_PYTHON_TESTS: ${{ matrix.run_python_tests }}
+ RUN_XLA_OP_TESTS1: ${{ matrix.run_xla_op_tests1 }}
+ RUN_XLA_OP_TESTS2: ${{ matrix.run_xla_op_tests2 }}
+ RUN_XLA_OP_TESTS3: ${{ matrix.run_xla_op_tests3 }}
+ RUN_TORCH_MP_OP_TESTS: ${{ matrix.run_torch_mp_op_tests }}
+ BAZEL_JOBS: 16
+ BAZEL_REMOTE_CACHE: 1
+ steps:
+ - name: Setup gcloud
+ shell: bash
+ run: |
+ echo "${GCLOUD_SERVICE_KEY}" > $GOOGLE_APPLICATION_CREDENTIALS
+ - name: Fetch wheels
+ uses: actions/download-artifact@v4
+ with:
+ name: torch-xla-wheels
+ path: /tmp/wheels/
+ - name: Fetch CUDA plugin
+ uses: actions/download-artifact@v4
+ with:
+ name: cuda-plugin
+ path: /tmp/wheels/
+ if: ${{ inputs.install-cuda-plugin }}
+ - name: Setup CUDA environment
+ shell: bash
+ run: |
+ # TODO: Make PJRT_DEVICE=CPU work with XLA_REGISTER_INSTALLED_PLUGINS=1
+ echo "XLA_REGISTER_INSTALLED_PLUGINS=1" >> $GITHUB_ENV
+
+ echo "PATH=$PATH:/usr/local/cuda-12.1/bin" >> $GITHUB_ENV
+ echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64" >> $GITHUB_ENV
+ if: ${{ inputs.install-cuda-plugin }}
+ - name: Check GPU
+ run: nvidia-smi
+ if: ${{ inputs.install-cuda-plugin }}
+ - name: Install wheels
+ shell: bash
+ run: |
+ pip install /tmp/wheels/*.whl
+ # TODO: Add these in setup.py
+ pip install fsspec
+ pip install rich
+ - name: Record PyTorch commit
+ run: echo "PYTORCH_COMMIT=$(python -c 'import torch_xla.version; print(torch_xla.version.__torch_gitrev__)')" >> $GITHUB_ENV
+ - name: Checkout PyTorch Repo
+ uses: actions/checkout@v4
+ with:
+ repository: pytorch/pytorch
+ path: pytorch
+ ref: ${{ env.PYTORCH_COMMIT }}
+ - name: Checkout PyTorch/XLA Repo
+ uses: actions/checkout@v4
+ with:
+ path: pytorch/xla
+ - name: Extra CI deps
+ shell: bash
+ run: |
+ set -x
+
+ pip install expecttest unittest-xml-reporting
+
+ if [[ ! -z "$RUN_BENCHMARK_TESTS" ]]; then
+ pip install -r pytorch/xla/benchmarks/requirements.txt
+ fi
+ - name: Test
+ shell: bash
+ run: |
+ source pytorch/xla/.circleci/common.sh
+
+ run_torch_xla_tests pytorch/ pytorch/xla/ $USE_COVERAGE
+ - name: Upload coverage results
+ if: ${{ inputs.collect-coverage }}
+ shell: bash
+ env:
+ CIRCLE_WORKFLOW_ID: ${{ github.run_id }}
+ CIRCLE_BUILD_NUM: ${{ github.run_number }}
+ BENCHMARK_TEST_NAME: ${{ env.RUN_BENCHMARK_TESTS }}
+ PYTHON_TEST_NAME: ${{ env.RUN_PYTHON_TESTS }}${{ env.RUN_XLA_OP_TESTS1 }}${{ env.RUN_XLA_OP_TESTS2 }}${{ env.RUN_XLA_OP_TESTS3 }}${{ env.RUN_TORCH_MP_OP_TESTS }}
+ CPP_TEST_NAME: ${{ env.RUN_CPP_TESTS1 }}${{ env.RUN_CPP_TESTS2 }}
+ run: |
+ # TODO(yeounoh) collect coverage report as needed.
+ if [ -n "${BENCHMARK_TEST_NAME}" ]; then
+ exit 0
+ fi
+ docker cp "${pid}":/home/jenkins/htmlcov "${GITHUB_WORKSPACE}"
+ if [ -n "${GPU_FLAG:-}" ]; then
+ if [ -n "${PYTHON_TEST_NAME}" ]; then
+ gsutil cp ${GITHUB_WORKSPACE}/htmlcov/lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/gpu_python_coverage_${PYTHON_TEST_NAME}.out
+ gsutil cp ${GITHUB_WORKSPACE}/htmlcov/lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/gpu_python_coverage_${PYTHON_TEST_NAME}.out
+ fi
+ if [ -n "${CPP_TEST_NAME}" ]; then
+ gsutil cp ${GITHUB_WORKSPACE}/htmlcov/cpp_lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/gpu_cpp_coverage_${CPP_TEST_NAME}.out
+ gsutil cp ${GITHUB_WORKSPACE}/htmlcov/cpp_lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/gpu_cpp_coverage_${CPP_TEST_NAME}.out
+ fi
+ else
+ if [ -n "${PYTHON_TEST_NAME}" ]; then
+ gsutil cp ${GITHUB_WORKSPACE}/htmlcov/lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_python_coverage_${PYTHON_TEST_NAME}.out
+ gsutil cp ${GITHUB_WORKSPACE}/htmlcov/lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_python_coverage_${PYTHON_TEST_NAME}.out
+ fi
+
+ if [ -n "${CPP_TEST_NAME}" ]; then
+ gsutil cp ${GITHUB_WORKSPACE}/htmlcov/cpp_lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_cpp_coverage_${CPP_TEST_NAME}.out
+ gsutil cp ${GITHUB_WORKSPACE}/htmlcov/cpp_lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_cpp_coverage_${CPP_TEST_NAME}.out
+ fi
+
+ if [ "${CPP_TEST_NAME}" == "cpp_tests1" ]; then
+ ABS_METADATA='{"host": "github", "project": "pytorchxla", "trace_type": "LCOV", "commit_id": '\"${GITHUB_SHA}\"', "ref": "HEAD", "source": "https://github.com/pytorch/xla", "owner": "cloud-tpu-pt-dev", "bug_component": "587012"}'
+ echo $ABS_METADATA > abs_metadata.json
+ gsutil cp abs_metadata.json gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/metadata.json
+
+ INC_METADATA='{"host": "github", "project": "pytorchxla", "trace_type": "LCOV", "patchset_num": 1, "change_id": '${CIRCLE_BUILD_NUM}', "owner": "cloud-tpu-pt-dev", "bug_component": "587012"}'
+ echo $INC_METADATA > inc_metadata.json
+ gsutil cp inc_metadata.json gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/metadata.json
+ fi
+ fi
diff --git a/.github/workflows/_tpu_ci.yml b/.github/workflows/_tpu_ci.yml
new file mode 100644
index 00000000000..bfe9359cf15
--- /dev/null
+++ b/.github/workflows/_tpu_ci.yml
@@ -0,0 +1,35 @@
+name: TPU Integration Test
+on:
+ workflow_call:
+jobs:
+ tpu-test:
+ runs-on: v4-runner-set
+ steps:
+ - name: Checkout repo
+ uses: actions/checkout@v4
+ with:
+ path: pytorch/xla
+ - name: Fetch wheels
+ uses: actions/download-artifact@v4
+ with:
+ name: torch-xla-wheels
+ path: /tmp/wheels/
+ - name: Install wheels
+ shell: bash
+ run: |
+ pip install /tmp/wheels/*.whl
+ - name: Install test dependencies
+ shell: bash
+ run: |
+ # TODO: Add these in setup.py
+ pip install fsspec
+ pip install rich
+ # Jax nightly is needed for pallas tests.
+ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
+ pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
+ - name: Run Tests
+ env:
+ PJRT_DEVICE: TPU
+ run: |
+ cd pytorch/xla
+ test/tpu/run_tests.sh
diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml
index 41bca83b5cb..e5738b5a6af 100644
--- a/.github/workflows/build_and_test.yml
+++ b/.github/workflows/build_and_test.yml
@@ -1,3 +1,4 @@
+name: Build and test
on:
pull_request:
branches:
@@ -18,8 +19,9 @@ concurrency:
cancel-in-progress: true
jobs:
+ # Old CI workflow
build:
- name: "Build XLA"
+ name: "Build PyTorch/XLA (GPU)"
uses: ./.github/workflows/_build.yml
with:
ecr-docker-image-base: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base
@@ -28,20 +30,20 @@ jobs:
secrets:
gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}
- test-cpu:
- name: "CPU tests"
- uses: ./.github/workflows/_test.yml
+ test-cpp-cpu:
+ name: "CPU C++ tests"
+ uses: ./.github/workflows/_test_cpp.yml
needs: build
with:
docker-image: ${{ needs.build.outputs.docker-image }}
timeout-minutes: 120
- collect-coverage: false
+ collect-coverage: false # TODO(yeounoh) separate from CPU coverage metrics
secrets:
gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}
- test-cuda:
- name: "GPU tests"
- uses: ./.github/workflows/_test.yml
+ test-cpp-cuda:
+ name: "GPU C++ tests"
+ uses: ./.github/workflows/_test_cpp.yml
needs: build
with:
docker-image: ${{ needs.build.outputs.docker-image }}
@@ -60,3 +62,52 @@ jobs:
docker-image: ${{ needs.build.outputs.docker-image }}
secrets:
torchxla-bot-token: ${{ secrets.TORCH_XLA_BOT_TOKEN }}
+
+ # New CI workflow
+ build-torch-xla:
+ name: "Build PyTorch/XLA (TPU)"
+ uses: ./.github/workflows/_build_torch_xla.yml
+ with:
+ dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_tpuvm
+ secrets:
+ gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}
+
+ build-cuda-plugin:
+ name: "Build XLA CUDA plugin"
+ uses: ./.github/workflows/_build_plugin.yml
+ with:
+ dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.1
+ secrets:
+ gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}
+
+ test-python-cpu:
+ name: "CPU Python tests"
+ uses: ./.github/workflows/_test_python.yml
+ needs: build-torch-xla
+ with:
+ dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_tpuvm
+ timeout-minutes: 120
+ collect-coverage: false
+ secrets:
+ gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}
+
+ test-python-cuda:
+ name: "GPU Python tests"
+ uses: ./.github/workflows/_test_python.yml
+ needs: [build-torch-xla, build-cuda-plugin]
+ with:
+ dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.1
+ runner: linux.8xlarge.nvidia.gpu
+ timeout-minutes: 300
+ collect-coverage: false
+ install-cuda-plugin: true
+ secrets:
+ gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}
+
+ test-tpu:
+ name: "TPU tests"
+ uses: ./.github/workflows/_tpu_ci.yml
+ needs: build-torch-xla
+ # Only run this for HEAD and releases
+ if: github.event_name == 'push'
+
diff --git a/.github/workflows/tpu_ci.yml b/.github/workflows/tpu_ci.yml
deleted file mode 100644
index 009689ea7c5..00000000000
--- a/.github/workflows/tpu_ci.yml
+++ /dev/null
@@ -1,52 +0,0 @@
-name: TPU Integration Test
-run-name: TPU Testing
-on:
- workflow_dispatch:
- pull_request:
- branches:
- - r[0-9]+.[0-9]+
- paths-ignore:
- - 'experimental/torch_xla2/**'
- push:
- branches:
- - master
- - r[0-9]+.[0-9]+
- paths-ignore:
- - 'experimental/torch_xla2/**'
-jobs:
- tpu-test:
- runs-on: v4-runner-set
- steps:
- - name: Checkout and Setup PyTorch Repo
- env:
- _GLIBCXX_USE_CXX11_ABI: 0
- run: |
- git clone --recursive https://github.com/pytorch/pytorch
- cd pytorch/
- python3 setup.py install --user
- - name: Install torchvision
- run: |
- cd pytorch/
- pip install --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)"
- - name: Checkout PyTorch/XLA Repo
- uses: actions/checkout@v4
- with:
- path: pytorch/xla
- - name: Run PyTorch/XLA Setup
- env:
- BAZEL_VERBOSE: 1
- TPUVM_MODE: 1
- run: |
- cd pytorch/xla
- python3 setup.py install --user
- - name: Run Tests
- env:
- PJRT_DEVICE: TPU
- # Jax is needed for pallas tests.
- run: |
- pip install fsspec
- pip install rich
- pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
- pip install torch_xla[tpuvm]
- cd pytorch/xla
- test/tpu/run_tests.sh
diff --git a/README.md b/README.md
index 289f0017ac6..d1653eb7b53 100644
--- a/README.md
+++ b/README.md
@@ -158,7 +158,7 @@ bucket.
| Version | Cloud TPU VMs Wheel |
|---------|-------------------|
| 2.1 (XRT + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch_xla-2.1.0%2Bxrt-cp310-cp310-manylinux_2_28_x86_64.whl` |
-| 2.1 (Python 3.8) | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.1-cp38-cp38-linux_x86_64.whl` |
+| 2.1 (Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.1.0-cp38-cp38-linux_x86_64.whl` |
| 2.0 (Python 3.8) | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl` |
| 1.13 | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.13-cp38-cp38-linux_x86_64.whl` |
| 1.12 | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.12-cp38-cp38-linux_x86_64.whl` |
@@ -195,7 +195,6 @@ wheels for `torch` and `torch_xla` at
| 1.13 | `https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.13-cp37-cp37m-linux_x86_64.whl` |
| 1.12 | `https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.12-cp37-cp37m-linux_x86_64.whl` |
| 1.11 | `https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.11-cp37-cp37m-linux_x86_64.whl` |
-| nightly | `https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-nightly-cp37-cp37-linux_x86_64.whl` |
diff --git a/WORKSPACE b/WORKSPACE
index a7f8b7762e7..9c6963dae65 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -50,9 +50,9 @@ http_archive(
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:f16_abi_clang.diff",
],
- strip_prefix = "xla-4386e9238b12df5fcba2220e698bf259cbfea27a",
+ strip_prefix = "xla-54ca388f9ad9e8bbcb0ef823752d6b47a99d0b5f",
urls = [
- "https://github.com/openxla/xla/archive/4386e9238b12df5fcba2220e698bf259cbfea27a.tar.gz",
+ "https://github.com/openxla/xla/archive/54ca388f9ad9e8bbcb0ef823752d6b47a99d0b5f.tar.gz",
],
)
diff --git a/benchmarks/experiment_runner.py b/benchmarks/experiment_runner.py
index 443a4067ac1..da952f3c079 100644
--- a/benchmarks/experiment_runner.py
+++ b/benchmarks/experiment_runner.py
@@ -319,8 +319,8 @@ def loop(pytorch_profile=None, iter_fn=None):
self._args.profile_cuda_cpu or \
self._args.profile_cuda_cpu_individual_ops
enable_xla_profiling = self._args.profile_xla
- assert not (enable_pytorch_profiling and enable_pytorch_profiling
- ), "More than one profiling path enabled."
+ assert not (enable_pytorch_profiling and
+ enable_xla_profiling), "More than one profiling path enabled."
if enable_xla_profiling:
logdir = self._get_results_dir_path(experiment_config, model_config,
diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt
new file mode 100644
index 00000000000..14e2549fec3
--- /dev/null
+++ b/benchmarks/requirements.txt
@@ -0,0 +1,3 @@
+tabulate
+scipy
+pandas
diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml
index 199025dc7e1..de5500a0c5b 100644
--- a/codegen/xla_native_functions.yaml
+++ b/codegen/xla_native_functions.yaml
@@ -361,6 +361,7 @@ supported:
- zero_
- _native_batch_norm_legit
- _native_batch_norm_legit.no_stats
+ - _embedding_bag_forward_only
# Note: [functionalization and CompositeExplicitAutograd]
# Below are all operators that are "composite" in core,
# but require us to explicitly re-enable functionalization in order to use them.
diff --git a/docs/fori_loop.md b/docs/fori_loop.md
new file mode 100644
index 00000000000..0c9f85af399
--- /dev/null
+++ b/docs/fori_loop.md
@@ -0,0 +1,114 @@
+# Fori_loop
+`fori_loop` is a replacement of pure python for loop, PyTorch/XLA would enable `torch_xla.experimental.fori_loop` to keep loop computation graph as rolled during compilation
+like [`jax.lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html), not like currently repeat computations by enumerating all execution steps
+of each iteration. `fori_loop` might help memory utilization and might help faster compilation.
+
+User could use `fori_loop` like this:
+```python
+from torch_xla.experimental.fori_loop import fori_loop
+res = fori_loop(upper, lower, /*user defined*/body_fun, init)
+```
+
+current fori_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-fori_loop) with `fori_loop` on TPU too.
+
+For detailed implementation:
+- for situation that loop range is dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`while_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#while_loop),
+like [`jax.lax.while_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html), PyTorch/XLA would support `while_loop` with the
+native PyTorch and the XLA backend: XLA::While. Due to `while_loop` didn't support autograd, so it would be used for inference only.
+
+- for situation that loop range is not dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`scan`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#wipscan),
+like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` using XLA::While operator.
+This implementation would be very similar like `while_loop`. `scan` support autograd, and it could be used in both training and inference.
+
+# while_loop
+`while_loop` is a replacement of pure python while loop, PyTorch has supported `while_loop` in
+[code](https://github.com/pytorch/pytorch/blob/ca6a0e1348ba7dcade1833d983b1b4ca12a5c1e1/torch/_higher_order_ops/while_loop.py#L69).
+PyTorch/XLA want to support `while_loop` with the native PyTorch and the XLA backend: XLA::While.
+
+User could use `while_loop` like this:
+```python
+import torch_xla.experimental.fori_loop
+from torch._higher_order_ops.while_loop import while_loop
+res = while_loop(/*user-defined*/cond_fn, /*user-defined*/body_fn, /*tuple or list*/init)
+```
+current while_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-while_loop) with `while_loop` on TPU too.
+
+
+# [WIP]scan
+like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` for training and inference since it support autograd.
+`scan` is WIP.
+
+
+# Simple user guide
+User could try these three simple test case to better compare difference between `pure python for loop` and `fori_loop` and `while_loop`, these three test case have similar logic: cumulative plus 1 for ten times:
+
+### simple example with pure python for loop
+```bash
+# python
+>>> import torch
+>>> init = torch.tensor([0], dtype=torch.int32)
+>>> one_value = torch.ones(1, dtype=torch.int32)
+>>>
+>>> for i in range(10):
+... init = init + one_value
+...
+>>> init
+tensor([10], dtype=torch.int32)
+```
+
+### simple example with `while_loop`:
+```bash
+# PJRT_DEVICE=TPU python
+>>> import torch
+>>> import torch_xla
+>>> import torch_xla.experimental.fori_loop
+>>> from torch_xla.experimental.fori_loop import fori_loop
+>>> from torch._higher_order_ops.while_loop import while_loop
+>>> import torch_xla.core.xla_model as xm
+>>> import torch_xla.core.xla_builder as xb
+>>>
+>>> device = xm.xla_device()
+>>>
+>>> def cond_fn(init, limit_value):
+... return limit_value[0] >= init[0]
+...
+>>> def body_fn(init, limit_value):
+... one_value = torch.ones(1, dtype=torch.int32, device=device)
+... return (torch.add(init, one_value), limit_value.clone())
+...
+>>> init = torch.tensor([0], dtype=torch.int32, device=device)
+>>> limit_value = torch.tensor([10], dtype=torch.int32, device=device)
+>>> res_, limit_value_ = while_loop(cond_fn, body_fn, (init, limit_value))
+>>> res_
+FunctionalTensor(lvl=0, value=\
+tensor([11], device='xla:0', dtype=torch.int32))
+```
+
+### simple example with `fori_loop`:
+```bash
+# PJRT_DEVICE=TPU python
+>>> import torch
+>>> import torch_xla
+>>> import torch_xla.experimental.fori_loop
+>>> from torch_xla.experimental.fori_loop import fori_loop
+>>> from torch._higher_order_ops.while_loop import while_loop
+>>> import torch_xla.core.xla_model as xm
+>>> import torch_xla.core.xla_builder as xb
+>>>
+>>> device = xm.xla_device()
+>>>
+>>> lower = torch.tensor([2], dtype=torch.int32, device=device)
+>>> upper = torch.tensor([52], dtype=torch.int32, device=device)
+>>> plus_value = torch.tensor([1], dtype=torch.int32, device=device)
+>>> init_val = torch.tensor([1], dtype=torch.int32, device=device)
+>>>
+>>> def body_fun(*argus):
+... plus_value, init_val = argus
+... return plus_value, torch.add(plus_value, init_val)
+...
+>>> _, _, _, res_ = fori_loop(upper, lower, body_fun, plus_value, init_val)
+>>> res_
+tensor([51], device='xla:0', dtype=torch.int32)
+```
+
+For more example and detailed user guide, please read [this test file](https://github.com/pytorch/xla/blob/master/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py). PyTorch/XLA would include `while_loop` support in 2.3 for simple test case, complex test case and support for `fori_loop` and `scan` would be added after 2.3
diff --git a/docs/gpu.md b/docs/gpu.md
index 3222f4174cf..c2678164f4e 100644
--- a/docs/gpu.md
+++ b/docs/gpu.md
@@ -64,8 +64,8 @@ Thu Dec 8 06:24:29 2022
Make sure `PATH` and `LD_LIBRARY_PATH` environment variables account for cuda. Please do a `echo $PATH` and `echo $LD_LIBRARY_PATH` to verify. If not, please follow [link](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#mandatory-actions) to do so. Example:
```
-echo "export PATH=/usr/local/cuda-12.1/bin${PATH:+:${PATH}}" >> ~/.bashrc
-echo "export LD_LIBRARY_PATH=/usr/local/cuda-12.1/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> ~/.bashrc
+echo "export PATH=\$PATH:/usr/local/cuda-12.1/bin" >> ~/.bashrc
+echo "export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64" >> ~/.bashrc
source ~/.bashrc
```
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 26f491f6c15..411e6642ff7 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,6 +1,6 @@
mistune==0.8.4
-sphinx==2.4.4
+sphinx==5.0.0
docutils==0.16
-Jinja2<3.1
+Jinja2==3.1.3
m2r
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
diff --git a/experimental/torch_xla2/docs/dispatch.png b/experimental/torch_xla2/docs/dispatch.png
new file mode 100644
index 00000000000..fcdd5e9e58a
Binary files /dev/null and b/experimental/torch_xla2/docs/dispatch.png differ
diff --git a/experimental/torch_xla2/docs/how_it_works.md b/experimental/torch_xla2/docs/how_it_works.md
new file mode 100644
index 00000000000..e4098ca0096
--- /dev/null
+++ b/experimental/torch_xla2/docs/how_it_works.md
@@ -0,0 +1,134 @@
+How it works
+============
+
+
+## Tensor subclass and eager mode
+
+The class `XLATensor2` is a `torch.Tensor` subclass
+that overrides `__torch_dispatch__`.
+
+It roughly looks like this (with some details removed):
+
+The complete class impl is at [tensor.py](../torch_xla2/tensor.py).
+
+```python
+class XLATensor2(torch.Tensor):
+
+ @staticmethod
+ def __new__(cls, elem):
+ return torch.Tensor._make_wrapper_subclass(
+ cls,
+ shape,
+ dtype=dtype,
+ device='meta',
+ requires_grad=False,
+ )
+
+ def __init__(self, elem: jax.Array):
+ super().__init__()
+ self._elem = elem
+
+ __torch_function__ = torch._C._disabled_torch_function_impl
+
+ @classmethod
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+ # here assumes ALL tensors in args / kwargs are
+ # instances of XLATensor2
+ args, kwargs = unwrap((args, kwargs))
+ jax_func = some_registry[func]
+ res = jax_func(*args, **kwargs)
+ return wrap(res)
+
+def wrap(tree):
+ # wrap jax.Array with XLATensor2
+ return pytree.tree_map_only(
+ jax.Array, XLATensor2, tree)
+
+def unwrap(tree):
+ # get jax.Array out ofXLATensor2
+ return pytree.tree_map_only(
+ XLATensor2, lambda x: x._elem, tree)
+```
+
+In other words, assuming that we have a function
+that takes `jax.Array` as input and returns `jax.Array`
+but otherwise implement the same semantics
+as a `ATen` op; then, using this tensor we would
+be able to route the call to this jax function.
+
+[_ops.py](../torch_xla2/_ops.py) files defines some of those ops.
+
+Let's take `aten::add` as example:
+
+```python
+@op(torch.ops.aten.add)
+def _aten_add(x, y, *, alpha=1):
+ """if isinstance(x, jnp.ndarray) and isinstance(y, jnp.ndarray):
+
+ assert x.dtype == y.dtype, (x.dtype, y.dtype)
+ """
+ return x + y * alpha
+```
+
+The `@op` decorator just puts this function into `some_registry` dictionary.
+
+`_aten_add` has same signature as `torch.ops.aten.add` but takes `jax.Array` as
+input.
+
+![](dispatch.png)
+
+
+## fx Interpreter and dynamo mode
+
+Now, assuming we have this `some_registry` dict with key core Aten ops,
+and value the equivalent python Jax functions. We can also build a `fx.Interpreter`
+subclass that executes the jax function given a `fx.GraphModule`.
+
+
+```python
+class JaxInterpreter(torch.fx.Interpreter):
+
+ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any:
+ if not isinstance(target,
+ (torch._ops.OpOverloadPacket, torch._ops.OpOverload)):
+ return super().call_function(target, args, kwargs)
+
+ op = some_registry[target]
+ return op.func(*args, **kwargs)
+```
+
+There is no wrapping and unwrapping needed because `args` and `kwargs` are
+already `jax.Array`'s.
+
+Using this interpreter we can build a dynamo backend:
+
+```python
+def backend(fxgraph):
+
+ def tojit(*args, *kwargs):
+ return JaxInterpreter(fxgraph).run(*args, **kwargs)
+ jitted = jax.jit(to_jit)
+
+ def f(*torchtensor):
+ jaxarrays = unwrap(torchtensors)
+ res = jitted(jax_array)
+ return wrap(res)
+
+ return f
+```
+
+The inner function `tojit` is a function that takes and returns
+`jax.Array`'s. So it's suitable to be jitted with `jax.jit`.
+
+`f` is returned callable that takes `XLATensor2`; so can interop with
+other torch codes.
+
+## nn.Modules and state management
+
+See [README.md](../README.md) for using `torch.func.functional_call` to
+make `nn.Module`s interact well with `jax.jit`.
+
+See [Examples](../examples/README.md) for training using torch's optimizers or jax's
+optimizers.
+
+[def]: dispatch.png
\ No newline at end of file
diff --git a/experimental/torch_xla2/examples/README.md b/experimental/torch_xla2/examples/README.md
new file mode 100644
index 00000000000..0e22d28c531
--- /dev/null
+++ b/experimental/torch_xla2/examples/README.md
@@ -0,0 +1,115 @@
+## Intro
+
+This readme will have a subsection for every example *.py file.
+
+Please follow the instructions in [README.md](../README.md) to install torch_xla2,
+then install requirements for all of the examples with
+
+```bash
+pip install -r requirements.txt
+```
+
+
+
+## basic_training.py
+
+This file constructed by first copy & paste code fragments from this pytorch training tutorial:
+https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
+
+Then adding few lines of code that serves the purpose of moving `torch.Tensor` into
+`XLA devices`.
+
+Example:
+
+```python
+state_dict = pytree.tree_map_only(torch.Tensor,
+ torch_xla2.tensor.move_to_device, state_dict)
+```
+
+This fragment moves the state_dict to XLA devices; then the state_dict is passed
+back to model via `load_state_dict`.
+
+Then, you can train the model. This shows what is minimum to train a model on XLA
+devices. The perf is not as good because we didn't use `jax.jit`, this is intentional
+as it is meant to showcase the minimum code change.
+
+Example run:
+```bash
+(xla2) hanq-macbookpro:examples hanq$ python basic_training.py
+Training set has 60000 instances
+Validation set has 10000 instances
+Bag Dress Sneaker T-shirt/top
+tensor([[0.8820, 0.3807, 0.3010, 0.9266, 0.7253, 0.9265, 0.0688, 0.4567, 0.7035,
+ 0.2279],
+ [0.3253, 0.1558, 0.1274, 0.2776, 0.2590, 0.4169, 0.1881, 0.7423, 0.4561,
+ 0.5985],
+ [0.5067, 0.4514, 0.9758, 0.6088, 0.7438, 0.6811, 0.9609, 0.3572, 0.4504,
+ 0.8738],
+ [0.1850, 0.1217, 0.8551, 0.2120, 0.9902, 0.7623, 0.1658, 0.6980, 0.3086,
+ 0.5709]])
+tensor([1, 5, 3, 7])
+Total loss for this batch: 2.325265645980835
+EPOCH 1:
+ batch 1000 loss: 1.041275198560208
+ batch 2000 loss: 0.6450189483696595
+ batch 3000 loss: 0.5793989677671343
+ batch 4000 loss: 0.5170258888280951
+ batch 5000 loss: 0.4920090722264722
+ batch 6000 loss: 0.48910293977567926
+ batch 7000 loss: 0.48058812761632724
+ batch 8000 loss: 0.47159107415075413
+ batch 9000 loss: 0.4712311488997657
+ batch 10000 loss: 0.4675815168160479
+ batch 11000 loss: 0.43210567891132085
+ batch 12000 loss: 0.445208148030797
+ batch 13000 loss: 0.4119230824254337
+ batch 14000 loss: 0.4190662656680215
+ batch 15000 loss: 0.4094535468676477
+LOSS train 0.4094535468676477 valid XLA
+```
+
+## basic_training_jax.py
+
+This file constructed by first copy & paste code fragments from this pytorch training tutorial:
+https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
+
+Then replacing torch optimizer with `optax` optimizer; and use `jax.grad` for
+gradient instead of `torch.Tensor.backward()`.
+
+Then, you can train the model using jax ecosystem's training loop. This is meant to
+showcase how easy is to integrate with Jax.
+
+Example run:
+```bash
+(xla2) hanq-macbookpro:examples hanq$ python basic_training_jax.py
+Training set has 60000 instances
+Validation set has 10000 instances
+Pullover Ankle Boot Pullover Ankle Boot
+tensor([[0.5279, 0.8340, 0.3131, 0.8608, 0.3668, 0.6192, 0.7453, 0.3261, 0.8872,
+ 0.1854],
+ [0.7414, 0.8309, 0.8127, 0.8866, 0.2475, 0.2664, 0.0327, 0.6918, 0.6010,
+ 0.2766],
+ [0.3304, 0.9135, 0.2762, 0.6737, 0.0480, 0.6150, 0.5610, 0.5804, 0.9607,
+ 0.6450],
+ [0.9464, 0.9439, 0.3122, 0.1814, 0.1194, 0.5012, 0.2058, 0.1170, 0.7377,
+ 0.7453]])
+tensor([1, 5, 3, 7])
+Total loss for this batch: 2.4054245948791504
+EPOCH 1:
+ batch 1000 loss: 1.0705260595591972
+ batch 2000 loss: 1.0997755021179327
+ batch 3000 loss: 1.0186579653513108
+ batch 4000 loss: 0.9090727646966116
+ batch 5000 loss: 0.8309370622411024
+ batch 6000 loss: 0.8702225417760783
+ batch 7000 loss: 0.8750176187023462
+ batch 8000 loss: 0.9652624803795453
+ batch 9000 loss: 0.8688667197711766
+ batch 10000 loss: 0.8021814124770199
+ batch 11000 loss: 0.8000540231048071
+ batch 12000 loss: 0.9150884484921057
+ batch 13000 loss: 0.819690621060171
+ batch 14000 loss: 0.8569030471532278
+ batch 15000 loss: 0.8740896808278603
+LOSS train 0.8740896808278603 valid 2.3132264614105225
+```
\ No newline at end of file
diff --git a/experimental/torch_xla2/examples/_diffusion.py b/experimental/torch_xla2/examples/_diffusion.py
new file mode 100644
index 00000000000..5eae15edf25
--- /dev/null
+++ b/experimental/torch_xla2/examples/_diffusion.py
@@ -0,0 +1,112 @@
+import functools
+
+import torch
+from time import time
+from diffusers import DiffusionPipeline
+from torch.utils import _pytree as pytree
+
+
+import torch_xla2
+import torch_xla2.functions
+from torch_xla2.extra import torch_view, jax_view
+
+import jax
+import torch.func
+
+
+class CompiledModule:
+
+ def __init__(self, model):
+ weights = model.state_dict()
+ weights.update(model.named_parameters())
+ self._weights = pytree.tree_map_only(torch.Tensor, torch_xla2.tensor.move_to_device, weights)
+ self._model = model
+
+ self._func_jitted_torch = None #torch_view(func_mod_jitted)
+
+
+ def _maybe_move_tensor(self, tensor):
+ if isinstance(tensor, torch.Tensor) and not isinstance(tensor, torch_xla2.tensor.XLATensor2):
+ return torch_xla2.tensor.move_to_device(tensor)
+ return tensor
+
+ def _make_jitted(self, args, kwargs):
+ static = []
+ for i, a in enumerate(args):
+ if not isinstance(a, torch.Tensor):
+ static.append(i + 1) # weight is 0
+ static_argnames = []
+ for k, v in kwargs.items():
+ if not isinstance(v, torch.Tensor):
+ static_argnames.append(k)
+
+ def f(weights, *args, **kwargs):
+ weights, args, kwargs = torch_xla2.tensor.wrap((weights, args, kwargs))
+ with torch_xla2.functions.XLAFunctionMode(), torch_xla2.tensor.XLADispatchMode():
+ res = torch.func.functional_call(self._model, weights, args, kwargs)
+ if isinstance(res, tuple) and len(res) == 1:
+ res = res[0]
+ return torch_xla2.tensor.unwrap(res)
+
+ fjit = jax.jit(f, static_argnames=tuple(static_argnames))
+ return torch_view(fjit)
+
+
+ def forward(self, *args, **kwargs):
+ (args, kwargs) = pytree.tree_map(self._maybe_move_tensor, (args, kwargs))
+ if self._func_jitted_torch is None:
+ self._func_jitted_torch = self._make_jitted(args, kwargs)
+ return self._func_jitted_torch(
+ self._weights,
+ *args,
+ **kwargs
+ )
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+ def __getattr__(self, key):
+ return getattr(self._model, key)
+
+
+def compile_pipe(pipe):
+ pipe.text_encoder = CompiledModule(pipe.text_encoder)
+ pipe.text_encoder_2 = CompiledModule(pipe.text_encoder_2)
+ pipe.unet = CompiledModule(pipe.unet)
+ pipe.vae = CompiledModule(pipe.vae)
+
+
+def main():
+ pipe = DiffusionPipeline.from_pretrained(
+ # "stabilityai/stable-diffusion-xl-base-0.9",
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ use_safetensors=True,
+
+ )
+ compile_pipe(pipe)
+
+ global_bs = 10
+ inference_steps = 20
+ resol = 1024
+ prompts = ["a photo of an astronaut riding a horse on mars"] * global_bs
+ print(f'global batch size {global_bs}',
+ f'inference steps {inference_steps}',
+ f'Image resolution {resol}',
+ flush=True
+ )
+
+ iters = 5
+ for i in range(iters):
+ prompt = prompts
+ # print('per device prompts len',len(prompt))
+ # prompt = prompts[rank]
+ start = time()
+ image = pipe(prompt,
+ num_inference_steps=inference_steps,
+ height=resol,
+ width=resol).images[0]
+ print(f'Step {i} inference time {time()-start} sec', flush=True)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experimental/torch_xla2/examples/basic_training.py b/experimental/torch_xla2/examples/basic_training.py
new file mode 100644
index 00000000000..5d3f5a734c5
--- /dev/null
+++ b/experimental/torch_xla2/examples/basic_training.py
@@ -0,0 +1,197 @@
+"""
+This is the script from this tutorial:
+https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
+
+Then, it's modified to make the training loop using Jax's grad
+and optimizer
+"""
+
+import torch
+from torch.utils import _pytree as pytree
+import torchvision
+import torchvision.transforms as transforms
+import torch_xla2
+
+# PyTorch TensorBoard support
+from torch.utils.tensorboard import SummaryWriter
+from datetime import datetime
+
+
+transform = transforms.Compose(
+ [transforms.ToTensor(),
+ transforms.Normalize((0.5,), (0.5,))])
+
+# Create datasets for training & validation, download if necessary
+training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
+validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)
+
+# Create data loaders for our datasets; shuffle for training, not for validation
+training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
+validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)
+
+# Class labels
+classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
+ 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
+
+# Report split sizes
+print('Training set has {} instances'.format(len(training_set)))
+print('Validation set has {} instances'.format(len(validation_set)))
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+# Helper function for inline image display
+def matplotlib_imshow(img, one_channel=False):
+ if one_channel:
+ img = img.mean(dim=0)
+ img = img / 2 + 0.5 # unnormalize
+ npimg = img.numpy()
+ if one_channel:
+ plt.imshow(npimg, cmap="Greys")
+ else:
+ plt.imshow(np.transpose(npimg, (1, 2, 0)))
+
+dataiter = iter(training_loader)
+images, labels = next(dataiter)
+
+# Create a grid from the images and show them
+img_grid = torchvision.utils.make_grid(images)
+matplotlib_imshow(img_grid, one_channel=True)
+print(' '.join(classes[labels[j]] for j in range(4)))
+
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+# PyTorch models inherit from torch.nn.Module
+class GarmentClassifier(nn.Module):
+ def __init__(self):
+ super(GarmentClassifier, self).__init__()
+ self.fc1 = nn.Linear(28 * 28, 120)
+ self.fc2 = nn.Linear(120, 84)
+ self.fc3 = nn.Linear(84, 10)
+
+ def forward(self, x):
+ x = x.view(-1, 28 * 28)
+ x = F.relu(self.fc1(x))
+ x = F.relu(self.fc2(x))
+ x = self.fc3(x)
+ return x
+
+
+model = GarmentClassifier()
+
+loss_fn = torch.nn.CrossEntropyLoss()
+
+# NB: Loss functions expect data in batches, so we're creating batches of 4
+# Represents the model's confidence in each of the 10 classes for a given input
+dummy_outputs = torch.rand(4, 10)
+# Represents the correct class among the 10 being tested
+dummy_labels = torch.tensor([1, 5, 3, 7])
+
+print(dummy_outputs)
+print(dummy_labels)
+
+loss = loss_fn(dummy_outputs, dummy_labels)
+print('Total loss for this batch: {}'.format(loss.item()))
+
+# Optimizers specified in the torch.optim package
+
+# NEW: Move model to XLA device
+state_dict = model.state_dict()
+state_dict = pytree.tree_map_only(torch.Tensor,
+ torch_xla2.tensor.move_to_device, state_dict)
+model.load_state_dict(state_dict, strict=False, assign=True)
+
+optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
+
+def train_one_epoch(epoch_index, tb_writer):
+ running_loss = 0.
+ last_loss = 0.
+
+ # Here, we use enumerate(training_loader) instead of
+ # iter(training_loader) so that we can track the batch
+ # index and do some intra-epoch reporting
+ for i, data in enumerate(training_loader):
+ # Every data instance is an input + label pair
+ # NEW: Move model to XLA device
+ data = pytree.tree_map_only(torch.Tensor,
+ torch_xla2.tensor.move_to_device, data)
+ inputs, labels = data
+
+ # Zero your gradients for every batch!
+ optimizer.zero_grad()
+
+ # Make predictions for this batch
+ outputs = model(inputs)
+
+ # Compute the loss and its gradients
+ loss = loss_fn(outputs, labels)
+ loss.backward()
+
+ # Adjust learning weights
+ optimizer.step()
+
+ # Gather data and report
+ running_loss += loss.item()
+ if i % 1000 == 999:
+ last_loss = running_loss / 1000 # loss per batch
+ print(' batch {} loss: {}'.format(i + 1, last_loss))
+ tb_x = epoch_index * len(training_loader) + i + 1
+ tb_writer.add_scalar('Loss/train', last_loss, tb_x)
+ running_loss = 0.
+
+ return last_loss
+
+
+
+# Initializing in a separate cell so we can easily add more epochs to the same run
+timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
+writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
+epoch_number = 0
+EPOCHS = 2
+best_vloss = 1_000_000.
+
+for epoch in range(EPOCHS):
+ print('EPOCH {}:'.format(epoch_number + 1))
+
+ # Make sure gradient tracking is on, and do a pass over the data
+ model.train(True)
+
+
+ avg_loss = train_one_epoch(epoch_number, writer)
+
+ running_vloss = 0.0
+ # Set the model to evaluation mode, disabling dropout and using population
+ # statistics for batch normalization.
+ model.eval()
+
+ # Disable gradient computation and reduce memory consumption.
+ with torch.no_grad():
+ for i, vdata in enumerate(validation_loader):
+ # NOTE: move to XLA device
+ vinputs, vlabels = pytree.tree_map_only(
+ torch.Tensor,
+ torch_xla2.tensor.move_to_device,
+ vdata)
+ voutputs = model(vinputs) # call model's forward
+ vloss = loss_fn(voutputs, vlabels)
+ running_vloss += vloss
+
+ avg_vloss = running_vloss / (i + 1)
+ print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
+
+ # Log the running loss averaged per batch
+ # for both training and validation
+ writer.add_scalars('Training vs. Validation Loss',
+ { 'Training' : avg_loss, 'Validation' : avg_vloss },
+ epoch_number + 1)
+ writer.flush()
+
+ # Track best performance, and save the model's state
+ if avg_vloss < best_vloss:
+ best_vloss = avg_vloss
+ model_path = 'model_{}_{}'.format(timestamp, epoch_number)
+ torch.save(model.state_dict(), model_path)
+
+ epoch_number += 1
\ No newline at end of file
diff --git a/experimental/torch_xla2/examples/basic_training_jax.py b/experimental/torch_xla2/examples/basic_training_jax.py
new file mode 100644
index 00000000000..3941fcdf8fe
--- /dev/null
+++ b/experimental/torch_xla2/examples/basic_training_jax.py
@@ -0,0 +1,196 @@
+"""
+This is the script from this tutorial:
+https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
+"""
+
+import torch
+from torch.utils import _pytree as pytree
+import torchvision
+import torchvision.transforms as transforms
+import torch_xla2
+import torch_xla2.extra
+import jax
+import optax
+import numpy as np
+
+# PyTorch TensorBoard support
+from torch.utils.tensorboard import SummaryWriter
+from datetime import datetime
+
+
+transform = transforms.Compose(
+ [transforms.ToTensor(),
+ transforms.Normalize((0.5,), (0.5,))])
+
+# Create datasets for training & validation, download if necessary
+training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
+validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)
+
+# Create data loaders for our datasets; shuffle for training, not for validation
+training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
+validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)
+
+# Class labels
+classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
+ 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
+
+# Report split sizes
+print('Training set has {} instances'.format(len(training_set)))
+print('Validation set has {} instances'.format(len(validation_set)))
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+# Helper function for inline image display
+def matplotlib_imshow(img, one_channel=False):
+ if one_channel:
+ img = img.mean(dim=0)
+ img = img / 2 + 0.5 # unnormalize
+ npimg = img.numpy()
+ if one_channel:
+ plt.imshow(npimg, cmap="Greys")
+ else:
+ plt.imshow(np.transpose(npimg, (1, 2, 0)))
+
+dataiter = iter(training_loader)
+images, labels = next(dataiter)
+
+# Create a grid from the images and show them
+img_grid = torchvision.utils.make_grid(images)
+matplotlib_imshow(img_grid, one_channel=True)
+print(' '.join(classes[labels[j]] for j in range(4)))
+
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+# PyTorch models inherit from torch.nn.Module
+class GarmentClassifier(nn.Module):
+ def __init__(self):
+ super(GarmentClassifier, self).__init__()
+ self.fc1 = nn.Linear(28 * 28, 120)
+ self.fc2 = nn.Linear(120, 84)
+ self.fc3 = nn.Linear(84, 10)
+
+ def forward(self, x):
+ x = x.view(-1, 28 * 28)
+ x = F.relu(self.fc1(x))
+ x = F.relu(self.fc2(x))
+ x = self.fc3(x)
+ return x
+
+
+model = GarmentClassifier()
+loss_fn = torch.nn.CrossEntropyLoss()
+
+jax_weights, jax_func = torch_xla2.extract_jax(model)
+jax_func = jax.jit(jax_func, inline=True)
+jax_optimizer = optax.adam(0.01)
+opt_state = jax_optimizer.init(jax_weights)
+
+
+def jax_loss(weights, data, label):
+ pred = jax_func(weights, data)
+ loss = torch_xla2.extra.call_torch(loss_fn, pred, label)
+ return loss
+
+grad_fn = jax.jit(jax.value_and_grad(jax_loss))
+
+
+# NB: Loss functions expect data in batches, so we're creating batches of 4
+# Represents the model's confidence in each of the 10 classes for a given input
+dummy_outputs = torch.rand(4, 10)
+# Represents the correct class among the 10 being tested
+dummy_labels = torch.tensor([1, 5, 3, 7])
+
+print(dummy_outputs)
+print(dummy_labels)
+
+loss = loss_fn(dummy_outputs, dummy_labels)
+print('Total loss for this batch: {}'.format(loss.item()))
+
+
+def train_one_epoch(jax_weights, opt_state, epoch_index, tb_writer):
+
+ running_loss = 0.
+ last_loss = 0.
+
+ # Here, we use enumerate(training_loader) instead of
+ # iter(training_loader) so that we can track the batch
+ # index and do some intra-epoch reporting
+ for i, data in enumerate(training_loader):
+ # Every data instance is an input + label pair
+ # NEW: Move model to XLA device
+ data = pytree.tree_map_only(torch.Tensor,
+ torch_xla2.tensor.t2j, data)
+ inputs, labels = data
+
+ val, grads = grad_fn(jax_weights, (inputs, ), labels)
+ updates, opt_state = jax_optimizer.update(grads, opt_state)
+ jax_weights = optax.apply_updates(jax_weights, updates)
+
+ # Gather data and report
+ running_loss += val.item()
+ if i % 1000 == 999:
+ last_loss = running_loss / 1000 # loss per batch
+ print(' batch {} loss: {}'.format(i + 1, last_loss))
+ tb_x = epoch_index * len(training_loader) + i + 1
+ tb_writer.add_scalar('Loss/train', last_loss, tb_x)
+ running_loss = 0.
+
+ return last_loss, opt_state
+
+
+
+# Initializing in a separate cell so we can easily add more epochs to the same run
+timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
+writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
+epoch_number = 0
+EPOCHS = 2
+best_vloss = 1_000_000.
+
+for epoch in range(EPOCHS):
+ print('EPOCH {}:'.format(epoch_number + 1))
+
+ # Make sure gradient tracking is on, and do a pass over the data
+ model.train(True)
+
+ # NEW: Move model to XLA device
+ state_dict = model.state_dict()
+ state_dict = pytree.tree_map_only(torch.Tensor,
+ torch_xla2.tensor.move_to_device, state_dict)
+ model.load_state_dict(state_dict, strict=False, assign=True)
+
+ avg_loss, opt_state = train_one_epoch(jax_weights, opt_state, epoch_number, writer)
+
+ running_vloss = 0.0
+ # Set the model to evaluation mode, disabling dropout and using population
+ # statistics for batch normalization.
+ model.eval()
+
+ # Disable gradient computation and reduce memory consumption.
+ with torch.no_grad():
+ for i, vdata in enumerate(validation_loader):
+
+ vinputs, vlabels = pytree.tree_map_only(torch.Tensor, torch_xla2.tensor.t2j, vdata)
+ voutputs = jax_func(jax_weights, (vinputs, )) # call model's forward
+ vloss = torch_xla2.extra.call_torch(loss_fn, voutputs, vlabels)
+ running_vloss += vloss
+
+ avg_vloss = running_vloss / (i + 1)
+ print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
+
+ # Log the running loss averaged per batch
+ # for both training and validation
+ writer.add_scalars('Training vs. Validation Loss',
+ { 'Training' : np.asarray(avg_loss), 'Validation' : np.asarray(avg_vloss) },
+ epoch_number + 1)
+ writer.flush()
+
+ # Track best performance, and save the model's state
+ if avg_vloss < best_vloss:
+ best_vloss = avg_vloss
+ model_path = 'model_{}_{}'.format(timestamp, epoch_number)
+ torch.save(model.state_dict(), model_path)
+
+ epoch_number += 1
\ No newline at end of file
diff --git a/experimental/torch_xla2/examples/eager_mode.py b/experimental/torch_xla2/examples/eager_mode.py
new file mode 100644
index 00000000000..358ee6256c6
--- /dev/null
+++ b/experimental/torch_xla2/examples/eager_mode.py
@@ -0,0 +1,42 @@
+
+from torch_xla2.tensor import move_to_device
+import torch_xla2
+from torch import nn
+from torch.nn import functional as F
+import torch
+from torch.utils import _pytree as pytree
+
+
+class MyModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.fc1 = nn.Linear(28 * 28, 120)
+ self.fc2 = nn.Linear(120, 84)
+ self.fc3 = nn.Linear(84, 10)
+
+ def forward(self, x):
+ x = x.view(-1, 28 * 28)
+ x = F.relu(self.fc1(x))
+ x = F.relu(self.fc2(x))
+ x = self.fc3(x)
+ return x
+
+m = MyModel()
+
+# Execute this model using torch
+inputs = (torch.randn(3, 3, 28, 28), )
+
+inputs, state_dict = pytree.tree_map_only(torch.Tensor, move_to_device, (inputs, m.state_dict()))
+m.load_state_dict(state_dict, strict=False, assign=True)
+print(m(*inputs))
+print('---=====')
+
+from torch_xla2.extra import jax_jit
+
+@jax_jit
+def model_func(param, inputs):
+ return torch.func.functional_call(m, param, inputs)
+
+print(model_func(state_dict, inputs))
+
+
diff --git a/experimental/torch_xla2/examples/requirements.txt b/experimental/torch_xla2/examples/requirements.txt
new file mode 100644
index 00000000000..69e01ff3dd0
--- /dev/null
+++ b/experimental/torch_xla2/examples/requirements.txt
@@ -0,0 +1,3 @@
+torchvision
+matplotlib
+optax
\ No newline at end of file
diff --git a/experimental/torch_xla2/test/llama/test_llama.py b/experimental/torch_xla2/test/llama/test_llama.py
index 69ec3c33aef..dae7bf0cc5c 100644
--- a/experimental/torch_xla2/test/llama/test_llama.py
+++ b/experimental/torch_xla2/test/llama/test_llama.py
@@ -33,7 +33,6 @@ def test_can_run(self):
# NOTE: this API does NOT use torch export
weights, jax_func = torch_xla2.extract_jax(m)
-
print(jax_func(weights, sample_args))
def test_can_run_exportable(self):
diff --git a/experimental/torch_xla2/test/test_base.py b/experimental/torch_xla2/test/test_base.py
index 71f4dc97b67..d8b409306b7 100644
--- a/experimental/torch_xla2/test/test_base.py
+++ b/experimental/torch_xla2/test/test_base.py
@@ -1,4 +1,55 @@
import unittest
+import torch
+from torch.utils import _pytree as pytree
+
+from torch_xla2 import tensor
TestCase = unittest.TestCase
main = unittest.main
+
+
+def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True):
+ if isinstance(output1, torch.Tensor):
+ testcase.assertIsInstance(output2, torch.Tensor)
+ output2_cpu = output2.detach().cpu()
+ if output2_cpu.dtype != output1.dtype:
+ output2_cpu = output2_cpu.to(output1.dtype)
+ testcase.assertTrue(
+ torch.allclose(
+ output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan))
+ elif isinstance(output1, (tuple, list)):
+ testcase.assertIsInstance(output2, (tuple, list))
+ testcase.assertEqual(len(output1), len(output2))
+ for o1, o2 in zip(output1, output2):
+ diff_output(testcase, o1, o2, rtol, atol)
+ else:
+ testcase.assertEqual(output1, output2)
+
+
+def run_function_and_compare(testcase,
+ func,
+ args,
+ kwargs,
+ atol=1e-3,
+ rtol=1e-5,
+ equal_nan=True,
+ ignore_indices=False):
+ with testcase.subTest("torch_eval"):
+ res = func(*args, **kwargs)
+ with testcase.subTest("torch_xla2_eval"):
+ args2, kwargs2 = pytree.tree_map_only(torch.Tensor, tensor.move_to_device,
+ (args, kwargs))
+ res2 = func(*args2, **kwargs2)
+ res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2)
+ with testcase.subTest("torch_xla2_diff:" + str(atol)):
+ if ignore_indices and isinstance(res, tuple) and len(res) == 2:
+ diff_output(
+ testcase,
+ res[0],
+ res2[0],
+ atol=atol,
+ rtol=rtol,
+ equal_nan=equal_nan)
+ else:
+ diff_output(
+ testcase, res, res2, atol=atol, rtol=rtol, equal_nan=equal_nan)
diff --git a/experimental/torch_xla2/test/test_context.py b/experimental/torch_xla2/test/test_context.py
new file mode 100644
index 00000000000..1a75a7d23d0
--- /dev/null
+++ b/experimental/torch_xla2/test/test_context.py
@@ -0,0 +1,31 @@
+import unittest
+
+import torch
+import torch_xla2
+from torch_xla2 import tensor
+
+
+class TestContext(unittest.TestCase):
+ def test_mode_context_manager(self):
+ with torch_xla2.mode():
+ x = torch.full((3, 3), -1)
+ self.assertIsInstance(x, tensor.XLATensor2)
+ y = x.abs()
+ self.assertIsInstance(y, tensor.XLATensor2)
+
+ @staticmethod
+ @torch_xla2.mode()
+ def _test_mode_decorator():
+ x = torch.full((3, 3), -1)
+ y = x.abs()
+
+ return x, y
+
+ def test_mode_decorator(self):
+ x, y = self._test_mode_decorator()
+ self.assertIsInstance(x, tensor.XLATensor2)
+ self.assertIsInstance(y, tensor.XLATensor2)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/experimental/torch_xla2/test/test_mutations.py b/experimental/torch_xla2/test/test_mutations.py
new file mode 100644
index 00000000000..2f9ddca975b
--- /dev/null
+++ b/experimental/torch_xla2/test/test_mutations.py
@@ -0,0 +1,52 @@
+import unittest
+import torch_xla2
+import torch
+from torch.testing._internal.common_utils import TestCase
+
+
+class TestMutations(TestCase):
+
+ def test_add(self):
+ x = torch.tensor([1, 2, 3], dtype=torch.int32)
+ y = torch.tensor([4, 5, 6], dtype=torch.int32)
+
+ x = torch_xla2.tensor.move_to_device(x)
+ y = torch_xla2.tensor.move_to_device(y)
+ x.add_(y)
+ xt = torch_xla2.tensor.j2t(x._elem)
+ self.assertEqual(xt, torch.tensor([5, 7, 9], dtype=torch.int32))
+
+ def test_sub(self):
+ x = torch.tensor([1, 2, 3], dtype=torch.int32)
+ y = torch.tensor([4, 5, 6], dtype=torch.int32)
+
+ x = torch_xla2.tensor.move_to_device(x)
+ y = torch_xla2.tensor.move_to_device(y)
+ x.sub_(y)
+ xt = torch_xla2.tensor.j2t(x._elem)
+ self.assertEqual(xt, torch.tensor([-3, -3, -3], dtype=torch.int32))
+
+ def test_mul(self):
+ x = torch.tensor([1, 2, 3], dtype=torch.int32)
+ y = torch.tensor([4, 5, 6], dtype=torch.int32)
+
+ x = torch_xla2.tensor.move_to_device(x)
+ y = torch_xla2.tensor.move_to_device(y)
+ x.mul_(y)
+ xt = torch_xla2.tensor.j2t(x._elem)
+ self.assertEqual(xt, torch.tensor([4, 10, 18], dtype=torch.int32))
+
+ def test_div(self):
+ x = torch.tensor([1, 2, 3], dtype=torch.int32)
+ y = torch.tensor([4, 5, 6], dtype=torch.int32)
+
+ x = torch_xla2.tensor.move_to_device(x)
+ y = torch_xla2.tensor.move_to_device(y)
+ x.div_(y)
+ xt = torch_xla2.tensor.j2t(x._elem)
+ self.assertEqual(xt,
+ torch.tensor([1. / 4, 2. / 5, 3. / 6], dtype=torch.float))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/experimental/torch_xla2/torch_xla2/CONTRIBUTING.md b/experimental/torch_xla2/torch_xla2/CONTRIBUTING.md
new file mode 100644
index 00000000000..ee640dac836
--- /dev/null
+++ b/experimental/torch_xla2/torch_xla2/CONTRIBUTING.md
@@ -0,0 +1,38 @@
+# Contributing to TorchXLA2
+
+We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. You are very welcome to pick issues from good first issue and help wanted labels.
+
+If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of.
+
+
+# Developer setup
+
+## Mac setup:
+@qihqi
+
+I am able to develop directly on mac (m1) laptop for most of parts. Using steps
+in README.md works. The condensed version for easy copy & paste:
+
+```bash
+conda create --name python=3.10
+conda activate
+pip install --upgrade "jax[cpu]" torch
+pip install -r test_requirements.txt
+pip install -e .
+pytest test
+```
+
+### VSCode
+
+I use vscode on my Mac. I loosely followed instruction in
+https://code.visualstudio.com/docs/python/python-tutorial
+to setup a proper python environment.
+
+The plugins I installed (a subset of the ones listed above) are:
+* VSCode's official Python plugin
+* Ruff formatter
+* Python Debugger
+
+I also changed Python interpreter to point at the one in my conda env.
+That is all the changes I have.
+
diff --git a/experimental/torch_xla2/torch_xla2/__init__.py b/experimental/torch_xla2/torch_xla2/__init__.py
index 4d07006fcd0..b0bb20712d4 100644
--- a/experimental/torch_xla2/torch_xla2/__init__.py
+++ b/experimental/torch_xla2/torch_xla2/__init__.py
@@ -1,13 +1,19 @@
+import contextlib
import jax
import torch
from torch._functorch import make_functional
from torch.utils import _pytree as pytree
-from torch_xla2 import tensor
-from torch_xla2 import export, _ops, ops_registry, tensor, tf_integration
+from torch_xla2 import export, _ops, ops_registry, tensor, tf_integration, functions
jax.config.update('jax_enable_x64', True)
+@contextlib.contextmanager
+def mode():
+ with tensor.XLADispatchMode(), functions.XLAFunctionMode():
+ yield
+
+
def extract_jax(mod: torch.nn.Module):
"""Returns a pytree of jax.ndarray and a jax callable."""
func, weights, buffer = make_functional.make_functional_with_buffers(mod)
diff --git a/experimental/torch_xla2/torch_xla2/_ops.py b/experimental/torch_xla2/torch_xla2/_ops.py
index e4aece8f6a8..fe0f97a0f01 100644
--- a/experimental/torch_xla2/torch_xla2/_ops.py
+++ b/experimental/torch_xla2/torch_xla2/_ops.py
@@ -60,20 +60,11 @@ def _aten_add(x, y, *, alpha=1):
return x + y * alpha
-@op(torch.ops.aten.add_, is_jax_func=False)
-def _aten_add_inplace(self, other, *, alpha):
- if isinstance(other, XLATensor2):
- self._elem += alpha * other._elem
- else:
- self._elem += alpha * other
- return self
-
-
@op(torch.ops.aten.copy_, is_jax_func=False)
def _aten_copy(x, y, memory_format=None):
- if isinstance(x, XLATensor2):
+ if isinstance(x, tensor.XLATensor2):
x._elem = y._elem
- elif isinstance(x, SliceView):
+ elif isinstance(x, tensor.SliceView):
x.mutate(y)
return x
@@ -1693,3 +1684,62 @@ def _aten_scalar_tensor(s,
dtype = tensor.t2j_dtype(dtype)
return jnp.array(s, dtype=dtype)
return jnp.array(s)
+
+
+@op(torch.ops.aten.to.device)
+def _aten_to_device(x,device, dtype):
+ return x
+
+
+@op(torch.ops.aten.max_pool2d_with_indices_backward)
+def max_pool2d_with_indices_backward_custom(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices):
+
+ """
+ Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward.
+
+ Args:
+ grad_output: The gradient tensor from the preceding layer.
+ self: The input tensor on which the original max pooling was performed.
+ kernel_size: The size of the pooling window.
+ stride: The stride of the pooling window.
+ padding: The padding applied during max pooling.
+ dilation: The dilation factor for the pooling operation.
+ ceil_mode: Whether to use ceil or floor when calculating output shapes.
+ indices: The indices of the maximum values, as produced by max_pool2d_with_indices.
+
+ Returns:
+ The calculated gradient with respect to the input (grad_input).
+ """
+
+ kH, kW = kernel_size
+ dH, dW = stride
+ padH, padW = padding
+ dilH, dilW = dilation
+
+ # Calculate output shape (may need adjustment based on ceil_mode)
+ out_shape = jnp.array(self.shape)
+ grad_input = jnp.zeros_like(self)
+
+ # Iterate over the flattened input and output tensors
+ for i, idx in enumerate(indices.flatten()):
+ # Calculate input coordinates corresponding to the maximum value
+ out_y, out_x = i // grad_output.shape[3], i % grad_output.shape[3]
+ in_y = out_y * dH - padH + out_y * (dilH - 1)
+ in_x = out_x * dW - padW + out_x * (dilW - 1)
+
+ # Scatter the gradient to the appropriate input locations (handling potential overlaps)
+ for y in range(in_y, in_y + kH):
+ for x in range(in_x, in_x + kW):
+ if 0 <= y < grad_input.shape[2] and 0 <= x < grad_input.shape[3]:
+ grad_input = grad_input.at[y, x].add(grad_output.flatten()[i])
+
+ return grad_input
+
+
+@op(torch.ops.aten._local_scalar_dense)
+def _aten_local_scalar_dense(x):
+ return x.item()
+
+@op(torch.ops.aten.tensor_split.sections)
+def _aten_tensor_split(ary, indices_or_sections, axis=0):
+ return jnp.array_split(ary, indices_or_sections, axis)
\ No newline at end of file
diff --git a/experimental/torch_xla2/torch_xla2/functions.py b/experimental/torch_xla2/torch_xla2/functions.py
index 6c40455959c..9fcd5653a86 100644
--- a/experimental/torch_xla2/torch_xla2/functions.py
+++ b/experimental/torch_xla2/torch_xla2/functions.py
@@ -103,7 +103,6 @@ def __torch_function__(self,
kwargs=None) -> torch.Tensor:
jax_func = registry.get(func)
if not jax_func:
- logging.warning(f'Falling back to default implementation of {func.__name__}')
return func(*args, **(kwargs or {}))
# TODO: unwrap args here or in implementations?
diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py
index 6adacedbbc0..a30fae82de8 100644
--- a/experimental/torch_xla2/torch_xla2/ops/jaten.py
+++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py
@@ -1,9 +1,37 @@
"""This module contains implementation of ATen ops."""
import torch
-import jax
-import jax.numpy as jnp
-from torch_xla2.ops import op_base
# Keys are OpOverload, value is a callable that takes
# XLATensor2
all_ops = {}
+
+# list all Aten ops from pytorch that does mutation
+# and need to be implemented in jax
+
+mutation_ops_to_functional = {
+ torch.ops.aten.add_: torch.ops.aten.add,
+ torch.ops.aten.sub_: torch.ops.aten.sub,
+ torch.ops.aten.mul_: torch.ops.aten.mul,
+ torch.ops.aten.div_: torch.ops.aten.div,
+ torch.ops.aten.pow_: torch.ops.aten.pow,
+ torch.ops.aten.lt_: torch.ops.aten.lt,
+ torch.ops.aten.le_: torch.ops.aten.le,
+ torch.ops.aten.gt_: torch.ops.aten.gt,
+ torch.ops.aten.ge_: torch.ops.aten.ge,
+ torch.ops.aten.eq_: torch.ops.aten.eq,
+ torch.ops.aten.ne_: torch.ops.aten.ne,
+}
+
+
+def make_mutation(op):
+
+ def f(*args, **kwargs):
+ res = mutation_ops_to_functional[op](*args, **kwargs)
+ args[0].copy_(res)
+ return args[0]
+
+ return f
+
+
+for op in mutation_ops_to_functional.keys():
+ all_ops[op] = make_mutation(op)
diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py
index 04ca3aed58b..98953a8b04c 100644
--- a/experimental/torch_xla2/torch_xla2/tensor.py
+++ b/experimental/torch_xla2/torch_xla2/tensor.py
@@ -4,13 +4,16 @@
import jax.numpy as jnp
import numpy
import torch
-import torch._decomp as decomp
+import torch.func
import torch._decomp.decompositions
from torch_xla2 import ops_registry
import torch.utils._python_dispatch as torch_dispatch
import torch.utils._pytree as torch_pytree
import torch.utils.dlpack as torchdl
from torch_xla2.ops import jaten
+from torch._subclasses.fake_tensor import FakeTensorMode
+
+fake_mode = FakeTensorMode()
class XLADispatchMode(torch_dispatch.TorchDispatchMode):
@@ -36,8 +39,8 @@ def _aten_arange(start,
def _aten_scalar_tensor(val, **kwargs):
- p = torch.ops.aten.scalar_tensor(val)
- return wrap(t2j(p))
+ p = torch.ops.aten.scalar_tensor(val)
+ return wrap(t2j(p))
constructors = {
@@ -192,11 +195,17 @@ def type_as(self, other):
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
with jax.named_scope(func.name()):
+
if isinstance(func, torch._ops.OpOverloadPacket):
return func(*args, **kwargs)
- if func in jaten.all_ops:
- return jaten.all_ops[func](*args, **kwargs)
+ if func.name() == 'aten::copy_':
+ x, y = args
+ x._elem = y._elem
+ return
+
+ if func.overloadpacket in jaten.all_ops:
+ return jaten.all_ops[func.overloadpacket](*args, **kwargs)
lowering = ops_registry.lowerings.lookup(func)
diff --git a/infra/ansible/Dockerfile b/infra/ansible/Dockerfile
index acd726061ad..36954841372 100644
--- a/infra/ansible/Dockerfile
+++ b/infra/ansible/Dockerfile
@@ -9,16 +9,10 @@ COPY . /ansible
ARG ansible_vars
RUN ansible-playbook -vvv playbook.yaml -e "stage=build" -e "${ansible_vars}"
+RUN ansible-playbook -vvv playbook.yaml -e "stage=build_plugin" -e "${ansible_vars}" --skip-tags=fetch_srcs,install_deps
FROM python:${python_version}-${debian_version} AS release
-WORKDIR /ansible
-RUN pip install ansible
-COPY . /ansible
-
-ARG ansible_vars
-RUN ansible-playbook -vvv playbook.yaml -e "stage=release" -e "${ansible_vars}" --tags "install_deps"
-
WORKDIR /tmp/wheels
COPY --from=build /src/pytorch/dist/*.whl ./
COPY --from=build /src/pytorch/xla/dist/*.whl ./
@@ -27,6 +21,14 @@ COPY --from=build /dist/torchvision*.whl ./
RUN echo "Installing the following wheels" && ls *.whl
RUN pip install *.whl
+# Install the dependencies including libtpu, that's why this needs to happen after wheels are installed.
+WORKDIR /ansible
+RUN pip install ansible
+COPY . /ansible
+
+ARG ansible_vars
+RUN ansible-playbook -vvv playbook.yaml -e "stage=release" -e "${ansible_vars}" --tags "install_deps"
+
WORKDIR /
RUN rm -rf /ansible /tmp/wheels
diff --git a/infra/ansible/config/env.yaml b/infra/ansible/config/env.yaml
index 15e8dc79d6c..d324729ce11 100644
--- a/infra/ansible/config/env.yaml
+++ b/infra/ansible/config/env.yaml
@@ -14,7 +14,7 @@ release_env:
TPUVM_MODE: 1
cuda:
- TF_CUDA_COMPUTE_CAPABILITIES: 7.0,7.5,8.0,9.0
+ TF_CUDA_COMPUTE_CAPABILITIES: "{{ cuda_compute_capabilities }}"
XLA_CUDA: 1
# Variables that will be passed to shell environment only for building PyTorch and XLA libs.
@@ -31,7 +31,7 @@ build_env:
PYTORCH_BUILD_VERSION: "{{ package_version }}"
XLA_SANDBOX_BUILD: 1
BAZEL_REMOTE_CACHE: 1
- SILO_NAME: "cache-silo-{{ arch }}-{{ accelerator }}-{{ clang_version }}"
+ SILO_NAME: "cache-silo-{{ arch }}-{{ accelerator }}-{{ clang_version }}{{ cache_suffix }}"
_GLIBCXX_USE_CXX11_ABI: 0
GIT_VERSIONED_XLA_BUILD: "{{ nightly_release }}"
@@ -41,7 +41,7 @@ build_env:
aarch64:
cuda:
- TF_CUDA_COMPUTE_CAPABILITIES: 7.0,7.5,8.0,9.0
+ TF_CUDA_COMPUTE_CAPABILITIES: "{{ cuda_compute_capabilities }}"
XLA_CUDA: 1
tpu:
diff --git a/infra/ansible/config/pip.yaml b/infra/ansible/config/pip.yaml
index 1c82075abef..5941f6e2e0b 100644
--- a/infra/ansible/config/pip.yaml
+++ b/infra/ansible/config/pip.yaml
@@ -45,6 +45,7 @@ pip:
- mkl-include
release_tpu:
+ - torch_xla[tpuvm]
# Packages that will be installed with the `--nodeps` flag.
pkgs_nodeps:
diff --git a/infra/ansible/config/vars.yaml b/infra/ansible/config/vars.yaml
index 2347d066e84..c1ca7a93d27 100644
--- a/infra/ansible/config/vars.yaml
+++ b/infra/ansible/config/vars.yaml
@@ -1,6 +1,8 @@
# Used for fetching cuda from the right repo, see apt.yaml.
cuda_repo: debian11
cuda_version: "11.8"
+# Determines supported GPUs. See https://developer.nvidia.com/cuda-gpus
+cuda_compute_capabilities: 7.0,7.5,8.0,9.0
# Used for fetching clang from the right repo, see apt.yaml.
llvm_debian_repo: bullseye
clang_version: 17
@@ -10,3 +12,5 @@ package_version: 2.4.0
nightly_release: false
# Whether to preinstall libtpu in the PyTorch/XLA wheel. Ignored for GPU build.
bundle_libtpu: 1
+# Suffix for bazel remote cache key
+cache_suffix: ""
diff --git a/infra/ansible/e2e_tests.Dockerfile b/infra/ansible/e2e_tests.Dockerfile
index ab29d2dc76c..780c0f4abdb 100644
--- a/infra/ansible/e2e_tests.Dockerfile
+++ b/infra/ansible/e2e_tests.Dockerfile
@@ -10,6 +10,7 @@ COPY . /ansible
# Build PyTorch and PyTorch/XLA wheels.
ARG ansible_vars
RUN ansible-playbook -vvv playbook.yaml -e "stage=build" -e "${ansible_vars}"
+RUN ansible-playbook -vvv playbook.yaml -e "stage=build_plugin" -e "${ansible_vars}" --skip-tags=fetch_srcs,install_deps
FROM python:${python_version}-${debian_version}
WORKDIR /ansible
diff --git a/infra/ansible/playbook.yaml b/infra/ansible/playbook.yaml
index 3484fdc72ce..524b2a8c70c 100644
--- a/infra/ansible/playbook.yaml
+++ b/infra/ansible/playbook.yaml
@@ -16,7 +16,7 @@
"Pass the required variable with: --e \"{{ item.name }}=\""
loop:
- name: stage
- pattern: ^(build|release)$
+ pattern: ^(build|build_plugin|release)$
- name: arch
pattern: ^(aarch64|amd64)$
- name: accelerator
@@ -73,6 +73,7 @@
src_root: "/src"
tags: fetch_srcs
+ # TODO: better name now that there are two builds
- role: build_srcs
vars:
src_root: "/src"
@@ -81,8 +82,20 @@
combine(build_env[arch] | default({}, true)) |
combine(build_env[accelerator] | default({}, true))
}}"
+ when: stage == "build"
tags: build_srcs
+ - role: build_plugin
+ vars:
+ src_root: "/src"
+ env_vars: "{{
+ build_env.common | default({}, true) |
+ combine(build_env[arch] | default({}, true)) |
+ combine(build_env[accelerator] | default({}, true))
+ }}"
+ when: stage == "build_plugin"
+ tags: build_plugin
+
- role: configure_env
vars:
env_vars: "{{
diff --git a/infra/ansible/roles/build_plugin/tasks/main.yaml b/infra/ansible/roles/build_plugin/tasks/main.yaml
new file mode 100644
index 00000000000..2e2590b150a
--- /dev/null
+++ b/infra/ansible/roles/build_plugin/tasks/main.yaml
@@ -0,0 +1,32 @@
+- name: Create /dist directory for exported wheels
+ ansible.builtin.file:
+ path: /dist
+ state: directory
+ mode: '0755'
+
+- name: Build PyTorch/XLA CUDA Plugin
+ ansible.builtin.command:
+ cmd: pip wheel -w /dist plugins/cuda -v
+ chdir: "{{ (src_root, 'pytorch/xla') | path_join }}"
+ environment: "{{ env_vars }}"
+ when: accelerator == "cuda"
+
+- name: Find CUDA plugin wheel pytorch/xla/dist
+ ansible.builtin.find:
+ path: "/dist"
+ pattern: "torch_xla_cuda_plugin*.whl"
+ when: accelerator == "cuda"
+ register: plugin_wheels
+
+- name: Install CUDA plugin wheels
+ ansible.builtin.pip:
+ name: "{{ plugin_wheels.files | map(attribute='path') }}"
+ state: "forcereinstall"
+ when: accelerator == "cuda"
+
+# TODO: Pass libtpu to next release stage somehow. This only runs during build
+- name: Install libtpu
+ ansible.builtin.pip:
+ name: torch_xla[tpu]
+ extra_args: -f https://storage.googleapis.com/libtpu-releases/index.html
+ when: accelerator == "tpuvm"
diff --git a/infra/ansible/roles/build_srcs/tasks/main.yaml b/infra/ansible/roles/build_srcs/tasks/main.yaml
index bc708e2c680..d945f150d38 100644
--- a/infra/ansible/roles/build_srcs/tasks/main.yaml
+++ b/infra/ansible/roles/build_srcs/tasks/main.yaml
@@ -23,13 +23,6 @@
chdir: "{{ (src_root, 'pytorch/xla') | path_join }}"
environment: "{{ env_vars }}"
-- name: Build PyTorch/XLA CUDA Plugin
- ansible.builtin.command:
- cmd: pip wheel -w dist plugins/cuda -v
- chdir: "{{ (src_root, 'pytorch/xla') | path_join }}"
- environment: "{{ env_vars }}"
- when: accelerator == "cuda"
-
- name: Find XLA *.whl files in pytorch/xla/dist
ansible.builtin.find:
path: "{{ (src_root, 'pytorch/xla/dist') | path_join }}"
diff --git a/infra/tpu-pytorch-releases/dev_images.auto.tfvars b/infra/tpu-pytorch-releases/dev_images.auto.tfvars
index db5627d2a09..ba6f273ba09 100644
--- a/infra/tpu-pytorch-releases/dev_images.auto.tfvars
+++ b/infra/tpu-pytorch-releases/dev_images.auto.tfvars
@@ -4,15 +4,11 @@ dev_images = [
extra_tags = ["tpu"]
python_version = "3.10"
},
- {
- accelerator = "cuda"
- cuda_version = "11.8"
- extra_tags = ["cuda"]
- },
{
accelerator = "cuda"
cuda_version = "12.1"
extra_tags = ["cuda"]
+ python_version = "3.10"
},
{
accelerator = "cuda"
diff --git a/plugins/cuda/pyproject.toml b/plugins/cuda/pyproject.toml
index fd8bbf59f6c..d44a2ea3bd5 100644
--- a/plugins/cuda/pyproject.toml
+++ b/plugins/cuda/pyproject.toml
@@ -1,15 +1,15 @@
[build-system]
-requires = ["setuptools"]
+requires = ["setuptools", "numpy"]
build-backend = "setuptools.build_meta"
[project]
name = "torch_xla_cuda_plugin"
-version = "0.0.1"
authors = [
{name = "PyTorch/XLA Dev Team", email = "pytorch-xla@googlegroups.com"},
]
description = "PyTorch/XLA CUDA Plugin"
requires-python = ">=3.8"
+dynamic = ["version"]
[tool.setuptools.package-data]
torch_xla_cuda_plugin = ["lib/*.so"]
diff --git a/plugins/cuda/setup.py b/plugins/cuda/setup.py
index 4207d598ed2..d8756c4f72e 100644
--- a/plugins/cuda/setup.py
+++ b/plugins/cuda/setup.py
@@ -1,3 +1,4 @@
+import datetime
import os
import sys
@@ -10,4 +11,9 @@
build_util.bazel_build('@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so',
'torch_xla_cuda_plugin/lib', ['--config=cuda'])
-setuptools.setup()
+setuptools.setup(
+ # TODO: Use a common version file
+ version=os.getenv(
+ 'TORCH_XLA_VERSION',
+ f'2.4.0.dev{datetime.date.today().strftime("%Y%m%d")}')
+)
diff --git a/plugins/cuda/torch_xla_cuda_plugin/__init__.py b/plugins/cuda/torch_xla_cuda_plugin/__init__.py
index 9321d26a1a6..e6863ff711a 100644
--- a/plugins/cuda/torch_xla_cuda_plugin/__init__.py
+++ b/plugins/cuda/torch_xla_cuda_plugin/__init__.py
@@ -27,6 +27,9 @@ def physical_chip_count(self) -> int:
# TODO: default to actual device count
return xu.getenv_as('GPU_NUM_DEVICES', int, 1)
+ def configure_single_process(self):
+ pass
+
def client_create_options(self) -> dict:
local_process_rank, global_process_rank = self._get_process_rank()
local_world_size, global_world_size = self._get_world_size()
diff --git a/setup.py b/setup.py
index c049eb6baf0..dbe47007aff 100644
--- a/setup.py
+++ b/setup.py
@@ -64,7 +64,7 @@
base_dir = os.path.dirname(os.path.abspath(__file__))
-_date = '20240404'
+_date = '20240418'
_libtpu_version = f'0.1.dev{_date}'
_libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl'
_jax_version = f'0.4.27.dev{_date}'
diff --git a/test/benchmarks/run_tests.sh b/test/benchmarks/run_tests.sh
index 7d404a7ee7f..3832b21ed22 100755
--- a/test/benchmarks/run_tests.sh
+++ b/test/benchmarks/run_tests.sh
@@ -39,10 +39,14 @@ function run_make_tests {
}
function run_python_tests {
- python3 "$CDIR/test_experiment_runner.py"
- python3 "$CDIR/test_benchmark_experiment.py"
- python3 "$CDIR/test_benchmark_model.py"
- python3 "$CDIR/test_result_analyzer.py"
+ # HACK: don't confuse local `torch_xla` folder with installed package
+ # Python 3.11 has the permanent fix: https://stackoverflow.com/a/73636559
+ pushd $CDIR
+ python3 "test_experiment_runner.py"
+ python3 "test_benchmark_experiment.py"
+ python3 "test_benchmark_model.py"
+ python3 "test_result_analyzer.py"
+ popd
}
function run_tests {
diff --git a/test/cpp/run_tests.sh b/test/cpp/run_tests.sh
index 74244322840..d6b492dc694 100755
--- a/test/cpp/run_tests.sh
+++ b/test/cpp/run_tests.sh
@@ -5,7 +5,7 @@ BUILDTYPE="opt"
VERB=
FILTER=
LOGFILE=/tmp/pytorch_cpp_test.log
-XLA_EXPERIMENTAL="nonzero:masked_select"
+XLA_EXPERIMENTAL="nonzero:masked_select:nms"
BAZEL_REMOTE_CACHE="0"
BAZEL_VERB="test"
diff --git a/test/cpp/test_aten_xla_tensor_4.cpp b/test/cpp/test_aten_xla_tensor_4.cpp
index 0a0d84d463a..ff6130ca1b9 100644
--- a/test/cpp/test_aten_xla_tensor_4.cpp
+++ b/test/cpp/test_aten_xla_tensor_4.cpp
@@ -1226,7 +1226,6 @@ TEST_F(AtenXlaTensorTest, TestPixelShuffle) {
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
- ExpectCounterChanged("xla::permute_copy", cpp_test::GetIgnoredCounters());
}
TEST_F(AtenXlaTensorTest, TestSumToSize) {
diff --git a/test/cpp/test_aten_xla_tensor_5.cpp b/test/cpp/test_aten_xla_tensor_5.cpp
index 4070779529f..07e4c2dae86 100644
--- a/test/cpp/test_aten_xla_tensor_5.cpp
+++ b/test/cpp/test_aten_xla_tensor_5.cpp
@@ -267,6 +267,27 @@ TEST_F(AtenXlaTensorTest, TestEmbedding) {
});
}
+TEST_F(AtenXlaTensorTest, TestEmbeddingBag) {
+ torch::Tensor weight =
+ torch::rand({32, 4}, torch::TensorOptions(torch::kFloat));
+ torch::Tensor indices =
+ torch::randint(0, 31, {10}, torch::TensorOptions(torch::kLong));
+ torch::Tensor offsets = torch::arange(0, 10, 3);
+ auto out = torch::embedding_bag(weight, indices, offsets);
+ torch::Tensor result = std::get<0>(out);
+ ForEachDevice([&](const torch::Device& device) {
+ torch::Tensor xla_weight = CopyToDevice(weight, device);
+ torch::Tensor xla_indices = CopyToDevice(indices, device);
+ torch::Tensor xla_offsets = CopyToDevice(offsets, device);
+ auto xla_out = torch::embedding_bag(xla_weight, xla_indices, xla_offsets);
+ torch::Tensor xla_result = std::get<0>(xla_out);
+ AllClose(result, xla_result);
+ ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
+ ExpectCounterChanged("xla::_embedding_bag_forward_only",
+ cpp_test::GetIgnoredCounters());
+ });
+}
+
TEST_F(AtenXlaTensorTest, TestOneHot) {
int num_classes = 5;
torch::Tensor input =
diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py
index f0aee024f75..c3dfe6bbed1 100644
--- a/test/dynamo/test_dynamo.py
+++ b/test/dynamo/test_dynamo.py
@@ -7,6 +7,7 @@
import torch_xla.utils.utils as xu
import torch_xla.debug.metrics as met
from torch_xla import runtime as xr
+import torch_xla.debug.profiler as xp
import torch.optim as optim
import torch.nn as nn
import torch._dynamo as dynamo
@@ -91,6 +92,20 @@ def test_mark_step_after_dynamo(self):
self.assertEqual(current_execute_time, met.metric_data('ExecuteTime')[0])
+class DynamoProfilerTest(unittest.TestCase):
+
+ def dummy_fn(self, a):
+ return torch.sin(a) + a
+
+ def test_dynamo_with_trace(self):
+ dynamo_dummy = torch.compile(
+ self.dummy_fn, backend="openxla", fullgraph=True)
+ t = torch.randn(2, 3, 4, device=xm.xla_device())
+ for i in range(10):
+ with xp.Trace('build_graph'):
+ t = dynamo_dummy(t)
+
+
class DynamoInferenceBasicTest(unittest.TestCase):
@classmethod
@@ -137,7 +152,7 @@ def test_simple_model(self):
# Tests that the dynamo bridge automatically moves tensors to XLA device,
# then back to the original device.
- @unittest.skipIf(xr.device_type() != "CUDA",
+ @unittest.skipIf(xr.device_type() != "CUDA" or not torch.cuda.is_available(),
f"GPU tests should only run on GPU devices.")
def test_simple_model_automoves_tensors(self):
x = torch.tensor(100.0).to(device="cuda")
diff --git a/test/pytorch_test_base.py b/test/pytorch_test_base.py
index 88ad0f6bc3d..3a6dcdd96c6 100644
--- a/test/pytorch_test_base.py
+++ b/test/pytorch_test_base.py
@@ -70,6 +70,7 @@
'test_pdist_norm_backward_xla', # pdist_single
'test_pdist_norm_forward_xla', # pdist_single
'test_nuclear_norm_axes_small_brute_force',
+ 'test_nondeterministic_alert_EmbeddingBag_max_xla', # FIXME: implement embedding_bag_backward
'test_mul_intertype_scalar',
'test_masked_select_discontiguous', # FIXME: wrong result
'test_memory_format_type',
diff --git a/test/run_tests.sh b/test/run_tests.sh
index 4d4bd530e27..8926318dc38 100755
--- a/test/run_tests.sh
+++ b/test/run_tests.sh
@@ -104,7 +104,7 @@ function run_xla_hlo_debug {
function run_dynamic {
echo "Running in DynamicShape mode: $@"
- XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter" run_test "$@"
+ XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter:nms" run_test "$@"
}
function run_eager_debug {
diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py
index 0595f502da0..d1f6cdc3dce 100644
--- a/test/spmd/test_dynamo_spmd.py
+++ b/test/spmd/test_dynamo_spmd.py
@@ -205,6 +205,8 @@ def test_dynamo_spmd_mark_sharding_outside_of_compile(self):
dynamo_res = dynamo_linear(xla_x)
self.assertEqual(met.metric_data('CompileTime')[0], compile_count)
+ # https://github.com/pytorch/xla/pull/6921#issuecomment-2062106737
+ @unittest.skip("Failing in CI")
def test_mark_sharding_inside_compile(self):
met.clear_counters()
device = xm.xla_device()
diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py
index cbd85eb532b..c6a30fd4bd8 100644
--- a/test/spmd/test_xla_sharding.py
+++ b/test/spmd/test_xla_sharding.py
@@ -1100,6 +1100,119 @@ def test_global_mesh(self):
self.assertEqual(id(mesh), id(expected_mesh))
+ def test_mark_manual_sharding(self):
+ x = torch.zeros(3, 2).to(xm.xla_device())
+ with self.assertRaises(RuntimeError):
+ xt = xs._mark_manual_sharding(x)
+
+ xx = x + 1
+ xt = xs._mark_manual_sharding(xx)
+
+ hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt.global_tensor])
+ self.assertIn(', sharding={manual}', hlo)
+ self.assertEqual(xt.sharding_type, xs.ShardingType.MANUAL)
+ self.assertEqual(xt.sharding_spec, "{manual}")
+
+ # It looks like XLA does't like only having manual sharding in the HLO.
+ # It needs to be paired with SPMDFullToShardShape/SPMDShardToFullShape.
+ # The following exception cannot be caught somehow.
+ # xt.global_tensor.cpu()
+
+ def test_spmd_full_to_shard_shape(self):
+ x = torch.zeros(8, 8).to(xm.xla_device())
+ with self.assertRaises(RuntimeError):
+ x = torch_xla._XLAC._spmd_full_to_shard_shape(x)
+
+ # Sharded shape
+ xt = xs.mark_sharding(x, self._get_mesh((1, self.n_devices)), (0, 1))
+ xx = torch_xla._XLAC._spmd_full_to_shard_shape(xt.global_tensor)
+
+ hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx])
+ self.assertEqual(xx.shape, (8, 8 // self.n_devices))
+ self.assertIn(f'%custom-call.2 = f32[8,{8//self.n_devices}]{{1,0}}', hlo)
+ self.assertIn(
+ f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}', hlo)
+ self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{manual}")
+
+ # It looks like XLA does't like only having manual sharding in the HLO.
+ # It needs to be paired with SPMDFullToShardShape/SPMDShardToFullShape.
+ # The following exception cannot be caught somehow.
+ # xx.cpu()
+
+ # Replicated shape
+ x = torch.zeros(8, 4).to(xm.xla_device())
+ xt = xs.mark_sharding(x, self._get_mesh((self.n_devices, 1)), (None, None))
+ xx = torch_xla._XLAC._spmd_full_to_shard_shape(xt.global_tensor)
+
+ hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx])
+ self.assertEqual(xx.shape, (8, 4))
+ self.assertIn(f'%custom-call.2 = f32[8,4]{{1,0}}', hlo)
+ self.assertIn(
+ f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}', hlo)
+ self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{manual}")
+
+ def test_spmd_shard_to_full_shape(self):
+ x = torch.zeros(8, 8).to(xm.xla_device())
+ x += 1
+ # No sharding spec attached.
+ with self.assertRaises(RuntimeError):
+ x = torch_xla._XLAC._spmd_shard_to_full_shape(
+ x, torch_xla._XLAC.OpSharding([], [], [], xs.ShardingType.REPLICATED),
+ x.shape, x.dtype)
+
+ xt = xs.mark_sharding(x, self._get_mesh((1, self.n_devices)), (0, 1))
+ # Not manual sharding.
+ with self.assertRaises(RuntimeError):
+ x = torch_xla._XLAC._spmd_shard_to_full_shape(
+ xt.global_tensor,
+ torch_xla._XLAC.OpSharding([], [], [], xs.ShardingType.REPLICATED),
+ x.shape, x.dtype)
+
+ xs.clear_sharding(xt)
+ xt = xs._mark_manual_sharding(xt)
+ xx = torch_xla._XLAC._spmd_shard_to_full_shape(
+ xt.global_tensor,
+ torch_xla._XLAC.OpSharding([], [], [], xs.ShardingType.REPLICATED),
+ x.shape, x.dtype)
+
+ hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx])
+ self.assertEqual(xx.shape, x.shape)
+ self.assertIn('%custom-call.9 = f32[8,8]{1,0}', hlo)
+ self.assertIn(
+ 'custom_call_target="SPMDShardToFullShape", sharding={replicated}', hlo)
+ self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{replicated}")
+
+ def test_manual_sharding_e2e(self):
+ x = torch.zeros(8, 8).to(xm.xla_device())
+ mesh = self._get_mesh((1, self.n_devices))
+ partition_spec = (0, 1)
+ xt = xs.mark_sharding(x, mesh, partition_spec)
+
+ xx = torch_xla._XLAC._spmd_full_to_shard_shape(xt.global_tensor)
+ self.assertEqual(xx.shape, (8, 8 // self.n_devices))
+
+ xx = xx + 1
+ xxt = xs._mark_manual_sharding(xx)
+ xxx = torch_xla._XLAC._spmd_shard_to_full_shape(
+ xxt.global_tensor, mesh.get_op_sharding(partition_spec), x.shape,
+ x.dtype)
+ self.assertEqual(xxx.shape, (8, 8))
+
+ self.assertTrue(torch.allclose(x.cpu() + 1, xxx.cpu()))
+
+ def test_manual_sharding_api_e2e(self):
+ xs.set_global_mesh(self._get_mesh((1, self.n_devices)))
+ x = torch.zeros(8, 8).to(xm.xla_device())
+ partition_spec = (0, 1)
+
+ xx = xs.enable_manual_sharding(x, partition_spec)
+ self.assertEqual(xx.shape, (8, 8 // self.n_devices))
+
+ xx = xx + 1
+ xxx = xs.disable_manual_sharding(xx, partition_spec, x.shape)
+ self.assertEqual(xxx.shape, (8, 8))
+ self.assertTrue(torch.allclose(x.cpu() + 1, xxx.cpu()))
+
if __name__ == '__main__':
test = unittest.main()
diff --git a/test/stablehlo/test_exports.py b/test/stablehlo/test_exports.py
index 155e4148b14..6208ae1ca52 100644
--- a/test/stablehlo/test_exports.py
+++ b/test/stablehlo/test_exports.py
@@ -45,7 +45,7 @@ def test_interpolate(self):
exported = torch.export.export(model, arg)
shlo = exported_program_to_stablehlo(exported)
ans2 = shlo(*arg).cpu().to(torch.float32)
- self.assertTrue(torch.allclose(ans, ans2, atol=1e-5))
+ torch.testing.assert_close(ans, ans2, rtol=1e-5, atol=1e-4)
def test_constant(self):
@@ -113,6 +113,26 @@ def forward(self, x):
self.assertTrue(re.search(r'%arg.*tensor', shlo_text) is not None)
self.assertFalse('stablehlo.constant dense<5.000000e+00>' in shlo_text)
+ def test_export_no_weights(self):
+
+ class M(torch.nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.randn(10, 10))
+
+ def forward(self, x):
+ return torch.ops.aten.add(x, self.weight)
+
+ arg = (torch.randn(10, 10),)
+ model = M()
+ with torch.no_grad():
+ exported = torch.export.export(model, arg)
+ export_options = StableHLOExportOptions()
+ export_options.export_weights = False
+ shlo = exported_program_to_stablehlo(exported, options=export_options)
+ self.assertEqual(shlo._bundle.state_dict, {})
+
if __name__ == '__main__':
unittest.main()
diff --git a/test/test_devices.py b/test/test_devices.py
index ff93f64a5c5..e1fc804736d 100644
--- a/test/test_devices.py
+++ b/test/test_devices.py
@@ -4,14 +4,19 @@
import torch
import torch_xla as xla
import torch_xla.runtime as xr
+import torch_xla.debug.metrics as met
class TestDevices(parameterized.TestCase):
- def setUpClass():
+ @classmethod
+ def setUpClass(cls):
xr.set_device_type('CPU')
os.environ['CPU_NUM_DEVICES'] = '4'
+ def tearDown(self):
+ met.clear_metrics()
+
@parameterized.parameters((None, torch.device('xla:0')),
(0, torch.device('xla:0')),
(3, torch.device('xla:3')))
@@ -29,6 +34,12 @@ def test_real_devices(self):
def test_device_count(self):
self.assertEqual(xla.device_count(), 4)
+ def test_sync(self):
+ torch.ones((3, 3), device=xla.device())
+ xla.sync()
+
+ self.assertEqual(met.counter_value('MarkStep'), 1)
+
if __name__ == "__main__":
absltest.main()
diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
index fdbc58d12d0..a76197cc736 100644
--- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
+++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
@@ -6,17 +6,26 @@
import torch_xla
# We need to import the underlying implementation function to register with the dispatcher
import torch_xla.experimental.fori_loop
+from torch_xla.experimental.fori_loop import fori_loop
from torch._higher_order_ops.while_loop import while_loop
import torch_xla.core.xla_model as xm
import torch_xla.core.xla_builder as xb
def _fake_while_loop(cond_fn, body_fn, operands):
- while cond_fn(operands[0], operands[1]):
- operands = body_fn(operands[0], operands[1])
+ # operands need to be more than one here
+ while cond_fn(*operands):
+ operands = body_fn(*operands)
return operands
+def _fake_fori_loop(lower, upper, body_fun, *init_val):
+ (plus_value, init_val) = init_val
+ for i in range((upper - lower)[0]):
+ plus_value, init_val = body_fun(plus_value, init_val)
+ return init_val
+
+
class WhileLoopTest(unittest.TestCase):
def test_while_loop_tpu_subtraction(self):
@@ -73,7 +82,25 @@ def body_fn(init, limit_value):
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
self.assertEqual(expected, res)
+ def test_fori_loop_tpu_addition(self):
+
+ xm.mark_step()
+ device = xm.xla_device()
+
+ lower = torch.tensor([2], dtype=torch.int32, device=device)
+ upper = torch.tensor([52], dtype=torch.int32, device=device)
+ plus_value = torch.tensor([1], dtype=torch.int32, device=device)
+ init_val = torch.tensor([1], dtype=torch.int32, device=device)
+
+ def body_fun(*argus):
+ plus_value, init_val = argus
+ return plus_value, torch.add(plus_value, init_val)
+
+ _, _, _, actual = fori_loop(upper, lower, body_fun, plus_value, init_val)
+ expected = _fake_fori_loop(lower, upper, body_fun, plus_value, init_val)
+ self.assertEqual(expected, actual)
+
if __name__ == '__main__':
test = unittest.main()
- sys.exit(0 if test.result.wasSuccessful() else 1)
+ sys.exit(0 if test.result.wasSuccessful() else 1)
\ No newline at end of file
diff --git a/test/test_operations.py b/test/test_operations.py
index 7fb9f5bc3e3..ff32c268927 100644
--- a/test/test_operations.py
+++ b/test/test_operations.py
@@ -88,6 +88,12 @@ def onlyOnCUDA(fn):
return unittest.skipIf(accelerator != "cuda", "PJRT_DEVICE=CUDA required")(fn)
+def onlyIfXLAExperimentalContains(feat):
+ experimental = os.environ.get("XLA_EXPERIMENTAL", "").split(":")
+ return unittest.skipIf(feat not in experimental,
+ f"XLA_EXPERIMENTAL={feat} required")
+
+
def _gen_tensor(*args, **kwargs):
return torch.randn(*args, **kwargs)
@@ -2454,6 +2460,7 @@ def test_dropout(self):
# These tests were extracted and adapted from torchvision.
# Source: vision/test/test_ops.py
+@onlyIfXLAExperimentalContains("nms")
class TestNMS(test_utils.XlaTestCase):
def _reference_nms(self, boxes, scores, iou_threshold):
diff --git a/test/test_pallas.py b/test/test_pallas.py
index 8ab0caf9bc8..f8480782094 100644
--- a/test/test_pallas.py
+++ b/test/test_pallas.py
@@ -89,30 +89,6 @@ def test_tpu_custom_call_pallas_flash_attention(self):
[q.dtype])
self.assertTrue(torch.allclose(o[0].cpu(), expected_o.cpu()))
- @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
- @unittest.skip("TODO: Make the tpu_custom_call_ as functional.")
- @unittest.mock.patch.dict(os.environ, {"XLA_DISABLE_FUNCTIONALIZATION": "1"})
- def test_tpu_custom_call_pallas_add_one_dynamo(self):
- # This payload is generated by the following Pallas code:
- # def add_vectors_kernel(x_ref, o_ref):
- # o_ref[...] = x_ref[...] + 1
- payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTguMC4wZ2l0AAEtCwEDBQcJAQMLAwUDDQcFDxEJBxMVFwNlSQ0BRwcPCw8PDxMLDzMLCwsLZQsLCwsPCw8LEw8PCxMPCxMTDwsLBQNhAQ0bDxMHFw8CpgIfFSsxBRkdQwMdRQMRCwEDAw8nBRsdKQMDCxUXGRsfCyELIyUFHQEBBR8NCWFmZmluZV9tYXA8KGQwKSAtPiAoZDApPgAFIQUjBSUFJxEHAQUpHS0vBSsXBRsBFTM5HTU3BS0XBS8BHTs9BS8XBUUBAwMPQREDBQUxBTMjdHB1Lm1lbW9yeV9zcGFjZTx2bWVtPgAXRwMhAx0BAgInAyEDAwUFAQEBAQIEBKEFARABBwMBBQMRARMHAxMnBQEBAQEHAxENAwcLBhEDBQUBBQcDBz8DAw0GBwMFAwkJBgcDBQUHCwcDCQ0DBwsGCQMFBQMPDwQJBw0DDwUAAQYDAQUBAMIHNdsLEyEv2QsTIyEdKQ1DDRULCxMPDw8NCQsRYnVpbHRpbgBmdW5jAHRwdQBhcml0aAB2ZWN0b3IAbW9kdWxlAHJldHVybgBjb25zdGFudABhZGRpAGxvYWQAYnJvYWRjYXN0AHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AHZhbHVlAGRpbWVuc2lvbl9zZW1hbnRpY3MAZnVuY3Rpb25fdHlwZQBzY2FsYXJfcHJlZmV0Y2gAc2NyYXRjaF9vcGVyYW5kcwBzeW1fbmFtZQBtYWluAC9nZXRbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAGFkZF9vbmVfdmVjdG9yc19rZXJuZWwAYWRkX3ZlY3RvcnNfb25lADxtb2R1bGU+AC9hZGQAL3N3YXBbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAA==\", \"needs_layout_passes\": true}}"
-
- x = torch.arange(8, dtype=torch.int).to("xla")
- expected_output = x + 1
- output = torch.arange(8, dtype=torch.int).to("xla")
-
- import torch_xla.experimental.custom_kernel
-
- def add_one_pallas(output, inputs, payload):
- torch.ops.xla.tpu_custom_call(output, inputs, payload)
-
- compiled_add_one_pallas = torch.compile(
- add_one_pallas, backend='openxla', fullgraph=True)
-
- compiled_add_one_pallas(output, [x], payload)
- self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu()))
-
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tpu_custom_call_pallas_extract_add_payload(self):
import jax._src.pallas.mosaic.pallas_call_registration
@@ -441,6 +417,7 @@ def test__flash_attention_bwd_dkv(self):
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_backward(self):
+ jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
from torch_xla.experimental.custom_kernel import flash_attention
torch.manual_seed(42)
@@ -473,9 +450,9 @@ def test_flash_attention_backward(self):
loss.backward()
xm.mark_step()
- mse = torch.nn.MSELoss()
for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
- self.assertTrue(mse(i[0].grad.cpu(), i[1].cpu()) < 1e-4)
+ self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
+ jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)
if __name__ == '__main__':
diff --git a/test/test_pallas_spmd.py b/test/test_pallas_spmd.py
new file mode 100644
index 00000000000..33434594191
--- /dev/null
+++ b/test/test_pallas_spmd.py
@@ -0,0 +1,110 @@
+import logging
+import os
+import unittest
+
+import torch
+from torch import nn as nn
+
+import torch_xla
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.spmd as xs
+from torch_xla import runtime as xr
+from torch_xla._internal import tpu
+
+if xr.device_type() == 'TPU':
+ from torch_xla.experimental.custom_kernel import flash_attention
+ from torch_xla.experimental.custom_kernel import jax_import_guard
+ jax_import_guard()
+ import jax
+ import jax.numpy as jnp
+ from jax.experimental import pallas as pl
+
+
+class PallasTest(unittest.TestCase):
+
+ def _attention(self, q, k, v):
+ attn_weight = q @ k.transpose(-2, -1)
+ attn_weight = nn.functional.softmax(attn_weight, dim=-1)
+ attn_output = attn_weight @ v
+ return attn_output
+
+ @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
+ "This test only works on TPUv3+.")
+ def test_flash_attention_spmd_data_parallel(self):
+ jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
+ n_devices = xr.global_runtime_device_count()
+ xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1)))
+
+ q = torch.randn(4, 2, 128, 4).to("xla")
+ k = torch.randn(4, 2, 128, 4).to("xla")
+ v = torch.randn(4, 2, 128, 4).to("xla")
+
+ o = flash_attention(q, k, v, partition_spec=range(n_devices))
+ self.assertEqual(
+ torch_xla._XLAC._get_xla_sharding_spec(o),
+ f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
+
+ expected_o = self._attention(q, k, v)
+ self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
+ jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)
+
+ @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
+ "This test only works on TPUv3+.")
+ def test_flash_attention_backward_spmd_data_parallel(self):
+ jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
+ n_devices = xr.global_runtime_device_count()
+ xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1)))
+
+ torch.manual_seed(42)
+ q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
+ k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
+ v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
+ q.retain_grad()
+ k.retain_grad()
+ v.retain_grad()
+
+ o = flash_attention(q, k, v, partition_spec=range(n_devices))
+ loss = o.sum()
+ loss.backward()
+ xm.mark_step()
+
+ q_grad = q.grad
+ k_grad = k.grad
+ v_grad = v.grad
+ self.assertEqual(
+ torch_xla._XLAC._get_xla_sharding_spec(q_grad),
+ f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
+ self.assertEqual(
+ torch_xla._XLAC._get_xla_sharding_spec(k_grad),
+ f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
+ self.assertEqual(
+ torch_xla._XLAC._get_xla_sharding_spec(v_grad),
+ f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
+
+ torch.manual_seed(42)
+ q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
+ k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
+ v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
+ q.retain_grad()
+ k.retain_grad()
+ v.retain_grad()
+
+ o = self._attention(q, k, v)
+ loss = o.sum()
+ loss.backward()
+ xm.mark_step()
+
+ for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
+ self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
+ jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ torch.set_default_dtype(torch.float32)
+ torch.manual_seed(42)
+ torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
+ use_full_mat_mul_precision=True)
+ xr.use_spmd()
+ test = unittest.main()
+ sys.exit(0 if test.result.wasSuccessful() else 1)
diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh
index 413951854d6..dc2f4e96dba 100755
--- a/test/tpu/run_tests.sh
+++ b/test/tpu/run_tests.sh
@@ -11,8 +11,8 @@ python3 test/spmd/test_xla_distributed_checkpoint.py
python3 test/spmd/test_train_spmd_linear_model.py
python3 test/spmd/test_xla_spmd_python_api_interaction.py
python3 test/spmd/test_xla_auto_sharding.py
-XLA_EXPERIMENTAL=nonzero:masked_select python3 test/ds/test_dynamic_shape_models.py -v
-XLA_EXPERIMENTAL=nonzero:masked_select python3 test/ds/test_dynamic_shapes.py -v
+XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shape_models.py -v
+XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.py -v
python3 test/test_autocast.py
python3 test/dynamo/test_dynamo.py
python3 test/spmd/test_spmd_debugging.py
diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py
index 379416ec73f..624acb9cb6f 100644
--- a/torch_xla/core/dynamo_bridge.py
+++ b/torch_xla/core/dynamo_bridge.py
@@ -429,7 +429,8 @@ def extract_internal(xla_model: torch.fx.GraphModule):
for xla_arg in xla_model.xla_args:
if isinstance(xla_arg, torch.Tensor):
print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg))
- xm.mark_step()
+ # Don't reset the scope as we might be under some profiler trace scope.
+ xm.mark_step(reset_scope=False)
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model)
@@ -614,8 +615,9 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
if isinstance(a, torch.Tensor) and torch._is_functional_tensor(a):
torch._functionalize_sync(a)
- # This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids
- xm.mark_step()
+ # This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids.
+ # Don't reset the scope as we might be under some profiler trace scope.
+ xm.mark_step(reset_scope=False)
# Find tensor constructor nodes that create CPU tensors, and make
# them create XLA tensors, where possible, instead. i.e. replace the
diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py
index 5e9433ce91e..7591e13af29 100644
--- a/torch_xla/core/xla_model.py
+++ b/torch_xla/core/xla_model.py
@@ -491,8 +491,7 @@ def all_reduce(reduce_type, inputs, scale=1.0, groups=None, pin_layout=True):
if scale == 1.0 and groups == [] and pin_layout:
# TODO(alanwaketan): Support groups.
# Only c10d_functional version cc ops are traceable by Dynamo.
- result = torch.ops.c10d_functional.all_reduce(inputs, reduce_type, "", [],
- 0)
+ result = torch.ops._c10d_functional.all_reduce(inputs, reduce_type, "")
else:
result = torch_xla._XLAC._xla_all_reduce(reduce_type, inputs, scale,
groups, pin_layout)
@@ -1046,7 +1045,7 @@ def _run_step_closures():
return devctx
-def mark_step(wait=False):
+def mark_step(wait=False, reset_scope=True):
if xu.getenv_as('XLA_EMIT_STEPLOG', bool, False):
print(
'torch_xla.core.xla_model::mark_step\n',
@@ -1055,7 +1054,8 @@ def mark_step(wait=False):
flush=True)
torch_xla._XLAC._xla_step_marker(
torch_xla._XLAC._xla_get_default_device(), [],
- wait=xu.getenv_as('XLA_SYNC_WAIT', bool, wait))
+ wait=xu.getenv_as('XLA_SYNC_WAIT', bool, wait),
+ reset_scope=reset_scope)
# Only emit metrics from the first local device index, to avoid emitting the
# same values from different threads.
if is_master_ordinal():
diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp
index 1c91b29bc5b..56a69ca1e05 100644
--- a/torch_xla/csrc/aten_xla_type.cpp
+++ b/torch_xla/csrc/aten_xla_type.cpp
@@ -1290,6 +1290,38 @@ at::Tensor XLANativeFunctions::embedding_dense_backward(
num_weights, padding_idx, scale_grad_by_freq));
}
+std::tuple
+XLANativeFunctions::_embedding_bag_forward_only(
+ const at::Tensor& weight, const at::Tensor& indices,
+ const at::Tensor& offsets, bool scale_grad_by_freq, int64_t mode,
+ bool sparse, const c10::optional& per_sample_weights,
+ bool include_last_offset, int64_t padding_idx) {
+ TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
+ if (mode == 1 || scale_grad_by_freq || sparse || padding_idx != -1) {
+ return at::native::call_fallback_fn<
+ &xla_cpu_fallback,
+ ATEN_OP(_embedding_bag_forward_only)>::call(weight, indices, offsets,
+ scale_grad_by_freq, mode,
+ sparse, per_sample_weights,
+ include_last_offset,
+ padding_idx);
+ }
+ auto indices_tensor = bridge::GetXlaTensor(indices);
+ auto sample_weights =
+ per_sample_weights.has_value() && per_sample_weights.value().defined()
+ ? bridge::GetXlaTensor(per_sample_weights.value())
+ : tensor_methods::full_like(indices_tensor, 1.0,
+ *torch_xla::bridge::GetXlaDevice(weight),
+ at::ScalarType::Float);
+ auto result = tensor_methods::embedding_bag(
+ bridge::GetXlaTensor(weight), indices_tensor,
+ bridge::GetXlaTensor(offsets), mode, sample_weights, include_last_offset);
+ return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(result)),
+ bridge::AtenFromXlaTensor(std::get<1>(result)),
+ bridge::AtenFromXlaTensor(std::get<2>(result)),
+ bridge::AtenFromXlaTensor(std::get<3>(result)));
+}
+
at::Tensor XLANativeFunctions::empty_symint(
at::SymIntArrayRef sym_size, c10::optional dtype,
c10::optional layout, c10::optional device,
@@ -3709,6 +3741,7 @@ at::Tensor XLANativeFunctions::embedding_symint(const at::Tensor& weight,
scale_grad_by_freq, sparse);
}
+ // TODO: We need to make use of the TPU embedding core here eventually.
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return bridge::AtenFromXlaTensor(tensor_methods::embedding(
bridge::GetXlaTensor(weight), bridge::GetXlaTensor(indices)));
@@ -3748,10 +3781,8 @@ at::Tensor XLANativeFunctions::narrow_copy_symint(const at::Tensor& self,
at::Tensor XLANativeFunctions::pixel_shuffle(const at::Tensor& self,
int64_t upscale_factor) {
- XLA_CHECK(
- !runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false));
- return at::functionalization::functionalize_aten_op::call(self, upscale_factor);
+ return bridge::AtenFromXlaTensor(tensor_methods::pixel_shuffle(
+ bridge::GetXlaTensor(self), upscale_factor));
}
at::Tensor XLANativeFunctions::pixel_unshuffle(const at::Tensor& self,
diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp
index 1778fd81e61..d140a486223 100644
--- a/torch_xla/csrc/cross_replica_reduces.cpp
+++ b/torch_xla/csrc/cross_replica_reduces.cpp
@@ -112,23 +112,20 @@ std::shared_ptr CreateToken(
// order. RFC: https://github.com/pytorch/pytorch/issues/93173
////////////////////////////////////////////////////////////////////////////////////
-// tag is ignored as it's only used in PyTorch to provide backward compatibility
-// with the traditional process group API.
-at::Tensor all_reduce(const at::Tensor& self, c10::string_view reduceOp,
- c10::string_view /*tag*/, at::IntArrayRef /*ranks*/,
- int64_t /*group_size*/) {
+at::Tensor all_reduce(const at::Tensor& self, std::string reduceOp,
+ std::string /*group_name*/) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
auto self_tensor = bridge::GetXlaTensor(self);
- // TODO(alanwaketan): Use ranks and group_size to generate groups. Currently
- // we just suse {} as a workaround. Scale is always 1.0 here, and we always
- // pin layout.
+ // TODO(alanwaketan): Use group_name to generate groups. Currently we just
+ // use {} as a workaround. Scale is always 1.0 here, and we always pin
+ // layout.
auto result = tensor_methods::all_reduce(self_tensor, GetReduceType(reduceOp),
/*scale*/ 1.0,
/*groups*/ {}, /*pin_layout*/ true);
return bridge::AtenFromXlaTensor(result);
}
-TORCH_LIBRARY_IMPL(c10d_functional, XLA, m) {
+TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) {
m.impl("all_reduce", all_reduce);
}
diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp
index 7a4a52ad1d2..c603e5d27a5 100644
--- a/torch_xla/csrc/init_python_bindings.cpp
+++ b/torch_xla/csrc/init_python_bindings.cpp
@@ -182,6 +182,12 @@ std::vector GetXlaTensors(const std::vector& tensors,
return xtensors;
}
+bool IsNonDeviceDataIR(const at::Tensor& tensor) {
+ XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
+ return xtensor->CurrentIrValue() &&
+ !DeviceData::Cast(xtensor->CurrentIrValue().node.get());
+}
+
std::vector> CreateReduceGroups(const py::list& groups) {
std::vector> replica_groups;
for (auto& group : groups) {
@@ -458,12 +464,13 @@ void SyncLiveTensors(const std::string& device_str,
}
void StepMarker(const std::string& device_str,
- const std::vector& devices, bool wait) {
+ const std::vector& devices, bool wait,
+ bool reset_scope) {
tsl::profiler::TraceMe activity("StepMarker",
tsl::profiler::TraceMeLevel::kInfo);
torch::lazy::BackendDevice device = GetDeviceOrCurrent(device_str);
XLAGraphExecutor::Get()->SyncLiveTensorsGraph(&device, devices, wait);
- XLAGraphExecutor::Get()->MarkStep(device);
+ XLAGraphExecutor::Get()->MarkStep(device, reset_scope);
bool debug_mode = runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false);
if (TF_PREDICT_FALSE(debug_mode)) {
std::string report = runtime::metrics::CreatePerformanceReport(
@@ -901,6 +908,43 @@ class PyLoweringContext {
lowering_ctx.AddResult(root);
}
computation = ConsumeValue(lowering_ctx.BuildXla());
+ }
+
+ // Builds a HLO graph given a set of output tensors, and add unused parameters
+ // needed in xlacomputation.
+ void BuildForiLoop(std::vector tensors,
+ std::vector input_arguments = {}) {
+ if (GetNameString() == "condctx") {
+ xla::XlaBuilder* local_builder = lowering_ctx.builder();
+ // hard-code parameter_idx to 2 to skip existing upper/lower arguments
+ int64_t parameter_idx = 2;
+ for (at::Tensor input_argument : input_arguments) {
+ xla::Shape shape =
+ xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1});
+ xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape,
+ "UnusedArgumentsPlaceholder");
+ parameter_idx += 1;
+ }
+ }
+
+ // Get the backing XLA tensors from the output torch tensor handles
+ std::vector xtensors =
+ GetXlaTensors(tensors, /*want_all=*/true);
+
+ // Get the lazy IR value from the output XLA tensors
+ std::vector ir_values;
+ for (auto& xtensor : xtensors) {
+ torch::lazy::Value value = xtensor->GetIrValue();
+ ir_values.push_back(value);
+ }
+
+ // Lower the graph using the output IR values
+ for (auto& ir_value : ir_values) {
+ xla::XlaOp root = lowering_ctx.GetOutputOp(
+ torch::lazy::Output(ir_value.node.get(), ir_value.index));
+ lowering_ctx.AddResult(root);
+ }
+ computation = ConsumeValue(lowering_ctx.BuildXla());
// wrap inputs of cond/body_computation
if ((GetNameString() == "condctx") || (GetNameString() == "bodyctx")) {
@@ -1043,6 +1087,7 @@ void BuildLoweringContextSubmodule(py::module* m) {
lowering_context_class.def(py::init<>())
.def("build", &PyLoweringContext::Build)
+ .def("buildforiloop", &PyLoweringContext::BuildForiLoop)
.def("hlo", &PyLoweringContext::GetHlo)
.def("hlo_text", &PyLoweringContext::GetHloText)
.def("hlo_json", &PyLoweringContext::GetHloJsonText)
@@ -1649,11 +1694,12 @@ void InitXlaModuleBindings(py::module m) {
m.def(
"_xla_step_marker",
[](const std::string& device, const std::vector& devices,
- bool wait) {
+ bool wait, bool reset_scope) {
NoGilSection nogil;
- StepMarker(device, devices, wait);
+ StepMarker(device, devices, wait, reset_scope);
},
- py::arg("device") = "", py::arg("devices"), py::arg("wait") = true);
+ py::arg("device") = "", py::arg("devices"), py::arg("wait") = true,
+ py::arg("reset_scope") = true);
m.def("_get_stablehlo",
[](const std::vector& tensors, const std::string& device,
const std::vector& devices,
@@ -1899,6 +1945,49 @@ void InitXlaModuleBindings(py::module m) {
[](const at::Tensor& input, xla::OpSharding sharding) {
ShardingUtil::XlaMarkSharding(input, sharding);
});
+ m.def("_mark_manual_sharding",
+ [](const at::Tensor& input, xla::OpSharding sharding) {
+ XLA_CHECK(IsNonDeviceDataIR(input))
+ << "Marking any data tensors as manual is not supported";
+ ShardingUtil::XlaMarkSharding(input, sharding);
+ });
+ m.def("_spmd_full_to_shard_shape", [](const at::Tensor& input) -> at::Tensor {
+ XLATensorPtr xtensor = bridge::GetXlaTensor(input);
+ auto sharding_spec = xtensor->sharding_spec();
+ XLA_CHECK(sharding_spec != nullptr) << "Input tensor is not sharded";
+
+ auto shard_shape = xla::ShapeUtil::MakeShape(
+ MakeXlaPrimitiveType(xtensor->dtype(), &(xtensor->GetDevice())),
+ ShardingUtil::GetShardShape(sharding_spec));
+ auto output = xtensor->CreateFrom(torch::lazy::MakeNode(
+ xtensor->GetIrValue(), shard_shape,
+ CustomSharding::Type::kSPMDFullToShardShape));
+ output->SetShardingSpec(XLATensor::ShardingSpec(
+ xla::HloSharding::Manual().ToProto(), shard_shape));
+ return bridge::AtenFromXlaTensor(output);
+ });
+ m.def(
+ "_spmd_shard_to_full_shape",
+ [](const at::Tensor& input, const xla::OpSharding& sharding,
+ const std::vector& output_shape,
+ const py::object& output_dtype) -> at::Tensor {
+ XLATensorPtr xtensor = bridge::GetXlaTensor(input);
+ auto sharding_spec = xtensor->sharding_spec();
+ XLA_CHECK(sharding_spec != nullptr &&
+ sharding_spec->sharding.type() == xla::OpSharding::MANUAL)
+ << "Input tensor is not manual sharded";
+
+ auto full_shape = xla::ShapeUtil::MakeShape(
+ MakeXlaPrimitiveType(
+ reinterpret_cast(output_dtype.ptr())->scalar_type,
+ &(xtensor->GetDevice())),
+ output_shape);
+ auto output = xtensor->CreateFrom(torch::lazy::MakeNode(
+ xtensor->GetIrValue(), full_shape,
+ CustomSharding::Type::kSPMDShardToFullShape));
+ output->SetShardingSpec(XLATensor::ShardingSpec(sharding, full_shape));
+ return bridge::AtenFromXlaTensor(output);
+ });
m.def("_xla_mark_sharding_dynamo_custom_op",
[](const at::Tensor& input, const py::list& tile_assignment,
const py::list& group_assignment, const py::list& replication_groups,
diff --git a/torch_xla/csrc/ops/custom_sharding.cpp b/torch_xla/csrc/ops/custom_sharding.cpp
index 3fccc1b497f..0a6b8bf486f 100644
--- a/torch_xla/csrc/ops/custom_sharding.cpp
+++ b/torch_xla/csrc/ops/custom_sharding.cpp
@@ -5,24 +5,42 @@
#include "torch_xla/csrc/xla_lower_util.h"
namespace torch_xla {
+namespace {
+std::string TypeToString(const CustomSharding::Type& type) {
+ switch (type) {
+ case CustomSharding::Type::kSharding:
+ return "Sharding";
+ case CustomSharding::Type::kSPMDFullToShardShape:
+ return "SPMDFullToShardShape";
+ case CustomSharding::Type::kSPMDShardToFullShape:
+ return "SPMDShardToFullShape";
+ }
+}
+} // namespace
-CustomSharding::CustomSharding(const torch::lazy::Value& input)
- : XlaNode(xla_custom_sharding, {input}, GetXlaShape(input),
- /*num_outputs=*/1, torch::lazy::MHash(std::string("Sharding"))) {}
+CustomSharding::CustomSharding(const torch::lazy::Value& input,
+ const xla::Shape& output_shape,
+ const CustomSharding::Type& type)
+ : XlaNode(xla_custom_sharding, {input}, output_shape,
+ /*num_outputs=*/1, torch::lazy::MHash(static_cast(type))),
+ type(type),
+ output_shape(output_shape) {}
torch::lazy::NodePtr CustomSharding::Clone(torch::lazy::OpList operands) const {
- return torch::lazy::MakeNode(operands.at(0));
+ return torch::lazy::MakeNode(operands.at(0), output_shape,
+ type);
}
XlaOpVector CustomSharding::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
- xla::XlaOp output = BuildCustomSharding(input);
+ xla::XlaOp output =
+ BuildCustomSharding(input, TypeToString(type), output_shape);
return ReturnOp(output, loctx);
}
std::string CustomSharding::ToString() const {
std::stringstream ss;
- ss << XlaNode::ToString() << ", Sharding";
+ ss << XlaNode::ToString() << ", " << TypeToString(type);
return ss.str();
}
diff --git a/torch_xla/csrc/ops/custom_sharding.h b/torch_xla/csrc/ops/custom_sharding.h
index 6a956d97423..5c7599cce2b 100644
--- a/torch_xla/csrc/ops/custom_sharding.h
+++ b/torch_xla/csrc/ops/custom_sharding.h
@@ -7,14 +7,27 @@ namespace torch_xla {
class CustomSharding : public XlaNode {
public:
+ // The following enum represents the custom_call_target being
+ // passed to xla builder. The actual sharding will still be
+ // attached to the XLATensor.
+ enum class Type {
+ kSharding,
+ kSPMDFullToShardShape,
+ kSPMDShardToFullShape,
+ };
+
// Make a custom call to Sharding.
- CustomSharding(const torch::lazy::Value& input);
+ CustomSharding(const torch::lazy::Value& input,
+ const xla::Shape& output_shape, const Type& type);
torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;
XlaOpVector Lower(LoweringContext* loctx) const override;
std::string ToString() const override;
+
+ Type type;
+ xla::Shape output_shape;
};
} // namespace torch_xla
diff --git a/torch_xla/csrc/ops/embedding_bag.cpp b/torch_xla/csrc/ops/embedding_bag.cpp
new file mode 100644
index 00000000000..d2bb034a005
--- /dev/null
+++ b/torch_xla/csrc/ops/embedding_bag.cpp
@@ -0,0 +1,192 @@
+#include "torch_xla/csrc/ops/embedding_bag.h"
+
+#include "torch_xla/csrc/helpers.h"
+#include "torch_xla/csrc/lowering_context.h"
+#include "torch_xla/csrc/ops/infer_output_shape.h"
+#include "torch_xla/csrc/ops/xla_ops.h"
+#include "torch_xla/csrc/shape_helper.h"
+#include "torch_xla/csrc/xla_lower_util.h"
+#include "tsl/platform/stacktrace.h"
+#include "xla/client/lib/constants.h"
+#include "xla/client/lib/loops.h"
+#include "xla/client/lib/slicing.h"
+#include "xla/shape_util.h"
+
+namespace torch_xla {
+namespace {
+const int MODE_SUM = 0;
+const int MODE_MEAN = 1;
+const int MODE_MAX = 2;
+std::vector BuildEmbeddingBag(xla::XlaOp weight, xla::XlaOp indices,
+ xla::XlaOp offsets,
+ xla::XlaOp per_sample_weights,
+ bool include_last_offset, int mode) {
+ xla::Shape offset_shape = ShapeHelper::ShapeOfXlaOp(offsets);
+ int64_t n = offset_shape.dimensions(0);
+ xla::Shape weight_shape = ShapeHelper::ShapeOfXlaOp(weight);
+ int64_t weight_dim = weight_shape.dimensions(1);
+ xla::Shape indices_shape = ShapeHelper::ShapeOfXlaOp(indices);
+ int64_t num_embeddings = indices_shape.dimensions(0);
+ XLA_CHECK(indices_shape.rank() == 1 || indices_shape.rank() == 2)
+ << "input has to be a 1D or 2D Tensor, but got Tensor of dimension "
+ << indices_shape.rank();
+ if (indices_shape.rank() == 1) {
+ XLA_CHECK(offset_shape.rank() == 1)
+ << "offsets has to be a 1D Tensor, but got Tensor of dimension "
+ << offset_shape.rank();
+ }
+ XLA_CHECK(weight_shape.rank() == 2)
+ << "weight has to be a 2D Tensor, but got Tensor of dimension "
+ << weight_shape.rank();
+
+ xla::XlaOp output2 = xla::ZerosLike(indices);
+ xla::XlaOp output3 = xla::ZerosLike(offsets);
+ std::vector sizes = {n, weight_dim};
+ xla::XlaOp output4 =
+ xla::Zeros(offsets.builder(),
+ xla::ShapeUtil::MakeShape(offset_shape.element_type(), sizes));
+
+ xla::XlaOp embeddings = xla::TorchIndexSelect(weight, indices, 0);
+ xla::XlaOp embeddings_weighted = xla::Mul(
+ embeddings, xla::ConvertElementType(
+ xla::BroadcastInDim(per_sample_weights,
+ {num_embeddings, weight_dim}, {0}),
+ weight_shape.element_type()));
+
+ std::vector shape_elements = {
+ xla::ShapeUtil::MakeShape(offset_shape.element_type(), {}),
+ xla::ShapeUtil::MakeShape(offset_shape.element_type(), {}),
+ xla::ShapeUtil::MakeShape(weight_shape.element_type(),
+ {num_embeddings, weight_dim}),
+ xla::ShapeUtil::MakeShape(weight_shape.element_type(), {1, weight_dim})};
+ xla::Shape result_shape = xla::ShapeUtil::MakeTupleShape(shape_elements);
+
+ xla::XlaComputation condition;
+ {
+ xla::XlaBuilder builder("condition");
+ auto prev = xla::Parameter(&builder, 0, result_shape, "prev");
+ auto index = xla::GetTupleElement(prev, 0);
+ auto final_value = xla::GetTupleElement(prev, 1);
+ xla::Lt(index, final_value);
+ condition = builder.Build().value();
+ }
+
+ xla::XlaComputation body;
+ {
+ xla::XlaBuilder builder("body");
+ auto prev = xla::Parameter(&builder, 0, result_shape, "prev");
+ auto index = xla::GetTupleElement(prev, 0);
+ auto emb = xla::GetTupleElement(prev, 2);
+ auto w = xla::GetTupleElement(prev, 3);
+
+ xla::XlaOp slice = xla::DynamicSlice(
+ emb,
+ {index, xla::ConvertElementType(xla::ConstantR0(&builder, 0),
+ offset_shape.element_type())},
+ {1, weight_dim});
+ xla::XlaOp result =
+ mode == MODE_SUM ? xla::Add(w, slice) : xla::Max(w, slice);
+
+ xla::Tuple(&builder,
+ {
+ xla::Add(index, xla::ConvertElementType(
+ xla::ConstantR0(&builder, 1),
+ offset_shape.element_type())),
+ xla::GetTupleElement(prev, 1),
+ xla::GetTupleElement(prev, 2),
+ result,
+ });
+ body = builder.Build().value();
+ }
+
+ xla::Array initial_vector({1, weight_dim}, 0.f);
+ std::vector results;
+ for (int64_t i = 0; i < n; i++) {
+ xla::XlaOp start = xla::DynamicSlice(
+ offsets, {xla::ConstantR0(offsets.builder(), i)}, {1});
+ if (i == n - 1 && include_last_offset) continue;
+ xla::XlaOp end =
+ i == n - 1 && !include_last_offset
+ ? xla::ConvertElementType(xla::ConstantR1(
+ offsets.builder(), 1, num_embeddings),
+ offset_shape.element_type())
+ : xla::DynamicSlice(
+ offsets, {xla::ConstantR0(offsets.builder(), i + 1)},
+ {1});
+ // Create a While node with computations for the condition and the body.
+ auto init_tuple = xla::Tuple(
+ offsets.builder(),
+ {xla::Reshape(start, {0}, {}), xla::Reshape(end, {0}, {}),
+ embeddings_weighted,
+ xla::ConvertElementType(
+ xla::ConstantFromArray(offsets.builder(), initial_vector),
+ weight_shape.element_type())});
+ auto result = xla::While(condition, body, init_tuple);
+ results.push_back(xla::GetTupleElement(result, 3));
+ };
+ xla::XlaOp output1 = xla::ConcatInDim(offsets.builder(), results, 0);
+ return {output1, output2, output3, output4};
+}
+
+xla::Shape NodeOutputShapes(const torch::lazy::Value& weight,
+ const torch::lazy::Value& indices,
+ const torch::lazy::Value& offsets,
+ const torch::lazy::Value& per_sample_weights,
+ bool include_last_offset, bool mode) {
+ auto lower_for_shapes_fn =
+ [&](absl::Span operands) -> xla::XlaOp {
+ return xla::Tuple(
+ operands[0].builder(),
+ BuildEmbeddingBag(operands[0], operands[1], operands[2], operands[3],
+ include_last_offset, mode));
+ };
+
+ std::vector input_shapes = {
+ GetXlaShape(weight), GetXlaShape(indices), GetXlaShape(offsets),
+ GetXlaShape(per_sample_weights)};
+
+ return InferOutputShape(absl::MakeSpan(input_shapes), lower_for_shapes_fn);
+}
+} // namespace
+
+std::string EmbeddingBag::ToString() const {
+ std::stringstream ss;
+ ss << XlaNode::ToString();
+ return ss.str();
+}
+
+EmbeddingBag::EmbeddingBag(const torch::lazy::Value& weight,
+ const torch::lazy::Value& indices,
+ const torch::lazy::Value& offsets, int64_t mode,
+ const torch::lazy::Value& per_sample_weights,
+ bool include_last_offset)
+ : XlaNode(
+ torch::lazy::OpKind(at::aten::embedding_bag),
+ {weight, indices, offsets, per_sample_weights},
+ [&]() {
+ return NodeOutputShapes(weight, indices, offsets,
+ per_sample_weights, include_last_offset,
+ mode);
+ },
+ /*num_outputs=*/4, torch::lazy::MHash(mode, include_last_offset)),
+ mode_(mode),
+ include_last_offset_(include_last_offset) {}
+
+torch::lazy::NodePtr EmbeddingBag::Clone(torch::lazy::OpList operands) const {
+ return torch::lazy::MakeNode(operands.at(0), operands.at(1),
+ operands.at(2), mode_,
+ operands.at(3), false);
+}
+
+XlaOpVector EmbeddingBag::Lower(LoweringContext* loctx) const {
+ xla::XlaOp weight = loctx->GetOutputOp(operand(0));
+ xla::XlaOp indices = loctx->GetOutputOp(operand(1));
+ xla::XlaOp offsets = loctx->GetOutputOp(operand(2));
+ xla::XlaOp per_sample_weights = loctx->GetOutputOp(operand(3));
+ std::vector ops =
+ BuildEmbeddingBag(weight, indices, offsets, per_sample_weights,
+ include_last_offset_, mode_);
+ return ReturnOps(absl::MakeSpan(ops), loctx);
+}
+
+} // namespace torch_xla
\ No newline at end of file
diff --git a/torch_xla/csrc/ops/embedding_bag.h b/torch_xla/csrc/ops/embedding_bag.h
new file mode 100644
index 00000000000..4d9b0a6eecb
--- /dev/null
+++ b/torch_xla/csrc/ops/embedding_bag.h
@@ -0,0 +1,31 @@
+#ifndef XLA_TORCH_XLA_CSRC_OPS_EMBEDDING_BAG_H_
+#define XLA_TORCH_XLA_CSRC_OPS_EMBEDDING_BAG_H_
+
+#include
+
+#include "torch_xla/csrc/ir.h"
+
+namespace torch_xla {
+
+class EmbeddingBag : public XlaNode {
+ public:
+ EmbeddingBag(const torch::lazy::Value& weight,
+ const torch::lazy::Value& indices,
+ const torch::lazy::Value& offsets, int64_t mode,
+ const torch::lazy::Value& per_sample_weights,
+ bool include_last_offset);
+
+ std::string ToString() const override;
+
+ torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;
+
+ XlaOpVector Lower(LoweringContext* loctx) const override;
+
+ private:
+ int64_t mode_;
+ bool include_last_offset_;
+};
+
+} // namespace torch_xla
+
+#endif // XLA_TORCH_XLA_CSRC_OPS_EMBEDDING_BAG_H_
\ No newline at end of file
diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp
index 7391f8ff714..af4daf28648 100644
--- a/torch_xla/csrc/ops/ops.cpp
+++ b/torch_xla/csrc/ops/ops.cpp
@@ -593,6 +593,34 @@ torch::lazy::NodePtr Pdist_forward(const torch::lazy::Value& input,
std::move(lower_fn), 1);
}
+torch::lazy::NodePtr PixelShuffle(const torch::lazy::Value& input,
+ int64_t upscale_factor) {
+ auto lower_fn = [=](const XlaNode& node,
+ LoweringContext* loctx) -> XlaOpVector {
+ xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
+ return node.ReturnOp(BuildPixelShuffle(xla_input, upscale_factor), loctx);
+ };
+ auto lower_for_shape_fn =
+ [&](absl::Span operands) -> xla::XlaOp {
+ return BuildPixelShuffle(operands[0], upscale_factor);
+ };
+ const xla::Shape& input_shape = GetXlaShape(input);
+ absl::Span dimensions = input_shape.dimensions();
+ int64_t channels = dimensions[1];
+
+ if (channels % (upscale_factor * upscale_factor) != 0) {
+ XLA_ERROR() << "Number of channels must be divisible by the square of the "
+ "upscale factor.";
+ }
+
+ return GenericOp(
+ torch::lazy::OpKind(at::aten::pixel_shuffle), {input},
+ [&]() {
+ return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
+ },
+ std::move(lower_fn), 1);
+}
+
torch::lazy::NodePtr LinalgVectorNorm(const torch::lazy::Value& input,
const at::Scalar& ord,
std::vector dimensions,
diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h
index 013474aa03c..5d423b3b1ee 100644
--- a/torch_xla/csrc/ops/ops.h
+++ b/torch_xla/csrc/ops/ops.h
@@ -177,6 +177,9 @@ torch::lazy::NodePtr Pdist_forward(const torch::lazy::Value& input,
const c10::optional& p,
c10::optional dtype);
+torch::lazy::NodePtr PixelShuffle(const torch::lazy::Value& input,
+ int64_t upscale_factor);
+
torch::lazy::NodePtr LinalgVectorNorm(const torch::lazy::Value& input,
const at::Scalar& ord,
std::vector dimensions,
diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD
index 6f746972355..582b69d8a50 100644
--- a/torch_xla/csrc/runtime/BUILD
+++ b/torch_xla/csrc/runtime/BUILD
@@ -237,7 +237,7 @@ cc_library(
deps = [
":debug_macros",
":sys_util",
- "@tsl//tsl/distributed_runtime/preemption:preemption_sync_manager",
+ "@xla//xla/tsl/distributed_runtime/preemption:preemption_sync_manager",
"@xla//xla/pjrt/distributed",
],
)
diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc
index 029f9268342..20ee9b0bfa6 100644
--- a/torch_xla/csrc/runtime/ifrt_computation_client.cc
+++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc
@@ -58,18 +58,6 @@ torch::lazy::hash_t hash_comp_env(
xla::ifrt::Client* client,
std::vector& ordered_devices) {
torch::lazy::hash_t hash = hash::HashXlaEnvVars();
- auto topology_desc = client->GetTopologyForDevices(ordered_devices);
- if (topology_desc.ok()) {
- // Some backends support a topology description which provides a better
- // view of the specific compilation environment.
- auto serialized = topology_desc.value()->Serialize();
- if (serialized.ok()) {
- return torch::lazy::HashCombine(
- hash,
- torch::lazy::DataHash(serialized->data(), serialized->length()));
- }
- // If serialization fails, fallthrough to the manual approach.
- }
std::string platform_name(client->platform_name());
std::string platform_version(client->platform_version());
hash = torch::lazy::HashCombine(
@@ -78,10 +66,26 @@ torch::lazy::hash_t hash_comp_env(
hash = torch::lazy::HashCombine(
hash, torch::lazy::StringHash(platform_version.c_str()));
// Include global devices in the hash, ensuring order is consistent.
+ xla::ifrt::DeviceList::Devices ifrt_devices;
for (auto& device : ordered_devices) {
std::string device_str(device->ToString());
hash = torch::lazy::HashCombine(
hash, torch::lazy::StringHash(device_str.c_str()));
+ ifrt_devices.push_back(device);
+ }
+
+ xla::ifrt::DeviceList device_list(std::move(ifrt_devices));
+ auto topology_desc = client->GetTopologyForDevices(device_list);
+ if (topology_desc.ok()) {
+ // Some backends support a topology description which provides a better
+ // view of the specific compilation environment.
+ auto serialized = topology_desc.value()->Serialize();
+ if (serialized.ok()) {
+ return torch::lazy::HashCombine(
+ hash,
+ torch::lazy::DataHash(serialized->data(), serialized->length()));
+ }
+ // If serialization fails, fallthrough to the manual approach.
}
return hash;
}
diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc
index 188e26f8ac2..a129a476a2e 100644
--- a/torch_xla/csrc/runtime/pjrt_computation_client.cc
+++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc
@@ -463,7 +463,7 @@ std::vector PjRtComputationClient::TransferFromDevice(
metrics::TimedSection timed(TransferFromDeviceMetric());
tsl::profiler::TraceMe activity("PjRtComputationClient::TransferFromDevice",
tsl::profiler::TraceMeLevel::kInfo);
- std::vector> futures;
+ std::vector> futures;
futures.reserve(handles.size());
std::vector literals;
literals.reserve(handles.size());
@@ -679,7 +679,7 @@ PjRtComputationClient::ExecuteComputation(
TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device
<< " Done";
- std::optional> returned_future;
+ std::optional> returned_future;
std::vector> results =
pjrt_computation.executable
->ExecuteSharded(buffers, pjrt_device, execute_options,
@@ -779,8 +779,8 @@ PjRtComputationClient::ExecuteReplicated(
TF_VLOG(5) << "ExecuteReplicated acquiring PJRT device lock for "
<< spmd_device_str << " Done";
- std::optional>> returned_futures =
- std::vector>();
+ std::optional>> returned_futures =
+ std::vector>();
std::vector>> results;
{
tsl::profiler::TraceMe activity(
diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc
index 648076757be..99e23f4b555 100644
--- a/torch_xla/csrc/runtime/pjrt_registry.cc
+++ b/torch_xla/csrc/runtime/pjrt_registry.cc
@@ -60,7 +60,8 @@ InitializePjRt(const std::string& device_type) {
std::unique_ptr client;
std::unique_ptr coordinator;
- if (sys_util::GetEnvBool(env::kEnvPjrtDynamicPlugins, false)) {
+ if (sys_util::GetEnvBool(env::kEnvPjrtDynamicPlugins, false) &&
+ device_type != "CPU") {
std::shared_ptr plugin = GetPjRtPlugin(device_type);
if (plugin) {
TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type;
diff --git a/torch_xla/csrc/runtime/xla_coordinator.h b/torch_xla/csrc/runtime/xla_coordinator.h
index ae85c79a941..fb2cfaf99f5 100644
--- a/torch_xla/csrc/runtime/xla_coordinator.h
+++ b/torch_xla/csrc/runtime/xla_coordinator.h
@@ -3,8 +3,8 @@
#include
-#include "tsl/distributed_runtime/preemption/preemption_sync_manager.h"
#include "xla/pjrt/distributed/distributed.h"
+#include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h"
namespace torch_xla {
namespace runtime {
diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp
index b9ab1bc2764..fbb240f31d3 100644
--- a/torch_xla/csrc/tensor_methods.cpp
+++ b/torch_xla/csrc/tensor_methods.cpp
@@ -40,7 +40,6 @@
#include "torch_xla/csrc/ops/count_nonzero.h"
#include "torch_xla/csrc/ops/cumprod.h"
#include "torch_xla/csrc/ops/cumsum.h"
-#include "torch_xla/csrc/ops/custom_sharding.h"
#include "torch_xla/csrc/ops/dequant_tensor.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/diagonal.h"
@@ -49,6 +48,7 @@
#include "torch_xla/csrc/ops/dynamic_view.h"
#include "torch_xla/csrc/ops/einsum.h"
#include "torch_xla/csrc/ops/einsum_backward.h"
+#include "torch_xla/csrc/ops/embedding_bag.h"
#include "torch_xla/csrc/ops/expand.h"
#include "torch_xla/csrc/ops/expand_symint.h"
#include "torch_xla/csrc/ops/exponential.h"
@@ -522,9 +522,10 @@ std::pair collective_permute(
void custom_sharding_(
const XLATensorPtr& input,
- const std::shared_ptr& sharding_spec) {
- input->SetInPlaceIrValue(
- torch::lazy::MakeNode(input->GetIrValue()));
+ const std::shared_ptr& sharding_spec,
+ const CustomSharding::Type& type) {
+ input->SetInPlaceIrValue(torch::lazy::MakeNode(
+ input->GetIrValue(), input->shape().get(), type));
input->SetShardingSpec(*sharding_spec);
}
@@ -1010,6 +1011,12 @@ XLATensorPtr pdist_forward(const XLATensorPtr& input, double p) {
return input->CreateFrom(Pdist_forward(input->GetIrValue(), p, dtype));
}
+XLATensorPtr pixel_shuffle(const XLATensorPtr& input, int64_t upscale_factor) {
+ c10::optional dtype = input->dtype_optional();
+ torch::lazy::NodePtr node = PixelShuffle(input->GetIrValue(), upscale_factor);
+ return input->CreateFrom(node, dtype);
+}
+
XLATensorPtr celu(const XLATensorPtr& input, const at::Scalar& alpha) {
return input->CreateFrom(Celu(input->GetIrValue(), alpha));
}
@@ -1286,6 +1293,20 @@ XLATensorPtr embedding(const XLATensorPtr& weight,
return tensor_ops::Embedding(weight, indices);
}
+std::tuple
+embedding_bag(const XLATensorPtr& weight, const XLATensorPtr& indices,
+ const XLATensorPtr& offsets, int64_t mode,
+ const XLATensorPtr& per_sample_weights,
+ bool include_last_offset) {
+ torch::lazy::NodePtr node = torch::lazy::MakeNode(
+ weight->GetIrValue(), indices->GetIrValue(), offsets->GetIrValue(), mode,
+ per_sample_weights->GetIrValue(), include_last_offset);
+ return std::make_tuple(weight->CreateFrom(torch::lazy::Value(node, 0)),
+ weight->CreateFrom(torch::lazy::Value(node, 1)),
+ weight->CreateFrom(torch::lazy::Value(node, 2)),
+ weight->CreateFrom(torch::lazy::Value(node, 3)));
+}
+
XLATensorPtr exp(const XLATensorPtr& input) {
return input->CreateFrom(Exp(input->GetIrValue()));
}
diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h
index 991c105a78a..6a7005a5f0f 100644
--- a/torch_xla/csrc/tensor_methods.h
+++ b/torch_xla/csrc/tensor_methods.h
@@ -2,6 +2,7 @@
#define XLA_TORCH_XLA_CSRC_TENSOR_METHODS_H_
#include "torch_xla/csrc/cross_replica_reduces.h"
+#include "torch_xla/csrc/ops/custom_sharding.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/tensor.h"
@@ -79,8 +80,10 @@ std::pair collective_permute(
const XLATensorPtr& input, const torch::lazy::Value& token,
std::vector> source_target_pairs);
-void custom_sharding_(const XLATensorPtr& input,
- const std::shared_ptr& spec);
+void custom_sharding_(
+ const XLATensorPtr& input,
+ const std::shared_ptr& spec,
+ const CustomSharding::Type& type = CustomSharding::Type::kSharding);
std::vector tpu_custom_call(
const std::vector& inputs, const std::string& payload,
@@ -287,6 +290,8 @@ XLATensorPtr cdist_forward(const XLATensorPtr& x1, const XLATensorPtr& x2,
XLATensorPtr pdist_forward(const XLATensorPtr& input, double p);
+XLATensorPtr pixel_shuffle(const XLATensorPtr& self, int64_t upscale_factor);
+
XLATensorPtr celu(const XLATensorPtr& input, const at::Scalar& alpha);
void celu_(XLATensorPtr& input, const at::Scalar& alpha);
@@ -376,6 +381,11 @@ XLATensorPtr embedding_dense_backward(const XLATensorPtr& grad_output,
int64_t num_weights, int64_t padding_idx,
bool scale_grad_by_freq);
+std::tuple
+embedding_bag(const XLATensorPtr& weight, const XLATensorPtr& indices,
+ const XLATensorPtr& offsets, int64_t mode,
+ const XLATensorPtr& per_sample_weights, bool include_last_offset);
+
XLATensorPtr embedding(const XLATensorPtr& weight, const XLATensorPtr& indices);
XLATensorPtr eq(const XLATensorPtr& input, const at::Scalar& other);
diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp
index 0eddefc39f3..fe12e392ea4 100644
--- a/torch_xla/csrc/xla_graph_executor.cpp
+++ b/torch_xla/csrc/xla_graph_executor.cpp
@@ -404,14 +404,17 @@ void XLAGraphExecutor::SyncLiveTensorsGraph(
SyncTensorsGraph(&tensors, devices, wait, /*sync_ltc_data=*/true);
}
-void XLAGraphExecutor::MarkStep(const torch::lazy::BackendDevice& device) {
+void XLAGraphExecutor::MarkStep(const torch::lazy::BackendDevice& device,
+ bool reset_scope) {
// TODO(jwtan): Replace this with TORCH_LAZY_COUNTER. We need MarkStep to
// remain as XLA_COUNTER to support
// runtime::metrics::CreatePerformanceReport(). For more information, see
// NOTE: [TORCH_LAZY_COUNTER v.s. XLA_COUNTER].
XLA_COUNTER("MarkStep", 1);
DeviceContextArena::Get()->MarkStep(device);
- torch::lazy::ScopePusher::ResetScopes();
+ if (reset_scope) {
+ torch::lazy::ScopePusher::ResetScopes();
+ }
ResetTrimCounter();
}
diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h
index ca874274a98..b2b76b8ae33 100644
--- a/torch_xla/csrc/xla_graph_executor.h
+++ b/torch_xla/csrc/xla_graph_executor.h
@@ -134,7 +134,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
// Marks an execution step, which allows the tensor framework to understand
// the computation boundaries.
// Override to use our own DeviceContextArena.
- void MarkStep(const torch::lazy::BackendDevice& device) final;
+ void MarkStep(const torch::lazy::BackendDevice& device, bool reset_scope);
// Waits for all the outstanding operations on all the supplied devices.
// If devices is empty, the wait will happen for all local devices.
diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp
index 3efd289d5a1..0954c2fa3ac 100644
--- a/torch_xla/csrc/xla_lower_util.cpp
+++ b/torch_xla/csrc/xla_lower_util.cpp
@@ -1194,6 +1194,27 @@ xla::XlaOp BuildCdistForward(xla::XlaOp x1, xla::XlaOp x2, xla::XlaOp p,
}
}
+xla::XlaOp BuildPixelShuffle(xla::XlaOp input, int64_t upscale_factor) {
+ const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
+ absl::Span dimensions = input_shape.dimensions();
+ int64_t batch_size = dimensions[0];
+ int64_t channels = dimensions[1];
+ int64_t height = dimensions[2];
+ int64_t width = dimensions[3];
+
+ int64_t new_channels = channels / (upscale_factor * upscale_factor);
+ int64_t new_height = height * upscale_factor;
+ int64_t new_width = width * upscale_factor;
+
+ xla::XlaOp tmp =
+ xla::Reshape(input, {batch_size, new_channels, upscale_factor,
+ upscale_factor, height, width});
+ tmp = xla::Transpose(tmp, {0, 1, 4, 2, 5, 3});
+ xla::XlaOp output =
+ xla::Reshape(tmp, {batch_size, new_channels, new_height, new_width});
+ return output;
+}
+
xla::XlaOp BuildMultinomial(xla::XlaOp input, int64_t num_samples,
bool replacement, xla::XlaOp seed) {
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
@@ -1245,9 +1266,10 @@ xla::XlaOp BuildMultinomial(xla::XlaOp input, int64_t num_samples,
return output;
}
-xla::XlaOp BuildCustomSharding(const xla::XlaOp& input) {
- return xla::CustomCall(input.builder(), /*call_target_name=*/"Sharding",
- {input}, ShapeHelper::ShapeOfXlaOp(input));
+xla::XlaOp BuildCustomSharding(const xla::XlaOp& input, const std::string& type,
+ const xla::Shape& output_shape) {
+ return xla::CustomCall(input.builder(), /*call_target_name=*/type, {input},
+ output_shape);
}
std::vector BuildTpuCustomCall(
diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h
index 8e632796c23..d0e9afca9fa 100644
--- a/torch_xla/csrc/xla_lower_util.h
+++ b/torch_xla/csrc/xla_lower_util.h
@@ -148,9 +148,12 @@ xla::XlaOp BuildAddcmul(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
xla::XlaOp BuildCdistForward(xla::XlaOp x1, xla::XlaOp x2, xla::XlaOp p,
bool use_hamming, bool use_chebyshev);
+xla::XlaOp BuildPixelShuffle(xla::XlaOp input, int64_t upscale_factor);
+
xla::XlaOp BuildUpperTriangle(xla::XlaOp input);
-xla::XlaOp BuildCustomSharding(const xla::XlaOp& input);
+xla::XlaOp BuildCustomSharding(const xla::XlaOp& input, const std::string& type,
+ const xla::Shape& output_shape);
std::vector BuildTpuCustomCall(
const std::vector& inputs, const xla::Shape& output_shape,
diff --git a/torch_xla/csrc/xla_manual_registration.cpp b/torch_xla/csrc/xla_manual_registration.cpp
index dc7df436ec7..6020ef6bc04 100644
--- a/torch_xla/csrc/xla_manual_registration.cpp
+++ b/torch_xla/csrc/xla_manual_registration.cpp
@@ -1,7 +1,9 @@
#include
#include
+#include "torch_xla/csrc/aten_cpu_fallback.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
+#include "torch_xla/csrc/debug_util.h"
#include "torch_xla/csrc/ops/nms.h"
#include "torch_xla/csrc/ops/ops.h"
#include "torch_xla/csrc/tensor_methods.h"
@@ -11,10 +13,22 @@ namespace torch_xla {
namespace manual {
namespace {
+struct NmsOp {
+ using schema = at::Tensor(const at::Tensor&, const at::Tensor&, double);
+ using ptr_schema = schema*;
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "torchvision::nms")
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
+};
+
at::Tensor nms_kernel(const at::Tensor& boxes, const at::Tensor& scores,
double iou_threshold) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
+ if (!DebugUtil::ExperimentEnabled("nms")) {
+ return at::native::call_fallback_fn<&xla_cpu_fallback, NmsOp>::call(
+ boxes, scores, iou_threshold);
+ }
+
XLA_CHECK_EQ(boxes.dim(), 2) << "nms(): boxes should be a 2D tensor.";
XLA_CHECK_EQ(boxes.size(1), 4)
<< "nms(): boxes should be a 2D tensor of shape [N, 4].";
diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp
index de01e5ba163..dd3eb566d06 100644
--- a/torch_xla/csrc/xla_sharding_util.cpp
+++ b/torch_xla/csrc/xla_sharding_util.cpp
@@ -240,8 +240,7 @@ xla::OpSharding ShardingUtil::CreateOpSharding(
xla::OpSharding sharding;
switch (sharding_type) {
case ShardingType::MANUAL: {
- TF_LOG(ERROR) << "Invalid arguments: sharding_type (MANUAL) is "
- << "currently not supported";
+ sharding = xla::HloSharding::Manual().ToProto();
break;
}
case ShardingType::TUPLE: {
@@ -323,7 +322,7 @@ std::vector ShardingUtil::GetShardShape(
return shard_shape;
} else {
- TF_LOG(ERROR) << "Unsupported OpSharding type " << sharding.type();
+ XLA_CHECK(false) << "Unsupported OpSharding type " << sharding.type();
}
}
@@ -429,7 +428,7 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices(
shard_indices[device_index[core]] = std::make_pair(replica_id, indices);
}
} else {
- TF_LOG(ERROR) << "Unsupported OpSharding type " << sharding.type();
+ XLA_CHECK(false) << "Unsupported OpSharding type " << sharding.type();
}
return shard_indices;
}
@@ -488,9 +487,8 @@ std::vector ShardingUtil::ShardTensor(
shards[i], c10::IntArrayRef(pads.data(), pads.size()), 0);
}
}
- } else if ((sharding.type() == xla::OpSharding::MANUAL) ||
- (sharding.type() == xla::OpSharding::TUPLE)) {
- TF_LOG(ERROR) << "Unsupported OpSharding type " << sharding.type();
+ } else {
+ XLA_CHECK(false) << "Unsupported OpSharding type " << sharding.type();
}
return shards;
}
diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h
index 5e0a414b00c..d243c8872a3 100644
--- a/torch_xla/csrc/xla_sharding_util.h
+++ b/torch_xla/csrc/xla_sharding_util.h
@@ -80,11 +80,12 @@ class ShardingUtil {
// based on the `sharding` spec. REPLICATED sharding should result in shards
// identical to the input; OTHERS (tiled) sharding result in shards where
// each data dimension is sharded across devices along the same dimension in
- // the `tile_assignment`; the returned tensor shards vector is indexed by the
- // device IDs. There is no data duplication. Shards are not padded in case the
- // input tensor is not evenly partitionable, unless `padded` is set.
- // The the returned tensors will be in 1:1 correspondence with the `devices`
- // vector, so the `i`th result will belong on the `i`th device.
+ // the `tile_assignment`; the returned tensor shards vector is
+ // indexed by the device IDs. There is no data duplication. Shards are not
+ // padded in case the input tensor is not evenly partitionable, unless
+ // `padded` is set. The the returned tensors will be in 1:1 correspondence
+ // with the `devices` vector, so the `i`th result will belong on the `i`th
+ // device.
static std::vector ShardTensor(
const at::Tensor& tensor, const XLATensor::ShardingSpecPtr shardings,
const std::vector& devices, bool padded = true);
diff --git a/torch_xla/distributed/spmd/__init__.py b/torch_xla/distributed/spmd/__init__.py
index 87ac6f8e965..099f25e9fb5 100644
--- a/torch_xla/distributed/spmd/__init__.py
+++ b/torch_xla/distributed/spmd/__init__.py
@@ -2,7 +2,9 @@
from .xla_sharding import (Mesh, HybridMesh, ShardingType, ShardingSpec,
XLAPatchedLinear, mark_sharding, clear_sharding,
wrap_if_sharded, xla_patched_nn_linear_forward,
- set_global_mesh, get_global_mesh)
+ set_global_mesh, get_global_mesh,
+ _mark_manual_sharding, enable_manual_sharding,
+ disable_manual_sharding)
from .api import xla_distribute_tensor, xla_distribute_module, auto_policy
__all__ = [
@@ -22,4 +24,9 @@
"xla_patched_nn_linear_forward",
"set_global_mesh",
"get_global_mesh",
+ "_mark_manual_sharding",
+ "enable_manual_sharding",
+ "disable_manual_sharding",
+ "enable_manual_sharding",
+ "disable_manual_sharding",
]
diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py
index 1cdc66a20c2..916fb56f7c9 100644
--- a/torch_xla/distributed/spmd/xla_sharding.py
+++ b/torch_xla/distributed/spmd/xla_sharding.py
@@ -474,6 +474,52 @@ def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple):
return tuple(_partition_spec)
+def _mark_manual_sharding(
+ t: Union[torch.Tensor, XLAShardedTensor]) -> XLAShardedTensor:
+ """
+ This API is meant to be paired with the upcoming pause_spmd&resume_spmd APIs.
+ Don't use it alone.
+ """
+ manual_sharding = torch_xla._XLAC.OpSharding([], [], [], ShardingType.MANUAL)
+ torch_xla._XLAC._mark_manual_sharding(
+ unwrap_sharded_tensor(t), manual_sharding)
+ return wrap_as_sharded_tensor(t)
+
+
+def enable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor],
+ partition_spec: Tuple[Union[Tuple, int, str, None]],
+ *,
+ mesh: Mesh = None) -> XLAShardedTensor:
+ """
+ This API enables manual sharding for the given tensor. Manual sharding disables SPMD sharding proporgation and auto
+ partition for the given tensor and all subsequential tensors that produced by an op that uses the given tensor as
+ input, and therefore allows the user to manually call collectives for the tensor and subsequential tensors. It
+ requires the user to provide the partition spec to shard the tensor before enabling the manual sharding. To be noted,
+ the leaf tensors need to pass to disable_manual_sharding before ending the graph.
+ """
+ mesh = get_global_mesh() if mesh is None else mesh
+ t = mark_sharding(unwrap_sharded_tensor(t), mesh, partition_spec)
+ t = torch_xla._XLAC._spmd_full_to_shard_shape(unwrap_sharded_tensor(t))
+ return wrap_as_sharded_tensor(t)
+
+
+def disable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor],
+ partition_spec: Tuple[Union[Tuple, int, str, None]],
+ full_shape: torch.Size,
+ *,
+ mesh: Mesh = None) -> XLAShardedTensor:
+ """
+ This API disables manual sharding for the given tensor. The partition_spec and full_shape are used to construct the
+ output tensor as if the input tensor has not been manual sharded.
+ """
+ mesh = get_global_mesh() if mesh is None else mesh
+ t = _mark_manual_sharding(unwrap_sharded_tensor(t))
+ t = torch_xla._XLAC._spmd_shard_to_full_shape(
+ unwrap_sharded_tensor(t), mesh.get_op_sharding(partition_spec),
+ full_shape, t.dtype)
+ return wrap_as_sharded_tensor(t)
+
+
@xr.requires_pjrt
def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor],
mesh: Mesh,
@@ -541,7 +587,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor],
def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor:
"""Clear sharding annotation from the input tensor and return a `cpu` casted tensor."""
- torch_xla._XLAC._xla_clear_sharding(t)
+ torch_xla._XLAC._xla_clear_sharding(unwrap_sharded_tensor(t))
if isinstance(t, XLAShardedTensor):
return t.global_tensor
return t
diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py
index bad35db07b3..bb4ce0c4e23 100644
--- a/torch_xla/experimental/custom_kernel.py
+++ b/torch_xla/experimental/custom_kernel.py
@@ -5,6 +5,7 @@
import torch
import torch_xla
import torch_xla.core.xla_model as xm
+import torch_xla.distributed.spmd as xs
from typing import List, Callable
from torch.library import impl
@@ -12,22 +13,6 @@
_XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0") == "1"
-XLA_LIB.define(
- "tpu_custom_call_(Tensor(a!) output, Tensor[] inputs, str payload) -> ()",)
-
-
-@impl(XLA_LIB, "tpu_custom_call_", "XLA")
-def tpu_custom_call_xla_(output: torch.Tensor, inputs: List[torch.Tensor],
- payload: str):
- torch_xla._XLAC._xla_tpu_custom_call_(output, inputs, payload)
-
-
-@impl(XLA_LIB, "tpu_custom_call_", "CompositeExplicitAutograd")
-def tpu_custom_call_(output: torch.Tensor, inputs: List[torch.Tensor],
- payload: str):
- # Do nothing for non-xla tensor.
- return
-
def _extract_backend_config(
module: "jaxlib.mlir._mlir_libs._mlir.ir.Module") -> str | None:
@@ -184,15 +169,29 @@ class FlashAttention(torch.autograd.Function):
}
@staticmethod
- def forward(ctx, q, k, v, causal=False):
+ def forward(ctx, q, k, v, causal=False, partition_spec=None, mesh=None):
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_impl
ctx.causal = causal
+ ctx.partition_spec = partition_spec
+ ctx.mesh = mesh
+ ctx.full_shape = None
save_residuals = q.requires_grad or k.requires_grad or v.requires_grad
+ # SPMD integration.
+ # mark_sharding is in-placed, and therefore save the full q, k, v for the backward.
+ full_q = q
+ full_k = k
+ full_v = v
+ if partition_spec is not None:
+ ctx.full_shape = q.shape
+ q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor
+ k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor
+ v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor
+
# It returns the shape and type of o, l, m.
def shape_dtype(q, *arg):
if not save_residuals:
@@ -224,11 +223,24 @@ def shape_dtype(q, *arg):
False,
static_argnums=range(5, 13))
if not save_residuals:
+ # SPMD integration
+ if partition_spec is not None:
+ o = xs.disable_manual_sharding(
+ o, partition_spec, ctx.full_shape, mesh=mesh).global_tensor
return o
o, *aux = o
l, m = (v[..., 0] for v in aux[-2:])
- ctx.save_for_backward(q, k, v, o, l, m)
+ # SPMD integration
+ if partition_spec is not None:
+ o = xs.disable_manual_sharding(
+ o, partition_spec, ctx.full_shape, mesh=mesh).global_tensor
+ l = xs.disable_manual_sharding(
+ l, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor
+ m = xs.disable_manual_sharding(
+ m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor
+
+ ctx.save_for_backward(full_q, full_k, full_v, o, l, m)
return o
@staticmethod
@@ -237,6 +249,9 @@ def backward(ctx, grad_output):
q, k, v, o, l, m = ctx.saved_tensors
causal = ctx.causal
+ partition_spec = ctx.partition_spec
+ mesh = ctx.mesh
+ full_shape = ctx.full_shape
grad_q = grad_k = grad_v = None
grad_i = torch.sum(
@@ -250,6 +265,20 @@ def backward(ctx, grad_output):
expanded_grad_i = grad_i.unsqueeze(-1).expand(
[-1 for _ in grad_i.shape] + [FlashAttention.MIN_BLOCK_SIZE])
+ # SPMD integration
+ if partition_spec is not None:
+ q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor
+ k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor
+ v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor
+ expanded_l = xs.enable_manual_sharding(
+ expanded_l, partition_spec, mesh=mesh).global_tensor
+ expanded_m = xs.enable_manual_sharding(
+ expanded_m, partition_spec, mesh=mesh).global_tensor
+ grad_output = xs.enable_manual_sharding(
+ grad_output, partition_spec, mesh=mesh).global_tensor
+ expanded_grad_i = xs.enable_manual_sharding(
+ expanded_grad_i, partition_spec, mesh=mesh).global_tensor
+
if ctx.needs_input_grad[0]:
payload, _ = trace_pallas(
_flash_attention_bwd_dq,
@@ -319,7 +348,16 @@ def backward(ctx, grad_output):
if ctx.needs_input_grad[2]:
grad_v = grads[1]
- return grad_q, grad_k, grad_v, None
+ # SPMD integration
+ if partition_spec is not None:
+ grad_q = xs.disable_manual_sharding(
+ grad_q, partition_spec, full_shape, mesh=mesh).global_tensor
+ grad_k = xs.disable_manual_sharding(
+ grad_k, partition_spec, full_shape, mesh=mesh).global_tensor
+ grad_v = xs.disable_manual_sharding(
+ grad_v, partition_spec, full_shape, mesh=mesh).global_tensor
+
+ return grad_q, grad_k, grad_v, None, None, None
def flash_attention(
@@ -327,8 +365,10 @@ def flash_attention(
k, # [batch_size, num_heads, kv_seq_len, d_model]
v, # [batch_size, num_heads, kv_seq_len, d_model]
causal=False,
-):
- return FlashAttention.apply(q, k, v, causal)
+ *,
+ partition_spec=None,
+ mesh=None):
+ return FlashAttention.apply(q, k, v, causal, partition_spec, mesh)
XLA_LIB.define(
diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py
index ec47f7cbfce..bf32a712f3e 100644
--- a/torch_xla/experimental/fori_loop.py
+++ b/torch_xla/experimental/fori_loop.py
@@ -12,20 +12,53 @@
from torch._higher_order_ops.while_loop import while_loop_op
+def fori_loop(lower, upper, user_body_func, *init_val):
+
+ device = xm.xla_device()
+
+ def cond_fn(upper, lower, *init_val):
+ return lower[0] < upper[0]
+
+ def body_fn(upper, lower, *init_val):
+ one_value_i = torch.ones(1, dtype=torch.int32, device=device)
+ res_list = list(user_body_func(*init_val))
+ res_list.insert(0, lower)
+ res_list.insert(0, torch.sub(upper, one_value_i))
+ return res_list
+
+ res = while_loop(cond_fn, body_fn, (lower, upper, *init_val))
+ return res
+
+
@while_loop_op.py_impl(DispatchKey.XLA)
-def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None):
+def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None):
+ # TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '')
# cond_fn&body_fn: callable
# carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors)
if additional_inputs is None:
additional_inputs = tuple()
- return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs)
+ return _xla_while_loop(
+ cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs)
-def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs):
+def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs):
+ # untuple carried_inputs from while_loop
+ carried_inputs = carried_inputs[0]
+ # fake carried_inputs to split formal code
+ fake_carried_inputs = []
+ for carried_input in carried_inputs:
+ device = carried_input.device
+ fake_carried_inputs.append(
+ torch.randint(10, carried_input.size(),
+ dtype=carried_input.dtype).to(device))
+ fake_carried_inputs = tuple(fake_carried_inputs)
- # create inputs placeholder
+ # trans fake_carried_inputs from list(tensor) to list(xla::op)
kwargs = {}
- shapes = xb.tensor_shape(carried_inputs)
+ if type(fake_carried_inputs) is tuple:
+ shapes = xb.tensor_shape(fake_carried_inputs)
+ else:
+ shapes = xb.tensor_shape((fake_carried_inputs))
builder = xb.create_builder('test_while')
params = []
for shape in shapes:
@@ -33,19 +66,19 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs):
params.append(p)
# generate cond_fn xlacomputation
- cond_result = cond_fn(carried_inputs[0], carried_inputs[1])
+ cond_result = cond_fn(*fake_carried_inputs)
cond_ctx = torch_xla._XLAC.lowering.LoweringContext()
cond_ctx.set_name_string("condctx")
- cond_ctx.build([cond_result])
+ cond_ctx.buildforiloop([cond_result], list(fake_carried_inputs[2:]))
cond_hlo = cond_ctx.hlo()
cond_computation = xb.computation_from_module_proto("condcomputation",
cond_hlo)
# generate body_fn xlacomputation
- body_result = body_fn(carried_inputs[0], carried_inputs[1])
+ body_result = body_fn(*fake_carried_inputs)
body_ctx = torch_xla._XLAC.lowering.LoweringContext()
body_ctx.set_name_string("bodyctx")
- body_ctx.build(list(body_result))
+ body_ctx.buildforiloop(list(body_result), [])
body_hlo = body_ctx.hlo()
body_computation = xb.computation_from_module_proto("bodycomputation",
body_hlo)
@@ -61,7 +94,6 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs):
# gain final result with generated while xlacomputation
result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while',
- tuple(carried_inputs),
- computation)
+ (carried_inputs), computation)
- return result
+ return result
\ No newline at end of file
diff --git a/torch_xla/experimental/plugins.py b/torch_xla/experimental/plugins.py
index 77c2a572de3..620dff7e45c 100644
--- a/torch_xla/experimental/plugins.py
+++ b/torch_xla/experimental/plugins.py
@@ -76,7 +76,9 @@ def use_dynamic_plugins():
def using_dynamic_plugins():
- return xu.getenv_as(xenv.PJRT_DYNAMIC_PLUGINS, bool, False)
+ # TODO: dummy plugin for CPU
+ return xu.getenv_as(xenv.PJRT_DYNAMIC_PLUGINS, bool,
+ False) and xr.device_type() != "CPU"
def default() -> DevicePlugin:
diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py
index 47018cd949f..3642354ab91 100644
--- a/torch_xla/stablehlo.py
+++ b/torch_xla/stablehlo.py
@@ -1,17 +1,13 @@
import copy
from dataclasses import dataclass
import enum
-import itertools
import json
-import shutil
import os
-import re
from typing import List, Tuple, Optional, Mapping, Any, Dict
import dataclasses
import numpy as np
import torch
-from torch import nn
from torch.fx import _pytree as fx_pytree
import torch_xla
from torch_xla.core import xla_model as xm
@@ -19,13 +15,10 @@
from torch_xla.debug import metrics
import torch_xla.experimental.quantized
from torch_xla.experimental.unbounded_dynamism_export import exported_program_has_symbolic_input_shape, process_exported_program_with_symbolic_input
-import torch._dynamo as torchdynamo
from torch.utils import _pytree as pytree
from torch._decomp import get_decompositions
-from typing import Tuple, Type, Callable
-
-import sys
+from typing import Tuple
def _get_numpy_dtype(dtype):
@@ -64,6 +57,9 @@ class StableHLOExportOptions:
# special constants (0 and 1) are inlined.
inline_all_constant: bool = True
+ # Whether to export the weights
+ export_weights: bool = True
+
class StableHLOGraphModule:
@@ -432,11 +428,15 @@ def _exported_program_to_stablehlo_bundle(exported_model,
output_pytree_spec=pytree.treespec_dumps(
exported_model.call_spec.out_spec),
)
+
+ exported_state_dict = {}
+ if options.export_weights:
+ exported_state_dict = pytree.tree_map_only(
+ torch.Tensor, lambda x: x.detach().cpu().numpy(), state_dict)
+
bundle = StableHLOModelBundle(
stablehlo_funcs=[StableHLOFunc(meta, stablehlo_content, stablehlo_text)],
- state_dict=pytree.tree_map_only(torch.Tensor,
- lambda x: x.detach().cpu().numpy(),
- state_dict),
+ state_dict=exported_state_dict,
additional_constants=additional_constants,
)
diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py
index 961f6a3217e..141d7e3e5a7 100644
--- a/torch_xla/torch_xla.py
+++ b/torch_xla/torch_xla.py
@@ -45,3 +45,8 @@ def real_devices() -> List[str]:
def device_count() -> int:
"""Returns number of addressable devices in the current process."""
return len(real_devices())
+
+
+def sync():
+ """Launches all pending graph operations."""
+ xm.mark_step()