diff --git a/.circleci/common.sh b/.circleci/common.sh
index 0f09e02f2981..aaf61901a179 100755
--- a/.circleci/common.sh
+++ b/.circleci/common.sh
@@ -150,10 +150,10 @@ function run_torch_xla_python_tests() {
# CUDA tests
if [ -x "$(command -v nvidia-smi)" ]; then
- # These tests fail on CUDA with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)
- # PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
- # PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1
- # XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
+ # These tests fail on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)
+ PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
+ PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1
+ XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
# Syncfree SGD optimizer tests
if [ -d ./torch_xla/amp/syncfree ]; then
echo "Running Syncfree Optimizer Test"
diff --git a/.kokoro/Dockerfile b/.kokoro/Dockerfile
index a7ffc24dfdd5..32cc499477d1 100644
--- a/.kokoro/Dockerfile
+++ b/.kokoro/Dockerfile
@@ -3,7 +3,7 @@ WORKDIR /
RUN apt-get update
RUN apt-get -y upgrade
RUN apt-get -y install clang time
-RUN pip install pytest tf-nightly
+RUN pip install pytest
ARG USE_MKLDNN=0
ARG SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2
ARG DISABLE_XRT=1
@@ -53,4 +53,4 @@ RUN time pip install -e .
# Run tests
ENV PJRT_DEVICE=CPU
ENV XLA_STABLEHLO_COMPILE=1
-ENTRYPOINT pytest test/stablehlo
\ No newline at end of file
+ENTRYPOINT pytest test/stablehlo
diff --git a/CODEGEN_MIGRATION_GUIDE.md b/CODEGEN_MIGRATION_GUIDE.md
index f234c5193ca7..6cad3d3b84d3 100644
--- a/CODEGEN_MIGRATION_GUIDE.md
+++ b/CODEGEN_MIGRATION_GUIDE.md
@@ -7,7 +7,7 @@ As PyTorch/XLA migrates to the LTC (Lazy Tensor Core), we need to clean up the e
You should follow the instructions in [here](https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md) to install required dependencies and build pytorch and pytorch/XLA from the source. You do not need access to TPU to implement the lowering. It is recommended to experiment on a workstation and configure it to use XLA:CPU. You can configure Pytorch/XLA to use XLA:CPU by running
```
-export XRT_DEVICE_MAP="CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0" XRT_WORKERS="localservice:0;grpc://localhost:51011"
+export PJRT_DEVICE=CPU
```
It is also recommended that you're familiar with our [op lowering process](https://github.com/pytorch/xla/blob/master/OP_LOWERING_GUIDE.md) before you work on the codegen.
diff --git a/README.md b/README.md
index d894d53a46ba..bfda642b2f0e 100644
--- a/README.md
+++ b/README.md
@@ -25,10 +25,12 @@ started:
## Getting Started
-To install PyTorch/XLA a new VM:
+**PyTorch/XLA is now on PyPI!**
+
+To install PyTorch/XLA a new TPU VM:
```
-pip install torch~=2.0.0 https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl
+pip install torch~=2.1.0 torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
```
To update your existing training loop, make the following changes:
@@ -130,26 +132,37 @@ Our comprehensive user guides are available at:
## Available docker images and wheels
-### Wheel
+### Python packages
+
+PyTorch/XLA releases starting with version r2.1 will be available on PyPI. You
+can now install the main build with `pip install torch_xla`. To also install the
+Cloud TPU plugin, install the optional `tpu` dependencies:
+
+```
+pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
+```
+
+GPU, XRT (legacy runtime), and nightly builds are available in our public GCS
+bucket.
| Version | Cloud TPU VMs Wheel |
| --- | ----------- |
-| 2.0 (Python 3.8) | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl` |
-| nightly >= 2023/04/25 (Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl` |
-| nightly >= 2023/04/25 (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl` |
+| 2.1 (CUDA 12.0 + Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.0/torch_xla-2.1.0-cp38-cp38-manylinux_2_28_x86_64.whl` |
+| 2.1 (XRT + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch_xla-2.1.0%2Bxrt-cp310-cp310-manylinux_2_28_x86_64.whl` |
+| nightly (Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl` |
+| nightly (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl` |
- older versions
+
+older versions
| Version | Cloud TPU VMs Wheel |
|---------|-------------------|
+| 2.0 (Python 3.8) | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl` |
| 1.13 | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.13-cp38-cp38-linux_x86_64.whl` |
| 1.12 | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.12-cp38-cp38-linux_x86_64.whl` |
| 1.11 | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.11-cp38-cp38-linux_x86_64.whl` |
| 1.10 | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.10-cp38-cp38-linux_x86_64.whl` |
-| nightly <= 2023/04/25 | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl` |
-
-
@@ -204,53 +217,58 @@ pip3 install torch_xla[tpuvm]
This is only required on Cloud TPU VMs.
+
+
### Docker
| Version | Cloud TPU VMs Docker |
| --- | ----------- |
-2.0 | `gcr.io/tpu-pytorch/xla:r2.0_3.8_tpuvm` |
-1.13 | `gcr.io/tpu-pytorch/xla:r1.13_3.8_tpuvm` |
-nightly python 3.10 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm` |
-nightly python 3.8 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm` |
-nightly python 3.10(>= 2023/04/25) | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_YYYYMMDD` |
-nightly python 3.8(>= 2023/04/25) | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_YYYYMMDD` |
-nightly at date(< 2023/04/25) | `gcr.io/tpu-pytorch/xla:nightly_3.8_tpuvm_YYYYMMDD` |
+| 2.1 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_tpuvm` |
+| 2.0 | `gcr.io/tpu-pytorch/xla:r2.0_3.8_tpuvm` |
+| 1.13 | `gcr.io/tpu-pytorch/xla:r1.13_3.8_tpuvm` |
+| nightly python | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm` |
-| Version | GPU CUDA 12.0 + Python 3.8 Docker |
+| Version | GPU CUDA 12.0 Docker |
| --- | ----------- |
+| 2.1 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_12.0` |
| nightly | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.0` |
-| nightly at date(>=2023/06/27) | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.0_YYYYMMDD` |
+| nightly at date | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.0_YYYYMMDD` |
-| Version | GPU CUDA 11.8 + Python 3.8 Docker |
+| Version | GPU CUDA 11.8 + Docker |
| --- | ----------- |
+| 2.1 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_11.8` |
| 2.0 | `gcr.io/tpu-pytorch/xla:r2.0_3.8_cuda_11.8` |
| nightly | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8` |
-| nightly at date(>=2023/04/25) | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8_YYYYMMDD` |
-| nightly at date(<2023/04/25) | `gcr.io/tpu-pytorch/xla:nightly_3.8_cuda_11.8_YYYYMMDD` |
+| nightly at date | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8_YYYYMMDD` |
-| Version | GPU CUDA 11.7 + Python 3.8 Docker |
+
+
+older versions
+
+| Version | GPU CUDA 11.7 + Docker |
| --- | ----------- |
| 2.0 | `gcr.io/tpu-pytorch/xla:r2.0_3.8_cuda_11.7` |
-| Version | GPU CUDA 11.2 + Python 3.8 Docker |
+| Version | GPU CUDA 11.2 + Docker |
| --- | ----------- |
| 1.13 | `gcr.io/tpu-pytorch/xla:r1.13_3.8_cuda_11.2` |
-| Version | GPU CUDA 11.2 + Python 3.7 Docker |
+| Version | GPU CUDA 11.2 + Docker |
| --- | ----------- |
-1.13 | `gcr.io/tpu-pytorch/xla:r1.13_3.7_cuda_11.2` |
-1.12 | `gcr.io/tpu-pytorch/xla:r1.12_3.7_cuda_11.2` |
+| 1.13 | `gcr.io/tpu-pytorch/xla:r1.13_3.7_cuda_11.2` |
+| 1.12 | `gcr.io/tpu-pytorch/xla:r1.12_3.7_cuda_11.2` |
+
To run on [compute instances with
GPUs](https://cloud.google.com/compute/docs/gpus/create-vm-with-gpus).
diff --git a/TROUBLESHOOTING.md b/TROUBLESHOOTING.md
index ed4b9173fa59..842deabd1868 100644
--- a/TROUBLESHOOTING.md
+++ b/TROUBLESHOOTING.md
@@ -203,6 +203,8 @@ only be enabled for debugging.
* ```XLA_SAVE_TENSORS_FMT```: The format of the graphs stored within the _XLA_SAVE_TENSORS_FILE_
file. Can be ```text``` (the default), ```dot``` (the _Graphviz_ format) or ```hlo```.
+* ```XLA_FLAGS=--xla_dump_to```: If set to ```=/tmp/dir_name```, XLA compiler will dump the unoptimized and optimzed HLO per compilation.
+
* ```XLA_METRICS_FILE```: If set, the path to a local file where the internal metrics will be
saved at every step. Metrics will be appended to the file, if already existing.
@@ -261,61 +263,3 @@ only be enabled for debugging.
* ```XLA_DUMP_HLO_GRAPH```: If set to `=1` in case of a compilation or execution error the
offending HLO graph will be dumped as part of the runtime error raised by `xla_util.cc`.
-### Retrieving Stack Traces
-
-In the event that the _PyTorch_ process is hanging, it might be useful to include the stack
-traces together with the GitHub issue.
-
-First thing is to find out which PID the _PyTorch_ process is associated with. Using the ```ps```
-command it is possible to find that information. It will be a _python_ process running your
-main _python_ file.
-
-In order to allow _GDB_ to attach a user process the following command should be run as root:
-
-```Shell
-echo 0 > /proc/sys/kernel/yama/ptrace_scope
-```
-
-The above command remains active until the machine is rebooted.
-
-The, given the PID, it is possible to grab the stack traces with the following command:
-
-```Shell
-./scripts/dump_stacks.py PID > /tmp/stack-traces.log
-```
-
-## Using debug_run.py To Collect Debug Information
-
-A utility is provided in `scripts/debug_run.py` which can be used to create a `tar.gz`
-archive with the information required to debug _PyTorch/XLA_ executions.
-
-Example:
-
-```Shell
-./scripts/debug_run.py --outfile /tmp/debug_run.tar.gz -- python -u SCRIPT [ARGS...]
-```
-
-The _python_ `-u` flag is suggested to disable buffering so that captured logs are correctly
-interleaved (otherwise STDOUT will be rendered after all STDERR).
-
-The above command line example will leave the temporary folder containing the archived
-information on the filesystem. Use the `--tidy` flag to have that removed on exit:
-
-```Shell
-./scripts/debug_run.py --tidy --outfile /tmp/debug_run.tar.gz -- python -u SCRIPT [ARGS...]
-```
-
-The `debug_run.tar.gz` file should then be attached to bug reports when necessary.
-
-Since the script will collect a lot of data, it should usually be let run for no more
-than hundred steps or so.
-
-If the SCRIPT has arguments to control the number of steps, those should be used,
-otherwise hitting `CTRL^C` will interrupt the run.
-
-It is also suggested to run in single-core mode, to minimize the amount of data.
-Running in single-core mode is also strongly suggested when debugging execution issues.
-
-## Common Issues
-
-* `Missing XLA configuration` error message: You need to set `XRT_TPU_CONFIG` if using TPUs. If using GPUs set `GPU_NUM_DEVICES=N` for `N` number of GPUs. If using CPUs set `XRT_DEVICE_MAP="CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0"` and `XRT_WORKERS="localservice:0;grpc://localhost:9002"`
diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml
index 74055955bd7b..bdb6c38e8cc0 100644
--- a/codegen/xla_native_functions.yaml
+++ b/codegen/xla_native_functions.yaml
@@ -63,6 +63,8 @@ full_codegen:
- log_sigmoid_forward
- lt.Scalar
- lt.Tensor
+ - masked_fill.Scalar
+ - masked_fill.Tensor
- maximum
- minimum
- native_dropout_backward
@@ -217,8 +219,6 @@ supported:
- log2
- log10
- logsumexp
- - masked_fill.Scalar
- - masked_fill.Tensor
- masked_scatter
- masked_select
- max
diff --git a/docs/pjrt.md b/docs/pjrt.md
index aedcfacec2f4..265c0abcce50 100644
--- a/docs/pjrt.md
+++ b/docs/pjrt.md
@@ -1,7 +1,7 @@
# PJRT Runtime
PyTorch/XLA has migrated from the TensorFlow-based XRT runtime to the [PJRT
-runtime](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla/pjrt)
+runtime](https://github.com/openxla/xla/tree/main/xla/pjrt)
used by [JAX](https://github.com/google/jax).
If you encounter a bug with PJRT, please file an issue on GitHub with the
diff --git a/infra/tpu-pytorch-releases/artifacts.auto.tfvars b/infra/tpu-pytorch-releases/artifacts.auto.tfvars
index bd5edb6d02af..cbec743e3806 100644
--- a/infra/tpu-pytorch-releases/artifacts.auto.tfvars
+++ b/infra/tpu-pytorch-releases/artifacts.auto.tfvars
@@ -35,14 +35,14 @@ xrt_versioned_builds = [
{
accelerator = "tpu"
python_version = "3.10"
- pytorch_git_rev = "v2.1.0-rc6"
+ pytorch_git_rev = "v2.1.0"
package_version = "2.1.0+xrt"
},
{
accelerator = "cuda"
python_version = "3.10"
cuda_version = "12.0"
- pytorch_git_rev = "v2.1.0-rc6"
+ pytorch_git_rev = "v2.1.0"
package_version = "2.1.0+xrt"
},
]
@@ -51,14 +51,14 @@ xrt_versioned_builds = [
versioned_builds = [
{
git_tag = "v2.1.0"
- pytorch_git_rev = "v2.1.0-rc6"
+ pytorch_git_rev = "v2.1.0"
package_version = "2.1.0"
accelerator = "tpu"
bundle_libtpu = "0"
},
{
git_tag = "v2.1.0"
- pytorch_git_rev = "v2.1.0-rc6"
+ pytorch_git_rev = "v2.1.0"
package_version = "2.1.0"
accelerator = "tpu"
python_version = "3.10"
@@ -66,7 +66,7 @@ versioned_builds = [
},
{
git_tag = "v2.1.0"
- pytorch_git_rev = "v2.1.0-rc6"
+ pytorch_git_rev = "v2.1.0"
package_version = "2.1.0+libtpu"
accelerator = "tpu"
python_version = "3.10"
@@ -84,26 +84,41 @@ versioned_builds = [
},
{
git_tag = "v2.1.0"
- pytorch_git_rev = "v2.1.0-rc6"
+ pytorch_git_rev = "v2.1.0"
package_version = "2.1.0",
accelerator = "cuda"
cuda_version = "12.0"
},
{
git_tag = "v2.1.0"
- pytorch_git_rev = "v2.1.0-rc6"
+ pytorch_git_rev = "v2.1.0"
package_version = "2.1.0"
accelerator = "cuda"
cuda_version = "11.8"
},
{
git_tag = "v2.1.0"
- pytorch_git_rev = "v2.1.0-rc6"
+ pytorch_git_rev = "v2.1.0"
+ package_version = "2.1.0"
+ accelerator = "cuda"
+ cuda_version = "12.1"
+ },
+ {
+ git_tag = "v2.1.0"
+ pytorch_git_rev = "v2.1.0"
package_version = "2.1.0"
accelerator = "cuda"
cuda_version = "11.8"
python_version = "3.10"
},
+ {
+ git_tag = "v2.1.0"
+ pytorch_git_rev = "v2.1.0"
+ package_version = "2.1.0"
+ accelerator = "cuda"
+ cuda_version = "12.1"
+ python_version = "3.10"
+ },
{
git_tag = "v2.0.0"
package_version = "2.0"
diff --git a/scripts/fixup_binary.py b/scripts/fixup_binary.py
deleted file mode 100755
index 8d12d1c78f22..000000000000
--- a/scripts/fixup_binary.py
+++ /dev/null
@@ -1,65 +0,0 @@
-#!/usr/bin/env python
-
-import argparse
-import glob
-import os
-import site
-import subprocess
-
-
-def find_torch_xla_site(site_paths):
- for site_path in site_paths:
- # If there is one named 'torch_xla', this is what we pick.
- path = os.path.join(site_path, 'torch_xla', 'lib')
- if os.path.isdir(path):
- return [site_path, path]
- dirs = glob.glob(os.path.join(site_path, 'torch_xla*'))
- # Get the most recent one.
- for xpath in sorted(dirs, key=os.path.getmtime):
- path = os.path.join(xpath, 'lib')
- if os.path.isdir(path):
- return [site_path, path]
- if os.path.isfile(os.path.join(xpath, 'libptxla.so')):
- return [site_path, xpath, os.path.join(xpath, 'torch_xla', 'lib')]
- raise RuntimeError('Unable to find torch_xla package in {}'.format(site_path))
-
-
-def find_torch_site(site_paths):
- for site_path in site_paths:
- path = os.path.join(site_path, 'torch', 'lib')
- if os.path.isdir(path):
- return [path]
- raise RuntimeError('Unable to find torch package in {}'.format(site_path))
-
-
-def list_rpaths(path):
- if subprocess.call(['patchelf', '--shrink-rpath', path]) != 0:
- raise RuntimeError('Failed to shrink RPATH folders: {}'.format(path))
- return subprocess.check_output(['patchelf', '--print-rpath',
- path]).decode('utf-8').strip('\n').split(':')
-
-
-def set_rpaths(path, rpaths):
- if subprocess.call(['patchelf', '--set-rpath', ':'.join(rpaths), path]) != 0:
- raise RuntimeError('Failed to set RPATH folders {}: {}'.format(
- rpaths, path))
-
-
-def fixup_binary(args):
- site_paths = site.getsitepackages()
- xla_rpaths = find_torch_xla_site(site_paths)
- torch_rpaths = find_torch_site(site_paths)
- rpaths = list_rpaths(args.binary)
- rpaths = xla_rpaths + torch_rpaths + rpaths
- set_rpaths(args.binary, rpaths)
-
-
-if __name__ == '__main__':
- arg_parser = argparse.ArgumentParser()
- arg_parser.add_argument(
- 'binary',
- type=str,
- metavar='BINARY',
- help='The path to the binary to be patched')
- args, files = arg_parser.parse_known_args()
- fixup_binary(args)
diff --git a/setup.py b/setup.py
index b3b39cd709e4..3548701a9cea 100644
--- a/setup.py
+++ b/setup.py
@@ -307,15 +307,37 @@ def run(self):
super().run()
+# Read in README.md for our long_description
+cwd = os.path.dirname(os.path.abspath(__file__))
+with open(os.path.join(cwd, "README.md"), encoding="utf-8") as f:
+ long_description = f.read()
+
setup(
name=os.environ.get('TORCH_XLA_PACKAGE_NAME', 'torch_xla'),
version=version,
description='XLA bridge for PyTorch',
+ long_description=long_description,
+ long_description_content_type="text/markdown",
url='https://github.com/pytorch/xla',
author='PyTorch/XLA Dev Team',
author_email='pytorch-xla@googlegroups.com',
- # Exclude the build files.
- packages=find_packages(exclude=['build']),
+ classifiers=[
+ "Development Status :: 5 - Production/Stable",
+ "Intended Audience :: Developers",
+ "Intended Audience :: Education",
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: BSD License",
+ "Topic :: Scientific/Engineering",
+ "Topic :: Scientific/Engineering :: Mathematics",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: Software Development",
+ "Topic :: Software Development :: Libraries",
+ "Topic :: Software Development :: Libraries :: Python Modules",
+ "Programming Language :: C++",
+ "Programming Language :: Python :: 3",
+ ],
+ python_requires=">=3.8.0",
+ packages=find_packages(include=['torch_xla*']),
ext_modules=[
BazelExtension('//:_XLAC.so'),
],
@@ -334,12 +356,12 @@ def run(self):
},
extras_require={
# On Cloud TPU VM install with:
- # $ sudo pip3 install torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.11-cp38-cp38-linux_x86_64.whl
+ # pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
+ 'tpu': [f'libtpu-nightly=={_libtpu_version}'],
+ # On nightly, install libtpu with `pip install torch_xla[tpuvm]`
+ # Remove from release branches since this is not allowed by PyPI.
'tpuvm': [f'libtpu-nightly @ {_libtpu_storage_path}'],
},
- data_files=[
- 'scripts/fixup_binary.py',
- ],
cmdclass={
'build_ext': BuildBazelExtension,
'clean': Clean,
diff --git a/test/cpp/test_aten_xla_tensor_4.cpp b/test/cpp/test_aten_xla_tensor_4.cpp
index fd669ce49e00..1e61a6fa05ac 100644
--- a/test/cpp/test_aten_xla_tensor_4.cpp
+++ b/test/cpp/test_aten_xla_tensor_4.cpp
@@ -1081,6 +1081,24 @@ TEST_F(AtenXlaTensorTest, TestMaskedFillBroadcast2) {
ExpectCounterChanged("xla::masked_fill", cpp_test::GetIgnoredCounters());
}
+TEST_F(AtenXlaTensorTest, TestMaskedFillBroadcast3) {
+ torch::Tensor input =
+ torch::rand({2, 1}, torch::TensorOptions(torch::kFloat));
+ torch::Tensor mask =
+ torch::randint(0, 2, {4, 2, 3}, torch::TensorOptions(torch::kBool));
+ torch::Scalar value(42);
+ torch::Tensor result = torch::masked_fill(input, mask, value);
+ ForEachDevice([&](const torch::Device& device) {
+ torch::Tensor xla_input = CopyToDevice(input, device);
+ torch::Tensor xla_mask = CopyToDevice(mask, device);
+ torch::Tensor xla_result = torch::masked_fill(xla_input, xla_mask, value);
+ AllClose(result, xla_result);
+ });
+
+ ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
+ ExpectCounterChanged("xla::masked_fill", cpp_test::GetIgnoredCounters());
+}
+
TEST_F(AtenXlaTensorTest, TestFill) {
torch::Scalar value(42);
ForEachDevice([&](const torch::Device& device) {
diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py
index 7b3d5eb86fbd..276571e59793 100644
--- a/test/spmd/test_xla_distributed_checkpoint.py
+++ b/test/spmd/test_xla_distributed_checkpoint.py
@@ -16,7 +16,7 @@
create_default_global_save_plan,
)
from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner
-from torch_xla.experimental._distributed_checkpoint_helpers import (
+from torch_xla.experimental.distributed_checkpoint._helpers import (
_sharded_cpu_state_dict, _CpuShards, _is_sharded_tensor)
diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py
index ce2cae18dd6c..1b128164a22b 100644
--- a/test/spmd/test_xla_sharding.py
+++ b/test/spmd/test_xla_sharding.py
@@ -900,6 +900,139 @@ def test_op_sharding_cache(self):
xs.mark_sharding(v, mesh, (0, None))
self.assertEqual(met.counter_value("CreateOpSharding"), 2)
+ def test_from_cpu_shards_replicated(self):
+ from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
+
+ # Create an OpSharding with all devices on a single axis
+ mesh = self._get_mesh((self.n_devices,))
+ partition_spec = (None,)
+ op_sharding = mesh.get_op_sharding(partition_spec)
+ shards = [torch.arange(4)] * self.n_devices
+
+ # No shape should result in the shape of a single shard.
+ global_tensor = from_cpu_shards(shards, op_sharding)
+ self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))
+
+ # Specify a valid shape for the global tensor
+ global_tensor = from_cpu_shards(shards, op_sharding, shards[0].shape)
+ self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))
+
+ # All invalid shapes should raise
+ with self.assertRaises(RuntimeError):
+ from_cpu_shards(shards, op_sharding, torch.Size((5,)))
+ with self.assertRaises(RuntimeError):
+ from_cpu_shards(shards, op_sharding, torch.Size((3,)))
+ with self.assertRaises(RuntimeError):
+ from_cpu_shards(shards, op_sharding, torch.Size((2, 2)))
+
+ def test_from_cpu_shards_tiled(self):
+ from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
+
+ # Create an OpSharding with all devices on a single axis
+ mesh = self._get_mesh((self.n_devices,))
+ partition_spec = (0,)
+ op_sharding = mesh.get_op_sharding(partition_spec)
+ shards = [torch.LongTensor([i]) for i in range(self.n_devices)]
+
+ global_tensor = from_cpu_shards(shards, op_sharding)
+ self.assertTrue(
+ torch.allclose(global_tensor.cpu(), torch.arange(self.n_devices)))
+
+ # Test incorrect number of shards
+ with self.assertRaises(RuntimeError):
+ from_cpu_shards(shards[:-1], op_sharding)
+
+ # Test an invalid global shape - too many values.
+ with self.assertRaises(RuntimeError):
+ from_cpu_shards(shards, op_sharding, torch.Size((self.n_devices * 2,)))
+
+ # Test an invalid global shape - incorrect rank
+ with self.assertRaises(RuntimeError):
+ from_cpu_shards(shards, op_sharding, torch.Size((1, self.n_devices)))
+
+ # Test a valid global shape - restrict the number of meaningful values
+ # to 1, treating the rest as padding.
+ global_tensor = from_cpu_shards(shards, op_sharding, torch.Size((1,)))
+ self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(1)))
+
+ def test_from_cpu_shards_2d(self):
+ from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
+
+ # Create an appropriate 2D mesh for the number of devices
+ if self.n_devices >= 4:
+ mesh_shape = (self.n_devices // 2, 2)
+ else:
+ mesh_shape = (1, self.n_devices)
+ mesh_2d = self._get_mesh(mesh_shape)
+
+ # Replicated sharding
+ shards = [torch.LongTensor([self.n_devices])] * self.n_devices
+ partition_spec = (None, None)
+ op_sharding = mesh_2d.get_op_sharding(partition_spec)
+ global_tensor = from_cpu_shards(shards, op_sharding)
+ self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0]))
+
+ if self.n_devices > 1:
+ # Tiled sharding
+ shards = [torch.LongTensor([[i]]) for i in range(self.n_devices)]
+ partition_spec = (0, 1)
+ op_sharding = mesh_2d.get_op_sharding(partition_spec)
+ global_tensor = from_cpu_shards(shards, op_sharding)
+ expected = torch.arange(self.n_devices).reshape(*mesh_shape)
+ self.assertTrue(torch.allclose(global_tensor.cpu(), expected))
+
+ # Partially replicated sharding
+ shards = [torch.LongTensor([[i]]) for i in range(2)] * (
+ self.n_devices // 2)
+ partition_spec = (None, 1)
+ op_sharding = mesh_2d.get_op_sharding(partition_spec)
+ global_tensor = from_cpu_shards(shards, op_sharding)
+ # Partial replication along the 0th axis represents a global tensor
+ # of torch.Tensor([[0, 1]]).
+ expected = torch.arange(2).reshape(1, 2)
+ self.assertTrue(torch.allclose(global_tensor.cpu(), expected))
+
+ def test_from_cpu_shards_global_shape(self):
+ from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
+
+ mesh = self._get_mesh((self.n_devices,))
+ numel = self.n_devices**2
+ # The global tensor is torch.arange(numel).
+ shards = [
+ torch.arange(self.n_devices) + (i * self.n_devices)
+ for i in range(self.n_devices)
+ ]
+ partition_spec = (0,)
+ op_sharding = mesh.get_op_sharding(partition_spec)
+
+ # No global shape specified will include all data from the shards
+ global_tensor = from_cpu_shards(shards, op_sharding)
+ self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(numel)))
+
+ # Too large of a global shape will error out
+ with self.assertRaises(RuntimeError):
+ from_cpu_shards(shards, op_sharding, torch.Size((numel + 1,)))
+
+ if self.n_devices > 1:
+ # When the global tensor has fewer elements than the sum of its shards,
+ # there are two cases:
+
+ # Case 1: If the global shape is within n_devices of numel, the excess
+ # data is treated as padding and ignored.
+ for delta in range(self.n_devices):
+ size = torch.Size((numel - delta,))
+ global_tensor = from_cpu_shards(shards, op_sharding, size)
+ expected = torch.arange(size[0])
+ self.assertTrue(torch.allclose(global_tensor.cpu(), expected))
+
+ # Case 2: Otherwise, it is not possible to have that much padding in a
+ # sharded tensor, and the shards are incompatible with the shape.
+ with self.assertRaises(RuntimeError):
+ shape = torch.Size((numel - self.n_devices,))
+ from_cpu_shards(shards, op_sharding, shape)
+ with self.assertRaises(RuntimeError):
+ from_cpu_shards(shards, op_sharding, torch.Size((1,)))
+
if __name__ == '__main__':
test = unittest.main()
diff --git a/test/test_operations.py b/test/test_operations.py
index db46574bd8cf..220805b3fbea 100644
--- a/test/test_operations.py
+++ b/test/test_operations.py
@@ -309,6 +309,19 @@ def test_get_xla_tensor(self):
tx = t.select(1, 12)
self.assertEqual(tx, sx.data.cpu())
+ def test_masked_fill_scalar(self):
+
+ def fn(tensor):
+ # Build a mask from the first line of tensor.
+ # Also, make it have the same rank as the original tensor.
+ mask = tensor[0].ge(0.5).unsqueeze(dim=0)
+ # Call masked_fill.
+ return tensor.masked_fill(mask, 10)
+
+ x = _gen_tensor(2, 2, device=xm.xla_device())
+ x_cpu = x.cpu()
+ self.assertEqual(fn(x_cpu), fn(x))
+
class TestRandom(test_utils.XlaTestCase):
diff --git a/test/test_train_mp_imagenet_fsdp.py b/test/test_train_mp_imagenet_fsdp.py
index a40f3bef74fa..fdfdc8a698c1 100644
--- a/test/test_train_mp_imagenet_fsdp.py
+++ b/test/test_train_mp_imagenet_fsdp.py
@@ -110,7 +110,7 @@
transformer_auto_wrap_policy)
DEFAULT_KWARGS = dict(
- batch_size=128,
+ batch_size=64,
test_set_batch_size=64,
num_epochs=18,
momentum=0.9,
diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp
index ecad6314f495..b51fbba98cdf 100644
--- a/torch_xla/csrc/aten_xla_type.cpp
+++ b/torch_xla/csrc/aten_xla_type.cpp
@@ -1639,25 +1639,6 @@ at::Tensor XLANativeFunctions::xlogy(const at::Tensor& self,
bridge::GetXlaTensor(self), bridge::GetXlaTensor(other)));
}
-at::Tensor XLANativeFunctions::masked_fill(const at::Tensor& self,
- const at::Tensor& mask,
- const at::Tensor& value) {
- TORCH_LAZY_FN_COUNTER("xla::");
- XLA_CHECK_EQ(value.dim(), 0) << "masked_fill_ only supports a 0-dimensional "
- << "value tensor, but got tensor "
- << "with " << value.dim() << " dimension(s).";
- return masked_fill(self, mask, value.item());
-}
-
-at::Tensor XLANativeFunctions::masked_fill(const at::Tensor& self,
- const at::Tensor& mask,
- const at::Scalar& value) {
- TORCH_LAZY_FN_COUNTER("xla::");
- XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
- return bridge::AtenFromXlaTensor(tensor_methods::masked_fill(
- self_tensor, bridge::GetXlaTensor(mask), value));
-}
-
at::Tensor XLANativeFunctions::masked_scatter(const at::Tensor& self,
const at::Tensor& mask,
const at::Tensor& source) {
@@ -3208,7 +3189,7 @@ at::Tensor XLANativeFunctions::upsample_nearest2d_backward(
// our XLA lowering.
XlaDeviceType hw_type =
static_cast(grad_output_tensor->GetDevice().type());
- if (hw_type != XlaDeviceType::TPU) {
+ if (hw_type != XlaDeviceType::TPU && hw_type != XlaDeviceType::NEURON) {
return at::native::call_fallback_fn<
&xla_cpu_fallback,
ATEN_OP(upsample_nearest2d_backward)>::call(grad_output, output_size,
diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp
index 89931e77ecff..e5425a93001a 100644
--- a/torch_xla/csrc/data_ops.cpp
+++ b/torch_xla/csrc/data_ops.cpp
@@ -31,7 +31,7 @@ bool IsSparseGather(const xla::Shape& input_shape,
// to avoid gather on a single float on TPU.
XlaDeviceType hw_type =
static_cast(bridge::GetCurrentDevice().type());
- if (hw_type == XlaDeviceType::TPU) {
+ if (hw_type == XlaDeviceType::TPU || hw_type == XlaDeviceType::NEURON) {
// XLA_DENSE_GATHER_FACTOR can be used to finely control the
// sparsity check.
static int dense_gather_factor =
@@ -144,6 +144,25 @@ xla::XlaOp BuildExpand(xla::XlaOp input,
torch::lazy::Iota(output_sizes.size()));
}
+xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask,
+ xla::XlaOp scalar) {
+ const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
+ const xla::Shape& mask_shape = ShapeHelper::ShapeOfXlaOp(mask);
+
+ if (!xla::ShapeUtil::Compatible(input_shape, mask_shape)) {
+ xla::Shape shape = XlaHelpers::GetPromotedShape(input_shape, mask_shape);
+ input = BuildExpand(input, shape.dimensions());
+ mask = BuildExpand(mask, shape.dimensions());
+ }
+
+ xla::XlaOp zero = xla::Zero(mask.builder(), XlaHelpers::TypeOfXlaOp(mask));
+ xla::XlaOp mask_pred = xla::Ne(mask, zero);
+ xla::XlaOp update_scalar =
+ ConvertTo(scalar, ShapeHelper::ShapeOfXlaOp(scalar).element_type(),
+ ShapeHelper::ShapeOfXlaOp(input).element_type(), nullptr);
+ return xla::Select(mask_pred, update_scalar, input);
+}
+
std::vector BuildSqueezedDimensions(
absl::Span dimensions, int64_t squeeze_dim) {
std::vector output_dimensions;
diff --git a/torch_xla/csrc/data_ops.h b/torch_xla/csrc/data_ops.h
index 5d05e0a6d285..e22821a7bb0e 100644
--- a/torch_xla/csrc/data_ops.h
+++ b/torch_xla/csrc/data_ops.h
@@ -43,6 +43,9 @@ xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input);
xla::XlaOp BuildExpand(xla::XlaOp input,
absl::Span output_sizes);
+xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask,
+ xla::XlaOp scalar);
+
std::vector BuildSqueezedDimensions(
absl::Span dimensions, int64_t squeeze_dim);
diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp
index b3957d7a68f7..421066ba72cd 100644
--- a/torch_xla/csrc/init_python_bindings.cpp
+++ b/torch_xla/csrc/init_python_bindings.cpp
@@ -34,6 +34,7 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/ir_dump_util.h"
+#include "torch_xla/csrc/layout_manager.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "torch_xla/csrc/runtime/computation_client.h"
@@ -1663,6 +1664,72 @@ void InitXlaModuleBindings(py::module m) {
}
return std::nullopt;
});
+ // Reassemble the CPU shards into a global tensor. A new sharded tensor is
+ // created from the local shards with the provided sharding annotation
+ // attached. The order of the shards should coincide with the order of
+ // devices returned by `torch_xla.runtime.local_runtime_devices()`.
+ m.def(
+ "_global_tensor_from_cpu_shards",
+ [](const std::vector& shards, const xla::OpSharding& sharding,
+ std::optional>& global_shape) -> at::Tensor {
+ XLA_CHECK(UseVirtualDevice())
+ << "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
+ auto local_devices = runtime::GetComputationClient()->GetLocalDevices();
+ XLA_CHECK(local_devices.size() == shards.size())
+ << "Must specify a shard for each local device";
+ XLA_CHECK(!global_shape.has_value() ||
+ global_shape.value().size() == shards[0].sizes().size())
+ << "Global shape rank must agree with shard rank: expected rank "
+ << shards[0].sizes().size() << ", got "
+ << global_shape.value().size();
+
+ if (!global_shape.has_value()) {
+ // Set a default value for the global shape based on the sharding
+ // type.
+ if (sharding.type() == xla::OpSharding::OTHER) {
+ // Infer the global shape to be the shard shape scaled by the tiling
+ // dimensionality.
+ auto tile_shape = sharding.tile_assignment_dimensions();
+ global_shape = std::vector();
+ for (int dim = 0; dim < shards[0].sizes().size(); ++dim) {
+ auto global_dim = tile_shape[dim] * shards[0].sizes()[dim];
+ global_shape->push_back(global_dim);
+ }
+ } else if (sharding.type() == xla::OpSharding::REPLICATED) {
+ global_shape = shards[0].sizes().vec();
+ } else {
+ XLA_ERROR() << "Unsupported OpSharding type: " << sharding.type();
+ }
+ }
+
+ auto device = GetVirtualDevice();
+ auto primitive_type =
+ MakeXlaPrimitiveType(shards[0].type().scalarType(), &device);
+ xla::Shape tensor_shape = MakeArrayShapeFromDimensions(
+ global_shape.value(), /*dynamic_dimensions=*/{}, primitive_type,
+ static_cast(device.type()));
+ auto sharding_spec =
+ std::make_shared(sharding, tensor_shape);
+
+ // Verify that the shard shape is correct for the global shape and
+ // sharding spec.
+ auto expected_shard_shape = ShardingUtil::GetShardShape(sharding_spec);
+ for (auto shard : shards) {
+ XLA_CHECK(shard.sizes() == expected_shard_shape)
+ << "Input shard shape must include padding: " << shard.sizes()
+ << " vs " << expected_shard_shape;
+ }
+
+ auto data_handle = ShardingUtil::CreateShardedData(
+ shards, local_devices, sharding_spec);
+ XLATensorPtr xla_tensor = XLATensor::Create(std::move(data_handle));
+ xla_tensor->SetShardingSpec(*sharding_spec);
+ auto tensor = bridge::AtenFromXlaTensor(std::move(xla_tensor));
+ return torch::autograd::make_variable(tensor,
+ shards[0].requires_grad());
+ },
+ py::arg("shards"), py::arg("sharding"),
+ py::arg("global_shape") = py::none());
// Returns the local shards of the tensor, with values taken from the
// underlying ComputationClient::GetDataShards. As such, the shards will
// contain any padding that was applied to ensure they all have the same
diff --git a/torch_xla/csrc/ops/masked_fill.cpp b/torch_xla/csrc/ops/masked_fill.cpp
deleted file mode 100644
index baad805fd446..000000000000
--- a/torch_xla/csrc/ops/masked_fill.cpp
+++ /dev/null
@@ -1,42 +0,0 @@
-#include "torch_xla/csrc/ops/masked_fill.h"
-
-#include "torch_xla/csrc/helpers.h"
-#include "torch_xla/csrc/lowering_context.h"
-#include "torch_xla/csrc/ops/scalar.h"
-#include "xla/client/lib/constants.h"
-
-namespace torch_xla {
-
-MaskedFill::MaskedFill(const torch::lazy::Value& input,
- const torch::lazy::Value& mask, const at::Scalar& value)
- : XlaNode(torch::lazy::OpKind(at::aten::masked_fill), {input, mask},
- GetXlaShape(input),
- /*num_outputs=*/1, ScalarHash(value)),
- value_(std::move(value)) {}
-
-torch::lazy::NodePtr MaskedFill::Clone(torch::lazy::OpList operands) const {
- return torch::lazy::MakeNode(operands.at(0), operands.at(1),
- value_);
-}
-
-XlaOpVector MaskedFill::Lower(LoweringContext* loctx) const {
- xla::XlaOp input = loctx->GetOutputOp(operand(0));
- xla::XlaOp mask = loctx->GetOutputOp(operand(1));
- xla::XlaOp zero = xla::Zero(loctx->builder(), XlaHelpers::TypeOfXlaOp(mask));
- xla::XlaOp mask_pred = xla::Ne(mask, zero);
- // Input shape is the same as output shape.
- const xla::Shape& input_shape = xla_shape();
- xla::XlaOp value =
- xla::Broadcast(XlaHelpers::ScalarValue(value_, input_shape.element_type(),
- input.builder()),
- input_shape.dimensions());
- return ReturnOp(xla::Select(mask_pred, value, input), loctx);
-}
-
-std::string MaskedFill::ToString() const {
- std::stringstream ss;
- ss << XlaNode::ToString() << ", value=" << value_;
- return ss.str();
-}
-
-} // namespace torch_xla
diff --git a/torch_xla/csrc/ops/masked_fill.h b/torch_xla/csrc/ops/masked_fill.h
deleted file mode 100644
index 8269b4678bbe..000000000000
--- a/torch_xla/csrc/ops/masked_fill.h
+++ /dev/null
@@ -1,29 +0,0 @@
-#ifndef XLA_TORCH_XLA_CSRC_OPS_MASKED_FILL_H_
-#define XLA_TORCH_XLA_CSRC_OPS_MASKED_FILL_H_
-
-#include
-
-#include "torch_xla/csrc/ir.h"
-
-namespace torch_xla {
-
-class MaskedFill : public XlaNode {
- public:
- MaskedFill(const torch::lazy::Value& input, const torch::lazy::Value& mask,
- const at::Scalar& value);
-
- std::string ToString() const override;
-
- torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;
-
- XlaOpVector Lower(LoweringContext* loctx) const override;
-
- at::Scalar value() const { return value_; }
-
- private:
- at::Scalar value_;
-};
-
-} // namespace torch_xla
-
-#endif // XLA_TORCH_XLA_CSRC_OPS_MASKED_FILL_H_
\ No newline at end of file
diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp
index 28587c38df8b..1fbb0dec44b0 100644
--- a/torch_xla/csrc/ops/ops_lower_fn.cpp
+++ b/torch_xla/csrc/ops/ops_lower_fn.cpp
@@ -548,6 +548,20 @@ torch_xla::XlaOpVector LogSigmoidBackward::Lower(LoweringContext* loctx) const {
BuildLogSigmoidBackward(xla_grad_output, xla_input, xla_buffer), loctx);
}
+torch_xla::XlaOpVector MaskedFillScalar::Lower(LoweringContext* loctx) const {
+ xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
+ xla::XlaOp mask = loctx->GetOutputOp(operand(1));
+ xla::XlaOp scalar = loctx->GetOutputOp(operand(2));
+ return ReturnOp(BuildMaskedFillScalar(xla_input, mask, scalar), loctx);
+}
+
+torch_xla::XlaOpVector MaskedFillTensor::Lower(LoweringContext* loctx) const {
+ xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
+ xla::XlaOp mask = loctx->GetOutputOp(operand(1));
+ xla::XlaOp tensor = loctx->GetOutputOp(operand(2));
+ return ReturnOp(BuildMaskedFillScalar(xla_input, mask, tensor), loctx);
+}
+
torch_xla::XlaOpVector Maximum::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
xla::XlaOp xla_other = loctx->GetOutputOp(operand(1));
diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp
index b9a9d4048390..101476fa1127 100644
--- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp
+++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp
@@ -637,6 +637,24 @@ xla::Shape LogSigmoidBackwardOutputShape(const torch::lazy::Value& grad_output,
return GetXlaShape(grad_output);
}
+xla::Shape MaskedFillScalarOutputShape(const torch::lazy::Value& input,
+ const torch::lazy::Value& mask,
+ const torch::lazy::Value& value) {
+ auto lower_for_shape_fn =
+ [&](absl::Span operands) -> xla::XlaOp {
+ return BuildMaskedFillScalar(operands[0], operands[1], operands[2]);
+ };
+ return InferOutputShape(
+ {GetXlaShape(input), GetXlaShape(mask), GetXlaShape(value)},
+ lower_for_shape_fn);
+}
+
+xla::Shape MaskedFillTensorOutputShape(const torch::lazy::Value& input,
+ const torch::lazy::Value& mask,
+ const torch::lazy::Value& value) {
+ return MaskedFillScalarOutputShape(input, mask, value);
+}
+
xla::Shape MaximumOutputShape(const torch::lazy::Value& input,
const torch::lazy::Value& other) {
auto lower_for_shape_fn =
diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.h b/torch_xla/csrc/ops/ops_xla_shape_fn.h
index 7a79196fb772..57cda7b83f20 100644
--- a/torch_xla/csrc/ops/ops_xla_shape_fn.h
+++ b/torch_xla/csrc/ops/ops_xla_shape_fn.h
@@ -210,6 +210,14 @@ xla::Shape LogSigmoidBackwardOutputShape(const torch::lazy::Value& grad_output,
const torch::lazy::Value& input,
const torch::lazy::Value& buffer);
+xla::Shape MaskedFillScalarOutputShape(const torch::lazy::Value& input,
+ const torch::lazy::Value& mask,
+ const torch::lazy::Value& value);
+
+xla::Shape MaskedFillTensorOutputShape(const torch::lazy::Value& input,
+ const torch::lazy::Value& mask,
+ const torch::lazy::Value& value);
+
xla::Shape MaximumOutputShape(const torch::lazy::Value& input,
const torch::lazy::Value& other);
diff --git a/torch_xla/csrc/resize_ops.cpp b/torch_xla/csrc/resize_ops.cpp
index db1d90b2d6c7..d77d1dd84100 100644
--- a/torch_xla/csrc/resize_ops.cpp
+++ b/torch_xla/csrc/resize_ops.cpp
@@ -267,7 +267,7 @@ xla::XlaOp LowerForward2d(const std::string& target, xla::XlaOp input,
XlaDeviceType hw_type =
static_cast(bridge::GetCurrentDevice().type());
- if (hw_type == XlaDeviceType::TPU) {
+ if (hw_type == XlaDeviceType::TPU || hw_type == XlaDeviceType::NEURON) {
// TPU uses custom call implementation
resized =
xla::CustomCall(input.builder(), target, {tinput}, resized_shape,
diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp
index 67e694491bcc..6f2cfbf6b8e1 100644
--- a/torch_xla/csrc/tensor_methods.cpp
+++ b/torch_xla/csrc/tensor_methods.cpp
@@ -62,7 +62,6 @@
#include "torch_xla/csrc/ops/linspace.h"
#include "torch_xla/csrc/ops/log_softmax.h"
#include "torch_xla/csrc/ops/logsumexp.h"
-#include "torch_xla/csrc/ops/masked_fill.h"
#include "torch_xla/csrc/ops/masked_scatter.h"
#include "torch_xla/csrc/ops/masked_select.h"
#include "torch_xla/csrc/ops/max_in_dim.h"
@@ -1574,22 +1573,6 @@ XLATensorPtr lt(const XLATensorPtr& input, const XLATensorPtr& other) {
return DispatchComparisonOp(at::aten::lt, input, other);
}
-XLATensorPtr masked_fill(XLATensorPtr& input, const XLATensorPtr& mask,
- const at::Scalar& value) {
- torch::lazy::ScopePusher ir_scope(at::aten::masked_fill.toQualString());
- auto input_value = input->GetIrValue();
- // Expand input tensor to mask if needed (same as masked_scatter below).
- // An additional check makes sure to only expand if the rank of input tensor
- // is less than that of the mask tensor.
- if (input->shape().get().rank() <= mask->shape().get().rank() &&
- input->shape().get().dimensions() < mask->shape().get().dimensions()) {
- input_value = MaybeExpand(input->GetIrValue(), mask->shape());
- }
- return input->CreateFrom(torch::lazy::MakeNode(
- input_value, MaybeExpand(mask->GetIrValue(), GetXlaShape(input_value)),
- value));
-}
-
XLATensorPtr masked_scatter(XLATensorPtr& input, const XLATensorPtr& mask,
const XLATensorPtr& source) {
torch::lazy::ScopePusher ir_scope(at::aten::masked_scatter.toQualString());
diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h
index 9462b1848d72..88d6e8b44965 100644
--- a/torch_xla/csrc/tensor_methods.h
+++ b/torch_xla/csrc/tensor_methods.h
@@ -483,9 +483,6 @@ XLATensorPtr lt(const XLATensorPtr& input, const at::Scalar& other);
XLATensorPtr lt(const XLATensorPtr& input, const XLATensorPtr& other);
-XLATensorPtr masked_fill(XLATensorPtr& input, const XLATensorPtr& mask,
- const at::Scalar& value);
-
XLATensorPtr masked_scatter(XLATensorPtr& input, const XLATensorPtr& mask,
const XLATensorPtr& source);
diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp
index f7da463fb647..cde74256eeee 100644
--- a/torch_xla/csrc/xla_sharding_util.cpp
+++ b/torch_xla/csrc/xla_sharding_util.cpp
@@ -706,7 +706,8 @@ void ShardingUtil::PrepareOutputShardingPropagation(
}
runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData(
- std::vector& local_shards, std::vector& devices,
+ const std::vector& local_shards,
+ const std::vector& devices,
const XLATensor::ShardingSpecPtr& sharding_spec) {
XLA_CHECK(local_shards.size() == devices.size())
<< "A device must be speficied for each shard";
diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h
index 4a595f4e99b0..32060c7fc098 100644
--- a/torch_xla/csrc/xla_sharding_util.h
+++ b/torch_xla/csrc/xla_sharding_util.h
@@ -147,7 +147,8 @@ class ShardingUtil {
// Transfers the individual shards to the devices and returns a DataPtr for
// the PjRtShardedData wrapping the shards.
static runtime::ComputationClient::DataPtr CreateShardedData(
- std::vector& shards, std::vector& devices,
+ const std::vector& shards,
+ const std::vector& devices,
const XLATensor::ShardingSpecPtr& sharding_spec);
};
diff --git a/torch_xla/experimental/distributed_checkpoint/__init__.py b/torch_xla/experimental/distributed_checkpoint/__init__.py
new file mode 100644
index 000000000000..cad57c3a4058
--- /dev/null
+++ b/torch_xla/experimental/distributed_checkpoint/__init__.py
@@ -0,0 +1,8 @@
+from .manager import CheckpointManager
+from .planners import SPMDSavePlanner, SPMDLoadPlanner
+
+__all__ = [
+ "CheckpointManager",
+ "SPMDSavePlanner",
+ "SPMDLoadPlanner",
+]
diff --git a/torch_xla/experimental/_distributed_checkpoint_helpers.py b/torch_xla/experimental/distributed_checkpoint/_helpers.py
similarity index 100%
rename from torch_xla/experimental/_distributed_checkpoint_helpers.py
rename to torch_xla/experimental/distributed_checkpoint/_helpers.py
diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py
new file mode 100644
index 000000000000..cd36cbe1eb64
--- /dev/null
+++ b/torch_xla/experimental/distributed_checkpoint/manager.py
@@ -0,0 +1,148 @@
+import torch.distributed.checkpoint as dist_cp
+import torch_xla.experimental.distributed_checkpoint as xc
+
+from typing import List, Optional
+from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
+
+
+class CheckpointManager:
+ """
+ The CheckpointManager class provides a higher-level wrapper around the
+ torch.distributed.checkpoint APIs to manage checkpointing. It builds on top
+ of those APIs to enable a few key features:
+ - Per-step checkpointing: Each checkpoint taken by the CheckpointManager is
+ identified by the step at which it was taken, and any step tracked
+ by the CheckpointManager can be restored.
+ - Async checkpointing: The torch.distributed.checkpoint APIs are
+ synchronous, which will block training for the duration of the
+ checkpoint. The CheckpointManager's save_async method can be used to
+ offload checkpointing to a background thread, unblocking training
+ while the checkpoint is written to persistent storage.
+ - Automatic checkpointing: If the training process would be shut down due
+ to a SIGTERM, the CheckpointManager will automatically take a
+ checkpoint at the next step.
+ - Native fsspec integration: Any storage protocol compatible with fsspec
+ can be used with CheckpointManager.
+
+ The intended usage of CheckpointManager is as follows:
+
+ >>> # Create a CheckpointManager to checkpoint every 10 steps into GCS.
+ >>> chkpt_mgr = CheckpointManager('gs://my-bucket/my-experiment', 10)
+
+ >>> # Select a checkpoint to restore from, and restore if applicable
+ >>> tracked_steps = chkpt_mgr.all_steps()
+ >>> if tracked_steps:
+ >>> # Choose the highest step
+ >>> best_step = max(tracked_steps)
+ >>> state_dict = {'model': model.state_dict()}
+ >>> chkpt_mgr.restore(best_step, state_dict)
+ >>> model.load_state_dict(state_dict['model'])
+
+ >>> # Call `save` or `save_async` every step within the train loop.
+ >>> for step, data in enumerate(dataloader):
+ >>> ...
+ >>> state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
+ >>> if chkpt_mgr.save_async(step, state_dict):
+ >>> print(f'Checkpoint taken at step {step}')
+
+ By calling `save` or `save_async` every step, the CheckpointManager has the
+ opportunity to take a checkpoint on steps which are out-of-cycle with its
+ step_period, as would be the case in auto checkpointing.
+
+ This class is inspired by Orbax's CheckpointManager, which can be found here:
+ https://github.com/google/orbax/blob/efc079c4e5b437782a80138913d322cb3ed365c7/checkpoint/orbax/checkpoint/checkpoint_manager.py
+ """
+
+ def __init__(self,
+ path: str,
+ save_period: int,
+ max_to_keep: Optional[int] = -1,
+ async_queue_size: Optional[int] = 1):
+ """
+ Create a checkpoint manager that reads and writes checkpoints into
+ the provided directory.
+
+ Args:
+ path: The base path for the CheckpointManager to write checkpoints into.
+ save_period: The number of steps between saving checkpoints.
+ max_to_keep: The maximum number of checkpoints to be tracked by the
+ CheckpointManager. When a new checkpoint will be taken, the
+ checkpoint for the lowest tracked step will be deleted.
+ Default: -1, indicating no upper bound on the number of checkpoints.
+ async_queue_size: The size of the execution queue which processes async
+ checkpoints. This should be a small value to ensure training doesn't
+ get too far ahead of the last finished checkpoint, but increasing
+ the value to 2 can unblock training when there are transient
+ network issues which slow down the active checkpoint.
+ Default: 1, which only allows a single async checkpoint to be
+ pending at a time.
+ """
+ raise NotImplementedError
+
+ def should_save(self, step: int) -> bool:
+ """
+ Returns true if a checkpoint should be saved for the current step or if
+ a preemption has been detected.
+ """
+ raise NotImplementedError
+
+ def save(self,
+ step,
+ state_dict: STATE_DICT_TYPE,
+ force: Optional[bool] = False) -> bool:
+ """
+ Take a checkpoint synchronously if `self.should_save(step)`.
+
+ Args:
+ step: The current training step.
+ state_dict: The state dict to be checkpointed.
+ force: Option to force a checkpoint to be taken regardless of the result
+ of `should_save(step)`.
+ Returns:
+ True if a checkpoint was taken and False otherwise.
+ """
+ raise NotImplementedError
+
+ def save_async(self,
+ step: int,
+ state_dict: STATE_DICT_TYPE,
+ force: Optional[bool] = False) -> bool:
+ """
+ Take a checkpoint asynchronously if `self.should_save(step)`. The
+ input state_dict will be transferred to the CPU device using the
+ `sharded_cpu_state_dict` function.
+
+ This function will do the following:
+ 1. Transfer `state_dict` to the CPU device.
+ 2. Dispatch the checkpoint workload to an asynchronous execution
+ queue. This will block training until the ongoing async
+ checkpoint finishes when the queue is full.
+
+ Args:
+ step: The current training step.
+ state_dict: The state dict to be checkpointed.
+ force: Option to force a checkpoint to be taken regardless of the result
+ of `should_save(step)`.
+ Returns:
+ True if a checkpoint was taken and False otherwise.
+ """
+ raise NotImplementedError
+
+ def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None:
+ """
+ Restores the checkpoint taken at the given step into the state_dict. The
+ caller is responsible for calling `model.load_state_dict` to restore any
+ non-tensor values.
+
+ Args:
+ step: The step whose checkpoint is to be restored.
+ state_dict: The state dict to restore the checkpoint into. Values are
+ updated in-place within the state_dict.
+ """
+ raise NotImplementedError
+
+ def all_steps(self) -> List[int]:
+ """
+ List all steps tracked by the CheckpointManager.
+ """
+ raise NotImplementedError
diff --git a/torch_xla/experimental/distributed_checkpoint.py b/torch_xla/experimental/distributed_checkpoint/planners.py
similarity index 99%
rename from torch_xla/experimental/distributed_checkpoint.py
rename to torch_xla/experimental/distributed_checkpoint/planners.py
index 5b1ee97b7d64..fbf466ff28a9 100644
--- a/torch_xla/experimental/distributed_checkpoint.py
+++ b/torch_xla/experimental/distributed_checkpoint/planners.py
@@ -35,16 +35,11 @@
from torch.distributed.checkpoint.utils import find_state_dict_object
from torch.utils._pytree import tree_map
from torch_xla.experimental.xla_sharding import XLAShardedTensor, XLAShard
-from torch_xla.experimental._distributed_checkpoint_helpers import (
+from torch_xla.experimental.distributed_checkpoint._helpers import (
FLATTEN_MAPPING, flatten_state_dict, dedup_tensors, _is_sharded_tensor,
set_element, narrow_tensor_by_index, _unwrap_xla_sharded_tensor, _CpuShards)
from typing import Any, Dict, List, Tuple, Union
-__all__ = [
- "SPMDSavePlanner",
- "SPMDLoadPlanner",
-]
-
class SPMDSavePlanner(SavePlanner):
"""
diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py
index 95f4a88128bb..21d0e2e570ac 100644
--- a/torch_xla/experimental/xla_sharding.py
+++ b/torch_xla/experimental/xla_sharding.py
@@ -87,6 +87,14 @@ def get_op_sharding(self,
Return the OpSharding for the given partition spec. This is an expensive
operation as the mesh grows, so the value is cached for reuse.
"""
+ partition_spec = _translate_named_partition_spec(self, partition_spec)
+ flat_specs = np.hstack([d for d in partition_spec])
+ specs = [d for d in flat_specs if d is not None]
+ assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \
+ f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
+ assert len(specs) == len(np.unique(specs)), \
+ f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."
+
tile_assignment = _get_tile_assignment(self, partition_spec)
if len(tile_assignment.shape) > len(partition_spec):
# Use partial replication for sharding a tensor over a higher-rank mesh
@@ -482,19 +490,12 @@ def mark_sharding(
assert num_devices > 0, "This requires XLA supported device(s)."
assert mesh.size() == num_devices, \
f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
- partition_spec = _translate_named_partition_spec(mesh, partition_spec)
# We only allow fully specified `partition_spec` to be applicable, as opposed
# to filling in the unspecified replicated dims. Fully specified `partiion_spec`
# should be of the same rank as `t`. This is to support partial replication
# where the group assignment may vary with different input ranks.
assert len(t.shape) == len(partition_spec), \
f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."
- flat_specs = np.hstack([d for d in partition_spec])
- specs = [d for d in flat_specs if d is not None]
- assert all(d >= 0 and d < len(mesh.mesh_shape) for d in specs), \
- f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
- assert len(specs) == len(np.unique(specs)), \
- f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."
op_sharding = mesh.get_op_sharding(partition_spec)