From 8ec552d9d9e0a67f329652cbf7e73ec65e4f0c11 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Wed, 9 Oct 2024 16:11:52 -0700 Subject: [PATCH 1/9] add profiler hints in paralloader (#8244) --- torch_xla/distributed/parallel_loader.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torch_xla/distributed/parallel_loader.py b/torch_xla/distributed/parallel_loader.py index a177c92b59d..b7d2519eccf 100644 --- a/torch_xla/distributed/parallel_loader.py +++ b/torch_xla/distributed/parallel_loader.py @@ -7,6 +7,7 @@ import torch_xla.utils.keyd_queue as kq import torch_xla.utils.utils as xu import torch_xla.core.xla_model as xm +import torch_xla.debug.profiler as xp class PerDeviceQueue(object): @@ -160,7 +161,8 @@ def _loader_worker(self): try: while not self._done: try: - _, data = next(data_iter) + with xp.Trace("cpu_loader.next"): + _, data = next(data_iter) except StopIteration: break batch.append(data) @@ -227,12 +229,14 @@ def _worker(self, dqueue, host_to_device_transfer_threads): try: while True: - batch = self._get_batch(dqueue) + with xp.Trace("get_batch_from_cpu_queue"): + batch = self._get_batch(dqueue) if not batch: break with torch.no_grad(): try: - batch = self.send_cpu_data_to_device(batch, device) + with xp.Trace("cpu_data_to_xla_device"): + batch = self.send_cpu_data_to_device(batch, device) except Exception as e: # _worker is being run in a daemon thread, raise the error # will not work. Put the error in an error queue instead. From 0a91f79919335e0d6ced7fa043323d286d019aa7 Mon Sep 17 00:00:00 2001 From: Yenkai Wang Date: Thu, 10 Oct 2024 10:49:48 -0500 Subject: [PATCH 2/9] Op info test for `linalg.solve_ex .. linalg.tensorinv` (#7504) (#8251) --- experimental/torch_xla2/test/test_ops.py | 12 +++---- .../torch_xla2/torch_xla2/ops/jaten.py | 32 +++++++++++++++++-- .../torch_xla2/torch_xla2/ops/jtorch.py | 10 ++++++ 3 files changed, 46 insertions(+), 8 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index aa07381cfda..fde60bd0205 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -40,11 +40,6 @@ "linalg.lu_solve", "linalg.matrix_norm", "linalg.matrix_power", - "linalg.solve_ex", - "linalg.solve_triangular", - "linalg.svd", - "linalg.svdvals", - "linalg.tensorinv", "linalg.tensorsolve", "linalg.vector_norm", "linspace", @@ -162,7 +157,12 @@ 'exponential', } -atol_dict = {"matrix_exp": (2e-1, 2e-4), "linalg.pinv": (8e-1, 2e0), "linalg.eig": (2e0, 3e0), "linalg.eigh": (5e1, 3e0), "linalg.eigvalsh": (5e1, 3e0)} +atol_dict = {"linalg.eig": (2e0, 3e0), + "linalg.eigh": (5e1, 3e0), + "linalg.eigvalsh": (5e1, 3e0), + "linalg.pinv": (8e-1, 2e0), + "linalg.svd": (1e0, 1e0), + "matrix_exp": (2e-1, 2e-4)} def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True, check_output=True): if isinstance(output1, torch.Tensor): diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index f671e039839..28062b05615 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -4204,7 +4204,7 @@ def _aten__linalg_slogdet(input): # torch.linalg.svd @op(torch.ops.aten._linalg_svd) def _aten__linalg_svd(a, full_matrices=True): - return jnp.linalg.svd(a, full_matrices) + return jnp.linalg.svd(a, full_matrices=full_matrices) # torch.linalg.pinv @@ -4216,7 +4216,35 @@ def _aten_linalg_pinv_atol_rtol_tensor(a, rtol=None, **kwargs): # torch.linalg.solve @op(torch.ops.aten._linalg_solve_ex) def _aten__linalg_solve_ex(a, b): - return jnp.linalg.solve(a, b), jnp.array(0) + res = jnp.linalg.solve(a, b) + info_shape = a.shape[0] if len(a.shape) >= 3 else [] + info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32)) + return res, info + + +# torch.linalg.solve_triangular +@op(torch.ops.aten.linalg_solve_triangular) +def _aten_linalg_solve_triangular(a, b, *, upper=True, left=True, unitriangular=False): + if left is False: + a = jnp.matrix_transpose(a) + b = jnp.matrix_transpose(b) + upper = not upper + res = jax.scipy.linalg.solve_triangular(a, b, lower=not upper, unit_diagonal=unitriangular) + if left is False: + res = jnp.matrix_transpose(res) + return res + + +@op(torch.ops.aten.linalg_inv_ex) +def _aten_linalg_inv_ex(a): + ainv = jnp.linalg.inv(a) + info = jnp.array(0) + return ainv, info + + +@op(torch.ops.aten._linalg_check_errors) +def _aten__linalg_check_errors(*args, **kwargs): + pass @op(torch.ops.aten.median) diff --git a/experimental/torch_xla2/torch_xla2/ops/jtorch.py b/experimental/torch_xla2/torch_xla2/ops/jtorch.py index 2354834af73..50bbba6252d 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jtorch.py +++ b/experimental/torch_xla2/torch_xla2/ops/jtorch.py @@ -289,3 +289,13 @@ def tensor_split(input, indices_or_sections, dim=0): def linalg_solve(a, b): res, _ = jaten._aten__linalg_solve_ex(a, b) return res + + +@register_function(torch.linalg.solve_ex) +def linalg_solve_ex(a, b): + res, info = jaten._aten__linalg_solve_ex(a, b) + return res, info + +@register_function(torch.linalg.svd) +def linalg_svd(a, full_matrices=True, **kwargs): + return jaten._aten__linalg_svd(a, full_matrices=full_matrices, **kwargs) From 543ccc8602dd946e8480c0f9ae4cccddb8964e87 Mon Sep 17 00:00:00 2001 From: Michael Green <59619482+mikegre-google@users.noreply.github.com> Date: Thu, 10 Oct 2024 15:54:25 +0000 Subject: [PATCH 3/9] reorganized files and toc (#8246) --- FIX_LOWERING_FOR_CORE_ATEN_OPS.md | 39 -- TROUBLESHOOTING.md | 382 ------------ docs/amp.md | 95 --- docs/bazel.md | 177 ------ docs/ddp.md | 280 --------- docs/dynamo.md | 116 ---- docs/eager.md | 119 ---- docs/first_steps.md | 257 -------- docs/fsdp.md | 112 ---- docs/glossry.md | 137 ----- docs/gpu.md | 172 ------ docs/kubernetes.md | 312 ---------- docs/openxla.md | 19 - docs/pjrt.md | 429 -------------- docs/plugins.md | 82 --- docs/source/accelerators/gpu.md | 6 + docs/source/accelerators/tpu.md | 24 + docs/source/contribute/bazel.md | 247 ++++++++ docs/source/contribute/codegen_migration.md | 326 ++++++++++ .../contribute/configure-environment.md} | 91 +-- docs/source/contribute/op_lowering.md | 186 ++++++ docs/source/contribute/plugins.md | 85 +++ docs/source/debug.rst | 1 - docs/source/eager_mode.rst | 1 - docs/{ => source/features}/pallas.md | 55 +- docs/{ => source/features}/stablehlo.md | 146 +++-- docs/{ => source/features}/triton.md | 26 +- docs/source/gpu.rst | 1 - docs/source/index.rst | 144 ++--- docs/source/learn/api-guide.rst | 89 +++ docs/{ => source/learn}/dynamic_shape.md | 2 +- docs/source/learn/eager.md | 132 +++++ docs/source/learn/pjrt.md | 438 ++++++++++++++ docs/source/learn/pytorch-on-xla-devices.md | 393 ++++++++++++ docs/source/learn/troubleshoot.md | 474 +++++++++++++++ docs/source/learn/xla-overview.md | 558 ++++++++++++++++++ docs/source/multi_process_distributed.rst | 2 - docs/source/perf/amp.md | 149 +++++ docs/source/perf/ddp.md | 252 ++++++++ docs/source/perf/dynamo.md | 233 ++++++++ docs/{ => source/perf}/fori_loop.md | 49 +- docs/source/perf/fsdp.md | 186 ++++++ docs/{ => source/perf}/fsdpv2.md | 50 +- docs/{ => source/perf}/quantized_ops.md | 122 +++- docs/source/perf/recompilation.md | 176 ++++++ docs/{ => source/perf}/spmd_advanced.md | 8 +- docs/source/perf/spmd_basic.md | 116 ++++ .../perf/spmd_distributed_checkpoint.md | 142 +++++ docs/source/perf/spmd_gpu.md | 48 ++ docs/source/quantized_ops.rst | 1 - docs/source/runtime.rst | 1 - docs/source/spmd.rst | 4 - docs/source/torch_compile.rst | 1 - docs/spmd_basic.md | 83 --- docs/spmd_distributed_checkpoint.md | 125 ---- docs/spmd_gpu.md | 40 -- 56 files changed, 4661 insertions(+), 3280 deletions(-) delete mode 100644 FIX_LOWERING_FOR_CORE_ATEN_OPS.md delete mode 100644 TROUBLESHOOTING.md delete mode 100644 docs/amp.md delete mode 100644 docs/bazel.md delete mode 100644 docs/ddp.md delete mode 100644 docs/dynamo.md delete mode 100644 docs/eager.md delete mode 100644 docs/first_steps.md delete mode 100644 docs/fsdp.md delete mode 100644 docs/glossry.md delete mode 100644 docs/gpu.md delete mode 100644 docs/kubernetes.md delete mode 100644 docs/openxla.md delete mode 100644 docs/pjrt.md delete mode 100644 docs/plugins.md create mode 100644 docs/source/accelerators/gpu.md create mode 100644 docs/source/accelerators/tpu.md create mode 100644 docs/source/contribute/bazel.md create mode 100644 docs/source/contribute/codegen_migration.md rename docs/{workflow.md => source/contribute/configure-environment.md} (57%) create mode 100644 docs/source/contribute/op_lowering.md create mode 100644 docs/source/contribute/plugins.md delete mode 100644 docs/source/debug.rst delete mode 100644 docs/source/eager_mode.rst rename docs/{ => source/features}/pallas.md (55%) rename docs/{ => source/features}/stablehlo.md (67%) rename docs/{ => source/features}/triton.md (84%) delete mode 100644 docs/source/gpu.rst create mode 100644 docs/source/learn/api-guide.rst rename docs/{ => source/learn}/dynamic_shape.md (98%) create mode 100644 docs/source/learn/eager.md create mode 100644 docs/source/learn/pjrt.md create mode 100644 docs/source/learn/pytorch-on-xla-devices.md create mode 100644 docs/source/learn/troubleshoot.md create mode 100644 docs/source/learn/xla-overview.md delete mode 100644 docs/source/multi_process_distributed.rst create mode 100644 docs/source/perf/amp.md create mode 100644 docs/source/perf/ddp.md create mode 100644 docs/source/perf/dynamo.md rename docs/{ => source/perf}/fori_loop.md (57%) create mode 100644 docs/source/perf/fsdp.md rename docs/{ => source/perf}/fsdpv2.md (57%) rename docs/{ => source/perf}/quantized_ops.md (52%) create mode 100644 docs/source/perf/recompilation.md rename docs/{ => source/perf}/spmd_advanced.md (97%) create mode 100644 docs/source/perf/spmd_basic.md create mode 100644 docs/source/perf/spmd_distributed_checkpoint.md create mode 100644 docs/source/perf/spmd_gpu.md delete mode 100644 docs/source/quantized_ops.rst delete mode 100644 docs/source/runtime.rst delete mode 100644 docs/source/spmd.rst delete mode 100644 docs/source/torch_compile.rst delete mode 100644 docs/spmd_basic.md delete mode 100644 docs/spmd_distributed_checkpoint.md delete mode 100644 docs/spmd_gpu.md diff --git a/FIX_LOWERING_FOR_CORE_ATEN_OPS.md b/FIX_LOWERING_FOR_CORE_ATEN_OPS.md deleted file mode 100644 index deb3b804d9c..00000000000 --- a/FIX_LOWERING_FOR_CORE_ATEN_OPS.md +++ /dev/null @@ -1,39 +0,0 @@ -In order for PyTorch/XLA to support the PyTorch core ATen opset, it requires lowering each core ATen op in PyTorch/XLA. Note that this document will serve as a guide for fixing these lowering for core aten opset, specifically looking at [test_core_aten_ops.py test](https://github.com/pytorch/xla/blob/master/test/test_core_aten_ops.py). This guide will **not** cover how to lower an op in PyTorch/XLA, please refer our [op lowering guide](https://github.com/pytorch/xla/blob/master/OP_LOWERING_GUIDE.md) for this. - -We also have a worklog for lowering these core aten ops in a GitHub issue, so we can track who's working on which ops and share some findings: [[Core Aten ops] Logs for fixing core aten ops coverage issues](https://github.com/pytorch/xla/issues/5934). - -Let's go back and take a closer look at the [test_core_aten_ops.py test](https://github.com/pytorch/xla/blob/master/test/test_core_aten_ops.py), which is the source of truth to verify and correctness of these lowerings. The core of this test file is the `run_export_and_compare` at https://github.com/pytorch/xla/blob/master/test/test_core_aten_ops.py#L28. Each op unit test initializes the input and passes the op as a function and its inputs. The `run_export_and_compare` has multiple subtests that have the following structure: -- `torch_eval` - - `torch_xla_eval` - - `torch_xla_diff` - - `can_export` - - `can_convert_to_stablehlo` - - `stablehlo_can_run` - - `stablehlo_diff` - -Below we'll describe what each of these subtests mean and give some recommendations on fixing it. - -### `torch_eval` - -This subtest directly calls torch version of the op with the given inputs. If the unit test fails in this subtest, this implies that torch there is a problem with the unit test itself. One common reason might be due to inputs (or types of inputs) not being compatible with the op. We recommend you to look at the official torch documentation of the corresponding op to ensure that that unit tests are passing valid inputs to the op. - -### `torch_eval_xla` - -This subtest calls the torch_xla version of the op. If you've made changes to lower the op and this subtest fails, this means there may be something wrong with the lowering. We recommend you to take another look at our [op lowering guide](https://github.com/pytorch/xla/blob/master/OP_LOWERING_GUIDE.md). If you're unable to debug further, feel free to leave a comment in your assigned GitHub issue. - -### `torch_xla_diff` - -This subtest compares the output of the op between torch and torch_xla. -If this subtest fails, it implies that your lowering runs successfully -but produced a different result than torch eager mode. - -If the test uses 16-bit floats (float16, bfloat16); This is very likely -that the tolerances that we give to `torch.allclose` to compare was to -strict. You can relax it a bit. Take a look at [this issue](https://github.com/pytorch/xla/issues/5934) of one such example. - -If the result torchxla produces is totally different than what torch produces, that means it's a bug in lowering code; and probably need -more work. Feel free to tag more people (such as qihqi to look). - -### `can_export`, `can_convert_to_stablehlo`, `stablehlo_can_run`, `stablehlo_diff` - -These subtests are related to `export` and `stablehlo`. If the lowering is complete and the above `torch_*` subtests all succeed, it is highly likely that these tests will also succeed. diff --git a/TROUBLESHOOTING.md b/TROUBLESHOOTING.md deleted file mode 100644 index 22f6d01e374..00000000000 --- a/TROUBLESHOOTING.md +++ /dev/null @@ -1,382 +0,0 @@ -# Troubleshooting - -Note that the information in this section is subject to be removed in future releases of the _PyTorch/XLA_ software, -since many of them are peculiar to a given internal implementation which might change. - -## Sanity Check -Before performing any in depth debugging, we want to do a sanity check on the installed PyTorch/XLA. - -### Check PyTorch/XLA Version -PyTorch and PyTorch/XLA version should match. Check out our [README](https://github.com/pytorch/xla#getting-started) for more detials on versions available. -``` -vm:~$ python ->>> import torch ->>> import torch_xla ->>> print(torch.__version__) -2.1.0+cu121 ->>> print(torch_xla.__version__) -2.1.0 -``` - -### Perform A Simple Calculation -``` -vm:~$ export PJRT_DEVICE=TPU -vm:~$ python3 ->>> import torch ->>> import torch_xla.core.xla_model as xm ->>> t1 = torch.tensor(100, device=xm.xla_device()) ->>> t2 = torch.tensor(200, device=xm.xla_device()) ->>> print(t1 + t2) -tensor(300, device='xla:0') -``` - -### Run Resnet With Fake Data -For nightly -``` -vm:~$ git clone https://github.com/pytorch/xla.git -vm:~$ python xla/test/test_train_mp_imagenet.py --fake_data -``` - -For release version `x.y`, you want to use the branch `rx.y`. For example if you installed 2.1 release, you should do -``` -vm:~$ git clone --branch r2.1 https://github.com/pytorch/xla.git -vm:~$ python xla/test/test_train_mp_imagenet.py --fake_data -``` - -If you can get the resnet to run we can conclude that torch_xla is installed correctly. - - -## Performance Debugging - -To diagnose performance issues, we can use the execution metrics and counters provided by _PyTorch/XLA_ -The **first thing** to check when model is slow is to generate a metrics report. - -Metrics report is extremely helpful in diagnosing issues. Please try to include it in your bug -report sent to us if you have it. - -## PyTorch/XLA Debugging Tool - -You can enable the PyTorch/XLA debugging tool by setting `PT_XLA_DEBUG_LEVEL=2`, which provides a couple useful debugging features. You can also lower the debug level to `1` to slip the execution analysis. - -### Perform A Auto-Metrics Analysis - -The debugging tool will analyze the metrics report and provide a summary. Some example output would be - -``` -pt-xla-profiler: CompileTime too frequent: 21 counts during 11 steps -pt-xla-profiler: TransferFromDeviceTime too frequent: 11 counts during 11 steps -pt-xla-profiler: Op(s) not lowered: aten::_ctc_loss, aten::_ctc_loss_backward, Please open a GitHub issue with the above op lowering requests. -pt-xla-profiler: CompileTime too frequent: 23 counts during 12 steps -pt-xla-profiler: TransferFromDeviceTime too frequent: 12 counts during 12 steps -``` - -### Compilation & Execution Analysis -The debugging tool will analyze every compilation and execution for your model. Some example output would be -``` -Compilation Analysis: ================================================================================ -Compilation Analysis: Compilation Cause -Compilation Analysis: mark_step in parallel loader at step end -Compilation Analysis: Graph Info: -Compilation Analysis: Graph Hash: c74c3b91b855b2b123f833b0d5f86943 -Compilation Analysis: Number of Graph Inputs: 35 -Compilation Analysis: Number of Graph Outputs: 107 -Compilation Analysis: Python Frame Triggered Execution: -Compilation Analysis: mark_step (/workspaces/dk3/pytorch/xla/torch_xla/core/xla_model.py:1055) -Compilation Analysis: next (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:44) -Compilation Analysis: __next__ (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:32) -Compilation Analysis: train_loop_fn (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:48) -Compilation Analysis: start_training (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:65) -Compilation Analysis: (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:73) -Compilation Analysis: -------------------------------------------------------------------------------- -Compilation Analysis: ================================================================================ - -Post Compilation Analysis: ================================================================================ -Post Compilation Analysis: Graph input size: 1.548000 GB -Post Compilation Analysis: Graph output size: 7.922460 GB -Post Compilation Analysis: Aliased Input size: 1.547871 GB -Post Compilation Analysis: Intermediate tensor size: 12.124478 GB -Post Compilation Analysis: Compiled program size: 0.028210 GB -Post Compilation Analysis: -------------------------------------------------------------------------------- -Post Compilation Analysis: ================================================================================ - -Execution Analysis: ================================================================================ -Execution Analysis: Execution Cause -Execution Analysis: mark_step in parallel loader at step end -Execution Analysis: Graph Info: -Execution Analysis: Graph Hash: c74c3b91b855b2b123f833b0d5f86943 -Execution Analysis: Number of Graph Inputs: 35 -Execution Analysis: Number of Graph Outputs: 107 -Execution Analysis: Python Frame Triggered Execution: -Execution Analysis: mark_step (/workspaces/dk3/pytorch/xla/torch_xla/core/xla_model.py:1055) -Execution Analysis: next (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:44) -Execution Analysis: __next__ (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:32) -Execution Analysis: train_loop_fn (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:48) -Execution Analysis: start_training (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:65) -Execution Analysis: (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:73) -Execution Analysis: -------------------------------------------------------------------------------- -Execution Analysis: ================================================================================ -``` - -Some common causes of Compilation/Executation are -1. User manually call `mark_step`. -2. [Parallel loader](https://github.com/pytorch/xla/blob/fe4af0080af07f78ca2b614dd91b71885a3bbbb8/torch_xla/distributed/parallel_loader.py#L49-L51) call `mark_step` for every x (configurable) batch. -3. Exiting a [profiler StepTrace region](https://github.com/pytorch/xla/blob/fe4af0080af07f78ca2b614dd91b71885a3bbbb8/torch_xla/debug/profiler.py#L165-L171). -4. Dynamo decide to compile/execute the graph. -5. User trying to access(often due to logging) the value of a tensor before the `mark_step`. - -The executation caused by 1-4 are expected, and we want to avoid 5 by either reduce the frequency of accessing tensor values or manually add a `mark_step` before accessing. - -Users should expect to see this `Compilation Cause` + `Executation Cause` pairs for first couple steps. After the model stabilize users should expect to only see `Execution Cause`(you can disable execution analysis by `PT_XLA_DEBUG_LEVEL=1`). To use PyTorch/XLA efficiently, we expect the same models code to be run for every step and compilation only happen once for every graph. If you keep seeing `Compilation Cause`, you should try to dump the IR/HLO following [this section](#common-debugging-environment-variables-combinations) and compare the graphs for each step and understand the source of the differences. - -Following section will explain how to get and understand a more detail metrics report. - -## Get A Metrics Report - -Put the following line in your program to generate a report: - -```Python -import torch_xla.debug.metrics as met - -# For short report that only contains a few key metrics. -print(met.short_metrics_report()) -# For full report that includes all metrics. -print(met.metrics_report()) -``` - -## Understand The Metrics Report - -The report includes things like: -- how many time we issue _XLA_ compilations and time spent on issuing. -- how many times we execute and time spent on execution -- how many device data handles we create/destroy etc. - -This information is reported in terms of percentiles of the samples. An example is: - -``` -Metric: CompileTime - TotalSamples: 202 - Counter: 06m09s401ms746.001us - ValueRate: 778ms572.062us / second - Rate: 0.425201 / second - Percentiles: 1%=001ms32.778us; 5%=001ms61.283us; 10%=001ms79.236us; 20%=001ms110.973us; 50%=001ms228.773us; 80%=001ms339.183us; 90%=001ms434.305us; 95%=002ms921.063us; 99%=21s102ms853.173us -``` - -We also provide counters, which are named integer variables which track internal software status. For example: - -``` -Counter: CachedSyncTensors - Value: 395 -``` - -In this report, any counter that starts with `aten::` -indicates a context switch between the XLA device and CPU, which can be a -potential performance optimization area in the model code. - -Counters are useful to understand which operations are routed back to the CPU engine of _PyTorch_. -They are fully qualified with their C++ namespace: - -``` -Counter: aten::nonzero - Value: 33 -``` - -If you see `aten::` ops other than `nonzero` and `_local_scalar_dense`, that usually means a missing -lowering in PyTorch/XLA. Feel free to open a feature request for it on [GitHub issues](https://github.com/pytorch/xla/issues). - -## Clear The Metrics Report -If you want to clear the metrics between steps/epochs, you can use -```Python -import torch_xla.debug.metrics as met - -met.clear_all() -``` - -## PyTorch/XLA + Dynamo Debugging Tool - -You can enable the PyTorch/XLA + Dynamo debugging tool by setting `XLA_DYNAMO_DEBUG=1`. - -## Performance Profiling -To profile your workload in depth to understand bottlenecks please check the following resources: -* [Official tutorial](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm) -* [Colab notebook](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/pytorch-xla-profiling-colab.ipynb) -* [Sample MNIST training script with profiling](https://github.com/pytorch/xla/blob/master/test/test_profile_mp_mnist.py) -* [Utility script for capturing performance profiles](https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py) - -## Simple Benchmarking -Take a look at [`examples/train_resnet_benchmark.py`](https://github.com/pytorch/xla/blob/master/examples/train_resnet_benchmark.py) for how to benchmark a PyTorch/XLA model. - -## Known Performance Caveats - -PyTorch/XLA behaves semantically like regular PyTorch and XLA tensors share the full tensor interface with CPU & GPU tensors. -However, constraints in XLA/hardware and the lazy evaluation model suggest certain patterns might result in bad performance. - -If your model shows bad performance, keep in mind the following caveats: - -1. **XLA/TPU yield degraded performance with too many recompilations.** - - XLA compilation is expensive. PyTorch/XLA automatically recompiles the graph every time new shapes are encountered. - Usually models should stabilize within a few steps and you can see huge speedup for the rest of training. - - In order to avoid recompilations, not only must shapes be constant, but computations across XLA devices in all hosts should also be constant. - - _Possible sources_: - * Direct or indirect uses of `nonzero` introduce dynamic shapes; for example, masked indexing `base[index]` where `index` is a mask tensor. - * Loops with a different number of iterations between steps can result in different execution graphs, thus require recompilations. - - _Solution_: - * Tensor shapes should be the same between iterations, or a low number of shape variations should be used. - * Pad tensors to fixed sizes when possible. - -1. **Certain operations don't have native translations to XLA.** - - For these operations PyTorch/XLA automatically transfers to the CPU memory, evaluates on CPU, and transfers the result back to the XLA device. - Doing too many such operations during the training step can lead to significant slowdowns. - - _Possible sources_: - - - The `item()` operation explicitly asks to evaluate the result. Don't use it unless it's necessary. - - _Solution_: - - - For most ops we can lower them to XLA to fix it. Checkout [metrics report section](#metrics-report) to find out the missing ops and open a feature request on [GitHub](https://github.com/pytorch/xla/issues). - - Even when a PyTorch tensor is known as a scalar, avoid using `tensor.item()`. Keep it as a tensor and use tensor operations on it. - - Use `torch.where` to substitute control flow when applicable. - E.g. The control flow with `item()` used in [clip_grad_norm_](https://github.com/pytorch/pytorch/blob/de19eeee99a2a282fc441f637b23d8e50c75ecd1/torch/nn/utils/clip_grad.py#L33) is problematic and impacts performance, so we have [patched](https://github.com/pytorch/xla/blob/master/torch_patches/X10-clip_grad.diff) `clip_grad_norm_` by calling `torch.where` instead, which gives us a dramatic performance improvement. - ```python - ... - else: - device = parameters[0].device - total_norm = torch.zeros([], device=device if parameters else None) - for p in parameters: - param_norm = p.grad.data.norm(norm_type) ** norm_type - total_norm.add_(param_norm) - total_norm = (total_norm ** (1. / norm_type)) - clip_coef = torch.tensor(max_norm, device=device) / (total_norm + 1e-6) - for p in parameters: - p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device))) - ``` - -1. **Iterators in `torch_xla.distributed.data_parallel` may drop the last few batches in the input iterator.** - - This is to make sure we do the same amount of work on all XLA devices. - - _Solution_: - - * When dataset is small, and there are too few steps, this may result in a no-op epoch. Therefore, it is better to use - small batch sizes in those cases. - -## XLA Tensor Quirks - -1. **XLA tensor internals are opaque.** XLA tensors always appear to be -contiguous and without storage. Networks should not try to check the strides -of XLA tensors. - -1. **XLA tensors should be moved to the CPU before saving them.** Saving -XLA tensors directly causes them to be loaded back on the device(s) they were -saved from. If a device is unavailable at load time then the load will fail. -Moving XLA tensors to the CPU before saving them lets you decide which -device(s) to put the loaded tensors on. This is necessary if you want to -load the tensors on a machine without XLA devices. Care should be taken -moving the XLA tensors to the CPU before saving them, however, as moving -tensors across device types does not preserve view relationships. Instead, -views should be reconstructed as necessary after the tensors are loaded. - -1. **Copying an XLA Tensor with Python's copy.copy returns a deep copy, not a -shallow copy.** Use a view of an XLA tensor to get a shallow copy of it. - -1. **Handling shared weights.** Modules can share weights by setting the -Parameters of one module to another. This "tying" of module weights should -be done **AFTER** the modules are moved to an XLA device. Otherwise two -independent copies of the shared tensor will be made on the XLA device. - -## More Debugging Tools - -We don't expect users to use tools in this section to debug their models. But we might ask for -them when you submit a bug report since they provide additional information that metrics report -doesn't have. - -* ```print(torch_xla._XLAC._get_xla_tensors_text([res]))``` where `res` is the result tensor prints out the IR. -* ```print(torch_xla._XLAC._get_xla_tensors_hlo([res]))``` where `res` is the result tensor prints out the generated XLA HLO. - -Note these functions must be called prior to `mark_step()`, otherwise the tensor will already be materialized. - -### Environment Variables - -There are also a number of environment variables which control the behavior of the _PyTorch/XLA_ -software stack. - -Setting such variables will cause different degrees of performance degradation, so they should -only be enabled for debugging. - -* ```XLA_IR_DEBUG```: Enables the _Python_ stack trace to be captured where creating IR nodes, - hence allowing to understand which _PyTorch_ operation was responsible for generating the IR. - -* ```XLA_HLO_DEBUG```: Enables the _Python_ stack frame captured when _XLA_IR_DEBUG_ is active, - to be propagated to the _XLA_ _HLO_ metadata. - -* ```XLA_SAVE_TENSORS_FILE```: The path to a file which will be used to dump the IR graphs during - execution. Note that the file can become really big if the option is left enabled and the - _PyTorch_ program let run for long time. The graphs are appended to the file, so to have a clean - sheet from run to run, the file should be explicitly removed. - -* ```XLA_SAVE_TENSORS_FMT```: The format of the graphs stored within the _XLA_SAVE_TENSORS_FILE_ - file. Can be ```text``` (the default), ```dot``` (the _Graphviz_ format) or ```hlo```. - -* ```XLA_FLAGS=--xla_dump_to```: If set to ```=/tmp/dir_name```, XLA compiler will dump the unoptimized and optimzed HLO per compilation. - -* ```XLA_METRICS_FILE```: If set, the path to a local file where the internal metrics will be - saved at every step. Metrics will be appended to the file, if already existing. - -* ```XLA_SAVE_HLO_FILE```: If set, the path to a local file where, in case of compilation/execution - error, the offending HLO graph will be saved. - -* ```XLA_SYNC_WAIT```: Forces the XLA tensor sync operation to wait for its completion, before - moving to the next step. - -* ```XLA_USE_EAGER_DEBUG_MODE```: Forces the XLA tensor to execute eagerly, meaning compile and execute the torch operations one - by one. This is useful to bypass the long compilation time but overall step time will be a lot slower and memory usage will be higher - since all compiler optimizaiton will be skipped. - -* ```TF_CPP_LOG_THREAD_ID```: If set to 1, the TF logs will show the thread ID - helping with debugging multithreaded processes. - -* ```TF_CPP_VMODULE```: Environment variable used for TF VLOGs and takes the - form of `TF_CPP_VMODULE=name=value,...`. Note that for VLOGs you must set - `TF_CPP_MIN_LOG_LEVEL=0`. - -* ```TF_CPP_MIN_LOG_LEVEL```: Level to print messages for. `TF_CPP_MIN_LOG_LEVEL=0` will turn - on INFO logging, `TF_CPP_MIN_LOG_LEVEL=1` WARNING and so on. Our PyTorch/XLA `TF_VLOG` uses - `tensorflow::INFO` level by default so to see VLOGs set `TF_CPP_MIN_LOG_LEVEL=0`. - -* ```XLA_DUMP_HLO_GRAPH```: If set to `=1` in case of a compilation or execution error the - offending HLO graph will be dumped as part of the runtime error raised by `xla_util.cc`. - -### Common Debugging Environment Variables Combinations - -* Record the graph execution in the IR format - ``` - XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 XLA_SAVE_TENSORS_FMT="text" XLA_SAVE_TENSORS_FILE="/tmp/save1.ir" - ``` - -* Record the graph execution in the HLO format - ``` - XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 XLA_SAVE_TENSORS_FMT="hlo" XLA_SAVE_TENSORS_FILE="/tmp/save1.hlo" - ``` - -* Show debugging VLOG for runtime and graph compilation/execution - ``` - TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE="xla_graph_executor=5,pjrt_computation_client=3" - ``` - -### Reproducing PyTorch/XLA CI/CD unit test failures. - -You may see some test failures for a PR such as: - -``` -To execute this test, run the following from the base repo dir: - PYTORCH_TEST_WITH_SLOW=1 python ../test/test_torch.py -k test_put_xla_uint8 -``` - -Running this directly in the command line does not work. You need to set the environment variable `TORCH_TEST_DEVICES` to your local `pytorch/xla/test/pytorch_test_base.py`. For example: - -`TORCH_TEST_DEVICES=/path/to/pytorch/xla/test/pytorch_test_base.py PYTORCH_TEST_WITH_SLOW=1 python ../test/test_torch.py -k test_put_xla_uint8` should work. diff --git a/docs/amp.md b/docs/amp.md deleted file mode 100644 index bd019df7970..00000000000 --- a/docs/amp.md +++ /dev/null @@ -1,95 +0,0 @@ -# AMP (Automatic Mixed Precision) with Pytorch/XLA - -Pytorch/XLA's AMP extends [Pytorch's AMP package](https://pytorch.org/docs/stable/amp.html) with support for automatic mixed precision on XLA:GPU and XLA:TPU devices. -AMP is used to accelerate training and inference by executing certain operations in `float32` and other operations in a lower precision datatype (`float16` or `bfloat16` depending on hardware support). -This document describes how to use AMP on XLA devices and best practices. - -## AMP for XLA:TPU -AMP on TPUs automatically casts operations to run in either `float32` or `bfloat16` because TPUs natively support bfloat16. A simple TPU AMP example is below: - -``` -# Creates model and optimizer in default precision -model = Net().to(xm.xla_device()) -# Pytorch/XLA provides sync-free optimizers for improved performance -optimizer = syncfree.SGD(model.parameters(), ...) - -for input, target in data: - optimizer.zero_grad() - - # Enables autocasting for the forward pass - with autocast(xm.xla_device()): - output = model(input) - loss = loss_fn(output, target) - - # Exits the context manager before backward() - loss.backward() - xm.optimizer_step.(optimizer) -``` -`autocast(xm.xla_device())` aliases `torch.autocast('xla')` when the XLA Device is a TPU. Alternatively, if a script is only used with TPUs, then `torch.autocast('xla', dtype=torch.bfloat16)` can be directly used. - -Please file an issue or submit a pull request if there is an operator that should be autocasted that is not included. - - -### Best Practices -1. `autocast` should wrap only the forward pass(es) and loss computation(s) of the network. Backward ops run in the same type that autocast used for the corresponding forward ops. -2. Since TPU's use bfloat16 mixed precision, gradient scaling is not necessary. -3. Pytorch/XLA provides modified version of [optimizers](https://github.com/pytorch/xla/tree/master/torch_xla/amp/syncfree) that avoid the additional sync between device and host. - -### Supported Operators -AMP on TPUs operates like Pytorch's AMP. Rules for how autocasting is applied is summarized below: - -Only out-of-place ops and Tensor methods are eligible to be autocasted. In-place variants and calls that explicitly supply an out=... Tensor are allowed in autocast-enabled regions, but won’t go through autocasting. For example, in an autocast-enabled region a.addmm(b, c) can autocast, but a.addmm_(b, c) and a.addmm(b, c, out=d) cannot. For best performance and stability, prefer out-of-place ops in autocast-enabled regions. - -Ops that run in float64 or non-floating-point dtypes are not eligible, and will run in these types whether or not autocast is enabled. Additionally, Ops called with an explicit dtype=... argument are not eligible, and will produce output that respects the dtype argument. - -Ops not listed below do not go through autocasting. They run in the type defined by their inputs. Autocasting may still change the type in which unlisted ops run if they’re downstream from autocasted ops. - -**Ops that autocast to `bfloat16`:** - -`__matmul__`, `addbmm`, `addmm`, `addmv`, `addr`, `baddbmm`,` bmm`, `conv1d`, `conv2d`, `conv3d`, `conv_transpose1d`, `conv_transpose2d`, `conv_transpose3d`, `linear`, `matmul`, `mm`, `relu`, `prelu`, `max_pool2d` - -**Ops that autocast to `float32`:** - -`batch_norm`, `log_softmax`, `binary_cross_entropy`, `binary_cross_entropy_with_logits`, `prod`, `cdist`, `trace`, `chloesky` ,`inverse`, `reflection_pad`, `replication_pad`, `mse_loss`, `cosine_embbeding_loss`, `nll_loss`, `multilabel_margin_loss`, `qr`, `svd`, `triangular_solve`, `linalg_svd`, `linalg_inv_ex` - -**Ops that autocast to widest input type:** - -`stack`, `cat`, `index_copy` - -## AMP for XLA:GPU -AMP on XLA:GPU devices reuse Pytorch's AMP rules. See [Pytorch's AMP documentation](https://pytorch.org/docs/stable/amp.html) for CUDA specific behavior. A simple CUDA AMP example is below: - -``` -# Creates model and optimizer in default precision -model = Net().to(xm.xla_device()) -# Pytorch/XLA provides sync-free optimizers for improved performance -optimizer = syncfree.SGD(model.parameters(), ...) -scaler = GradScaler() - -for input, target in data: - optimizer.zero_grad() - - # Enables autocasting for the forward pass - with autocast(xm.xla_device()): - output = model(input) - loss = loss_fn(output, target) - - # Exits the context manager before backward pass - scaler.scale(loss).backward() - gradients = xm._fetch_gradients(optimizer) - xm.all_reduce('sum', gradients, scale=1.0 / xr.world_size()) - scaler.step(optimizer) - scaler.update() -``` - -`autocast(xm.xla_device())` aliases `torch.cuda.amp.autocast()` when the XLA Device is a CUDA device (XLA:GPU). Alternatively, if a script is only used with CUDA devices, then `torch.cuda.amp.autocast` can be directly used, but requires `torch` is compiled with `cuda` support for datatype of `torch.bfloat16`. We recommend using `autocast(xm.xla_device())` on XLA:GPU as it does not require `torch.cuda` support for any datatypes, including `torch.bfloat16`. - -### Best Practices -1. `autocast` should wrap only the forward pass(es) and loss computation(s) of the network. Backward ops run in the same type that autocast used for the corresponding forward ops. -2. Do not set `XLA_USE_F16` flag when using AMP on Cuda devices. This will override the per-operator precision settings provided by AMP and cause all operators to execute in float16. -3. Use gradient scaling to prevent float16 gradients from underflowing. -4. Pytorch/XLA provides modified version of [optimizers](https://github.com/pytorch/xla/tree/master/torch_xla/amp/syncfree) that avoid the additional sync between device and host. - -## Examples -Our [mnist training script](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_amp.py) and [imagenet training script](https://github.com/pytorch/xla/blob/master/test/test_train_mp_imagenet_amp.py) demonstrate how AMP is used on both TPUs and GPUs. - diff --git a/docs/bazel.md b/docs/bazel.md deleted file mode 100644 index 53d520dd683..00000000000 --- a/docs/bazel.md +++ /dev/null @@ -1,177 +0,0 @@ -## Bazel in Pytorch/XLA - -[Bazel](https://bazel.build/) is a free software tool used for the automation of building and testing software. [TensorFlow](https://www.tensorflow.org/http) and [OpenXLA](https://github.com/openxla/xla) both use it, which makes it a good fit for PyTorch/XLA as well. - -## Bazel dependencies - -Tensorflow is a [bazel external dependency](https://bazel.build/external/overview) for PyTorch/XLA, which can be seen in the `WORKSPACE` file: - -`WORKSPACE` -```bzl -http_archive( - name = "org_tensorflow", - strip_prefix = "tensorflow-f7759359f8420d3ca7b9fd19493f2a01bd47b4ef", - urls = [ - "https://github.com/tensorflow/tensorflow/archive/f7759359f8420d3ca7b9fd19493f2a01bd47b4ef.tar.gz", - ], -) -``` - -TensorFlow pin can be updated by pointing this repository to a different revision. Patches may be added as needed. -Bazel will resolve the dependency, prepare the code and patch it hermetically. - -For PyTorch, a different dependency mechanism is deployed because a local [PyTorch](https://github.com/pytorch/pytorch) -checkout is used, and this local checkout has to be `built` from source and ideally installed on the system for version -compatibility (e.g. codegen in PyTorch/XLA uses `torchgen` python module that should be installed in the system). - -The local directory can either set in `bazel/dependencies.bzl`, or overriden on the command line: - -```bash -bazel build --override_repository=org_tensorflow=/path/to/exported/tf_repo //... -``` - -```bash -bazel build --override_repository=torch=/path/to/exported/and/built/torch_repo //... -``` - -Please make sure that the overridden repositories are at the appropriate revisions and in case of `torch`, that it -has been built with `USE_CUDA=0 python setup.py bdist_wheel` to make sure that all expected build objects are present; -ideally installed into the system. - -`WORKSPACE` -```bzl -new_local_repository( - name = "torch", - build_file = "//bazel:torch.BUILD", - path = PYTORCH_LOCAL_DIR, -) -``` - -PyTorch headers are directly sourced from the `torch` dependency, the local checkout of PyTorch. The shared libraries -(e.g. `libtorch.so`) are sourced from the same local checkout where the code has been built and `build/lib/` contains the -built objects. For this to work, it's required to pass `-isystemexternal/torch` to the compiler so it can find `system` -libraries and satisfy them from the local checkout. Some are included as `` and some as `"user"` headers. - -Bazel brings in [pybind11](https://github.com/pybind/pybind11) embeded python and links against it to provide `libpython` -to the plugin using this mechanism. Python headers are also sourced from there instead of depending on the system version. -These are satisfied from the `"@pybind11//:pybind11_embed"`, which sets up compiler options for linking with `libpython` -transitively. - -## How to build XLA libraries - -Building the libraries is simple: - -```bash -bazel build //torch_xla/csrc/runtime/... -``` - -Bazel is configred via `.bazelrc`, but it can also take flags on the command line. - -```bash -bazel build --config=remote_cache //torch_xla/csrc/runtime/... -``` - -The `remote_cache` configurations use gcloud for caching and usually faster, but require -authentication with gcloud. See `.bazelrc` for the configuration. - -Using bazel makes it easy to express complex dependencies and there is a lot of gain from having a single build graph -with everything expressed in the same way. Therefore, there is no need to build the XLA libraries separately from the -rest of the pluing as used to be the case, building the whole repository, or the plugin shared object that links everythin -else in, is enough. - -## How to build the Torch/XLA plugin - -The normal build can be achieved by the invoking the standard `python setup.py bdist_wheel`, but C++ bindings can be built simply with: - -```bash -bazel build //:_XLAC.so -``` - -This will build the XLA client and the PyTorch plugin and link it all together. This can be useful when testing changes, to be -able to compile the C++ code without building the python plugin faster iteration cycles. - -## Remote caching - -Bazel comes with [remote caching](https://bazel.build/remote/caching) built in. There are plenty of cache backends that can be used; we deploy our caching on (GCS)[https://bazel.build/remote/caching#cloud-storage]. You can see the configuration in `.bazelrc`, under config name `remote_cache`. - -Remote caching is disabled by default but because it speeds up incremental builds by a huge margin, it is almost always recommended, and it is enabled by default in the CI automation and on Cloud Build. - -To authenticate on a machine, please ensure that you have the credentials present with `gcloud auth application-default login --no-launch-browser` or equivalent. - -Using the remote cache configured by `remote_cache` configuration setup requires authentication with GCP. -There are various ways to authenticate with GCP. For individual developers who have access to the development GCP project, one only needs to -specify the `--config=remote_cache` flag to bazel, and the default `--google_default_credentials` will be used and if the -gcloud token is present on the machine, it will work out of the box, using the logged in user for authentication. The user -needs to have remote build permissions in GCP (add new developers into the `Remote Bazel` role). In the CI, the service account key -is used for authentication and is passed to bazel using `--config=remote_cache --google_credentials=path/to/service.key`. -On [Cloud Build](https://cloud.google.com/build), `docker build --network=cloudbuild` is used to pass the authentication from the service -account running the cloud build down into the docker image doing the compilation: [Application Default Credentials](https://cloud.google.com/docs/authentication/provide-credentials-adc) does the work there and authenticates as the service account. All accounts, both user and service accounts, need to have remote cache read/write permissions. - -Remote cache uses cache silos. Each unique machine and build should specify a unique silo key to benefit from consistent caching. The silo key can be passed using a flag: `-remote_default_exec_properties=cache-silo-key=SOME_SILO_KEY'`. - -Running the build with remote cache: - -```bash -BAZEL_REMOTE_CACHE=1 SILO_NAME="cache-silo-YOUR-USER" TPUVM_MODE=1 python setup.py bdist_wheel -``` - -Adding - -```bash -GCLOUD_SERVICE_KEY_FILE=~/.config/gcloud/application_default_credentials.json -``` - -might help too if `bazel` cannot find the auth token. - -`YOUR-USER` here can the author's username or machine name, a unique name that ensures good cache behavior. Other `setup.py` functionality works as intended too (e.g. `develop`). - -The first time the code is compiled using a new cached key will be slow because it will compile everything from scratch, but incremental compilations will be very fast. On updating the TensorFlow pin, it will once again be a bit slower the first time per key, and then until the next update quite fast again. - -## Running tests - -Currently C++ code is built and tested by bazel. Python code will be migrated in the future. - -Bazel is a test plafrom too, making it easy to run tests: - -```bash -bazel test //test/cpp:main -``` - -Ofcourse the XLA and PJRT configuration have to be present in the environment to run the tests. Not all environmental variables are passed into the bazel test environment to make sure that the remote cache misses are not too common (environment -is part of the cache key), see `.bazelrc` test configuration to see which ones are passed in, and add new ones as required. - -You can run the tests using the helper script too: - -```bash -BAZEL_REMOTE_CACHE=1 SILO_NAME="cache-silo-YOUR-USER" ./test/cpp/run_tests.sh -R -``` - -The `xla_client` tests are pure hermetic tests that can be easily executed. The `torch_xla` plugin tests are more complex: -they require `torch` and `torch_xla` to be installed, and they cannot run in parallel, since they are using either -XRT server/client on the same port, or because they use a GPU or TPU device and there's only one available at the time. -For that reason, all tests under `torch_xla/csrc/` are bundled into a single target `:main` that runs them all sequentially. - -## Code coverage - -When running tests, it can be useful to calculate code coverage. - -```bash -bazel coverage //torch_xla/csrc/runtime/... -``` - -Coverage can be visualized using `lcov` as described in [Bazel's documentation](https://bazel.build/configure/coverage), or in your editor of choice with lcov plugins, e.g. [Coverage Gutters](https://marketplace.visualstudio.com/items?itemName=ryanluker.vscode-coverage-gutters) for VSCode. - - -## Language Server - -Bazel can power a language server like [clangd](https://clangd.llvm.org/) that brings code references, autocompletion and semantic understanding of the underlying code to your editor of choice. For VSCode, -one can use [Bazel Stack](https://github.com/stackb/bazel-stack-vscode-cc) that can be combined with -[clangd](https://marketplace.visualstudio.com/items?itemName=llvm-vs-code-extensions.vscode-clangd) functionality to bring powerful features to assist code editing. - -## Building PyTorch/XLA - -As always, PyTorch/XLA can be built using Python `distutils`: - -```bash -BAZEL_REMOTE_CACHE=1 SILO_NAME="cache-silo-YOUR-USER" TPUVM_MODE=1 python setup.py bdist_wheel -``` diff --git a/docs/ddp.md b/docs/ddp.md deleted file mode 100644 index 86876db1028..00000000000 --- a/docs/ddp.md +++ /dev/null @@ -1,280 +0,0 @@ -# How to do DistributedDataParallel(DDP) - -This document shows how to use torch.nn.parallel.DistributedDataParallel in xla, -and further describes its difference against the native xla data parallel -approach. You can find a minimum runnable example [here](https://github.com/pytorch/xla/blob/master/examples/data_parallel/train_resnet_ddp.py). - - -## Background / Motivation - -Customers have long requested the ability to use PyTorch’s -DistributedDataParallel API with xla. And here we enable it as an experimental -feature. - - -## How to use DistributedDataParallel - -For those who switched from the PyTorch eager mode to XLA, here are all the -changes you need to do to convert your eager DDP model into XLA model. We assume -that you already know how to use XLA [on a single -device](../API_GUIDE.md#running-on-a-single-xla-device). - -1. Import xla specific distributed packages: - -``` -import torch_xla -import torch_xla.runtime as xr -import torch_xla.distributed.xla_backend -``` - -2. Init xla process group similar to other process groups such as nccl and gloo. - -``` -dist.init_process_group("xla", rank=rank, world_size=world_size) -``` - -3. Use xla specific APIs to get rank and world\_size if you need to. - -``` -new_rank = xr.global_ordinal() -world_size = xr.world_size() -``` - -4. Pass `gradient_as_bucket_view=True` to the DDP wrapper. - -``` -ddp_model = DDP(model, gradient_as_bucket_view=True) -``` - -5. Finally launch your model with xla specific launcher. - -``` -torch_xla.launch(demo_fn) -``` - -Here we have put everything together (the example is actually taken from the -[DDP tutorial](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html)). -The way you code it is pretty similar to the eager experience. Just with xla -specific touches on a single device plus the above five changes to your script. - -``` -import os -import sys -import tempfile -import torch -import torch.distributed as dist -import torch.nn as nn -import torch.optim as optim - -from torch.nn.parallel import DistributedDataParallel as DDP - -# additional imports for xla -import torch_xla -import torch_xla.core.xla_model as xm -import torch_xla.runtime as xr -import torch_xla.distributed.xla_backend - -def setup(rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' - - # initialize the xla process group - dist.init_process_group("xla", rank=rank, world_size=world_size) - -def cleanup(): - dist.destroy_process_group() - -class ToyModel(nn.Module): - def __init__(self): - super(ToyModel, self).__init__() - self.net1 = nn.Linear(10, 1000000) - self.relu = nn.ReLU() - self.net2 = nn.Linear(1000000, 5) - - def forward(self, x): - return self.net2(self.relu(self.net1(x))) - -def demo_basic(rank): - # xla specific APIs to get rank, world_size. - new_rank = xr.global_ordinal() - assert new_rank == rank - world_size = xr.world_size() - - print(f"Running basic DDP example on rank {rank}.") - setup(rank, world_size) - - # create model and move it to XLA device - device = xm.xla_device() - model = ToyModel().to(device) - # currently, graident_as_bucket_view is needed to make DDP work for xla - ddp_model = DDP(model, gradient_as_bucket_view=True) - - loss_fn = nn.MSELoss() - optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) - - optimizer.zero_grad() - outputs = ddp_model(torch.randn(20, 10).to(device)) - labels = torch.randn(20, 5).to(device) - loss_fn(outputs, labels).backward() - optimizer.step() - # xla specific API to execute the graph - xm.mark_step() - - cleanup() - - -def run_demo(demo_fn): - # xla specific launcher - torch_xla.launch(demo_fn) - -if __name__ == "__main__": - run_demo(demo_basic) -``` - -## Benchmarking - - -### Resnet50 with fake data - -The following results are collected with the command: `python -test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1` on a -TPU VM V3-8 environment with ToT PyTorch and PyTorch/XLA. And the statistical -metrics are produced by using the script in this [pull -request](https://github.com/pytorch/xla/pull/4107). The unit for the rate is -images per second. - - - - - - - - - - - - - - - - - - - - - - - - - - -
Type - Mean - Median - 90th % - Std Dev - CV -
xm.optimizer_step - 418.54 - 419.22 - 430.40 - 9.76 - 0.02 -
DDP - 395.97 - 395.54 - 407.13 - 7.60 - 0.02 -
- - -The performance difference between our native approach for distributed data -parallel and DistributedDataParallel wrapper is: 1 - 395.97 / 418.54 = 5.39%. -This result seems reasonable given the DDP wrapper introduces extra overheads on -tracing the DDP runtime. - -### MNIST with fake data - -The following results are collected with the command: `python -test/test_train_mp_mnist.py --fake_data` on a TPU VM V3-8 environment with ToT -PyTorch and PyTorch/XLA. And the statistical metrics are produced by using the -script in this [pull request](https://github.com/pytorch/xla/pull/4107). The -unit for the rate is images per second. - - - - - - - - - - - - - - - - - - - - - - - - - - -
Type - Mean - Median - 90th % - Std Dev - CV -
xm.optimizer_step - 17864.19 - 20108.96 - 24351.74 - 5866.83 - 0.33 -
DDP - 10701.39 - 11770.00 - 14313.78 - 3102.92 - 0.29 -
- - -The performance difference between our native approach for distributed data -parallel and DistributedDataParallel wrapper is: 1 - 14313.78 / 24351.74 = -41.22%. Here we compare 90th % instead since the dataset is small and first a -few rounds are heavily impacted by data loading. This slowdown is huge but makes -sense given the model is small. The additional DDP runtime tracing overhead is -hard to amortize. - -### MNIST with real data - -The following results are collected with the command: `python -test/test_train_mp_mnist.py --logdir mnist/` on a TPU VM V3-8 environment with -ToT PyTorch and PyTorch/XLA. - -![learning_curves](_static/img/ddp_md_mnist_with_real_data.png) - -And we can observe that the DDP wrapper converges slower than the native XLA -approach even though it still achieves a high accuracy rate at 97.48% at the -end. (The native approach achieves 99%.) - -## Disclaimer - -This feature is still experimental and under active development. Use it in -cautions and feel free to file any bugs to the [xla github -repo](https://github.com/pytorch/xla/). For those who are interested in the -native xla data parallel approach, here is the -[tutorial](../API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing). - -Here are some of the known issues that are under investigation: -* `gradient_as_bucket_view=True` needs to be enforced. -* There are some issues while being used with `torch.utils.data.DataLoader`. `​​test_train_mp_mnist.py` with real data crashes before exiting. diff --git a/docs/dynamo.md b/docs/dynamo.md deleted file mode 100644 index 794c7aeb090..00000000000 --- a/docs/dynamo.md +++ /dev/null @@ -1,116 +0,0 @@ -## TorchDynamo(torch.compile) integration in PyTorch XLA - -[TorchDynamo](https://pytorch.org/docs/stable/torch.compiler.html) is a Python-level JIT compiler designed to make unmodified PyTorch programs faster. It provides a clean API for compiler backends to hook in and its biggest feature is to dynamically modify Python bytecode right before it is executed. In the pytorch/xla 2.0 release, PyTorch/XLA provided an experimental backend for the TorchDynamo for both inference and training. - -The way that XLA bridge works is that Dynamo will provide a TorchFX graph when it recognizes a model pattern and PyTorch/XLA will use existing Lazy Tensor technology to compile the FX graph and return the compiled function. - -### Integration - -Support for PyTorch/XLA and Dynamo currently exists by adding the `backend='openxla'` argument to `torch.compile`. For example: - -``` -import torch -import torch_xla.core.xla_model as xm - -def add(a, b): - a_xla = a.to(xm.xla_device()) - b_xla = b.to(xm.xla_device()) - return a_xla + b_xla - -compiled_code = torch.compile(add, backend='openxla') -print(compiled_code(torch.randn(10), torch.randn(10))) -``` - - -### Inference -Here is a small code example of running resnet18 with `torch.compile` - -```python -import torch -import torchvision -import torch_xla.core.xla_model as xm - -def eval_model(loader): - device = xm.xla_device() - xla_resnet18 = torchvision.models.resnet18().to(device) - xla_resnet18.eval() - dynamo_resnet18 = torch.compile( - xla_resnet18, backend='openxla') - for data, _ in loader: - with torch.no_grad(): - output = dynamo_resnet18(data) -``` - -With the `torch.compile` you will see that PyTorch/XLA only traces the resent18 model once during the init time and executes the compiled binary every time `dynamo_resnet18` is invoked, instead of tracing the model every time. Here is a inference speed analysis to compare Dynamo and Lazy using torch bench on Cloud TPU v4-8 - -| model | Speed up | -| --- | ----------- | -resnet18 | 2.59 -resnet50 | 2.64 -resnext50_32x4d | 1.91 -alexnet | 1.28 -mobilenet_v2 | 18.62 -mnasnet1_0 | 2.68 -vgg16 | 1.33 -BERT_pytorch | 7.49 -squeezenet1_1 | 2.29 -timm_vision_transformer | 3.52 -geomean | 3.04 - - -### Training -PyTorch/XLA also supports Dynamo for training, but it is experimental and we are working with the PyTorch Compiler team to iterate on the implementation. Here is an example of training a resnet18 with `torch.compile` - -```python -import torch -import torchvision -import torch_xla.core.xla_model as xm - -def train_model(model, data, target, optimizer): - loss_fn = torch.nn.CrossEntropyLoss() - pred = model(data) - loss = loss_fn(pred, target) - loss.backward() - optimizer.step() - return pred - -def train_model_main(loader): - device = xm.xla_device() - xla_resnet18 = torchvision.models.resnet18().to(device) - xla_resnet18.train() - dynamo_train_model = torch.compile( - train_model, backend='openxla') - for data, target in loader: - xla_optimizer = optim.SGD(data, lr=0.1, weight_decay=1e-2) - output = dynamo_train_model(xla_resnet18, data, target, xla_optimizer) -``` - -We expect to extract and execute 3 graphs per training step instead of 1 graph per training step if you use the Lazy tensor. Here is a training speed analysis to compare Dynamo and Lazy using a torch bench on Cloud TPU v4-8. - -| model | Speed up | -| --- | ----------- | -resnet50 | 1.33 -resnet18 | 1.33 -BERT_pytorch | 3.07 -resnext50_32x4d | 1.43 -alexnet | 1.12 -mobilenet_v2 | 1.4 -mnasnet1_0 | 1.19 -vgg16 | 0.81 -timm_vision_transformer | 1.87 -squeezenet1_1 | 1.41 -geomean | 1.41 - -> **NOTE:** We run each model's fwd and bwd for a single step and then collect the e2e time. In the real world we will run multiple steps at each training job which can easily hide the tracing cost from execution(since it is async). Lazy Tensor will have much better performance in that scenario. - -### Feature gaps -There is one gap we want to call out that are preventing us from using the TorchDynamo on larger scale models. - -1. TorchDynamo will trace forward and backward into separate graphs. For PyTorch/XLA it is important to let the XLA compiler see the whole step as one graph to best optimize the speed. There is also a fixed overhead to launch every device execution which make executing multiple graphs per training step less ideal. - -This gap compared to Lazy Tensor makes it less efficient in real world training use cases, especially the tracing cost can be overlapped with the execution in training. - -### Take away -TorchDynamo provides a really promising way for the compiler backend to hide the complexity from the user and easily retrieve the modeling code in a graph format. Compared with PyTorch/XLA's traditional Lazy Tensor way of extracting the graph, TorchDynamo can skip the graph tracing for every iteration, hence providing a much better inference response time. - -Most models supported by PyTorch/XLA, have seen significant speedup when running inference with the new dynamo-xla bridge. Our community is working hard to expand the set of supported models. Regarding the training feature gaps mentioned above, the PyTorch/XLA community is super excited to improve the training gap in our upcoming development work. The team continues to heavily invest in TorchDynamo and work with the upstream to mature the training story. diff --git a/docs/eager.md b/docs/eager.md deleted file mode 100644 index 0d1fde2995c..00000000000 --- a/docs/eager.md +++ /dev/null @@ -1,119 +0,0 @@ -# Eager Mode + Compile API - -In this doc we will go over how to use PyTorch/XLA’s new experimental `eager` mode with the `compile` API. The goal is to make PyTorch/XLA experience more aligned with the native PyTorch and make development process easier. - - -## Background -Currently PyTorch/XLA runs on the LazyTensor tracing mode by default. In the following code -```python -import torch -import torch_xla -import torchvision - -device = torch_xla.device() -model = torchvision.models.resnet18().to(device) -input = torch.randn(64, 3, 224, 224).to(device) - -# model tracing -res = model(input) - -# model execution, same as `xm.mark_step` -torch_xla.sync() -``` -The actual model compilation and device execution happens when `torch_xla.sync` is called. There are multiple drawback of this approach. - -1. Users are often confused about when the framework is tracing and when the framework is executing. -2. Non-core model code(data preprocessing for example) often generates some small pending execution that gets leaked into the main graph(step function) and causes recompilation. The recompilation of the whole graph is usually very expensive. -3. It is hard to debug when/why recompilation happens. - -To mitigate above issues we want to introduce the new UX with eager and compile. - -## Basic Usage -```python -import torch -import torch_xla -import torchvision - -# Run ops eagerly by default -torch_xla.experimental.eager_mode(True) - -device = torch_xla.device() -model = torchvision.models.resnet18().to(device) - -# Mark the function to be compiled -compiled_model = torch_xla.compile(model) -input = torch.randn(64, 3, 224, 224).to(device) - -# Compilation and execution happens right away. -res = compiled_model(input) -``` -Note that - -1. Currently user has to manually enable the eager mode by `torch_xla.experimental.eager_mode(True)`. -2. The region of the code that wants to be compiled should be wrapped by `torch_xla.compile`. - -The implementation of the `torch_xla.compile` is actually pretty straight forward, it disable the eager mode when entering the target function and start tracing. It will call the `torch_xla.sync()` when target function returns and reenable the eager mode. You can expect the same perfomrance by using the `eager` + `compile` API compared to the existing `mark_step/sync` approach. - - -### Inference -```python -torch_xla.experimental.eager_mode(True) - -compiled_model = torch.compile(model, backend="openxla") -``` -It is recommened to use the `torch.compile` instead of `torch_xla.compile` for inference to reduce the tracing overhad. - -### Training -```python -torch_xla.experimental.eager_mode(True) - -def step_fn(model, data, target, loss_fn, optimizer): - optimizer.zero_grad() - logits = model(data) - loss = loss_fn(logits, target) - loss.backward() - optimizer.step() - return loss - -step_fn = torch_xla.compile(step_fn) -``` -In training we asked user to refactor the `step_fn` out because it is usually better to compile the model's forward, backward and optimizer together. The long term goal is to also use `torch.compile` for training but right now we recommend user to use `torch_xla.compile`(for perfomrance reason). - -## Benchmark - -I run a 2 layer decoder only model training(it is pretty much just a llama2) with fake data on a single chip of v4-8 for 300 steps. Below is the number I observed. - - - - - - - - - - - - - - - - - - - - - - -
- token/s - -
Tracing mode(base line) - 147 -
Eager mode - 65 -
Eager + torch_xla compile - 147 -
- - -Eager mode can achieve ~45% performance of the fully compiled model for the decoder only model. The trainer I used to test can be found [here](https://github.com/pytorch/xla/blob/master/examples/train_decoder_only_base.py) and [here](https://github.com/pytorch/xla/tree/master/examples/eager). Note that perfomrane of the eager mode is very model dependent. When I tried to run the resnet50, the eager mode perfomrance is ~1% of the compiled mode. We don't exepct user to use eager mode to execute the main training loop. Eager mode is meant to be used to handle non-core part of the training/inference logic(Data preprocessing, random number generations etc) or debug. diff --git a/docs/first_steps.md b/docs/first_steps.md deleted file mode 100644 index 2db784fcf95..00000000000 --- a/docs/first_steps.md +++ /dev/null @@ -1,257 +0,0 @@ -### **Objective:** -This document provides a high-level overview of PyTorch XLA and illustrates a -few examples how PyTorch code is converted to run on XLA devices (e.g. TPUs). -This is not a complete solution, and additional changes may be required -depending on the specific code. However, this document should serve as a -starting point for the conversion process. - - -# Basic high-level understanding of some XLA details -This section provides a brief overview of the basic details of PyTorch XLA, - which should help readers better understand the required modifications and - optimizations of code. - -Unlike regular PyTorch, which executes code line by line and does not block execution until the value of a PyTorch tensor is fetched, PyTorch XLA works differently. It iterates through the python code and records the operations on (PyTorch) XLA tensors in an intermediate representation (IR) graph until it encounters a barrier (discussed below). This process of generating the IR graph is referred to as tracing (LazyTensor tracing or code tracing). PyTorch XLA then converts the IR graph to a lower-level machine-readable format called HLO (High-Level Opcodes). HLO is a representation of a computation that is specific to the XLA compiler and allows it to generate efficient code for the hardware that it is running on. HLO is fed to the XLA compiler for compilation and optimization. Compilation is then cached by PyTorch XLA to be reused later if/when needed. The compilation of the graph is done on the host (CPU), which is the machine that runs the Python code. If there are multiple XLA devices, the host compiles the code for each of the devices separately except when using SPMD (single-program, multiple-data). For example, v4-8 has one host machine and [four devices](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4). In this case the host compiles the code for each of the four devices separately. In case of pod slices, when there are multiple hosts, each host does the compilation for XLA devices it is attached to. If SPMD is used, then the code is compiled only once (for given shapes and computations) on each host for all the devices. - -![img](_static/img/pytorchXLA_flow.svg) - -For more details and examples, please refer to the [LazyTensor guide](https://pytorch.org/blog/understanding-lazytensor-system-performance-with-pytorch-xla-on-cloud-tpu/). - -The operations in the IR graph are executed only when values of tensors are needed. This is referred to as evaluation or materialization of tensors. Sometimes this is also called lazy evaluation and it can lead to significant [performance improvements](https://arxiv.org/pdf/2102.13267.pdf). - -The _synchronous_ operations in Pytorch XLA, like printing, logging, checkpointing or callbacks block tracing and result in slower execution. In the case when an operation requires a specific value of an XLA tensor, e.g. `print(xla_tensor_z)`, tracing is blocked until the value of that tensor is available to the host. Note that only the part of the graph responsible for computing that tensor value is executed. These operations do not cut the IR graph, but they trigger host-device communication through `TransferFromDevice`, which results in slower performance. - -A _barrier_ is a special instruction that tells XLA to execute the IR graph and materialize the tensors. This means that the PyTorch XLA tensors will be evaluated, and the results will be available to the host. The user-exposed barrier in Pytorch XLA is [xm.mark_step()](https://github.com/pytorch/xla/blob/bdceee54eca1269ee954f6cdd1868c584d0e88a4/torch_xla/core/xla_model.py#L808), which breaks the IR graph and results in code execution on the XLA devices. One of the key properties of `xm.mark_step` is that unlike synchronous operations it does not block the further tracing while the device is executing the graph. However, it does block access to the values of the tensors that are being materialized. - -The example in the LazyTensor guide illustrates what happens in a simple case of adding two tensors. Now, suppose we have a for loop that adds XLA tensors and uses the value later: - -``` -for x, y in tensors_on_device: - z += x + y -``` - -Without a barrier, the Python tracing will result in a single graph that wraps the addition of tensors `len(tensors_on_device)` times. This is because the `for` loop is not captured by the tracing, so each iteration of the loop will create a new subgraph corresponding to the computation of `z += x+y` and add it to the graph. Here is an example when `len(tensors_on_device)=3`. - -![img](_static/img/IRgraph_no_markstep.png) - -However, introducing a barrier at the end of the loop will result in a smaller graph that will be compiled once during the first pass inside the `for` loop and will be reused for the next `len(tensors_on_device)-1 ` iterations. The barrier will signal to the tracing that the graph traced so far can be submitted for execution, and if that graph has been seen before, a cached compiled program will be reused. - -``` -for x, y in tensors_on_device: - z += x + y - xm.mark_step() -``` - -In this case there will be a small graph that is used `len(tensors_on_device)=3` times. - -![img](_static/img/IRgraph_markstep.png) - -It is important to highlight that in PyTorch XLA Python code inside for loops is traced and a new graph is constructed for each iteration if there is a barrier at the end. This can be a significant performance bottleneck. - -The XLA graphs can be reused when the same computation happens on the same shapes of tensors. If the shapes of the inputs or intermediate tensors change, then the XLA compiler will recompile a new graph with the new tensor shapes. This means that if you have dynamic shapes or if your code does not reuse tensor graphs, running your model on XLA will not be suitable for that use case. Padding the input into a fixed shape can be an option to help avoid dynamic shapes. Otherwise, a significant amount of time will be spent by the compiler on optimizing and fusing operations which will not be used again. - -The trade-off between graph size and compilation time is also important to consider. If there is one large IR graph, the XLA compiler can spend a lot of time on optimization and fusion of the ops. This can result in a very long compilation time. However, the later execution may be much faster, due to the optimizations that were performed during compilation. - -Sometimes it is worth breaking the IR graph with `xm.mark_step()`. As explained above, this will result in a smaller graph that can be reused later. However making graphs smaller can reduce optimizations that otherwise could be done by the XLA compiler. - -Another important point to consider is [MPDeviceLoader](https://github.com/pytorch/xla/blob/a1f822e2627a5639464273241821852677401026/torch_xla/distributed/parallel_loader.py#L186). Once your code is running on an XLA device, consider wrapping the torch dataloader with XLA `MPDeviceLoader` which preloads data to the device to improve performance and includes `xm.mark_step()` in it. The latter automatically breaks the iterations over batches of data and sends them for execution. Note, if you are not using MPDeviceLoader, you might need to set `barrier=True` in the `optimizer_step()` to enable `xm.mark_step()` if running a training job or explicitly adding `xm.mark_step()`. - -# TPU Setup -Create TPU with base image to use nightly wheels or from the stable release by specifying the `RUNTIME_VERSION`. -``` -export ZONE=us-central2-b -export PROJECT_ID=your-project-id -export ACCELERATOR_TYPE=v4-8 # v4-16, v4-32, … -export RUNTIME_VERSION=tpu-vm-v4-pt-2.0 # or tpu-vm-v4-base -export TPU_NAME=your_tpu_name - -gcloud compute tpus tpu-vm create ${TPU_NAME} \ ---zone=${ZONE} \ ---accelerator-type=${ACCELERATOR_TYPE} \ ---version=${RUNTIME_VERSION} \ ---subnetwork=tpusubnet -``` - -If you have a single host VM (e.g. v4-8), you can ssh to your vm and run the following commands from the vm directly. Otherwise, in case of TPU pods, you can use `--worker=all --command=""` similar to - -``` -gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ ---zone=us-central2-b \ ---worker=all \ ---command="pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl" -``` - -Next, if you are using base image, install nightly packages and required libraries - -``` -pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl -​​pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl -sudo apt-get install libopenblas-dev -y - -sudo apt-get update && sudo apt-get install libgl1 -y # diffusion specific -``` - -# Converting code to PyTorch XLA -General guidelines to modify your code: -* Replace `cuda` with `xm.xla_device()` -* Remove progress bar, printing that would access the XLA tensor values -* Reduce logging and callbacks that would access the XLA tensor values -* Wrap data loader with MPDeviceLoader -* Profile to further optimize the code - -Remember: each case is unique so you might need to do something different for each case. - -# Example 1. Stable Diffusion inference in PyTorch Lightning on a Single TPU Device - -As a first example consider the [inference code](https://github.com/pytorch-tpu/stable-diffusion/blob/main/scripts/txt2img.py) of the stable diffusion model in PyTorch Lightning which can be run from command line as -``` -python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" -``` - -For your reference, the diff of modifications described below can be found [here](https://github.com/pytorch-tpu/stable-diffusion/commit/57f398eb784387e244dc5fb78421aa5261abd1ef). Let's go over them step by step. -As in the general guideline above, start with changes related to `cuda` device. This inference code is written to run on GPUs and `cuda` can be found in multiple places. Start making changes by removing `model.cuda()` from [this line](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/scripts/txt2img.py#L64), and `precision_scope` from [here](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/scripts/txt2img.py#L290). Additionally, replace the `cuda` device in [this line](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/scripts/txt2img.py#L248) with the `xla` device similar to the code below: - -Next, this particular configuration of the model is using `FrozenCLIPEmbedder`, therefore we will modify this [line](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/ldm/modules/encoders/modules.py#L143) as well. For simplicity we will directly define the `device` in this tutorial, but you can pass the `device` value to the function as well. -``` -import torch_xla.core.xla_model as xm -self.device = xm.xla_device() -``` - -Another place in the code that has cuda specific code is DDIM scheduler. Add `import torch_xla.core.xla_model as xm` on top of the file then replace [these](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/ldm/models/diffusion/ddim.py#L21-L22) lines - - -``` -if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) -``` - -with - -``` -device = xm.xla_device() -attr = attr.to(torch.device(device)) -``` - -Next, you can reduce device (TPU) and host (CPU) communication by removing print statements, disabling progress bars, and reducing or removing callbacks and logging. These operations require the device to stop executing, falling back to the CPU, executing the logging/callbacks, and then returning to the device. This can be a significant performance bottleneck, especially on large models. - -After making these changes, the code will run on TPUs. However, the performance will be very slow. This is because the XLA compiler tries to build a single (huge) graph that wraps the number of inference steps (in this case, 50) as there is no barrier inside the for loop. It is difficult for the compiler to optimize the graph, and this leads to significant performance degradation. As discussed above, breaking the for loop with the barrier (xm.mark_step()) will result in a smaller graph that is easier for the compiler to optimize. This will also allow the compiler to reuse the graph from the previous step, which can improve performance. - -Now the [code](https://github.com/pytorch-tpu/stable-diffusion/blob/ss-inference/scripts/txt2img.py) is ready to run on TPUs in a reasonable time. More optimization and analysis can be done by [capturing a profile](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm) and investigating further. However, this is not covered here. - -Note: if you are running on v4-8 TPU, then you have 4 available XLA (TPU) devices. Running the code as above will only use one XLA device. In order to run on all 4 devices you need to use `torch_xla.launch()` function to spawn the code on all the devices. We will discuss a `torch_xla.launch` in the next example. - -# Example 2. HF Stable Diffusion Inference -Now, consider using [Stable Diffusion Inference](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) in the HuggingFace diffusers library for both the SD-XL and 2.1 versions of the model. For your reference, the changes described below can be found in this [repo](https://github.com/pytorch-tpu/diffusers). You can clone the repo and run the inference using the following command on your TPU VM: - -``` -(vm)$ git clone https://github.com/pytorch-tpu/diffusers.git -(vm)$ cd diffusers/examples/text_to_image/ -(vm)$ python3 inference_tpu_single_device.py -``` - -# Running on a Single TPU device - -This section describes the changes that need to be made to the [text_to_image inference example](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#inference) code to run it on TPUs. - -The original code uses Lora for inference, but this tutorial will not use it. Instead, we will set the `model_id` argument to `stabilityai/stable-diffusion-xl-base-0.9` when initializing the pipeline. We will also use the default scheduler (DPMSolverMultistepScheduler). However, similar changes can be made to the other schedulers as well. -``` -git clone https://github.com/huggingface/diffusers -cd diffusers -pip install . # pip install -e . - -cd examples/text_to_image/ -pip install -r requirements.txt -pip install invisible_watermark transformers accelerate safetensors -``` -(If `accelerate` is not found, log out, log back in.) - -Log in to HF and agree to the [sd-xl 0.9 license](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9) on the model card. Next, go to [account→settings→access](https://huggingface.co/settings/tokens) token and generate a new token. Copy the token and run the following command with that specific token value on your vm -``` -(vm)$ huggingface-cli login --token _your_copied_token__ -``` - -The HuggingFace readme provides PyTorch code that is written to run on GPUs. To run it on TPUs, the first step is to change the CUDA device to an XLA device. This can be done by replacing the line `pipe.to("cuda")` with the following lines: - -``` -import torch_xla.core.xla_model as xm -device = xm.xla_device() -pipe.to(device) -``` - -Additionally, it is important to note that the first time you run inference with XLA, it will take a long time to compile. For example, compilation time for stable diffusion XL model inference from HuggingFace can take about an hour to compile, whereas the actual inference may take only 5 seconds, depending on the batch size. Likewise, a GPT-2 model can take about 10-15 mins to compile, after which the training epoch time becomes much faster. This is because XLA builds a graph of the computation that will be performed, and then optimizes this graph for the specific hardware that it is running on. However, once the graph has been compiled, it can be reused for subsequent inferences, which will be much faster. Therefore, if you are only running inference once, you may not benefit from using XLA. However, if you are running inference multiple times, or if you are running inference on a list of prompts, you will start to see the advantages of XLA after the first few inferences. For example, if you run inference on a list of 10 prompts, the first inference (maybe two[^1]) may take a long time to compile, but the remaining inference steps will be much faster. This is because XLA will reuse the graph that it compiled for the first inference. - -If you try to run the code without making any additional changes, you will notice that the compilation time is very long (>6 hours). This is because the XLA compiler tries to build a single graph for all of the scheduler steps at once similar to what we have discussed in the previous example. To make the code run faster, we need to break the graph up into smaller pieces with `xm.mark_step()` and reuse them in the next steps. This happens inside the `pipe.__call__` [function](https://github.com/huggingface/diffusers/blob/2b1786735e27bc97f4d4699712292d5c463a7380/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L559) in [these lines](https://github.com/huggingface/diffusers/blob/2b1786735e27bc97f4d4699712292d5c463a7380/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L805-L839). Disabling the progress bar, removing callbacks and adding `xm.mark_step()` at the end of the for loop speeds up the code significantly. Changes are provided in this [commit](https://github.com/huggingface/diffusers/compare/main...pytorch-tpu:diffusers:main). - - -Additionally, the `self.scheduler.step()` function, which by default uses the DPMSolverMultistepScheduler scheduler, has a few issues that are described in the -[PyTorch XLA caveats](https://pytorch.org/xla/release/2.0/index.html#known-performance-caveats). The `.nonzero()` and `.item()` calls in this function send requests to the CPU for tensor evaluation, which trigger device-host communication. This is not desirable, as it can slow down the code. In this particular case, we can avoid these calls by passing the index to the function directly. This will prevent the function from sending requests to the CPU, and will improve the performance of the code. Changes are available in [this](https://github.com/pytorch-tpu/diffusers/commit/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d) commit. The code now is ready to be run on TPUs. - -[^1]: 0 and 1 are magic numbers in XLA and treated as constants in the HLO. So if there is a random number generator in the code that can generate these values, the code will compile for each value separately. This can be disabled with `XLA_NO_SPECIAL_SCALARS=1` environment variable. - - -# Profiling and performance analysis - -To further investigate the performance of the model, we can profile it using the profiling [guide](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm). As a rule of thumb, the profiling script should be run with the maximum batch size that fits into the memory for [optimal memory usage](https://cloud.google.com/tpu/docs/performance-guide). It also helps to overlap tracing of the code with device execution which leads to more optimal device usage. The duration of profiling should be long enough to capture at least one step. Good performance of the model on TPUs means that device-host communication is minimized and the device is constantly running processes with no idle time. - -Starting a server in the `inference_tpu_*.py` file and running `capture_profile.py` script as described in the guide will give us information on processes that run on the devices. Currently, only one XLA device is profiled. To better understand the TPU idle time (gaps in the profile), profiling traces (`xp.Trace()`) should be added to the code. The `xp.Trace()` measures the time it takes to trace the python code on the host machine wrapped with the trace. For this example, `xp.Trace()` traces were added inside the [pipeline](https://github.com/ssusie/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py) and the [U-net model](https://github.com/ssusie/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py) to measure the time to run specific sections of the code on the host (CPU). - -If the gaps in the profile are due to Python code tracing that happens on the host, then this might be a bottleneck and there is no further straightforward optimization that can be done. Otherwise, the code should be analyzed further to understand the caveats and improve the performance further. Note that you cannot `xp.Trace()` wrap portions of the code where `xm.mark_step()` is called. - -To illustrate this we can look at already captured profiles that were uploaded to tensorboard following the profiling guide. - -Starting from Stable Diffusion model version 2.1 - -If we capture a profile without inserting any traces, we will see the following: - -![Alt text](_static/img/image.png) - -The single TPU device on v4-8, which has two cores, appears to be busy. There are no significant gaps in their usage, except for a small one in the middle. If we scroll up to try to find which process is occupying the host machine, we will not find any information. Therefore, we will add `xp.traces` to the pipeline [file](https://github.com/pytorch-tpu/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) as well as the U-net [function](https://github.com/pytorch-tpu/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py). The latter may not be useful for this particular use case, but it does demonstrate how traces can be added in different places and how their information is displayed in TensorBoard. - -If we add traces and re-capture the profile with the largest batch size that can fit on the device (32 in this case), we will see that the gap in the device is caused by a Python process that is running on the host machine. -![Alt text](_static/img/image-1.png) -![Alt text](_static/img/image-2.png) - -We can use the appropriate tool to zoom in on the timeline and see which process is running during that period. This is when the Python code tracing happens on the host, and we cannot improve the tracing further at this point. - - -Now, let's examine the XL version of the model and do the same thing. We will add traces to the pipeline [file](https://github.com/pytorch-tpu/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py) in the same way that we did for the 2.1 version and capture a profile. - -![Alt text](_static/img/image-4.png) - -This time, in addition to the large gap in the middle, which is caused by the `pipe_watermark` tracing, there are many small gaps between the inference steps within [this loop](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L814-L830). - -First look closer into the large gap that is caused by `pipe_watermark`. The gap is preceded with `TransferFromDevice` which indicates that something is happening on the host machine that is waiting for computation to finish before proceeding. Looking into watermark [code](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/pipelines/stable_diffusion_xl/watermark.py#L29), we can see that tensors are transferred to cpu and converted to numpy arrays in order to be processed with `cv2` and `pywt` libraries later. Since this part is not straightforward to optimize, we will leave this as is. - -Now if we zoom in on the loop, we can see that the graph within the loop is broken into smaller parts because the `TransferFromDevice` operation happens. -![Alt text](_static/img/image-3.png) - - -If we investigate the U-Net function and the scheduler, we can see that the U-Net code does not contain any optimization targets for PyTorch/XLA. However, there are `.item()` and `.nonzero()` calls inside the [scheduler.step](https://github.com/huggingface/diffusers/blob/15782fd506e8c4a7c2b288fc2e558bd77fdfa51a/src/diffusers/schedulers/scheduling_euler_discrete.py#L371). We can [rewrite](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/schedulers/scheduling_euler_discrete.py#L310) the function to avoid those calls. If we fix this issue and rerun a profile, we will not see much difference. However, since we have reduced the device-host communication that was introducing smaller graphs, we allowed the compiler to optimize the code better. The function [scale_model_input](https://github.com/huggingface/diffusers/blob/15782fd506e8c4a7c2b288fc2e558bd77fdfa51a/src/diffusers/schedulers/scheduling_euler_discrete.py#L205) has similar issues, and we can fix these by making the changes we made above to the `step` function. Overall, since many of the gaps are caused from python level code tracing and graph building, these gaps are not possible to optimize with the current version of PyTorch XLA, but we may see improvements in the future when dynamo is enabled in PyTorch XLA. - - -# Running on Multiple TPU Devices - -To use multiple TPU devices, you can use the `torch_xla.launch` function to spawn the function you ran on a single device to multiple devices. The `torch_xla.launch` function will start processes on multiple TPU devices and sync them when needed. This can be done by passing the `index` argument to the function that runs on a single device. For example, -``` -import torch_xla - -def my_function(index): - # function that runs on a single device - -torch_xla.launch(my_function, args=(0,)) -``` - -In this example, the `my_function` function will be spawned on 4 TPU devices on v4-8, with each device being assigned an index from 0 to 3. Note that by default, the launch() function will spawn preocesses on all TPU devices. If you only want to run single process, set the argument `launch(..., debug_single_process=True)`. - -[This file](https://github.com/ssusie/diffusers/blob/main/examples/text_to_image/inference_tpu_multidevice.py) illustrates how xmp.spawn can be used to run stable diffusion 2.1 version on multiple TPU devices. For this version similar to the above changes were made to the [pipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) file. - - -# Running on Pods -Once you have the code for running on a single host device, there is no further change needed. You can create the TPU pod, for example, by following these [instructions](https://cloud.google.com/tpu/docs/pytorch-pods#create-tpu-vm). Then run your script with -``` -gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ - --zone=${ZONE} \ - --worker=all \ - --command="python3 your_script.py" -``` - diff --git a/docs/fsdp.md b/docs/fsdp.md deleted file mode 100644 index f0ccad4fc9f..00000000000 --- a/docs/fsdp.md +++ /dev/null @@ -1,112 +0,0 @@ -## Fully Sharded Data Parallel (FSDP) in PyTorch XLA - -Fully Sharded Data Parallel (FSDP) in PyTorch XLA is a utility for sharding Module parameters across data-parallel workers. - -Example usage: -```python3 -import torch -import torch_xla.core.xla_model as xm -import torch_xla.runtime as xr -from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP - -model = FSDP(my_module) -optim = torch.optim.Adam(model.parameters(), lr=0.0001) -output = model(x, y) -loss = output.sum() -loss.backward() -optim.step() -``` -It is also possible to shard individual layers separately and have an outer wrapper handle any leftover parameters. - -Notes: -* The `XlaFullyShardedDataParallel` class supports both the ZeRO-2 optimizer (sharding gradients and optimizer states) and the ZeRO-3 optimizer (sharding parameters, gradients, and optimizer states) in https://arxiv.org/abs/1910.02054. - * The ZeRO-3 optimizer should be implemented via nested FSDP with `reshard_after_forward=True`. See `test/test_train_mp_mnist_fsdp_with_ckpt.py` and `test/test_train_mp_imagenet_fsdp.py` for an example. - * For large models that cannot fit into a single TPU memory or the host CPU memory, one should interleave submodule construction with inner FSDP wrapping. See [`FSDPViTModel`](https://github.com/ronghanghu/vit_10b_fsdp_example/blob/master/run_vit_training.py) for an example. -* a simple wrapper `checkpoint_module` is provided (based on `torch_xla.utils.checkpoint.checkpoint` from https://github.com/pytorch/xla/pull/3524) to perform [gradient checkpointing](https://spell.ml/blog/gradient-checkpointing-pytorch-YGypLBAAACEAefHs) over a given `nn.Module` instance. See `test/test_train_mp_mnist_fsdp_with_ckpt.py` and `test/test_train_mp_imagenet_fsdp.py` for an example. -* Auto-wrapping submodules: instead of manually nested FSDP wrapping, one can also specify an `auto_wrap_policy` argument to automatically wrap the submodules with inner FSDP. `size_based_auto_wrap_policy` in `torch_xla.distributed.fsdp.wrap` is an example of `auto_wrap_policy` callable, this policy wraps layers with the number of parameters larger than 100M. `transformer_auto_wrap_policy` in `torch_xla.distributed.fsdp.wrap` is an example of `auto_wrap_policy` callable for transformer-like model architectures. - -For example, to automatically wrap all `torch.nn.Conv2d` submodules with inner FSDP, one can use: -```python3 -from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy -auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d}) -``` - -Additionally, one can also specify an `auto_wrapper_callable` argument to use a custom callable wrapper for the submodules (the default wrapper is just the `XlaFullyShardedDataParallel` class itself). For example, one can use the following to apply gradient checkpointing (i.e. activation checkpointing/rematerialization) to each auto-wrapped submodule. -```python3 -from torch_xla.distributed.fsdp import checkpoint_module -auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel( - checkpoint_module(m), *args, **kwargs) -``` -* When stepping the optimizer, directly call `optimizer.step` and do not call `xm.optimizer_step`. The latter reduces the gradient across ranks, which is not needed for FSDP (where the parameters are already sharded). -* When saving model and optimizer checkpoints during training, each training process needs to save its own checkpoint of the (sharded) model and optimizer state dicts (use `master_only=False` and set different paths for each rank in `xm.save`). When resuming, it needs to load the checkpoint for the corresponding rank. -* Please also save `model.get_shard_metadata()` along with `model.state_dict()` as follows and use `consolidate_sharded_model_checkpoints` to stitch the sharded model checkpoints together into a full model state dict. See `test/test_train_mp_mnist_fsdp_with_ckpt.py` for an example. -```python3 -ckpt = { - 'model': model.state_dict(), - 'shard_metadata': model.get_shard_metadata(), - 'optimizer': optimizer.state_dict(), -} -ckpt_path = f'/tmp/rank-{xr.global_ordinal()}-of-{xr.world_size()}.pth' -xm.save(ckpt, ckpt_path, master_only=False) -``` -* The checkpoint consolidation script can also be launched from the command line as follows. -```bash -# consolidate the saved checkpoints via command line tool -python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \ - --ckpt_prefix /path/to/your_sharded_checkpoint_files \ - --ckpt_suffix "_rank-*-of-*.pth" -``` - -The implementation of this class is largely inspired by and mostly follows the structure of `fairscale.nn.FullyShardedDataParallel` in https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html. One of the biggest differences from `fairscale.nn.FullyShardedDataParallel` is that in XLA we don't have explicit parameter storage, so here we resort to a different approach to free full parameters for ZeRO-3. - ---- - -### Example training scripts on MNIST and ImageNet -* Minimum example : [`examples/fsdp/train_resnet_fsdp_auto_wrap.py`](https://github.com/pytorch/xla/blob/master/examples/fsdp/train_resnet_fsdp_auto_wrap.py) -* MNIST: [`test/test_train_mp_mnist_fsdp_with_ckpt.py`](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_fsdp_with_ckpt.py) (it also tests checkpoint consolidation) -* ImageNet: [`test/test_train_mp_imagenet_fsdp.py`](https://github.com/pytorch/xla/blob/master/test/test_train_mp_imagenet_fsdp.py) - -#### Installation -FSDP is available on PyTorch/XLA 1.12 release and newer nightly. Please refer to https://github.com/pytorch/xla#-available-images-and-wheels for installation guide. - -#### Clone PyTorch/XLA repo -```bash -git clone --recursive https://github.com/pytorch/pytorch -cd pytorch/ -git clone --recursive https://github.com/pytorch/xla.git -cd ~/ -``` -#### Train MNIST on v3-8 TPU -It gets around 98.9 accuracy for 2 epochs: -```bash -python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \ - --batch_size 16 --drop_last --num_epochs 2 \ - --use_nested_fsdp --use_gradient_checkpointing -``` -This script automatically tests checkpoint consolidation at the end. You can also manually consolidate the sharded checkpoints via -```bash -# consolidate the saved checkpoints via command line tool -python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \ - --ckpt_prefix /tmp/mnist-fsdp/final_ckpt \ - --ckpt_suffix "_rank-*-of-*.pth" -``` - -#### Train ImageNet with ResNet-50 on v3-8 TPU - -It gets around 75.9 accuracy for 100 epochs; download [ImageNet-1k](https://github.com/pytorch/examples/tree/master/imagenet#requirements) to `/datasets/imagenet-1k`: -```bash -python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \ - --datadir /datasets/imagenet-1k --drop_last \ - --model resnet50 --test_set_batch_size 64 --eval_interval 10 \ - --lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \ - --use_nested_fsdp -``` -You can also add ` --use_gradient_checkpointing` (which needs to be used along with `--use_nested_fsdp` or `--auto_wrap_policy`) to apply gradient checkpointing on the residual blocks. - ---- - -### Example training scripts on TPU pod (with 10 billion parameters) - -To train large models that cannot fit into a single TPU, one should apply auto-wrap or manually wrap the submodules with inner FSDP when building the entire model to implement the ZeRO-3 algorithm. - -Please see https://github.com/ronghanghu/vit_10b_fsdp_example for an example of sharded training of a Vision Transformer (ViT) model using this XLA FSDP PR. diff --git a/docs/glossry.md b/docs/glossry.md deleted file mode 100644 index 0dedc24c41a..00000000000 --- a/docs/glossry.md +++ /dev/null @@ -1,137 +0,0 @@ -# PyTorch/XLA Glossary - -This glossary defines common terms used in the PyTorch/XLA documentation. - -## A - -**Accelerator** - A specialized hardware component designed to accelerate specific computational tasks, such as deep learning. Examples include GPUs and TPUs. - -## B - -**Barrier** - In the context of PyTorch/XLA, a synchronization point that ensures all operations on XLA tensors have completed before proceeding. It's often used to ensure the host (CPU) and device (TPU/GPU) are synchronized. - -**bfloat16** - A 16-bit floating-point data type commonly used on TPUs for faster training. - -## C - -**CUDA** - A parallel computing platform and programming model developed by NVIDIA for use with their GPUs. - -**Core Aten Op Set** - A collection of fundamental operations from PyTorch's ATen library considered essential for core functionality and model export. - -## D - -**Data Parallelism** - A parallelization strategy where the same model is replicated on multiple devices, with each device processing a different subset of the training data. - -**Device Mesh** - A logical representation of the interconnected devices (TPUs/GPUs) used for distributed training, defining the arrangement and communication paths between them. - -**DistributedDataParallel (DDP)** - A PyTorch module that enables data-parallel training across multiple devices, typically used in conjunction with torch.distributed. - -**Distributed Tensor** - A PyTorch API for representing tensors distributed across multiple devices, facilitating parallel and distributed computation. - -**Dynamo** (See **TorchDynamo**) - -## E - -**Eager Execution** - A computational model where operations are executed immediately as they are encountered in the code, as opposed to graph execution. - -**Environment Variables** - Variables that can be set outside of a program to control its behavior, often used in PyTorch/XLA to configure runtime options. - -## F - -**FSDP (Fully Sharded Data Parallel)** - A data-parallel training technique that shards model parameters, gradients, and optimizer states across devices - -**FX (TorchFX)** - An intermediate representation (IR) format used in PyTorch for representing computation graphs in a more structured way. - -**Functionalization** - A process of converting eager execution code into a functional representation, allowing for greater optimization and compilation opportunities. - -## G - -**GSPMD (General and Scalable Parallelization for ML Computation Graphs)** - A single API that enables a large variety of parallelism algorithms (including data parallelism, fully sharded data parallelism, spatial partitioning tensor and pipeline parallelism, as well as combinations of these algorithms) for different ML workloads and model architectures. - -## H - -**HLO (High-Level Optimizer)** - An intermediate representation (IR) format used by the XLA compiler, representing a computation graph at a higher level than machine code. - -**Hugging Face** - A community and platform providing tools and resources for natural language processing, including pre-trained models and a popular Trainer API. - -## I - -**IR (Intermediate Representation)** - A representation of a program or computation graph that is more abstract than machine code but closer to it than the original source code. - -## J - -**JAX** - A high-performance numerical computation library developed by Google, known for its automatic differentiation and XLA integration. - -**JIT (Just-in-Time Compilation)** - A compilation strategy where code is compiled at runtime, as needed, offering flexibility and potential optimizations based on runtime information. - -## K - -**Kaggle** - An online community and platform for machine learning practitioners to share code and solutions. - -## L - -**Lazy Tensor** - A type of tensor in PyTorch/XLA that delays operation execution until the results are explicitly needed, allowing for graph optimization and XLA compilation. - -**Lit-GPT** - Implements open-source large language models in XLA and supports fine-tuning - -## M - -**Model Parallelism** - A parallelization strategy where different parts of a model are distributed across multiple devices, enabling training of models too large to fit on a single device. - -**Multiprocessing** - A programming technique for running multiple processes concurrently, often used in PyTorch/XLA to utilize multiple TPU cores. - -**MpDeviceLoader** - A PyTorch/XLA utility for efficiently loading and distributing data across multiple devices during training. - -## N - -**NCCL (NVIDIA Collective Communications Library)** - A library for efficient collective communication operations (e.g., all-reduce, all-gather) on NVIDIA GPUs. - -## O - -**OpenXLA** - An open-source project aimed at developing and maintaining XLA, the deep learning compiler. - -**Ordinal** - A unique identifier for a device (TPU/GPU) within a distributed training setup, often used to determine the device's role and data partitioning. - -## P - -**Partition Spec** - In GSPMD, a specification that defines how a tensor is sharded across a device mesh. - -**PJRT (Portable JAX Runtime)** - A runtime environment for JAX that supports multiple backends. - -**Pod** - A group of interconnected TPU hosts, offering massive scale for training large models. - -**Preemption** - An event where a Cloud TPU is reclaimed by the cloud provider, requiring checkpointing to avoid losing training progress. - -## R - -**Rendezvous** - Used by Torch Distributed Elastic to gather participants of a training job (i.e. nodes) such that they all agree on the same list of participants and everyone’s roles, as well as make a consistent collective decision on when training can begin/resume. - -**Replication** - A data distribution strategy where a tensor is fully copied to all devices in a mesh, ensuring all devices have the same data. - -## S - -**Sharding** - The process of dividing a tensor into smaller pieces (shards) and distributing them across devices, commonly used to reduce memory footprint and enable parallel computation. - -**SPMD (Single Program, Multiple Data)** - A parallel programming model where the same program is executed on multiple devices. - -**State Dict**- A Python dictionary object that maps each layer to its parameter tensor. It is used for saving or loading models. - -## T - -**TensorBoard** - A visualization tool for monitoring and analyzing training progress, including performance metrics and computation graphs. - -**TorchDynamo** - A Python-level JIT compiler for PyTorch, dynamically modifying bytecode to enable graph capture and optimization. - -**TPU (Tensor Processing Unit)** - A custom-designed machine learning accelerator developed by Google, offering high performance for deep learning workloads. - -## X - -**XLA (Accelerated Linear Algebra)** - A deep learning compiler developed by Google. - -**XLATensor** - A tensor type in PyTorch/XLA representing data on an XLA device, enabling lazy execution and XLA compilation. - -**xla_device()** - A PyTorch/XLA function for retrieving the current XLA device. - -**xm (xla_model)** - A module in PyTorch/XLA providing core functions for interacting with XLA devices and executing computations. - -**xmp (xla_multiprocessing)** - A module in PyTorch/XLA for launching distributed training processes across multiple XLA devices. diff --git a/docs/gpu.md b/docs/gpu.md deleted file mode 100644 index ac8facf296c..00000000000 --- a/docs/gpu.md +++ /dev/null @@ -1,172 +0,0 @@ -# How to run with PyTorch/XLA:GPU - -PyTorch/XLA enables PyTorch users to utilize the XLA compiler which supports accelerators including TPU, GPU, and CPU. This doc will go over the basic steps to run PyTorch/XLA on a nvidia GPU instances. - -## Create a GPU instance - -You can either use a local machine with GPU attached or a GPU VM on the cloud. For example in Google Cloud you can follow this [doc](https://cloud.google.com/compute/docs/gpus/create-vm-with-gpus) to create the GPU VM. - -## Environment Setup - -Make sure you have cuda driver installed on the host. - -### Docker -Pytorch/XLA currently publish prebuilt docker images and wheels with cuda11.8/12.1 and python 3.8. We recommend users to create a docker container with corresponding config. For a full list of docker images and wheels, please refer to [this doc](https://github.com/pytorch/xla#available-docker-images-and-wheels). -``` -sudo docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1 - -# Installing the NVIDIA Container Toolkit per https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html -# For example -curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \ - && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \ - sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \ - sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list -sudo apt-get update -sudo apt-get install -y nvidia-container-toolkit - -# Configuring the NVIDIA Container Toolkit -sudo nvidia-ctk runtime configure --runtime=docker -sudo systemctl restart docker - -sudo docker run --shm-size=16g --net=host --gpus all -it -d us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1 bin/bash -sudo docker exec -it $(sudo docker ps | awk 'NR==2 { print $1 }') /bin/bash -``` - -Note that you need to restart the docker to make gpu devices visible in the docker container. After logging into the docker, you can use `nvidia-smi` to verify the device is setup correctly. - -``` -(pytorch) root@20ab2c7a2d06:/# nvidia-smi -Thu Dec 8 06:24:29 2022 -+-----------------------------------------------------------------------------+ -| NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 | -|-------------------------------+----------------------+----------------------+ -| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|===============================+======================+======================| -| 0 Tesla V100-SXM2... Off | 00000000:00:04.0 Off | 0 | -| N/A 36C P0 38W / 300W | 0MiB / 16384MiB | 1% Default | -| | | N/A | -+-------------------------------+----------------------+----------------------+ - -+-----------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=============================================================================| -| No running processes found | -+-----------------------------------------------------------------------------+ - -``` - -### Check environment variable - -Make sure `PATH` and `LD_LIBRARY_PATH` environment variables account for cuda. Please do a `echo $PATH` and `echo $LD_LIBRARY_PATH` to verify. If not, please follow [link](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#mandatory-actions) to do so. Example: - -``` -echo "export PATH=\$PATH:/usr/local/cuda-12.1/bin" >> ~/.bashrc -echo "export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64" >> ~/.bashrc -source ~/.bashrc -``` - -### Wheel - -> **_NOTE:_** The wheel file is compatible only with x86_64 linux based architecutre. To check the architecture of your linux system, execute the following command: -> ``` ->uname -a -> ``` - -``` -pip3 install torch==2.4.0 -# GPU whl for python 3.10 + cuda 12.1 -pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl -``` -Wheels for other Python version and CUDA version can be found [here](https://github.com/pytorch/xla?tab=readme-ov-file#available-docker-images-and-wheels). - - -## Run some simple models -In order to run below examples, you need to clone the pytorch/xla repository. - -### MP_ImageNet Example -This example uses ImageNet. It is included in what we already cloned in our Docker container. -``` -(pytorch) root@20ab2c7a2d06:/# export GPU_NUM_DEVICES=1 PJRT_DEVICE=CUDA -(pytorch) root@20ab2c7a2d06:/# git clone --recursive https://github.com/pytorch/xla.git -(pytorch) root@20ab2c7a2d06:/# python xla/test/test_train_mp_imagenet.py --fake_data -==> Preparing data.. -Epoch 1 train begin 06:12:38 -| Training Device=xla:0/0 Epoch=1 Step=0 Loss=6.89059 Rate=2.82 GlobalRate=2.82 Time=06:13:23 -| Training Device=xla:0/0 Epoch=1 Step=20 Loss=6.79297 Rate=117.16 GlobalRate=45.84 Time=06:13:36 -| Training Device=xla:0/0 Epoch=1 Step=40 Loss=6.43628 Rate=281.16 GlobalRate=80.49 Time=06:13:43 -| Training Device=xla:0/0 Epoch=1 Step=60 Loss=5.83108 Rate=346.88 GlobalRate=108.82 Time=06:13:49 -| Training Device=xla:0/0 Epoch=1 Step=80 Loss=4.99023 Rate=373.62 GlobalRate=132.43 Time=06:13:56 -| Training Device=xla:0/0 Epoch=1 Step=100 Loss=3.92699 Rate=384.33 GlobalRate=152.40 Time=06:14:02 -| Training Device=xla:0/0 Epoch=1 Step=120 Loss=2.68816 Rate=388.35 GlobalRate=169.49 Time=06:14:09 -``` -### ResNet Example -This example uses ResNet. -``` -(pytorch) root@20ab2c7a2d06:/# python3 /xla/examples/train_resnet_base.py -1:35PM UTC on Jun 08, 2024 -epoch: 1, step: 0, loss: 6.887794017791748, rate: 8.746502586051985 -epoch: 1, step: 10, loss: 6.877807140350342, rate: 238.4789458412044 -epoch: 1, step: 20, loss: 6.867819786071777, rate: 329.86095958663503 -epoch: 1, step: 30, loss: 6.857839584350586, rate: 367.3038003653586 -epoch: 1, step: 40, loss: 6.847847938537598, rate: 381.53141087190835 -epoch: 1, step: 50, loss: 6.837860584259033, rate: 387.80462249591113 -... -epoch: 1, step: 260, loss: 6.628140926361084, rate: 391.135639565343 -epoch: 1, step: 270, loss: 6.618192195892334, rate: 391.6901797745233 -epoch: 1, step: 280, loss: 6.608224391937256, rate: 391.1602680460045 -epoch: 1, step: 290, loss: 6.598264217376709, rate: 391.6731498290759 -Epoch 1 train end 1:36PM UTC -``` - - -## AMP (AUTOMATIC MIXED PRECISION) -AMP is very useful on GPU training and PyTorch/XLA reuse Cuda's AMP rule. You can checkout our [mnist example](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_amp.py) and [imagenet example](https://github.com/pytorch/xla/blob/master/test/test_train_mp_imagenet_amp.py). Note that we also used a modified version of [optimizers](https://github.com/pytorch/xla/tree/master/torch_xla/amp/syncfree) to avoid the additional sync between device and host. - -## Develop PyTorch/XLA on a GPU instance (build PyTorch/XLA from source with GPU support) - -1. Inside a GPU VM, create a docker container from a development docker image. For example: - -``` -sudo docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_12.1 - -# Installing the NVIDIA Container Toolkit per https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html -# For example -curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \ - && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \ - sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \ - sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list -sudo apt-get update -sudo apt-get install -y nvidia-container-toolkit - -# Configuring the NVIDIA Container Toolkit -sudo nvidia-ctk runtime configure --runtime=docker -sudo systemctl restart docker - -sudo docker run --shm-size=16g --net=host --gpus all -it -d us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_12.1 -sudo docker exec -it $(sudo docker ps | awk 'NR==2 { print $1 }') /bin/bash -``` - -2. Build PyTorch and PyTorch/XLA from source. - -Make sure `PATH` and `LD_LIBRARY_PATH` environment variables account for cuda. See the [above](https://github.com/pytorch/xla/blob/master/docs/gpu.md#check-environment-variable) for more info. - -``` -git clone https://github.com/pytorch/pytorch.git -cd pytorch -USE_CUDA=1 python setup.py install -USE_CUDA=1 python setup.py bdist_wheel # Required for hermetic Python in PyTorch/XLA build setup. - -git clone https://github.com/pytorch/xla.git -cd xla -XLA_CUDA=1 python setup.py install -``` - -3. Verify if PyTorch and PyTorch/XLA have been installed successfully. - -If you can run the tests in the section -[Run some simple models](#run-some-simple-models) successfully, then PyTorch and -PyTorch/XLA should have been installed successfully. diff --git a/docs/kubernetes.md b/docs/kubernetes.md deleted file mode 100644 index db49332724f..00000000000 --- a/docs/kubernetes.md +++ /dev/null @@ -1,312 +0,0 @@ -# Distributed training on GKE - -PyTorch/XLA supports distributed training on GKE via [indexed -`Job`s](https://kubernetes.io/docs/tasks/job/job-with-pod-to-pod-communication/) -and `torchrun`. For more information about creating a GKE cluster with -accelerators, see the documentation for -[TPUs](https://cloud.google.com/kubernetes-engine/docs/how-to/tpus) and -[GPUs](https://cloud.google.com/kubernetes-engine/docs/how-to/gpus), -respectively. - -## GPU Example - -GKE is the recommended platform for distributed training with GPUs. This example -uses two hosts, each with two NVidia v100 GPUs. Adjust the values according -to the comments in the example for a larger or smaller cluster. - -Create a new file `gpu_test.yaml` with the following: - -```yaml -# Headless service used for service discovery. -# See https://kubernetes.io/docs/concepts/services-networking/service/#headless-services -apiVersion: v1 -kind: Service -metadata: - name: headless-svc -spec: - selector: - headless-svc: "true" - clusterIP: None ---- -apiVersion: batch/v1 -kind: Job -metadata: - name: torch-xla-resnet50-v100-x2x2 -spec: - # Don't retry upon failure - backoffLimit: 0 - # Indexed jobs pass rank to each replica - completionMode: Indexed - # Set `completions` and `parallelism` to the number of hosts in the cluster - completions: 2 - parallelism: 2 - template: - metadata: - creationTimestamp: null - labels: - headless-svc: "true" - spec: - subdomain: headless-svc - containers: - - name: main - image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_cuda_12.1 - command: - - bash - - -cxue - - | - export PATH=/usr/local/nvidia/bin${PATH:+:${PATH}} - export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/nvidia/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} - - nvidia-smi - - mkdir -p pytorch/xla - git clone -b r2.3 https://github.com/pytorch/xla.git pytorch/xla - - # Run `args` here - "${@:0}" - args: - - torchrun - # Set this to the number of hosts - - --nnodes=2 - # Index provided by Job - - --node_rank=$(JOB_COMPLETION_INDEX) - # Create one process per local GPU - - --nproc_per_node=2 - # Coordinator always runs on 0th instance of job - - --rdzv_endpoint=$(JOB_NAME)-0.headless-svc:12355 - # Replace this with your script and flags - - pytorch/xla/test/test_train_mp_imagenet.py - - --model=resnet50 - - --log_steps=200 - - --fake_data - - --pjrt_distributed - - --nometrics_debug - - --num_epochs=1 - env: - - name: JOB_NAME - valueFrom: - fieldRef: - apiVersion: v1 - fieldPath: metadata.labels['job-name'] - - name: PJRT_DEVICE - value: CUDA - resources: - limits: - # Change this to the number of GPUs per host - nvidia.com/gpu: "2" - # PyTorch requires a large `shm` - volumeMounts: - - mountPath: /dev/shm - name: dshm - restartPolicy: Never - # Change the node selector if you're using a different GPU type - nodeSelector: - cloud.google.com/gke-accelerator: nvidia-tesla-v100 - volumes: - - emptyDir: - medium: Memory - name: dshm -``` - -Once the job schedules, you should start seeing logs like this: - -``` -$ kubectl logs job/torch-xla-resnet50-v100-x2x2 -... -+ nvidia-smi -Fri Jun 28 20:15:43 2024 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 550.54.14 Driver Version: 550.54.14 CUDA Version: 12.4 | -|-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 Tesla V100-SXM2-16GB Off | 00000000:00:04.0 Off | 0 | -| N/A 35C P0 33W / 300W | 0MiB / 16384MiB | 0% Default | -| | | N/A | -+-----------------------------------------+------------------------+----------------------+ -| 1 Tesla V100-SXM2-16GB Off | 00000000:00:05.0 Off | 0 | -| N/A 35C P0 33W / 300W | 0MiB / 16384MiB | 0% Default | -| | | N/A | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| No running processes found | -+-----------------------------------------------------------------------------------------+ -+ mkdir -p pytorch/xla -+ git clone -b r2.3 https://github.com/pytorch/xla.git pytorch/xla -Cloning into 'pytorch/xla'... -+ torchrun --nnodes=2 --node_rank=0 --nproc_per_node=2 --rdzv_endpoint=torch-xla-resnet50-v100-x2x2-0.headless-svc:12355 pytorch/xla/test/test_train_mp_imagenet.py --model=resnet50 --log_steps=200 --fake_data --pjrt_distributed --nometrics_debug --num_epochs=1 -... -I0000 00:00:1719605752.973014 55 service.cc:145] XLA service 0x59c1a6e31500 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices: -I0000 00:00:1719605752.973052 55 service.cc:153] StreamExecutor device (0): Tesla V100-SXM2-16GB, Compute Capability 7.0 -... -==> Preparing data.. -==> Preparing data.. -Epoch 1 train begin 20:15:56 -| Training Device=xla:0/1 Epoch=1 Step=0 Loss=6.89059 Rate=3.13 GlobalRate=3.13 Time=20:16:36 -| Training Device=xla:0/0 Epoch=1 Step=0 Loss=6.89059 Rate=3.14 GlobalRate=3.14 Time=20:16:36 -... -| Training Device=xla:0/0 Epoch=1 Step=1800 Loss=0.00135 Rate=332.54 GlobalRate=314.11 Time=20:28:09 -| Training Device=xla:0/1 Epoch=1 Step=1800 Loss=0.00135 Rate=332.54 GlobalRate=314.06 Time=20:28:09 -... -| Training Device=xla:0/0 Epoch=1 Step=2200 Loss=0.00135 Rate=336.66 GlobalRate=318.00 Time=20:30:42 -| Training Device=xla:0/1 Epoch=1 Step=2200 Loss=0.00135 Rate=336.66 GlobalRate=317.96 Time=20:30:42 -Epoch 1 train end 20:31:36 -| Test Device=xla:0/0 Step=0 Epoch=1 Time=20:31:42 -| Test Device=xla:0/1 Step=0 Epoch=1 Time=20:31:42 -Epoch 1 test end 20:31:47, Accuracy=100.00 -Max Accuracy: 100.00% -... -``` - -## TPUs - -Training on TPU is similar to training on GPU in GKE, the same steps for `torchrun` apply. For more -information about TPU GKE clusters, see [GKE's official -docs](https://cloud.google.com/kubernetes-engine/docs/how-to/tpus). - -The example below use two ct5lp-hightpu-4t VMs, with 4 v5e TPU each to construct a 2x4 topology nodepool. -You can adjust the values accordingly to match the training requirement. - -Create a new file `tpu_test.yaml` with the following: - -```yaml -apiVersion: v1 -kind: Service -metadata: - name: headless-svc - namespace: default -spec: - clusterIP: None - selector: - headless-svc: "true" ---- -apiVersion: batch/v1 -kind: Job -metadata: - name: torch-xla-tpu-2x4 -spec: - parallelism: 2 # num of nodes - completions: 2 # num of nodes - backoffLimit: 0 # default, no retries - completionMode: Indexed - template: - metadata: - labels: - headless-svc: "true" - spec: # pod-spec: - serviceAccountName: default - affinity: - nodeAffinity: - requiredDuringSchedulingIgnoredDuringExecution: - nodeSelectorTerms: - - matchExpressions: # need to be specified to get tpu resources - - key: cloud.google.com/gke-tpu-accelerator - operator: "In" - values: - - "tpu-v5-lite-podslice" - - key: cloud.google.com/gke-tpu-topology - operator: "In" - values: - - "2x4" # 2 nodes of 4 tpu's - tolerations: - - key: "google.com/tpu" - operator: "Equal" - value: "present" - effect: "NoSchedule" - restartPolicy: Never # look in https://kubernetes.io/docs/concepts/workloads/controllers/job/ - subdomain: headless-svc - volumes: - # Increase size of tmpfs /dev/shm to avoid OOM. - - name: shm - emptyDir: - medium: Memory - containers: - - name: training - image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_libtpu_3.10_tpuvm - command: - - bash - - -cxue - - | - - mkdir -p pytorch/xla - git clone -b r2.3 https://github.com/pytorch/xla.git pytorch/xla - - # Run `args` here - "${@:0}" - args: - - torchrun - # Set this to the number of hosts - - --nnodes=2 - # Index provided by Job - - --node_rank=$(JOB_COMPLETION_INDEX) - # Create one process per local TPU - - --nproc_per_node=4 - # Coordinator always runs on 0th instance of job - - --rdzv_endpoint=$(JOB_NAME)-0.headless-svc:12355 - # Replace this with your script and flags - - pytorch/xla/test/test_train_mp_imagenet.py - - --model=resnet50 - - --log_steps=200 - - --fake_data - - --pjrt_distributed - - --nometrics_debug - - --num_epochs=1 - ports: - - containerPort: 8471 # 8471 is the default port for the TPU VMs communication - - containerPort: 12355 # used by the code - - containerPort: 8479 - - containerPort: 8478 - - containerPort: 8477 - - containerPort: 8476 - - containerPort: 8431 # Port to export TPU usage metrics, if supported. - volumeMounts: - - mountPath: /dev/shm - name: shm - env: - - name: PJRT_DEVICE - value: 'TPU' - - name: JOB_NAME - valueFrom: - fieldRef: - apiVersion: v1 - fieldPath: metadata.labels['job-name'] - resources: - requests: - google.com/tpu: 4 - memory: 16G - limits: - google.com/tpu: 4 -``` - -Once the job schedules, you should start seeing logs like this: - -``` -$ kubectl logs job/torch-xla-tpu-2x4 -... -Cloning into 'pytorch/xla'... -+ torchrun --nnodes=2 --node_rank=1 --nproc_per_node=4 --rdzv_endpoint=torch-xla-tpu-2x4-0.headless-svc:12355 pytorch/xla/test/test_train_mp_imagenet.py --model=resnet50 --log_steps=200 --fake_data --pjrt_distributed --nometrics_debug --num_epochs=1 -... -==> Preparing data.. -==> Preparing data.. -Epoch 1 train begin 23:10:22 -|| Training Device=xla:0/3 Epoch=1 Step=0 Loss=6.89059 Rate=4.64 GlobalRate=4.64 Time=23:10:54 - Training Device=xla:0/0 Epoch=1 Step=0 Loss=6.89059 Rate=3.97 GlobalRate=3.97 Time=23:10:54 -| Training Device=xla:0/1 Epoch=1 Step=0 Loss=6.89059 Rate=4.13 GlobalRate=4.13 Time=23:10:54 -| Training Device=xla:0/2 Epoch=1 Step=0 Loss=6.89059 Rate=3.99 GlobalRate=3.99 Time=23:10:54 -...\ -| Training Device=xla:0/3 Epoch=1 Step=1000 Loss=0.00139 Rate=1343.24 GlobalRate=864.39 Time=23:12:54 -| Training Device=xla:0/2 Epoch=1 Step=1000 Loss=0.00139 Rate=1343.23 GlobalRate=839.12 Time=23:12:54 -Epoch 1 train end 23:13:07 -| Test Device=xla:0/1 Step=0 Epoch=1 Time=23:13:11 -| Test Device=xla:0/3 Step=0 Epoch=1 Time=23:13:11 -... -Epoch 1 test end 23:13:16, Accuracy=100.00 -Max Accuracy: 100.00% -``` diff --git a/docs/openxla.md b/docs/openxla.md deleted file mode 100644 index 5dd3c6a76ce..00000000000 --- a/docs/openxla.md +++ /dev/null @@ -1,19 +0,0 @@ -# OpenXLA - -As of June 28th, 2023, PyTorch/XLA now pulls XLA from OpenXLA. -OpenXLA is an [open source machine learning compiler XLA for GPUs, CPUs, and ML accelerators](https://github.com/openxla/xla). - -Previous to OpenXLA, PyTorch/XLA pulled XLA directly from [TensorFlow](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla). With our [XLA to OpenXLA migration](https://github.com/pytorch/xla/pull/5202), PyTorch/XLA now pulls XLA from [OpenXLA](https://github.com/openxla/xla). - -# How to use OpenXLA - -For [PJRT runtime](https://github.com/pytorch/xla/blob/master/docs/pjrt.md) users, there is no change with this migration. For XRT runtime users, there is a separate [XRT branch of PyTorch/XLA](https://github.com/pytorch/xla/tree/xrt) since OpenXLA doesn't support XRT. - - -# Performance -Below is a performance visual comparison of throughput for ResNet50 pre and post the migration on different TPU hardwares. - -| | resnet50-pjrt-v2-8 | resnet50-pjrt-v4-8 | resnet50-pjrt-v4-32 | -| :------------ | :------------ | :------------ | :------------ | -| Pre Migration | 18.59 | 20.06 | 27.92 | -| Post Migration | 18.63 | 19.94 | 27.14 | diff --git a/docs/pjrt.md b/docs/pjrt.md deleted file mode 100644 index 2d66d2ef925..00000000000 --- a/docs/pjrt.md +++ /dev/null @@ -1,429 +0,0 @@ -# PJRT Runtime - -PyTorch/XLA has migrated from the TensorFlow-based XRT runtime to the [PJRT -runtime](https://github.com/openxla/xla/tree/main/xla/pjrt) -used by [JAX](https://github.com/google/jax). - -If you encounter a bug with PJRT, please file an issue on GitHub with the -`runtime` tag. - -_New features in PyTorch/XLA r2.1_: - -* PJRT is stable in PyTorch/XLA r2.1! -* Public runtime APIs have moved from `torch_xla.experimental.pjrt` to - `torch_xla.runtime`. - * The `pjrt://` init method has been renamed to `xla://`, and it is registered - by `torch_xla.distributed.xla_backend`. - * The previous `torch_xla.experimental.*` names are still available in this - release for compatibility. -* `torchrun` is now supported when using `init_method='xla://'`. -* New plugins for XPU and Neuron via the PJRT C API. - -_New features in PyTorch/XLA r2.0_: - -* PJRT will be configured by default if you don't pass in any other runtime - configuration. If you continue to set XRT configuration (`XRT_TPU_CONFIG`), - this change has no impact -* New TPU runtime implementation in `libtpu` improves performance by up to 30%. -* New `xm.rendezvous` implementation that scales to thousands of TPU cores -* [experimental] `torch.distributed` support for TPU v2 and v3, including - `pjrt://` `init_method` - -## TL;DR - -* To use the PJRT preview runtime, set the `PJRT_DEVICE` environment variable to - `CPU`, `TPU`, or `CUDA` -* In XRT, all distributed workloads are multiprocess, with one process per - device. On TPU v2 and v3 in PJRT, workloads are multiprocess and multithreaded - (4 processes with 2 threads each), so your workload should be thread-safe. See - [Multithreading on TPU v2/v3](#multithreading-on-tpu-v2v3) and the - [Multiprocessing section of the API - guide](https://github.com/pytorch/xla/blob/master/API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing) - for more information. Key differences to keep in mind: - * To initialize a model in a thread-safe way, either broadcast the parameters - across replicas after initialization - (`torch_xla.experimental.pjrt.broadcast_master_param`) or load each - replica's parameters from a common checkpoint. - * For other random number generation, use `torch.Generator` where possible. - The global `torch` RNG is _not_ thread-safe, even if you set the same - `torch.manual_seed` across replicas. - * To use `torch.distributed`, import `torch_xla.experimental.pjrt_backend` and - use the `xla://` `init_method`. - * These steps are optional for GPU and TPU v4. - -Sample diff from XRT to PJRT: - -```diff - import os - - import torch - import torch.nn as nn - from torch.nn.parallel import DistributedDataParallel as DDP - import torch.optim as optim - import torch.distributed as dist - import torch_xla - import torch_xla.core.xla_model as xm - import torch_xla.distributed.parallel_loader as pl - import torch_xla.distributed.xla_backend -+import torch_xla.runtime as xr - - - def _mp_fn(index): - device = xm.xla_device() -- dist.init_process_group('xla', rank=xr.global_ordinal(), world_size=xr.world_size()) -+ dist.init_process_group('xla', init_method='xla://') - - torch.manual_seed(42) - model = nn.Linear(128, 10).to(device) - -+ # Optional for TPU v4 and GPU -+ xm.broadcast_master_param(model) - model = DDP(model, gradient_as_bucket_view=True) - - loss_fn = nn.MSELoss() - optimizer = optim.SGD(model.parameters(), lr=.001) - - for i in range(10): - data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device) - - optimizer.zero_grad() - output = model(data) - loss = loss_fn(output, target) - loss.backward() - - optimizer.step() - xm.mark_step() - - # Print mean parameters so we can confirm they're the same across replicas - print([p.mean() for p in model.parameters()]) - - if __name__ == '__main__': -- os.environ['XRT_TPU_CONFIG'] = 'localservice;0;localhost:51011' -- os.environ['MASTER_ADDR'] = 'localhost' -- os.environ['MASTER_PORT'] = '12355' - -+ # Recommended: set PJRT_DEVICE to your local device type -+ os.environ['PJRT_DEVICE'] = 'TPU' - - torch_xla.launch(_mp_fn) -``` - -## Benefits - -* Simple runtime configuration: just set `PJRT_DEVICE` to `TPU`, `CPU`, or `CUDA` - and start using XLA! Or, let PJRT select a device automatically based on your - environment. -* Improved performance: reduced overhead from gRPC means faster end-to-end - execution. On TorchBench 2.0, we observed a >35% improvement in training time - on TPU v4. -* Easy pod execution: just copy your code to each TPU worker, and execute them - all at the same time with `gcloud compute tpus tpuvm ssh --worker=all`. -* Better scaling: removes [XRT's limitation on parameter - sizes](https://github.com/pytorch/xla/pull/3920) and supports up to 2048 TPU - chips. - -## Quickstart - -To start using PJRT with PyTorch/XLA, all you need to do is set the -`PJRT_DEVICE` environment variable. If you're working on a TPU v2 or v3, keep -reading to learn about the differences between TPU v2 and v3 and v4. - -### CPU - -On any machine with PyTorch/XLA installed, you can run our MNIST example on CPU -like this: - -``` -PJRT_DEVICE=CPU python3 xla/test/test_train_mp_mnist.py --fake_data -``` - -### TPU - -To create a new TPU with PyTorch/XLA r2.0 installed: - -``` -gcloud alpha compute tpus tpu-vm create $USER-pjrt --accelerator-type=v4-8 --version=tpu-vm-v4-pt-2.0 --zone=us-central2-b --project=$PROJECT -``` - -On a v4-8, you can run our ResNet50 example like this: - -``` -git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git -PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1 -``` - -By default, PJRT will use all TPU chips. To use only one TPU chip, configure -`TPU_PROCESS_BOUNDS` and `TPU_VISIBLE_CHIPS`: - -``` -TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_CHIPS=0 PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1 -``` - -#### Pods - -On TPU Pods, use `gcloud` to run your command on each TPU in parallel: - -``` -gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r1.13 https://github.com/pytorch/xla.git" -gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1" -``` - -#### Docker - -You can also use Docker to run your workload in a container with PyTorch/XLA -preinstalled: - -``` -export DOCKER_IMAGE=gcr.io/... - -# Optional: authenticate docker if your image is in a private GCP repository -gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo gcloud auth configure-docker" - -# Run your workload -gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo docker run --rm --privileged --net=host -e PJRT_DEVICE=TPU $DOCKER_IMAGE python pytorch/xla/test/test_train_mp_imagenet.py --fake_data" -``` - -Note that `docker run` requires privileged access to the host (`--privileged`) -to expose the TPU device to the container. Docker on TPU pods is only supported -with host networking `--net=host` at this time. See the [Cloud TPU documentation](https://cloud.google.com/tpu/docs/run-in-container) -for more information. - -### GPU - -### Single-node GPU training - -To use GPUs with PJRT, simply set `PJRT_DEVICE=CUDA` and configure -`GPU_NUM_DEVICES` to the number of devices on the host. For example: - -``` -PJRT_DEVICE=CUDA GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1 -``` - -You can also use `torchrun` to initiate the single-node multi-GPU training. For example, - -``` -PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1 -``` - -In the above example, `--nnodes` means how many machines (physical machines or VMs) to be used (it is 1 since we do single-node training). `--nproc-per-node` means how many GPU devices to be used. - -### Multi-node GPU training - -**Note that this feature only works for cuda 12+**. Similar to how PyTorch uses multi-node training, you can run the command as below: - -``` -PJRT_DEVICE=CUDA torchrun \ ---nnodes=${NUMBER_GPU_VM} \ ---node_rank=${CURRENT_NODE_RANK} \ ---nproc_per_node=${NUMBER_LOCAL_GPU_DEVICES} \ ---rdzv_endpoint= multinode_training.py -``` - -- `--nnodes`: how many GPU machines to be used. -- `--node_rank`: the index of the current GPU machines. The value can be 0, 1, ..., ${NUMBER_GPU_VM}-1. -- `--nproc_per_node`: the number of GPU devices to be used on the current machine. -- `--rdzv_endpoint`: the endpoint of the GPU machine with node_rank==0, in the form `host:port`. The `host` will be the internal IP address. The `port` can be any available port on the machine. For single-node training/inference, this parameter can be omitted. - -For example, if you want to train on 2 GPU machines: machine_0 and machine_1, on the first GPU machine machine_0, run - -``` -# PJRT_DEVICE=CUDA torchrun \ ---nnodes=2 \ ---node_rank=0 \ ---nproc_per_node=4 \ ---rdzv_endpoint=":12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1 -``` - -On the second GPU machine, run - -``` -# PJRT_DEVICE=CUDA torchrun \ ---nnodes=2 \ ---node_rank=1 \ ---nproc_per_node=4 \ ---rdzv_endpoint=":12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1 -``` - -the difference between the 2 commands above are `--node_rank` and potentially `--nproc_per_node` if you want to use different number of GPU devices on each machine. All the rest are identical. For more information about `torchrun`, please refer to this [page](https://pytorch.org/docs/stable/elastic/run.html). - -## Differences from XRT - -Although in most cases we expect PJRT and XRT to work mostly interchangeably -from the end-user's perspective (especially on TPU v4), there are some subtle -differences that are important to keep in mind. Importantly, XRT was designed -around the TPU Node architecture, so it will always spawn a client and a server -process, even on TPU VMs. Thus, every batch of inputs has additional latency -from serializing and deserializing data to send it over the network. - -PJRT uses the local device directly with no intermediate server process. In the -default configuration, PJRT will create one process per TPU chip, or 4 processes -per TPU host. See the [Cloud TPU -documentation](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) for -more information about TPU architecture. - -* Performance gains are possible for workloads constrained overhead from . -* Under XRT, the server process is the only process that interacts with the TPU - devices, and client processes don't have direct access to the TPU devices. - When profiling a single-host TPU (e.g. v3-8 or v4-8), you would normally see 8 - device traces (one for each TPU core). With PJRT, each process has one chip, - and a profile from that process will show only 2 TPU cores. - * For the same reason, profiling does not work on TPU Pods with XRT, because - the server process runs independently from the user's model code. PJRT does - not have that constraint, so it is possible to profile 2 TPU cores per - process in a TPU Pod. -* PJRT only supports the TPU VM architecture and we have no plans to support the - TPU Node architecture with PJRT. -* Runtime configuration is significantly simpler with PJRT. `xla_dist` is not - required to run TPU Pod workloads. Instead, copy your code to each TPU host - (`[gcloud compute tpus tpu-vm - scp](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/scp)`) - and run the code on each host in parallel (e.g. `[gcloud compute tpus tpu-vm - ssh --workers=all --command="PJRT_DEVICE=TPU python - run.py"](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/ssh)`) -* `xm.rendezvous` has been reimplemented using XLA-native collective - communication to enhance stability on large TPU pods. See below for more - details. - - -### Multithreading on TPU v2/v3 - -On TPU v2 and v3, **distributed workloads always run multithreaded**, since each -TPU core exposes two TPU cores as devices and only one process may open a TPU -chip at a time. In its default configuration, `xmp.spawn` automatically spawns -as many processes as possible (4 per TPU host) and creates two threads per -process (one per TPU core). - -Note: on TPU v4, each TPU chip is represented as one PyTorch device, so -distributed workloads will run across 4 processes, each with only one thread. -This is identical to XRT's behavior. - -In most cases, this will not require substantial changes to your existing code. -The main change you will have to make in most cases is to model initialization. -Because `torch`'s global RNG is shared between threads, results will vary -between threads and runs even if you set `torch.manual_seed` to the same value -in every replica. To get consistent parameters between replicas, either use -`torch_xla.experimental.pjrt.broadcast_master_param` to broadcast one replica's -parameters to all other replicas, or load each replica's parameters from a -common checkpoint. - - -### Changes to xm.rendezvous - -_New in PyTorch/XLA r2.0_ - -With XRT, worker 0 runs a mesh master service, and all processes on all workers -connect to that service over gRPC. In practice, we found that running a single -mesh master process was unreliable on TPU pods with thousands of chips due to -the number of inbound connections to worker 0. A single client process timing -out could cause a failure and force the entire workload to restart. - -Thus, we have reimplemented `xm.rendezvous` with native XLA collective -communication, which is much more stable and well-tested on large TPU pods. This -imposes two new constraints compared to the XRT implementation: - -* Because the payload has to become part of the XLA graph, `xm.mark_step` is - called both before and after the data is transferred. Calling `xm.rendezvous` - in the middle of model code may force an unwanted compilation. -* Because XLA does not permit collective operations to run on a subset of - workers, all workers must participate in the `rendezvous`. - -If you require the old behavior of `xm.rendezvous` (i.e. communicating data -without altering the XLA graph and/or synchronizing a subset of workers), -consider using -[`torch.distributed.barrier`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.barrier) -or -[`torch.distributed.all_gather_object`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather_object) -with a `gloo` process group. If you are also using the `xla` `torch.distributed` -backend, you can use `torch.new_group` to create a `gloo` subgroup. See [this -example](https://pytorch.org/docs/stable/distributed.html#monitored-barrier) -from the PyTorch documentation. Keep in mind these constraints: - -* `torch.distributed` is not fully supported on TPU v2/v3. Only a subset of - operations with the `xla` backend are implemented, and `gloo` will likely not - work as expected in a multithreaded context. -* In our experiments, `gloo` does not scale well to thousands of TPU chips, so - expect this alternative to be less reliable than using `xm.rendezvous` with - PJRT at large scales. - -### PJRT and torch.distributed - -_New in PyTorch/XLA r2.0_ - -When using PJRT with `torch.distributed` and -`[torch.nn.parallel.DistributedDataParallel](https://github.com/pytorch/xla/blob/master/docs/ddp.md)` -we strongly recommend using the new `xla://` `init_method`, which automatically -finds the replica IDs, world size, and master IP by querying the runtime. For -example: - -```python -import torch -import torch_xla -import torch.distributed as dist -import torch_xla.core.xla_model as xm -from torch_xla.experimental import pjrt - -# Required for `xla://` init_method and `xla` backend -import torch_xla.distributed.xla_backend - -def _all_gather(index: int): - # No need to pass in `rank` or `world_size` - dist.init_process_group('xla', init_method='xla://') - - t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device()) - output = [torch.zeros_like(t) for _ in range(dist.get_world_size())] - dist.all_gather(output, t) - - xm.mark_step() - print(output) - -if __name__ == '__main__': - torch_xla.launch(_all_gather) -``` - -Note: Although the `xla://` init_method is not required on TPU v4, it is still -recommended. If you use `env://`, `MASTER_ADDR` must be set to IP host that has -device 0, which is _not_ always worker 0. The `xla://` init_method finds this -IP automatically. - -Note: For TPU v2/v3, you still need to import -`torch_xla.experimental.pjrt_backend`, as TPU v2/v3 support in -`torch.distributed` is still experimental. - -For more information about using `DistributedDataParallel` on PyTorch/XLA, see -[`ddp.md`](./ddp.md) on TPU V4. For an example that uses DDP and PJRT together, -run the following [example script](../test/test_train_mp_imagenet.py) on a TPU: - -``` -PJRT_DEVICE=TPU python xla/test/test_train_mp_mnist.py --ddp --pjrt_distributed --fake_data --num_epochs 1 -``` - -## Performance - -TorchBench shows improvements in average training time across tasks with PJRT -compared to XRT, with an average improvement of over 35% on TPU v4-8. The -benefits vary significantly by task and model type, ranging from 0% to 175%. -The following chart shows the breakdown by task: - -![PJRT vs XRT](_static/img/torchbench_pjrt_vs_xrt.svg) - -### New TPU runtime - -_New in PyTorch/XLA r2.0_ - -The PyTorch/XLA r2.0 release introduces support for the [PJRT Plugin -API](https://github.com/openxla/community/blob/main/rfcs/20230123-pjrt-plugin.md#rfc-openxla-pjrt-plugin), -used to access the new TFRT-based TPU runtime in `libtpu`. This is now the -default runtime when `PJRT_DEVICE=TPU` is set. The legacy StreamExecutor-based -TPU runtime used in 1.13 will still be available with `PJRT_DEVICE=TPU_LEGACY` -in the 2.0 release, but it will be removed in a future version. If you encounter -an issue that only happens on `TPU` and not `TPU_LEGACY`, please file an issue -on GitHub. - -In most cases, we expect performance to be similar between the two runtimes, but -in some cases, the new runtime may be up to 30% faster. The following chart -shows the breakdown by task: - -![TFRT vs StreamExecutor](_static/img/torchbench_tfrt_vs_se.svg) - -Note: the improvements shown in this chart are also included in the PJRT vs XRT -comparison. diff --git a/docs/plugins.md b/docs/plugins.md deleted file mode 100644 index 4008b437368..00000000000 --- a/docs/plugins.md +++ /dev/null @@ -1,82 +0,0 @@ -# Custom Hardware Plugins - -PyTorch/XLA supports custom hardware through OpenXLA's PJRT C API. The -PyTorch/XLA team direclty supports plugins for Cloud TPU (`libtpu`) and GPU -([OpenXLA](https://github.com/openxla/xla/tree/main/xla/pjrt/gpu)). The same -plugins may also be used by JAX and TF. - -## Implementing a PJRT Plugin - -PJRT C API plugins may be closed-source or open-source. They contain two parts: - -1. Binary exposing a PJRT C API implementation. This part can be shared with JAX -and TensorFlow. -2. Python package containing the above binary, as well as an implementation of -our `DevicePlugin` Python interface, which handles additional setup. - -### PJRT C API Implementation - -In short, you must implement a -[`PjRtClient`](https://github.com/openxla/xla/blob/main/xla/pjrt/pjrt_client.h) -containing an XLA compiler and runtime for your device. The PJRT C++ interface -is mirrored in C in the -[`PJRT_Api`](https://github.com/openxla/xla/blob/main/xla/pjrt/c/pjrt_c_api.h). -The most straightforward option is to implement your plugin in C++ and [wrap -it](https://github.com/openxla/xla/blob/main/xla/pjrt/c/pjrt_c_api_wrapper_impl.h) -as a C API implementation. This process is explained in detail in [OpenXLA's -documentation](https://openxla.org/xla/pjrt_integration#how_to_integrate_with_pjrt). - -For a concrete example, see the example [CPU plugin](../plugins/cpu). ([OpenXLA -implementation](https://github.com/openxla/xla/blob/main/xla/pjrt/c/pjrt_c_api_cpu_internal.cc)). - -### PyTorch/XLA Plugin Package - -At this point, you should have a functional PJRT plugin binary, which you can -test with the placeholder `LIBRARY` device type. For example: - -``` -$ PJRT_DEVICE=LIBRARY PJRT_LIBRARY_PATH=/path/to/your/plugin.so python ->>> import torch_xla ->>> torch_xla.devices() -# Assuming there are 4 devices. Your hardware may differ. -[device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)] -``` - -To register your device type automatically for users as well as to handle extra -setup for e.g. multiprocessing, you may implement the `DevicePlugin` Python API. -PyTorch/XLA plugin packages contain two key components: - -1. An implementation of `DevicePlugin` that (at the very least) provides the -path to your plugin binary. For example: - -``` -class CpuPlugin(plugins.DevicePlugin): - - def library_path(self) -> str: - return os.path.join( - os.path.dirname(__file__), 'lib', 'pjrt_c_api_cpu_plugin.so') -``` - -2. A `torch_xla.plugins` [entry -point](https://setuptools.pypa.io/en/latest/userguide/entry_point.html) that -identifies your `DevicePlugin`. For exmaple, to register the `EXAMPLE` device -type in a `pyproject.toml`: - -``` -[project.entry-points."torch_xla.plugins"] -example = "torch_xla_cpu_plugin:CpuPlugin" -``` - -With your package installed, you may then use your `EXAMPLE` device directly: - -``` -$ PJRT_DEVICE=EXAMPLE python ->>> import torch_xla ->>> torch_xla.devices() -[device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)] -``` - -[`DevicePlugin`](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/plugins.py) -provides additional extension points for multiprocess initialization and client -options. The API is currently in an experimental state, but it is expected to -become stable in a future release. diff --git a/docs/source/accelerators/gpu.md b/docs/source/accelerators/gpu.md new file mode 100644 index 00000000000..56abb192a70 --- /dev/null +++ b/docs/source/accelerators/gpu.md @@ -0,0 +1,6 @@ +# Learn about GPUs + +For information on GPUs on Google Cloud, see: + +- [About GPUs on Google Cloud](https://cloud.google.com/compute/docs/gpus/overview) +- [GPU machine types](https://cloud.google.com/compute/docs/gpus) diff --git a/docs/source/accelerators/tpu.md b/docs/source/accelerators/tpu.md new file mode 100644 index 00000000000..3f60dcd6a60 --- /dev/null +++ b/docs/source/accelerators/tpu.md @@ -0,0 +1,24 @@ +# Learn about TPUs + +Google Cloud TPUs are custom-designed AI accelerators, which are +optimized for training and inference of large AI models. They are ideal +for a variety of use cases, such as chatbots, code generation, media +content generation, synthetic speech, vision services, recommendation +engines, personalization models, among others. + +Cloud TPUs are designed to scale cost-efficiently for a wide range of AI +workloads, spanning training, fine-tuning, and inference. Cloud TPUs +provide the versatility to accelerate workloads on leading AI +frameworks, including PyTorch, JAX, and TensorFlow. Seamlessly +orchestrate large-scale AI workloads through Cloud TPU integration in +Google Kubernetes Engine (GKE). Leverage Dynamic Workload Scheduler to +improve the scalability of workloads by scheduling all accelerators +needed simultaneously. Customers looking for the simplest way to develop +AI models can also leverage Cloud TPUs in Vertex AI, a fully-managed AI +platform. + +For more information, see: + +- [Introduction to Cloud TPUs](https://cloud.google.com/tpu/docs/intro-to-tpu) +- [Set up a Cloud TPU environment](https://cloud.google.com/tpu/docs/setup-gcp-account) +- [Manage Cloud TPU resources](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm) diff --git a/docs/source/contribute/bazel.md b/docs/source/contribute/bazel.md new file mode 100644 index 00000000000..71c8d318aeb --- /dev/null +++ b/docs/source/contribute/bazel.md @@ -0,0 +1,247 @@ +# Bazel in Pytorch/XLA + +[Bazel](https://bazel.build/) is a free software tool used for the +automation of building and testing software. +[TensorFlow](https://www.tensorflow.org/http) and +[OpenXLA](https://github.com/openxla/xla) both use it, which makes it a +good fit for PyTorch/XLA as well. + +## Bazel dependencies + +Tensorflow is a [bazel external dependency](https://bazel.build/external/overview) for PyTorch/XLA, +which can be seen in the `WORKSPACE` file: + +`WORKSPACE` + +``` python +http_archive( + name = "org_tensorflow", + strip_prefix = "tensorflow-f7759359f8420d3ca7b9fd19493f2a01bd47b4ef", + urls = [ + "https://github.com/tensorflow/tensorflow/archive/f7759359f8420d3ca7b9fd19493f2a01bd47b4ef.tar.gz", + ], +) +``` + +TensorFlow pin can be updated by pointing this repository to a different +revision. Patches may be added as needed. Bazel will resolve the +dependency, prepare the code and patch it hermetically. + +For PyTorch, a different dependency mechanism is deployed because a +local [PyTorch](https://github.com/pytorch/pytorch) checkout is used, +and this local checkout has to be `built` from source and ideally +installed on the system for version compatibility (e.g codegen in +PyTorch/XLA uses `torchgen` python module that should be installed in +the system). + +The local directory can either set in `bazel/dependencies.bzl`, or +overriden on the command line: + +``` bash +bazel build --override_repository=org_tensorflow=/path/to/exported/tf_repo //... +``` + +``` bash +bazel build --override_repository=torch=/path/to/exported/and/built/torch_repo //... +``` + +Please make sure that the overridden repositories are at the appropriate +revisions and in case of `torch`, that it has been built with +`USE_CUDA=0 python setup.py bdist_wheel` to make sure that all expected +build objects are present; ideally installed into the system. + +`WORKSPACE` + +``` python +new_local_repository( + name = "torch", + build_file = "//bazel:torch.BUILD", + path = PYTORCH_LOCAL_DIR, +) +``` + +PyTorch headers are directly sourced from the `torch` dependency, the +local checkout of PyTorch. The shared libraries (e.g. `libtorch.so`) are +sourced from the same local checkout where the code has been built and +`build/lib/` contains the built objects. For this to work, it's required +to pass `-isystemexternal/torch` to the compiler so it can find `system` +libraries and satisfy them from the local checkout. Some are included as +`` and some as `"user"` headers. + +Bazel brings in [pybind11](https://github.com/pybind/pybind11) embeded +python and links against it to provide `libpython` to the plugin using +this mechanism. Python headers are also sourced from there instead of +depending on the system version. These are satisfied from the +`"@pybind11//:pybind11_embed"`, which sets up compiler options for +linking with `libpython` transitively. + +## How to build XLA libraries + +Building the libraries is simple: + +``` bash +bazel build //torch_xla/csrc/runtime/... +``` + +Bazel is configred via `.bazelrc`, but it can also take flags on the +command line. + +``` bash +bazel build --config=remote_cache //torch_xla/csrc/runtime/... +``` + +The `remote_cache` configurations use gcloud for caching and usually +faster, but require authentication with gcloud. See `.bazelrc` for the +configuration. + +Using bazel makes it easy to express complex dependencies and there is a +lot of gain from having a single build graph with everything expressed +in the same way. Therefore, there is no need to build the XLA libraries +separately from the rest of the pluing as used to be the case, building +the whole repository, or the plugin shared object that links everythin +else in, is enough. + +## How to build the Torch/XLA plugin + +The normal build can be achieved by the invoking the standard +`python setup.py bdist_wheel`, but C++ bindings can be built simply +with: + +``` bash +bazel build //:_XLAC.so +``` + +This will build the XLA client and the PyTorch plugin and link it all +together. This can be useful when testing changes, to be able to compile +the C++ code without building the python plugin faster iteration cycles. + +## Remote caching + +Bazel comes with [remote caching](https://bazel.build/remote/caching) +built in. There are plenty of cache backends that can be used; we deploy +our caching on +(GCS)\[\]. You can see +the configuration in `.bazelrc`, under config name `remote_cache`. + +Remote caching is disabled by default but because it speeds up +incremental builds by a huge margin, it is almost always recommended, +and it is enabled by default in the CI automation and on Cloud Build. + +To authenticate on a machine, please ensure that you have the +credentials present with: + +``` bash +gcloud auth application-default login --no-launch-browser +``` + +Using the remote cache configured by `remote_cache` configuration setup +requires authentication with GCP. There are various ways to authenticate +with GCP. For individual developers who have access to the development +GCP project, one only needs to specify the `--config=remote_cache` flag +to bazel, and the default `--google_default_credentials` will be used +and if the gcloud token is present on the machine, it will work out of +the box, using the logged in user for authentication. The user needs to +have remote build permissions in GCP (add new developers into the +`Remote Bazel` role). In the CI, the service account key is used for +authentication and is passed to bazel using +`--config=remote_cache --google_credentials=path/to/service.key`. On +[Cloud Build](https://cloud.google.com/build), +`docker build --network=cloudbuild` is used to pass the authentication +from the service account running the cloud build down into the docker +image doing the compilation: [Application Default +Credentials](https://cloud.google.com/docs/authentication/provide-credentials-adc) +does the work there and authenticates as the service account. All +accounts, both user and service accounts, need to have remote cache +read/write permissions. + +Remote cache uses cache silos. Each unique machine and build should +specify a unique silo key to benefit from consistent caching. The silo +key can be passed using a flag: +`-remote_default_exec_properties=cache-silo-key=SOME_SILO_KEY'`. + +Running the build with remote cache: + +``` bash +BAZEL_REMOTE_CACHE=1 SILO_NAME="cache-silo-YOUR-USER" TPUVM_MODE=1 python setup.py bdist_wheel +``` + +Adding + +``` bash +GCLOUD_SERVICE_KEY_FILE=~/.config/gcloud/application_default_credentials.json +``` + +might help too if `bazel` cannot find the auth token. + +`YOUR-USER` here can the author's username or machine name, a unique +name that ensures good cache behavior. Other `setup.py` functionality +works as intended too (e.g. `develop`). + +The first time the code is compiled using a new cached key will be slow +because it will compile everything from scratch, but incremental +compilations will be very fast. On updating the TensorFlow pin, it will +once again be a bit slower the first time per key, and then until the +next update quite fast again. + +## Running tests + +Currently C++ code is built and tested by bazel. Python code will be +migrated in the future. + +Bazel is a test plafrom too, making it easy to run tests: + +``` bash +bazel test //test/cpp:main +``` + +Of course the XLA and PJRT configuration have to be present in the +environment to run the tests. Not all environmental variables are passed +into the bazel test environment to make sure that the remote cache +misses are not too common (environment is part of the cache key), see +`.bazelrc` test configuration to see which ones are passed in, and add +new ones as required. + +You can run the tests using the helper script too: + +``` bash +BAZEL_REMOTE_CACHE=1 SILO_NAME="cache-silo-YOUR-USER" ./test/cpp/run_tests.sh -R +``` + +The `xla_client` tests are pure hermetic tests that can be easily +executed. The `torch_xla` plugin tests are more complex: they require +`torch` and `torch_xla` to be installed, and they cannot run in +parallel, since they are using either XRT server/client on the same +port, or because they use a GPU or TPU device and there's only one +available at the time. For that reason, all tests under +`torch_xla/csrc/` are bundled into a single target `:main` that runs +them all sequentially. + +## Code coverage + +When running tests, it can be useful to calculate code coverage. + +``` bash +bazel coverage //torch_xla/csrc/runtime/... +``` + +Coverage can be visualized using `lcov` as described in [Bazel's +documentation](https://bazel.build/configure/coverage), or in your +editor of choice with lcov plugins, e.g. [Coverage +Gutters](https://marketplace.visualstudio.com/items?itemName=ryanluker.vscode-coverage-gutters) +for VSCode. + +## Language Server + +Bazel can power a language server like [clangd](https://clangd.llvm.org/) that brings code references, +autocompletion and semantic understanding of the underlying code to your +editor of choice. For VSCode, one can use [Bazel Stack](https://github.com/stackb/bazel-stack-vscode-cc) +that can be combined with [Visual Studio clangd extension](https://marketplace.visualstudio.com/items?itemName=llvm-vs-code-extensions.vscode-clangd) +functionality to bring powerful features to assist code editing. + +## Building PyTorch/XLA + +As always, PyTorch/XLA can be built using Python `distutils`: + +``` bash +BAZEL_REMOTE_CACHE=1 SILO_NAME="cache-silo-YOUR-USER" TPUVM_MODE=1 python setup.py bdist_wheel +``` diff --git a/docs/source/contribute/codegen_migration.md b/docs/source/contribute/codegen_migration.md new file mode 100644 index 00000000000..a84e5568ba5 --- /dev/null +++ b/docs/source/contribute/codegen_migration.md @@ -0,0 +1,326 @@ +# Codegen migration Guide + +As PyTorch/XLA migrates to the LTC (Lazy Tensor Core), we need to clean +up the existing stub code (which spans over 6+ files) that were used to +do the op lowering. The complete process and file structure for the old +op lowering can be found in the op lowering guide :ref:\'op-lowering\'. +Replacing the supported op with the codegen SHOULD NOT introduce any new +behavior, it is purely for the clean up purpose. + +## Before you start + +You should follow the instructions in +[here](https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md) to +install required dependencies and build pytorch and pytorch/XLA from the +source. You do not need access to TPU to implement the lowering. It is +recommended to experiment on a workstation and configure it to use +XLA:CPU. You can configure Pytorch/XLA to use XLA:CPU by running + +``` bash +export PJRT_DEVICE=CPU +``` + +It is also recommended that you're familiar with our [op lowering +process](https://github.com/pytorch/xla/blob/master/OP_LOWERING_GUIDE.md) +before you work on the codegen. + +PyTorch/XLA uses to track +the status of codegen migration. When working on a codegen, please put +your GitHub alias with the PR link on the issue to avoid duplicate work. + +## File structure + +All file mentioned below lives under the `xla/torch_xla/csrc` folder, +with the exception of `xla_native_functions.yaml` + +### PyTorch Codegen files + +- torch/csrc/lazy/core/shape_inference.h + - Shape inference functions defined for each op that will take for + input torch::lazy::shapes and return output torch::lazy::shape. + Only the ops that is not structural will require a manual shape + inference function +- torchgen/gen_lazy_tensor.py + - Builds on existing data models and helpers used by all ATen + backends, and adds new functionality specific to lazy tensor + backends. run_gen_lazy_tensor is defined in this file +- torchgen/dest/lazy_ir.py + - Contains data class GenLazyIR that can be overridden by the back + and defined the generated IR class + +### PyTorch/XLA Codegen files + +- xla/xla_native_functions.yaml + - Contains all the op XLA supported today. Most of the ops are + under the supported category, the goal of this document is to + move most of the ops to the full_codegen category. +- xla/scripts/gen_lazy_tensor.py + - Provides necessary XLA versions of the codegen Codegen class and + calls the upstream codegen API. +- xla/torch_xla/csrc/XLANativeFunctions.cpp + - Result of the full_codegen column of the + xla/codegen/xla_native_functions.yaml. The op function defined + here will implement the op declared in the XLANativeFunctions.h. + Each op will take at::tensor and return another at::tensor + wrapped around a XLATensor. +- xla/torch_xla/csrc/LazyIr.h + - Result of the full_codegen column of the + xla/codegen/xla_native_functions.yaml. Defines the IR that is + used to construct the full_codegen ops. + +### PyTorch/XLA Old Op Lowering files + +- xla/torch_xla/csrc/generated/aten_xla_type.cpp + - Manually implements ops defined in + xla/codegen/xla_native_functions.yaml. Will be replaced by + XLANativeFunctions.cpp +- xla/torch_xla/csrc/generated/tensor.h + - Defines XLATensor class and XLATensor method declarations. These + declarations are usually a one to one mapping of the at::Tensor + nodes we declared in XLANativeFunctions.h. XLATensor method will + be removed for full_codegen ops +- xla/torch_xla/csrc/generated/tensor_method.cpp + - Implements tensor methods defined in tensor.h. This file will be + removed for full_codegen ops +- xla/torch_xla/csrc/generated/ops/... + - Defines IR class for "most" ops. It is possible that multiple + ops share the same IR. + +## Codegen step by step + +### 1. Identify the op + +When you work on your first few codegens, we generally recommend you to +start with the simpler ops. This guide will go over one unary one one +binary op as examples, but it is recommend that you avoid ops with the +following characteristics: 1. Contains custom fallback code. For example +in \_adaptive_avg_pool3d, there is a conditional fallback: + +``` c++ +if (!IsSupportedAdaptivePool(XlaHelpers::I64List(self.sizes()), + output_size_list, /*pool_dim=*/3)) { + return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(_adaptive_avg_pool3d)>::call(self, output_size); +} +``` + +2. Results in dynamic shape as these ops are WIP and may evolve over + time. At some future point, we may bring the ops into codegen. +3. Does not invoke a tensor_method directly. For example: + +``` c++ +if (!self_tensor) { + static bool sync_update = + torch_xla::runtime::sys_util::GetEnvBool("XLA_TENSOR_UPDATE_SYNC", true); + XLA_CHECK(dst_tensor); + dst_tensor->UpdateFromTensor(self, /*sync=*/sync_update); +} +``` + +4. Has a complicated tensor_method, ideally it should be a directly + mapping from op to IR. + +An good example of a "simple" op would be something like `abs`: + +``` c++ +at::Tensor XLANativeFunctions::abs(const at::Tensor& self) { + TORCH_LAZY_FN_COUNTER("xla::"); + return bridge::AtenFromXlaTensor(XLATensor::abs(bridge::GetXlaTensor(self))); +} +``` + +### 2. Codegen the op and inspect the generated file + +Find the op in `xla/codegen/xla_native_functions.yaml` and move it to +the full_codegen column and run `python setup.py install` under xla +directory again. The build will fail (reason explained later in this +guide) but you can still see the generated file. The code snippets below +uses `abs` as an example. \#### XLANativeFunctions.cpp + +``` c++ +at::Tensor XLANativeFunctions::abs(const at::Tensor & self) { + TORCH_LAZY_FN_COUNTER("xla::"); + auto common_device = torch_xla::bridge::GetXlaDevice(self); + TORCH_INTERNAL_ASSERT(common_device); + + torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device); + + torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue()); + if (!node) { + node = torch_xla::MakeNode(lazy_self->GetIrValue()); + CacheNode(node); + } + + auto result = torch_xla::bridge::AtenFromXlaTensor( + torch_xla::XLATensor::Create(std::move(node), *common_device)); + return result; +}; +``` + +Describing the generated code line by line: - Get and verify device from +input tensor + +``` c++ +auto common_device = torch_xla::bridge::GetXlaDevice(self); +TORCH_INTERNAL_ASSERT(common_device); +``` + +Check if we can reuse the node from previous creation. If not, create corresponding IR node and cache it. + +``` c++ +torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue()); +if (!node) { + node = torch_xla::MakeNode(lazy_self->GetIrValue()); + CacheNode(node); +} +``` + +Wrap the newly created IR node in a XLATensor. And wrap the XLATensor within the at::Tensor +and return it as a result. Note that this part used to be manually done in tensor_method.cpp. + +``` c++ +auto result = torch_xla::bridge::AtenFromXlaTensor( + torch_xla::XLATensor::Create(std::move(node), *common_device)); +return result; +``` + +#### LazyIr.h + +``` c++ +class Abs : public XlaNode { + public: + Abs(const torch_xla::XlaValue& self) + : XlaNode(torch::lazy::OpKind(at::aten::abs), {self}, + [&]() { return AbsOutputShape(self); }, + /* num_outputs */ 1, torch::lazy::MHash()) + {} + + std::string ToString() const override { + std::stringstream ss; + ss << XlaNode::ToString(); + return ss.str(); + } + torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override; +}; +``` + +A couple of things to keep in mind: - Codegen does not generate the +`Clone` method which is expected. There is no use of the `Clone` method +even in PyTorch/XLA today, we will remove them as part of the +migration. - For every op, it will generate a {OP}OutputShape method. We +need to manually declare and implement this method in a separate file. +-For every op, it will generate a Lower declaration. We need to manually +implement this lowering function in a separate file. + +### 3. Implement the missing IR function + +#### torch_xla/csrc/ops/ops_xla_shape_fn.h + +Declare the {OP}OutputShape: + +``` c++ +xla::Shape AbsOutputShape(const XlaValue& input); +``` + +#### torch_xla/csrc/ops/ops_xla_shape_fn.cpp + +Implement the {OP}OutputShape: + +``` c++ +xla::Shape AbsOutputShape(const XlaValue& input) { return input.xla_shape(); } +``` + +`Abs` is an overly simplified example, in a normal case you need to call +the BuildXXXOp function again to get the output shape. A slightly better +example would be: + +``` c++ +xla::Shape MaximumOutputShape(const XlaValue& input, const XlaValue& other) { + auto lower_for_shape_fn = + [&](absl::Span operands) -> xla::XlaOp { + auto promoted = XlaHelpers::Promote(operands[0], operands[1]); + return xla::Max(promoted.first, promoted.second); + }; + return InferOutputShape({input.xla_shape(), other.xla_shape()}, + lower_for_shape_fn); +} +``` + +Note that you should not start from scratch. Find the Xla::Shape +computation logic from the existing op and move it this these two files. + +### 4. Implement the lowering function + +#### torch_xla/csrc/ops/ops_lower_fn.cpp + +``` c++ +torch_xla::XlaOpVector Abs::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + return ReturnOp(BuildAbs(xla_input), loctx); +} +``` + +Note that this function should be directly moved from the existing +lowering. Some Ops that were originally implemented in +`torch_xla/csrc/ops/ops.cpp` use `GenericOp`. You will need to slightly +modify their lowering implementation to fit the implementation provided +above. + +### 5. Cleanup + +Delete the existing op from aten_xla_type.cpp, tensor_methods.h, +tensor_methods.cpp, and ops/.... Note that sometimes you have to keep +the tensor_method, because it is being used in tensor_ops like. So, +before removing the op, cross reference it with `tensor_ops.cpp`. + +``` c++ +XLATensor s1 = XLATensor::sub(XLATensor::mul(u2, v3), XLATensor::mul(u3, v2), one); +``` + +Sometimes other IRNode uses the 'IRNode' you migrated. In this case you +need to update those IRNode lowering logic as well. In the long term we +need to get rid of these composite IR from our end and provide a +lowering function for each op. + +``` c++ +torch::lazy::NodePtr exp = Pow(Abs(input), norm_exp); +``` + +to + +``` c++ +torch::lazy::NodePtr exp = + Pow(torch_xla::MakeNode(input, std::vector()), + norm_exp); +``` + +## Run the test and verify the result + +Run the C++ op test or a simple test that only involves the generated +ops. To run the C++ test: 1. Build the xla through +`python setup.py install` (note: don't use the `BUILD_CPP_TESTS=0` flag +since this will skip building the C++ tests) 2. Go into the +`test/cpp/build` directory in your `pytorch/xla` 3. Run the command to +run the desired C++ test (for example, to run `Abs` C++ test): + +``` bash +./test_ptxla --gtest_filter=AtenXlaTensorTest.TestAbs +``` + +As usual, two things to verify are the correctness and the xla counter +being incremented correctly. + +## Sample PRs + +- Unary/Binary OP -\> Codegen erf, erfc, erfinv, and exp + () +- OP with optional -\> Codegen binary_cross_entropy/backward + () +- OP with `at::Scalar` -\> Codegen addcdiv and addcmul + () +- OP with vector that support negative index -\> Codegen amin amax + () +- OP with special fallback logic -\> partially codegen + adaptive_avgpool3d and backward + () To see more examples, + please take a look at the tracking issue + (). diff --git a/docs/workflow.md b/docs/source/contribute/configure-environment.md similarity index 57% rename from docs/workflow.md rename to docs/source/contribute/configure-environment.md index fbaa2dc971d..972765108b0 100644 --- a/docs/workflow.md +++ b/docs/source/contribute/configure-environment.md @@ -1,36 +1,36 @@ -# The Cloud TPU Workflow +# Configure a development environment -The goal of this guide is to set up an interactive development environment on a -Cloud TPU with PyTorch/XLA installed. If this is your first time using TPUs, we -recommend you start with +The goal of this guide is to set up an interactive development +environment on a Cloud TPU with PyTorch/XLA installed. If this is your +first time using TPUs, we recommend you start with [Colab](https://colab.sandbox.google.com/github/tensorflow/docs/blob/master/site/en/guide/tpu.ipynb) -and [Kaggle](https://www.kaggle.com/discussions/product-feedback/369338) or. -Both options have PyTorch/XLA preinstalled with dependencies and ecosystem -packages. For an up-to-date list of examples, see our main -[`README`](https://github.com/pytorch/xla). +and [Kaggle](https://www.kaggle.com/discussions/product-feedback/369338) +or. Both options have PyTorch/XLA preinstalled with dependencies and +ecosystem packages. For an up-to-date list of examples, see our main +[README](https://github.com/pytorch/xla). -If you would like to set up a more customized development environment, keep -reading. +If you would like to set up a more customized development environment, +keep reading. ## Visual Studio Code Prerequisites: -- [Visual Studio Code](https://code.visualstudio.com/download) with the [Remote - Development - extensions](https://code.visualstudio.com/docs/remote/remote-overview) - installed on your local machine -- A GCP project with Cloud TPU quota. For more information about requesting - Cloud TPU quota, see the [official - documentation](https://cloud.google.com/tpu/docs/quota) -- An SSH key registered with `ssh-agent`. If you have not already done this, see - [GitHub's - documentation](https://docs.github.com/en/authentication/connecting-to-github-with-ssh/generating-a-new-ssh-key-and-adding-it-to-the-ssh-agent) - -Before you begin, export environment variables with the GCP project and zone -where you have Cloud TPU quota: - -``` +- [Visual Studio Code](https://code.visualstudio.com/download) with + the [Remote Development + extensions](https://code.visualstudio.com/docs/remote/remote-overview) + installed on your local machine +- A GCP project with Cloud TPU quota. For more information about + requesting Cloud TPU quota, see the [official + documentation](https://cloud.google.com/tpu/docs/quota) +- An SSH key registered with `ssh-agent`. If you have not already done + this, see [GitHub's + documentation](https://docs.github.com/en/authentication/connecting-to-github-with-ssh/generating-a-new-ssh-key-and-adding-it-to-the-ssh-agent) + +Before you begin, export environment variables with the GCP project and +zone where you have Cloud TPU quota: + +``` bash export PROJECT=... export ZONE=... export TPU_TYPE=... # e.g. "v2-8" @@ -40,27 +40,27 @@ export TPU_TYPE=... # e.g. "v2-8" Create a Cloud TPU VM with your SSH key registered: -```bash +``` bash # Assuming your SSH key is named `id_ed25519` gcloud compute tpus tpu-vm create --project=$PROJECT --zone=$ZONE --accelerator-type=$TPU_TYPE --version=tpu-ubuntu2204-base --metadata="ssh-keys=$USER:$(cat ~/.ssh/id_ed25519.pub)" $USER-tpu ``` Check that your TPU has an external IP and SSH to it: -```bash +``` bash gcloud compute tpus tpu-vm describe --project=$PROJECT --zone=$ZONE $USER-tpu --format="value(networkEndpoints.accessConfig.externalIp)" # Output: 123.123.123.123 ``` Give your TPU a friendly name to make future steps easier: -```bash +``` bash echo -e Host $USER-tpu "\n " HostName $(gcloud compute tpus tpu-vm describe --project=$PROJECT --zone=$ZONE $USER-tpu --format="value(networkEndpoints.accessConfig.externalIp)") >> ~/.ssh/config ``` SSH to your TPU to test your connection: -``` +``` bash ssh $USER-tpu ``` @@ -68,41 +68,42 @@ ssh $USER-tpu From the [VS Code Command Palette](https://code.visualstudio.com/docs/getstarted/userinterface#_command-palette), -select [`Remote-SSH: Connect to -Host`](https://code.visualstudio.com/docs/remote/ssh) and select the host you -just created (named `$USER-tpu`). VS Code will then open a new window connected -to your TPU VM. +select `` `Remote-SSH: Connect to Host `` +\<\>[\_\_ and select the +host you just created (named ]{.title-ref}[\$USER-tpu]{.title-ref}\`). +VS Code will then open a new window connected to your TPU VM. -From the built-in `Terminal`, create a new folder to use as a workspace (e.g. -`mkdir ptxla`). Then open the folder from the UI or Command Palette. +From the built-in `Terminal`, create a new folder to use as a workspace +(e.g. `mkdir ptxla`). Then open the folder from the UI or Command +Palette. -Note: It is optional (but recommended) at this point to install the official -[Python +Note: It is optional (but recommended) at this point to install the +official [Python extension](https://marketplace.visualstudio.com/items?itemName=ms-python.python) -and create a [`venv` virtual +and create a [venv virtual environment](https://code.visualstudio.com/docs/python/environments#_using-the-create-environment-command) via the Command Palette (`Python: Create Environment`). Install the latest PyTorch and PyTorch/XLA releases: -``` +``` bash pip install numpy torch torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html ``` Create a file `test.py`: -```python +``` python import torch_xla as xla # Optional xla.runtime.set_device_type("TPU") -print("XLA devices:", xla.real_devices()) +print("XLA devices:", xla.real_devices()) ``` Run the test script from your terminal: -```bash +``` bash $ python test.py # Output: XLA devices: ['TPU:0', 'TPU:1', 'TPU:2', 'TPU:3', 'TPU:4', 'TPU:5', 'TPU:6', 'TPU:7'] # Number of devices will vary based on TPU type @@ -110,6 +111,6 @@ $ python test.py ### Next steps -That's it! You should now have a remote Visual Studio Code workspace set up with -PyTorch/XLA installed. To run more realistic examples, see our [examples -guide](https://github.com/pytorch/xla/tree/master/examples). +That's it! You should now have a remote Visual Studio Code workspace set +up with PyTorch/XLA installed. To run more realistic examples, see our +[examples guide](https://github.com/pytorch/xla/tree/master/examples). diff --git a/docs/source/contribute/op_lowering.md b/docs/source/contribute/op_lowering.md new file mode 100644 index 00000000000..45a445910aa --- /dev/null +++ b/docs/source/contribute/op_lowering.md @@ -0,0 +1,186 @@ +# OP Lowering Guide + +PyTorch wraps the C++ ATen tensor library that offers a wide range of +operations implemented on GPU and CPU. Pytorch/XLA is a PyTorch +extension; one of its purposes is to convert PyTorch operations to XLA +operations. Lowering defines a process of converting a higher-level +representation to a lower-level representation. In this document, I will +refer to the process of converting PyTorch operation to XLA operation as +the lowering. XLA Compiler will also lower XlaOp to HLO, but that's +beyond the scope of this documentation. We will forward operations that +we haven't provided an XLA lowering yet to CPU and call ATen +implementations. Operations that are forwarded to the CPU will cause a +significant slowdown. We must lower all operations used in the model to +achieve the best performance. + +Here's an example of what you might see from the PyTorch/XLA debugging +tool for an operation that has not been lowered: + +``` none + pt-xla-profiler: Op(s) not lowered: aten::_ctc_loss, aten::_ctc_loss_backward, Please open a GitHub issue with the above op lowering requests. +``` + +## Before you start + +You should follow the instructions in +[Contributing to Pytorch/XLA](https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md) to +install required dependencies and build pytorch and pytorch/XLA from the +source. You do not need access to TPU to implement the lowering. It is +recommended to experiment on a workstation and configure it to use +XLA:CPU. You can configure Pytorch/XLA to use XLA:CPU by running + +``` bash + export PJRT_DEVICE=CPU +``` + +## Understanding the operation + +You can find the definition of the C++ ATen operations in +[native_functions.yaml](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml). +After you build Pytorch/XLA from source, you will also find our default +implementation (a boxed kernel which forwards calls to either PyTorch +native kernels) in `xla/torch_xla/csrc/aten_fallback.h/cpp`. Pytorch +operations can usually be mapped to [PyTorch tensor +api](https://pytorch.org/docs/stable/index.html) easily. If that is not +the case searching the PyTorch native implementation under [PyTorch +repo](https://github.com/pytorch/pytorch) is recommended. The goal is to +lower the PyTorch operations into a sequence of XLA operations defined +in [XLA operation semantics](https://www.tensorflow.org/xla/operation_semantics). + +## File structure + +All file mentioned below lives under the `xla/torch_xla/csrc` folder, +with the exception of `codegen/xla_native_functions.yaml` + +1. `xla_native_functions.yaml` contains the list of all operators (from + the [Core Aten + list](https://pytorch.org/docs/stable/torch.compiler_ir.html)) that + are explicitly lowered. Composed operators are not listed here. Each + operator name here must directly match a pytorch operator listed in + [native_functions.yaml](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml). + This file serves as the interface to adding new xla operators, and + is an input to PyTorch's [codegen + machinery](https://github.com/pytorch/pytorch/blob/main/torchgen/gen_backend_stubs.py). + It generates the below 3 files: `XLANativeFunctions.h`, + `RegisterXLA.cpp`, and `RegisterAutogradXLA.cpp` +2. `XLANativeFunctions.h` and `aten_xla_type.cpp` are entry points of + PyTorch to the pytorch_xla world, and contain the manually written + lowerings to XLA for each operator. `XLANativeFunctions.h` is + auto-generated through a combination of `xla_native_functions.yaml` + and the PyTorch core `native_functions.yaml` file, and contains + declarations for kernels that need to be defined in + `aten_xla_type.cpp`. The kernels written here need to construct + 'XLATensor' using the input `at::Tensor` and other parameters. The + resulting `XLATensor` needs to be converted back to the `at::Tensor` + before returning to the PyTorch world. +3. `RegisterXLA.cpp` and `RegisterAutogradXLA.cpp` are auto-generated + files that register all lowerings to the PyTorch Dispatcher. They + also include auto-generated wrapper implementations of `out=` and + `inplace` operators. +4. `aten_fallback.h/.cpp` contain our boxed fallback implementation. + The boxed fallback kernel will be used if a lowering is not + explicitly defined in `xla_native_functions.yaml` + + `aten_xla_type.cpp`, and the operator is not composite. +5. `tensor_methods.h` contains the `XLATensor` declarations. These + declarations are usually a one to one mapping of the `at::Tensor` + nodes we declared in `XLANativeFunctions.h` +6. `tensor_methods.cpp` contains the implementation of `XLATensor node` + defined in `tensor_methods.h`. We constructed the corresponding + `ir::op` from the parameter's `ir::Value` and wrapped it inside a + `XLATensor`. Ir stands for intermediate representation. +7. `ops/` directory contains all `ir::ops` declaration and definition. + Smaller nodes can be put in `ops/ops.h/.cpp`. More complicated nodes + can be put into a separate file. All ops inherit from + `ir::ops::Node` and provide a way to lower input `ir::Value` to a + sequence of `XlaOp`. + +## Unit Test + +Our CI runs PyTorch native python tests for every change and every day. +Those tests will use XLA implementation if we provide a lowering. We +usually don't need to add additional python tests for PyTorch/XLA unless +we want to verify some xla behaviors(like dynamic shape) or we skipped +the pytorch native test for some reason. The python test should be added +to `xla/test/test_operations.py` if it is required. We also need to add +CPP tests in `xla/test/cpp/test_aten_xla_tensor.cpp`. This test should +call PyTorch c++ API and verify our implementation yields the same +result as PyTorch native implementation. We also need to verify if the +xla implementation is called when the tensor is a XLA tensor by checking +the `aten::op` and `xla::op` counters. + +## Tips + +The process of lowering is breaking down the PyTorch operations into a +sequence of XlaOp. To provide a good lowering of the PyTorch operation, +one needs to have a good grasp of what XLA is capable of. Reading the +XlaOp document and looking into how similar ops is lowered is the best +way to achieve that. You can find a minimal Op lowering example in [this Op lowering PR](https://github.com/pytorch/xla/pull/2969). You can also find a +slightly more complicated example with backward lowering in [this backward lowering PR](https://github.com/pytorch/xla/pull/2972). + +We have auto-generated wrapper implementations of `out=` and `inplace` +operators for some operators in `RegisterXLA.cpp`. We only need to lower +the vanilla op in this case. An example would be `lerp` operator which +has 6 variants in `native_functions.yaml`, they are + + + - lerp_.Scalar + - lerp_.Tensor + - lerp.Scalar_out + - lerp.Tensor_out + - lerp.Scalar + - lerp.Tensor + +and will generate function prototypes + +``` c++ + at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight); + at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight); + at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight); + at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight, at::Tensor & out); + at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Tensor & weight); + at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight, at::Tensor & out); +``` + +in `XLANativeFunctions.h` if we add all of them to the +`xla_native_functions.yaml`. However if we only lower `lerp.Scalar` and +`lerp.Tensor` and check `RegisterXLA.cpp`, we will see + +``` c++ + namespace { + + at::Tensor wrapper_Scalar_lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) { + // No device check + + + // DeviceGuard omitted + return torch_xla::lerp(self, end, weight); + } + + } // anonymous namespace + + at::Tensor & wrapper_Scalar_lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) { + auto wrapper_Scalar_lerp__tmp = wrapper_Scalar_lerp(self, end, weight); + at::_copy_from(wrapper_Scalar_lerp__tmp, self); + return self; + } + + ... + m.impl("lerp_.Scalar", + TORCH_FN(wrapper_Scalar_lerp_)); +``` + +The codegen will automatically generate lowerings for `lerp_.Scalar` and +`lerp.Scalar_out` that use our `lerp.Scalar` implementation, without us +having to provide an explicit lowering. + +In general, if there is an operator in pytorch core that has both an +out-of-place and an out= variant, it's better to write a lowering for +the out-of-place variant, since you'll get a code-generated out= +lowering for free. + +For each node we need to pass an `ir::OpKind`. Here is an +([example](https://github.com/pytorch/xla/blob/5ce99bff336325feb41a982dc80299fb53166b29/torch_xla/csrc/ops/var_mean.cpp#L36)). +You can find the `OpKind` definition in +[interned_strings.h](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/core/interned_strings.h). +If the aten symbol is missing, you can submit a PR like +[this](https://github.com/pytorch/pytorch/pull/36851). diff --git a/docs/source/contribute/plugins.md b/docs/source/contribute/plugins.md new file mode 100644 index 00000000000..08985f4a496 --- /dev/null +++ b/docs/source/contribute/plugins.md @@ -0,0 +1,85 @@ +# Custom Hardware Plugins + +PyTorch/XLA supports custom hardware through OpenXLA's PJRT C API. The +PyTorch/XLA team direclty supports plugins for Cloud TPU (`libtpu`) and +GPU ([OpenXLA](https://github.com/openxla/xla/tree/main/xla/pjrt/gpu)). +The same plugins may also be used by JAX and TF. + +## Implementing a PJRT Plugin + +PJRT C API plugins may be closed-source or open-source. They contain two +parts: + +1. Binary exposing a PJRT C API implementation. This part can be shared + with JAX and TensorFlow. +2. Python package containing the above binary, as well as an + implementation of our `DevicePlugin` Python interface, which handles + additional setup. + +### PJRT C API Implementation + +In short, you must implement a +[PjRtClient](https://github.com/openxla/xla/blob/main/xla/pjrt/pjrt_client.h) +containing an XLA compiler and runtime for your device. The PJRT C++ +interface is mirrored in C in the +[PJRT_Api](https://github.com/openxla/xla/blob/main/xla/pjrt/c/pjrt_c_api.h). +The most straightforward option is to implement your plugin in C++ and +[wrap +it](https://github.com/openxla/xla/blob/main/xla/pjrt/c/pjrt_c_api_wrapper_impl.h) +as a C API implementation. This process is explained in detail in +[OpenXLA's +documentation](https://openxla.org/xla/pjrt_integration#how_to_integrate_with_pjrt). + +For a concrete example, see the [example +implementation](https://github.com/openxla/xla/blob/main/xla/pjrt/c/pjrt_c_api_cpu_internal.cc). + +### PyTorch/XLA Plugin Package + +At this point, you should have a functional PJRT plugin binary, which +you can test with the placeholder `LIBRARY` device type. For example: + + $ PJRT_DEVICE=LIBRARY PJRT_LIBRARY_PATH=/path/to/your/plugin.so python + >>> import torch_xla + >>> torch_xla.devices() + # Assuming there are 4 devices. Your hardware may differ. + [device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)] + +To register your device type automatically for users as well as to +handle extra setup for e.g. multiprocessing, you may implement the +`DevicePlugin` Python API. PyTorch/XLA plugin packages contain two key +components: + +1. An implementation of `DevicePlugin` that (at the very least) + provides the path to your plugin binary. For example: + +``` python +class CpuPlugin(plugins.DevicePlugin): + + def library_path(self) -> str: + return os.path.join( + os.path.dirname(__file__), 'lib', 'pjrt_c_api_cpu_plugin.so') +``` + +2. A `torch_xla.plugins` [entry + point](https://setuptools.pypa.io/en/latest/userguide/entry_point.html) + that identifies your `DevicePlugin`. For exmaple, to register the + `EXAMPLE` device type in a `pyproject.toml`: + +```{=html} + +``` + [project.entry-points."torch_xla.plugins"] + example = "torch_xla_cpu_plugin:CpuPlugin" + +With your package installed, you may then use your `EXAMPLE` device +directly: + + $ PJRT_DEVICE=EXAMPLE python + >>> import torch_xla + >>> torch_xla.devices() + [device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)] + +[DevicePlugin](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/plugins.py) +provides additional extension points for multiprocess initialization and +client options. The API is currently in an experimental state, but it is +expected to become stable in a future release. diff --git a/docs/source/debug.rst b/docs/source/debug.rst deleted file mode 100644 index 7c6a6eee671..00000000000 --- a/docs/source/debug.rst +++ /dev/null @@ -1 +0,0 @@ -.. mdinclude:: ../../TROUBLESHOOTING.md \ No newline at end of file diff --git a/docs/source/eager_mode.rst b/docs/source/eager_mode.rst deleted file mode 100644 index 05e7d359e1d..00000000000 --- a/docs/source/eager_mode.rst +++ /dev/null @@ -1 +0,0 @@ -.. mdinclude:: ../eager.md \ No newline at end of file diff --git a/docs/pallas.md b/docs/source/features/pallas.md similarity index 55% rename from docs/pallas.md rename to docs/source/features/pallas.md index 99bf9b72496..0779bddd38d 100644 --- a/docs/pallas.md +++ b/docs/source/features/pallas.md @@ -1,9 +1,18 @@ # Custom Kernels via Pallas -With the rise of OpenAI [triton](https://openai.com/research/triton), custom kernels become more and more popular in the GPU community, for instance, the introduction of [FlashAttention](https://github.com/Dao-AILab/flash-attention) and [PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html). In order to provide the feature parity in the TPU world, Google has introduced [Pallas](https://jax.readthedocs.io/en/latest/pallas/index.html). For PyTorch/XLA to continue pushing the performance in TPU, we have to support custom kernels, and the best way is through Pallas. The design doc is [TBA](). +With the rise of OpenAI [Triton](https://openai.com/research/triton), +custom kernels become more and more popular in the GPU community, for +instance, the introduction of +[FlashAttention](https://github.com/Dao-AILab/flash-attention) and +[PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html). In order to +provide the feature parity in the TPU world, Google has introduced +[Pallas](https://jax.readthedocs.io/en/latest/pallas/index.html). For +PyTorch/XLA to continue pushing the performance in TPU, we have to +support custom kernels, and the best way is through Pallas. Let's assume you have a Pallas kernel defined as follow: -```python3 + +``` python3 from torch_xla.experimental.custom_kernel import jax_import_guard jax_import_guard() @@ -22,12 +31,15 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: )(x, y) ``` -To be noted, it's very important to run `jax_import_guard()` before importing any jax modules. Otherwise, the program will hang on TPU as jax will lock the TPU and torch-xla cannot access it. +To be noted, it's very important to run `jax_import_guard()` before +importing any jax modules. Otherwise, the program will hang on TPU as +jax will lock the TPU and torch-xla cannot access it. ## Adopt the above kernel to be compatible with PyTorch/XLA Example usage: -```python3 + +``` python3 q = torch.randn(3, 2, 128, 4).to("xla") k = torch.randn(3, 2, 128, 4).to("xla") v = torch.randn(3, 2, 128, 4).to("xla") @@ -37,30 +49,39 @@ from torch_xla.experimental.custom_kernel import make_kernel_from_pallas pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)]) output = pt_kernel(q, k) ``` -For simple kernels, the adoption is just as simple as one liner. For more complicated kernels, you can refer to our Flash Attention implementation for details. + +For simple kernels, the adoption is just as simple as one liner. For +more complicated kernels, you can refer to our Flash Attention +implementation for details. ## Use built-in kernels -Besides manually wrapping external Pallas kernels, there are built-in kernels where the adoptions are done by PyTorch/XLA already. These built-in kernels can be used like any other torch.ops. The current built-in kernels that are suppored are: -- FlashAttention -- PagedAttention +Besides manually wrapping external Pallas kernels, there are built-in +kernels where the adoptions are done by PyTorch/XLA already. These +built-in kernels can be used like any other torch.ops. The current +built-in kernels that are suppored are: - FlashAttention -PagedAttention ### FlashAttention #### Example usage -```python3 + +``` python3 # Use built-in kernels import torch_xla.experimental.custom_kernel output = flash_attention(q, k, v) ``` #### Integration Example -We have an example of [FlashAttention integration here](https://github.com/pytorch/xla/blob/master/examples/flash_attention/train_decoder_only_flash_attention.py) in our training test script. + +We have an example of [FlashAttention integration +here](https://github.com/pytorch/xla/blob/master/examples/flash_attention/train_decoder_only_flash_attention.py) +in our training test script. ### PagedAttention #### Example usage -```python3 + +``` python3 # Use built-in kernels import torch_xla.experimental.custom_kernel output = torch.ops.xla.paged_attention( @@ -75,11 +96,17 @@ output = torch.ops.xla.paged_attention( ``` #### Integration Example -The vLLM TPU integration utilizes [PagedAttention here](https://github.com/vllm-project/vllm/blob/f5e1bf5d44877149eaabf9c04379a4e14a023145/vllm/attention/backends/pallas.py#L194) for effective memory management with KV cache. +The vLLM TPU integration utilizes [PagedAttention +here](https://github.com/vllm-project/vllm/blob/f5e1bf5d44877149eaabf9c04379a4e14a023145/vllm/attention/backends/pallas.py#L194) +for effective memory management with KV cache. ## Dependencies -The Pallas integration depends on JAX to function. However, not every JAX version is compatible with your installed PyTorch/XLA. To install the proper JAX: -```bash + +The Pallas integration depends on JAX to function. However, not every +JAX version is compatible with your installed PyTorch/XLA. To install +the proper JAX: + +``` bash pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html ``` diff --git a/docs/stablehlo.md b/docs/source/features/stablehlo.md similarity index 67% rename from docs/stablehlo.md rename to docs/source/features/stablehlo.md index 2700a263ddb..b53d79fb10f 100644 --- a/docs/stablehlo.md +++ b/docs/source/features/stablehlo.md @@ -1,13 +1,9 @@ -Torch Export to StableHLO --------------------------- +# Torch Export to StableHLO -This document describes how to use torch export + torch xla to export to +This document describes how to use torch export + torch xla to export to [StableHLO](https://github.com/openxla/stablehlo) format. - -## How to use: - -```python +``` python from torch.export import export from torch_xla.stablehlo import exported_program_to_stablehlo import torch_xla.core.xla_model as xm @@ -42,58 +38,63 @@ output2 = stablehlo_program(*sample_input_xla) print(torch.allclose(output, output2.cpu(), atol=1e-5)) ``` -# Saving StableHLO bytecodes to disk +## Saving StableHLO bytecodes to disk + +One can now save stablehlo to disk with -One can now save stablehlo to disk with -```python +``` python stablehlo_program.save('/tmp/stablehlo_dir') ``` -The path should be path to an empty directory. If it doesn't exist, it will be created. -This directory can be loaded again as another stablehlo_program: -```python +The path should be path to an empty directory. If it doesn't exist, it +will be created. This directory can be loaded again as another +stablehlo_program: + +``` python from torch_xla.stablehlo import StableHLOGraphModule stablehlo_program2 = StableHLOGraphModule.load('/tmp/stablehlo_dir') output3 = stablehlo_program2(*sample_input_xla) ``` -# Convert saved StableHLO for serving +## Convert saved StableHLO for serving -StableHLO is an open format and it is supported for serving in [tensorflow.serving](https://github.com/tensorflow/serving) model server. However, before giving it to tf.serving, we need to first -wrap the generated StableHLO bytecode into a `tf.saved_model` format. +StableHLO is an open format and it is supported for serving in +[tensorflow.serving](https://github.com/tensorflow/serving) model +server. However, before giving it to tf.serving, we need to first wrap +the generated StableHLO bytecode into a `tf.saved_model` format. -For that, first ensure that you have the latest `tensorflow` install in the current python env, -if not, install with +For that, first ensure that you have the latest `tensorflow` install in +the current python env, if not, install with -```bash +``` bash pip install tf-nightly ``` Now, you can run a converter (provided in the torch/xla installation) -``` -stablehlo-to-saved-model /tmp/stablehlo_dir /tmp/resnet_tf/1 -``` -After that, you can run your model server on the newly generated `tf.saved_model` with -tf serving binary. + stablehlo-to-saved-model /tmp/stablehlo_dir /tmp/resnet_tf/1 +After that, you can run your model server on the newly generated +`tf.saved_model` with tf serving binary. -``` +``` bash docker pull tensorflow/serving docker run -p 8500:8500 \ --mount type=bind,source=/tmp/resnet_tf,target=/models/resnet_tf \ -e MODEL_NAME=resnet_tf -t tensorflow/serving & ``` -You can also use the `tf.serving` binary directly without docker. -For more details, please follow the [tf serving guide](https://www.tensorflow.org/tfx/serving/serving_basic). +You can also use the `tf.serving` binary directly without docker. For +more details, please follow the [tf serving +guide](https://www.tensorflow.org/tfx/serving/serving_basic). -# Common wrappers +## Common wrappers ### I want to save directly tf.saved_model format without needing to run an separate command. You can accomplish this by using this helper function: -```python + +``` python from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model save_torch_module_as_tf_saved_model( @@ -103,43 +104,46 @@ save_torch_module_as_tf_saved_model( ) ``` -### Other common wrappers +## Other common wrappers -```python +``` python def save_as_stablehlo(exported_model: 'ExportedProgram', stablehlo_dir: os.PathLike, options: Optional[StableHLOExportOptions] = None): ``` -`save_as_stablehlo` (also aliased as `torch_xla.save_as_stablehlo`) -takes ExportedProgram and saves StableHLO on disk. i.e. - same as exported_program_to_stablehlo(...).save(...) +`save_as_stablehlo` (also aliased as `torch_xla.save_as_stablehlo`) +takes ExportedProgram and saves StableHLO on disk. i.e. same as +exported_program_to_stablehlo(...).save(...) -```python +``` python def save_torch_model_as_stablehlo( torchmodel: torch.nn.Module, args: Tuple[Any], path: os.PathLike, options: Optional[StableHLOExportOptions] = None) -> None: - """Convert a torch model to a callable backed by StableHLO. - + """Convert a torch model to a callable backed by StableHLO. ``` -takes `torch.nn.Module` and saves StableHLO on disk. i.e. - same as torch.export.export followed by save_as_stablehlo +takes `torch.nn.Module` and saves StableHLO on disk. i.e. same as +torch.export.export followed by save_as_stablehlo -# Files produced by `save_as_stablehlo`. +## Files produced by `save_as_stablehlo`. -Inside of `/tmp/stablehlo_dir` in the example above, you will find 3 directories: `data`, `constants`, `functions`. Both data and constants will consist of tensors used by the program -saved as `numpy.ndarray` using `numpy.save`. - -The functions directory will contain StableHLO bytecode, here named `forward.bytecode`, human readable StableHLO code (MLIR form) `forward.mlir`, and a JSON file specifying which weights -and original user's input become the which positional arguments of this StableHLO function; as well -as the dtypes and shapes of every argument. +Inside of `/tmp/stablehlo_dir` in the example above, you will find 3 +directories: `data`, `constants`, `functions`. Both data and constants +will consist of tensors used by the program saved as `numpy.ndarray` +using `numpy.save`. +The functions directory will contain StableHLO bytecode, here named +`forward.bytecode`, human readable StableHLO code (MLIR form) +`forward.mlir`, and a JSON file specifying which weights and original +user's input become the which positional arguments of this StableHLO +function; as well as the dtypes and shapes of every argument. Example: -``` + +``` bash $ find /tmp/stablehlo_dir ./functions ./functions/forward.mlir @@ -159,22 +163,39 @@ $ find /tmp/stablehlo_dir ... ``` -The JSON file is serialized form of the `torch_xla.stablehlo.StableHLOFunc` class. +The JSON file is serialized form of the +`torch_xla.stablehlo.StableHLOFunc` class. -This format is currently also in prototype stage and there are no backward compatibility guarantees. -The future plan is to standardize a format that the major frameworks (PyTorch, JAX, TensorFlow) can agree. +This format is currently also in prototype stage and there are no +backward compatibility guarantees. The future plan is to standardize a +format that the major frameworks (PyTorch, JAX, TensorFlow) can agree. -# Preserving High-Level PyTorch Operations in StableHLO by generating `stablehlo.composite` +## Preserving High-Level PyTorch Operations in StableHLO by generating `stablehlo.composite` -High level PyTorch ops (e.g. `F.scaled_dot_product_attention`) will be decomposed into low level ops during PyTorch -> StableHLO lowering. Capturing the high level op in downstream ML compilers can be crucial for genearting a performant, efficient specialized kernels. While pattern matching a bunch of low level ops in the ML compiler can be challenging and error-prone, we offer a more robust method to outline the high-level PyTorch op in StableHLO program - by generating [stablehlo.composite](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#composite) for the high level PyTorch ops. +High level PyTorch ops (e.g. `F.scaled_dot_product_attention`) will be +decomposed into low level ops during PyTorch -\> StableHLO lowering. +Capturing the high level op in downstream ML compilers can be crucial +for genearting a performant, efficient specialized kernels. While +pattern matching a bunch of low level ops in the ML compiler can be +challenging and error-prone, we offer a more robust method to outline +the high-level PyTorch op in StableHLO program - by generating +[stablehlo.composite](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#composite) +for the high level PyTorch ops. -With `StableHLOCompositeBuilder`, user can outline an arbitary region within the `forward` function of a `torch.nn.Module`. Then in the exported StableHLO program, a composite op for the outlined region will be produced. +With `StableHLOCompositeBuilder`, user can outline an arbitary region +within the `forward` function of a `torch.nn.Module`. Then in the +exported StableHLO program, a composite op for the outlined region will +be produced. -**NOTE:** Because the value of non-tensor inputs to the outlined region will be hardcoded in the exported graph, please store those values as composite attributes, if retrieving from the downstream compiler is desired. +**NOTE:** Because the value of non-tensor inputs to the outlined region +will be hardcoded in the exported graph, please store those values as +composite attributes, if retrieving from the downstream compiler is +desired. -The following example shows a pratical use case - capturing `scaled_product_attention` +The following example shows a pratical use case - capturing +`scaled_product_attention` -```python +``` python import torch import torch.nn.functional as F from torch_xla import stablehlo @@ -213,7 +234,7 @@ print(stablehlo) The main StableHLO graph is shown below: -```mlir +``` none module @IrToHlo.56 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { func.func @main(%arg0: tensor<10x8x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128x128xf32>) -> tensor<10x8x128xf32> { ... @@ -229,15 +250,18 @@ module @IrToHlo.56 attributes {mhlo.cross_program_prefetches = [], mhlo.input_ou } ``` -The sdpa operation is encapsulated as a stablehlo composite call within the main graph. The name and attributes specified in the torch.nn.Module are propagated. +The sdpa operation is encapsulated as a stablehlo composite call within +the main graph. The name and attributes specified in the torch.nn.Module +are propagated. -```mlir +``` none %10 = stablehlo.composite "test.sdpa" %3, %6, %9 {composite_attributes = {other_attr = "val", scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl} ``` -The reference PyTorch decomposition of the sdpa operation is captured in a StableHLO function: +The reference PyTorch decomposition of the sdpa operation is captured in +a StableHLO function: -```mlir +``` none func.func private @test.sdpa.impl(%arg0: tensor<10x8x128xf32>, %arg1: tensor<10x8x128xf32>, %arg2: tensor<10x8x128xf32>) -> tensor<10x8x128xf32> { // Actual implementation of the composite ... diff --git a/docs/triton.md b/docs/source/features/triton.md similarity index 84% rename from docs/triton.md rename to docs/source/features/triton.md index bcc4444a67a..9e036e31ba0 100644 --- a/docs/triton.md +++ b/docs/source/features/triton.md @@ -1,9 +1,14 @@ # Custom GPU Kernels via Triton -PyTorch/XLA now supports [Triton](https://openai.com/research/triton) kernels, enabling high-performance deep learning model execution on GPUs. Triton, a specialized language and compiler for GPU programming, empowers developers to write custom kernels that leverage the full potential of GPUs for various operations in deep learning models. +PyTorch/XLA now supports [Triton](https://openai.com/research/triton) +kernels, enabling high-performance deep learning model execution on +GPUs. Triton, a specialized language and compiler for GPU programming, +empowers developers to write custom kernels that leverage the full +potential of GPUs for various operations in deep learning models. Given a Triton kernel defined as follows: -```python3 + +``` python3 @triton.jit def add_kernel( x_ptr, # *Pointer* to first input vector. @@ -22,12 +27,12 @@ def add_kernel( y = tl.load(y_ptr + offsets, mask=mask) output = x + y tl.store(output_ptr + offsets, output, mask=mask) - ``` -We can run make this kernel a part of the PyTorch/XLA execution graph as follows: +We can run make this kernel a part of the PyTorch/XLA execution graph as +follows: -```python3 +``` python3 import torch import torch_xla.experimental.triton as xla_triton @@ -55,13 +60,16 @@ payload = xla_triton.triton_call( # regarding how the GPU buffers will be loaded when this node is executed. output = torch_xla._XLAC._xla_gpu_custom_call([x, y], payload, [output.shape], [torch.int64]) - ``` -For more complex kernels, you can also refer to the Triton Flash Attention kernel test in PyTorch/XLA. +For more complex kernels, you can also refer to the Triton Flash +Attention kernel test in PyTorch/XLA. ## Dependencies -The Triton integration depends on the `triton` package to function. This code is tested with `triton==2.3.0`. To install: -```bash + +The Triton integration depends on the `triton` package to function. This +code is tested with `triton==2.3.0`. To install: + +``` bash pip install --no-deps triton==2.3.0 ``` diff --git a/docs/source/gpu.rst b/docs/source/gpu.rst deleted file mode 100644 index 79d8385467a..00000000000 --- a/docs/source/gpu.rst +++ /dev/null @@ -1 +0,0 @@ -.. mdinclude:: ../gpu.md \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index edd6bcc5372..7b03724cc31 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -5,104 +5,68 @@ PyTorch/XLA documentation PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs. .. toctree:: - :hidden: + :glob: + :maxdepth: 1 + :caption: Learn about Pytorch/XLA - self + learn/xla-overview + learn/pytorch-on-xla-devices + learn/api-guide + learn/dynamic_shape + learn/eager + learn/pjrt + learn/troubleshoot .. toctree:: :glob: :maxdepth: 1 - :caption: Docs - - * - -.. mdinclude:: ../../API_GUIDE.md - -PyTorch/XLA API -================================== - -torch_xla ----------------------------------- -.. automodule:: torch_xla -.. autofunction:: device -.. autofunction:: devices -.. autofunction:: device_count -.. autofunction:: sync -.. autofunction:: compile -.. autofunction:: manual_seed + :caption: Learn about accelerators -runtime ----------------------------------- -.. automodule:: torch_xla.runtime -.. autofunction:: device_type -.. autofunction:: local_process_count -.. autofunction:: local_device_count -.. autofunction:: addressable_device_count -.. autofunction:: global_device_count -.. autofunction:: global_runtime_device_count -.. autofunction:: world_size -.. autofunction:: global_ordinal -.. autofunction:: local_ordinal -.. autofunction:: get_master_ip -.. autofunction:: use_spmd -.. autofunction:: is_spmd -.. autofunction:: initialize_cache + accelerators/tpu + accelerators/gpu +.. toctree:: + :glob: + :maxdepth: 1 + :caption: Run ML workloads with Pytorch/XLA -xla_model ----------------------------------- - -.. automodule:: torch_xla.core.xla_model -.. autofunction:: xla_device -.. autofunction:: xla_device_hw -.. autofunction:: is_master_ordinal -.. autofunction:: all_reduce -.. autofunction:: all_gather -.. autofunction:: all_to_all -.. autofunction:: add_step_closure -.. autofunction:: wait_device_ops -.. autofunction:: optimizer_step -.. autofunction:: save -.. autofunction:: rendezvous -.. autofunction:: mesh_reduce -.. autofunction:: set_rng_state -.. autofunction:: get_rng_state -.. autofunction:: get_memory_info -.. autofunction:: get_stablehlo -.. autofunction:: get_stablehlo_bytecode - -distributed ----------------------------------- - -.. automodule:: torch_xla.distributed.parallel_loader -.. autoclass:: MpDeviceLoader - -.. automodule:: torch_xla.distributed.xla_multiprocessing -.. autofunction:: spawn - -spmd ----------------------------------- -.. automodule:: torch_xla.distributed.spmd -.. autofunction:: mark_sharding -.. autofunction:: clear_sharding -.. autofunction:: set_global_mesh -.. autofunction:: get_global_mesh -.. autofunction:: get_1d_mesh -.. autoclass:: Mesh -.. autoclass:: HybridMesh + workloads/kubernetes -experimental ----------------------------------- -.. automodule:: torch_xla.experimental -.. autofunction:: eager_mode +.. toctree:: + :glob: + :maxdepth: 1 + :caption: PyTorch/XLA features -debug ----------------------------------- + features/pallas.md + features/stablehlo.md + features/triton.md -.. automodule:: torch_xla.debug.metrics -.. autofunction:: metrics_report -.. autofunction:: short_metrics_report -.. autofunction:: counter_names -.. autofunction:: counter_value -.. autofunction:: metric_names -.. autofunction:: metric_data +.. toctree:: + :glob: + :maxdepth: 1 + :caption: Improve Pytorch/XLA workload performance + + perf/amp + perf/spmd_basic + perf/spmd_advanced + perf/spmd_distributed_checkpoint + perf/spmd_gpu + perf/ddp + perf/dynamo + perf/fori_loop + perf/fsdp + perf/fsdpv2 + perf/quantized_ops + perf/recompilation + +.. toctree:: + :glob: + :maxdepth: 1 + :caption: Contribute to Pytorch/XLA + + contribute/configure-environment + contribute/codegen_migration + contribute/op_lowering + contribute/plugins + contribute/bazel + contribute/recompilation diff --git a/docs/source/learn/api-guide.rst b/docs/source/learn/api-guide.rst new file mode 100644 index 00000000000..8b59a2fe9f1 --- /dev/null +++ b/docs/source/learn/api-guide.rst @@ -0,0 +1,89 @@ + +PyTorch/XLA API +================================== + +torch_xla +---------------------------------- +.. automodule:: torch_xla +.. autofunction:: device +.. autofunction:: devices +.. autofunction:: device_count +.. autofunction:: sync +.. autofunction:: compile +.. autofunction:: manual_seed + +runtime +---------------------------------- +.. automodule:: torch_xla.runtime +.. autofunction:: device_type +.. autofunction:: local_process_count +.. autofunction:: local_device_count +.. autofunction:: addressable_device_count +.. autofunction:: global_device_count +.. autofunction:: global_runtime_device_count +.. autofunction:: world_size +.. autofunction:: global_ordinal +.. autofunction:: local_ordinal +.. autofunction:: get_master_ip +.. autofunction:: use_spmd +.. autofunction:: is_spmd +.. autofunction:: initialize_cache + + +xla_model +---------------------------------- + +.. automodule:: torch_xla.core.xla_model +.. autofunction:: xla_device +.. autofunction:: xla_device_hw +.. autofunction:: is_master_ordinal +.. autofunction:: all_reduce +.. autofunction:: all_gather +.. autofunction:: all_to_all +.. autofunction:: add_step_closure +.. autofunction:: wait_device_ops +.. autofunction:: optimizer_step +.. autofunction:: save +.. autofunction:: rendezvous +.. autofunction:: mesh_reduce +.. autofunction:: set_rng_state +.. autofunction:: get_rng_state +.. autofunction:: get_memory_info +.. autofunction:: get_stablehlo +.. autofunction:: get_stablehlo_bytecode + +distributed +---------------------------------- + +.. automodule:: torch_xla.distributed.parallel_loader +.. autoclass:: MpDeviceLoader + +.. automodule:: torch_xla.distributed.xla_multiprocessing +.. autofunction:: spawn + +spmd +---------------------------------- +.. automodule:: torch_xla.distributed.spmd +.. autofunction:: mark_sharding +.. autofunction:: clear_sharding +.. autofunction:: set_global_mesh +.. autofunction:: get_global_mesh +.. autofunction:: get_1d_mesh +.. autoclass:: Mesh +.. autoclass:: HybridMesh + +experimental +---------------------------------- +.. automodule:: torch_xla.experimental +.. autofunction:: eager_mode + +debug +---------------------------------- + +.. automodule:: torch_xla.debug.metrics +.. autofunction:: metrics_report +.. autofunction:: short_metrics_report +.. autofunction:: counter_names +.. autofunction:: counter_value +.. autofunction:: metric_names +.. autofunction:: metric_data \ No newline at end of file diff --git a/docs/dynamic_shape.md b/docs/source/learn/dynamic_shape.md similarity index 98% rename from docs/dynamic_shape.md rename to docs/source/learn/dynamic_shape.md index 6804e0e49fb..a2855c909fc 100644 --- a/docs/dynamic_shape.md +++ b/docs/source/learn/dynamic_shape.md @@ -36,7 +36,7 @@ Here are some numbers we get when we run the MLP model for 100 iterations: | Number of compilations | 102 | 49 | | Compilation cache hit | 198 | 1953 | -![Performance comparison (a) without dynamic shape (b) with dynamic shape](_static/img/dynamic_shape_mlp_perf.png) +![Performance comparison (a) without dynamic shape (b) with dynamic shape](../_static/img/dynamic_shape_mlp_perf.png) One of the motivations of the dynamic shape is to reduce the number of excessive recompilation when the shape keeps changing between iterations. From the figure above, you can see the number of compilations reduced by half which results in the drop of the training time. diff --git a/docs/source/learn/eager.md b/docs/source/learn/eager.md new file mode 100644 index 00000000000..2b6b1fa48e6 --- /dev/null +++ b/docs/source/learn/eager.md @@ -0,0 +1,132 @@ +# Eager Mode + Compile API + +In this doc we will go over how to use PyTorch/XLA's new experimental +`eager` mode with the `compile` API. The goal is to make PyTorch/XLA +experience more aligned with the native PyTorch and make development +process easier. + +Currently PyTorch/XLA runs on the LazyTensor tracing mode by default. In +the following code + +``` python +import torch +import torch_xla +import torchvision + +device = torch_xla.device() +model = torchvision.models.resnet18().to(device) +input = torch.randn(64, 3, 224, 224).to(device) + +# model tracing +res = model(input) + +# model execution, same as `xm.mark_step` +torch_xla.sync() +``` + +The actual model compilation and device execution happens when +`torch_xla.sync` is called. There are multiple drawback of this +approach. + +1. Users are often confused about when the framework is tracing and + when the framework is executing. +2. Non-core model code(data preprocessing for example) often generates + some small pending execution that gets leaked into the main + graph(step function) and causes recompilation. The recompilation of + the whole graph is usually very expensive. +3. It is hard to debug when/why recompilation happens. + +To mitigate above issues we want to introduce the new UX with eager and +compile. + +## Basic Usage + +``` python +import torch +import torch_xla +import torchvision + +# Run ops eagerly by default +torch_xla.experimental.eager_mode(True) + +device = torch_xla.device() +model = torchvision.models.resnet18().to(device) + +# Mark the function to be compiled +compiled_model = torch_xla.compile(model) +input = torch.randn(64, 3, 224, 224).to(device) + +# Compilation and execution happens right away. +res = compiled_model(input) +``` + +Note that + +1. Currently user has to manually enable the eager mode by + `torch_xla.experimental.eager_mode(True)`. +2. The region of the code that wants to be compiled should be wrapped + by `torch_xla.compile`. + +The implementation of the `torch_xla.compile` is actually pretty +straight forward, it disable the eager mode when entering the target +function and start tracing. It will call the `torch_xla.sync()` when +target function returns and reenable the eager mode. You can expect the +same perfomrance by using the `eager` + `compile` API compared to the +existing `mark_step/sync` approach. + +### Inference + +``` python +torch_xla.experimental.eager_mode(True) +compiled_model = torch.compile(model, backend="openxla") +``` + +It is recommened to use the `torch.compile` instead of +`torch_xla.compile` for inference to reduce the tracing overhad. + +### Training + +``` python +torch_xla.experimental.eager_mode(True) + +def step_fn(model, data, target, loss_fn, optimizer): + optimizer.zero_grad() + logits = model(data) + loss = loss_fn(logits, target) + loss.backward() + optimizer.step() + return loss + +step_fn = torch_xla.compile(step_fn) +``` + +In training we asked user to refactor the `step_fn` out because it is +usually better to compile the model's forward, backward and optimizer +together. The long term goal is to also use `torch.compile` for training +but right now we recommend user to use `torch_xla.compile`(for +perfomrance reason). + +## Benchmark + +I run a 2 layer decoder only model training(it is pretty much just a +llama2) with fake data on a single chip of v4-8 for 300 steps. Below is +the number I observed. + + Mode token/s + --------------------------- --------- + Tracing mode (base line) 147 + Eager mode 65 + Eager + torch_xla compile 147 + + : Eager mode benchmarks + +Eager mode can achieve ~45% performance of the fully compiled model for +the decoder only model. For more information, see +[train_decoder_only_base.py](https://github.com/pytorch/xla/blob/master/examples/train_decoder_only_base.py) +and [eager example](https://github.com/pytorch/xla/tree/master/examples/eager). +Note that perfomrane of the eager mode is very model dependent. When I +tried to run the resnet50, the eager mode perfomrance is \~1% of the +compiled mode. We don't exepct user to use eager mode to execute the +main training loop. Eager mode is meant to be used to handle non-core +part of the training/inference logic(Data preprocessing, random number +generations etc) or debug. diff --git a/docs/source/learn/pjrt.md b/docs/source/learn/pjrt.md new file mode 100644 index 00000000000..6fc84bf9de3 --- /dev/null +++ b/docs/source/learn/pjrt.md @@ -0,0 +1,438 @@ +# PJRT Runtime + +PyTorch/XLA has migrated from the TensorFlow-based XRT runtime to the +[PJRT runtime](https://github.com/openxla/xla/tree/main/xla/pjrt) used +by [JAX](https://github.com/google/jax). + +If you encounter a bug with PJRT, please file an issue on GitHub with +the `runtime` tag. + +*New features in PyTorch/XLA r2.1*: + +- PJRT is stable in PyTorch/XLA r2.1! +- Public runtime APIs have moved from `torch_xla.experimental.pjrt` to + `torch_xla.runtime`. + - The `pjrt://` init method has been renamed to `xla://`, and it + is registered by `torch_xla.distributed.xla_backend`. + - The previous `torch_xla.experimental.*` names are still + available in this release for compatibility. +- `torchrun` is now supported when using `init_method='xla://'`. +- New plugins for XPU and Neuron via the PJRT C API. + +*New features in PyTorch/XLA r2.0*: + +- PJRT will be configured by default if you don't pass in any other + runtime configuration. If you continue to set XRT configuration + (`XRT_TPU_CONFIG`), this change has no impact +- New TPU runtime implementation in `libtpu` improves performance by + up to 30%. +- New `xm.rendezvous` implementation that scales to thousands of TPU + cores +- \[experimental\] `torch.distributed` support for TPU v2 and v3, + including `pjrt://` `init_method` + +## TL;DR + +- To use the PJRT preview runtime, set the `PJRT_DEVICE` environment + variable to `CPU`, `TPU`, or `CUDA` +- In XRT, all distributed workloads are multiprocess, with one process + per device. On TPU v2 and v3 in PJRT, workloads are multiprocess and + multithreaded (4 processes with 2 threads each), so your workload + should be thread-safe. See [Multithreading on TPU + v2/v3](#multithreading-on-tpu-v2v3) and the [Multiprocessing section + of the API + guide](https://github.com/pytorch/xla/blob/master/API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing) + for more information. Key differences to keep in mind: + - To initialize a model in a thread-safe way, either broadcast the + parameters across replicas after initialization + (`torch_xla.experimental.pjrt.broadcast_master_param`) or load + each replica's parameters from a common checkpoint. + - For other random number generation, use `torch.Generator` where + possible. The global `torch` RNG is *not* thread-safe, even if + you set the same `torch.manual_seed` across replicas. + - To use `torch.distributed`, import + `torch_xla.experimental.pjrt_backend` and use the `xla://` + `init_method`. + - These steps are optional for GPU and TPU v4. + +Sample diff from XRT to PJRT: + +``` diff +import os + +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.optim as optim +import torch.distributed as dist +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.distributed.parallel_loader as pl +import torch_xla.distributed.xla_backend ++import torch_xla.runtime as xr + + +def _mp_fn(index): + device = xm.xla_device() +- dist.init_process_group('xla', rank=xr.global_ordinal(), world_size=xr.world_size()) ++ dist.init_process_group('xla', init_method='xla://') + + torch.manual_seed(42) + model = nn.Linear(128, 10).to(device) + ++ # Optional for TPU v4 and GPU ++ xm.broadcast_master_param(model) + model = DDP(model, gradient_as_bucket_view=True) + + loss_fn = nn.MSELoss() + optimizer = optim.SGD(model.parameters(), lr=.001) + + for i in range(10): + data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device) + + optimizer.zero_grad() + output = model(data) + loss = loss_fn(output, target) + loss.backward() + + optimizer.step() + xm.mark_step() + + # Print mean parameters so we can confirm they're the same across replicas + print([p.mean() for p in model.parameters()]) + +if __name__ == '__main__': +- os.environ['XRT_TPU_CONFIG'] = 'localservice;0;localhost:51011' +- os.environ['MASTER_ADDR'] = 'localhost' +- os.environ['MASTER_PORT'] = '12355' + ++ # Recommended: set PJRT_DEVICE to your local device type ++ os.environ['PJRT_DEVICE'] = 'TPU' + + torch_xla.launch(_mp_fn) +``` + +## Benefits + +- Simple runtime configuration: just set `PJRT_DEVICE` to `TPU`, + `CPU`, or `CUDA` and start using XLA! Or, let PJRT select a device + automatically based on your environment. +- Improved performance: reduced overhead from gRPC means faster + end-to-end execution. On TorchBench 2.0, we observed a \>35% + improvement in training time on TPU v4. +- Easy pod execution: just copy your code to each TPU worker, and + execute them all at the same time with + `gcloud compute tpus tpuvm ssh --worker=all`. +- Better scaling: removes [XRT's limitation on parameter + sizes](https://github.com/pytorch/xla/pull/3920) and supports up to + 2048 TPU chips. + +## Quickstart + +To start using PJRT with PyTorch/XLA, all you need to do is set the +`PJRT_DEVICE` environment variable. If you're working on a TPU v2 or v3, +keep reading to learn about the differences between TPU v2 and v3 and +v4. + +### CPU + +On any machine with PyTorch/XLA installed, you can run our MNIST example +on CPU like this: + + PJRT_DEVICE=CPU python3 xla/test/test_train_mp_mnist.py --fake_data + +### TPU + +To create a new TPU with PyTorch/XLA r2.0 installed: + + gcloud alpha compute tpus tpu-vm create $USER-pjrt --accelerator-type=v4-8 --version=tpu-vm-v4-pt-2.0 --zone=us-central2-b --project=$PROJECT + +On a v4-8, you can run our ResNet50 example like this: + + git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git + PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1 + +By default, PJRT will use all TPU chips. To use only one TPU chip, +configure `TPU_PROCESS_BOUNDS` and `TPU_VISIBLE_CHIPS`: + + TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_CHIPS=0 PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1 + +#### Pods + +On TPU Pods, use `gcloud` to run your command on each TPU in parallel: + + gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r1.13 https://github.com/pytorch/xla.git" + gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1" + +#### Docker + +You can also use Docker to run your workload in a container with +PyTorch/XLA preinstalled: + + export DOCKER_IMAGE=gcr.io/... + + # Optional: authenticate docker if your image is in a private GCP repository + gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo gcloud auth configure-docker" + + # Run your workload + gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo docker run --rm --privileged --net=host -e PJRT_DEVICE=TPU $DOCKER_IMAGE python pytorch/xla/test/test_train_mp_imagenet.py --fake_data" + +Note that `docker run` requires privileged access to the host +(`--privileged`) to expose the TPU device to the container. Docker on +TPU pods is only supported with host networking `--net=host` at this +time. See the [Cloud TPU +documentation](https://cloud.google.com/tpu/docs/run-in-container) for +more information. + +### GPU + +### Single-node GPU training + +To use GPUs with PJRT, simply set `PJRT_DEVICE=CUDA` and configure +`GPU_NUM_DEVICES` to the number of devices on the host. For example: + + PJRT_DEVICE=CUDA GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1 + +You can also use `torchrun` to initiate the single-node multi-GPU +training. For example, + + PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1 + +In the above example, `--nnodes` means how many machines (physical +machines or VMs) to be used (it is 1 since we do single-node training). +`--nproc-per-node` means how many GPU devices to be used. + +### Multi-node GPU training + +**Note that this feature only works for cuda 12+**. Similar to how +PyTorch uses multi-node training, you can run the command as below: + + PJRT_DEVICE=CUDA torchrun \ + --nnodes=${NUMBER_GPU_VM} \ + --node_rank=${CURRENT_NODE_RANK} \ + --nproc_per_node=${NUMBER_LOCAL_GPU_DEVICES} \ + --rdzv_endpoint= multinode_training.py + +- `--nnodes`: how many GPU machines to be used. +- `--node_rank`: the index of the current GPU machines. The value can + be 0, 1, ..., \${NUMBER_GPU_VM}-1. +- `--nproc_per_node`: the number of GPU devices to be used on the + current machine. +- `--rdzv_endpoint`: the endpoint of the GPU machine with + node_rank==0, in the form `host:port`. The `host` will be the + internal IP address. The `port` can be any available port on the + machine. For single-node training/inference, this parameter can be + omitted. + +For example, if you want to train on 2 GPU machines: machine_0 and +machine_1, on the first GPU machine machine_0, run + + # PJRT_DEVICE=CUDA torchrun \ + --nnodes=2 \ + --node_rank=0 \ + --nproc_per_node=4 \ + --rdzv_endpoint=":12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1 + +On the second GPU machine, run + + # PJRT_DEVICE=CUDA torchrun \ + --nnodes=2 \ + --node_rank=1 \ + --nproc_per_node=4 \ + --rdzv_endpoint=":12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1 + +the difference between the 2 commands above are `--node_rank` and +potentially `--nproc_per_node` if you want to use different number of +GPU devices on each machine. All the rest are identical. For more +information about `torchrun`, please refer to this +[page](https://pytorch.org/docs/stable/elastic/run.html). + +## Differences from XRT + +Although in most cases we expect PJRT and XRT to work mostly +interchangeably from the end-user's perspective (especially on TPU v4), +there are some subtle differences that are important to keep in mind. +Importantly, XRT was designed around the TPU Node architecture, so it +will always spawn a client and a server process, even on TPU VMs. Thus, +every batch of inputs has additional latency from serializing and +deserializing data to send it over the network. + +PJRT uses the local device directly with no intermediate server process. +In the default configuration, PJRT will create one process per TPU chip, +or 4 processes per TPU host. See the [Cloud TPU +documentation](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) +for more information about TPU architecture. + +- Performance gains are possible for workloads constrained overhead + from . +- Under XRT, the server process is the only process that interacts + with the TPU devices, and client processes don't have direct access + to the TPU devices. When profiling a single-host TPU (e.g. v3-8 or + v4-8), you would normally see 8 device traces (one for each TPU + core). With PJRT, each process has one chip, and a profile from that + process will show only 2 TPU cores. + - For the same reason, profiling does not work on TPU Pods with + XRT, because the server process runs independently from the + user's model code. PJRT does not have that constraint, so it is + possible to profile 2 TPU cores per process in a TPU Pod. +- PJRT only supports the TPU VM architecture and we have no plans to + support the TPU Node architecture with PJRT. +- Runtime configuration is significantly simpler with PJRT. `xla_dist` + is not required to run TPU Pod workloads. Instead, copy your code to + each TPU host + (`[gcloud compute tpus tpu-vm scp](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/scp)`) + and run the code on each host in parallel + (e.g. `[gcloud compute tpus tpu-vm ssh --workers=all --command="PJRT_DEVICE=TPU python run.py"](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/ssh)`) +- `xm.rendezvous` has been reimplemented using XLA-native collective + communication to enhance stability on large TPU pods. See below for + more details. + +### Multithreading on TPU v2/v3 + +On TPU v2 and v3, **distributed workloads always run multithreaded**, +since each TPU core exposes two TPU cores as devices and only one +process may open a TPU chip at a time. In its default configuration, +`xmp.spawn` automatically spawns as many processes as possible (4 per +TPU host) and creates two threads per process (one per TPU core). + +Note: on TPU v4, each TPU chip is represented as one PyTorch device, so +distributed workloads will run across 4 processes, each with only one +thread. This is identical to XRT's behavior. + +In most cases, this will not require substantial changes to your +existing code. The main change you will have to make in most cases is to +model initialization. Because `torch`'s global RNG is shared between +threads, results will vary between threads and runs even if you set +`torch.manual_seed` to the same value in every replica. To get +consistent parameters between replicas, either use +`torch_xla.experimental.pjrt.broadcast_master_param` to broadcast one +replica's parameters to all other replicas, or load each replica's +parameters from a common checkpoint. + +### Changes to xm.rendezvous + +*New in PyTorch/XLA r2.0* + +With XRT, worker 0 runs a mesh master service, and all processes on all +workers connect to that service over gRPC. In practice, we found that +running a single mesh master process was unreliable on TPU pods with +thousands of chips due to the number of inbound connections to worker 0. +A single client process timing out could cause a failure and force the +entire workload to restart. + +Thus, we have reimplemented `xm.rendezvous` with native XLA collective +communication, which is much more stable and well-tested on large TPU +pods. This imposes two new constraints compared to the XRT +implementation: + +- Because the payload has to become part of the XLA graph, + `xm.mark_step` is called both before and after the data is + transferred. Calling `xm.rendezvous` in the middle of model code may + force an unwanted compilation. +- Because XLA does not permit collective operations to run on a subset + of workers, all workers must participate in the `rendezvous`. + +If you require the old behavior of `xm.rendezvous` (i.e. communicating +data without altering the XLA graph and/or synchronizing a subset of +workers), consider using `` `torch.distributed.barrier `` +\<\>[\_\_ +or ]{.title-ref}`torch.distributed.all_gather_object` +\<\>[\_\_ +with a ]{.title-ref}[gloo]{.title-ref}[ process group. If you are also +using the ]{.title-ref}[xla]{.title-ref}[ +]{.title-ref}[torch.distributed]{.title-ref}[ backend, you can use +]{.title-ref}[torch.new_group]{.title-ref}[ to create a +]{.title-ref}[gloo]{.title-ref}[ subgroup. See \`this example +\]{.title-ref}\_\_ +from the PyTorch documentation. Keep in mind these constraints: + +- `torch.distributed` is not fully supported on TPU v2/v3. Only a + subset of operations with the `xla` backend are implemented, and + `gloo` will likely not work as expected in a multithreaded context. +- In our experiments, `gloo` does not scale well to thousands of TPU + chips, so expect this alternative to be less reliable than using + `xm.rendezvous` with PJRT at large scales. + +### PJRT and torch.distributed + +*New in PyTorch/XLA r2.0* + +When using PJRT with `torch.distributed` and +`[torch.nn.parallel.DistributedDataParallel](https://github.com/pytorch/xla/blob/master/docs/ddp.md)` +we strongly recommend using the new `xla://` `init_method`, which +automatically finds the replica IDs, world size, and master IP by +querying the runtime. For example: + +``` python +import torch +import torch_xla +import torch.distributed as dist +import torch_xla.core.xla_model as xm +from torch_xla.experimental import pjrt + +# Required for `xla://` init_method and `xla` backend +import torch_xla.distributed.xla_backend + +def _all_gather(index: int): + # No need to pass in `rank` or `world_size` + dist.init_process_group('xla', init_method='xla://') + + t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device()) + output = [torch.zeros_like(t) for _ in range(dist.get_world_size())] + dist.all_gather(output, t) + + xm.mark_step() + print(output) + +if __name__ == '__main__': + torch_xla.launch(_all_gather) +``` + +Note: Although the `xla://` init_method is not required on TPU v4, it is +still recommended. If you use `env://`, `MASTER_ADDR` must be set to IP +host that has device 0, which is *not* always worker 0. The `xla://` +init_method finds this IP automatically. + +Note: For TPU v2/v3, you still need to import +`torch_xla.experimental.pjrt_backend`, as TPU v2/v3 support in +`torch.distributed` is still experimental. + +For more information about using `DistributedDataParallel` on +PyTorch/XLA, see [ddp.md](./ddp.md) on TPU V4. For an example that uses +DDP and PJRT together, run the following [example +script](../test/test_train_mp_imagenet.py) on a TPU: + +``` bash +PJRT_DEVICE=TPU python xla/test/test_train_mp_mnist.py --ddp --pjrt_distributed --fake_data --num_epochs 1 +``` + +## Performance + +TorchBench shows improvements in average training time across tasks with +PJRT compared to XRT, with an average improvement of over 35% on TPU +v4-8. The benefits vary significantly by task and model type, ranging +from 0% to 175%. The following chart shows the breakdown by task: + +![PJRT vs XRT](../_static/img/torchbench_pjrt_vs_xrt.svg) + +### New TPU runtime + +*New in PyTorch/XLA r2.0* + +The PyTorch/XLA r2.0 release introduces support for the [PJRT Plugin +API](https://github.com/openxla/community/blob/main/rfcs/20230123-pjrt-plugin.md#rfc-openxla-pjrt-plugin), +used to access the new TFRT-based TPU runtime in `libtpu`. This is now +the default runtime when `PJRT_DEVICE=TPU` is set. The legacy +StreamExecutor-based TPU runtime used in 1.13 will still be available +with `PJRT_DEVICE=TPU_LEGACY` in the 2.0 release, but it will be removed +in a future version. If you encounter an issue that only happens on +`TPU` and not `TPU_LEGACY`, please file an issue on GitHub. + +In most cases, we expect performance to be similar between the two +runtimes, but in some cases, the new runtime may be up to 30% faster. +The following chart shows the breakdown by task: + +![TFRT vs StreamExecutor](../_static/img/torchbench_tfrt_vs_se.svg) + +Note: the improvements shown in this chart are also included in the PJRT +vs XRT comparison. diff --git a/docs/source/learn/pytorch-on-xla-devices.md b/docs/source/learn/pytorch-on-xla-devices.md new file mode 100644 index 00000000000..5bf4953a6ce --- /dev/null +++ b/docs/source/learn/pytorch-on-xla-devices.md @@ -0,0 +1,393 @@ +# PyTorch on XLA Devices + +PyTorch runs on XLA devices, like TPUs, with the [torch_xla package](https://github.com/pytorch/xla/). This document describes how +to run your models on these devices. + +## Creating an XLA Tensor + +PyTorch/XLA adds a new `xla` device type to PyTorch. This device type +works just like other PyTorch device types. For example, here's how to +create and print an XLA tensor: + +``` python +import torch +import torch_xla +import torch_xla.core.xla_model as xm + +t = torch.randn(2, 2, device=xm.xla_device()) +print(t.device) +print(t) +``` + +This code should look familiar. PyTorch/XLA uses the same interface as +regular PyTorch with a few additions. Importing `torch_xla` initializes +PyTorch/XLA, and `xm.xla_device()` returns the current XLA device. This +may be a CPU or TPU depending on your environment. + +## XLA Tensors are PyTorch Tensors + +PyTorch operations can be performed on XLA tensors just like CPU or CUDA +tensors. + +For example, XLA tensors can be added together: + +``` python +t0 = torch.randn(2, 2, device=xm.xla_device()) +t1 = torch.randn(2, 2, device=xm.xla_device()) +print(t0 + t1) +``` + +Or matrix multiplied: + +``` python +print(t0.mm(t1)) +``` + +Or used with neural network modules: + +``` python +l_in = torch.randn(10, device=xm.xla_device()) +linear = torch.nn.Linear(10, 20).to(xm.xla_device()) +l_out = linear(l_in) +print(l_out) +``` + +Like other device types, XLA tensors only work with other XLA tensors on +the same device. So code like + +``` python +l_in = torch.randn(10, device=xm.xla_device()) +linear = torch.nn.Linear(10, 20) +l_out = linear(l_in) +print(l_out) +# Input tensor is not an XLA tensor: torch.FloatTensor +``` + +will throw an error since the `torch.nn.Linear` module is on the CPU. + +## Running Models on XLA Devices + +Building a new PyTorch network or converting an existing one to run on +XLA devices requires only a few lines of XLA-specific code. The +following snippets highlight these lines when running on a single device +and multiple devices with XLA multi-processing. + +### Running on a Single XLA Device + +The following snippet shows a network training on a single XLA device: + +``` python +import torch_xla.core.xla_model as xm + +device = xm.xla_device() +model = MNIST().train().to(device) +loss_fn = nn.NLLLoss() +optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) + +for data, target in train_loader: + optimizer.zero_grad() + data = data.to(device) + target = target.to(device) + output = model(data) + loss = loss_fn(output, target) + loss.backward() + + optimizer.step() + xm.mark_step() +``` + +This snippet highlights how easy it is to switch your model to run on +XLA. The model definition, dataloader, optimizer and training loop can +work on any device. The only XLA-specific code is a couple lines that +acquire the XLA device and mark the step. Calling `xm.mark_step()` at +the end of each training iteration causes XLA to execute its current +graph and update the model's parameters. See [XLA Tensor Deep +Dive](#xla-tensor-deep-dive) for more on how XLA creates graphs and runs +operations. + +### Running on Multiple XLA Devices with Multi-processing + +PyTorch/XLA makes it easy to accelerate training by running on multiple +XLA devices. The following snippet shows how: + +``` python +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.distributed.parallel_loader as pl + +def _mp_fn(index): + device = xm.xla_device() + mp_device_loader = pl.MpDeviceLoader(train_loader, device) + + model = MNIST().train().to(device) + loss_fn = nn.NLLLoss() + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) + + for data, target in mp_device_loader: + optimizer.zero_grad() + output = model(data) + loss = loss_fn(output, target) + loss.backward() + xm.optimizer_step(optimizer) + +if __name__ == '__main__': + torch_xla.launch(_mp_fn, args=()) +``` + +There are three differences between this multi-device snippet and the +previous single device snippet. Let's go over then one by one. + +- `torch_xla.launch()` + - Creates the processes that each run an XLA device. + - This function is a wrapper of multithreading spawn to allow user + run the script with torchrun command line also. Each process + will only be able to access the device assigned to the current + process. For example on a TPU v4-8, there will be 4 processes + being spawn up and each process will own a TPU device. + - Note that if you print the `xm.xla_device()` on each process you + will see `xla:0` on all devices. This is because each process + can only see one device. This does not mean multi-process is not + functioning. The only execution is with PJRT runtime on TPU v2 + and TPU v3 since there will be `#devices/2` processes and each + process will have 2 threads(check this + [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) + for more details). +- `MpDeviceLoader` + - Loads the training data onto each device. + - `MpDeviceLoader` can wrap on a torch dataloader. It can preload + the data to the device and overlap the dataloading with device + execution to improve the performance. + - `MpDeviceLoader` also call `xm.mark_step` for you every + `batches_per_execution`(default to 1) batch being yield. +- `xm.optimizer_step(optimizer)` + - Consolidates the gradients between devices and issues the XLA + device step computation. + - It is pretty much a `all_reduce_gradients` + + `optimizer.step()` + `mark_step` and returns the loss being + reduced. + +The model definition, optimizer definition and training loop remain the +same. + +> **NOTE:** It is important to note that, when using multi-processing, +> the user can start retrieving and accessing XLA devices only from +> within the target function of `torch_xla.launch()` (or any function +> which has `torch_xla.launch()` as parent in the call stack). + +See the [full multiprocessing +example](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist.py) +for more on training a network on multiple XLA devices with +multi-processing. + +### Running on TPU Pods + +Multi-host setup for different accelerators can be very different. This +doc will talk about the device independent bits of multi-host training +and will use the TPU + PJRT runtime(currently available on 1.13 and 2.x +releases) as an example. + +Before you being, please take a look at our user guide at +[here](https://cloud.google.com/tpu/docs/run-calculation-pytorch) which +will explain some Google Cloud basis like how to use `gcloud` command +and how to setup your project. You can also check +[here](https://cloud.google.com/tpu/docs/how-to) for all Cloud TPU +Howto. This doc will focus on the PyTorch/XLA perspective of the Setup. + +Let's assume you have the above mnist example from above section in a +`train_mnist_xla.py`. If it is a single host multi device training, you +would ssh to the TPUVM and run command like + + PJRT_DEVICE=TPU python3 train_mnist_xla.py + +Now in order to run the same models on a TPU v4-16 (which has 2 host, +each with 4 TPU devices), you will need to - Make sure each host can +access the training script and training data. This is usually done by +using the `gcloud scp` command or `gcloud ssh` command to copy the +training scripts to all hosts. - Run the same training command on all +hosts at the same time. + +``` console + gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=$ZONE --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 train_mnist_xla.py" +``` + +Above `gcloud ssh` command will ssh to all hosts in TPUVM Pod and run +the same command at the same time.. + +> **NOTE:** You need to run run above `gcloud` command outside of the +> TPUVM vm. + +The model code and training script is the same for the multi-process +training and the multi-host training. PyTorch/XLA and the underlying +infrastructure will make sure each device is aware of the global +topology and each device's local and global ordinal. Cross-device +communication will happen across all devices instead of local devices. + +For more details regarding PJRT runtime and how to run it on pod, please +refer to this +[doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpu). For +more information about PyTorch/XLA and TPU pod and a complete guide to +run a resnet50 with fakedata on TPU pod, please refer to this +[guide](https://cloud.google.com/tpu/docs/pytorch-pods). + +## XLA Tensor Deep Dive + +Using XLA tensors and devices requires changing only a few lines of +code. But even though XLA tensors act a lot like CPU and CUDA tensors, +their internals are different. This section describes what makes XLA +tensors unique. + +### XLA Tensors are Lazy + +CPU and CUDA tensors launch operations immediately or eagerly. XLA +tensors, on the other hand, are lazy. They record operations in a graph +until the results are needed. Deferring execution like this lets XLA +optimize it. A graph of multiple separate operations might be fused into +a single optimized operation, for example. + +Lazy execution is generally invisible to the caller. PyTorch/XLA +automatically constructs the graphs, sends them to XLA devices, and +synchronizes when copying data between an XLA device and the CPU. +Inserting a barrier when taking an optimizer step explicitly +synchronizes the CPU and the XLA device. For more information about our +lazy tensor design, you can read [this +paper](https://arxiv.org/pdf/2102.13267.pdf). + +### Memory Layout + +The internal data representation of XLA tensors is opaque to the user. +They do not expose their storage and they always appear to be +contiguous, unlike CPU and CUDA tensors. This allows XLA to adjust a +tensor's memory layout for better performance. + +### Moving XLA Tensors to and from the CPU + +XLA tensors can be moved from the CPU to an XLA device and from an XLA +device to the CPU. If a view is moved then the data its viewing is also +copied to the other device and the view relationship is not preserved. +Put another way, once data is copied to another device it has no +relationship with its previous device or any tensors on it. Again, +depending on how your code operates, appreciating and accommodating this +transition can be important. + +### Saving and Loading XLA Tensors + +XLA tensors should be moved to the CPU before saving, as in the +following snippet: + +``` python +import torch +import torch_xla +import torch_xla.core.xla_model as xm + +device = xm.xla_device() + +t0 = torch.randn(2, 2, device=device) +t1 = torch.randn(2, 2, device=device) + +tensors = (t0.cpu(), t1.cpu()) + +torch.save(tensors, 'tensors.pt') + +tensors = torch.load('tensors.pt') + +t0 = tensors[0].to(device) +t1 = tensors[1].to(device) +``` + +This lets you put the loaded tensors on any available device, not just +the one on which they were initialized. + +Per the above note on moving XLA tensors to the CPU, care must be taken +when working with views. Instead of saving views it is recommended that +you recreate them after the tensors have been loaded and moved to their +destination device(s). + +A utility API is provided to save data by taking care of previously +moving it to CPU: + +``` python +import torch +import torch_xla +import torch_xla.core.xla_model as xm + +xm.save(model.state_dict(), path) +``` + +In case of multiple devices, the above API will only save the data for +the master device ordinal (0). + +In case where memory is limited compared to the size of the model +parameters, an API is provided that reduces the memory footprint on the +host: + +``` python +import torch_xla.utils.serialization as xser + +xser.save(model.state_dict(), path) +``` + +This API streams XLA tensors to CPU one at a time, reducing the amount +of host memory used, but it requires a matching load API to restore: + +``` python +import torch_xla.utils.serialization as xser + +state_dict = xser.load(path) +model.load_state_dict(state_dict) +``` + +Directly saving XLA tensors is possible but not recommended. XLA tensors +are always loaded back to the device they were saved from, and if that +device is unavailable the load will fail. PyTorch/XLA, like all of +PyTorch, is under active development and this behavior may change in the +future. + +## Compilation Caching + +The XLA compiler converts the traced HLO into an executable which runs +on the devices. Compilation can be time consuming, and in cases where +the HLO doesn't change across executions, the compilation result can be +persisted to disk for reuse, significantly reducing development +iteration time. + +Note that if the HLO changes between executions, a recompilation will +still occur. + +This is currently an experimental opt-in API, which must be activated +before any computations are executed. Initialization is done through the +`initialize_cache` API: + +``` python +import torch_xla.runtime as xr +xr.initialize_cache('YOUR_CACHE_PATH', readonly=False) +``` + +This will initialize a persistent compilation cache at the specified +path. The `readonly` parameter can be used to control whether the worker +will be able to write to the cache, which can be useful when a shared +cache mount is used for an SPMD workload. + +If you want to use persistent compilation cache in the multi process +training(with `torch_xla.launch` or `xmp.spawn`), you should use the +different path for different process. + +``` python +def _mp_fn(index): + # cache init needs to happens inside the mp_fn. + xr.initialize_cache(f'/tmp/xla_cache_{index}', readonly=False) + .... + +if __name__ == '__main__': + torch_xla.launch(_mp_fn, args=()) +``` + +If you don't have the access to the `index`, you can use +`xr.global_ordinal()`. Check out the runnable example in +[here](https://github.com/pytorch/xla/blob/master/examples/data_parallel/train_resnet_xla_ddp.py). + +## Further Reading + +Additional documentation is available at the [PyTorch/XLA +repo](https://github.com/pytorch/xla/). More examples of running +networks on TPUs are available +[here](https://github.com/pytorch-tpu/examples). diff --git a/docs/source/learn/troubleshoot.md b/docs/source/learn/troubleshoot.md new file mode 100644 index 00000000000..3014ee1d33c --- /dev/null +++ b/docs/source/learn/troubleshoot.md @@ -0,0 +1,474 @@ +# Troubleshoot + +Note that the information in this section is subject to be removed in +future releases of the *PyTorch/XLA* software, since many of them are +peculiar to a given internal implementation which might change. + +## Sanity Check + +Before performing any in depth debugging, we want to do a sanity check +on the installed PyTorch/XLA. + +### Check PyTorch/XLA Version + +PyTorch and PyTorch/XLA version should match. Check out our +[README](https://github.com/pytorch/xla#getting-started) for more +detials on versions available. + +``` sh +vm:~$ python +>>> import torch +>>> import torch_xla +>>> print(torch.__version__) +2.1.0+cu121 +>>> print(torch_xla.__version__) +2.1.0 +``` + +### Perform A Simple Calculation + +``` sh +vm:~$ export PJRT_DEVICE=TPU +vm:~$ python3 +>>> import torch +>>> import torch_xla.core.xla_model as xm +>>> t1 = torch.tensor(100, device=xm.xla_device()) +>>> t2 = torch.tensor(200, device=xm.xla_device()) +>>> print(t1 + t2) +tensor(300, device='xla:0') +``` + +### Run Resnet With Fake Data + +For nightly + +``` sh +vm:~$ git clone https://github.com/pytorch/xla.git +vm:~$ python xla/test/test_train_mp_imagenet.py --fake_data +``` + +For release version `x.y`, you want to use the branch `rx.y`. For +example if you installed 2.1 release, you should do + +``` sh +vm:~$ git clone --branch r2.1 https://github.com/pytorch/xla.git +vm:~$ python xla/test/test_train_mp_imagenet.py --fake_data +``` + +If you can get the resnet to run we can conclude that torch_xla is +installed correctly. + +## Performance Debugging + +To diagnose performance issues, we can use the execution metrics and +counters provided by *PyTorch/XLA* The **first thing** to check when +model is slow is to generate a metrics report. + +Metrics report is extremely helpful in diagnosing issues. Please try to +include it in your bug report sent to us if you have it. + +## PyTorch/XLA Debugging Tool + +You can enable the PyTorch/XLA debugging tool by setting +`PT_XLA_DEBUG_LEVEL=2`, which provides a couple useful debugging +features. You can also lower the debug level to `1` to slip the +execution analysis. + +### Perform A Auto-Metrics Analysis + +The debugging tool will analyze the metrics report and provide a +summary. Some example output would be + +``` sh +pt-xla-profiler: CompileTime too frequent: 21 counts during 11 steps +pt-xla-profiler: TransferFromDeviceTime too frequent: 11 counts during 11 steps +pt-xla-profiler: Op(s) not lowered: aten::_ctc_loss, aten::_ctc_loss_backward, Please open a GitHub issue with the above op lowering requests. +pt-xla-profiler: CompileTime too frequent: 23 counts during 12 steps +pt-xla-profiler: TransferFromDeviceTime too frequent: 12 counts during 12 steps +``` + +### Compilation & Execution Analysis + +The debugging tool will analyze every compilation and execution for your +model. Some example output would be: + +``` sh +Compilation Analysis: ================================================================================ +Compilation Analysis: Compilation Cause +Compilation Analysis: mark_step in parallel loader at step end +Compilation Analysis: Graph Info: +Compilation Analysis: Graph Hash: c74c3b91b855b2b123f833b0d5f86943 +Compilation Analysis: Number of Graph Inputs: 35 +Compilation Analysis: Number of Graph Outputs: 107 +Compilation Analysis: Python Frame Triggered Execution: +Compilation Analysis: mark_step (/workspaces/dk3/pytorch/xla/torch_xla/core/xla_model.py:1055) +Compilation Analysis: next (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:44) +Compilation Analysis: __next__ (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:32) +Compilation Analysis: train_loop_fn (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:48) +Compilation Analysis: start_training (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:65) +Compilation Analysis: (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:73) +Compilation Analysis: -------------------------------------------------------------------------------- +Compilation Analysis: ================================================================================ + +Post Compilation Analysis: ================================================================================ +Post Compilation Analysis: Graph input size: 1.548000 GB +Post Compilation Analysis: Graph output size: 7.922460 GB +Post Compilation Analysis: Aliased Input size: 1.547871 GB +Post Compilation Analysis: Intermediate tensor size: 12.124478 GB +Post Compilation Analysis: Compiled program size: 0.028210 GB +Post Compilation Analysis: -------------------------------------------------------------------------------- +Post Compilation Analysis: ================================================================================ + +Execution Analysis: ================================================================================ +Execution Analysis: Execution Cause +Execution Analysis: mark_step in parallel loader at step end +Execution Analysis: Graph Info: +Execution Analysis: Graph Hash: c74c3b91b855b2b123f833b0d5f86943 +Execution Analysis: Number of Graph Inputs: 35 +Execution Analysis: Number of Graph Outputs: 107 +Execution Analysis: Python Frame Triggered Execution: +Execution Analysis: mark_step (/workspaces/dk3/pytorch/xla/torch_xla/core/xla_model.py:1055) +Execution Analysis: next (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:44) +Execution Analysis: __next__ (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:32) +Execution Analysis: train_loop_fn (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:48) +Execution Analysis: start_training (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:65) +Execution Analysis: (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:73) +Execution Analysis: -------------------------------------------------------------------------------- +Execution Analysis: ================================================================================ +``` + +Some common causes of Compilation/Executation are 1. User manually call +`mark_step`. 2. [Parallel +loader](https://github.com/pytorch/xla/blob/fe4af0080af07f78ca2b614dd91b71885a3bbbb8/torch_xla/distributed/parallel_loader.py#L49-L51) +call `mark_step` for every x (configurable) batch. 3. Exiting a +[profiler StepTrace +region](https://github.com/pytorch/xla/blob/fe4af0080af07f78ca2b614dd91b71885a3bbbb8/torch_xla/debug/profiler.py#L165-L171). +4. Dynamo decide to compile/execute the graph. 5. User trying to +access(often due to logging) the value of a tensor before the +`mark_step`. + +The executation caused by 1-4 are expected, and we want to avoid 5 by +either reduce the frequency of accessing tensor values or manually add a +`mark_step` before accessing. + +Users should expect to see this `Compilation Cause` + +`Executation Cause` pairs for first couple steps. After the model +stabilize users should expect to only see `Execution Cause`(you can +disable execution analysis by `PT_XLA_DEBUG_LEVEL=1`). To use +PyTorch/XLA efficiently, we expect the same models code to be run for +every step and compilation only happen once for every graph. If you keep +seeing `Compilation Cause`, you should try to dump the IR/HLO following +[this section](#common-debugging-environment-variables-combinations) and +compare the graphs for each step and understand the source of the +differences. + +Following section will explain how to get and understand a more detail +metrics report. + +## Get A Metrics Report + +Put the following line in your program to generate a report: + +``` python +import torch_xla.debug.metrics as met + +# For short report that only contains a few key metrics. +print(met.short_metrics_report()) +# For full report that includes all metrics. +print(met.metrics_report()) +``` + +## Understand The Metrics Report + +The report includes things like: - how many time we issue *XLA* +compilations and time spent on issuing. - how many times we execute and +time spent on execution - how many device data handles we create/destroy +etc. + +This information is reported in terms of percentiles of the samples. An +example is: + +``` sh +Metric: CompileTime + TotalSamples: 202 + Counter: 06m09s401ms746.001us + ValueRate: 778ms572.062us / second + Rate: 0.425201 / second + Percentiles: 1%=001ms32.778us; 5%=001ms61.283us; 10%=001ms79.236us; 20%=001ms110.973us; 50%=001ms228.773us; 80%=001ms339.183us; 90%=001ms434.305us; 95%=002ms921.063us; 99%=21s102ms853.173us +``` + +We also provide counters, which are named integer variables which track +internal software status. For example: + +``` sh +Counter: CachedSyncTensors + Value: 395 +``` + +In this report, any counter that starts with `aten::` indicates a +context switch between the XLA device and CPU, which can be a potential +performance optimization area in the model code. + +Counters are useful to understand which operations are routed back to +the CPU engine of *PyTorch*. They are fully qualified with their C++ +namespace: + + Counter: aten::nonzero + Value: 33 + +If you see `aten::` ops other than `nonzero` and `_local_scalar_dense`, +that usually means a missing lowering in PyTorch/XLA. Feel free to open +a feature request for it on [GitHub +issues](https://github.com/pytorch/xla/issues). + +## Clear The Metrics Report + +If you want to clear the metrics between steps/epochs, you can use + +``` python +import torch_xla.debug.metrics as met + +met.clear_all() +``` + +## PyTorch/XLA + Dynamo Debugging Tool + +You can enable the PyTorch/XLA + Dynamo debugging tool by setting +`XLA_DYNAMO_DEBUG=1`. + +## Performance Profiling + +To profile your workload in depth to understand bottlenecks please check +the following resources: + +- [Official + tutorial](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm) +- [Colab + notebook](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/pytorch-xla-profiling-colab.ipynb) +- [Sample MNIST training script with + profiling](https://github.com/pytorch/xla/blob/master/test/test_profile_mp_mnist.py) +- [Utility script for capturing performance + profiles](https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py) + +## Simple Benchmarking + +Take a look at: + +[examples/train_resnet_benchmark.py](https://github.com/pytorch/xla/blob/master/examples/train_resnet_benchmark.py) +for how to benchmark a PyTorch/XLA model. + +## Known Performance Caveats + +PyTorch/XLA behaves semantically like regular PyTorch and XLA tensors +share the full tensor interface with CPU & GPU tensors. However, +constraints in XLA/hardware and the lazy evaluation model suggest +certain patterns might result in bad performance. + +If your model shows bad performance, keep in mind the following caveats: + +1. **XLA/TPU yield degraded performance with too many recompilations.** + + XLA compilation is expensive. PyTorch/XLA automatically recompiles + the graph every time new shapes are encountered. Usually models + should stabilize within a few steps and you can see huge speedup for + the rest of training. + + In order to avoid recompilations, not only must shapes be constant, + but computations across XLA devices in all hosts should also be + constant. + + *Possible sources*: + + - Direct or indirect uses of `nonzero` introduce dynamic shapes; + for example, masked indexing `base[index]` where `index` is a + mask tensor. + - Loops with a different number of iterations between steps can + result in different execution graphs, thus require + recompilations. + + *Solution*: + + - Tensor shapes should be the same between iterations, or a low + number of shape variations should be used. + - Pad tensors to fixed sizes when possible. + +2. **Certain operations don't have native translations to XLA.** + + For these operations PyTorch/XLA automatically transfers to the CPU + memory, evaluates on CPU, and transfers the result back to the XLA + device. Doing too many such operations during the training step can + lead to significant slowdowns. + + *Possible sources*: + + - The `item()` operation explicitly asks to evaluate the result. + Don't use it unless it's necessary. + + *Solution*: + + - For most ops we can lower them to XLA to fix it. Checkout + [metrics report section](#metrics-report) to find out the + missing ops and open a feature request on + [GitHub](https://github.com/pytorch/xla/issues). + + - Even when a PyTorch tensor is known as a scalar, avoid using + tensor.item()\`. Keep it as a tensor and use tensor operations + on it. + + - Use `torch.where` to substitute control flow when applicable. + E.g. The control flow with `item()` used in + [clip_grad_norm](https://github.com/pytorch/pytorch/blob/de19eeee99a2a282fc441f637b23d8e50c75ecd1/torch/nn/utils/clip_grad.py#L33) + is problematic and impacts performance, so we have + [patched](https://github.com/pytorch/xla/blob/master/torch_patches/X10-clip_grad.diff) + `clip_grad_norm_` by calling `torch.where` instead, which gives + us a dramatic performance improvement. + + ``` python + ... + else: + device = parameters[0].device + total_norm = torch.zeros([], device=device if parameters else None) + for p in parameters: + param_norm = p.grad.data.norm(norm_type) ** norm_type + total_norm.add_(param_norm) + total_norm = (total_norm ** (1. / norm_type)) + clip_coef = torch.tensor(max_norm, device=device) / (total_norm + 1e-6) + for p in parameters: + p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device))) + ``` + +3. **Iterators in \`\`torch_xla.distributed.data_parallel\`\` may drop + the last few batches in the input iterator.** + + This is to make sure we do the same amount of work on all XLA + devices. + + *Solution*: + + - When dataset is small, and there are too few steps, this may + result in a no-op epoch. Therefore, it is better to use small + batch sizes in those cases. + +## XLA Tensor Quirks + +1. **XLA tensor internals are opaque.** XLA tensors always appear to be + contiguous and without storage. Networks should not try to check the + strides of XLA tensors. +2. **XLA tensors should be moved to the CPU before saving them.** + Saving XLA tensors directly causes them to be loaded back on the + device(s) they were saved from. If a device is unavailable at load + time then the load will fail. Moving XLA tensors to the CPU before + saving them lets you decide which device(s) to put the loaded + tensors on. This is necessary if you want to load the tensors on a + machine without XLA devices. Care should be taken moving the XLA + tensors to the CPU before saving them, however, as moving tensors + across device types does not preserve view relationships. Instead, + views should be reconstructed as necessary after the tensors are + loaded. +3. **Copying an XLA Tensor with Python's copy.copy returns a deep copy, + not a shallow copy.** Use a view of an XLA tensor to get a shallow + copy of it. +4. **Handling shared weights.** Modules can share weights by setting + the Parameters of one module to another. This "tying" of module + weights should be done **AFTER** the modules are moved to an XLA + device. Otherwise two independent copies of the shared tensor will + be made on the XLA device. + +## More Debugging Tools + +We don't expect users to use tools in this section to debug their +models. But we might ask for them when you submit a bug report since +they provide additional information that metrics report doesn't have. + +- `print(torch_xla._XLAC._get_xla_tensors_text([res]))` where `res` is + the result tensor prints out the IR. +- `print(torch_xla._XLAC._get_xla_tensors_hlo([res]))` where `res` is + the result tensor prints out the generated XLA HLO. + +Note these functions must be called prior to `mark_step()`, otherwise +the tensor will already be materialized. + +### Environment Variables + +There are also a number of environment variables which control the +behavior of the *PyTorch/XLA* software stack. + +Setting such variables will cause different degrees of performance +degradation, so they should only be enabled for debugging. + +- `XLA_IR_DEBUG`: Enables the *Python* stack trace to be captured + where creating IR nodes, hence allowing to understand which + *PyTorch* operation was responsible for generating the IR. +- `XLA_HLO_DEBUG`: Enables the *Python* stack frame captured when + *XLA_IR_DEBUG* is active, to be propagated to the *XLA* *HLO* + metadata. +- `XLA_SAVE_TENSORS_FILE`: The path to a file which will be used to + dump the IR graphs during execution. Note that the file can become + really big if the option is left enabled and the *PyTorch* program + let run for long time. The graphs are appended to the file, so to + have a clean sheet from run to run, the file should be explicitly + removed. +- `XLA_SAVE_TENSORS_FMT`: The format of the graphs stored within the + *XLA_SAVE_TENSORS_FILE* file. Can be `text` (the default), `dot` + (the *Graphviz* format) or `hlo`. +- `XLA_FLAGS=--xla_dump_to`: If set to `=/tmp/dir_name`, XLA compiler + will dump the unoptimized and optimzed HLO per compilation. +- `XLA_METRICS_FILE`: If set, the path to a local file where the + internal metrics will be saved at every step. Metrics will be + appended to the file, if already existing. +- `XLA_SAVE_HLO_FILE`: If set, the path to a local file where, in case + of compilation/execution error, the offending HLO graph will be + saved. +- `XLA_SYNC_WAIT`: Forces the XLA tensor sync operation to wait for + its completion, before moving to the next step. +- `XLA_USE_EAGER_DEBUG_MODE`: Forces the XLA tensor to execute + eagerly, meaning compile and execute the torch operations one by + one. This is useful to bypass the long compilation time but overall + step time will be a lot slower and memory usage will be higher since + all compiler optimizaiton will be skipped. +- `TF_CPP_LOG_THREAD_ID`: If set to 1, the TF logs will show the + thread ID helping with debugging multithreaded processes. +- `TF_CPP_VMODULE`: Environment variable used for TF VLOGs and takes + the form of `TF_CPP_VMODULE=name=value,...`. Note that for VLOGs you + must set `TF_CPP_MIN_LOG_LEVEL=0`. +- `TF_CPP_MIN_LOG_LEVEL`: Level to print messages for. + `TF_CPP_MIN_LOG_LEVEL=0` will turn on INFO logging, + `TF_CPP_MIN_LOG_LEVEL=1` WARNING and so on. Our PyTorch/XLA + `TF_VLOG` uses `tensorflow::INFO` level by default so to see VLOGs + set `TF_CPP_MIN_LOG_LEVEL=0`. +- `XLA_DUMP_HLO_GRAPH`: If set to `=1` in case of a compilation or + execution error the offending HLO graph will be dumped as part of + the runtime error raised by `xla_util.cc`. + +### Common Debugging Environment Variables Combinations + +- Record the graph execution in the IR format + + XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 XLA_SAVE_TENSORS_FMT="text" XLA_SAVE_TENSORS_FILE="/tmp/save1.ir" + +- Record the graph execution in the HLO format + + XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 XLA_SAVE_TENSORS_FMT="hlo" XLA_SAVE_TENSORS_FILE="/tmp/save1.hlo" + +- Show debugging VLOG for runtime and graph compilation/execution + + TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE="xla_graph_executor=5,pjrt_computation_client=3" + +### Reproducing PyTorch/XLA CI/CD unit test failures. + +You may see some test failures for a PR such as: + +To execute this test, run the following from the base repo dir: + +```bash +PYTORCH_TEST_WITH_SLOW=1 python ../test/test_torch.py -k test_put_xla_uint8 +``` + +Running this directly in the command line does not work. You need to set +the environment variable `TORCH_TEST_DEVICES` to your local +`pytorch/xla/test/pytorch_test_base.py`. For example: + +```bash +TORCH_TEST_DEVICES=/path/to/pytorch/xla/test/pytorch_test_base.py PYTORCH_TEST_WITH_SLOW=1 python ../test/test_torch.py -k test_put_xla_uint8 +``` +should work. diff --git a/docs/source/learn/xla-overview.md b/docs/source/learn/xla-overview.md new file mode 100644 index 00000000000..7fdb6b05237 --- /dev/null +++ b/docs/source/learn/xla-overview.md @@ -0,0 +1,558 @@ +# Pytorch/XLA overview + +This section provides a brief overview of the basic details of PyTorch +XLA, which should help readers better understand the required +modifications and optimizations of code. + +Unlike regular PyTorch, which executes code line by line and does not +block execution until the value of a PyTorch tensor is fetched, PyTorch +XLA works differently. It iterates through the python code and records +the operations on (PyTorch) XLA tensors in an intermediate +representation (IR) graph until it encounters a barrier (discussed +below). This process of generating the IR graph is referred to as +tracing (LazyTensor tracing or code tracing). PyTorch XLA then converts +the IR graph to a lower-level machine-readable format called HLO +(High-Level Opcodes). HLO is a representation of a computation that is +specific to the XLA compiler and allows it to generate efficient code +for the hardware that it is running on. HLO is fed to the XLA compiler +for compilation and optimization. Compilation is then cached by PyTorch +XLA to be reused later if/when needed. The compilation of the graph is +done on the host (CPU), which is the machine that runs the Python code. +If there are multiple XLA devices, the host compiles the code for each +of the devices separately except when using SPMD (single-program, +multiple-data). For example, v4-8 has one host machine and [four +devices](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4). +In this case the host compiles the code for each of the four devices +separately. In case of pod slices, when there are multiple hosts, each +host does the compilation for XLA devices it is attached to. If SPMD is +used, then the code is compiled only once (for given shapes and +computations) on each host for all the devices. + +![img](../_static/img/pytorchXLA_flow.svg) + +For more details and examples, please refer to the [LazyTensor +guide](https://pytorch.org/blog/understanding-lazytensor-system-performance-with-pytorch-xla-on-cloud-tpu/). + +The operations in the IR graph are executed only when values of tensors +are needed. This is referred to as evaluation or materialization of +tensors. Sometimes this is also called lazy evaluation and it can lead +to significant [performance +improvements](https://arxiv.org/pdf/2102.13267.pdf). + +The *synchronous* operations in Pytorch XLA, like printing, logging, +checkpointing or callbacks block tracing and result in slower execution. +In the case when an operation requires a specific value of an XLA +tensor, e.g. `print(xla_tensor_z)`, tracing is blocked until the value +of that tensor is available to the host. Note that only the part of the +graph responsible for computing that tensor value is executed. These +operations do not cut the IR graph, but they trigger host-device +communication through `TransferFromDevice`, which results in slower +performance. + +A *barrier* is a special instruction that tells XLA to execute the IR +graph and materialize the tensors. This means that the PyTorch XLA +tensors will be evaluated, and the results will be available to the +host. The user-exposed barrier in Pytorch XLA is +[xm.mark_step()](https://github.com/pytorch/xla/blob/bdceee54eca1269ee954f6cdd1868c584d0e88a4/torch_xla/core/xla_model.py#L808), +which breaks the IR graph and results in code execution on the XLA +devices. One of the key properties of `xm.mark_step` is that unlike +synchronous operations it does not block the further tracing while the +device is executing the graph. However, it does block access to the +values of the tensors that are being materialized. + +The example in the LazyTensor guide illustrates what happens in a simple +case of adding two tensors. Now, suppose we have a for loop that adds +XLA tensors and uses the value later: + +``` python +for x, y in tensors_on_device: + z += x + y +``` + +Without a barrier, the Python tracing will result in a single graph that +wraps the addition of tensors `len(tensors_on_device)` times. This is +because the `for` loop is not captured by the tracing, so each iteration +of the loop will create a new subgraph corresponding to the computation +of `z += x+y` and add it to the graph. Here is an example when +`len(tensors_on_device)=3`. + +![img](../_static/img/IRgraph_no_markstep.png) + +However, introducing a barrier at the end of the loop will result in a +smaller graph that will be compiled once during the first pass inside +the `for` loop and will be reused for the next +`len(tensors_on_device)-1` iterations. The barrier will signal to the +tracing that the graph traced so far can be submitted for execution, and +if that graph has been seen before, a cached compiled program will be +reused. + +``` python +for x, y in tensors_on_device: + z += x + y + xm.mark_step() +``` + +In this case there will be a small graph that is used +`len(tensors_on_device)=3` times. + +![img](../_static/img/IRgraph_markstep.png) + +It is important to highlight that in PyTorch XLA Python code inside for +loops is traced and a new graph is constructed for each iteration if +there is a barrier at the end. This can be a significant performance +bottleneck. + +The XLA graphs can be reused when the same computation happens on the +same shapes of tensors. If the shapes of the inputs or intermediate +tensors change, then the XLA compiler will recompile a new graph with +the new tensor shapes. This means that if you have dynamic shapes or if +your code does not reuse tensor graphs, running your model on XLA will +not be suitable for that use case. Padding the input into a fixed shape +can be an option to help avoid dynamic shapes. Otherwise, a significant +amount of time will be spent by the compiler on optimizing and fusing +operations which will not be used again. + +The trade-off between graph size and compilation time is also important +to consider. If there is one large IR graph, the XLA compiler can spend +a lot of time on optimization and fusion of the ops. This can result in +a very long compilation time. However, the later execution may be much +faster, due to the optimizations that were performed during compilation. + +Sometimes it is worth breaking the IR graph with `xm.mark_step()`. As +explained above, this will result in a smaller graph that can be reused +later. However making graphs smaller can reduce optimizations that +otherwise could be done by the XLA compiler. + +Another important point to consider is +[MPDeviceLoader](https://github.com/pytorch/xla/blob/a1f822e2627a5639464273241821852677401026/torch_xla/distributed/parallel_loader.py#L186). +Once your code is running on an XLA device, consider wrapping the torch +dataloader with XLA `MPDeviceLoader` which preloads data to the device +to improve performance and includes `xm.mark_step()` in it. The latter +automatically breaks the iterations over batches of data and sends them +for execution. Note, if you are not using MPDeviceLoader, you might need +to set `barrier=True` in the `optimizer_step()` to enable +`xm.mark_step()` if running a training job or explicitly adding +`xm.mark_step()`. + +## TPU Setup + +Create TPU with base image to use nightly wheels or from the stable +release by specifying the `RUNTIME_VERSION`. + +``` bash +export ZONE=us-central2-b +export PROJECT_ID=your-project-id +export ACCELERATOR_TYPE=v4-8 # v4-16, v4-32, … +export RUNTIME_VERSION=tpu-vm-v4-pt-2.0 # or tpu-vm-v4-base +export TPU_NAME=your_tpu_name + +gcloud compute tpus tpu-vm create ${TPU_NAME} \ +--zone=${ZONE} \ +--accelerator-type=${ACCELERATOR_TYPE} \ +--version=${RUNTIME_VERSION} \ +--subnetwork=tpusubnet +``` + +If you have a single host VM (e.g. v4-8), you can ssh to your vm and run +the following commands from the vm directly. Otherwise, in case of TPU +pods, you can use `--worker=all --command=""` similar to + +``` bash +gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ +--zone=us-central2-b \ +--worker=all \ +--command="pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl" +``` + +Next, if you are using base image, install nightly packages and required +libraries + +``` bash +pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl +​​pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl +sudo apt-get install libopenblas-dev -y + +sudo apt-get update && sudo apt-get install libgl1 -y # diffusion specific +``` + +## Converting code to PyTorch XLA + +General guidelines to modify your code: + +- Replace `cuda` with `xm.xla_device()` +- Remove progress bar, printing that would access the XLA tensor + values +- Reduce logging and callbacks that would access the XLA tensor values +- Wrap data loader with MPDeviceLoader +- Profile to further optimize the code + +Remember: each case is unique so you might need to do something +different for each case. + +### Example 1. Stable Diffusion inference in PyTorch Lightning on a Single TPU Device + +As a first example consider the [inference +code](https://github.com/pytorch-tpu/stable-diffusion/blob/main/scripts/txt2img.py) +of the stable diffusion model in PyTorch Lightning which can be run from +command line as + +``` bash + python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" +``` + +For your reference, the diff of modifications described below can be +found +[here](https://github.com/pytorch-tpu/stable-diffusion/commit/57f398eb784387e244dc5fb78421aa5261abd1ef). +Let's go over them step by step. As in the general guideline above, +start with changes related to `cuda` device. This inference code is +written to run on GPUs and `cuda` can be found in multiple places. Start +making changes by removing `model.cuda()` from [this +line](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/scripts/txt2img.py#L64), +and `precision_scope` from +[here](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/scripts/txt2img.py#L290). +Additionally, replace the `cuda` device in [this +line](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/scripts/txt2img.py#L248) +with the `xla` device similar to the code below: + +Next, this particular configuration of the model is using +`FrozenCLIPEmbedder`, therefore we will modify this +[line](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/ldm/modules/encoders/modules.py#L143) +as well. For simplicity we will directly define the `device` in this +tutorial, but you can pass the `device` value to the function as well. + +``` python + import torch_xla.core.xla_model as xm + self.device = xm.xla_device() +``` + +Another place in the code that has cuda specific code is DDIM scheduler. +Add `import torch_xla.core.xla_model as xm` on top of the file then +replace +[these](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/ldm/models/diffusion/ddim.py#L21-L22) +lines + +``` python +if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) +``` + +with + +``` python +device = xm.xla_device() +attr = attr.to(torch.device(device)) +``` + +Next, you can reduce device (TPU) and host (CPU) communication by +removing print statements, disabling progress bars, and reducing or +removing callbacks and logging. These operations require the device to +stop executing, falling back to the CPU, executing the +logging/callbacks, and then returning to the device. This can be a +significant performance bottleneck, especially on large models. + +After making these changes, the code will run on TPUs. However, the +performance will be very slow. This is because the XLA compiler tries to +build a single (huge) graph that wraps the number of inference steps (in +this case, 50) as there is no barrier inside the for loop. It is +difficult for the compiler to optimize the graph, and this leads to +significant performance degradation. As discussed above, breaking the +for loop with the barrier (xm.mark_step()) will result in a smaller +graph that is easier for the compiler to optimize. This will also allow +the compiler to reuse the graph from the previous step, which can +improve performance. + +Now the +[code](https://github.com/pytorch-tpu/stable-diffusion/blob/ss-inference/scripts/txt2img.py) +is ready to run on TPUs in a reasonable time. More optimization and +analysis can be done by [capturing a +profile](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm) +and investigating further. However, this is not covered here. + +Note: if you are running on v4-8 TPU, then you have 4 available XLA +(TPU) devices. Running the code as above will only use one XLA device. +In order to run on all 4 devices you need to use `torch_xla.launch()` +function to spawn the code on all the devices. We will discuss a +`torch_xla.launch` in the next example. + +### Example 2. HF Stable Diffusion Inference + +Now, consider using [Stable Diffusion +Inference](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) +in the HuggingFace diffusers library for both the SD-XL and 2.1 versions +of the model. For your reference, the changes described below can be +found in this [repo](https://github.com/pytorch-tpu/diffusers). You can +clone the repo and run the inference using the following command on your +TPU VM: + +``` bash +(vm)$ git clone https://github.com/pytorch-tpu/diffusers.git +(vm)$ cd diffusers/examples/text_to_image/ +(vm)$ python3 inference_tpu_single_device.py +``` + +### Running on a Single TPU device + +This section describes the changes that need to be made to the +[text_to_image inference +example](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#inference) +code to run it on TPUs. + +The original code uses Lora for inference, but this tutorial will not +use it. Instead, we will set the `model_id` argument to +`stabilityai/stable-diffusion-xl-base-0.9` when initializing the +pipeline. We will also use the default scheduler +(DPMSolverMultistepScheduler). However, similar changes can be made to +the other schedulers as well. + +``` bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install . # pip install -e . + +cd examples/text_to_image/ +pip install -r requirements.txt +pip install invisible_watermark transformers accelerate safetensors +``` + +(If `accelerate` is not found, log out, log back in.) + +Log in to HF and agree to the [sd-xl 0.9 +license](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9) +on the model card. Next, go to +[account→settings→access](https://huggingface.co/settings/tokens) token +and generate a new token. Copy the token and run the following command +with that specific token value on your vm + +``` bash +(vm)$ huggingface-cli login --token _your_copied_token__ +``` + +The HuggingFace readme provides PyTorch code that is written to run on +GPUs. To run it on TPUs, the first step is to change the CUDA device to +an XLA device. This can be done by replacing the line `pipe.to("cuda")` +with the following lines: + +``` python +import torch_xla.core.xla_model as xm +device = xm.xla_device() +pipe.to(device) +``` + +Additionally, it is important to note that the first time you run +inference with XLA, it will take a long time to compile. For example, +compilation time for stable diffusion XL model inference from +HuggingFace can take about an hour to compile, whereas the actual +inference may take only 5 seconds, depending on the batch size. +Likewise, a GPT-2 model can take about 10-15 mins to compile, after +which the training epoch time becomes much faster. This is because XLA +builds a graph of the computation that will be performed, and then +optimizes this graph for the specific hardware that it is running on. +However, once the graph has been compiled, it can be reused for +subsequent inferences, which will be much faster. Therefore, if you are +only running inference once, you may not benefit from using XLA. +However, if you are running inference multiple times, or if you are +running inference on a list of prompts, you will start to see the +advantages of XLA after the first few inferences. For example, if you +run inference on a list of 10 prompts, the first inference (maybe +two[^1]) may take a long time to compile, but the remaining inference +steps will be much faster. This is because XLA will reuse the graph that +it compiled for the first inference. + +If you try to run the code without making any additional changes, you +will notice that the compilation time is very long (\>6 hours). This is +because the XLA compiler tries to build a single graph for all of the +scheduler steps at once similar to what we have discussed in the +previous example. To make the code run faster, we need to break the +graph up into smaller pieces with `xm.mark_step()` and reuse them in the +next steps. This happens inside the `pipe.__call__` +[function](https://github.com/huggingface/diffusers/blob/2b1786735e27bc97f4d4699712292d5c463a7380/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L559) +in [these +lines](https://github.com/huggingface/diffusers/blob/2b1786735e27bc97f4d4699712292d5c463a7380/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L805-L839). +Disabling the progress bar, removing callbacks and adding +`xm.mark_step()` at the end of the for loop speeds up the code +significantly. Changes are provided in this +[commit](https://github.com/huggingface/diffusers/compare/main...pytorch-tpu:diffusers:main). + +Additionally, the `self.scheduler.step()` function, which by default +uses the `DPMSolverMultistepScheduler` scheduler, has a few issues that +are described in the [PyTorch XLA +caveats](https://pytorch.org/xla/release/2.0/index.html#known-performance-caveats). +The `.nonzero()` and `.item()` calls in this function send requests to +the CPU for tensor evaluation, which trigger device-host communication. +This is not desirable, as it can slow down the code. In this particular +case, we can avoid these calls by passing the index to the function +directly. This will prevent the function from sending requests to the +CPU, and will improve the performance of the code. Changes are available +in +[this](https://github.com/pytorch-tpu/diffusers/commit/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d) +commit. The code now is ready to be run on TPUs. + +## Profiling and performance analysis + +To further investigate the performance of the model, we can profile it +using the profiling +[guide](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm). +As a rule of thumb, the profiling script should be run with the maximum +batch size that fits into the memory for [optimal memory +usage](https://cloud.google.com/tpu/docs/performance-guide). It also +helps to overlap tracing of the code with device execution which leads +to more optimal device usage. The duration of profiling should be long +enough to capture at least one step. Good performance of the model on +TPUs means that device-host communication is minimized and the device is +constantly running processes with no idle time. + +Starting a server in the `inference_tpu_*.py` file and running +`capture_profile.py` script as described in the guide will give us +information on processes that run on the devices. Currently, only one +XLA device is profiled. To better understand the TPU idle time (gaps in +the profile), profiling traces (`xp.Trace()`) should be added to the +code. The `xp.Trace()` measures the time it takes to trace the python +code on the host machine wrapped with the trace. For this example, +`xp.Trace()` traces were added inside the +[pipeline](https://github.com/ssusie/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py) +and the [U-net +model](https://github.com/ssusie/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py) +to measure the time to run specific sections of the code on the host +(CPU). + +If the gaps in the profile are due to Python code tracing that happens +on the host, then this might be a bottleneck and there is no further +straightforward optimization that can be done. Otherwise, the code +should be analyzed further to understand the caveats and improve the +performance further. Note that you cannot `xp.Trace()` wrap portions of +the code where `xm.mark_step()` is called. + +To illustrate this we can look at already captured profiles that were +uploaded to tensorboard following the profiling guide. + +Starting from Stable Diffusion model version 2.1 + +If we capture a profile without inserting any traces, we will see the +following: + +![Alt text](../_static/img/image.png) + +The single TPU device on v4-8, which has two cores, appears to be busy. +There are no significant gaps in their usage, except for a small one in +the middle. If we scroll up to try to find which process is occupying +the host machine, we will not find any information. Therefore, we will +add `xp.traces` to the pipeline +[file](https://github.com/pytorch-tpu/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) +as well as the U-net +[function](https://github.com/pytorch-tpu/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py). +The latter may not be useful for this particular use case, but it does +demonstrate how traces can be added in different places and how their +information is displayed in TensorBoard. + +If we add traces and re-capture the profile with the largest batch size +that can fit on the device (32 in this case), we will see that the gap +in the device is caused by a Python process that is running on the host +machine. + +![Alt text](../_static/img/image-1.png) + +We can use the appropriate tool to zoom in on the timeline and see which +process is running during that period. This is when the Python code +tracing happens on the host, and we cannot improve the tracing further +at this point. + +Now, let's examine the XL version of the model and do the same thing. We +will add traces to the pipeline +[file](https://github.com/pytorch-tpu/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py) +in the same way that we did for the 2.1 version and capture a profile. + +![Alt text](../_static/img/image-4.png) + +This time, in addition to the large gap in the middle, which is caused +by the `pipe_watermark` tracing, there are many small gaps between the +inference steps within [this +loop](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L814-L830). + +First look closer into the large gap that is caused by `pipe_watermark`. +The gap is preceded with `TransferFromDevice` which indicates that +something is happening on the host machine that is waiting for +computation to finish before proceeding. Looking into watermark +[code](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/pipelines/stable_diffusion_xl/watermark.py#L29), +we can see that tensors are transferred to cpu and converted to numpy +arrays in order to be processed with `cv2` and `pywt` libraries later. +Since this part is not straightforward to optimize, we will leave this +as is. + +Now if we zoom in on the loop, we can see that the graph within the loop +is broken into smaller parts because the `TransferFromDevice` operation +happens. + +![Alt text](../_static/img/image-2.png) + +If we investigate the U-Net function and the scheduler, we can see that +the U-Net code does not contain any optimization targets for +PyTorch/XLA. However, there are `.item()` and `.nonzero()` calls inside +the +[scheduler.step](https://github.com/huggingface/diffusers/blob/15782fd506e8c4a7c2b288fc2e558bd77fdfa51a/src/diffusers/schedulers/scheduling_euler_discrete.py#L371). +We can +[rewrite](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/schedulers/scheduling_euler_discrete.py#L310) +the function to avoid those calls. If we fix this issue and rerun a +profile, we will not see much difference. However, since we have reduced +the device-host communication that was introducing smaller graphs, we +allowed the compiler to optimize the code better. The function +[scale_model_input](https://github.com/huggingface/diffusers/blob/15782fd506e8c4a7c2b288fc2e558bd77fdfa51a/src/diffusers/schedulers/scheduling_euler_discrete.py#L205) +has similar issues, and we can fix these by making the changes we made +above to the `step` function. Overall, since many of the gaps are caused +from python level code tracing and graph building, these gaps are not +possible to optimize with the current version of PyTorch XLA, but we may +see improvements in the future when dynamo is enabled in PyTorch XLA. + +## Running on Multiple TPU Devices + +To use multiple TPU devices, you can use the `torch_xla.launch` function +to spawn the function you ran on a single device to multiple devices. +The `torch_xla.launch` function will start processes on multiple TPU +devices and sync them when needed. This can be done by passing the +`index` argument to the function that runs on a single device. For +example, + +``` python +import torch_xla + +def my_function(index): + # function that runs on a single device + +torch_xla.launch(my_function, args=(0,)) +``` + +In this example, the `my_function` function will be spawned on 4 TPU +devices on v4-8, with each device being assigned an index from 0 to 3. +Note that by default, the launch() function will spawn preocesses on all +TPU devices. If you only want to run single process, set the argument +`launch(..., debug_single_process=True)`. + +[This +file](https://github.com/ssusie/diffusers/blob/main/examples/text_to_image/inference_tpu_multidevice.py) +illustrates how xmp.spawn can be used to run stable diffusion 2.1 +version on multiple TPU devices. For this version similar to the above +changes were made to the +[pipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) +file. + +## Running on Pods + +Once you have the code for running on a single host device, there is no +further change needed. You can create the TPU pod, for example, by +following these +[instructions](https://cloud.google.com/tpu/docs/pytorch-pods#create-tpu-vm). +Then run your script with + +``` bash +gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ + --zone=${ZONE} \ + --worker=all \ + --command="python3 your_script.py" +``` + +**Note:** + +0 and 1 are magic numbers in XLA and treated as constants in the +HLO. So if there is a random number generator in the code that can +generate these values, the code will compile for each value +separately. This can be disabled with `XLA_NO_SPECIAL_SCALARS=1` +environment variable. diff --git a/docs/source/multi_process_distributed.rst b/docs/source/multi_process_distributed.rst deleted file mode 100644 index f8f25e5c05a..00000000000 --- a/docs/source/multi_process_distributed.rst +++ /dev/null @@ -1,2 +0,0 @@ -.. mdinclude:: ../ddp.md -.. mdinclude:: ../fsdp.md \ No newline at end of file diff --git a/docs/source/perf/amp.md b/docs/source/perf/amp.md new file mode 100644 index 00000000000..b5b1a3ffa79 --- /dev/null +++ b/docs/source/perf/amp.md @@ -0,0 +1,149 @@ +# Automatic Mixed Precision + +Pytorch/XLA's AMP extends [Pytorch's AMP +package](https://pytorch.org/docs/stable/amp.html) with support for +automatic mixed precision on `XLA:GPU` and `XLA:TPU` devices. AMP is +used to accelerate training and inference by executing certain +operations in `float32` and other operations in a lower precision +datatype (`float16` or `bfloat16` depending on hardware support). This +document describes how to use AMP on XLA devices and best practices. + +## AMP for XLA:TPU + +AMP on TPUs automatically casts operations to run in either `float32` or +`bfloat16` because TPUs natively support bfloat16. A simple TPU AMP +example is below: + +``` python +# Creates model and optimizer in default precision +model = Net().to(xm.xla_device()) +# Pytorch/XLA provides sync-free optimizers for improved performance +optimizer = syncfree.SGD(model.parameters(), ...) + +for input, target in data: + optimizer.zero_grad() + + # Enables autocasting for the forward pass + with autocast(xm.xla_device()): + output = model(input) + loss = loss_fn(output, target) + + # Exits the context manager before backward() + loss.backward() + xm.optimizer_step.(optimizer) +``` + +`autocast(xm.xla_device())` aliases `torch.autocast('xla')` when the XLA +Device is a TPU. Alternatively, if a script is only used with TPUs, then +`torch.autocast('xla', dtype=torch.bfloat16)` can be directly used. + +Please file an issue or submit a pull request if there is an operator +that should be autocasted that is not included. + +### AMP for XLA:TPU Best Practices + +1. `autocast` should wrap only the forward pass(es) and loss + computation(s) of the network. Backward ops run in the same type + that autocast used for the corresponding forward ops. +2. Since TPU's use bfloat16 mixed precision, gradient scaling is not + necessary. +3. Pytorch/XLA provides modified version of + [optimizers](https://github.com/pytorch/xla/tree/master/torch_xla/amp/syncfree) + that avoid the additional sync between device and host. + +### Supported Operators + +AMP on TPUs operates like Pytorch's AMP. Rules for how autocasting is +applied is summarized below: + +Only out-of-place ops and Tensor methods are eligible to be autocasted. +In-place variants and calls that explicitly supply an out=... Tensor are +allowed in autocast-enabled regions, but won't go through autocasting. +For example, in an autocast-enabled region a.addmm(b, c) can autocast, +but a.addmm\_(b, c) and a.addmm(b, c, out=d) cannot. For best +performance and stability, prefer out-of-place ops in autocast-enabled +regions. + +Ops that run in float64 or non-floating-point dtypes are not eligible, +and will run in these types whether or not autocast is enabled. +Additionally, Ops called with an explicit dtype=... argument are not +eligible, and will produce output that respects the dtype argument. + +Ops not listed below do not go through autocasting. They run in the type +defined by their inputs. Autocasting may still change the type in which +unlisted ops run if they're downstream from autocasted ops. + +**Ops that autocast to `bfloat16`:** + +`__matmul__`, `addbmm`, `addmm`, `addmv`, `addr`, `baddbmm`,`bmm`, +`conv1d`, `conv2d`, `conv3d`, `conv_transpose1d`, `conv_transpose2d`, +`conv_transpose3d`, `linear`, `matmul`, `mm`, `relu`, `prelu`, +`max_pool2d` + +**Ops that autocast to `float32`:** + +`batch_norm`, `log_softmax`, `binary_cross_entropy`, +`binary_cross_entropy_with_logits`, `prod`, `cdist`, `trace`, `chloesky` +,`inverse`, `reflection_pad`, `replication_pad`, `mse_loss`, +`cosine_embbeding_loss`, `nll_loss`, `multilabel_margin_loss`, `qr`, +`svd`, `triangular_solve`, `linalg_svd`, `linalg_inv_ex` + +**Ops that autocast to widest input type:** + +`stack`, `cat`, `index_copy` + +## AMP for XLA:GPU + +AMP on XLA:GPU devices reuse Pytorch's AMP rules. See [Pytorch's AMP +documentation](https://pytorch.org/docs/stable/amp.html) for CUDA +specific behavior. A simple CUDA AMP example is below: + +``` python +# Creates model and optimizer in default precision +model = Net().to(xm.xla_device()) +# Pytorch/XLA provides sync-free optimizers for improved performance +optimizer = syncfree.SGD(model.parameters(), ...) +scaler = GradScaler() + +for input, target in data: + optimizer.zero_grad() + + # Enables autocasting for the forward pass + with autocast(xm.xla_device()): + output = model(input) + loss = loss_fn(output, target) + + # Exits the context manager before backward pass + scaler.scale(loss).backward() + gradients = xm._fetch_gradients(optimizer) + xm.all_reduce('sum', gradients, scale=1.0 / xr.world_size()) + scaler.step(optimizer) + scaler.update() +``` + +`autocast(xm.xla_device())` aliases `torch.cuda.amp.autocast()` when the +XLA Device is a CUDA device (XLA:GPU). Alternatively, if a script is +only used with CUDA devices, then `torch.cuda.amp.autocast` can be +directly used, but requires `torch` is compiled with `cuda` support for +datatype of `torch.bfloat16`. We recommend using +`autocast(xm.xla_device())` on XLA:GPU as it does not require +`torch.cuda` support for any datatypes, including `torch.bfloat16`. + +### AMP for XLA:GPU Best Practices + +1. `autocast` should wrap only the forward pass(es) and loss + computation(s) of the network. Backward ops run in the same type + that autocast used for the corresponding forward ops. +2. Do not set `XLA_USE_F16` flag when using AMP on Cuda devices. This + will override the per-operator precision settings provided by AMP + and cause all operators to execute in float16. +3. Use gradient scaling to prevent float16 gradients from underflowing. +4. Pytorch/XLA provides modified version of + [optimizers](https://github.com/pytorch/xla/tree/master/torch_xla/amp/syncfree) + that avoid the additional sync between device and host. + +## Examples + +Our [mnist training script](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_amp.py) +and [imagenet training script](https://github.com/pytorch/xla/blob/master/test/test_train_mp_imagenet_amp.py) +demonstrate how AMP is used on both TPUs and GPUs. diff --git a/docs/source/perf/ddp.md b/docs/source/perf/ddp.md new file mode 100644 index 00000000000..826895d5ac8 --- /dev/null +++ b/docs/source/perf/ddp.md @@ -0,0 +1,252 @@ +# How to do DistributedDataParallel(DDP) + +This document shows how to use torch.nn.parallel.DistributedDataParallel +in xla, and further describes its difference against the native xla data +parallel approach. You can find a minimum runnable example +[here](https://github.com/pytorch/xla/blob/master/examples/data_parallel/train_resnet_ddp.py). + +## Background / Motivation + +Customers have long requested the ability to use PyTorch's +DistributedDataParallel API with xla. And here we enable it as an +experimental feature. + +## How to use DistributedDataParallel + +For those who switched from the PyTorch eager mode to XLA, here are all +the changes you need to do to convert your eager DDP model into XLA +model. We assume that you already know how to use XLA [on a single +device](../API_GUIDE.md#running-on-a-single-xla-device). + +1. Import xla specific distributed packages: + + ``` python + import torch_xla + import torch_xla.runtime as xr + import torch_xla.distributed.xla_backend + ``` + +2. Init xla process group similar to other process groups such as nccl + and gloo. + + ``` python + dist.init_process_group("xla", rank=rank, world_size=world_size) + ``` + +3. Use xla specific APIs to get rank and world_size if you need to. + + ``` python + new_rank = xr.global_ordinal() + world_size = xr.world_size() + ``` + +4. Pass `gradient_as_bucket_view=True` to the DDP wrapper. + + ``` python + ddp_model = DDP(model, gradient_as_bucket_view=True) + ``` + +5. Finally launch your model with xla specific launcher. + + ``` python + torch_xla.launch(demo_fn) + ``` + +Here we have put everything together (the example is actually taken from +the [DDP +tutorial](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html)). +The way you code it is pretty similar to the eager experience. Just with +xla specific touches on a single device plus the above five changes to +your script. + +``` python +import os +import sys +import tempfile +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.optim as optim + +from torch.nn.parallel import DistributedDataParallel as DDP + +# additional imports for xla +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr +import torch_xla.distributed.xla_backend + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + # initialize the xla process group + dist.init_process_group("xla", rank=rank, world_size=world_size) + +def cleanup(): + dist.destroy_process_group() + +class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(10, 1000000) + self.relu = nn.ReLU() + self.net2 = nn.Linear(1000000, 5) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + +def demo_basic(rank): + # xla specific APIs to get rank, world_size. + new_rank = xr.global_ordinal() + assert new_rank == rank + world_size = xr.world_size() + + print(f"Running basic DDP example on rank {rank}.") + setup(rank, world_size) + + # create model and move it to XLA device + device = xm.xla_device() + model = ToyModel().to(device) + # currently, graident_as_bucket_view is needed to make DDP work for xla + ddp_model = DDP(model, gradient_as_bucket_view=True) + + loss_fn = nn.MSELoss() + optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) + + optimizer.zero_grad() + outputs = ddp_model(torch.randn(20, 10).to(device)) + labels = torch.randn(20, 5).to(device) + loss_fn(outputs, labels).backward() + optimizer.step() + # xla specific API to execute the graph + xm.mark_step() + + cleanup() + + +def run_demo(demo_fn): + # xla specific launcher + torch_xla.launch(demo_fn) + +if __name__ == "__main__": + run_demo(demo_basic) +``` + +## Benchmarking + +### Resnet50 with fake data + +The following results are collected with the command on a TPU VM V3-8 +environment with ToT PyTorch and PyTorch/XLA: + +``` bash +python test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1 +``` + +And the statistical metrics are produced by using the script in this +[pull request](https://github.com/pytorch/xla/pull/4107). The unit for +the rate is images per second. + + + + + + + + + + + + + + + + + + + + + + + + + + +
TypeMeanMedian90th %Std deviationCV
xm.optimizer_step418.54419.22430.409.760.02
DDP395.97395.54407.137.600.02
+ +The performance difference between our native approach for distributed +data parallel and DistributedDataParallel wrapper is: 1 - 395.97 / +418.54 = 5.39%. This result seems reasonable given the DDP wrapper +introduces extra overheads on tracing the DDP runtime. + +### MNIST with fake data + +The following results are collected with the command: +`python test/test_train_mp_mnist.py --fake_data` on a TPU VM V3-8 +environment with ToT PyTorch and PyTorch/XLA. And the statistical +metrics are produced by using the script in this [pull +request](https://github.com/pytorch/xla/pull/4107). The unit for the +rate is images per second. + + + + + + + + + + + + + + + + + + + + + + + + + + +
TypeMeanMedian90th %Std DevCV
xm.optimizer_step17864.1920108.9624351.745866.830.33
DDP10701.3911770.0014313.783102.920.29
+ +The performance difference between our native approach for distributed +data parallel and DistributedDataParallel wrapper is: 1 - 14313.78 / +24351.74 = 41.22%. Here we compare 90th % instead since the dataset is +small and first a few rounds are heavily impacted by data loading. This +slowdown is huge but makes sense given the model is small. The +additional DDP runtime tracing overhead is hard to amortize. + +### MNIST with real data + +The following results are collected with the command n a TPU VM V3-8 +environment with ToT PyTorch and PyTorch/XLA: + +``` bash +python test/test_train_mp_mnist.py --logdir mnist/ o. +``` + +![](../_static/img/ddp_md_mnist_with_real_data.png) + +And we can observe that the DDP wrapper converges slower than the native +XLA approach even though it still achieves a high accuracy rate at +97.48% at the end. (The native approach achieves 99%.) + +## Disclaimer + +This feature is still experimental and under active development. Use it +in cautions and feel free to file any bugs to the [xla github +repo](https://github.com/pytorch/xla/). For those who are interested in +the native xla data parallel approach, here is the +[tutorial](../API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing). + +Here are some of the known issues that are under investigation: \* +`gradient_as_bucket_view=True` needs to be enforced. \* There are some +issues while being used with `torch.utils.data.DataLoader`. +`test_train_mp_mnist.py` with real data crashes before exiting. diff --git a/docs/source/perf/dynamo.md b/docs/source/perf/dynamo.md new file mode 100644 index 00000000000..4e68d49a2e4 --- /dev/null +++ b/docs/source/perf/dynamo.md @@ -0,0 +1,233 @@ +# TorchDynamo integration in PyTorch XLA + +[TorchDynamo](https://pytorch.org/docs/stable/torch.compiler.html) is a +Python-level JIT compiler designed to make unmodified PyTorch programs +faster. It provides a clean API for compiler backends to hook in and its +biggest feature is to dynamically modify Python bytecode right before it +is executed. In the pytorch/xla 2.0 release, PyTorch/XLA provided an +experimental backend for the TorchDynamo for both inference and +training. + +The way that XLA bridge works is that Dynamo will provide a TorchFX +graph when it recognizes a model pattern and PyTorch/XLA will use +existing Lazy Tensor technology to compile the FX graph and return the +compiled function. + +## Integration + +Support for PyTorch/XLA and Dynamo currently exists by adding the +`backend='openxla'` argument to `torch.compile`. For example: + +``` python +import torch +import torch_xla.core.xla_model as xm + +def add(a, b): + a_xla = a.to(xm.xla_device()) + b_xla = b.to(xm.xla_device()) + return a_xla + b_xla + +compiled_code = torch.compile(add, backend='openxla') +print(compiled_code(torch.randn(10), torch.randn(10))) +``` + +## Inference + +Here is a small code example of running resnet18 with `torch.compile` + +``` python +import torch +import torchvision +import torch_xla.core.xla_model as xm + +def eval_model(loader): + device = xm.xla_device() + xla_resnet18 = torchvision.models.resnet18().to(device) + xla_resnet18.eval() + dynamo_resnet18 = torch.compile( + xla_resnet18, backend='openxla') + for data, _ in loader: + with torch.no_grad(): + output = dynamo_resnet18(data) +``` + +With the `torch.compile` you will see that PyTorch/XLA only traces the +resent18 model once during the init time and executes the compiled +binary every time `dynamo_resnet18` is invoked, instead of tracing the +model every time. Here is a inference speed analysis to compare Dynamo +and Lazy using torch bench on Cloud TPU v4-8 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
modelSpeed up
resnet182.59
resnet502.64
resnext50_32x4d1.91
alexnet1.28
mobilenet_v218.62
mnasnet1_02.68
vgg161.33
BERT_pytorch7.49
squeezenet1_12.29
timm_vision_transformer3.52
geomean3.04
+ +## Training + +PyTorch/XLA also supports Dynamo for training, but it is experimental +and we are working with the PyTorch Compiler team to iterate on the +implementation. Here is an example of training a resnet18 with +`torch.compile` + +``` python +import torch +import torchvision +import torch_xla.core.xla_model as xm + +def train_model(model, data, target, optimizer): + loss_fn = torch.nn.CrossEntropyLoss() + pred = model(data) + loss = loss_fn(pred, target) + loss.backward() + optimizer.step() + return pred + +def train_model_main(loader): + device = xm.xla_device() + xla_resnet18 = torchvision.models.resnet18().to(device) + xla_resnet18.train() + dynamo_train_model = torch.compile( + train_model, backend='openxla') + for data, target in loader: + xla_optimizer = optim.SGD(data, lr=0.1, weight_decay=1e-2) + output = dynamo_train_model(xla_resnet18, data, target, xla_optimizer) +``` + +We expect to extract and execute 3 graphs per training step instead of 1 +graph per training step if you use the Lazy tensor. Here is a training +speed analysis to compare Dynamo and Lazy using a torch bench on Cloud +TPU v4-8. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelSpeed up
resnet501.33
resnet181.33
BERT_pytorch3.07
resnext50_32x4d1.43
alexnet1.12
mobilenet_v21.4
mnasnet1_01.19
vgg160.81
timm_vision_transformer1.87
squeezenet1_11.41
geomean1.41
+ +> **NOTE:** We run each model's fwd and bwd for a single step and then +> collect the e2e time. In the real world we will run multiple steps at +> each training job which can easily hide the tracing cost from +> execution(since it is async). Lazy Tensor will have much better +> performance in that scenario. + +## Feature gaps + +There is one gap we want to call out that are preventing us from using the +TorchDynamo on larger scale models. + +TorchDynamo will trace forward and backward into separate graphs. For +PyTorch/XLA it is important to let the XLA compiler see the whole step as one +graph to best optimize the speed. There is also a fixed overhead to launch every +device execution which make executing multiple graphs per training step less +ideal. + +This gap compared to Lazy Tensor makes it less efficient in real world training +use cases, especially the tracing cost can be overlapped wit the execution in +training. + +## Take away + +TorchDynamo provides a really promising way for the compiler backend to +hide the complexity from the user and easily retrieve the modeling code +in a graph format. Compared with PyTorch/XLA's traditional Lazy Tensor +way of extracting the graph, TorchDynamo can skip the graph tracing for +every iteration, hence providing a much better inference response time. + +Most models supported by PyTorch/XLA, have seen significant speedup when +running inference with the new dynamo-xla bridge. Our community is +working hard to expand the set of supported models. Regarding the +training feature gaps mentioned above, the PyTorch/XLA community is +super excited to improve the training gap in our upcoming development +work. The team continues to heavily invest in TorchDynamo and work with +the upstream to mature the training story. diff --git a/docs/fori_loop.md b/docs/source/perf/fori_loop.md similarity index 57% rename from docs/fori_loop.md rename to docs/source/perf/fori_loop.md index c29e32e28b3..93f7e125efd 100644 --- a/docs/fori_loop.md +++ b/docs/source/perf/fori_loop.md @@ -1,24 +1,28 @@ -# `While_loop` optimize memory utilization and compilation +# Optimize memory utilization using `while_loop` -
+## `while_loop` -### `while_loop` -`while_loop` replace pure python `while` loop, PyTorch supported `while_loop` by -[torch._higher_order_ops.while_loop](https://github.com/pytorch/pytorch/blob/62311257adb902d6a4ea98809c88895af1dbbf2b/torch/_higher_order_ops/while_loop.py#L66). -PyTorch/XLA provide experimental XLA backend support for `torch._higher_order_ops.while_loop` via `XLA::While`. +`while_loop` replace pure python `while` loop, PyTorch supported +`while_loop` by +[torch.\_higher_order_ops.while_loop](https://github.com/pytorch/pytorch/blob/62311257adb902d6a4ea98809c88895af1dbbf2b/torch/_higher_order_ops/while_loop.py#L66). +PyTorch/XLA provide experimental XLA backend support for +`torch._higher_order_ops.while_loop` via `XLA::While`. -#### Usage: -```python +### Usage: + +``` python import torch_xla.experimental.fori_loop from torch._higher_order_ops.while_loop import while_loop result = while_loop(cond_fn, body_fn, init) ``` -- `cond_fn`: User-defined condition function. -- `body_fn`: User-defined loop body function. -- `init`: Initial values (tuple or list). -#### simple example with `while_loop`: -```bash +- `cond_fn`: User-defined condition function. +- `body_fn`: User-defined loop body function. +- `init`: Initial values (tuple or list). + +### simple example with `while_loop`: + +``` bash # PJRT_DEVICE=TPU python >>> import torch >>> import torch_xla @@ -42,13 +46,15 @@ FunctionalTensor(lvl=0, value=\ tensor(13, device='xla:0')) ``` -
+#### Control group test case -## Control group test case -For better compare difference between `pure python while loop` and `while_loop`, there is one test case called pure python `while` loop with similar logic: cumulative plus 1 for ten times: +For better compare difference between `pure python while loop` and +`while_loop`, there is one test case called pure python `while` loop +with similar logic: cumulative plus 1 for ten times: -### Control group example with pure python `while` loop -```bash +## Control group example with pure python `while` loop + +``` python # PJRT_DEVICE=TPU python >>> import torch >>> import torch_xla @@ -67,6 +73,7 @@ For better compare difference between `pure python while loop` and `while_loop`, tensor(51, device='xla:0') ``` - - -PyTorch/XLA would include `while_loop` support in 2.4 with test case, support for `fori_loop` would be added after 2.4. For `while_loop`, currently we only should force define `body_fn` with same `input` and `output(return args)` shape +PyTorch/XLA would include `while_loop` support in 2.4 with test case, +support for `fori_loop` would be added after 2.4. For `while_loop`, +currently we only should force define `body_fn` with same `input` and +`output(return args)` shape diff --git a/docs/source/perf/fsdp.md b/docs/source/perf/fsdp.md new file mode 100644 index 00000000000..377d4da8d37 --- /dev/null +++ b/docs/source/perf/fsdp.md @@ -0,0 +1,186 @@ +# Fully Sharded Data Parallel in PyTorch XLA + +Fully Sharded Data Parallel (FSDP) in PyTorch XLA is a utility for +sharding Module parameters across data-parallel workers. + +Example usage: + +``` python3 +import torch +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr +from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP + +model = FSDP(my_module) +optim = torch.optim.Adam(model.parameters(), lr=0.0001) +output = model(x, y) +loss = output.sum() +loss.backward() +optim.step() +``` + +It is also possible to shard individual layers separately and have an +outer wrapper handle any leftover parameters. + +Notes: The `XlaFullyShardedDataParallel` class supports both the ZeRO-2 +optimizer (sharding gradients and optimizer states) and the ZeRO-3 +optimizer (sharding parameters, gradients, and optimizer states) in +. The ZeRO-3 optimizer should be +implemented via nested FSDP with `reshard_after_forward=True`. See +`test/test_train_mp_mnist_fsdp_with_ckpt.py` and +`test/test_train_mp_imagenet_fsdp.py` for an example. \* For large +models that cannot fit into a single TPU memory or the host CPU memory, +one should interleave submodule construction with inner FSDP wrapping. +See +[FSDPViTModel](https://github.com/ronghanghu/vit_10b_fsdp_example/blob/master/run_vit_training.py) +for an example. a simple wrapper `checkpoint_module` is provided (based +on `torch_xla.utils.checkpoint.checkpoint` from +) to perform [gradient +checkpointing](https://spell.ml/blog/gradient-checkpointing-pytorch-YGypLBAAACEAefHs) +over a given `nn.Module` instance. See +`test/test_train_mp_mnist_fsdp_with_ckpt.py` and +`test/test_train_mp_imagenet_fsdp.py` for an example. Auto-wrapping +submodules: instead of manually nested FSDP wrapping, one can also +specify an `auto_wrap_policy` argument to automatically wrap the +submodules with inner FSDP. `size_based_auto_wrap_policy` in +`torch_xla.distributed.fsdp.wrap` is an example of `auto_wrap_policy` +callable, this policy wraps layers with the number of parameters larger +than 100M. `transformer_auto_wrap_policy` in +`torch_xla.distributed.fsdp.wrap` is an example of `auto_wrap_policy` +callable for transformer-like model architectures. + +For example, to automatically wrap all `torch.nn.Conv2d` submodules with +inner FSDP, one can use: + +``` python3 +from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy +auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d}) +``` + +Additionally, one can also specify an `auto_wrapper_callable` argument +to use a custom callable wrapper for the submodules (the default wrapper +is just the `XlaFullyShardedDataParallel` class itself). For example, +one can use the following to apply gradient checkpointing (i.e. +activation checkpointing/rematerialization) to each auto-wrapped +submodule. + +``` python3 +from torch_xla.distributed.fsdp import checkpoint_module +auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel( + checkpoint_module(m), *args, **kwargs) +``` + +- When stepping the optimizer, directly call `optimizer.step` and do + not call `xm.optimizer_step`. The latter reduces the gradient across + ranks, which is not needed for FSDP (where the parameters are + already sharded). +- When saving model and optimizer checkpoints during training, each + training process needs to save its own checkpoint of the (sharded) + model and optimizer state dicts (use `master_only=False` and set + different paths for each rank in `xm.save`). When resuming, it needs + to load the checkpoint for the corresponding rank. +- Please also save `model.get_shard_metadata()` along with + `model.state_dict()` as follows and use + `consolidate_sharded_model_checkpoints` to stitch the sharded model + checkpoints together into a full model state dict. See + `test/test_train_mp_mnist_fsdp_with_ckpt.py` for an example. + +``` python3 +ckpt = { + 'model': model.state_dict(), + 'shard_metadata': model.get_shard_metadata(), + 'optimizer': optimizer.state_dict(), +} +ckpt_path = f'/tmp/rank-{xr.global_ordinal()}-of-{xr.world_size()}.pth' +xm.save(ckpt, ckpt_path, master_only=False) +``` + +- The checkpoint consolidation script can also be launched from the + command line as follows. + +``` bash +# consolidate the saved checkpoints via command line tool +python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \ + --ckpt_prefix /path/to/your_sharded_checkpoint_files \ + --ckpt_suffix "_rank-*-of-*.pth" +``` + +The implementation of this class is largely inspired by and mostly +follows the structure of `fairscale.nn.FullyShardedDataParallel` in +. One of +the biggest differences from `fairscale.nn.FullyShardedDataParallel` is +that in XLA we don't have explicit parameter storage, so here we resort +to a different approach to free full parameters for ZeRO-3. + +## Example training scripts on MNIST and ImageNet + +- Minimum example : + [examples/fsdp/train_resnet_fsdp_auto_wrap.py](https://github.com/pytorch/xla/blob/master/examples/fsdp/train_resnet_fsdp_auto_wrap.py) +- MNIST: + [test/test_train_mp_mnist_fsdp_with_ckpt.py](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_fsdp_with_ckpt.py) + (it also tests checkpoint consolidation) +- ImageNet: + [test/test_train_mp_imagenet_fsdp.py](https://github.com/pytorch/xla/blob/master/test/test_train_mp_imagenet_fsdp.py) + +### Installation + +FSDP is available on PyTorch/XLA 1.12 release and newer nightly. Please +refer to +for installation guide. + +### Clone PyTorch/XLA repo + +``` bash +git clone --recursive https://github.com/pytorch/pytorch +cd pytorch/ +git clone --recursive https://github.com/pytorch/xla.git +cd ~/ +``` + +### Train MNIST on v3-8 TPU + +It gets around 98.9 accuracy for 2 epochs: + +``` bash +python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \ + --batch_size 16 --drop_last --num_epochs 2 \ + --use_nested_fsdp --use_gradient_checkpointing +``` + +This script automatically tests checkpoint consolidation at the end. You +can also manually consolidate the sharded checkpoints via + +``` bash +# consolidate the saved checkpoints via command line tool +python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \ + --ckpt_prefix /tmp/mnist-fsdp/final_ckpt \ + --ckpt_suffix "_rank-*-of-*.pth" +``` + +### Train ImageNet with ResNet-50 on v3-8 TPU + +It gets around 75.9 accuracy for 100 epochs; download +[ImageNet-1k](https://github.com/pytorch/examples/tree/master/imagenet#requirements) +to `/datasets/imagenet-1k`: + +``` bash +python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \ + --datadir /datasets/imagenet-1k --drop_last \ + --model resnet50 --test_set_batch_size 64 --eval_interval 10 \ + --lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \ + --use_nested_fsdp +``` + +You can also add `--use_gradient_checkpointing` (which needs to be used +along with `--use_nested_fsdp` or `--auto_wrap_policy`) to apply +gradient checkpointing on the residual blocks. + +## Example training scripts on TPU pod (with 10 billion parameters) + +To train large models that cannot fit into a single TPU, one should +apply auto-wrap or manually wrap the submodules with inner FSDP when +building the entire model to implement the ZeRO-3 algorithm. + +Please see for an +example of sharded training of a Vision Transformer (ViT) model using +this XLA FSDP PR. diff --git a/docs/fsdpv2.md b/docs/source/perf/fsdpv2.md similarity index 57% rename from docs/fsdpv2.md rename to docs/source/perf/fsdpv2.md index 6ad04dc1eab..b6e07db4370 100644 --- a/docs/fsdpv2.md +++ b/docs/source/perf/fsdpv2.md @@ -1,13 +1,19 @@ -# Fully Sharded Data Parallel(FSDP) via SPMD +# Fully Sharded Data Parallel using SPMD -Fully Sharded Data Parallel via SPMD or FSDPv2 is an utility that re-expresses the famous FSDP algorithm in SPMD. [This](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/spmd_fully_sharded_data_parallel.py) is -an experimental feature that aiming to offer a familiar interface for users to enjoy all the benefits that SPMD brings into -the table. The design doc is [here](https://github.com/pytorch/xla/issues/6379). +Fully Sharded Data Parallel via SPMD or FSDPv2 is an utility that +re-expresses the famous FSDP algorithm in SPMD. +[This](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/spmd_fully_sharded_data_parallel.py) +is an experimental feature that aiming to offer a familiar interface for +users to enjoy all the benefits that SPMD brings into the table. The +design doc is [here](https://github.com/pytorch/xla/issues/6379). -Please review the [SPMD user guide](./spmd_basic.md) before proceeding. You can also find a minimum runnable example [here](https://github.com/pytorch/xla/blob/master/examples/fsdp/train_decoder_only_fsdp_v2.py). +Please review the [SPMD user guide](./spmd_basic.html) before +proceeding. You can also find a minimum runnable example +[here](https://github.com/pytorch/xla/blob/master/examples/fsdp/train_decoder_only_fsdp_v2.py). Example usage: -```python3 + +``` python3 import torch import torch_xla.core.xla_model as xm import torch_xla.distributed.spmd as xs @@ -31,8 +37,12 @@ loss = output.sum() loss.backward() optim.step() ``` -It is also possible to shard individual layers separately and have an outer wrapper handle any leftover parameters. Here is an example to autowrap each `DecoderLayer`. -```python3 + +It is also possible to shard individual layers separately and have an +outer wrapper handle any leftover parameters. Here is an example to +autowrap each `DecoderLayer`. + +``` python3 from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy # Apply FSDP sharding on each DecoderLayer layer. @@ -48,12 +58,16 @@ model = FSDPv2( ## Sharding output -To ensure the XLA compiler correctly implements the FSDP algorithm, we need to shard both weights and activations. This means sharding the output of the forward method. Since the forward function output can vary, we offer shard_output to shard activations in cases where your module output doesn't fall into one of these categories: -1. A single tensor -2. A tuple of tensors where the 0th element is the activation. +To ensure the XLA compiler correctly implements the FSDP algorithm, we +need to shard both weights and activations. This means sharding the +output of the forward method. Since the forward function output can +vary, we offer shard_output to shard activations in cases where your +module output doesn't fall into one of these categories: 1. A single +tensor 2. A tuple of tensors where the 0th element is the activation. Example usage: -```python3 + +``` python3 def shard_output(output, mesh): xs.mark_sharding(output.logits, mesh, ('fsdp', None, None)) @@ -62,14 +76,20 @@ model = FSDPv2(my_module, mesh, shard_output) ## Gradient checkpointing -Currently, gradient checkpointing needs to be applied to the module before the FSDP wrapper. Otherwise, recursively loop into children modules will end up with infinite loop. We will fix this issue in the future releases. +Currently, gradient checkpointing needs to be applied to the module +before the FSDP wrapper. Otherwise, recursively loop into children +modules will end up with infinite loop. We will fix this issue in the +future releases. Example usage: -```python3 + +``` python3 from torch_xla.distributed.fsdp import checkpoint_module model = FSDPv2(checkpoint_module(my_module), mesh) ``` ## HuggingFace Llama 2 Example -We have a fork of HF Llama 2 to demonstrate a potential integration [here](https://github.com/huggingface/transformers/compare/main...pytorch-tpu:transformers:llama2-spmd-fsdp). + +We have a fork of HF Llama 2 to demonstrate a potential integration +[here](https://github.com/huggingface/transformers/compare/main...pytorch-tpu:transformers:llama2-spmd-fsdp). diff --git a/docs/quantized_ops.md b/docs/source/perf/quantized_ops.md similarity index 52% rename from docs/quantized_ops.md rename to docs/source/perf/quantized_ops.md index ce94c0c0748..1d68def0569 100644 --- a/docs/quantized_ops.md +++ b/docs/source/perf/quantized_ops.md @@ -1,27 +1,40 @@ -# Quantized Operations for XLA device (Experimental feature) --------------------------- +# Quantized Operations for XLA (Experimental feature) -This document outlines how to utilize quantized operations to enable quantization on XLA devices. +This document outlines how to utilize quantized operations to enable +quantization on XLA devices. -XLA Quantized ops offer a high-level abstraction for quantized operations (e.g., blockwise int4 quantized matrix multiplication). These ops are analogous to quantized CUDA kernels ([example](https://github.com/vllm-project/vllm/blob/main/csrc/quantization/gptq/q_gemm.cu)) in the CUDA ecosystem, providing similar functionality and performance benefits within the XLA framework. - -**NOTE:** Currently this is classified as experimental feature. It's API specifics -will change in the next (2.5) release. +XLA Quantized ops offer a high-level abstraction for quantized +operations (e.g., blockwise int4 quantized matrix multiplication). These +ops are analogous to quantized CUDA kernels +([example](https://github.com/vllm-project/vllm/blob/main/csrc/quantization/gptq/q_gemm.cu)) +in the CUDA ecosystem, providing similar functionality and performance +benefits within the XLA framework. +**NOTE:** Currently this is classified as experimental feature. It's API +specifics will change in the next (2.5) release. ## How to use: -XLA quantized operations can be used as `torch op`, or a `torch.nn.Module` that wraps the `torch.op`. These 2 options give model developers the flexibility to choose the best way to integrate XLA quantized ops into their solution. +XLA quantized operations can be used as `torch op`, or a +`torch.nn.Module` that wraps the `torch.op`. These 2 options give model +developers the flexibility to choose the best way to integrate XLA +quantized ops into their solution. -Both `torch op` and `nn.Module` are compatible with `torch.compile( backend='openxla')`. +Both `torch op` and `nn.Module` are compatible with +`torch.compile( backend='openxla')`. ### Call XLA quantized op in model code -Users can call XLA quantized ops in the same way as calling other regular PyTorch ops. This provides maximum flexibility in integrating XLA quantized ops into their applications. The quantized ops work in both eager mode and Dynamo, with regular PyTorch CPU tensor and XLA tensor. +Users can call XLA quantized ops in the same way as calling other +regular PyTorch ops. This provides maximum flexibility in integrating +XLA quantized ops into their applications. The quantized ops work in +both eager mode and Dynamo, with regular PyTorch CPU tensor and XLA +tensor. -**Note** Please check the docstring of the quantized ops for the layout of the quantized weights. +**Note** Please check the docstring of the quantized ops for the layout +of the quantized weights. -```Python +``` python import torch import torch_xla.core.xla_model as xm import torch_xla.experimental.xla_quantized_matmul @@ -51,9 +64,10 @@ f_dynamo = torch.compile(f, backend="openxla") dynamo_out_xla = f_dynamo(x_xla, w_int_xla, scaler_xla) ``` -It's common to wrap the quantized op into a custom `nn.Module` in model developers model code: +It's common to wrap the quantized op into a custom `nn.Module` in model +developers model code: -```Python +``` python class MyQLinearForXLABackend(torch.nn.Module): def __init__(self): self.weight = ... @@ -66,7 +80,7 @@ class MyQLinearForXLABackend(torch.nn.Module): self.weight = processed_w self.scaler = processed_scaler - + def forward(self, x): # Do some random stuff with x ... @@ -77,9 +91,10 @@ class MyQLinearForXLABackend(torch.nn.Module): ### Module Swap -Alternatively, users can also use the `nn.Module` that wraps the XLA quantized ops and do module swap in the model code: +Alternatively, users can also use the `nn.Module` that wraps the XLA +quantized ops and do module swap in the model code: -```Python +``` python orig_model = MyModel() # Quantize the model and get quantized weights q_weights = quantize(orig_model) @@ -97,18 +112,61 @@ orig_model.linear = q_linear ### Matrix Multiply -| Weight Quantization Type | Activation Quantization Type | Dtype | Supported | -|---|---|---|---| -| per-channel (sym/asym) | N/A | W8A16 | Yes | -| per-channel (sym/asym) | N/A | W4A16 | Yes | -| per-channel | per-token | W8A8 | No | -| per-channel | per-token | W4A8 | No | -| blockwise (sym/asym) | N/A | W8A16 | Yes | -| blockwise (sym/asym) | N/A | W4A16 | Yes | -| blockwise | per-token | W8A8 | No | -| blockwise | per-token | W4A8 | No | - -**Note** `W[X]A[Y]` refers to Weight in `X`-bit, Activation in `Y`-bit. If `X/Y` is 4 or 8, it refers to `int4/8`. 16 for `bfloat16` format. - -### Embedding -To be added + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
WeightActivationDtypeSupported
per-channel (sym/asym)W8A16Yes
per-channel (sym/asym)N/AW8A8No
per-channelper-tokenW8A8No
per-channelper-tokenW4A8No
blockwise (sym/asym)N/AW8A16Yes
blockwise (sym/asym)N/AW8A16Yes
blockwiseper-tokenW8A8No
blockwiseper-tokenW4A8No
+ +**Note** `W[X]A[Y]` refers to Weight in `X`-bit, Activation in `Y`-bit. +If `X/Y` is 4 or 8, it refers to `int4/8`. 16 for `bfloat16` format. \ No newline at end of file diff --git a/docs/source/perf/recompilation.md b/docs/source/perf/recompilation.md new file mode 100644 index 00000000000..6d4a3ceed6f --- /dev/null +++ b/docs/source/perf/recompilation.md @@ -0,0 +1,176 @@ +# Source of recompilations in Pytorch/XLA + +## Let’s first start with some facts/constraints: + +1. Graph compilations in XLA are pretty expensive. +2. XLA handles static shape only. In other words, even for the same IR graph, XLA recompiles when input shape changes. +3. Recompilations hurts torch_xla perf a lot when it happens, and it’s hard to understand and debug from a normal python user POV. + +Often when recompilation happens we say we just need dynamic shape support and then rest assured that when dynamic shape is supported in the future, all the recompilations will be magically gone. But this is not true, XLA now has pretty good bounded dynamic shapes coverage already, but we still see recompilations and they are expected. + +This doc aims to provide a detailed explanation of a few common sources of recompilations, and what do we need to get rid of them. It will mainly focus on explaining the problem to beginners without any context. To make it easy to understand, the “solutions” proposed here may rely on impractical assumptions. + +## #1. From input dataset. + +Yes it’s pretty common that input dataset contains examples with different shapes, e.g. sentences with varying length or images with different sizes. Without normalization, it’ll cause recompilation for every new input shape. + +Tensorflow graph mode users are more used to do padding/bucketization (`tf.pad`) to normalize input shapes to one or a few buckets. But this is kinda anti-pattern for PyTorch eager frontend users (which is the same user lazy tensor frontend is trying to target) since different input shapes just doesn’t matter for eager CPU/CUDA backend. + +**Proposed workaround:** okay now let’s say we can work around this problem by teaching our users to do padding/bucketization (it’s hard in practice :P). What’s next? + +## #2. From operator output + +There are certain operators semantically are data-dependent and produce dynamic shape outputs: e.g. `torch.nonzero` returns indices of nonzero elements in its input tensor. So even your input tensors to this operator always have the same shape, it might produce different shape outputs and cause recompilations. + +### 2.1 Bounded dynamic shape can fix the case when you use the tensor with dynamic shape as a Tensor, without querying its real dimension. + +**Proposed workaround:** let’s say now XLA supports bounded dynamic shape for all operators, is it good enough? + +* by bounded dynamic shape it means we can pad the tensor to a theoretical max, trading more memory usage for less recompilation/faster speed. + +Well, sort of. Let’s see the following example: + + +``` +a = torch.tensor([1, 2, 0, 1, 3], device='xla') +b = torch.nonzero(a) +c = b * 2 +d = c + 1 +print(torch_xla._XLAC._get_xla_tensors_text([d])) +``` + +In the example above every node below `b` in the graph (namely `c, d` and everything depend on them) will have dynamic shape, it’s pretty obvious that `b` has dynamic shape in dimension 0 as shown below: + + +``` + %9 = (s64[<=5,1]{1,0}, s64[]) aten::nonzero(%8), num_outputs=2 # b + %10 = s64[5,1]{1,0} aten::mul(%9.0, %3) # c + %11 = s64[5,1]{1,0} aten::add(%10, %2), ROOT=0 # d +``` + +Although it's not shown directly in the graph, `c & d` indeed also have dynamic shape (in other words, [5, 1] is just padded shape and it's masked). + +``` +print(torch_xla._XLAC._get_xla_tensor_dimension_size(d, 0)) # prints 4 instead of 5 +``` + +You can see that in this case as long as the input tensor `a` has shape `[5]` we only compile the graph once. Bounded dynamic shape support helped! + +### 2.2 what if real dimension is queried on a tensor with dynamic shape? + +This is actually pretty commonly used since not all PyTorch computation are done in the form of Tensors. + +For example, `tensor.size()` in PyTorch returns a tuple of ints instead of a Tensor of dtype=int. When `tensor` is a dynamic shape tensor, this op basically forces XLA to cut the graph and evaluate so that we can return the correct scalar (otherwise it’ll just return the padded shape which is wrong). + +What’s made it worse is that many PyTorch takes scalar inputs as well. After you do `s = tensor.size(0)` and use `s` in other operators it also becomes a dynamic source. In this case we probably know how to pad it and its upper bound, but we cannot do it since it’s not even a Tensor! + + +``` + a = torch.tensor([1, 2, 0, 1, 3], device='xla') + b = torch.nonzero(a) + s = a.size(0) # evaluation happens! nit: we use size() for simplicity, the actual API is _get_xla_tensor_dimension_size. + c = torch.rand(s, device='xla') # c can be of any shape between [0, 5] which causes more recompilations! + d = c + 1 +``` + +So this one is actually hard to solve without PyTorch frontend’s help. What do we need? + +In short, we need a Tensor world! + +For example, + +* `tensor.size()` should return a Tensor so that it can be a Tensor with dynamic shape and kept in the graph without early evaluation. +* Tensor accessor, e.g. for 2D tensor, `tensor[0][0]` now returns a value but this need to return a tensor as well. +* Implicitly this means all operators currently taking int/float/double as input need a Tensor overload as well. THIS IS A BIG ASK as it can easily explode our operator set. + * It’s easier if we can make scalar to Tensor conversion really cheap so that we can only care about the Tensor overload. + * In practice not all ops takes scalars from previous computation, so we’ve been adding Tensor variants by ad-hoc requests. + * This is also a common ask from tracing base approaches I think. + +Okay now that we assume every op in PyTorch has a Tensor verison we need, are we done? + +## #3. From control flow + +No! We actually only solved the problem without data dependent control flow... + +See the example below: + +``` +if x[0][0] == 3: + bla +else: + blabla +``` + +Even if `x[0][0]` was a Tensor, we need to execute/materialize its value for python interpreter to proceed. And different branch choices in multiple control flows combined means we have a lot of graph to compile as well! + +For now we just have no way to fix this. To fix it we need to lower the control flow from python to graph! Without too much thinking in implementation we can do this in two ways: + +* ask users to explicitly use a control flow op instead of python if/else/while/for. This is currently supported as [customized API in torch_xla](https://github.com/pytorch/xla/blob/master/torch_xla/core/xla_builder.py#L563-L574) but not widely adopted in users’ code. (python users are used to if/else/for and it’s hard to switch them to a uglier API unless there’s a huge perf win). +* parse python source. code to get the control flow statement automatically. This is like Torchscript and somehow merge the torchscripted graph into the lazily trace graph properly (including shape info etc). I haven’t thought through the steps of how to implement this indeed :P + +But either solution above requires non-trivial amount of effort, either on user side or on the framework side. That’s why we currently just take the hit of early evaluation & multiple compilations as a short term solution given the bandwidth we have. + +Okay so now we assume that also have control flow lowered in the graph automagically, are we gold? + +YES! Now you have your whole computation represented in a graph of Tensor operations, including control flow so that compilers can now consume and do their smart tricks! But tbh at this point your program is no longer very PyTorch-y. + + +## Conclusion: + +There’re actually multiple sources of recompilation and bounded dynamic shape support cannot solve all of them. The proposed workarounds in this doc are definitely sometimes impractical, and there might be better ways to fix each source properly that I’m totally unaware of. But I hope as we keep smashing our way to an ideal lazy tensor stack in this doc, it’s now easier for you understand what’re the remaining blockers ahead of us. + + +## Appendix: + +1. NNC uses symbolic shapes, does that help? + +Yes but partially. By having symbolic shape, your compilation optimization no longer requires concrete shape values. In other words your generated kernel are more general than XLA’s static shape ones. + +And which exactly problem does it help? + +It helps with cases like #1 and #2.1. + + +``` +shape [3, 5] -> add -> transpose -> ... -> mul +shape [6, 2] -> add -> transpose -> ... -> mul + +# with symbolic shape +shape [x, y] -> add -> transpose -> ... -> mul +``` + +With symbolic shape your generated kernel doesn’t recompile as XLA does with static shapes. + +XLA solves this problem in the other way, by using padding/bucketization (for #1) and bounded dynamic shape (for #2.1). + +Brian Hirsh(@bdhirsh) asked some really good questions in the comment, moving here to make them more visible: + +2. Is it worth sticking a TORCH_WARN in the XLA kernels of ops that produce data-dependent output shapes? + +Yea torch_warn is useful in telling users "hey your program won't run blazing fast". But for these data dependent ops, there isn't an easy rewrite for them unless users change the logic in their model. (another example is torch.unique()) + +3. How ops like nonzero impact our ability to devirtualize sizes()? If we want to devirtualize sizes(), we’ll need to be able to eagerly compute sizes for each op - won’t that mean we’re forced to evaluate the graph every time we hit an op like nonzero? Vs. right now, it sounds like we don't actually force an evaluation when a user calls nonzero()? + +Yea great question! So in the current form it’s not a hard blocker since size() on XLA Tensors doesn’t carry source of truth size information. As shown in the example, the source of truth lives in IRValue and can be retrieved by `_get_xla_tensor_dimension_size` only. So if we decide to devirtualize size it’ll just enforce this discrepancy. + +As a followup if we have `size()` return Tensor instead of values as mentioned in the proposed workarounds above. In that case size() won’t be able to devirtualize since it becomes an operator (taking in Tensor and produce Tensor, have different implementations for different backends.) + +4. If I, e.g. call `torch.add(input, 1)` in a loop, where input varies in size from 1-1000, normally we would have to compile 1000 different graphs - but with dynamic shapes, it sounds like XLA will internally be able to generate a single graph where it says “use this graph if the input size is <=1000”. My question is: is “dynamic shape” a property of just the graph? Or of both the graph and the input. I.e. if my code were instead calling `x = torch.add(input, 1); x.sizes()` in a loop, does x have a dynamic shape at this point, meaning we’d need to run the graph to get the sizes? Or are we able to make it an eagerly computed property even in the presence of graphs with dynamic shapes. + +Yea in this case you'll compile 1000 different graphs. Dynamic shapes means its input has dynamic dimension in it. So when you query `x.sizes()` (currently need use get_dimention_size to get the correct size) it'll trigger *execution* (since the size didn't change it doesn't trigger recompilation). Without the line accessing size, it won't trigger any recompilation/execution when input has dynamic dimension. + +5. Would an alternative of making control flow available in the graph be just to come up with a way to ensure that XLA graphs don't include control flow? i.e. if we have a model with a single conditional in the middle, then get XLA to produce 3 graphs: 1 for everything before the conditional, 1 for the if branch, and 1 for the else branch. That would mean you don't get the exponential blowup of new graphs for every combination of paths taken, but (a) the graphs are smaller and provide fewer optimization opportunities, and (b) it would probably be pretty non-trivial to get XLA to recognize where a conditional path is taken. + +Great point! So if we could break them up into smaller graphs it's indeed feasible. But in practice this pattern is annoying: + +``` +y = +x = y + 2 +if x[0] == 2 : + z = y +1 +else: + z = y - 1 +``` + +Note you'll evaluate x using a subgraph when you hit control flow, but there might be previous variable included in the branch computation as well (like` y` is just one node smaller than x, but it wasn't materizalized when you evaluate `x`). So you're actually evaluating 1 small graph and two big graphs for this example. And with more control flow involved, y could get updated in multiple branches which still produces different combo of large graphs. + diff --git a/docs/spmd_advanced.md b/docs/source/perf/spmd_advanced.md similarity index 97% rename from docs/spmd_advanced.md rename to docs/source/perf/spmd_advanced.md index 369fdfe2570..cf4b840bb3f 100644 --- a/docs/spmd_advanced.md +++ b/docs/source/perf/spmd_advanced.md @@ -117,8 +117,8 @@ from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding generated_table = visualize_tensor_sharding(t, use_color=False) ``` - - visualize_tensor_sharding example on TPU v4-8(single-host) + + visualize_tensor_sharding example on TPU v4-8(single-host) - Code snippet used `visualize_sharding` and visualization result: @@ -129,8 +129,8 @@ sharding = '{devices=[2,2]0,1,2,3}' generated_table = visualize_sharding(sharding, use_color=False) ``` - - visualize_sharding example on TPU v4-8(single-host) + + visualize_sharding example on TPU v4-8(single-host) You could use these examples on TPU/GPU/CPU single-host and modify it to run on multi-host. And you could modify it to sharding-style `tiled`, `partial_replication` and `replicated`. diff --git a/docs/source/perf/spmd_basic.md b/docs/source/perf/spmd_basic.md new file mode 100644 index 00000000000..ced09a9b504 --- /dev/null +++ b/docs/source/perf/spmd_basic.md @@ -0,0 +1,116 @@ +# PyTorch/XLA SPMD User Guide + +In this user guide, we discuss how +[GSPMD](https://arxiv.org/abs/2105.04663) is integrated in PyTorch/XLA, +and provide a design overview to illustrate how the SPMD sharding +annotation API and its constructs work. + +## What is PyTorch/XLA SPMD? + +[GSPMD](https://arxiv.org/abs/2105.04663) is an automatic +parallelization system for common ML workloads. The XLA compiler will +transform the single device program into a partitioned one with proper +collectives, based on the user provided sharding hints. This feature +allows developers to write PyTorch programs as if they are on a single +large device without any custom sharded computation ops and/or +collective communications to scale. + +![Execution strategies](../_static/img/spmd_mode.png "image_tooltip") +_Figure 1. Comparison of two different execution strategies, (a) for non-SPMD and (b) for SPMD._ + +## How to use PyTorch/XLA SPMD? + +Here is an simple example of using SPMD + +``` python +import numpy as np +import torch +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr +import torch_xla.distributed.spmd as xs +from torch_xla.distributed.spmd import Mesh + + +# Enable XLA SPMD execution mode. +xr.use_spmd() + + +# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape. +num_devices = xr.global_runtime_device_count() +mesh_shape = (num_devices, 1) +device_ids = np.array(range(num_devices)) +mesh = Mesh(device_ids, mesh_shape, ('data', 'model')) + + +t = torch.randn(8, 4).to(xm.xla_device()) + + +# Mesh partitioning, each device holds 1/8-th of the input +partition_spec = ('data', 'model') +xs.mark_sharding(t, mesh, partition_spec) +``` + +Let's explain these concepts one by one + +### SPMD Mode + +In order to use SPMD, you need to enable it via `xr.use_spmd()`. In SPMD +mode there is only one logical device. Distributed computation and +collective is handled by the `mark_sharding`. Note that user can not mix +SPMD with other distributed libraries. + +### Mesh + +For a given cluster of devices, a physical mesh is a representation of +the interconnect topology. + +1. `mesh_shape` is a tuple that will be multiplied to the total number + of physical devices. +2. `device_ids` is almost always `np.array(range(num_devices))`. +3. Users are also encouraged to give each mesh dimension a name. In the + above example, the first mesh dimension is the `data` dimension and + the second mesh dimension is the `model` dimension. + +You can also check more mesh info via + +``` python + >>> mesh.shape() + OrderedDict([('data', 4), ('model', 1)]) +``` + +### Partition Spec + +partition_spec has the same rank as the input tensor. Each dimension +describes how the corresponding input tensor dimension is sharded across +the device mesh. In the above example tensor `t`'s fist dimension is +being sharded at `data` dimension and the second dimension is being +sharded at `model` dimension. + +User can also shard tensor that has different dimensions from the mesh +shape. + +``` python +t1 = torch.randn(8, 8, 16).to(device) +t2 = torch.randn(8).to(device) + +# First dimension is being replicated. +xs.mark_sharding(t1, mesh, (None, 'data', 'model')) + +# First dimension is being sharded at data dimension. +# model dimension is used for replication when omitted. +xs.mark_sharding(t2, mesh, ('data',)) + +# First dimension is sharded across both mesh axes. +xs.mark_sharding( t2, mesh, (('data', 'model'),)) +``` + +## Further Reading + +1. [Example](https://github.com/pytorch/xla/blob/master/examples/data_parallel/train_resnet_spmd_data_parallel.py) + to use SPMD to express data parallism. +2. [Example](https://github.com/pytorch/xla/blob/master/examples/fsdp/train_decoder_only_fsdp_v2.py) + to use SPMD to express FSDP(Fully Sharded Data Parallel). +3. [SPMD advanced + topics](https://github.com/pytorch/xla/blob/master/docs/spmd_advanced.rst) +4. [Spmd Distributed + Checkpoint](https://github.com/pytorch/xla/blob/master/docs/spmd_distributed_checkpoint.rst) diff --git a/docs/source/perf/spmd_distributed_checkpoint.md b/docs/source/perf/spmd_distributed_checkpoint.md new file mode 100644 index 00000000000..97e4f4198cd --- /dev/null +++ b/docs/source/perf/spmd_distributed_checkpoint.md @@ -0,0 +1,142 @@ +# Distributed Checkpointing + +PyTorch/XLA SPMD is compatible with the +[torch.distributed.checkpoint](https://pytorch.org/docs/stable/distributed.checkpoint.html) +library through a dedicated `Planner` instance. Users are able to +synchronously save and load checkpoints through this common interface. + +The SPMDSavePlanner and SPMDLoadPlanner +([src](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/distributed_checkpoint.py)) +classes enable the `save` and `load` functions to operate directly on +the shards of an `XLAShardedTensor`, enabling all of the benefits of +distributed checkpointing in SPMD training. + +Here is a demonstration of the synchronous distributed checkpointing +API: + +``` python +import torch.distributed.checkpoint as dist_cp +import torch_xla.experimental.distributed_checkpoint as xc + +# Saving a state_dict +state_dict = { + "model": model.state_dict(), + "optim": optim.state_dict(), +} + +dist_cp.save( + state_dict=state_dict, + storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), + planner=xc.SPMDSavePlanner(), +) +... + +# Loading the model's state_dict from the checkpoint. The model should +# already be on the XLA device and have the desired sharding applied. +state_dict = { + "model": model.state_dict(), +} + +dist_cp.load( + state_dict=state_dict, + storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), + planner=xc.SPMDLoadPlanner(), +) +model.load_state_dict(state_dict["model"]) +``` + +## CheckpointManager + +The experimental +[CheckpointManager](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/distributed_checkpoint/manager.py#L40) +interface provides a higher-level API over the +`torch.distributed.checkpoint` functions to enable a few key features: + +- **Managed checkpoints**: Each checkpoint taken by the + `CheckpointManager` is identified by the step at which it was taken. + All steps tracked are accessible through the + `CheckpointManager.all_steps` method, and any tracked steps can be + restored using `CheckpointManager.restore`. +- **Asynchronous checkpointing**: Checkpoints taken through the + `CheckpointManager.save_async` API are written to persistent storage + asynchronously to unblock training for the duration of the + checkpoint. The input sharded state_dict is first moved to CPU + before the checkpoint is dispatched to a background thread. +- **Auto-checkpointing on preemption**: On Cloud TPU, preemptions can + be detected and a checkpoint taken before the process is terminated. + To use, ensure your TPU is provisioned through a QueuedResource with + [Autocheckpointing + enabled](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/queued-resources/create#--autocheckpoint-enabled), + and ensure the `chkpt_on_preemption` parameter is set when + constructing the CheckpointManager (this option is enabled by + default). +- **FSSpec Support**: `CheckpointManager` uses an fsspec storage + backend to enable checkpointing directly to any fsspec-compatible + filesystem, including GCS. + +Example usage of the CheckpointManager is below: + +``` python +from torch_xla.experimental.distributed_checkpoint import CheckpointManager, prime_optimizer + +# Create a CheckpointManager to checkpoint every 10 steps into GCS. +chkpt_mgr = CheckpointManager('gs://my-bucket/my-experiment', 10) + +# Select a checkpoint to restore from, and restore if applicable +tracked_steps = chkpt_mgr.all_steps() +if tracked_steps: + # Choose the highest step + best_step = max(tracked_steps) + # Before restoring the checkpoint, the optimizer state must be primed + # to allow state to be loaded into it. + prime_optimizer(optim) + state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()} + chkpt_mgr.restore(best_step, state_dict) + model.load_state_dict(state_dict['model']) + optim.load_state_dict(state_dict['optim']) + +# Call `save` or `save_async` every step within the train loop. These methods +# return True when a checkpoint is taken. +for step, data in enumerate(dataloader): + ... + state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()} + if chkpt_mgr.save_async(step, state_dict): + print(f'Checkpoint taken at step {step}') +``` + +### Restoring Optimizer State + +In distributed checkpointing, the state_dicts are loaded in-place, and +only the required shards of the checkpoint are loaded. Since optimizer +states are lazily created, the state isn't present until the first +`optimizer.step` call, and attempts to load an unprimed optimizer will +fail. + +The utility method `prime_optimizer` is provided for this: it runs a +fake train step by setting all gradients to zero and calling +`optimizer.step`. *This is a destructive method and will touch both +model parameters and optimizer state*, so it should only be called just +prior to restoration. + +#### Process Groups + +To use `torch.distributed` APIs such as distributed checkpointing, a +process group is required. In SPMD mode, the `xla` backend is not +supported since the compiler is responsible for all collectives. + +Instead, a CPU process group such as `gloo` must be used. On TPUs, the +`xla://` init_method is still supported to discover the master IP, +global world size, and host rank. An example initialization is below: + +``` python +import torch.distributed as dist +# Import to register the `xla://` init_method +import torch_xla.distributed.xla_backend +import torch_xla.runtime as xr + +xr.use_spmd() + +# The `xla://` init_method will automatically discover master worker IP, rank, +# and global world size without requiring environment configuration on TPUs. +dist.init_process_group('gloo', init_method='xla://') +``` diff --git a/docs/source/perf/spmd_gpu.md b/docs/source/perf/spmd_gpu.md new file mode 100644 index 00000000000..cda25723aaa --- /dev/null +++ b/docs/source/perf/spmd_gpu.md @@ -0,0 +1,48 @@ +# Running SPMD on GPU + +PyTorch/XLA supports SPMD on NVIDIA GPU (single-node or multi-nodes). +The training/inference script remains the same as the one used for TPU, +such as this [ResNet +script](https://github.com/pytorch/xla/blob/1dc78948c0c9d018d8d0d2b4cce912552ab27083/test/spmd/test_train_spmd_imagenet.py). +To execute the script using SPMD, we leverage `torchrun`: + + PJRT_DEVICE=CUDA \ + torchrun \ + --nnodes=${NUM_GPU_MACHINES} \ + --node_rank=${RANK_OF_CURRENT_MACHINE} \ + --nproc_per_node=1 \ + --rdzv_endpoint=":" \ + training_or_inference_script_using_spmd.py + +- `--nnodes`: how many GPU machines to be used. +- `--node_rank`: the index of the current GPU machines. The value can + be 0, 1, ..., \${NUMBER_GPU_VM}-1. +- `--nproc_per_node`: the value must be 1 due to the SPMD requirement. +- `--rdzv_endpoint`: the endpoint of the GPU machine with + node_rank==0, in the form `host:port`. The host will be the internal + IP address. The `port` can be any available port on the machine. For + single-node training/inference, this parameter can be omitted. + +For example, if you want to train a ResNet model on 2 GPU machines using +SPMD, you can run the script below on the first machine: + + XLA_USE_SPMD=1 PJRT_DEVICE=CUDA \ + torchrun \ + --nnodes=2 \ + --node_rank=0 \ + --nproc_per_node=1 \ + --rdzv_endpoint=":12355" \ + pytorch/xla/test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 128 + +and run the following on the second machine: + + XLA_USE_SPMD=1 PJRT_DEVICE=CUDA \ + torchrun \ + --nnodes=2 \ + --node_rank=1 \ + --nproc_per_node=1 \ + --rdzv_endpoint=":12355" \ + pytorch/xla/test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 128 + +For more information, please refer to the [SPMD support on GPU +RFC](https://github.com/pytorch/xla/issues/6256). diff --git a/docs/source/quantized_ops.rst b/docs/source/quantized_ops.rst deleted file mode 100644 index 1ebe49105a3..00000000000 --- a/docs/source/quantized_ops.rst +++ /dev/null @@ -1 +0,0 @@ -.. mdinclude:: ../quantized_ops.md \ No newline at end of file diff --git a/docs/source/runtime.rst b/docs/source/runtime.rst deleted file mode 100644 index 3aca8f3dfe1..00000000000 --- a/docs/source/runtime.rst +++ /dev/null @@ -1 +0,0 @@ -.. mdinclude:: ../pjrt.md \ No newline at end of file diff --git a/docs/source/spmd.rst b/docs/source/spmd.rst deleted file mode 100644 index 6765a5d24a6..00000000000 --- a/docs/source/spmd.rst +++ /dev/null @@ -1,4 +0,0 @@ -.. mdinclude:: ../spmd_basic.md -.. mdinclude:: ../fsdpv2.md -.. mdinclude:: ../spmd_advanced.md -.. mdinclude:: ../spmd_distributed_checkpoint.md \ No newline at end of file diff --git a/docs/source/torch_compile.rst b/docs/source/torch_compile.rst deleted file mode 100644 index 505163227f0..00000000000 --- a/docs/source/torch_compile.rst +++ /dev/null @@ -1 +0,0 @@ -.. mdinclude:: ../dynamo.md \ No newline at end of file diff --git a/docs/spmd_basic.md b/docs/spmd_basic.md deleted file mode 100644 index 342efbf5ef4..00000000000 --- a/docs/spmd_basic.md +++ /dev/null @@ -1,83 +0,0 @@ -# PyTorch/XLA SPMD User Guide - -In this user guide, we discuss how [GSPMD](https://arxiv.org/abs/2105.04663) is integrated in PyTorch/XLA, and provide a design overview to illustrate how the SPMD sharding annotation API and its constructs work. - -## What is PyTorch/XLA SPMD? -[GSPMD](https://arxiv.org/abs/2105.04663) is an automatic parallelization system for common ML workloads. The XLA compiler will transform the single device program into a partitioned one with proper collectives, based on the user provided sharding hints. This feature allows developers to write PyTorch programs as if they are on a single large device without any custom sharded computation ops and/or collective communications to scale. - -![alt_text](_static/img/spmd_mode.png "image_tooltip") -_Figure 1. Comparison of two different execution strategies, (a) for non-SPMD and (b) for SPMD._ - - -## How to use PyTorch/XLA SPMD? -Here is an simple example of using SPMD -```python -import numpy as np -import torch -import torch_xla.core.xla_model as xm -import torch_xla.runtime as xr -import torch_xla.distributed.spmd as xs -from torch_xla.distributed.spmd import Mesh - - -# Enable XLA SPMD execution mode. -xr.use_spmd() - - -# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape. -num_devices = xr.global_runtime_device_count() -mesh_shape = (num_devices, 1) -device_ids = np.array(range(num_devices)) -mesh = Mesh(device_ids, mesh_shape, ('data', 'model')) - - -t = torch.randn(8, 4).to(xm.xla_device()) - - -# Mesh partitioning, each device holds 1/8-th of the input -partition_spec = ('data', 'model') -xs.mark_sharding(t, mesh, partition_spec) -``` - -Let’s explain these concepts one by one - -### SPMD Mode -In order to use SPMD, you need to enable it via `xr.use_spmd()`. In SPMD mode there is only one logical device. Distributed computation and collective is handled by the `mark_sharding`. Note that user can not mix SPMD with other distributed libraries. - -### Mesh -For a given cluster of devices, a physical mesh is a representation of the interconnect topology. - -1. `mesh_shape` is a tuple that will be multiplied to the total number of physical devices. -2. `device_ids` is almost always `np.array(range(num_devices))`. -3. Users are also encouraged to give each mesh dimension a name. In the above example, the first mesh dimension is the `data` dimension and the second mesh dimension is the `model` dimension. - -You can also check more mesh info via -``` ->>> mesh.shape() -OrderedDict([('data', 4), ('model', 1)]) -``` - -### Partition Spec -partition_spec has the same rank as the input tensor. Each dimension describes how the corresponding input tensor dimension is sharded across the device mesh. In the above example tensor `t`’s fist dimension is being sharded at `data` dimension and the second dimension is being sharded at `model` dimension. - -User can also shard tensor that has different dimensions from the mesh shape. -```python -t1 = torch.randn(8, 8, 16).to(device) -t2 = torch.randn(8).to(device) - -# First dimension is being replicated. -xs.mark_sharding(t1, mesh, (None, 'data', 'model')) - -# First dimension is being sharded at data dimension. -# model dimension is used for replication when omitted. -xs.mark_sharding(t2, mesh, ('data',)) - -# First dimension is sharded across both mesh axes. -xs.mark_sharding( t2, mesh, (('data', 'model'),)) -``` - -## Further Reading -1. [Example](https://github.com/pytorch/xla/blob/master/examples/data_parallel/train_resnet_spmd_data_parallel.py) to use SPMD to express data parallism. -2. [Example](https://github.com/pytorch/xla/blob/master/examples/fsdp/train_decoder_only_fsdp_v2.py) to use SPMD to express FSDP(Fully Sharded Data Parallel). -3. [SPMD advanced topics](https://github.com/pytorch/xla/blob/master/docs/spmd_advanced.md) -4. [Spmd Distributed Checkpoint](https://github.com/pytorch/xla/blob/master/docs/spmd_distributed_checkpoint.md) \ No newline at end of file diff --git a/docs/spmd_distributed_checkpoint.md b/docs/spmd_distributed_checkpoint.md deleted file mode 100644 index 20cb8ed3db1..00000000000 --- a/docs/spmd_distributed_checkpoint.md +++ /dev/null @@ -1,125 +0,0 @@ -# Distributed Checkpointing -PyTorch/XLA SPMD is compatible with the [torch.distributed.checkpoint](https://pytorch.org/docs/stable/distributed.checkpoint.html) library through a dedicated `Planner` instance. Users are able to synchronously save and load checkpoints through this common interface. - -The SPMDSavePlanner and SPMDLoadPlanner ([src](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/distributed_checkpoint.py)) classes enable the `save` and `load` functions to operate directly on the shards of an `XLAShardedTensor`, enabling all of the benefits of distributed checkpointing in SPMD training. - -Here is a demonstration of the synchronous distributed checkpointing API: - -```python -import torch.distributed.checkpoint as dist_cp -import torch_xla.experimental.distributed_checkpoint as xc - -# Saving a state_dict -state_dict = { - "model": model.state_dict(), - "optim": optim.state_dict(), -} - -dist_cp.save( - state_dict=state_dict, - storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), - planner=xc.SPMDSavePlanner(), -) -... - -# Loading the model's state_dict from the checkpoint. The model should -# already be on the XLA device and have the desired sharding applied. -state_dict = { - "model": model.state_dict(), -} - -dist_cp.load( - state_dict=state_dict, - storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), - planner=xc.SPMDLoadPlanner(), -) -model.load_state_dict(state_dict["model"]) -``` - -#### CheckpointManager - -The experimental [CheckpointManager](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/distributed_checkpoint/manager.py#L40) -interface provides a higher-level API over the `torch.distributed.checkpoint` -functions to enable a few key features: - -- **Managed checkpoints**: Each checkpoint taken by the `CheckpointManager` is -identified by the step at which it was taken. All steps tracked are accessible -through the `CheckpointManager.all_steps` method, and any tracked steps can be -restored using `CheckpointManager.restore`. -- **Asynchronous checkpointing**: Checkpoints taken through the -`CheckpointManager.save_async` API are written to persistent storage -asynchronously to unblock training for the duration of the checkpoint. The -input sharded state_dict is first moved to CPU before the checkpoint is -dispatched to a background thread. -- **Auto-checkpointing on preemption**: On Cloud TPU, preemptions can be detected -and a checkpoint taken before the process is terminated. To use, ensure your -TPU is provisioned through a QueuedResource with -[Autocheckpointing enabled](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/queued-resources/create#--autocheckpoint-enabled), -and ensure the `chkpt_on_preemption` parameter is set when constructing the -CheckpointManager (this option is enabled by default). -- **FSSpec Support**: `CheckpointManager` uses an fsspec storage backend to enable -checkpointing directly to any fsspec-compatible filesystem, including GCS. - -Example usage of the CheckpointManager is below: - -```python -from torch_xla.experimental.distributed_checkpoint import CheckpointManager, prime_optimizer - -# Create a CheckpointManager to checkpoint every 10 steps into GCS. -chkpt_mgr = CheckpointManager('gs://my-bucket/my-experiment', 10) - -# Select a checkpoint to restore from, and restore if applicable -tracked_steps = chkpt_mgr.all_steps() -if tracked_steps: - # Choose the highest step - best_step = max(tracked_steps) - # Before restoring the checkpoint, the optimizer state must be primed - # to allow state to be loaded into it. - prime_optimizer(optim) - state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()} - chkpt_mgr.restore(best_step, state_dict) - model.load_state_dict(state_dict['model']) - optim.load_state_dict(state_dict['optim']) - -# Call `save` or `save_async` every step within the train loop. These methods -# return True when a checkpoint is taken. -for step, data in enumerate(dataloader): - ... - state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()} - if chkpt_mgr.save_async(step, state_dict): - print(f'Checkpoint taken at step {step}') -``` - -##### Restoring Optimizer State - -In distributed checkpointing, the state_dicts are loaded in-place, and only the -required shards of the checkpoint are loaded. Since optimizer states are lazily -created, the state isn't present until the first `optimizer.step` call, and -attempts to load an unprimed optimizer will fail. - -The utility method `prime_optimizer` is provided for this: it runs a fake train -step by setting all gradients to zero and calling `optimizer.step`. *This is a -destructive method and will touch both model parameters and optimizer state*, -so it should only be called just prior to restoration. - -### Process Groups -To use `torch.distributed` APIs such as distributed checkpointing, a process -group is required. In SPMD mode, the `xla` backend is not supported since the -compiler is responsible for all collectives. - -Instead, a CPU process group such as `gloo` must be used. On TPUs, the `xla://` -init_method is still supported to discover the master IP, global world size, -and host rank. An example initialization is below: - -```python -import torch.distributed as dist -# Import to register the `xla://` init_method -import torch_xla.distributed.xla_backend -import torch_xla.runtime as xr - -xr.use_spmd() - -# The `xla://` init_method will automatically discover master worker IP, rank, -# and global world size without requiring environment configuration on TPUs. -dist.init_process_group('gloo', init_method='xla://') -``` \ No newline at end of file diff --git a/docs/spmd_gpu.md b/docs/spmd_gpu.md deleted file mode 100644 index ced0e03ec57..00000000000 --- a/docs/spmd_gpu.md +++ /dev/null @@ -1,40 +0,0 @@ -# Running SPMD on GPU - -PyTorch/XLA supports SPMD on NVIDIA GPU (single-node or multi-nodes). The training/inference script remains the same as the one used for TPU, such as this [ResNet script](https://github.com/pytorch/xla/blob/1dc78948c0c9d018d8d0d2b4cce912552ab27083/test/spmd/test_train_spmd_imagenet.py). To execute the script using SPMD, we leverage `torchrun`: - -``` -PJRT_DEVICE=CUDA \ -torchrun \ ---nnodes=${NUM_GPU_MACHINES} \ ---node_rank=${RANK_OF_CURRENT_MACHINE} \ ---nproc_per_node=1 \ ---rdzv_endpoint=":" \ -training_or_inference_script_using_spmd.py -``` -- `--nnodes`: how many GPU machines to be used. -- `--node_rank`: the index of the current GPU machines. The value can be 0, 1, ..., ${NUMBER_GPU_VM}-1. -- `--nproc_per_node`: the value must be 1 due to the SPMD requirement. -- `--rdzv_endpoint`: the endpoint of the GPU machine with node_rank==0, in the form `host:port`. The host will be the internal IP address. The `port` can be any available port on the machine. For single-node training/inference, this parameter can be omitted. - -For example, if you want to train a ResNet model on 2 GPU machines using SPMD, you can run the script below on the first machine: -``` -XLA_USE_SPMD=1 PJRT_DEVICE=CUDA \ -torchrun \ ---nnodes=2 \ ---node_rank=0 \ ---nproc_per_node=1 \ ---rdzv_endpoint=":12355" \ -pytorch/xla/test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 128 -``` -and run the following on the second machine: -``` -XLA_USE_SPMD=1 PJRT_DEVICE=CUDA \ -torchrun \ ---nnodes=2 \ ---node_rank=1 \ ---nproc_per_node=1 \ ---rdzv_endpoint=":12355" \ -pytorch/xla/test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 128 -``` - -For more information, please refer to the [SPMD support on GPU RFC](https://github.com/pytorch/xla/issues/6256). \ No newline at end of file From e5e6d52861731c2c67d7d286ab3a1e4518148d37 Mon Sep 17 00:00:00 2001 From: barney-s <6457279+barney-s@users.noreply.github.com> Date: Thu, 10 Oct 2024 08:56:53 -0700 Subject: [PATCH 4/9] Enabling log_normal tests (#8247) --- experimental/torch_xla2/test/test_ops.py | 2 +- experimental/torch_xla2/torch_xla2/decompositions.py | 1 + experimental/torch_xla2/torch_xla2/ops/jaten.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index fde60bd0205..577b4d9a756 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -43,7 +43,6 @@ "linalg.tensorsolve", "linalg.vector_norm", "linspace", - "log_normal", "logspace", "lu", "lu_solve", @@ -155,6 +154,7 @@ 'nn.functional.feature_alpha_dropout', 'cauchy', 'exponential', + 'log_normal', } atol_dict = {"linalg.eig": (2e0, 3e0), diff --git a/experimental/torch_xla2/torch_xla2/decompositions.py b/experimental/torch_xla2/torch_xla2/decompositions.py index 354ac3d93bf..8eb813cd284 100644 --- a/experimental/torch_xla2/torch_xla2/decompositions.py +++ b/experimental/torch_xla2/torch_xla2/decompositions.py @@ -297,4 +297,5 @@ def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, w) -> Tens torch.ops.aten.nll_loss2d_backward, torch.ops.aten.bernoulli_.Tensor, torch.ops.aten.bernoulli_.float, + torch.ops.aten.log_normal, ]) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 28062b05615..6886ce6abc4 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -47,6 +47,7 @@ torch.ops.aten.logical_not_: torch.ops.aten.logical_not, torch.ops.aten.unsqueeze_: torch.ops.aten.unsqueeze, torch.ops.aten.transpose_: torch.ops.aten.transpose, + torch.ops.aten.log_normal_: torch.ops.aten.log_normal, } From fe03cd23517b6b437d0c26d9c6da07e2a2d909d8 Mon Sep 17 00:00:00 2001 From: barney-s <6457279+barney-s@users.noreply.github.com> Date: Thu, 10 Oct 2024 09:01:22 -0700 Subject: [PATCH 5/9] enable linalg.vector_norm tests (#8249) --- experimental/torch_xla2/test/test_ops.py | 1 - experimental/torch_xla2/torch_xla2/ops/jaten.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 577b4d9a756..907d4da51a8 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -41,7 +41,6 @@ "linalg.matrix_norm", "linalg.matrix_power", "linalg.tensorsolve", - "linalg.vector_norm", "linspace", "logspace", "lu", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 6886ce6abc4..1753ef843ac 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1291,7 +1291,8 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): # Special cases (for efficiency and clarity) if ord == 0: if self.shape == (): - result = jnp.array(float(self != 0)) + # float sets it to float64. set it back to input type + result = jnp.astype(jnp.array(float(self != 0)), self.dtype) else: result = _with_reduction_scalar(jnp.sum, jnp.where(self != 0, 1, 0), dim, keepdim) From 5a5165fd3f49e188868148708b0afef24825bfb7 Mon Sep 17 00:00:00 2001 From: barney-s <6457279+barney-s@users.noreply.github.com> Date: Thu, 10 Oct 2024 09:10:30 -0700 Subject: [PATCH 6/9] enable linspace & logspace tests (#8236) Co-authored-by: Manfei <41607353+ManfeiBai@users.noreply.github.com> --- experimental/torch_xla2/test/test_ops.py | 11 +++++++++-- experimental/torch_xla2/torch_xla2/ops/jaten.py | 15 +++++++++++++-- experimental/torch_xla2/torch_xla2/ops/op_base.py | 8 ++++++-- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 907d4da51a8..6850a786292 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -41,8 +41,6 @@ "linalg.matrix_norm", "linalg.matrix_power", "linalg.tensorsolve", - "linspace", - "logspace", "lu", "lu_solve", "lu_unpack", @@ -260,6 +258,15 @@ def test_reference_eager(self, device, dtype, op): continue check_output = op.name not in random_ops + #print("[DEBUG] sample_input: ", sample_input) + + # TODO: this is a workaround to skip int64 cast for linspace + # reference: https://github.com/pytorch/xla/issues/7505#issuecomment-2400895692 and subsequent comments + # we have opened a bug in pytorch: https://github.com/pytorch/pytorch/issues/137546 + if op.name == "linspace": + if 'dtype' in sample_input.kwargs: + if sample_input.kwargs['dtype'] == torch.int64: + sample_input.kwargs['dtype'] = torch.float if op.name == "special.polygamma": # The polygamma function is inaccurate for values < 1. # To avoid errors during testing, replace values below 1 with 1. diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 1753ef843ac..495269e0cf7 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -41,7 +41,8 @@ torch.ops.aten.random_: torch.ops.aten.uniform, torch.ops.aten.uniform_: torch.ops.aten.uniform, torch.ops.aten.relu_: torch.ops.aten.relu, - torch.ops.aten.squeeze_: torch.ops.aten.squeeze, + # squeeze_ is expected to change tensor's shape. So replace with new value + torch.ops.aten.squeeze_: (torch.ops.aten.squeeze, True), torch.ops.aten.clamp_: torch.ops.aten.clamp, torch.ops.aten.ceil_: torch.ops.aten.ceil, torch.ops.aten.logical_not_: torch.ops.aten.logical_not, @@ -52,6 +53,10 @@ def make_mutation(op): + if type(mutation_ops_to_functional[op]) is tuple: + return op_base.InplaceOp(mutation_ops_to_functional[op][0], + replace=mutation_ops_to_functional[op][1], + position_to_mutate=0) return op_base.InplaceOp(mutation_ops_to_functional[op], position_to_mutate=0) @@ -104,7 +109,13 @@ def _aten_add(x, y, *, alpha=1): @op(torch.ops.aten.copy_, is_jax_function=False) def _aten_copy(x, y, memory_format=None): - x._elem = y._elem.astype(x._elem.dtype) + if x.ndim == 1 and y.ndim == 0: + # case of torch.empty((1,)).copy_(tensor(N)) + # we need to return 0D tensor([N]) and not scalar tensor(N) + # ref: https://github.com/pytorch/xla/issues/7505#issuecomment-2395319131 + x._elem = jnp.array([y._elem.astype(x._elem.dtype)]) + else: + x._elem = y._elem.astype(x._elem.dtype) return x diff --git a/experimental/torch_xla2/torch_xla2/ops/op_base.py b/experimental/torch_xla2/torch_xla2/ops/op_base.py index 2c4176a361d..203ec5a3686 100644 --- a/experimental/torch_xla2/torch_xla2/ops/op_base.py +++ b/experimental/torch_xla2/torch_xla2/ops/op_base.py @@ -17,13 +17,17 @@ class InplaceOp: - def __init__(self, functional_op, position_to_mutate=0): + def __init__(self, functional_op, replace=False, position_to_mutate=0): self.functional = functional_op + self.replace = replace self.position_to_mutate = position_to_mutate def __call__(self, *args, **kwargs): to_mutate = args[0] - to_mutate.copy_(self.functional(*args, **kwargs)) + if self.replace: + to_mutate._elem = self.functional(*args, **kwargs)._elem + else: + to_mutate.copy_(self.functional(*args, **kwargs)) return to_mutate From ce7856febc3bd500b1664866c9a38271081cf00b Mon Sep 17 00:00:00 2001 From: David Huang Date: Thu, 10 Oct 2024 09:17:57 -0700 Subject: [PATCH 7/9] Fix op info test for linalg.inv and linalg.inv_ex (#8239) --- experimental/torch_xla2/test/test_ops.py | 2 -- experimental/torch_xla2/torch_xla2/ops/jaten.py | 10 ++++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 6850a786292..5ab64435c04 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -29,8 +29,6 @@ "linalg.cholesky", "linalg.cholesky_ex", "linalg.det", - "linalg.inv", - "linalg.inv_ex", "linalg.ldl_factor", "linalg.ldl_factor_ex", "linalg.ldl_solve", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 495269e0cf7..d4bfb1c7842 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2346,6 +2346,12 @@ def _aten_histc(input, bins=100, min=0, max=0): return hist +# Used by some linalg functions to raise an exception +# when check_errors == True. This is currently a no-op. +@op(torch.ops.aten._linalg_check_errors) +def _aten_linalg_check_errors(A, api_name, is_matrix): + ... + @op(torch.ops.aten.hypot) def _aten_hypot(input, other): return jnp.hypot(input, other) @@ -2371,6 +2377,10 @@ def _aten_linalg_eig(A): def _aten_linalg_eigh(A, UPLO='L'): return jnp.linalg.eigh(A, UPLO) +@op(torch.ops.aten.linalg_inv_ex) +def _aten_linalg_inv_ex(A): + return jnp.linalg.inv(A), jnp.zeros(A.shape[:-2], jnp.int32) + @op(torch.ops.aten.linalg_lu) def _aten_linalg_lu(A, pivot=True, out=None): From 99c391089566b0a03e155f0418a3712cc49187b2 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 10 Oct 2024 17:18:25 -0400 Subject: [PATCH 8/9] [torchxla2] add `argwhere` (#8073) --- experimental/torch_xla2/test/test_ops.py | 1 - experimental/torch_xla2/torch_xla2/ops/jaten.py | 3 +++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 5ab64435c04..811f8499199 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -108,7 +108,6 @@ "unique", "unravel_index", "var_mean", - "argwhere", "nanmean", "nn.functional.upsample_bilinear", "randint", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index d4bfb1c7842..2cccd696c53 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2543,6 +2543,9 @@ def _aten_nextafter(input, other, *, out=None): # aten.nonzero @op(torch.ops.aten.nonzero) def _aten_nonzero(x): + if jnp.ndim(x) == 0: # when x is scalar, return torch.tensor([], size=(1, 0), dtype=torch.int64) + res = torch.empty(1, 0, dtype=torch.int64) + return jnp.array(res.numpy()) index_tuple = jnp.nonzero(x) index_tuple = [jnp.expand_dims(p, -1) for p in index_tuple] return jnp.concatenate(index_tuple, axis=-1) From ab192b115b9e1c21d2abf58c20909bf005c0c1b8 Mon Sep 17 00:00:00 2001 From: qihqi Date: Thu, 10 Oct 2024 14:29:12 -0700 Subject: [PATCH 9/9] Revert "Fix op info test for linalg.inv and linalg.inv_ex (#8239)" (#8252) --- experimental/torch_xla2/test/test_ops.py | 2 ++ experimental/torch_xla2/torch_xla2/ops/jaten.py | 10 ---------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 811f8499199..c8d29fa296e 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -29,6 +29,8 @@ "linalg.cholesky", "linalg.cholesky_ex", "linalg.det", + "linalg.inv", + "linalg.inv_ex", "linalg.ldl_factor", "linalg.ldl_factor_ex", "linalg.ldl_solve", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 2cccd696c53..8afefe01db1 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2346,12 +2346,6 @@ def _aten_histc(input, bins=100, min=0, max=0): return hist -# Used by some linalg functions to raise an exception -# when check_errors == True. This is currently a no-op. -@op(torch.ops.aten._linalg_check_errors) -def _aten_linalg_check_errors(A, api_name, is_matrix): - ... - @op(torch.ops.aten.hypot) def _aten_hypot(input, other): return jnp.hypot(input, other) @@ -2377,10 +2371,6 @@ def _aten_linalg_eig(A): def _aten_linalg_eigh(A, UPLO='L'): return jnp.linalg.eigh(A, UPLO) -@op(torch.ops.aten.linalg_inv_ex) -def _aten_linalg_inv_ex(A): - return jnp.linalg.inv(A), jnp.zeros(A.shape[:-2], jnp.int32) - @op(torch.ops.aten.linalg_lu) def _aten_linalg_lu(A, pivot=True, out=None):