diff --git a/.circleci/common.sh b/.circleci/common.sh index 0f09e02f2981..aaf61901a179 100755 --- a/.circleci/common.sh +++ b/.circleci/common.sh @@ -150,10 +150,10 @@ function run_torch_xla_python_tests() { # CUDA tests if [ -x "$(command -v nvidia-smi)" ]; then - # These tests fail on CUDA with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840) - # PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 - # PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1 - # XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 + # These tests fail on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840) + PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 + PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1 + XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 # Syncfree SGD optimizer tests if [ -d ./torch_xla/amp/syncfree ]; then echo "Running Syncfree Optimizer Test" diff --git a/.kokoro/Dockerfile b/.kokoro/Dockerfile index a7ffc24dfdd5..32cc499477d1 100644 --- a/.kokoro/Dockerfile +++ b/.kokoro/Dockerfile @@ -3,7 +3,7 @@ WORKDIR / RUN apt-get update RUN apt-get -y upgrade RUN apt-get -y install clang time -RUN pip install pytest tf-nightly +RUN pip install pytest ARG USE_MKLDNN=0 ARG SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2 ARG DISABLE_XRT=1 @@ -53,4 +53,4 @@ RUN time pip install -e . # Run tests ENV PJRT_DEVICE=CPU ENV XLA_STABLEHLO_COMPILE=1 -ENTRYPOINT pytest test/stablehlo \ No newline at end of file +ENTRYPOINT pytest test/stablehlo diff --git a/CODEGEN_MIGRATION_GUIDE.md b/CODEGEN_MIGRATION_GUIDE.md index f234c5193ca7..6cad3d3b84d3 100644 --- a/CODEGEN_MIGRATION_GUIDE.md +++ b/CODEGEN_MIGRATION_GUIDE.md @@ -7,7 +7,7 @@ As PyTorch/XLA migrates to the LTC (Lazy Tensor Core), we need to clean up the e You should follow the instructions in [here](https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md) to install required dependencies and build pytorch and pytorch/XLA from the source. You do not need access to TPU to implement the lowering. It is recommended to experiment on a workstation and configure it to use XLA:CPU. You can configure Pytorch/XLA to use XLA:CPU by running ``` -export XRT_DEVICE_MAP="CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0" XRT_WORKERS="localservice:0;grpc://localhost:51011" +export PJRT_DEVICE=CPU ``` It is also recommended that you're familiar with our [op lowering process](https://github.com/pytorch/xla/blob/master/OP_LOWERING_GUIDE.md) before you work on the codegen. diff --git a/README.md b/README.md index d894d53a46ba..bfda642b2f0e 100644 --- a/README.md +++ b/README.md @@ -25,10 +25,12 @@ started: ## Getting Started -To install PyTorch/XLA a new VM: +**PyTorch/XLA is now on PyPI!** + +To install PyTorch/XLA a new TPU VM: ``` -pip install torch~=2.0.0 https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl +pip install torch~=2.1.0 torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html ``` To update your existing training loop, make the following changes: @@ -130,26 +132,37 @@ Our comprehensive user guides are available at: ## Available docker images and wheels -### Wheel +### Python packages + +PyTorch/XLA releases starting with version r2.1 will be available on PyPI. You +can now install the main build with `pip install torch_xla`. To also install the +Cloud TPU plugin, install the optional `tpu` dependencies: + +``` +pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html +``` + +GPU, XRT (legacy runtime), and nightly builds are available in our public GCS +bucket. | Version | Cloud TPU VMs Wheel | | --- | ----------- | -| 2.0 (Python 3.8) | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl` | -| nightly >= 2023/04/25 (Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl` | -| nightly >= 2023/04/25 (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl` | +| 2.1 (CUDA 12.0 + Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.0/torch_xla-2.1.0-cp38-cp38-manylinux_2_28_x86_64.whl` | +| 2.1 (XRT + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch_xla-2.1.0%2Bxrt-cp310-cp310-manylinux_2_28_x86_64.whl` | +| nightly (Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl` | +| nightly (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl` |
- older versions + +older versions | Version | Cloud TPU VMs Wheel | |---------|-------------------| +| 2.0 (Python 3.8) | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl` | | 1.13 | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.13-cp38-cp38-linux_x86_64.whl` | | 1.12 | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.12-cp38-cp38-linux_x86_64.whl` | | 1.11 | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.11-cp38-cp38-linux_x86_64.whl` | | 1.10 | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.10-cp38-cp38-linux_x86_64.whl` | -| nightly <= 2023/04/25 | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl` | - -

@@ -204,53 +217,58 @@ pip3 install torch_xla[tpuvm] This is only required on Cloud TPU VMs. + + ### Docker | Version | Cloud TPU VMs Docker | | --- | ----------- | -2.0 | `gcr.io/tpu-pytorch/xla:r2.0_3.8_tpuvm` | -1.13 | `gcr.io/tpu-pytorch/xla:r1.13_3.8_tpuvm` | -nightly python 3.10 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm` | -nightly python 3.8 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm` | -nightly python 3.10(>= 2023/04/25) | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_YYYYMMDD` | -nightly python 3.8(>= 2023/04/25) | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_YYYYMMDD` | -nightly at date(< 2023/04/25) | `gcr.io/tpu-pytorch/xla:nightly_3.8_tpuvm_YYYYMMDD` | +| 2.1 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_tpuvm` | +| 2.0 | `gcr.io/tpu-pytorch/xla:r2.0_3.8_tpuvm` | +| 1.13 | `gcr.io/tpu-pytorch/xla:r1.13_3.8_tpuvm` | +| nightly python | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm` |
-| Version | GPU CUDA 12.0 + Python 3.8 Docker | +| Version | GPU CUDA 12.0 Docker | | --- | ----------- | +| 2.1 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_12.0` | | nightly | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.0` | -| nightly at date(>=2023/06/27) | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.0_YYYYMMDD` | +| nightly at date | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.0_YYYYMMDD` |
-| Version | GPU CUDA 11.8 + Python 3.8 Docker | +| Version | GPU CUDA 11.8 + Docker | | --- | ----------- | +| 2.1 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_11.8` | | 2.0 | `gcr.io/tpu-pytorch/xla:r2.0_3.8_cuda_11.8` | | nightly | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8` | -| nightly at date(>=2023/04/25) | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8_YYYYMMDD` | -| nightly at date(<2023/04/25) | `gcr.io/tpu-pytorch/xla:nightly_3.8_cuda_11.8_YYYYMMDD` | +| nightly at date | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8_YYYYMMDD` |
-| Version | GPU CUDA 11.7 + Python 3.8 Docker | +
+ +older versions + +| Version | GPU CUDA 11.7 + Docker | | --- | ----------- | | 2.0 | `gcr.io/tpu-pytorch/xla:r2.0_3.8_cuda_11.7` |
-| Version | GPU CUDA 11.2 + Python 3.8 Docker | +| Version | GPU CUDA 11.2 + Docker | | --- | ----------- | | 1.13 | `gcr.io/tpu-pytorch/xla:r1.13_3.8_cuda_11.2` |
-| Version | GPU CUDA 11.2 + Python 3.7 Docker | +| Version | GPU CUDA 11.2 + Docker | | --- | ----------- | -1.13 | `gcr.io/tpu-pytorch/xla:r1.13_3.7_cuda_11.2` | -1.12 | `gcr.io/tpu-pytorch/xla:r1.12_3.7_cuda_11.2` | +| 1.13 | `gcr.io/tpu-pytorch/xla:r1.13_3.7_cuda_11.2` | +| 1.12 | `gcr.io/tpu-pytorch/xla:r1.12_3.7_cuda_11.2` | +
To run on [compute instances with GPUs](https://cloud.google.com/compute/docs/gpus/create-vm-with-gpus). diff --git a/TROUBLESHOOTING.md b/TROUBLESHOOTING.md index ed4b9173fa59..842deabd1868 100644 --- a/TROUBLESHOOTING.md +++ b/TROUBLESHOOTING.md @@ -203,6 +203,8 @@ only be enabled for debugging. * ```XLA_SAVE_TENSORS_FMT```: The format of the graphs stored within the _XLA_SAVE_TENSORS_FILE_ file. Can be ```text``` (the default), ```dot``` (the _Graphviz_ format) or ```hlo```. +* ```XLA_FLAGS=--xla_dump_to```: If set to ```=/tmp/dir_name```, XLA compiler will dump the unoptimized and optimzed HLO per compilation. + * ```XLA_METRICS_FILE```: If set, the path to a local file where the internal metrics will be saved at every step. Metrics will be appended to the file, if already existing. @@ -261,61 +263,3 @@ only be enabled for debugging. * ```XLA_DUMP_HLO_GRAPH```: If set to `=1` in case of a compilation or execution error the offending HLO graph will be dumped as part of the runtime error raised by `xla_util.cc`. -### Retrieving Stack Traces - -In the event that the _PyTorch_ process is hanging, it might be useful to include the stack -traces together with the GitHub issue. - -First thing is to find out which PID the _PyTorch_ process is associated with. Using the ```ps``` -command it is possible to find that information. It will be a _python_ process running your -main _python_ file. - -In order to allow _GDB_ to attach a user process the following command should be run as root: - -```Shell -echo 0 > /proc/sys/kernel/yama/ptrace_scope -``` - -The above command remains active until the machine is rebooted. - -The, given the PID, it is possible to grab the stack traces with the following command: - -```Shell -./scripts/dump_stacks.py PID > /tmp/stack-traces.log -``` - -## Using debug_run.py To Collect Debug Information - -A utility is provided in `scripts/debug_run.py` which can be used to create a `tar.gz` -archive with the information required to debug _PyTorch/XLA_ executions. - -Example: - -```Shell -./scripts/debug_run.py --outfile /tmp/debug_run.tar.gz -- python -u SCRIPT [ARGS...] -``` - -The _python_ `-u` flag is suggested to disable buffering so that captured logs are correctly -interleaved (otherwise STDOUT will be rendered after all STDERR). - -The above command line example will leave the temporary folder containing the archived -information on the filesystem. Use the `--tidy` flag to have that removed on exit: - -```Shell -./scripts/debug_run.py --tidy --outfile /tmp/debug_run.tar.gz -- python -u SCRIPT [ARGS...] -``` - -The `debug_run.tar.gz` file should then be attached to bug reports when necessary. - -Since the script will collect a lot of data, it should usually be let run for no more -than hundred steps or so. - -If the SCRIPT has arguments to control the number of steps, those should be used, -otherwise hitting `CTRL^C` will interrupt the run. - -It is also suggested to run in single-core mode, to minimize the amount of data. -Running in single-core mode is also strongly suggested when debugging execution issues. - -## Common Issues - -* `Missing XLA configuration` error message: You need to set `XRT_TPU_CONFIG` if using TPUs. If using GPUs set `GPU_NUM_DEVICES=N` for `N` number of GPUs. If using CPUs set `XRT_DEVICE_MAP="CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0"` and `XRT_WORKERS="localservice:0;grpc://localhost:9002"` diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index 74055955bd7b..bdb6c38e8cc0 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -63,6 +63,8 @@ full_codegen: - log_sigmoid_forward - lt.Scalar - lt.Tensor + - masked_fill.Scalar + - masked_fill.Tensor - maximum - minimum - native_dropout_backward @@ -217,8 +219,6 @@ supported: - log2 - log10 - logsumexp - - masked_fill.Scalar - - masked_fill.Tensor - masked_scatter - masked_select - max diff --git a/docs/pjrt.md b/docs/pjrt.md index aedcfacec2f4..265c0abcce50 100644 --- a/docs/pjrt.md +++ b/docs/pjrt.md @@ -1,7 +1,7 @@ # PJRT Runtime PyTorch/XLA has migrated from the TensorFlow-based XRT runtime to the [PJRT -runtime](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla/pjrt) +runtime](https://github.com/openxla/xla/tree/main/xla/pjrt) used by [JAX](https://github.com/google/jax). If you encounter a bug with PJRT, please file an issue on GitHub with the diff --git a/infra/tpu-pytorch-releases/artifacts.auto.tfvars b/infra/tpu-pytorch-releases/artifacts.auto.tfvars index bd5edb6d02af..cbec743e3806 100644 --- a/infra/tpu-pytorch-releases/artifacts.auto.tfvars +++ b/infra/tpu-pytorch-releases/artifacts.auto.tfvars @@ -35,14 +35,14 @@ xrt_versioned_builds = [ { accelerator = "tpu" python_version = "3.10" - pytorch_git_rev = "v2.1.0-rc6" + pytorch_git_rev = "v2.1.0" package_version = "2.1.0+xrt" }, { accelerator = "cuda" python_version = "3.10" cuda_version = "12.0" - pytorch_git_rev = "v2.1.0-rc6" + pytorch_git_rev = "v2.1.0" package_version = "2.1.0+xrt" }, ] @@ -51,14 +51,14 @@ xrt_versioned_builds = [ versioned_builds = [ { git_tag = "v2.1.0" - pytorch_git_rev = "v2.1.0-rc6" + pytorch_git_rev = "v2.1.0" package_version = "2.1.0" accelerator = "tpu" bundle_libtpu = "0" }, { git_tag = "v2.1.0" - pytorch_git_rev = "v2.1.0-rc6" + pytorch_git_rev = "v2.1.0" package_version = "2.1.0" accelerator = "tpu" python_version = "3.10" @@ -66,7 +66,7 @@ versioned_builds = [ }, { git_tag = "v2.1.0" - pytorch_git_rev = "v2.1.0-rc6" + pytorch_git_rev = "v2.1.0" package_version = "2.1.0+libtpu" accelerator = "tpu" python_version = "3.10" @@ -84,26 +84,41 @@ versioned_builds = [ }, { git_tag = "v2.1.0" - pytorch_git_rev = "v2.1.0-rc6" + pytorch_git_rev = "v2.1.0" package_version = "2.1.0", accelerator = "cuda" cuda_version = "12.0" }, { git_tag = "v2.1.0" - pytorch_git_rev = "v2.1.0-rc6" + pytorch_git_rev = "v2.1.0" package_version = "2.1.0" accelerator = "cuda" cuda_version = "11.8" }, { git_tag = "v2.1.0" - pytorch_git_rev = "v2.1.0-rc6" + pytorch_git_rev = "v2.1.0" + package_version = "2.1.0" + accelerator = "cuda" + cuda_version = "12.1" + }, + { + git_tag = "v2.1.0" + pytorch_git_rev = "v2.1.0" package_version = "2.1.0" accelerator = "cuda" cuda_version = "11.8" python_version = "3.10" }, + { + git_tag = "v2.1.0" + pytorch_git_rev = "v2.1.0" + package_version = "2.1.0" + accelerator = "cuda" + cuda_version = "12.1" + python_version = "3.10" + }, { git_tag = "v2.0.0" package_version = "2.0" diff --git a/scripts/fixup_binary.py b/scripts/fixup_binary.py deleted file mode 100755 index 8d12d1c78f22..000000000000 --- a/scripts/fixup_binary.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python - -import argparse -import glob -import os -import site -import subprocess - - -def find_torch_xla_site(site_paths): - for site_path in site_paths: - # If there is one named 'torch_xla', this is what we pick. - path = os.path.join(site_path, 'torch_xla', 'lib') - if os.path.isdir(path): - return [site_path, path] - dirs = glob.glob(os.path.join(site_path, 'torch_xla*')) - # Get the most recent one. - for xpath in sorted(dirs, key=os.path.getmtime): - path = os.path.join(xpath, 'lib') - if os.path.isdir(path): - return [site_path, path] - if os.path.isfile(os.path.join(xpath, 'libptxla.so')): - return [site_path, xpath, os.path.join(xpath, 'torch_xla', 'lib')] - raise RuntimeError('Unable to find torch_xla package in {}'.format(site_path)) - - -def find_torch_site(site_paths): - for site_path in site_paths: - path = os.path.join(site_path, 'torch', 'lib') - if os.path.isdir(path): - return [path] - raise RuntimeError('Unable to find torch package in {}'.format(site_path)) - - -def list_rpaths(path): - if subprocess.call(['patchelf', '--shrink-rpath', path]) != 0: - raise RuntimeError('Failed to shrink RPATH folders: {}'.format(path)) - return subprocess.check_output(['patchelf', '--print-rpath', - path]).decode('utf-8').strip('\n').split(':') - - -def set_rpaths(path, rpaths): - if subprocess.call(['patchelf', '--set-rpath', ':'.join(rpaths), path]) != 0: - raise RuntimeError('Failed to set RPATH folders {}: {}'.format( - rpaths, path)) - - -def fixup_binary(args): - site_paths = site.getsitepackages() - xla_rpaths = find_torch_xla_site(site_paths) - torch_rpaths = find_torch_site(site_paths) - rpaths = list_rpaths(args.binary) - rpaths = xla_rpaths + torch_rpaths + rpaths - set_rpaths(args.binary, rpaths) - - -if __name__ == '__main__': - arg_parser = argparse.ArgumentParser() - arg_parser.add_argument( - 'binary', - type=str, - metavar='BINARY', - help='The path to the binary to be patched') - args, files = arg_parser.parse_known_args() - fixup_binary(args) diff --git a/setup.py b/setup.py index b3b39cd709e4..3548701a9cea 100644 --- a/setup.py +++ b/setup.py @@ -307,15 +307,37 @@ def run(self): super().run() +# Read in README.md for our long_description +cwd = os.path.dirname(os.path.abspath(__file__)) +with open(os.path.join(cwd, "README.md"), encoding="utf-8") as f: + long_description = f.read() + setup( name=os.environ.get('TORCH_XLA_PACKAGE_NAME', 'torch_xla'), version=version, description='XLA bridge for PyTorch', + long_description=long_description, + long_description_content_type="text/markdown", url='https://github.com/pytorch/xla', author='PyTorch/XLA Dev Team', author_email='pytorch-xla@googlegroups.com', - # Exclude the build files. - packages=find_packages(exclude=['build']), + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: BSD License", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Programming Language :: C++", + "Programming Language :: Python :: 3", + ], + python_requires=">=3.8.0", + packages=find_packages(include=['torch_xla*']), ext_modules=[ BazelExtension('//:_XLAC.so'), ], @@ -334,12 +356,12 @@ def run(self): }, extras_require={ # On Cloud TPU VM install with: - # $ sudo pip3 install torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.11-cp38-cp38-linux_x86_64.whl + # pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html + 'tpu': [f'libtpu-nightly=={_libtpu_version}'], + # On nightly, install libtpu with `pip install torch_xla[tpuvm]` + # Remove from release branches since this is not allowed by PyPI. 'tpuvm': [f'libtpu-nightly @ {_libtpu_storage_path}'], }, - data_files=[ - 'scripts/fixup_binary.py', - ], cmdclass={ 'build_ext': BuildBazelExtension, 'clean': Clean, diff --git a/test/cpp/test_aten_xla_tensor_4.cpp b/test/cpp/test_aten_xla_tensor_4.cpp index fd669ce49e00..1e61a6fa05ac 100644 --- a/test/cpp/test_aten_xla_tensor_4.cpp +++ b/test/cpp/test_aten_xla_tensor_4.cpp @@ -1081,6 +1081,24 @@ TEST_F(AtenXlaTensorTest, TestMaskedFillBroadcast2) { ExpectCounterChanged("xla::masked_fill", cpp_test::GetIgnoredCounters()); } +TEST_F(AtenXlaTensorTest, TestMaskedFillBroadcast3) { + torch::Tensor input = + torch::rand({2, 1}, torch::TensorOptions(torch::kFloat)); + torch::Tensor mask = + torch::randint(0, 2, {4, 2, 3}, torch::TensorOptions(torch::kBool)); + torch::Scalar value(42); + torch::Tensor result = torch::masked_fill(input, mask, value); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_input = CopyToDevice(input, device); + torch::Tensor xla_mask = CopyToDevice(mask, device); + torch::Tensor xla_result = torch::masked_fill(xla_input, xla_mask, value); + AllClose(result, xla_result); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::masked_fill", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestFill) { torch::Scalar value(42); ForEachDevice([&](const torch::Device& device) { diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 7b3d5eb86fbd..276571e59793 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -16,7 +16,7 @@ create_default_global_save_plan, ) from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner -from torch_xla.experimental._distributed_checkpoint_helpers import ( +from torch_xla.experimental.distributed_checkpoint._helpers import ( _sharded_cpu_state_dict, _CpuShards, _is_sharded_tensor) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index ce2cae18dd6c..1b128164a22b 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -900,6 +900,139 @@ def test_op_sharding_cache(self): xs.mark_sharding(v, mesh, (0, None)) self.assertEqual(met.counter_value("CreateOpSharding"), 2) + def test_from_cpu_shards_replicated(self): + from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards + + # Create an OpSharding with all devices on a single axis + mesh = self._get_mesh((self.n_devices,)) + partition_spec = (None,) + op_sharding = mesh.get_op_sharding(partition_spec) + shards = [torch.arange(4)] * self.n_devices + + # No shape should result in the shape of a single shard. + global_tensor = from_cpu_shards(shards, op_sharding) + self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0])) + + # Specify a valid shape for the global tensor + global_tensor = from_cpu_shards(shards, op_sharding, shards[0].shape) + self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0])) + + # All invalid shapes should raise + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((5,))) + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((3,))) + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((2, 2))) + + def test_from_cpu_shards_tiled(self): + from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards + + # Create an OpSharding with all devices on a single axis + mesh = self._get_mesh((self.n_devices,)) + partition_spec = (0,) + op_sharding = mesh.get_op_sharding(partition_spec) + shards = [torch.LongTensor([i]) for i in range(self.n_devices)] + + global_tensor = from_cpu_shards(shards, op_sharding) + self.assertTrue( + torch.allclose(global_tensor.cpu(), torch.arange(self.n_devices))) + + # Test incorrect number of shards + with self.assertRaises(RuntimeError): + from_cpu_shards(shards[:-1], op_sharding) + + # Test an invalid global shape - too many values. + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((self.n_devices * 2,))) + + # Test an invalid global shape - incorrect rank + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((1, self.n_devices))) + + # Test a valid global shape - restrict the number of meaningful values + # to 1, treating the rest as padding. + global_tensor = from_cpu_shards(shards, op_sharding, torch.Size((1,))) + self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(1))) + + def test_from_cpu_shards_2d(self): + from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards + + # Create an appropriate 2D mesh for the number of devices + if self.n_devices >= 4: + mesh_shape = (self.n_devices // 2, 2) + else: + mesh_shape = (1, self.n_devices) + mesh_2d = self._get_mesh(mesh_shape) + + # Replicated sharding + shards = [torch.LongTensor([self.n_devices])] * self.n_devices + partition_spec = (None, None) + op_sharding = mesh_2d.get_op_sharding(partition_spec) + global_tensor = from_cpu_shards(shards, op_sharding) + self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0])) + + if self.n_devices > 1: + # Tiled sharding + shards = [torch.LongTensor([[i]]) for i in range(self.n_devices)] + partition_spec = (0, 1) + op_sharding = mesh_2d.get_op_sharding(partition_spec) + global_tensor = from_cpu_shards(shards, op_sharding) + expected = torch.arange(self.n_devices).reshape(*mesh_shape) + self.assertTrue(torch.allclose(global_tensor.cpu(), expected)) + + # Partially replicated sharding + shards = [torch.LongTensor([[i]]) for i in range(2)] * ( + self.n_devices // 2) + partition_spec = (None, 1) + op_sharding = mesh_2d.get_op_sharding(partition_spec) + global_tensor = from_cpu_shards(shards, op_sharding) + # Partial replication along the 0th axis represents a global tensor + # of torch.Tensor([[0, 1]]). + expected = torch.arange(2).reshape(1, 2) + self.assertTrue(torch.allclose(global_tensor.cpu(), expected)) + + def test_from_cpu_shards_global_shape(self): + from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards + + mesh = self._get_mesh((self.n_devices,)) + numel = self.n_devices**2 + # The global tensor is torch.arange(numel). + shards = [ + torch.arange(self.n_devices) + (i * self.n_devices) + for i in range(self.n_devices) + ] + partition_spec = (0,) + op_sharding = mesh.get_op_sharding(partition_spec) + + # No global shape specified will include all data from the shards + global_tensor = from_cpu_shards(shards, op_sharding) + self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(numel))) + + # Too large of a global shape will error out + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((numel + 1,))) + + if self.n_devices > 1: + # When the global tensor has fewer elements than the sum of its shards, + # there are two cases: + + # Case 1: If the global shape is within n_devices of numel, the excess + # data is treated as padding and ignored. + for delta in range(self.n_devices): + size = torch.Size((numel - delta,)) + global_tensor = from_cpu_shards(shards, op_sharding, size) + expected = torch.arange(size[0]) + self.assertTrue(torch.allclose(global_tensor.cpu(), expected)) + + # Case 2: Otherwise, it is not possible to have that much padding in a + # sharded tensor, and the shards are incompatible with the shape. + with self.assertRaises(RuntimeError): + shape = torch.Size((numel - self.n_devices,)) + from_cpu_shards(shards, op_sharding, shape) + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((1,))) + if __name__ == '__main__': test = unittest.main() diff --git a/test/test_operations.py b/test/test_operations.py index db46574bd8cf..220805b3fbea 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -309,6 +309,19 @@ def test_get_xla_tensor(self): tx = t.select(1, 12) self.assertEqual(tx, sx.data.cpu()) + def test_masked_fill_scalar(self): + + def fn(tensor): + # Build a mask from the first line of tensor. + # Also, make it have the same rank as the original tensor. + mask = tensor[0].ge(0.5).unsqueeze(dim=0) + # Call masked_fill. + return tensor.masked_fill(mask, 10) + + x = _gen_tensor(2, 2, device=xm.xla_device()) + x_cpu = x.cpu() + self.assertEqual(fn(x_cpu), fn(x)) + class TestRandom(test_utils.XlaTestCase): diff --git a/test/test_train_mp_imagenet_fsdp.py b/test/test_train_mp_imagenet_fsdp.py index a40f3bef74fa..fdfdc8a698c1 100644 --- a/test/test_train_mp_imagenet_fsdp.py +++ b/test/test_train_mp_imagenet_fsdp.py @@ -110,7 +110,7 @@ transformer_auto_wrap_policy) DEFAULT_KWARGS = dict( - batch_size=128, + batch_size=64, test_set_batch_size=64, num_epochs=18, momentum=0.9, diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index ecad6314f495..b51fbba98cdf 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1639,25 +1639,6 @@ at::Tensor XLANativeFunctions::xlogy(const at::Tensor& self, bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); } -at::Tensor XLANativeFunctions::masked_fill(const at::Tensor& self, - const at::Tensor& mask, - const at::Tensor& value) { - TORCH_LAZY_FN_COUNTER("xla::"); - XLA_CHECK_EQ(value.dim(), 0) << "masked_fill_ only supports a 0-dimensional " - << "value tensor, but got tensor " - << "with " << value.dim() << " dimension(s)."; - return masked_fill(self, mask, value.item()); -} - -at::Tensor XLANativeFunctions::masked_fill(const at::Tensor& self, - const at::Tensor& mask, - const at::Scalar& value) { - TORCH_LAZY_FN_COUNTER("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - return bridge::AtenFromXlaTensor(tensor_methods::masked_fill( - self_tensor, bridge::GetXlaTensor(mask), value)); -} - at::Tensor XLANativeFunctions::masked_scatter(const at::Tensor& self, const at::Tensor& mask, const at::Tensor& source) { @@ -3208,7 +3189,7 @@ at::Tensor XLANativeFunctions::upsample_nearest2d_backward( // our XLA lowering. XlaDeviceType hw_type = static_cast(grad_output_tensor->GetDevice().type()); - if (hw_type != XlaDeviceType::TPU) { + if (hw_type != XlaDeviceType::TPU && hw_type != XlaDeviceType::NEURON) { return at::native::call_fallback_fn< &xla_cpu_fallback, ATEN_OP(upsample_nearest2d_backward)>::call(grad_output, output_size, diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp index 89931e77ecff..e5425a93001a 100644 --- a/torch_xla/csrc/data_ops.cpp +++ b/torch_xla/csrc/data_ops.cpp @@ -31,7 +31,7 @@ bool IsSparseGather(const xla::Shape& input_shape, // to avoid gather on a single float on TPU. XlaDeviceType hw_type = static_cast(bridge::GetCurrentDevice().type()); - if (hw_type == XlaDeviceType::TPU) { + if (hw_type == XlaDeviceType::TPU || hw_type == XlaDeviceType::NEURON) { // XLA_DENSE_GATHER_FACTOR can be used to finely control the // sparsity check. static int dense_gather_factor = @@ -144,6 +144,25 @@ xla::XlaOp BuildExpand(xla::XlaOp input, torch::lazy::Iota(output_sizes.size())); } +xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask, + xla::XlaOp scalar) { + const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); + const xla::Shape& mask_shape = ShapeHelper::ShapeOfXlaOp(mask); + + if (!xla::ShapeUtil::Compatible(input_shape, mask_shape)) { + xla::Shape shape = XlaHelpers::GetPromotedShape(input_shape, mask_shape); + input = BuildExpand(input, shape.dimensions()); + mask = BuildExpand(mask, shape.dimensions()); + } + + xla::XlaOp zero = xla::Zero(mask.builder(), XlaHelpers::TypeOfXlaOp(mask)); + xla::XlaOp mask_pred = xla::Ne(mask, zero); + xla::XlaOp update_scalar = + ConvertTo(scalar, ShapeHelper::ShapeOfXlaOp(scalar).element_type(), + ShapeHelper::ShapeOfXlaOp(input).element_type(), nullptr); + return xla::Select(mask_pred, update_scalar, input); +} + std::vector BuildSqueezedDimensions( absl::Span dimensions, int64_t squeeze_dim) { std::vector output_dimensions; diff --git a/torch_xla/csrc/data_ops.h b/torch_xla/csrc/data_ops.h index 5d05e0a6d285..e22821a7bb0e 100644 --- a/torch_xla/csrc/data_ops.h +++ b/torch_xla/csrc/data_ops.h @@ -43,6 +43,9 @@ xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input); xla::XlaOp BuildExpand(xla::XlaOp input, absl::Span output_sizes); +xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask, + xla::XlaOp scalar); + std::vector BuildSqueezedDimensions( absl::Span dimensions, int64_t squeeze_dim); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index b3957d7a68f7..421066ba72cd 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -34,6 +34,7 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/ir_dump_util.h" +#include "torch_xla/csrc/layout_manager.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/runtime/computation_client.h" @@ -1663,6 +1664,72 @@ void InitXlaModuleBindings(py::module m) { } return std::nullopt; }); + // Reassemble the CPU shards into a global tensor. A new sharded tensor is + // created from the local shards with the provided sharding annotation + // attached. The order of the shards should coincide with the order of + // devices returned by `torch_xla.runtime.local_runtime_devices()`. + m.def( + "_global_tensor_from_cpu_shards", + [](const std::vector& shards, const xla::OpSharding& sharding, + std::optional>& global_shape) -> at::Tensor { + XLA_CHECK(UseVirtualDevice()) + << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; + auto local_devices = runtime::GetComputationClient()->GetLocalDevices(); + XLA_CHECK(local_devices.size() == shards.size()) + << "Must specify a shard for each local device"; + XLA_CHECK(!global_shape.has_value() || + global_shape.value().size() == shards[0].sizes().size()) + << "Global shape rank must agree with shard rank: expected rank " + << shards[0].sizes().size() << ", got " + << global_shape.value().size(); + + if (!global_shape.has_value()) { + // Set a default value for the global shape based on the sharding + // type. + if (sharding.type() == xla::OpSharding::OTHER) { + // Infer the global shape to be the shard shape scaled by the tiling + // dimensionality. + auto tile_shape = sharding.tile_assignment_dimensions(); + global_shape = std::vector(); + for (int dim = 0; dim < shards[0].sizes().size(); ++dim) { + auto global_dim = tile_shape[dim] * shards[0].sizes()[dim]; + global_shape->push_back(global_dim); + } + } else if (sharding.type() == xla::OpSharding::REPLICATED) { + global_shape = shards[0].sizes().vec(); + } else { + XLA_ERROR() << "Unsupported OpSharding type: " << sharding.type(); + } + } + + auto device = GetVirtualDevice(); + auto primitive_type = + MakeXlaPrimitiveType(shards[0].type().scalarType(), &device); + xla::Shape tensor_shape = MakeArrayShapeFromDimensions( + global_shape.value(), /*dynamic_dimensions=*/{}, primitive_type, + static_cast(device.type())); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); + + // Verify that the shard shape is correct for the global shape and + // sharding spec. + auto expected_shard_shape = ShardingUtil::GetShardShape(sharding_spec); + for (auto shard : shards) { + XLA_CHECK(shard.sizes() == expected_shard_shape) + << "Input shard shape must include padding: " << shard.sizes() + << " vs " << expected_shard_shape; + } + + auto data_handle = ShardingUtil::CreateShardedData( + shards, local_devices, sharding_spec); + XLATensorPtr xla_tensor = XLATensor::Create(std::move(data_handle)); + xla_tensor->SetShardingSpec(*sharding_spec); + auto tensor = bridge::AtenFromXlaTensor(std::move(xla_tensor)); + return torch::autograd::make_variable(tensor, + shards[0].requires_grad()); + }, + py::arg("shards"), py::arg("sharding"), + py::arg("global_shape") = py::none()); // Returns the local shards of the tensor, with values taken from the // underlying ComputationClient::GetDataShards. As such, the shards will // contain any padding that was applied to ensure they all have the same diff --git a/torch_xla/csrc/ops/masked_fill.cpp b/torch_xla/csrc/ops/masked_fill.cpp deleted file mode 100644 index baad805fd446..000000000000 --- a/torch_xla/csrc/ops/masked_fill.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include "torch_xla/csrc/ops/masked_fill.h" - -#include "torch_xla/csrc/helpers.h" -#include "torch_xla/csrc/lowering_context.h" -#include "torch_xla/csrc/ops/scalar.h" -#include "xla/client/lib/constants.h" - -namespace torch_xla { - -MaskedFill::MaskedFill(const torch::lazy::Value& input, - const torch::lazy::Value& mask, const at::Scalar& value) - : XlaNode(torch::lazy::OpKind(at::aten::masked_fill), {input, mask}, - GetXlaShape(input), - /*num_outputs=*/1, ScalarHash(value)), - value_(std::move(value)) {} - -torch::lazy::NodePtr MaskedFill::Clone(torch::lazy::OpList operands) const { - return torch::lazy::MakeNode(operands.at(0), operands.at(1), - value_); -} - -XlaOpVector MaskedFill::Lower(LoweringContext* loctx) const { - xla::XlaOp input = loctx->GetOutputOp(operand(0)); - xla::XlaOp mask = loctx->GetOutputOp(operand(1)); - xla::XlaOp zero = xla::Zero(loctx->builder(), XlaHelpers::TypeOfXlaOp(mask)); - xla::XlaOp mask_pred = xla::Ne(mask, zero); - // Input shape is the same as output shape. - const xla::Shape& input_shape = xla_shape(); - xla::XlaOp value = - xla::Broadcast(XlaHelpers::ScalarValue(value_, input_shape.element_type(), - input.builder()), - input_shape.dimensions()); - return ReturnOp(xla::Select(mask_pred, value, input), loctx); -} - -std::string MaskedFill::ToString() const { - std::stringstream ss; - ss << XlaNode::ToString() << ", value=" << value_; - return ss.str(); -} - -} // namespace torch_xla diff --git a/torch_xla/csrc/ops/masked_fill.h b/torch_xla/csrc/ops/masked_fill.h deleted file mode 100644 index 8269b4678bbe..000000000000 --- a/torch_xla/csrc/ops/masked_fill.h +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef XLA_TORCH_XLA_CSRC_OPS_MASKED_FILL_H_ -#define XLA_TORCH_XLA_CSRC_OPS_MASKED_FILL_H_ - -#include - -#include "torch_xla/csrc/ir.h" - -namespace torch_xla { - -class MaskedFill : public XlaNode { - public: - MaskedFill(const torch::lazy::Value& input, const torch::lazy::Value& mask, - const at::Scalar& value); - - std::string ToString() const override; - - torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; - - XlaOpVector Lower(LoweringContext* loctx) const override; - - at::Scalar value() const { return value_; } - - private: - at::Scalar value_; -}; - -} // namespace torch_xla - -#endif // XLA_TORCH_XLA_CSRC_OPS_MASKED_FILL_H_ \ No newline at end of file diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 28587c38df8b..1fbb0dec44b0 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -548,6 +548,20 @@ torch_xla::XlaOpVector LogSigmoidBackward::Lower(LoweringContext* loctx) const { BuildLogSigmoidBackward(xla_grad_output, xla_input, xla_buffer), loctx); } +torch_xla::XlaOpVector MaskedFillScalar::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + xla::XlaOp mask = loctx->GetOutputOp(operand(1)); + xla::XlaOp scalar = loctx->GetOutputOp(operand(2)); + return ReturnOp(BuildMaskedFillScalar(xla_input, mask, scalar), loctx); +} + +torch_xla::XlaOpVector MaskedFillTensor::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + xla::XlaOp mask = loctx->GetOutputOp(operand(1)); + xla::XlaOp tensor = loctx->GetOutputOp(operand(2)); + return ReturnOp(BuildMaskedFillScalar(xla_input, mask, tensor), loctx); +} + torch_xla::XlaOpVector Maximum::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); xla::XlaOp xla_other = loctx->GetOutputOp(operand(1)); diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index b9a9d4048390..101476fa1127 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -637,6 +637,24 @@ xla::Shape LogSigmoidBackwardOutputShape(const torch::lazy::Value& grad_output, return GetXlaShape(grad_output); } +xla::Shape MaskedFillScalarOutputShape(const torch::lazy::Value& input, + const torch::lazy::Value& mask, + const torch::lazy::Value& value) { + auto lower_for_shape_fn = + [&](absl::Span operands) -> xla::XlaOp { + return BuildMaskedFillScalar(operands[0], operands[1], operands[2]); + }; + return InferOutputShape( + {GetXlaShape(input), GetXlaShape(mask), GetXlaShape(value)}, + lower_for_shape_fn); +} + +xla::Shape MaskedFillTensorOutputShape(const torch::lazy::Value& input, + const torch::lazy::Value& mask, + const torch::lazy::Value& value) { + return MaskedFillScalarOutputShape(input, mask, value); +} + xla::Shape MaximumOutputShape(const torch::lazy::Value& input, const torch::lazy::Value& other) { auto lower_for_shape_fn = diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.h b/torch_xla/csrc/ops/ops_xla_shape_fn.h index 7a79196fb772..57cda7b83f20 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.h +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.h @@ -210,6 +210,14 @@ xla::Shape LogSigmoidBackwardOutputShape(const torch::lazy::Value& grad_output, const torch::lazy::Value& input, const torch::lazy::Value& buffer); +xla::Shape MaskedFillScalarOutputShape(const torch::lazy::Value& input, + const torch::lazy::Value& mask, + const torch::lazy::Value& value); + +xla::Shape MaskedFillTensorOutputShape(const torch::lazy::Value& input, + const torch::lazy::Value& mask, + const torch::lazy::Value& value); + xla::Shape MaximumOutputShape(const torch::lazy::Value& input, const torch::lazy::Value& other); diff --git a/torch_xla/csrc/resize_ops.cpp b/torch_xla/csrc/resize_ops.cpp index db1d90b2d6c7..d77d1dd84100 100644 --- a/torch_xla/csrc/resize_ops.cpp +++ b/torch_xla/csrc/resize_ops.cpp @@ -267,7 +267,7 @@ xla::XlaOp LowerForward2d(const std::string& target, xla::XlaOp input, XlaDeviceType hw_type = static_cast(bridge::GetCurrentDevice().type()); - if (hw_type == XlaDeviceType::TPU) { + if (hw_type == XlaDeviceType::TPU || hw_type == XlaDeviceType::NEURON) { // TPU uses custom call implementation resized = xla::CustomCall(input.builder(), target, {tinput}, resized_shape, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 67e694491bcc..6f2cfbf6b8e1 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -62,7 +62,6 @@ #include "torch_xla/csrc/ops/linspace.h" #include "torch_xla/csrc/ops/log_softmax.h" #include "torch_xla/csrc/ops/logsumexp.h" -#include "torch_xla/csrc/ops/masked_fill.h" #include "torch_xla/csrc/ops/masked_scatter.h" #include "torch_xla/csrc/ops/masked_select.h" #include "torch_xla/csrc/ops/max_in_dim.h" @@ -1574,22 +1573,6 @@ XLATensorPtr lt(const XLATensorPtr& input, const XLATensorPtr& other) { return DispatchComparisonOp(at::aten::lt, input, other); } -XLATensorPtr masked_fill(XLATensorPtr& input, const XLATensorPtr& mask, - const at::Scalar& value) { - torch::lazy::ScopePusher ir_scope(at::aten::masked_fill.toQualString()); - auto input_value = input->GetIrValue(); - // Expand input tensor to mask if needed (same as masked_scatter below). - // An additional check makes sure to only expand if the rank of input tensor - // is less than that of the mask tensor. - if (input->shape().get().rank() <= mask->shape().get().rank() && - input->shape().get().dimensions() < mask->shape().get().dimensions()) { - input_value = MaybeExpand(input->GetIrValue(), mask->shape()); - } - return input->CreateFrom(torch::lazy::MakeNode( - input_value, MaybeExpand(mask->GetIrValue(), GetXlaShape(input_value)), - value)); -} - XLATensorPtr masked_scatter(XLATensorPtr& input, const XLATensorPtr& mask, const XLATensorPtr& source) { torch::lazy::ScopePusher ir_scope(at::aten::masked_scatter.toQualString()); diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 9462b1848d72..88d6e8b44965 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -483,9 +483,6 @@ XLATensorPtr lt(const XLATensorPtr& input, const at::Scalar& other); XLATensorPtr lt(const XLATensorPtr& input, const XLATensorPtr& other); -XLATensorPtr masked_fill(XLATensorPtr& input, const XLATensorPtr& mask, - const at::Scalar& value); - XLATensorPtr masked_scatter(XLATensorPtr& input, const XLATensorPtr& mask, const XLATensorPtr& source); diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index f7da463fb647..cde74256eeee 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -706,7 +706,8 @@ void ShardingUtil::PrepareOutputShardingPropagation( } runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( - std::vector& local_shards, std::vector& devices, + const std::vector& local_shards, + const std::vector& devices, const XLATensor::ShardingSpecPtr& sharding_spec) { XLA_CHECK(local_shards.size() == devices.size()) << "A device must be speficied for each shard"; diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 4a595f4e99b0..32060c7fc098 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -147,7 +147,8 @@ class ShardingUtil { // Transfers the individual shards to the devices and returns a DataPtr for // the PjRtShardedData wrapping the shards. static runtime::ComputationClient::DataPtr CreateShardedData( - std::vector& shards, std::vector& devices, + const std::vector& shards, + const std::vector& devices, const XLATensor::ShardingSpecPtr& sharding_spec); }; diff --git a/torch_xla/experimental/distributed_checkpoint/__init__.py b/torch_xla/experimental/distributed_checkpoint/__init__.py new file mode 100644 index 000000000000..cad57c3a4058 --- /dev/null +++ b/torch_xla/experimental/distributed_checkpoint/__init__.py @@ -0,0 +1,8 @@ +from .manager import CheckpointManager +from .planners import SPMDSavePlanner, SPMDLoadPlanner + +__all__ = [ + "CheckpointManager", + "SPMDSavePlanner", + "SPMDLoadPlanner", +] diff --git a/torch_xla/experimental/_distributed_checkpoint_helpers.py b/torch_xla/experimental/distributed_checkpoint/_helpers.py similarity index 100% rename from torch_xla/experimental/_distributed_checkpoint_helpers.py rename to torch_xla/experimental/distributed_checkpoint/_helpers.py diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py new file mode 100644 index 000000000000..cd36cbe1eb64 --- /dev/null +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -0,0 +1,148 @@ +import torch.distributed.checkpoint as dist_cp +import torch_xla.experimental.distributed_checkpoint as xc + +from typing import List, Optional +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE + + +class CheckpointManager: + """ + The CheckpointManager class provides a higher-level wrapper around the + torch.distributed.checkpoint APIs to manage checkpointing. It builds on top + of those APIs to enable a few key features: + - Per-step checkpointing: Each checkpoint taken by the CheckpointManager is + identified by the step at which it was taken, and any step tracked + by the CheckpointManager can be restored. + - Async checkpointing: The torch.distributed.checkpoint APIs are + synchronous, which will block training for the duration of the + checkpoint. The CheckpointManager's save_async method can be used to + offload checkpointing to a background thread, unblocking training + while the checkpoint is written to persistent storage. + - Automatic checkpointing: If the training process would be shut down due + to a SIGTERM, the CheckpointManager will automatically take a + checkpoint at the next step. + - Native fsspec integration: Any storage protocol compatible with fsspec + can be used with CheckpointManager. + + The intended usage of CheckpointManager is as follows: + + >>> # Create a CheckpointManager to checkpoint every 10 steps into GCS. + >>> chkpt_mgr = CheckpointManager('gs://my-bucket/my-experiment', 10) + + >>> # Select a checkpoint to restore from, and restore if applicable + >>> tracked_steps = chkpt_mgr.all_steps() + >>> if tracked_steps: + >>> # Choose the highest step + >>> best_step = max(tracked_steps) + >>> state_dict = {'model': model.state_dict()} + >>> chkpt_mgr.restore(best_step, state_dict) + >>> model.load_state_dict(state_dict['model']) + + >>> # Call `save` or `save_async` every step within the train loop. + >>> for step, data in enumerate(dataloader): + >>> ... + >>> state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()} + >>> if chkpt_mgr.save_async(step, state_dict): + >>> print(f'Checkpoint taken at step {step}') + + By calling `save` or `save_async` every step, the CheckpointManager has the + opportunity to take a checkpoint on steps which are out-of-cycle with its + step_period, as would be the case in auto checkpointing. + + This class is inspired by Orbax's CheckpointManager, which can be found here: + https://github.com/google/orbax/blob/efc079c4e5b437782a80138913d322cb3ed365c7/checkpoint/orbax/checkpoint/checkpoint_manager.py + """ + + def __init__(self, + path: str, + save_period: int, + max_to_keep: Optional[int] = -1, + async_queue_size: Optional[int] = 1): + """ + Create a checkpoint manager that reads and writes checkpoints into + the provided directory. + + Args: + path: The base path for the CheckpointManager to write checkpoints into. + save_period: The number of steps between saving checkpoints. + max_to_keep: The maximum number of checkpoints to be tracked by the + CheckpointManager. When a new checkpoint will be taken, the + checkpoint for the lowest tracked step will be deleted. + Default: -1, indicating no upper bound on the number of checkpoints. + async_queue_size: The size of the execution queue which processes async + checkpoints. This should be a small value to ensure training doesn't + get too far ahead of the last finished checkpoint, but increasing + the value to 2 can unblock training when there are transient + network issues which slow down the active checkpoint. + Default: 1, which only allows a single async checkpoint to be + pending at a time. + """ + raise NotImplementedError + + def should_save(self, step: int) -> bool: + """ + Returns true if a checkpoint should be saved for the current step or if + a preemption has been detected. + """ + raise NotImplementedError + + def save(self, + step, + state_dict: STATE_DICT_TYPE, + force: Optional[bool] = False) -> bool: + """ + Take a checkpoint synchronously if `self.should_save(step)`. + + Args: + step: The current training step. + state_dict: The state dict to be checkpointed. + force: Option to force a checkpoint to be taken regardless of the result + of `should_save(step)`. + Returns: + True if a checkpoint was taken and False otherwise. + """ + raise NotImplementedError + + def save_async(self, + step: int, + state_dict: STATE_DICT_TYPE, + force: Optional[bool] = False) -> bool: + """ + Take a checkpoint asynchronously if `self.should_save(step)`. The + input state_dict will be transferred to the CPU device using the + `sharded_cpu_state_dict` function. + + This function will do the following: + 1. Transfer `state_dict` to the CPU device. + 2. Dispatch the checkpoint workload to an asynchronous execution + queue. This will block training until the ongoing async + checkpoint finishes when the queue is full. + + Args: + step: The current training step. + state_dict: The state dict to be checkpointed. + force: Option to force a checkpoint to be taken regardless of the result + of `should_save(step)`. + Returns: + True if a checkpoint was taken and False otherwise. + """ + raise NotImplementedError + + def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None: + """ + Restores the checkpoint taken at the given step into the state_dict. The + caller is responsible for calling `model.load_state_dict` to restore any + non-tensor values. + + Args: + step: The step whose checkpoint is to be restored. + state_dict: The state dict to restore the checkpoint into. Values are + updated in-place within the state_dict. + """ + raise NotImplementedError + + def all_steps(self) -> List[int]: + """ + List all steps tracked by the CheckpointManager. + """ + raise NotImplementedError diff --git a/torch_xla/experimental/distributed_checkpoint.py b/torch_xla/experimental/distributed_checkpoint/planners.py similarity index 99% rename from torch_xla/experimental/distributed_checkpoint.py rename to torch_xla/experimental/distributed_checkpoint/planners.py index 5b1ee97b7d64..fbf466ff28a9 100644 --- a/torch_xla/experimental/distributed_checkpoint.py +++ b/torch_xla/experimental/distributed_checkpoint/planners.py @@ -35,16 +35,11 @@ from torch.distributed.checkpoint.utils import find_state_dict_object from torch.utils._pytree import tree_map from torch_xla.experimental.xla_sharding import XLAShardedTensor, XLAShard -from torch_xla.experimental._distributed_checkpoint_helpers import ( +from torch_xla.experimental.distributed_checkpoint._helpers import ( FLATTEN_MAPPING, flatten_state_dict, dedup_tensors, _is_sharded_tensor, set_element, narrow_tensor_by_index, _unwrap_xla_sharded_tensor, _CpuShards) from typing import Any, Dict, List, Tuple, Union -__all__ = [ - "SPMDSavePlanner", - "SPMDLoadPlanner", -] - class SPMDSavePlanner(SavePlanner): """ diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index 95f4a88128bb..21d0e2e570ac 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -87,6 +87,14 @@ def get_op_sharding(self, Return the OpSharding for the given partition spec. This is an expensive operation as the mesh grows, so the value is cached for reuse. """ + partition_spec = _translate_named_partition_spec(self, partition_spec) + flat_specs = np.hstack([d for d in partition_spec]) + specs = [d for d in flat_specs if d is not None] + assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \ + f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." + assert len(specs) == len(np.unique(specs)), \ + f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." + tile_assignment = _get_tile_assignment(self, partition_spec) if len(tile_assignment.shape) > len(partition_spec): # Use partial replication for sharding a tensor over a higher-rank mesh @@ -482,19 +490,12 @@ def mark_sharding( assert num_devices > 0, "This requires XLA supported device(s)." assert mesh.size() == num_devices, \ f"{mesh.mesh_shape} is not mappable over {num_devices} devices." - partition_spec = _translate_named_partition_spec(mesh, partition_spec) # We only allow fully specified `partition_spec` to be applicable, as opposed # to filling in the unspecified replicated dims. Fully specified `partiion_spec` # should be of the same rank as `t`. This is to support partial replication # where the group assignment may vary with different input ranks. assert len(t.shape) == len(partition_spec), \ f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." - flat_specs = np.hstack([d for d in partition_spec]) - specs = [d for d in flat_specs if d is not None] - assert all(d >= 0 and d < len(mesh.mesh_shape) for d in specs), \ - f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." - assert len(specs) == len(np.unique(specs)), \ - f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." op_sharding = mesh.get_op_sharding(partition_spec)