From 5f3cba24b67a132b10ef1c4739d65332a822a06a Mon Sep 17 00:00:00 2001 From: torchxlabot2 Date: Wed, 30 Oct 2024 17:21:06 +0000 Subject: [PATCH] Update doc from commit 3efe1ebf827ff51d804656c9c5c241643268a32e --- master/_modules/index.html | 2 +- master/_modules/torch_xla/core/xla_model.html | 9 ++++----- master/_modules/torch_xla/debug/metrics.html | 2 +- .../distributed/parallel_loader.html | 2 +- .../distributed/spmd/xla_sharding.html | 2 +- .../distributed/xla_multiprocessing.html | 2 +- .../torch_xla/experimental/eager.html | 2 +- master/_modules/torch_xla/runtime.html | 2 +- master/_modules/torch_xla/torch_xla.html | 2 +- master/_sources/perf/ddp.md.txt | 3 +-- master/accelerators/gpu.html | 2 +- master/accelerators/tpu.html | 2 +- master/contribute/bazel.html | 2 +- master/contribute/codegen_migration.html | 2 +- master/contribute/configure-environment.html | 2 +- master/contribute/op_lowering.html | 2 +- master/contribute/plugins.html | 2 +- master/features/distop.html | 2 +- master/features/pallas.html | 2 +- master/features/stablehlo.html | 2 +- master/features/triton.html | 2 +- master/genindex.html | 2 +- master/index.html | 2 +- master/learn/api-guide.html | 2 +- master/learn/dynamic_shape.html | 2 +- master/learn/eager.html | 2 +- master/learn/pjrt.html | 2 +- master/learn/pytorch-on-xla-devices.html | 2 +- master/learn/troubleshoot.html | 2 +- master/learn/xla-overview.html | 2 +- master/notes/source_of_recompilation.html | 2 +- master/objects.inv | Bin 1571 -> 1571 bytes master/perf/amp.html | 2 +- master/perf/ddp.html | 5 ++--- master/perf/dynamo.html | 2 +- master/perf/fori_loop.html | 2 +- master/perf/fsdp.html | 2 +- master/perf/fsdpv2.html | 2 +- master/perf/quantized_ops.html | 2 +- master/perf/recompilation.html | 2 +- master/perf/spmd_advanced.html | 2 +- master/perf/spmd_basic.html | 2 +- master/perf/spmd_distributed_checkpoint.html | 2 +- master/perf/spmd_gpu.html | 2 +- master/py-modindex.html | 2 +- master/search.html | 2 +- master/searchindex.js | 2 +- 47 files changed, 50 insertions(+), 53 deletions(-) diff --git a/master/_modules/index.html b/master/_modules/index.html index cb24026f2ee..e47b0e8ee74 100644 --- a/master/_modules/index.html +++ b/master/_modules/index.html @@ -265,7 +265,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/_modules/torch_xla/core/xla_model.html b/master/_modules/torch_xla/core/xla_model.html index 261b8697428..202ac9f2cb1 100644 --- a/master/_modules/torch_xla/core/xla_model.html +++ b/master/_modules/torch_xla/core/xla_model.html @@ -265,7 +265,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
@@ -1722,10 +1722,9 @@

Source code for torch_xla.core.xla_model

         if sharding and tensor.dim() > 0 and (tensor.size()[0] %
                                               local_runtime_device_count) != 0:
           raise RuntimeError(
-              "When minibatch is configured, batch dimension of the tensor " +
-              "must be divisible by local runtime device count.input data shape "
-              +
-              f"={tensor.size()}, local_runtime_device_count = {local_runtime_device_count}"
+              "When minibatch is configured, the per-host batch size must be divisible "
+              + "by local runtime device count. Per host input data shape " +
+              f"= {tensor.size()}, local_runtime_device_count = {local_runtime_device_count}"
           )
 
     xtensors = torch_xla._XLAC._xla_tensors_from_aten(tensors, devices,
diff --git a/master/_modules/torch_xla/debug/metrics.html b/master/_modules/torch_xla/debug/metrics.html
index ab017db3779..545563fff23 100644
--- a/master/_modules/torch_xla/debug/metrics.html
+++ b/master/_modules/torch_xla/debug/metrics.html
@@ -265,7 +265,7 @@
               
               
                 
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/_modules/torch_xla/distributed/parallel_loader.html b/master/_modules/torch_xla/distributed/parallel_loader.html index 90b525e9dc1..b4ad4d16b62 100644 --- a/master/_modules/torch_xla/distributed/parallel_loader.html +++ b/master/_modules/torch_xla/distributed/parallel_loader.html @@ -265,7 +265,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/_modules/torch_xla/distributed/spmd/xla_sharding.html b/master/_modules/torch_xla/distributed/spmd/xla_sharding.html index 9fb6cafe472..2a178726554 100644 --- a/master/_modules/torch_xla/distributed/spmd/xla_sharding.html +++ b/master/_modules/torch_xla/distributed/spmd/xla_sharding.html @@ -265,7 +265,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/_modules/torch_xla/distributed/xla_multiprocessing.html b/master/_modules/torch_xla/distributed/xla_multiprocessing.html index 32c8e71fc9d..d361a261c97 100644 --- a/master/_modules/torch_xla/distributed/xla_multiprocessing.html +++ b/master/_modules/torch_xla/distributed/xla_multiprocessing.html @@ -265,7 +265,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/_modules/torch_xla/experimental/eager.html b/master/_modules/torch_xla/experimental/eager.html index d8ab06e72c2..5904737234a 100644 --- a/master/_modules/torch_xla/experimental/eager.html +++ b/master/_modules/torch_xla/experimental/eager.html @@ -265,7 +265,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/_modules/torch_xla/runtime.html b/master/_modules/torch_xla/runtime.html index 879694d2771..06594d9629d 100644 --- a/master/_modules/torch_xla/runtime.html +++ b/master/_modules/torch_xla/runtime.html @@ -265,7 +265,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/_modules/torch_xla/torch_xla.html b/master/_modules/torch_xla/torch_xla.html index f1bfd4d5e72..dd6ef337f64 100644 --- a/master/_modules/torch_xla/torch_xla.html +++ b/master/_modules/torch_xla/torch_xla.html @@ -265,7 +265,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/_sources/perf/ddp.md.txt b/master/_sources/perf/ddp.md.txt index 826895d5ac8..a84946cc728 100644 --- a/master/_sources/perf/ddp.md.txt +++ b/master/_sources/perf/ddp.md.txt @@ -77,8 +77,7 @@ import torch_xla.runtime as xr import torch_xla.distributed.xla_backend def setup(rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' + os.environ['PJRT_DEVICE'] = 'TPU' # initialize the xla process group dist.init_process_group("xla", rank=rank, world_size=world_size) diff --git a/master/accelerators/gpu.html b/master/accelerators/gpu.html index 4b23239a785..f5e1180b9c0 100644 --- a/master/accelerators/gpu.html +++ b/master/accelerators/gpu.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/accelerators/tpu.html b/master/accelerators/tpu.html index 67e54752caf..54a6292a3d8 100644 --- a/master/accelerators/tpu.html +++ b/master/accelerators/tpu.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/contribute/bazel.html b/master/contribute/bazel.html index 5d7e15ae4c9..bdadfbb9155 100644 --- a/master/contribute/bazel.html +++ b/master/contribute/bazel.html @@ -266,7 +266,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/contribute/codegen_migration.html b/master/contribute/codegen_migration.html index fc9d188799d..107d05eb811 100644 --- a/master/contribute/codegen_migration.html +++ b/master/contribute/codegen_migration.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/contribute/configure-environment.html b/master/contribute/configure-environment.html index a3411588eaa..f898c715642 100644 --- a/master/contribute/configure-environment.html +++ b/master/contribute/configure-environment.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/contribute/op_lowering.html b/master/contribute/op_lowering.html index 1a379678771..57893f0c5c7 100644 --- a/master/contribute/op_lowering.html +++ b/master/contribute/op_lowering.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/contribute/plugins.html b/master/contribute/plugins.html index d8134f035b3..d51f036440b 100644 --- a/master/contribute/plugins.html +++ b/master/contribute/plugins.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/features/distop.html b/master/features/distop.html index 32768b2dda1..b1b9f9129b7 100644 --- a/master/features/distop.html +++ b/master/features/distop.html @@ -265,7 +265,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/features/pallas.html b/master/features/pallas.html index 921128dd423..6b0c1a6fdf3 100644 --- a/master/features/pallas.html +++ b/master/features/pallas.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/features/stablehlo.html b/master/features/stablehlo.html index 9b538067332..b554fbdc1c8 100644 --- a/master/features/stablehlo.html +++ b/master/features/stablehlo.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/features/triton.html b/master/features/triton.html index ab0ae30efe1..c65d5ec6f4f 100644 --- a/master/features/triton.html +++ b/master/features/triton.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/genindex.html b/master/genindex.html index 257c39374dc..1f7b8a0bc73 100644 --- a/master/genindex.html +++ b/master/genindex.html @@ -265,7 +265,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/index.html b/master/index.html index 77e079a3b6c..3b388ce58cc 100644 --- a/master/index.html +++ b/master/index.html @@ -266,7 +266,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/learn/api-guide.html b/master/learn/api-guide.html index 34594bd245b..e71c934107a 100644 --- a/master/learn/api-guide.html +++ b/master/learn/api-guide.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/learn/dynamic_shape.html b/master/learn/dynamic_shape.html index 85be1cb953d..f83c510cd98 100644 --- a/master/learn/dynamic_shape.html +++ b/master/learn/dynamic_shape.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/learn/eager.html b/master/learn/eager.html index c38ee38b9f6..86a9aa7b2de 100644 --- a/master/learn/eager.html +++ b/master/learn/eager.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/learn/pjrt.html b/master/learn/pjrt.html index 6644d757f92..c160c8fbc4a 100644 --- a/master/learn/pjrt.html +++ b/master/learn/pjrt.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/learn/pytorch-on-xla-devices.html b/master/learn/pytorch-on-xla-devices.html index df7ac70441c..1c32f014e2e 100644 --- a/master/learn/pytorch-on-xla-devices.html +++ b/master/learn/pytorch-on-xla-devices.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/learn/troubleshoot.html b/master/learn/troubleshoot.html index 8c945439aa5..31ecbdf3b92 100644 --- a/master/learn/troubleshoot.html +++ b/master/learn/troubleshoot.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/learn/xla-overview.html b/master/learn/xla-overview.html index f37267b1ba9..1eaac5ad551 100644 --- a/master/learn/xla-overview.html +++ b/master/learn/xla-overview.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/notes/source_of_recompilation.html b/master/notes/source_of_recompilation.html index 1fb8853f25a..5a0d2e6bd0f 100644 --- a/master/notes/source_of_recompilation.html +++ b/master/notes/source_of_recompilation.html @@ -265,7 +265,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/objects.inv b/master/objects.inv index f37b36f22393a974117c2240139f13873be42dd6..80d7e63189c892ed0754113dcece6d381040a8c1 100644 GIT binary patch delta 18 ZcmZ3?vzTW>0K0K&TB>1c(#G&JtN=S=26_Mh delta 18 ZcmZ3?vzTW>0K0`{s)>1$@y75otN=HC1@!;` diff --git a/master/perf/amp.html b/master/perf/amp.html index 2c3929b2c66..7384e5d9352 100644 --- a/master/perf/amp.html +++ b/master/perf/amp.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/perf/ddp.html b/master/perf/ddp.html index 00c2113fa15..071c0438bf7 100644 --- a/master/perf/ddp.html +++ b/master/perf/ddp.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
@@ -486,8 +486,7 @@

How to use DistributedDataParallelimport torch_xla.distributed.xla_backend def setup(rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' + os.environ['PJRT_DEVICE'] = 'TPU' # initialize the xla process group dist.init_process_group("xla", rank=rank, world_size=world_size) diff --git a/master/perf/dynamo.html b/master/perf/dynamo.html index 42e594824e1..fa76062a19f 100644 --- a/master/perf/dynamo.html +++ b/master/perf/dynamo.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/perf/fori_loop.html b/master/perf/fori_loop.html index 35a30689777..240446256f2 100644 --- a/master/perf/fori_loop.html +++ b/master/perf/fori_loop.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/perf/fsdp.html b/master/perf/fsdp.html index 6d109e3bd60..5883efedb95 100644 --- a/master/perf/fsdp.html +++ b/master/perf/fsdp.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/perf/fsdpv2.html b/master/perf/fsdpv2.html index 10d84e8936a..4dd4804e6ec 100644 --- a/master/perf/fsdpv2.html +++ b/master/perf/fsdpv2.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/perf/quantized_ops.html b/master/perf/quantized_ops.html index ab1abafc132..64e7196323a 100644 --- a/master/perf/quantized_ops.html +++ b/master/perf/quantized_ops.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/perf/recompilation.html b/master/perf/recompilation.html index 482ed4fa902..f7d35ae678a 100644 --- a/master/perf/recompilation.html +++ b/master/perf/recompilation.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/perf/spmd_advanced.html b/master/perf/spmd_advanced.html index 92f34ce9d80..d9d480ffdda 100644 --- a/master/perf/spmd_advanced.html +++ b/master/perf/spmd_advanced.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/perf/spmd_basic.html b/master/perf/spmd_basic.html index 28b26150f1d..b491a50c4cd 100644 --- a/master/perf/spmd_basic.html +++ b/master/perf/spmd_basic.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/perf/spmd_distributed_checkpoint.html b/master/perf/spmd_distributed_checkpoint.html index cf5c534a2b6..7e194973203 100644 --- a/master/perf/spmd_distributed_checkpoint.html +++ b/master/perf/spmd_distributed_checkpoint.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/perf/spmd_gpu.html b/master/perf/spmd_gpu.html index 1d2697fe86f..19aeaa66028 100644 --- a/master/perf/spmd_gpu.html +++ b/master/perf/spmd_gpu.html @@ -267,7 +267,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/py-modindex.html b/master/py-modindex.html index 77d93f6e967..a7c27ad159d 100644 --- a/master/py-modindex.html +++ b/master/py-modindex.html @@ -268,7 +268,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/search.html b/master/search.html index 3e0fb087afc..f913b3a179a 100644 --- a/master/search.html +++ b/master/search.html @@ -265,7 +265,7 @@
- master (2.6.0+git89e47b3 ) + master (2.6.0+git3efe1eb )
diff --git a/master/searchindex.js b/master/searchindex.js index f2ec1bf4e75..7838b685bef 100644 --- a/master/searchindex.js +++ b/master/searchindex.js @@ -1 +1 @@ -Search.setIndex({"docnames": ["accelerators/gpu", "accelerators/tpu", "contribute/bazel", "contribute/codegen_migration", "contribute/configure-environment", "contribute/op_lowering", "contribute/plugins", "features/distop", "features/pallas", "features/stablehlo", "features/triton", "index", "learn/api-guide", "learn/dynamic_shape", "learn/eager", "learn/pjrt", "learn/pytorch-on-xla-devices", "learn/troubleshoot", "learn/xla-overview", "notes/source_of_recompilation", "perf/amp", "perf/ddp", "perf/dynamo", "perf/fori_loop", "perf/fsdp", "perf/fsdpv2", "perf/quantized_ops", "perf/recompilation", "perf/spmd_advanced", "perf/spmd_basic", "perf/spmd_distributed_checkpoint", "perf/spmd_gpu"], "filenames": ["accelerators/gpu.md", "accelerators/tpu.md", "contribute/bazel.md", "contribute/codegen_migration.md", "contribute/configure-environment.md", "contribute/op_lowering.md", "contribute/plugins.md", "features/distop.md", "features/pallas.md", "features/stablehlo.md", "features/triton.md", "index.rst", "learn/api-guide.rst", "learn/dynamic_shape.md", "learn/eager.md", "learn/pjrt.md", "learn/pytorch-on-xla-devices.md", "learn/troubleshoot.md", "learn/xla-overview.md", "notes/source_of_recompilation.md", "perf/amp.md", "perf/ddp.md", "perf/dynamo.md", "perf/fori_loop.md", "perf/fsdp.md", "perf/fsdpv2.md", "perf/quantized_ops.md", "perf/recompilation.md", "perf/spmd_advanced.md", "perf/spmd_basic.md", "perf/spmd_distributed_checkpoint.md", "perf/spmd_gpu.md"], "titles": ["Learn about GPUs", "Learn about TPUs", "Bazel in Pytorch/XLA", "Codegen migration Guide", "Configure a development environment", "OP Lowering Guide", "Custom Hardware Plugins", "Support of Torch Distributed API in PyTorch/XLA", "Custom Kernels via Pallas", "Torch Export to StableHLO", "Custom GPU Kernels via Triton", "PyTorch/XLA documentation", "PyTorch/XLA API", "Dynamic shape", "Eager Mode + Compile API", "PJRT Runtime", "PyTorch on XLA Devices", "Troubleshoot", "Pytorch/XLA overview", "Source of recompilations in torch_xla", "Automatic Mixed Precision", "How to do DistributedDataParallel(DDP)", "TorchDynamo integration in PyTorch XLA", "Optimize memory utilization using while_loop", "Fully Sharded Data Parallel in PyTorch XLA", "Fully Sharded Data Parallel using SPMD", "Quantized Operations for XLA (Experimental feature)", "Source of recompilations in Pytorch/XLA", "PyTorch/XLA SPMD advanced topics", "PyTorch/XLA SPMD User Guide", "Distributed Checkpointing", "Running SPMD on GPU"], "terms": {"For": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 26, 27, 28, 29, 31], "inform": [0, 1, 4, 10, 12, 14, 15, 16, 17, 18, 19, 27, 31], "googl": [0, 1, 8, 15, 16], "cloud": [0, 1, 2, 4, 6, 11, 15, 16, 22, 30], "see": [0, 1, 2, 3, 4, 5, 6, 9, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 24, 27], "machin": [0, 2, 4, 15, 17, 18, 31], "type": [0, 4, 6, 9, 12, 15, 16, 17, 18, 20, 21], "ar": [1, 2, 3, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 24, 26, 27, 28, 29, 30], "custom": [1, 3, 4, 7, 11, 12, 19, 21, 24, 26, 27, 28, 29], "design": [1, 15, 16, 22, 25, 29], "ai": 1, "acceler": [1, 4, 12, 13, 15, 16, 18, 20], "which": [1, 2, 3, 5, 6, 7, 9, 12, 13, 15, 16, 17, 18, 19, 20, 22, 24, 25, 27, 28, 30], "optim": [1, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 24, 25, 27], "train": [1, 8, 12, 13, 16, 17, 18, 20, 28, 30, 31], "infer": [1, 3, 12, 15, 20, 28, 31], "larg": [1, 13, 15, 18, 19, 24, 27, 29], "model": [1, 3, 5, 9, 10, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 24, 25, 27, 28, 29, 30, 31], "thei": [1, 2, 5, 6, 7, 12, 15, 16, 17, 18, 19, 20, 27, 28, 29], "ideal": [1, 2, 3, 19, 22, 27], "varieti": 1, "us": [1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 14, 15, 16, 17, 18, 20, 22, 24, 28, 30, 31], "case": [1, 2, 3, 5, 9, 12, 15, 16, 17, 18, 22, 25, 28], "chatbot": 1, "code": [1, 3, 5, 9, 10, 12, 14, 15, 16, 17, 19, 21, 22, 27, 28], "gener": [1, 5, 12, 14, 15, 16, 17, 18, 19, 27], "media": 1, "content": [1, 12], "synthet": 1, "speech": 1, "vision": [1, 24], "servic": [1, 2, 15], "recommend": [1, 2, 3, 4, 5, 12, 14, 15, 16, 20, 28], "engin": [1, 17], "person": 1, "among": 1, "other": [1, 2, 3, 5, 8, 12, 13, 15, 16, 17, 18, 19, 20, 21, 26, 27, 29], "scale": [1, 9, 12, 15, 20, 22, 29], "cost": [1, 22], "effici": [1, 9, 17, 18, 22], "wide": [1, 5, 19, 27], "rang": [1, 5, 12, 15, 25, 28, 29], "workload": [1, 15, 16, 17, 28, 29], "span": [1, 3], "fine": 1, "tune": [1, 28], "provid": [1, 2, 3, 5, 6, 8, 9, 12, 16, 17, 18, 19, 20, 22, 23, 24, 26, 27, 28, 29, 30], "versatil": 1, "lead": [1, 17, 18], "framework": [1, 9, 11, 14, 19, 26, 27], "includ": [1, 2, 5, 12, 15, 17, 18, 19, 20, 23, 27, 30], "pytorch": [1, 5, 10, 13, 14, 15, 19, 20, 21, 23, 26, 30, 31], "jax": [1, 6, 8, 9, 15], "tensorflow": [1, 2, 6, 9, 12, 15, 17, 19, 27], "seamlessli": 1, "orchestr": 1, "through": [1, 3, 5, 6, 7, 8, 16, 18, 19, 20, 27, 30], "integr": [1, 10, 11, 25, 26, 29], "kubernet": 1, "gke": 1, "leverag": [1, 10, 31], "dynam": [1, 3, 5, 11, 17, 18, 22], "schedul": [1, 18], "improv": [1, 15, 16, 17, 18, 20, 22, 28], "scalabl": 1, "all": [1, 2, 3, 5, 7, 10, 12, 15, 16, 17, 18, 19, 20, 21, 24, 25, 27, 28, 30], "need": [1, 2, 3, 5, 12, 15, 16, 17, 18, 19, 21, 24, 25, 27, 28, 29], "simultan": 1, "look": [1, 3, 5, 16, 17, 18, 28], "simplest": 1, "wai": [1, 2, 5, 7, 8, 12, 15, 16, 18, 19, 21, 22, 26, 27, 28], "develop": [1, 2, 10, 11, 14, 16, 21, 22, 26, 29], "can": [1, 2, 3, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 20, 21, 22, 24, 25, 26, 28, 29, 30, 31], "also": [1, 2, 3, 5, 6, 7, 9, 10, 12, 14, 15, 16, 17, 18, 19, 22, 24, 25, 26, 27, 28, 29], "vertex": 1, "fulli": [1, 11, 14, 15, 17, 29], "manag": [1, 8, 12, 20, 30], "platform": 1, "more": [1, 2, 3, 4, 5, 8, 9, 10, 12, 13, 14, 15, 16, 18, 19, 27, 28, 29, 31], "introduct": [1, 8], "set": [1, 2, 7, 12, 15, 17, 18, 19, 20, 22, 24, 27, 28, 30], "up": [1, 2, 3, 15, 16, 18, 19, 22, 25, 27], "environ": [1, 2, 11, 15, 16, 18, 21, 28, 30], "resourc": [1, 12, 17], "i": [2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 21, 22, 23, 24, 25, 26, 28, 30], "free": [2, 5, 13, 17, 20, 21, 24], "softwar": [2, 17], "tool": [2, 5, 18, 24], "autom": 2, "openxla": [2, 6, 14, 22, 26], "both": [2, 4, 5, 7, 9, 15, 18, 19, 20, 22, 24, 25, 26, 27, 29, 30], "make": [2, 4, 10, 12, 14, 15, 16, 17, 18, 19, 21, 22, 27, 28], "good": [2, 3, 5, 18, 19, 27, 28], "fit": [2, 3, 18, 24], "well": [2, 3, 6, 9, 12, 15, 18, 19, 27, 29], "extern": [2, 4, 8], "seen": [2, 18, 22], "workspac": [2, 17], "file": [2, 4, 12, 15, 17, 18, 20, 21], "http_archiv": 2, "name": [2, 4, 5, 9, 12, 15, 17, 19, 25, 27, 28, 29], "org_tensorflow": 2, "strip_prefix": 2, "f7759359f8420d3ca7b9fd19493f2a01bd47b4ef": 2, "url": 2, "http": [2, 3, 4, 8, 10, 12, 15, 17, 18, 24, 28], "github": [2, 3, 4, 5, 10, 12, 15, 17, 18, 21, 24, 28], "com": [2, 3, 4, 8, 10, 12, 15, 17, 18, 24, 28], "archiv": 2, "tar": 2, "gz": 2, "pin": [2, 12], "updat": [2, 3, 7, 16, 18, 19, 20, 27, 28], "point": [2, 3, 4, 5, 6, 9, 12, 18, 19, 20, 27], "thi": [2, 3, 4, 5, 6, 7, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 24, 25, 26, 27, 28, 29, 30, 31], "repositori": [2, 15], "differ": [2, 6, 12, 16, 17, 18, 19, 21, 23, 24, 27, 28, 29], "revis": 2, "patch": [2, 17], "mai": [2, 3, 6, 15, 16, 17, 18, 19, 20, 27, 28], "ad": [2, 5, 12, 16, 18, 19, 22, 23, 27, 28], "resolv": 2, "prepar": 2, "hermet": 2, "mechan": 2, "deploi": 2, "becaus": [2, 3, 9, 14, 15, 16, 18, 20, 28], "local": [2, 4, 12, 15, 16, 17, 28], "checkout": [2, 17], "ha": [2, 3, 4, 5, 8, 12, 14, 15, 16, 18, 19, 27, 28, 29], "built": [2, 4], "from": [2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 17, 18, 20, 21, 22, 23, 24, 25, 28, 29, 30], "sourc": [2, 3, 5, 6, 9, 11, 12, 17], "instal": [2, 3, 4, 5, 6, 8, 9, 10, 15, 17, 18], "system": [2, 29], "version": [2, 3, 4, 8, 15, 18, 20, 28], "compat": [2, 9, 15, 26, 30], "e": [2, 4, 6, 7, 9, 12, 13, 15, 17, 18, 19, 20, 24, 26, 27, 28], "g": [2, 4, 6, 7, 9, 12, 13, 15, 17, 18, 19, 26, 27, 28, 30], "codegen": [2, 5, 11], "torchgen": [2, 3], "python": [2, 3, 4, 5, 6, 9, 10, 11, 12, 13, 15, 17, 18, 19, 21, 22, 27, 28], "modul": [2, 8, 9, 12, 16, 17, 21, 24, 25, 28], "should": [2, 3, 4, 5, 6, 9, 10, 12, 14, 15, 16, 17, 18, 19, 20, 23, 24, 27, 28, 30], "The": [2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 18, 19, 21, 22, 24, 25, 26, 27, 28, 29, 30, 31], "directori": [2, 3, 5, 9], "either": [2, 5, 7, 12, 15, 17, 19, 20, 27], "bzl": 2, "overriden": 2, "command": [2, 3, 4, 15, 16, 17, 18, 21, 24], "line": [2, 3, 12, 14, 16, 17, 18, 19, 24, 27], "override_repositori": 2, "path": [2, 6, 9, 12, 16, 17, 19, 24, 27], "export": [2, 3, 4, 5, 11, 15, 17, 18], "tf_repo": 2, "torch_repo": 2, "pleas": [2, 3, 5, 7, 9, 12, 15, 16, 17, 18, 20, 24, 25, 26, 28, 31], "sure": [2, 16, 17], "overridden": [2, 3], "appropri": [2, 18], "been": [2, 5, 12, 15, 16, 18, 19, 27, 28], "use_cuda": 2, "0": [2, 3, 4, 6, 9, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31], "setup": [2, 3, 6, 16, 21], "py": [2, 3, 4, 5, 7, 10, 13, 14, 15, 16, 17, 18, 21, 24, 28, 31], "bdist_wheel": 2, "expect": [2, 3, 6, 10, 14, 15, 17, 19, 22, 26, 27], "object": [2, 12, 28], "present": [2, 30], "new_local_repositori": 2, "build_fil": 2, "pytorch_local_dir": 2, "header": 2, "directli": [2, 3, 5, 6, 12, 15, 16, 17, 18, 19, 20, 24, 27, 28, 30], "share": [2, 3, 6, 15, 16, 17, 28], "libtorch": 2, "so": [2, 3, 6, 10, 12, 13, 15, 16, 17, 18, 19, 24, 27, 30], "same": [2, 3, 5, 6, 9, 10, 12, 14, 15, 16, 17, 18, 19, 20, 23, 26, 27, 28, 29, 31], "where": [2, 4, 7, 8, 12, 13, 15, 16, 17, 18, 19, 24, 25, 27], "lib": [2, 6], "contain": [2, 3, 5, 6, 9, 10, 12, 15, 17, 18, 19, 27], "work": [2, 3, 7, 12, 13, 15, 16, 17, 18, 19, 21, 22, 26, 27, 28, 29], "": [2, 4, 5, 6, 7, 8, 9, 12, 14, 15, 16, 17, 18, 20, 21, 22, 26, 28, 29, 30], "requir": [2, 3, 5, 12, 13, 15, 16, 17, 18, 19, 20, 27, 28, 30, 31], "pass": [2, 5, 9, 10, 12, 15, 18, 20, 21, 28], "isystemextern": 2, "compil": [2, 5, 6, 9, 10, 11, 12, 13, 15, 18, 19, 20, 22, 25, 26, 27, 29, 30], "find": [2, 3, 5, 9, 15, 17, 18, 21, 25], "satisfi": [2, 28], "them": [2, 3, 5, 9, 12, 15, 16, 17, 18, 19, 27], "some": [2, 3, 5, 12, 13, 14, 15, 16, 17, 21, 26, 28], "user": [2, 4, 6, 9, 11, 14, 15, 16, 17, 18, 19, 22, 23, 25, 26, 27, 28, 30], "bring": [2, 3, 25], "pybind11": 2, "embed": 2, "link": [2, 3], "against": [2, 21], "libpython": 2, "instead": [2, 7, 12, 14, 15, 16, 17, 18, 19, 21, 22, 24, 27, 28, 30], "These": [2, 3, 5, 8, 15, 18, 26, 30], "pybind11_emb": 2, "option": [2, 3, 4, 6, 9, 12, 15, 17, 18, 26, 28, 30], "transit": [2, 16], "simpl": [2, 3, 8, 15, 18, 20, 24, 29], "torch_xla": [2, 4, 5, 6, 7, 8, 9, 10, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], "csrc": [2, 5], "runtim": [2, 3, 4, 6, 11, 16, 17, 21, 24, 28, 29, 30], "configr": 2, "via": [2, 4, 11, 15, 23, 24, 25, 28, 29], "bazelrc": 2, "take": [2, 3, 9, 10, 12, 16, 17, 18, 19, 27, 28], "flag": [2, 3, 12, 13, 20], "config": [2, 4], "remote_cach": 2, "configur": [2, 3, 5, 11, 12, 15, 17, 18, 30], "gcloud": [2, 4, 15, 16, 18], "usual": [2, 3, 5, 14, 16, 17], "faster": [2, 15, 18, 19, 22, 27], "authent": [2, 15], "easi": [2, 15, 16, 19, 27], "express": [2, 25, 29], "complex": [2, 10, 22], "lot": [2, 16, 17, 18, 19, 27], "gain": [2, 15], "have": [2, 3, 4, 5, 6, 8, 9, 12, 15, 16, 17, 18, 19, 21, 22, 24, 25, 27, 28, 30], "singl": [2, 12, 14, 19, 21, 22, 24, 25, 27, 28, 29, 31], "graph": [2, 9, 10, 12, 14, 15, 16, 17, 18, 19, 21, 22, 27, 28], "everyth": [2, 19, 21, 27], "therefor": [2, 17, 18], "separ": [2, 3, 5, 16, 18, 22, 24, 25], "rest": [2, 15, 17, 19, 27], "plu": [2, 21, 23], "whole": [2, 12, 14, 19, 22, 27], "everythin": 2, "els": [2, 17, 19, 27], "enough": [2, 18, 19, 27], "normal": [2, 3, 15, 19, 25, 27, 28], "achiev": [2, 5, 14, 21], "invok": [2, 3, 22, 28], "standard": [2, 9], "c": [2, 3, 5, 12, 15, 17, 19, 20, 27], "bind": [2, 9], "simpli": [2, 15], "_xlac": [2, 10, 17, 19, 27], "client": [2, 6, 12, 15], "togeth": [2, 14, 15, 16, 21, 24, 28], "when": [2, 3, 5, 7, 10, 12, 13, 14, 15, 16, 17, 18, 20, 22, 24, 28, 29, 30], "chang": [2, 5, 13, 16, 17, 18, 19, 20, 21, 26, 27, 28], "abl": [2, 16, 19, 27, 30], "without": [2, 5, 12, 15, 17, 18, 28, 29, 30], "iter": [2, 12, 13, 16, 17, 18, 22, 28], "cycl": 2, "come": [2, 12, 19, 27], "There": [2, 3, 14, 16, 17, 18, 19, 21, 22, 27, 28], "plenti": 2, "backend": [2, 3, 7, 12, 14, 15, 19, 22, 23, 26, 27, 28, 30], "we": [2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 29, 31], "our": [2, 3, 4, 5, 6, 7, 8, 9, 13, 15, 16, 17, 19, 20, 21, 22, 27, 28], "gc": [2, 30], "storag": [2, 4, 8, 16, 17, 18, 24, 30], "you": [2, 4, 6, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 21, 22, 24, 25, 28, 29, 31], "under": [2, 3, 5, 12, 15, 16, 21], "disabl": [2, 12, 14, 17, 18], "default": [2, 5, 12, 14, 15, 16, 17, 18, 20, 24, 28, 30], "speed": [2, 18, 19, 22, 27], "increment": [2, 3], "huge": [2, 17, 18, 19, 21, 27], "margin": 2, "almost": [2, 29], "alwai": [2, 15, 16, 17, 19, 27, 29], "enabl": [2, 10, 12, 13, 14, 17, 18, 20, 21, 26, 28, 29, 30], "ci": [2, 5], "To": [2, 3, 4, 5, 6, 8, 9, 10, 12, 13, 14, 15, 17, 18, 19, 24, 25, 27, 28, 30, 31], "ensur": [2, 9, 12, 19, 25, 27, 28, 30], "credenti": 2, "auth": [2, 15], "applic": [2, 17, 26, 30], "login": [2, 18], "launch": [2, 12, 15, 16, 18, 21, 22, 24], "browser": 2, "gcp": [2, 4, 15], "variou": [2, 10], "individu": [2, 24, 25, 29], "who": [2, 21], "access": [2, 3, 5, 8, 12, 15, 16, 17, 18, 19, 27, 30], "project": [2, 4, 6, 15, 16, 18], "one": [2, 3, 5, 8, 9, 12, 15, 16, 17, 18, 19, 22, 23, 24, 25, 27, 28, 29, 31], "onli": [2, 3, 5, 7, 9, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23, 27, 29, 30], "specifi": [2, 7, 9, 12, 16, 18, 24, 28], "google_default_credenti": 2, "token": [2, 14, 18, 26], "out": [2, 5, 9, 12, 13, 14, 15, 16, 17, 18, 20, 22, 28], "box": [2, 5, 28], "log": [2, 17, 18], "permiss": 2, "add": [2, 3, 5, 9, 10, 12, 16, 17, 18, 19, 22, 23, 24, 27], "new": [2, 3, 4, 5, 7, 14, 16, 17, 18, 19, 22, 27, 28], "role": 2, "In": [2, 3, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 27, 28, 29, 30], "account": [2, 18], "kei": [2, 4, 6, 15, 17, 18, 30], "google_credenti": 2, "On": [2, 15, 30], "docker": [2, 9], "network": [2, 12, 15, 16, 17, 20, 28], "cloudbuild": 2, "down": [2, 5, 18], "imag": [2, 15, 18, 19, 21, 24, 27], "do": [2, 3, 5, 11, 13, 15, 16, 17, 18, 19, 20, 24, 26, 27, 28], "doe": [2, 3, 12, 13, 15, 16, 17, 18, 19, 20, 27, 28], "read": [2, 4, 5, 12, 15, 28], "write": [2, 5, 10, 12, 16, 29], "silo": 2, "each": [2, 3, 5, 7, 10, 12, 15, 16, 17, 18, 19, 22, 24, 25, 27, 28, 29, 30], "uniqu": [2, 16, 18, 19, 27], "benefit": [2, 18, 25, 26, 30], "consist": [2, 7, 9, 15], "remote_default_exec_properti": 2, "some_silo_kei": 2, "bazel_remote_cach": 2, "1": [2, 4, 6, 7, 8, 9, 12, 14, 15, 16, 17, 20, 21, 22, 23, 24, 25, 28, 29, 31], "silo_nam": 2, "your": [2, 3, 6, 8, 9, 15, 16, 17, 18, 19, 21, 25, 27, 28, 30], "tpuvm_mod": 2, "gcloud_service_key_fil": 2, "application_default_credenti": 2, "json": [2, 9], "might": [2, 5, 12, 16, 17, 18, 19, 27], "help": [2, 17, 18, 19, 27], "too": [2, 17, 19, 27], "cannot": [2, 8, 18, 19, 20, 24, 27], "here": [2, 3, 5, 8, 9, 13, 16, 18, 19, 21, 22, 24, 25, 27, 28, 29, 30], "author": 2, "usernam": 2, "behavior": [2, 3, 5, 15, 16, 17, 20], "function": [2, 5, 6, 7, 8, 9, 10, 12, 14, 16, 17, 18, 22, 23, 25, 26, 30], "intend": 2, "first": [2, 3, 4, 9, 10, 12, 13, 15, 17, 18, 21, 28, 29, 30, 31], "time": [2, 3, 4, 12, 13, 15, 16, 17, 18, 19, 22, 23, 27, 28], "slow": [2, 17, 18], "scratch": [2, 3], "veri": [2, 6, 8, 14, 16, 18, 19, 27], "fast": [2, 19, 27], "onc": [2, 7, 12, 16, 17, 18, 19, 22, 27, 28], "again": [2, 3, 9, 16, 18], "bit": [2, 16, 26], "slower": [2, 17, 18, 21], "per": [2, 9, 12, 15, 16, 17, 20, 21, 22, 26], "until": [2, 12, 16, 18, 30], "next": [2, 12, 17, 18, 19, 26, 27, 28], "quit": 2, "current": [2, 6, 8, 9, 12, 13, 14, 15, 16, 18, 19, 21, 22, 23, 25, 26, 27, 28, 31], "migrat": [2, 11, 15], "futur": [2, 3, 4, 6, 9, 13, 15, 16, 17, 18, 19, 25, 27], "plafrom": 2, "cpp": [2, 5], "main": [2, 4, 7, 9, 10, 14, 15, 28], "Of": 2, "cours": 2, "pjrt": [2, 11, 12, 16, 28], "Not": 2, "environment": 2, "variabl": [2, 4, 13, 15, 18, 19, 27], "miss": [2, 5, 12, 17], "common": [2, 15, 19, 25, 26, 27, 29, 30], "part": [2, 3, 6, 10, 12, 14, 15, 17, 18, 28], "ones": [2, 12, 19, 27], "helper": [2, 3, 9, 12], "script": [2, 3, 4, 8, 15, 16, 17, 18, 20, 21, 31], "run_test": 2, "sh": 2, "r": [2, 18], "xla_client": 2, "pure": [2, 3], "easili": [2, 5, 19, 22, 27], "execut": [2, 10, 12, 14, 15, 16, 18, 19, 20, 21, 22, 27, 28, 29, 31], "parallel": [2, 11, 12, 15, 17, 21, 28, 29], "sinc": [2, 3, 5, 15, 16, 17, 18, 19, 20, 21, 22, 25, 27, 30], "xrt": [2, 12], "port": [2, 15, 31], "gpu": [2, 5, 6, 8, 11, 13, 17, 18, 28], "tpu": [2, 3, 5, 6, 8, 11, 12, 13, 17, 21, 22, 23, 30, 31], "devic": [2, 3, 4, 5, 6, 7, 9, 11, 12, 13, 14, 15, 17, 19, 20, 21, 22, 23, 26, 27, 29, 30], "avail": [2, 12, 15, 16, 17, 18, 19, 24, 27, 31], "reason": [2, 3, 5, 14, 15, 18, 21], "bundl": 2, "target": [2, 9, 14, 15, 16, 18, 19, 20, 22, 27], "sequenti": [2, 12], "calcul": 2, "visual": [2, 28], "lcov": 2, "describ": [2, 3, 4, 9, 12, 16, 18, 20, 21, 29], "document": [2, 3, 4, 5, 6, 9, 15, 16, 20, 21, 26], "editor": 2, "choic": [2, 19, 27], "gutter": 2, "vscode": 2, "power": 2, "like": [2, 3, 4, 5, 8, 12, 15, 16, 17, 18, 19, 20, 24, 27, 28], "clangd": 2, "refer": [2, 3, 5, 7, 8, 9, 10, 13, 15, 16, 18, 24, 26, 28, 31], "autocomplet": 2, "semant": [2, 5, 17, 19, 27], "understand": [2, 18, 19, 27], "underli": [2, 12, 16], "stack": [2, 16, 17, 19, 20, 27, 28], "combin": [2, 5, 12, 19, 27], "studio": 2, "extens": [2, 4, 5, 6], "featur": [2, 8, 13, 15, 17, 21, 25, 28, 29, 30], "assist": 2, "edit": 2, "As": [2, 3, 18, 19, 25, 27], "distutil": 2, "ltc": 3, "lazi": [3, 17, 18, 19, 22, 27, 28], "tensor": [3, 5, 7, 9, 12, 13, 15, 18, 20, 22, 23, 25, 26, 28, 29], "core": [3, 5, 7, 9, 12, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 29], "clean": [3, 17, 22], "exist": [3, 9, 12, 14, 15, 16, 17, 22, 28], "stub": 3, "over": [3, 12, 14, 15, 16, 18, 24, 30], "6": [3, 4, 5, 9, 12, 17, 18, 19, 27], "were": [3, 16, 17, 18, 19, 27], "complet": [3, 12, 16, 17], "process": [3, 5, 6, 7, 10, 12, 14, 15, 17, 18, 21, 24, 26], "found": [3, 15, 18], "ref": [3, 4, 15], "replac": [3, 18, 23], "support": [3, 6, 8, 9, 10, 12, 13, 15, 19, 22, 23, 24, 27, 28, 30, 31], "NOT": 3, "introduc": [3, 7, 8, 14, 15, 17, 18, 21, 28], "ani": [3, 8, 9, 12, 13, 15, 16, 17, 18, 19, 20, 21, 24, 25, 27, 28, 29, 30, 31], "purpos": [3, 5, 26], "follow": [3, 5, 8, 9, 10, 12, 14, 15, 16, 17, 18, 19, 21, 24, 25, 27, 28, 31], "instruct": [3, 5, 18], "depend": [3, 4, 5, 13, 14, 16, 18, 19, 20, 27], "build": [3, 5, 16, 18, 24], "It": [3, 4, 5, 7, 12, 13, 14, 16, 18, 19, 22, 24, 25, 26, 27, 28], "experi": [3, 5, 14, 15, 21, 30], "workstat": [3, 5], "cpu": [3, 5, 7, 9, 12, 17, 18, 19, 24, 26, 27, 28, 30], "pjrt_devic": [3, 5, 6, 13, 15, 16, 17, 23, 31], "re": [3, 12, 14, 15, 17, 18, 19, 20, 23, 25, 27], "familiar": [3, 16, 25], "issu": [3, 5, 12, 14, 15, 16, 17, 18, 20, 21, 25], "3560": 3, "track": [3, 17, 30], "statu": [3, 17], "put": [3, 5, 16, 17, 21], "alia": [3, 7, 12], "avoid": [3, 17, 18, 20], "duplic": 3, "mention": [3, 5, 19, 22, 27], "below": [3, 5, 7, 9, 14, 15, 18, 19, 20, 27, 30, 31], "live": [3, 5, 12, 19, 27], "folder": [3, 4, 5], "except": [3, 5, 18, 28], "xla_native_funct": [3, 5], "yaml": [3, 5], "torch": [3, 4, 8, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 29, 30], "shape_infer": 3, "shape": [3, 5, 8, 9, 10, 11, 12, 17, 18, 23, 28, 29], "defin": [3, 5, 8, 10, 12, 18, 20, 23, 25, 28, 29], "input": [3, 5, 7, 9, 10, 12, 13, 14, 15, 16, 17, 18, 20, 23, 25, 28, 29, 30], "return": [3, 5, 6, 7, 8, 9, 12, 14, 16, 17, 18, 19, 21, 22, 23, 26, 27, 28, 30], "output": [3, 4, 7, 8, 9, 10, 12, 15, 16, 17, 20, 21, 22, 23, 24, 28], "manual": [3, 5, 8, 14, 17, 24], "gen_lazy_tensor": 3, "data": [3, 7, 9, 11, 12, 14, 15, 16, 18, 19, 20, 22, 27, 29, 30], "aten": [3, 5, 17, 19, 27], "specif": [3, 12, 16, 18, 20, 21, 26], "run_gen_lazy_tensor": 3, "dest": 3, "lazy_ir": 3, "class": [3, 6, 7, 9, 12, 21, 24, 26, 30], "genlazyir": 3, "back": [3, 5, 9, 12, 16, 17, 18, 28], "todai": [3, 13], "most": [3, 6, 12, 15, 17, 22], "categori": [3, 25], "goal": [3, 4, 5, 7, 14], "move": [3, 9, 12, 15, 17, 19, 21, 27, 30], "full_codegen": 3, "necessari": [3, 12, 17, 20], "call": [3, 5, 7, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23, 24, 27, 28, 30], "upstream": [3, 7, 13, 22], "api": [3, 5, 11, 15, 16, 19, 21, 22, 24, 26, 27, 28, 29, 30], "xlanativefunct": [3, 5], "column": 3, "declar": [3, 5], "anoth": [3, 9, 13, 16, 17, 18, 19, 27], "wrap": [3, 5, 6, 8, 9, 12, 14, 16, 18, 20, 24, 25, 26, 28], "around": [3, 15, 19, 24, 27], "xlatensor": [3, 5, 12, 28], "construct": [3, 5, 16, 18, 24, 28, 29, 30], "aten_xla_typ": [3, 5], "Will": 3, "method": [3, 9, 12, 15, 20, 25, 28, 30], "map": [3, 5, 7, 12], "node": [3, 5, 7, 10, 17, 19, 27, 31], "remov": [3, 15, 17, 18], "tensor_method": [3, 5], "possibl": [3, 15, 16, 17, 18, 24, 25, 28], "multipl": [3, 7, 9, 12, 14, 19, 22, 26, 27], "few": [3, 16, 17, 18, 19, 21, 27, 30], "simpler": [3, 15], "go": [3, 14, 16, 18, 20, 28], "unari": 3, "binari": [3, 6, 9, 22], "exampl": [3, 4, 5, 6, 7, 9, 12, 13, 14, 15, 16, 17, 19, 21, 22, 26, 27, 28, 29, 30, 31], "characterist": 3, "fallback": [3, 5], "_adaptive_avg_pool3d": 3, "condit": [3, 19, 23, 27], "issupportedadaptivepool": 3, "xlahelp": 3, "i64list": 3, "self": [3, 5, 6, 7, 9, 12, 18, 21, 26, 28], "size": [3, 7, 10, 13, 15, 16, 17, 18, 19, 27, 30], "output_size_list": 3, "pool_dim": 3, "nativ": [3, 5, 14, 15, 17, 20, 21, 28], "call_fallback_fn": 3, "xla_fallback": 3, "aten_op": 3, "output_s": 3, "wip": 3, "evolv": 3, "At": [3, 6, 12], "self_tensor": 3, "static": [3, 13, 19, 27], "bool": [3, 12], "sync_upd": 3, "sys_util": 3, "getenvbool": 3, "xla_tensor_update_sync": 3, "true": [3, 12, 14, 15, 18, 19, 21, 24, 27, 28, 30], "xla_check": 3, "dst_tensor": 3, "updatefromtensor": 3, "sync": [3, 12, 14, 17, 18, 20], "complic": [3, 5, 8], "an": [3, 4, 5, 6, 7, 8, 12, 15, 17, 18, 19, 20, 21, 22, 24, 25, 27, 28, 29, 30], "would": [3, 4, 5, 12, 15, 16, 17, 18, 19, 23, 27], "someth": [3, 18], "ab": [3, 24], "const": [3, 5, 7], "torch_lazy_fn_count": 3, "bridg": [3, 22], "atenfromxlatensor": 3, "getxlatensor": 3, "fail": [3, 12, 16, 17, 30], "explain": [3, 6, 16, 17, 18, 19, 27, 29], "later": [3, 18], "still": [3, 7, 15, 16, 19, 20, 21, 27, 30], "snippet": [3, 16, 28], "auto": [3, 5, 12, 24, 30], "common_devic": 3, "getxladevic": 3, "torch_internal_assert": 3, "xlatensorptr": 3, "lazy_self": 3, "getxlatensororcreateforwrappednumb": 3, "nodeptr": 3, "reusenod": 3, "getirvalu": 3, "makenod": 3, "cachenod": 3, "creat": [3, 9, 10, 12, 15, 17, 18, 20, 21, 28, 30], "std": [3, 7, 21], "get": [3, 5, 12, 13, 14, 15, 18, 19, 21, 24, 26, 27], "check": [3, 4, 5, 12, 16, 26, 29], "reus": [3, 16, 18, 20], "previou": [3, 15, 16, 18, 19, 27], "creation": [3, 12], "If": [3, 4, 5, 9, 12, 15, 16, 17, 18, 19, 26, 27, 28], "correspond": [3, 5, 7, 12, 18, 20, 24, 28, 29], "cach": [3, 8, 12, 13, 18], "newli": [3, 9], "And": [3, 19, 21, 27, 28], "within": [3, 9, 12, 16, 17, 18, 26, 30], "note": [3, 4, 7, 8, 9, 10, 12, 14, 15, 16, 17, 18, 19, 22, 24, 25, 26, 27, 29], "done": [3, 4, 8, 16, 17, 18, 19, 27], "public": [3, 15], "xlanod": 3, "xlavalu": 3, "opkind": [3, 5], "absoutputshap": 3, "num_output": [3, 19, 27], "mhash": 3, "string": [3, 7, 12, 28], "tostr": 3, "overrid": [3, 12, 20], "stringstream": 3, "ss": 3, "str": [3, 6, 12], "xlaopvector": 3, "loweringcontext": 3, "loctx": 3, "A": [3, 4, 6, 12, 15, 16, 18, 19, 20, 25, 26, 27, 28], "coupl": [3, 16, 17], "thing": [3, 17, 18], "keep": [3, 4, 13, 15, 17, 19, 27], "mind": [3, 15, 17], "clone": [3, 15, 17, 18], "even": [3, 12, 15, 16, 17, 19, 21, 27], "everi": [3, 5, 8, 9, 12, 15, 16, 17, 19, 22, 27, 28, 30], "outputshap": 3, "xla_shap": 3, "overli": 3, "simplifi": 3, "buildxxxop": 3, "slightli": [3, 5, 12], "better": [3, 5, 14, 15, 16, 17, 18, 19, 22, 23, 27], "maximumoutputshap": 3, "lower_for_shape_fn": 3, "absl": 3, "xlaop": [3, 5], "operand": 3, "promot": 3, "max": [3, 19, 27, 30], "second": [3, 10, 13, 15, 17, 18, 21, 29, 31], "inferoutputshap": 3, "comput": [3, 4, 12, 15, 16, 17, 18, 19, 20, 27, 28, 29], "logic": [3, 12, 14, 19, 23, 27, 28, 29], "two": [3, 6, 12, 15, 17, 18, 19, 27, 28, 29], "xla_input": 3, "getoutputop": 3, "returnop": 3, "buildab": 3, "origin": [3, 9, 18], "genericop": 3, "modifi": [3, 18, 20, 22, 28], "abov": [3, 5, 6, 9, 13, 14, 15, 16, 17, 18, 19, 21, 22, 27, 29], "delet": 3, "sometim": [3, 18, 19, 27], "being": [3, 12, 16, 18, 21, 29], "tensor_op": 3, "cross": [3, 16, 28], "s1": [3, 28], "sub": 3, "mul": [3, 19, 27], "u2": 3, "v3": [3, 16, 21], "u3": 3, "v2": [3, 4, 16], "irnod": 3, "those": [3, 5, 9, 12, 17, 18, 21], "long": [3, 14, 17, 18, 19, 21, 27], "term": [3, 10, 14, 17, 19, 27], "rid": [3, 19, 27], "composit": [3, 5], "end": [3, 5, 10, 12, 13, 15, 16, 17, 18, 21, 24, 25], "exp": 3, "pow": 3, "norm_exp": 3, "vector": [3, 10], "involv": [3, 19, 27, 28], "don": [3, 5, 13, 14, 15, 16, 17, 19, 24, 27], "t": [3, 5, 9, 12, 13, 14, 15, 16, 17, 19, 20, 24, 25, 27, 28, 29, 30], "build_cpp_test": 3, "skip": [3, 5, 17, 22], "desir": [3, 9, 18, 30], "test_ptxla": 3, "gtest_filt": 3, "atenxlatensortest": 3, "testab": 3, "correct": [3, 19, 27], "counter": [3, 5, 12, 17], "correctli": [3, 17, 25], "gt": [3, 4, 9, 15, 18], "erf": 3, "erfc": 3, "erfinv": 3, "pull": [3, 9, 20, 21, 24], "3659": 3, "binary_cross_entropi": [3, 20], "backward": [3, 5, 9, 14, 15, 16, 20, 21, 22, 24, 25], "3809": 3, "scalar": [3, 5, 17, 19, 27], "addcdiv": 3, "addcmul": 3, "3768": 3, "neg": 3, "index": [3, 4, 6, 12, 15, 16, 17, 18, 31], "amin": 3, "amax": 3, "3771": 3, "special": [3, 9, 10, 18, 28], "partial": [3, 19, 24, 25, 27], "adaptive_avgpool3d": 3, "3790": 3, "guid": [4, 9, 11, 15, 16, 18, 24, 25, 28], "interact": [4, 15], "start": [4, 14, 15, 16, 17, 18], "colab": [4, 17], "kaggl": 4, "preinstal": [4, 15], "ecosystem": [4, 26], "packag": [4, 10, 11, 16, 18, 20, 21], "date": 4, "list": [4, 5, 12, 18, 20, 23, 28], "readm": [4, 17, 18], "prerequisit": 4, "remot": 4, "quota": 4, "about": [4, 14, 15, 16, 18, 19, 27], "request": [4, 5, 12, 17, 18, 19, 20, 21, 27, 28], "offici": [4, 17], "ssh": [4, 15, 16, 18], "regist": [4, 5, 6, 7, 15, 30], "agent": 4, "alreadi": [4, 8, 10, 12, 17, 18, 19, 21, 24, 27, 30], "befor": [4, 7, 8, 9, 12, 15, 16, 17, 18, 19, 20, 21, 22, 25, 27, 28, 30], "begin": [4, 28], "zone": [4, 15, 16, 18], "tpu_typ": 4, "8": [4, 9, 10, 12, 14, 15, 16, 18, 19, 21, 22, 26, 27, 28, 29], "vm": [4, 15, 16, 17, 18, 21], "assum": [4, 6, 8, 12, 16, 19, 21, 25, 27, 28], "id_ed25519": 4, "ubuntu2204": 4, "base": [4, 7, 12, 14, 15, 17, 18, 19, 24, 27, 28, 29], "metadata": [4, 17], "cat": [4, 20], "pub": 4, "ip": [4, 12, 15, 30, 31], "format": [4, 12, 17, 18, 22, 26], "valu": [4, 5, 9, 10, 12, 13, 15, 17, 18, 19, 23, 27, 28, 31], "networkendpoint": 4, "accessconfig": 4, "externalip": 4, "123": 4, "give": [4, 9, 17, 18, 26, 28, 29], "friendli": 4, "easier": [4, 14, 18, 19, 27], "echo": 4, "host": [4, 12, 15, 16, 17, 18, 20, 24, 30, 31], "n": [4, 12, 21, 26], "hostnam": 4, "test": [4, 6, 8, 9, 10, 13, 15, 21, 24, 31], "v": [4, 8, 9, 15, 19, 27], "palett": 4, "select": [4, 12, 15, 30], "visualstudio": 4, "doc": [4, 12, 14, 15, 16, 19, 25, 27, 28], "__": [4, 15], "just": [4, 8, 14, 15, 16, 19, 21, 24, 27, 30], "titl": [4, 15], "open": [4, 5, 6, 9, 15, 17], "window": 4, "termin": [4, 30], "mkdir": 4, "ptxla": 4, "Then": [4, 9, 18], "ui": 4, "venv": 4, "virtual": [4, 12], "latest": [4, 9], "releas": [4, 6, 7, 8, 15, 16, 17, 18, 22, 24, 25, 26, 28], "pip": [4, 8, 9, 10, 18], "numpi": [4, 8, 9, 12, 18, 29], "f": [4, 8, 9, 12, 16, 21, 24, 26, 30], "googleapi": [4, 8, 18], "libtpu": [4, 6, 15], "html": [4, 8, 15, 24], "import": [4, 6, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 28, 29, 30], "set_device_typ": 4, "print": [4, 9, 12, 13, 15, 16, 17, 18, 19, 21, 22, 27, 28, 30], "real_devic": 4, "run": [4, 5, 8, 10, 11, 12, 13, 14, 15, 19, 20, 21, 22, 26, 27, 30], "2": [4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 22, 23, 24, 26, 28, 31], "3": [4, 5, 6, 8, 9, 10, 12, 14, 17, 18, 22, 23, 24, 26, 28], "4": [4, 6, 8, 9, 12, 15, 16, 17, 18, 19, 22, 23, 24, 26, 27, 28, 29], "5": [4, 7, 9, 12, 13, 17, 18, 19, 21, 24, 26, 27], "7": [4, 12, 17, 21, 22], "number": [4, 10, 12, 13, 14, 15, 17, 18, 24, 28, 29], "vari": [4, 15, 19, 25, 27], "That": [4, 19, 27], "now": [4, 7, 9, 10, 14, 15, 16, 18, 19, 27, 28], "realist": 4, "librari": [5, 6, 18, 29, 30], "offer": [5, 9, 25, 26], "implement": [5, 7, 8, 9, 14, 15, 17, 19, 22, 24, 25, 27], "xla": [5, 9, 10, 13, 14, 15, 19, 21, 23, 25, 30, 31], "its": [5, 7, 9, 13, 15, 16, 17, 21, 22, 24, 28, 29], "convert": [5, 12, 16, 21], "higher": [5, 17, 30], "level": [5, 17, 18, 22, 26, 30], "represent": [5, 12, 16, 18, 29], "hlo": [5, 12, 16, 17, 18], "beyond": 5, "scope": 5, "forward": [5, 9, 14, 20, 21, 22, 25, 26], "haven": [5, 19, 27], "yet": [5, 7], "caus": [5, 12, 14, 15, 16, 17, 18, 19, 20, 27], "signific": [5, 17, 18, 22], "slowdown": [5, 17, 21], "must": [5, 6, 7, 12, 15, 16, 17, 25, 30, 31], "best": [5, 8, 22, 26], "perform": [5, 7, 8, 9, 10, 12, 14, 16, 20, 21, 22, 24, 26, 28], "what": [5, 16, 18], "debug": [5, 14, 19, 26, 27], "pt": [5, 15, 16, 17, 18], "profil": [5, 15], "_ctc_loss": [5, 17], "_ctc_loss_backward": [5, 17], "contribut": 5, "definit": [5, 16, 19, 27], "native_funct": 5, "after": [5, 7, 9, 12, 15, 16, 17, 18, 19, 23, 27, 28], "kernel": [5, 9, 11, 19, 26, 27], "aten_fallback": 5, "h": 5, "search": 5, "repo": [5, 16, 17, 18, 21], "sequenc": [5, 12], "explicitli": [5, 16, 17, 18, 19, 20, 27], "compos": 5, "match": [5, 9, 12, 16, 17], "serv": 5, "interfac": [5, 6, 16, 17, 25, 30], "machineri": 5, "registerxla": 5, "registerautogradxla": 5, "entri": [5, 6, 9], "pytorch_xla": 5, "world": [5, 8, 15, 19, 22, 27, 30], "written": [5, 18, 30], "paramet": [5, 12, 15, 16, 17, 20, 21, 25, 28, 30, 31], "result": [5, 7, 12, 13, 15, 16, 17, 18, 21, 23, 28], "dispatch": [5, 30], "wrapper": [5, 16, 21, 24, 25], "inplac": [5, 12, 28], "ir": [5, 9, 12, 17, 18, 19, 27], "insid": [5, 9, 16, 18, 28], "stand": 5, "intermedi": [5, 15, 17, 18], "smaller": [5, 18, 19, 27], "inherit": 5, "dai": 5, "addit": [5, 6, 10, 15, 16, 17, 18, 20, 21], "unless": [5, 17, 19, 27], "want": [5, 12, 14, 15, 16, 17, 18, 19, 22, 27, 28, 31], "verifi": 5, "test_oper": 5, "test_aten_xla_tensor": 5, "yield": [5, 16, 17], "break": [5, 18, 19, 27], "grasp": 5, "capabl": 5, "how": [5, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 27, 28, 31], "similar": [5, 15, 18, 21, 23, 26], "minim": [5, 18], "pr": [5, 17, 24], "vanilla": 5, "lerp": 5, "variant": [5, 12, 19, 20, 27], "lerp_": 5, "scalar_out": 5, "tensor_out": 5, "prototyp": [5, 9, 28], "weight": [5, 9, 12, 17, 25, 26], "lerp_out": 5, "howev": [5, 8, 9, 17, 18, 28], "namespac": [5, 17], "wrapper_scalar_lerp": 5, "No": [5, 13, 15, 19, 26, 27], "deviceguard": 5, "omit": [5, 15, 29, 31], "anonym": 5, "wrapper_scalar_lerp_": 5, "wrapper_scalar_lerp__tmp": 5, "_copy_from": 5, "m": [5, 7, 9, 19, 24, 27], "impl": [5, 7, 9], "torch_fn": 5, "automat": [5, 6, 11, 12, 15, 16, 17, 18, 19, 24, 27, 29, 30], "u": [5, 15, 17, 18, 19, 22, 27], "explicit": [5, 20, 24], "place": [5, 7, 12, 18, 20, 28, 30], "ll": [5, 19, 27], "interned_str": 5, "symbol": [5, 19, 27], "submit": [5, 17, 18, 20], "team": [6, 22], "direclti": 6, "tf": [6, 17, 19, 27], "close": 6, "expos": [6, 15, 16, 18, 28], "deviceplugin": 6, "handl": [6, 14, 17, 19, 24, 25, 27, 28, 29], "short": [6, 17, 19, 27], "pjrtclient": 6, "mirror": 6, "pjrt_api": 6, "straightforward": [6, 12, 18], "detail": [6, 7, 8, 9, 12, 13, 15, 16, 17, 18, 19, 27], "concret": [6, 19, 27], "placehold": 6, "pjrt_library_path": 6, "extra": [6, 21, 25], "multiprocess": [6, 12, 15, 16], "compon": 6, "least": [6, 18], "cpuplugin": 6, "def": [6, 7, 8, 9, 10, 12, 14, 15, 16, 18, 21, 22, 23, 25, 26], "library_path": 6, "o": [6, 9, 15, 21], "join": [6, 12], "dirnam": 6, "__file__": 6, "pjrt_c_api_cpu_plugin": 6, "identifi": [6, 12, 30], "exmapl": 6, "pyproject": 6, "toml": 6, "torch_xla_cpu_plugin": 6, "With": [6, 8, 9, 13, 15, 19, 22, 27], "initi": [6, 7, 9, 12, 15, 16, 18, 21, 23, 30], "experiment": [6, 8, 9, 10, 11, 13, 14, 15, 16, 21, 22, 23, 25, 28, 30], "state": [6, 12, 24], "becom": [6, 8, 9, 15, 17, 18, 19, 27], "stabl": [6, 15, 24], "xla_model": [7, 9, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 29], "adopt": [7, 19, 27], "traceabl": [7, 12], "commun": [7, 8, 12, 15, 16, 18, 22, 29], "reimplement": [7, 15], "_c10d_function": 7, "figur": [7, 13, 29], "show": [7, 9, 15, 16, 17, 21], "all_reduc": [7, 12, 20], "between": [7, 13, 15, 16, 17, 18, 19, 20, 21, 23, 27, 28], "processgroupxla": 7, "deriv": 7, "processgroup": 7, "xla_backend": [7, 15, 21, 30], "_create_xla_process_group": 7, "prefix_stor": 7, "rank": [7, 12, 15, 21, 24, 29, 30], "timeout": 7, "assert": [7, 21], "xr": [7, 12, 15, 16, 20, 21, 24, 25, 28, 29, 30], "is_spmd": [7, 12], "spmd": [7, 11, 16, 18, 30], "group": [7, 12, 15, 21, 28], "_register_xla_backend": 7, "dist": [7, 15, 21, 30], "register_backend": 7, "allreduc": 7, "all_reduce_opt": 7, "allgath": 7, "output_tensors_list": 7, "input_tensor": 7, "opt": [7, 16], "none": [7, 8, 9, 12, 17, 25, 28, 29], "_mp_fn": [7, 15, 16], "init_process_group": [7, 15, 21, 30], "init_method": [7, 15, 30], "progress": [7, 18], "instanc": [7, 8, 12, 24, 30], "blob": [7, 10, 12, 15, 28], "distributed_c10d": 7, "_exception_logg": 7, "all_gath": [7, 12, 15], "tensor_list": 7, "async_op": 7, "fals": [7, 9, 12, 16, 24, 28], "_get_default_group": 7, "certain": [7, 17, 19, 20, 27], "remap": 7, "_functional_collect": 7, "all_reduce_inplac": 7, "eventu": 7, "reach": [7, 12], "rewrit": [7, 18, 19, 27, 28], "reduceop": 7, "group_nam": 7, "torch_library_impl": 7, "four": [7, 18], "oper": [7, 10, 11, 12, 15, 16, 17, 18, 30], "align": [7, 14], "while": [7, 9, 12, 18, 19, 21, 27], "signatur": 7, "remain": [7, 16, 18, 19, 27, 31], "restrict": 7, "appli": [7, 12, 20, 24, 25, 30], "usag": [7, 12, 17, 18, 19, 24, 25, 27, 30], "test_collective_ops_tpu": 7, "demonstr": [7, 18, 20, 25, 30], "scenario": [7, 22], "sum": [7, 12, 20, 24, 25], "reduct": [7, 12], "aggreg": 7, "all_gather_into_tensor": 7, "gather": [7, 12, 28], "reduce_scatter_tensor": 7, "reduc": [7, 12, 13, 14, 15, 16, 17, 18, 24], "across": [7, 12, 15, 16, 17, 24, 29], "all_to_all_singl": 7, "output_split_s": 7, "input_split_s": 7, "although": [7, 15, 19, 27], "accept": [7, 28], "argument": [7, 9, 10, 12, 18, 20, 22, 24], "limit": [7, 12, 15, 16], "reflect": 7, "compromis": 7, "maintain": 7, "constraint": [7, 15, 17], "alltoal": [7, 12], "rise": 8, "openai": [8, 10], "triton": [8, 11], "popular": 8, "order": [8, 12, 16, 17, 18, 28, 29], "pariti": 8, "continu": [8, 15, 22], "push": 8, "let": [8, 15, 16, 17, 18, 22, 29], "custom_kernel": 8, "jax_import_guard": 8, "pl": [8, 15, 16, 28], "jnp": 8, "add_vectors_kernel": 8, "x_ref": 8, "y_ref": 8, "o_ref": 8, "x": [8, 9, 10, 12, 16, 17, 18, 19, 21, 23, 24, 25, 26, 27, 28, 29], "y": [8, 10, 12, 17, 18, 19, 24, 25, 26, 27, 28], "jit": [8, 10, 22], "add_vector": 8, "arrai": [8, 12, 18, 25, 29], "pallas_cal": 8, "out_shap": 8, "shapedtypestruct": 8, "dtype": [8, 9, 10, 15, 19, 20, 26, 27], "otherwis": [8, 12, 17, 18, 19, 25, 27], "program": [8, 9, 10, 12, 17, 18, 19, 22, 27, 28, 29], "hang": 8, "lock": 8, "q": [8, 9], "randn": [8, 9, 12, 14, 15, 16, 21, 22, 26, 28, 29], "128": [8, 9, 15, 24, 26, 31], "k": [8, 9, 17], "make_kernel_from_palla": 8, "pt_kernel": 8, "lambda": [8, 24], "liner": 8, "flash": [8, 10], "attent": [8, 10], "besid": 8, "op": [8, 9, 11, 12, 14, 17, 18, 19, 20, 27, 28, 29], "suppor": 8, "flash_attent": 8, "paged_attent": 8, "queri": [8, 15], "squeez": 8, "dim": [8, 12], "key_cach": 8, "value_cach": 8, "context_len": 8, "block_tabl": 8, "pages_per_compute_block": 8, "megacore_mod": 8, "vllm": 8, "util": [8, 11, 12, 16, 17, 21, 24, 25, 26, 30], "effect": [8, 12], "memori": [8, 11, 12, 13, 17, 18, 19, 24, 27], "kv": 8, "proper": [8, 29], "jax_nightly_releas": 8, "jaxlib_nightly_releas": 8, "exported_program_to_stablehlo": 9, "xm": [9, 12, 14, 16, 17, 18, 20, 21, 22, 23, 24, 25, 26, 29], "torchvis": [9, 14, 22], "xla_devic": [9, 12, 15, 16, 17, 18, 20, 21, 22, 23, 26, 29], "resnet18": [9, 14, 22], "sampl": [9, 12, 15, 17], "tupl": [9, 12, 19, 23, 25, 27, 29], "sample_input": 9, "224": [9, 14], "stablehlo_program": 9, "callabl": [9, 12, 24], "get_stablehlo_text": 9, "get_stablehlo_bytecod": [9, 12], "sample_input_xla": 9, "output2": 9, "allclos": 9, "atol": 9, "1e": [9, 17, 22], "One": [9, 12, 13, 18, 24], "tmp": [9, 16, 17, 24], "stablehlo_dir": 9, "empti": [9, 12], "doesn": [9, 16, 17, 19, 25, 27], "load": [9, 10, 12, 15, 17, 21, 24, 26, 30], "stablehlographmodul": 9, "stablehlo_program2": 9, "output3": 9, "server": [9, 12, 15, 18], "env": [9, 12, 15, 28], "nightli": [9, 17, 18, 24, 28], "resnet_tf": 9, "p": [9, 15, 17, 19, 27], "8500": 9, "mount": [9, 16], "model_nam": 9, "accomplish": 9, "tf_saved_model_integr": 9, "save_torch_module_as_tf_saved_model": 9, "nn": [9, 12, 15, 16, 21, 22, 24, 26, 28], "trace": [9, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 27, 28], "exported_model": 9, "exportedprogram": 9, "pathlik": 9, "stablehloexportopt": 9, "alias": [9, 17, 20], "save_torch_model_as_stablehlo": 9, "torchmodel": 9, "arg": [9, 12, 16, 18, 23, 24], "constant": [9, 17, 18, 28], "ndarrai": [9, 12], "human": 9, "readabl": [9, 18], "mlir": 9, "form": [9, 15, 17, 19, 27, 31], "posit": [9, 12], "meta": 9, "l__fn___layers_15_feed_forward_w2": 9, "l__fn___layers_13_feed_forward_w1": 9, "l__fn___layers_3_attention_wo": 9, "l__fn___layers_12_ffn_norm_weight": 9, "l__fn___layers_25_attention_wo": 9, "serial": [9, 15, 16], "stablehlofunc": 9, "stage": 9, "guarante": [9, 12], "plan": [9, 13, 15], "major": 9, "agre": [9, 18], "scaled_dot_product_attent": 9, "decompos": 9, "low": [9, 13, 17], "dure": [9, 12, 17, 18, 22, 24, 28], "lower": [9, 11, 17, 18, 19, 20, 27], "captur": [9, 12, 17, 18], "downstream": [9, 20], "ml": [9, 29], "crucial": 9, "geneart": 9, "pattern": [9, 17, 19, 22, 27], "bunch": 9, "challeng": 9, "error": [9, 12, 16, 17], "prone": 9, "robust": 9, "outlin": [9, 26], "stablehlocompositebuild": 9, "arbitari": 9, "region": [9, 12, 14, 17, 20, 28], "non": [9, 12, 14, 19, 20, 27, 29], "hardcod": [9, 28], "store": [9, 10, 12, 17], "attribut": 9, "retriev": [9, 12, 16, 19, 22, 27, 28], "pratic": 9, "scaled_product_attent": 9, "mark_pattern_util": 9, "__init__": [9, 21, 26], "super": [9, 21, 22], "q_proj": 9, "linear": [9, 12, 15, 16, 20, 21, 26], "bia": 9, "k_proj": 9, "v_proj": 9, "builder": 9, "b": [9, 12, 15, 18, 19, 20, 22, 27, 29], "sdpa": 9, "25": [9, 13], "other_attr": 9, "val": 9, "mark_input": 9, "attn_out": 9, "mark_output": 9, "input_arg": 9, "10": [9, 12, 15, 16, 17, 18, 19, 21, 22, 23, 26, 27, 30], "stablehlo_gm": 9, "shown": [9, 15, 19, 27], "irtohlo": 9, "56": 9, "mhlo": 9, "cross_program_prefetch": 9, "input_output_alia": 9, "is_dynam": 9, "use_auto_spmd_partit": 9, "func": 9, "arg0": 9, "10x8x128xf32": 9, "arg1": 9, "128x128xf32": 9, "arg2": 9, "arg3": 9, "9": [9, 18, 19, 21, 24, 27], "composite_attribut": 9, "500000e": 9, "01": [9, 10], "f32": 9, "decomposit": 9, "11": [9, 17, 19, 27], "privat": [9, 15], "actual": [9, 14, 18, 19, 21, 27, 28], "encapsul": 9, "propag": [9, 17], "high": [10, 13, 18, 21, 26], "deep": [10, 11, 17], "learn": [10, 15], "languag": 10, "empow": 10, "full": [10, 12, 16, 17, 24], "potenti": [10, 12, 15, 17, 25], "given": [10, 12, 17, 18, 19, 21, 24, 27, 29], "add_kernel": 10, "x_ptr": 10, "pointer": 10, "y_ptr": 10, "output_ptr": 10, "n_element": 10, "block_siz": 10, "tl": 10, "constexpr": 10, "element": [10, 12, 19, 25, 27, 28], "tutori": [10, 17, 18, 21, 28], "l28": 10, "pid": 10, "program_id": 10, "axi": [10, 12, 25], "block_start": 10, "offset": 10, "arang": 10, "mask": [10, 17, 19, 27], "xla_triton": 10, "16": [10, 16, 18, 24, 26, 29], "int64": 10, "empty_lik": 10, "grid": 10, "cdiv": 10, "triton_cal": 10, "itself": [10, 12, 24], "kwarg": [10, 12, 24, 28], "payload": [10, 12, 15], "regard": [10, 16, 22], "buffer": [10, 12], "_xla_gpu_custom_cal": 10, "dep": 10, "connect": [11, 12, 15, 28], "overview": [11, 29], "eager": [11, 12, 19, 21, 26, 27], "mode": [11, 12, 19, 21, 26, 27, 28, 30], "troubleshoot": 11, "palla": 11, "stablehlo": [11, 12], "mix": [11, 12, 29], "precis": 11, "advanc": [11, 29], "topic": [11, 29], "distribut": [11, 16, 17, 21, 24, 25, 28, 29], "checkpoint": [11, 15, 18, 24, 29], "distributeddataparallel": [11, 15], "ddp": [11, 15], "torchdynamo": 11, "while_loop": 11, "shard": [11, 12, 29, 30], "quantiz": 11, "recompil": [11, 13, 14, 16, 17, 18], "hardwar": [11, 12, 17, 18, 20], "plugin": [11, 15], "bazel": 11, "int": [12, 15, 19, 27, 28], "device_count": [12, 28], "address": [12, 15, 28, 31], "wait": [12, 17, 18], "pend": [12, 14], "whether": [12, 16, 20], "block": [12, 18, 24, 28], "finish": [12, 18], "full_graph": 12, "num_different_graphs_allow": 12, "lazytensor": [12, 14, 18], "repres": [12, 15, 19, 27], "happen": [12, 14, 15, 16, 17, 18, 19, 27], "decid": [12, 17, 19, 27], "funciton": 12, "act": [12, 16], "context": [12, 15, 17, 19, 20, 27], "throw": [12, 16], "info": [12, 17, 19, 27, 29], "exit": [12, 17, 20, 21], "pt_xla_debug": 12, "messag": [12, 17], "dump": [12, 17], "allow": [12, 16, 17, 18, 20, 28, 29, 30], "rais": [12, 17], "exceed": 12, "foo": 12, "sin": 12, "co": 12, "foo2": 12, "compiled_foo2": 12, "manual_se": [12, 15], "seed": 12, "random": [12, 14, 15, 18, 26], "integ": [12, 17], "rng": [12, 15], "device_typ": 12, "local_process_count": 12, "local_device_count": 12, "total": [12, 19, 27, 29], "addressable_device_count": 12, "visibl": [12, 19, 27], "global_device_count": 12, "global_runtime_device_count": [12, 25, 28, 29], "especi": [12, 15, 18, 22, 28], "world_siz": [12, 15, 20, 21, 24, 28], "particip": [12, 15], "job": [12, 18, 22], "global_ordin": [12, 15, 16, 21, 24], "global": [12, 15, 16, 28, 30], "ordin": [12, 16], "thread": [12, 15, 16, 17, 30], "predict": 12, "relationship": [12, 16, 17], "worker": [12, 15, 16, 18, 24, 30], "id": [12, 15, 17, 18], "nor": 12, "contigu": [12, 16, 17], "local_ordin": 12, "get_master_ip": 12, "master": [12, 15, 16, 30], "discoveri": 12, "use_spmd": [12, 28, 29, 30], "forc": [12, 15, 17, 19, 23, 27], "mean": [12, 15, 16, 17, 18, 19, 21, 25, 27, 28], "replic": [12, 28, 29], "spmd_advanc": 12, "md": [12, 15], "initialize_cach": [12, 16], "readonli": [12, 16], "persist": [12, 16, 30], "devkind": 12, "cuda": [12, 15, 16, 18, 19, 20, 26, 27, 31], "deprec": 12, "xla_device_hw": 12, "union": 12, "real": [12, 22], "is_master_ordin": 12, "multi": [12, 13, 28, 31], "num_host": 12, "boolean": 12, "indic": [12, 17, 18, 19, 27], "reduce_typ": 12, "float": [12, 19, 20, 27], "pin_layout": 12, "reduce_sum": 12, "reduce_mul": 12, "reduce_and": 12, "reduce_or": 12, "reduce_min": 12, "reduce_max": 12, "replica": [12, 15], "layout": [12, 26], "pine": 12, "prevent": [12, 18, 20, 22, 28], "corrupt": 12, "unpin": 12, "hlomodul": 12, "constrain": [12, 15], "hold": [12, 28, 29], "along": [12, 24], "dimens": [12, 13, 28, 29], "all_to_al": 12, "split_dimens": 12, "concat_dimens": 12, "split_count": 12, "www": 12, "org": [12, 15, 24], "operation_semant": 12, "upon": 12, "split": 12, "concat": 12, "count": [12, 17], "add_step_closur": 12, "closur": 12, "run_async": 12, "step": [12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 24, 25, 27, 28, 30], "mani": [12, 15, 17, 18, 19, 27, 31], "report": 12, "consol": 12, "post": [12, 17], "tensorboard": [12, 18], "etc": [12, 14, 17, 19, 27, 28], "intermediari": 12, "inspect": 12, "typic": 12, "barrier": [12, 15, 16, 18], "materi": [12, 17, 18, 19, 27, 28], "queu": 12, "though": [12, 16, 21], "advis": 12, "throttl": 12, "event": 12, "asynchron": [12, 28, 30], "wait_device_op": 12, "async": [12, 22], "whose": [12, 13], "optimizer_step": [12, 16, 18, 20, 21, 24], "optimizer_arg": 12, "dict": [12, 24], "gradid": 12, "parallelload": [12, 28], "dataparallel": 12, "loader": [12, 17, 18, 22], "dictionari": 12, "gradient": [12, 16, 20, 24, 30], "save": [12, 17, 24, 30], "file_or_path": 12, "textio": 12, "master_onli": [12, 24], "global_mast": 12, "transfer": [12, 15, 17, 18, 28], "care": [12, 16, 17, 19, 27], "taken": [12, 16, 17, 19, 21, 27, 30], "view": [12, 16, 17], "recreat": [12, 16], "destin": [12, 16], "nest": [12, 24], "locat": 12, "control": [12, 13, 16, 17, 28], "obj_to_sav": 12, "path_to_sav": 12, "rendezv": 12, "tag": [12, 15], "byte": 12, "mesh": [12, 15, 25], "xla_rendezv": 12, "sent": [12, 17], "exchang": 12, "mesh_reduc": 12, "reduce_fn": 12, "toxlatensorarena": 12, "receiv": 12, "copi": [12, 15, 16, 17, 18], "np": [12, 25, 29], "accuraci": [12, 21, 24], "test_accuraci": 12, "set_rng_stat": 12, "get_rng_stat": 12, "get_memory_info": 12, "memoryinfo": 12, "bytes_us": 12, "290816": 12, "bytes_limit": 12, "34088157184": 12, "peak_bytes_us": 12, "500816": 12, "get_stablehlo": 12, "var": [12, 28], "xla_hlo_debug": [12, 17], "root": [12, 19, 27], "bytecod": [12, 22], "parallel_load": [12, 15, 16, 17], "mpdeviceload": [12, 16, 18, 28], "dataload": [12, 16, 18, 21, 28, 30], "background": [12, 30], "upload": [12, 18, 28], "per_device_load": [12, 28], "constructor": 12, "train_device_load": 12, "train_load": [12, 16, 28], "xla_multiprocess": 12, "spawn": [12, 15, 16, 18], "fn": 12, "nproc": [12, 15], "daemon": 12, "start_method": 12, "moment": 12, "maximum": [12, 13, 18, 26], "valueerror": 12, "mark_shard": [12, 25, 28, 29], "xlashardedtensor": [12, 30], "partition_spec": [12, 28, 29], "annot": [12, 28, 29], "partit": [12, 28], "spec": [12, 28], "intern": [12, 15, 16, 17, 19, 27, 28, 31], "spmdpartition": [12, 28], "topologi": [12, 16, 28, 29], "device_mesh": [12, 28], "mesh_shap": [12, 25, 28, 29], "ax": [12, 28, 29], "impact": [12, 15, 17, 19, 21, 27], "dynamo_custom_op": 12, "dynamo": [12, 18, 22, 26], "recogniz": 12, "num_devic": [12, 25, 28, 29], "device_id": [12, 25, 28, 29], "32": [12, 17, 18], "clear_shard": 12, "clear": 12, "cast": [12, 20], "t1": [12, 16, 17, 29], "get_1d_mesh": 12, "set_global_mesh": 12, "get_global_mesh": 12, "axis_nam": [12, 28], "v4": [12, 14, 15, 16, 18, 22, 28], "ravel": 12, "reshap": 12, "fill": 12, "assign": [12, 16, 18], "Its": 12, "length": [12, 19, 27], "len": [12, 18], "get_xla_supported_devic": 12, "get_logical_mesh": 12, "ordereddict": [12, 28, 29], "hybridmesh": [12, 28], "ici_mesh_shap": [12, 28], "dcn_mesh_shap": [12, 28], "hybrid": 12, "ici": 12, "dcn": [12, 28], "increas": 12, "intens": 12, "mdl": 12, "inner": [12, 24, 28], "outer": [12, 24, 25, 28], "slice": [12, 18, 28], "fsdp": [12, 24, 25, 28, 29], "eager_mod": [12, 14], "wa": [12, 15, 17, 18, 19, 27, 30], "d": [12, 13, 19, 20, 27], "eagerli": [12, 14, 16, 17, 19, 27], "metric": [12, 21], "metrics_report": [12, 17], "short_metrics_report": [12, 17], "counter_nam": 12, "metric_nam": 12, "activ": [12, 16, 17, 21, 24, 25, 26], "counter_valu": 12, "metric_data": 12, "total_sampl": 12, "accumul": 12, "retain": 12, "circular": 12, "natur": 13, "in_tensor": 13, "randint": [13, 26], "out_tensor": 13, "nonzero": [13, 17, 18, 19, 27], "word": [13, 19, 27], "further": [13, 18, 21], "categor": 13, "unbound": 13, "alloc": 13, "infinit": [13, 25], "phase": 13, "layer": [13, 14, 24, 25, 28], "perceptron": 13, "mlp": 13, "xla_experiment": 13, "masked_select": 13, "masked_scatt": 13, "your_script": [13, 18], "100": [13, 17, 24], "29": [13, 21, 22], "49": [13, 22], "20": [13, 16, 17, 21, 26], "03": 13, "102": 13, "hit": [13, 19, 27], "198": 13, "1953": 13, "motiv": 13, "excess": 13, "half": 13, "drop": [13, 17], "try": [13, 17, 18, 19, 27], "python3": [13, 15, 16, 17, 18, 24], "test_dynamic_shape_model": 13, "testdynamicshapemodel": 13, "test_backward_pass_with_dynamic_input": 13, "expand": [13, 22], "feel": [13, 17, 21], "review": [13, 25], "rfc": [13, 28, 31], "64": [14, 22, 24], "mark_step": [14, 15, 16, 17, 18, 21], "drawback": 14, "approach": [14, 19, 21, 24, 27], "often": [14, 17, 19, 27], "confus": 14, "preprocess": [14, 26], "small": [14, 17, 18, 19, 21, 22, 27], "leak": 14, "expens": [14, 17, 19, 27], "hard": [14, 19, 21, 22, 27], "why": [14, 19, 27], "mitig": 14, "ux": 14, "mark": [14, 16], "compiled_model": 14, "right": [14, 19, 22, 27], "awai": 14, "pretti": [14, 16, 19, 21, 27], "straight": 14, "enter": 14, "reenabl": 14, "perfomr": 14, "compar": [14, 15, 16, 17, 21, 22, 23], "recommen": 14, "overhad": 14, "step_fn": 14, "loss_fn": [14, 15, 16, 20, 21, 22], "zero_grad": [14, 15, 16, 20, 21], "logit": [14, 25], "loss": [14, 15, 16, 20, 22, 24, 25], "ask": [14, 17, 19, 27], "refactor": 14, "decod": 14, "much": [14, 15, 16, 18, 19, 22, 27], "llama2": 14, "fake": [14, 30], "chip": [14, 15], "300": [14, 17], "observ": [14, 15, 21], "147": 14, "65": [14, 17], "45": 14, "train_decoder_only_bas": [14, 17], "perfomran": 14, "tri": [14, 18], "resnet50": [14, 15, 16, 22, 24], "exepct": 14, "loop": [14, 16, 17, 18, 19, 25, 27, 30], "meant": 14, "encount": [15, 17, 18], "bug": [15, 17, 21], "r2": [15, 17, 28], "init": [15, 16, 21, 22, 23], "renam": 15, "torchrun": [15, 16, 31], "xpu": 15, "neuron": 15, "xrt_tpu_config": 15, "30": [15, 24], "thousand": 15, "preview": 15, "safe": 15, "section": [15, 16, 17, 18, 28], "broadcast": 15, "broadcast_master_param": 15, "pjrt_backend": 15, "diff": [15, 18], "42": 15, "gradient_as_bucket_view": [15, 21], "mseloss": [15, 21], "sgd": [15, 16, 20, 21, 22], "lr": [15, 16, 21, 22, 24, 25], "001": [15, 21], "confirm": 15, "__name__": [15, 16, 21], "__main__": [15, 16, 21], "localservic": 15, "localhost": [15, 21], "51011": 15, "master_addr": [15, 21], "master_port": [15, 21], "12355": [15, 21, 31], "Or": [15, 16, 19, 27], "overhead": [15, 21, 22], "grpc": 15, "torchbench": 15, "35": [15, 17], "tpuvm": [15, 16, 18, 28], "2048": 15, "mnist": [15, 16, 17, 20], "test_train_mp_mnist": [15, 21], "fake_data": [15, 17, 21, 31], "alpha": [15, 16], "central2": [15, 18], "git": [15, 17, 18, 24], "depth": [15, 17], "branch": [15, 17, 19, 27], "test_train_mp_imagenet": [15, 17, 21], "batch_siz": [15, 24, 31], "256": 15, "num_epoch": [15, 21, 24], "By": [15, 19, 27], "tpu_process_bound": 15, "tpu_visible_chip": 15, "r1": 15, "13": [15, 16, 21, 23], "docker_imag": 15, "gcr": 15, "io": [15, 24], "sudo": [15, 18], "rm": 15, "privileg": 15, "net": [15, 18, 20], "gpu_num_devic": 15, "nnode": [15, 31], "num_gpu_devic": 15, "pjrt_distribut": 15, "physic": [15, 28, 29], "12": [15, 17, 22, 24], "number_gpu_vm": [15, 31], "node_rank": [15, 31], "current_node_rank": 15, "nproc_per_nod": [15, 31], "number_local_gpu_devic": 15, "rdzv_endpoint": [15, 31], "internal_ip_address": 15, "multinode_train": 15, "endpoint": [15, 31], "machine_0": 15, "machine_1": 15, "machine_0_internal_ip_address": [15, 31], "ident": 15, "page": 15, "mostli": [15, 24], "interchang": 15, "perspect": [15, 16], "subtl": 15, "importantli": 15, "architectur": [15, 24], "thu": [15, 17], "batch": [15, 16, 17, 18, 28], "latenc": 15, "deseri": 15, "send": [15, 16, 18, 28], "direct": [15, 17], "independ": [15, 16, 17], "significantli": [15, 16, 18], "xla_dist": 15, "scp": [15, 16], "sdk": 15, "collect": [15, 21, 22, 29, 30], "enhanc": 15, "stabil": [15, 17, 20], "xmp": [15, 16, 18], "substanti": 15, "practic": [15, 19, 25, 27], "unreli": 15, "due": [15, 17, 18, 31], "inbound": 15, "could": [15, 18, 19, 27, 28], "failur": 15, "entir": [15, 24], "restart": 15, "impos": 15, "middl": [15, 18, 19, 27], "unwant": 15, "permit": 15, "subset": 15, "old": 15, "alter": 15, "synchron": [15, 16, 18, 28, 30], "consid": [15, 18], "all_gather_object": 15, "gloo": [15, 21, 30], "subgroup": 15, "monitor": 15, "_": [15, 22, 23], "altern": [15, 19, 20, 26, 27], "less": [15, 19, 22, 27], "reliabl": 15, "than": [15, 17, 19, 21, 24, 27], "strongli": 15, "_all_gath": 15, "int32": 15, "zeros_lik": 15, "get_world_s": 15, "averag": 15, "task": 15, "175": 15, "chart": 15, "breakdown": 15, "tfrt": 15, "legaci": 15, "streamexecutor": 15, "tpu_legaci": 15, "comparison": [15, 29], "regular": [16, 17, 18, 26], "t0": 16, "matrix": 16, "multipli": [16, 29], "mm": [16, 20], "neural": 16, "l_in": 16, "l_out": 16, "floattensor": 16, "highlight": [16, 18], "nllloss": 16, "momentum": 16, "switch": [16, 17, 19, 21, 27], "acquir": 16, "mp_device_load": 16, "three": 16, "multithread": [16, 17], "own": [16, 24], "onto": 16, "preload": [16, 18], "overlap": [16, 18, 22, 28], "batches_per_execut": 16, "consolid": [16, 24], "all_reduce_gradi": 16, "parent": 16, "talk": 16, "basi": 16, "howto": 16, "focu": [16, 19, 27], "train_mnist_xla": 16, "outsid": 16, "infrastructur": 16, "awar": 16, "fakedata": 16, "But": [16, 17, 19, 27], "immedi": [16, 28], "hand": 16, "record": [16, 17, 18], "defer": 16, "fuse": [16, 18], "invis": 16, "caller": 16, "insert": [16, 18], "paper": 16, "opaqu": [16, 17], "appear": [16, 17, 18], "unlik": [16, 18], "adjust": 16, "preserv": [16, 17], "appreci": 16, "accommod": 16, "previous": 16, "state_dict": [16, 24, 30], "footprint": 16, "xser": 16, "stream": 16, "amount": [16, 17, 18, 19, 27], "restor": 16, "load_state_dict": [16, 30], "unavail": [16, 17], "consum": [16, 19, 27], "disk": 16, "occur": 16, "your_cache_path": 16, "mp_fn": 16, "xla_cache_": 16, "runnabl": [16, 21, 25], "subject": 17, "peculiar": 17, "detial": 17, "__version__": 17, "cu121": 17, "t2": [17, 29], "200": 17, "rx": 17, "conclud": 17, "diagnos": 17, "extrem": 17, "pt_xla_debug_level": 17, "slip": 17, "analyz": [17, 18], "summari": 17, "compiletim": 17, "frequent": 17, "21": 17, "transferfromdevicetim": 17, "23": 17, "hash": 17, "c74c3b91b855b2b123f833b0d5f86943": 17, "107": 17, "frame": 17, "trigger": [17, 18, 19, 27], "dk3": 17, "1055": 17, "44": 17, "__next__": 17, "train_loop_fn": 17, "48": [17, 21], "start_train": 17, "73": 17, "548000": 17, "gb": 17, "922460": 17, "547871": 17, "124478": 17, "028210": 17, "steptrac": 17, "frequenc": 17, "pair": 17, "met": 17, "spent": [17, 18], "destroi": 17, "percentil": 17, "totalsampl": 17, "202": 17, "06m09s401ms746": 17, "001u": 17, "valuer": 17, "778ms572": 17, "062u": 17, "rate": [17, 21], "425201": 17, "001ms32": 17, "778u": 17, "001ms61": 17, "283u": 17, "001ms79": 17, "236u": 17, "001ms110": 17, "973u": 17, "50": [17, 18, 23], "001ms228": 17, "773u": 17, "80": 17, "001ms339": 17, "183u": 17, "90": 17, "001ms434": 17, "305u": 17, "95": 17, "002ms921": 17, "063u": 17, "99": [17, 21], "21s102ms853": 17, "173u": 17, "cachedsynctensor": 17, "395": [17, 21], "area": 17, "rout": 17, "qualifi": 17, "33": [17, 21, 22], "_local_scalar_dens": 17, "epoch": [17, 18, 24], "clear_al": 17, "xla_dynamo_debug": 17, "bottleneck": [17, 18], "notebook": 17, "train_resnet_benchmark": 17, "behav": 17, "evalu": [17, 18, 19, 27], "suggest": 17, "bad": 17, "degrad": [17, 18], "speedup": [17, 22], "indirect": 17, "solut": [17, 19, 26, 27], "variat": 17, "pad": [17, 18, 19, 27], "fix": [17, 18, 22, 25], "translat": 17, "item": [17, 18], "substitut": 17, "flow": 17, "clip_grad_norm": 17, "problemat": 17, "clip_grad_norm_": 17, "dramat": 17, "total_norm": 17, "zero": [17, 24, 30], "param_norm": 17, "grad": 17, "norm": 17, "norm_typ": 17, "add_": 17, "clip_coef": 17, "max_norm": 17, "mul_": 17, "data_parallel": 17, "last": 17, "dataset": [17, 21, 24], "stride": 17, "reconstruct": 17, "shallow": 17, "ty": 17, "made": [17, 18, 19, 27, 28], "_get_xla_tensors_text": [17, 19, 27], "_get_xla_tensors_hlo": 17, "prior": [17, 30], "degre": 17, "xla_ir_debug": 17, "henc": [17, 22], "respons": [17, 18, 22, 30], "xla_save_tensors_fil": 17, "realli": [17, 19, 22, 27], "big": [17, 19, 27], "left": 17, "append": 17, "sheet": 17, "xla_save_tensors_fmt": 17, "text": 17, "dot": 17, "graphviz": 17, "xla_flag": 17, "xla_dump_to": 17, "dir_nam": 17, "unoptim": 17, "optimz": 17, "xla_metrics_fil": 17, "xla_save_hlo_fil": 17, "offend": 17, "xla_sync_wait": 17, "xla_use_eager_debug_mod": 17, "bypass": 17, "overal": [17, 18], "optimizaiton": 17, "tf_cpp_log_thread_id": 17, "tf_cpp_vmodul": 17, "vlog": 17, "tf_cpp_min_log_level": 17, "turn": 17, "warn": 17, "tf_vlog": 17, "xla_dump_hlo_graph": 17, "xla_util": 17, "cc": 17, "save1": 17, "xla_graph_executor": 17, "pjrt_computation_cli": 17, "dir": 17, "pytorch_test_with_slow": 17, "test_torch": 17, "test_put_xla_uint8": 17, "torch_test_devic": 17, "pytorch_test_bas": 17, "brief": 18, "basic": [18, 19, 21, 27], "reader": 18, "modif": 18, "fetch": 18, "discuss": [18, 29], "opcod": 18, "fed": 18, "attach": [18, 28], "callback": 18, "xla_tensor_z": 18, "cut": [18, 19, 27], "transferfromdevic": 18, "tell": [18, 19, 27], "properti": [18, 19, 27], "illustr": [18, 29], "suppos": 18, "tensors_on_devic": 18, "z": [18, 19, 27], "subgraph": [18, 19, 27], "signal": 18, "far": 18, "suitabl": 18, "trade": [18, 19, 27], "off": 18, "spend": 18, "fusion": 18, "worth": [18, 19, 27], "latter": [18, 24], "wheel": [18, 24], "runtime_vers": 18, "project_id": 18, "accelerator_typ": 18, "tpu_nam": 18, "your_tpu_nam": 18, "subnetwork": 18, "tpusubnet": 18, "pip3": 18, "cp38": 18, "linux_x86_64": 18, "whl": 18, "apt": 18, "libopenbla": 18, "dev": [18, 21], "libgl1": 18, "guidelin": 18, "bar": 18, "rememb": 18, "txt2img": 18, "prompt": 18, "photograph": 18, "astronaut": 18, "ride": 18, "hors": 18, "relat": 18, "precision_scop": 18, "addition": [18, 20, 24], "particular": 18, "frozenclipembedd": 18, "simplic": [18, 19, 27], "ddim": 18, "top": 18, "attr": 18, "statement": [18, 19, 27], "stop": 18, "fall": [18, 25], "difficult": 18, "readi": 18, "investig": [18, 21], "cover": [18, 28], "huggingfac": 18, "sd": 18, "xl": 18, "cd": [18, 24], "text_to_imag": 18, "inference_tpu_single_devic": 18, "lora": 18, "model_id": 18, "stabilityai": 18, "pipelin": 18, "dpmsolvermultistepschedul": 18, "txt": 18, "invisible_watermark": 18, "transform": [18, 24, 29], "safetensor": 18, "licens": 18, "card": 18, "cli": 18, "_your_copied_token__": 18, "pipe": 18, "hour": 18, "wherea": 18, "likewis": 18, "gpt": 18, "15": 18, "min": 18, "subsequ": 18, "advantag": 18, "mayb": 18, "notic": 18, "piec": 18, "__call__": 18, "commit": 18, "caveat": 18, "rule": [18, 20], "thumb": 18, "durat": [18, 30], "constantli": 18, "idl": 18, "inference_tpu_": 18, "capture_profil": 18, "gap": 18, "xp": 18, "measur": 18, "portion": 18, "busi": 18, "scroll": 18, "occupi": 18, "displai": 18, "largest": 18, "zoom": 18, "timelin": 18, "period": 18, "examin": 18, "did": 18, "pipe_watermark": 18, "closer": 18, "preced": 18, "proceed": [18, 25], "watermark": 18, "cv2": 18, "pywt": 18, "leav": 18, "broken": 18, "rerun": 18, "scale_model_input": 18, "ran": 18, "my_funct": 18, "preocess": 18, "debug_single_process": 18, "magic": [18, 19, 27], "treat": 18, "xla_no_special_scalar": 18, "hurt": [19, 27], "perf": [19, 27], "pov": [19, 27], "sai": [19, 27], "assur": [19, 27], "gone": [19, 27], "coverag": [19, 27], "aim": [19, 25, 27], "explan": [19, 27], "mainli": [19, 27], "problem": [19, 27], "beginn": [19, 27], "propos": [19, 27], "reli": [19, 27], "impract": [19, 27], "assumpt": [19, 27], "ye": [19, 26, 27], "sentenc": [19, 27], "bucket": [19, 27, 30], "kinda": [19, 27], "anti": [19, 27], "frontend": [19, 27], "matter": [19, 27], "workaround": [19, 27], "okai": [19, 27], "teach": [19, 27], "produc": [19, 20, 21, 27], "theoret": [19, 27], "sort": [19, 27], "obviou": [19, 27], "s64": [19, 27], "inde": [19, 27], "_get_xla_tensor_dimension_s": [19, 27], "commonli": [19, 27], "wrong": [19, 27], "wors": [19, 27], "probabl": [19, 27], "know": [19, 21, 27], "upper": [19, 27], "nit": [19, 27], "rand": [19, 27], "solv": [19, 27], "kept": [19, 27], "earli": [19, 27], "accessor": [19, 27], "2d": [19, 25, 27], "implicitli": [19, 27], "doubl": [19, 27], "overload": [19, 27], "explod": [19, 27], "convers": [19, 27], "cheap": [19, 27], "ve": [19, 27], "hoc": [19, 27], "think": [19, 27], "verison": [19, 27], "bla": [19, 27], "blabla": [19, 27], "interpret": [19, 27], "proce": [19, 27], "uglier": [19, 27], "win": [19, 27], "pars": [19, 27], "torchscript": [19, 27], "somehow": [19, 27], "merg": [19, 27], "lazili": [19, 27, 28, 30], "properli": [19, 27], "thought": [19, 27], "trivial": [19, 27], "effort": [19, 27, 28], "side": [19, 27], "bandwidth": [19, 27], "automag": [19, 27], "gold": [19, 27], "smart": [19, 27], "trick": [19, 27], "tbh": [19, 27], "longer": [19, 27], "unawar": [19, 27], "hope": [19, 27], "smash": [19, 27], "blocker": [19, 27], "ahead": [19, 27], "nnc": [19, 27], "exactli": [19, 27], "transpos": [19, 27], "brian": [19, 27], "hirsh": [19, 27], "bdhirsh": [19, 27], "question": [19, 27], "comment": [19, 27], "stick": [19, 27], "torch_warn": [19, 27], "yea": [19, 27], "hei": [19, 27], "won": [19, 20, 27], "blaze": [19, 27], "isn": [19, 27, 30], "abil": [19, 21, 27], "devirtu": [19, 27], "sound": [19, 27], "great": [19, 27], "carri": [19, 27, 28], "truth": [19, 27], "irvalu": [19, 27], "enforc": [19, 21, 27], "discrep": [19, 27], "followup": [19, 27], "1000": [19, 27], "my": [19, 27, 30], "presenc": [19, 27], "get_dimention_s": [19, 27], "didn": [19, 27], "exponenti": [19, 27], "blowup": [19, 27], "fewer": [19, 27], "opportun": [19, 27], "recogn": [19, 22, 27], "feasibl": [19, 27], "annoi": [19, 27], "wasn": [19, 27], "materiz": [19, 27], "combo": [19, 27], "extend": 20, "float32": 20, "datatyp": 20, "float16": 20, "bfloat16": [20, 26], "syncfre": 20, "autocast": 20, "summar": 20, "elig": 20, "suppli": 20, "addmm": 20, "addmm_": 20, "prefer": 20, "float64": 20, "respect": 20, "unlist": 20, "__matmul__": 20, "addbmm": 20, "addmv": 20, "addr": 20, "baddbmm": 20, "bmm": 20, "conv1d": 20, "conv2d": [20, 24], "conv3d": 20, "conv_transpose1d": 20, "conv_transpose2d": 20, "conv_transpose3d": 20, "matmul": 20, "relu": [20, 21], "prelu": 20, "max_pool2d": 20, "batch_norm": 20, "log_softmax": 20, "binary_cross_entropy_with_logit": 20, "prod": 20, "cdist": 20, "chloeski": 20, "invers": 20, "reflection_pad": 20, "replication_pad": 20, "mse_loss": 20, "cosine_embbeding_loss": 20, "nll_loss": 20, "multilabel_margin_loss": 20, "qr": 20, "svd": 20, "triangular_solv": 20, "linalg_svd": 20, "linalg_inv_ex": 20, "widest": 20, "index_copi": 20, "scaler": [20, 26], "gradscal": 20, "_fetch_gradi": 20, "xla_use_f16": 20, "underflow": 20, "imagenet": 20, "minimum": [21, 24, 25], "nccl": 21, "new_rank": 21, "ddp_model": 21, "final": [21, 28], "launcher": 21, "demo_fn": 21, "touch": [21, 30], "five": 21, "sy": 21, "tempfil": 21, "cleanup": 21, "destroy_process_group": 21, "toymodel": 21, "net1": 21, "1000000": 21, "net2": 21, "demo_bas": 21, "graident_as_bucket_view": 21, "label": 21, "run_demo": 21, "tot": 21, "statist": 21, "unit": 21, "median": 21, "90th": 21, "deviat": 21, "cv": 21, "418": 21, "54": 21, "419": 21, "22": 21, "430": 21, "40": 21, "76": 21, "02": 21, "97": 21, "407": 21, "60": 21, "39": 21, "seem": 21, "17864": 21, "19": [21, 22], "20108": 21, "96": 21, "24351": 21, "74": 21, "5866": 21, "83": 21, "10701": 21, "11770": 21, "00": 21, "14313": 21, "78": 21, "3102": 21, "92": 21, "41": [21, 22], "round": 21, "heavili": [21, 22], "sens": 21, "amort": 21, "logdir": 21, "converg": 21, "caution": 21, "interest": 21, "known": 21, "crash": 21, "unmodifi": 22, "hook": 22, "biggest": [22, 24], "torchfx": 22, "technologi": 22, "fx": 22, "a_xla": 22, "b_xla": 22, "compiled_cod": 22, "eval_model": 22, "xla_resnet18": 22, "eval": 22, "dynamo_resnet18": 22, "no_grad": 22, "resent18": 22, "analysi": 22, "bench": 22, "59": 22, "resnext50_32x4d": 22, "91": 22, "alexnet": 22, "28": 22, "mobilenet_v2": 22, "18": 22, "62": 22, "mnasnet1_0": 22, "68": 22, "vgg16": 22, "bert_pytorch": 22, "squeezenet1_1": 22, "timm_vision_transform": 22, "52": 22, "geomean": 22, "04": 22, "train_model": 22, "crossentropyloss": 22, "pred": 22, "train_model_main": 22, "dynamo_train_model": 22, "xla_optim": 22, "weight_decai": 22, "extract": 22, "07": 22, "43": 22, "81": 22, "87": 22, "fwd": 22, "bwd": 22, "e2": 22, "hide": 22, "larger": [22, 24], "wit": 22, "promis": 22, "tradit": 22, "excit": 22, "upcom": [22, 28], "invest": 22, "matur": 22, "stori": 22, "_higher_order_op": 23, "fori_loop": 23, "cond_fn": 23, "body_fn": 23, "bodi": 23, "iteri": 23, "init_v": 23, "functionaltensor": 23, "lvl": 23, "cumul": 23, "ten": 23, "51": 23, "xlafullyshardeddataparallel": 24, "my_modul": [24, 25], "adam": [24, 25], "0001": [24, 25], "leftov": [24, 25], "arxiv": 24, "1910": 24, "02054": 24, "reshard_after_forward": 24, "test_train_mp_mnist_fsdp_with_ckpt": 24, "test_train_mp_imagenet_fsdp": 24, "interleav": 24, "submodul": 24, "fsdpvitmodel": 24, "checkpoint_modul": [24, 25], "3524": 24, "auto_wrap_polici": [24, 25], "size_based_auto_wrap_polici": 24, "polici": [24, 28], "100m": 24, "transformer_auto_wrap_polici": [24, 25], "transformer_layer_cl": [24, 25], "auto_wrapper_cal": 24, "remateri": 24, "resum": 24, "get_shard_metadata": 24, "consolidate_sharded_model_checkpoint": 24, "stitch": 24, "ckpt": 24, "shard_metadata": 24, "ckpt_path": 24, "pth": 24, "consolidate_sharded_ckpt": 24, "ckpt_prefix": 24, "your_sharded_checkpoint_fil": 24, "ckpt_suffix": 24, "_rank": 24, "inspir": 24, "structur": [24, 28], "fairscal": 24, "fullyshardeddataparallel": 24, "readthedoc": 24, "en": 24, "resort": 24, "train_resnet_fsdp_auto_wrap": 24, "newer": 24, "recurs": [24, 25], "98": 24, "drop_last": 24, "use_nested_fsdp": 24, "use_gradient_checkpoint": 24, "final_ckpt": 24, "75": 24, "download": 24, "1k": 24, "datadir": 24, "test_set_batch_s": 24, "eval_interv": 24, "num_warmup_epoch": 24, "lr_scheduler_divide_every_n_epoch": 24, "lr_scheduler_divisor": 24, "residu": 24, "algorithm": [24, 25], "ronghanghu": 24, "vit_10b_fsdp_exampl": 24, "vit": 24, "fsdpv2": 25, "famou": 25, "enjoi": 25, "tabl": 25, "spmd_fully_sharded_data_parallel": 25, "spmdfullyshardeddataparallel": 25, "autowrap": 25, "decoderlay": 25, "functool": 25, "decoder_only_model": 25, "shard_output": 25, "0th": 25, "children": 25, "fork": 25, "hf": 25, "abstract": [26, 28], "blockwis": 26, "int4": 26, "analog": 26, "classifi": 26, "flexibl": 26, "choos": [26, 30], "docstr": 26, "xla_quantized_matmul": 26, "n_input_featur": 26, "n_output_featur": 26, "w_int": 26, "127": 26, "int8": 26, "matmul_output": 26, "quantized_matmul": 26, "x_xla": 26, "w_int_xla": 26, "scaler_xla": 26, "matmul_output_xla": 26, "w": 26, "f_dynamo": 26, "dynamo_out_xla": 26, "myqlinearforxlabackend": 26, "load_weight": 26, "processed_w": 26, "processed_scal": 26, "stuff": 26, "orig_model": 26, "mymodel": 26, "q_weight": 26, "q_weights_for_xla": 26, "process_for_xla": 26, "q_linear": 26, "xlaquantizedlinear": 26, "in_featur": 26, "out_featur": 26, "load_quantized_weight": 26, "channel": 26, "sym": 26, "asym": 26, "w8a16": 26, "w8a8": 26, "w4a8": 26, "gspmd": [28, 29], "proced": 28, "src": [28, 30], "_input_sharding_": 28, "4d": 28, "input_shard": 28, "shardingspec": 28, "input_mesh": 28, "s2": 28, "s3": 28, "s4": 28, "_after": 28, "_the": 28, "unnecessari": 28, "forth": 28, "techniqu": 28, "decis": 28, "nice": 28, "arrang": 28, "center": 28, "multislic": 28, "denot": 28, "delai": 28, "subclass": 28, "__torch_dispatch__": 28, "global_tensor": 28, "strictli": 28, "local_shard": 28, "xlashard": 28, "4e8e5511555073ce8b6d1a436bf808c9333dcac6": 28, "xla_sharded_tensor": 28, "l12": 28, "ongo": 28, "distributedtensor": 28, "proof": 28, "concept": [28, 29], "distribute_tensor": 28, "devicemesh": 28, "big_tensor": 28, "100000": 28, "88": 28, "my_dtensor": 28, "stai": 28, "dynamo_mark_shard": 28, "placement": 28, "visualize_tensor_shard": 28, "visualize_shard": 28, "rich": 28, "2x2": 28, "generated_t": 28, "use_color": 28, "style": 28, "tile": 28, "partial_repl": 28, "envvar": 28, "xla_auto_spmd": 28, "_tensor": 28, "distribute_modul": 28, "auto_polici": 28, "mymodul": 28, "sharded_model": 28, "behvaior": 28, "xla_auto_use_group_shard": 28, "reshard": 28, "xla_auto_spmd_mesh": 28, "unset": 28, "hint": 29, "strategi": 29, "th": 29, "cluster": 29, "interconnect": 29, "encourag": 29, "fist": 29, "paral": 29, "dedic": 30, "planner": 30, "spmdsaveplann": 30, "spmdloadplann": 30, "dist_cp": 30, "distributed_checkpoint": 30, "xc": 30, "storage_writ": 30, "filesystemwrit": 30, "checkpoint_dir": 30, "storage_read": 30, "filesystemread": 30, "all_step": 30, "save_async": 30, "unblock": 30, "preemption": 30, "detect": 30, "provis": 30, "queuedresourc": 30, "autocheckpoint": 30, "chkpt_on_preempt": 30, "fsspec": 30, "filesystem": 30, "prime_optim": 30, "chkpt_mgr": 30, "tracked_step": 30, "highest": 30, "best_step": 30, "prime": 30, "enumer": 30, "attempt": 30, "unprim": 30, "destruct": 30, "discov": 30, "nvidia": 31, "resnet": 31, "num_gpu_machin": 31, "rank_of_current_machin": 31, "machine_0_ip_address": 31, "training_or_inference_script_using_spmd": 31, "xla_use_spmd": 31, "test_train_spmd_imagenet": 31}, "objects": {"": [[12, 0, 0, "-", "torch_xla"]], "torch_xla": [[12, 1, 1, "", "compile"], [12, 1, 1, "", "device"], [12, 1, 1, "", "device_count"], [12, 1, 1, "", "devices"], [12, 0, 0, "-", "experimental"], [12, 1, 1, "", "manual_seed"], [12, 0, 0, "-", "runtime"], [12, 1, 1, "", "sync"]], "torch_xla.core": [[12, 0, 0, "-", "xla_model"]], "torch_xla.core.xla_model": [[12, 1, 1, "", "add_step_closure"], [12, 1, 1, "", "all_gather"], [12, 1, 1, "", "all_reduce"], [12, 1, 1, "", "all_to_all"], [12, 1, 1, "", "get_memory_info"], [12, 1, 1, "", "get_rng_state"], [12, 1, 1, "", "get_stablehlo"], [12, 1, 1, "", "get_stablehlo_bytecode"], [12, 1, 1, "", "is_master_ordinal"], [12, 1, 1, "", "mesh_reduce"], [12, 1, 1, "", "optimizer_step"], [12, 1, 1, "", "rendezvous"], [12, 1, 1, "", "save"], [12, 1, 1, "", "set_rng_state"], [12, 1, 1, "", "wait_device_ops"], [12, 1, 1, "", "xla_device"], [12, 1, 1, "", "xla_device_hw"]], "torch_xla.debug": [[12, 0, 0, "-", "metrics"]], "torch_xla.debug.metrics": [[12, 1, 1, "", "counter_names"], [12, 1, 1, "", "counter_value"], [12, 1, 1, "", "metric_data"], [12, 1, 1, "", "metric_names"], [12, 1, 1, "", "metrics_report"], [12, 1, 1, "", "short_metrics_report"]], "torch_xla.distributed": [[12, 0, 0, "-", "parallel_loader"], [12, 0, 0, "-", "spmd"], [12, 0, 0, "-", "xla_multiprocessing"]], "torch_xla.distributed.parallel_loader": [[12, 2, 1, "", "MpDeviceLoader"]], "torch_xla.distributed.spmd": [[12, 2, 1, "", "HybridMesh"], [12, 2, 1, "", "Mesh"], [12, 1, 1, "", "clear_sharding"], [12, 1, 1, "", "get_1d_mesh"], [12, 1, 1, "", "get_global_mesh"], [12, 1, 1, "", "mark_sharding"], [12, 1, 1, "", "set_global_mesh"]], "torch_xla.distributed.xla_multiprocessing": [[12, 1, 1, "", "spawn"]], "torch_xla.experimental": [[12, 1, 1, "", "eager_mode"]], "torch_xla.runtime": [[12, 1, 1, "", "addressable_device_count"], [12, 1, 1, "", "device_type"], [12, 1, 1, "", "get_master_ip"], [12, 1, 1, "", "global_device_count"], [12, 1, 1, "", "global_ordinal"], [12, 1, 1, "", "global_runtime_device_count"], [12, 1, 1, "", "initialize_cache"], [12, 1, 1, "", "is_spmd"], [12, 1, 1, "", "local_device_count"], [12, 1, 1, "", "local_ordinal"], [12, 1, 1, "", "local_process_count"], [12, 1, 1, "", "use_spmd"], [12, 1, 1, "", "world_size"]]}, "objtypes": {"0": "py:module", "1": "py:function", "2": "py:class"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "function", "Python function"], "2": ["py", "class", "Python class"]}, "titleterms": {"learn": [0, 1, 11], "about": [0, 1, 11], "gpu": [0, 10, 15, 20, 31], "tpu": [1, 4, 15, 16, 18, 20, 24, 28], "bazel": 2, "pytorch": [2, 3, 4, 6, 7, 8, 9, 11, 12, 16, 17, 18, 22, 24, 27, 28, 29], "xla": [2, 3, 4, 6, 7, 8, 11, 12, 16, 17, 18, 20, 22, 24, 26, 27, 28, 29], "depend": [2, 8, 10], "how": [2, 21, 26, 29], "build": 2, "librari": 2, "torch": [2, 7, 9, 15, 28], "plugin": [2, 6], "remot": 2, "cach": [2, 16], "run": [2, 3, 9, 16, 17, 18, 28, 31], "test": [2, 3, 5, 17, 23], "code": [2, 4, 18, 26], "coverag": 2, "languag": 2, "server": 2, "codegen": 3, "migrat": 3, "guid": [3, 5, 29], "befor": [3, 5], "you": [3, 5, 19, 27], "start": [3, 5, 19, 27], "file": [3, 5, 9], "structur": [3, 5], "old": 3, "op": [3, 5, 7, 26], "lower": [3, 5, 7], "step": [3, 4], "1": [3, 18, 19, 27], "identifi": 3, "2": [3, 18, 19, 25, 27], "inspect": 3, "gener": [3, 9], "lazyir": 3, "h": 3, "3": [3, 19, 27], "implement": [3, 6], "miss": 3, "ir": 3, "function": 3, "torch_xla": [3, 12, 19], "csrc": 3, "ops_xla_shape_fn": 3, "cpp": 3, "4": 3, "ops_lower_fn": 3, "5": 3, "cleanup": 3, "verifi": 3, "result": 3, "sampl": 3, "pr": 3, "configur": 4, "develop": 4, "environ": [4, 17], "visual": 4, "studio": 4, "creat": [4, 16], "connect": 4, "your": 4, "set": 4, "up": 4, "workspac": 4, "next": 4, "understand": [5, 17], "oper": [5, 9, 19, 20, 26, 27], "unit": [5, 17], "tip": 5, "custom": [6, 8, 10], "hardwar": 6, "pjrt": [6, 15], "c": 6, "api": [6, 7, 12, 14], "packag": 6, "support": [7, 20, 26], "distribut": [7, 12, 15, 30], "collect": 7, "stack": 7, "non": 7, "dynamo": [7, 17], "case": [7, 19, 23, 27], "descript": 7, "kernel": [8, 10], "via": [8, 10], "palla": 8, "adopt": 8, "abov": 8, "compat": 8, "us": [8, 19, 21, 23, 25, 26, 27, 29], "built": 8, "flashattent": 8, "exampl": [8, 18, 20, 23, 24, 25], "usag": [8, 14, 23], "integr": [8, 22, 28], "pagedattent": 8, "export": 9, "stablehlo": 9, "save": [9, 16], "bytecod": 9, "disk": 9, "convert": [9, 18], "serv": 9, "common": [9, 17], "wrapper": 9, "i": [9, 19, 27, 29], "want": 9, "directli": 9, "tf": 9, "saved_model": 9, "format": 9, "without": [9, 19, 27], "need": 9, "an": [9, 16], "separ": 9, "command": 9, "other": 9, "produc": 9, "save_as_stablehlo": 9, "preserv": 9, "high": 9, "level": 9, "composit": 9, "triton": 10, "document": 11, "acceler": 11, "featur": [11, 22, 26], "improv": 11, "workload": 11, "perform": [11, 15, 17, 18], "contribut": 11, "runtim": [12, 15], "xla_model": 12, "spmd": [12, 25, 28, 29, 31], "experiment": [12, 26], "debug": [12, 17, 28], "dynam": [13, 19, 27], "shape": [13, 19, 27], "bound": [13, 19, 27], "eager": 14, "mode": [14, 29], "compil": [14, 16, 17, 28], "basic": 14, "infer": [14, 18, 22], "train": [14, 15, 22, 24], "benchmark": [14, 17, 21], "tl": 15, "dr": 15, "benefit": 15, "quickstart": 15, "cpu": [15, 16], "pod": [15, 16, 18, 24, 28], "docker": 15, "singl": [15, 16, 18], "node": 15, "multi": [15, 16], "differ": 15, "from": [15, 16, 19, 27], "xrt": 15, "multithread": 15, "v2": 15, "v3": [15, 24], "chang": 15, "xm": 15, "rendezv": 15, "new": 15, "devic": [16, 18, 28], "tensor": [16, 17, 19, 27], "ar": 16, "model": [16, 26], "multipl": [16, 18], "process": [16, 30], "deep": 16, "dive": 16, "lazi": 16, "memori": [16, 23], "layout": 16, "move": 16, "load": [16, 28], "further": [16, 29], "read": [16, 29], "troubleshoot": 17, "saniti": 17, "check": 17, "version": 17, "A": 17, "simpl": [17, 23], "calcul": 17, "resnet": [17, 24], "With": 17, "fake": [17, 21], "data": [17, 21, 24, 25, 28], "tool": [17, 28], "auto": [17, 28], "metric": 17, "analysi": [17, 18], "execut": 17, "get": 17, "report": 17, "The": 17, "clear": 17, "profil": [17, 18], "known": 17, "caveat": 17, "quirk": 17, "more": 17, "variabl": 17, "combin": 17, "reproduc": 17, "ci": 17, "cd": 17, "failur": 17, "overview": 18, "setup": 18, "stabl": 18, "diffus": 18, "lightn": 18, "hf": 18, "sourc": [19, 27], "recompil": [19, 27], "let": [19, 27], "": [19, 27], "first": [19, 27], "some": [19, 27], "fact": [19, 27], "constraint": [19, 27], "input": [19, 27], "dataset": [19, 27], "output": [19, 25, 27], "can": [19, 27], "fix": [19, 27], "when": [19, 27], "queri": [19, 27], "its": [19, 27], "real": [19, 21, 27], "dimens": [19, 27], "what": [19, 27, 29], "control": [19, 23, 27], "flow": [19, 27], "conclus": [19, 27], "appendix": [19, 27], "automat": 20, "mix": 20, "precis": 20, "amp": 20, "best": 20, "practic": 20, "do": 21, "distributeddataparallel": 21, "ddp": 21, "background": 21, "motiv": 21, "resnet50": 21, "mnist": [21, 24], "disclaim": 21, "torchdynamo": 22, "gap": 22, "take": 22, "awai": 22, "optim": [23, 28, 30], "util": 23, "while_loop": 23, "group": [23, 30], "pure": 23, "python": 23, "while": 23, "loop": 23, "fulli": [24, 25], "shard": [24, 25, 28], "parallel": [24, 25], "script": 24, "imagenet": 24, "instal": 24, "clone": 24, "repo": 24, "8": 24, "50": 24, "10": 24, "billion": 24, "paramet": 24, "gradient": 25, "checkpoint": [25, 30], "huggingfac": 25, "llama": 25, "quantiz": 26, "call": 26, "modul": 26, "swap": 26, "matrix": 26, "multipli": 26, "advanc": 28, "topic": 28, "awar": 28, "host": 28, "virtual": 28, "hybrid": 28, "mesh": [28, 29], "xlashardedtensor": 28, "dtensor": 28, "activ": 28, "user": 29, "partit": 29, "spec": 29, "checkpointmanag": 30, "restor": 30, "state": 30}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 8, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx.ext.todo": 2, "sphinx.ext.viewcode": 1, "sphinx": 57}, "alltitles": {"Learn about GPUs": [[0, "learn-about-gpus"]], "Learn about TPUs": [[1, "learn-about-tpus"]], "Bazel in Pytorch/XLA": [[2, "bazel-in-pytorch-xla"]], "Bazel dependencies": [[2, "bazel-dependencies"]], "How to build XLA libraries": [[2, "how-to-build-xla-libraries"]], "How to build the Torch/XLA plugin": [[2, "how-to-build-the-torch-xla-plugin"]], "Remote caching": [[2, "remote-caching"]], "Running tests": [[2, "running-tests"]], "Code coverage": [[2, "code-coverage"]], "Language Server": [[2, "language-server"]], "Building PyTorch/XLA": [[2, "building-pytorch-xla"]], "Codegen migration Guide": [[3, "codegen-migration-guide"]], "Before you start": [[3, "before-you-start"], [5, "before-you-start"]], "File structure": [[3, "file-structure"], [5, "file-structure"]], "PyTorch Codegen files": [[3, "pytorch-codegen-files"]], "PyTorch/XLA Codegen files": [[3, "pytorch-xla-codegen-files"]], "PyTorch/XLA Old Op Lowering files": [[3, "pytorch-xla-old-op-lowering-files"]], "Codegen step by step": [[3, "codegen-step-by-step"]], "1. Identify the op": [[3, "identify-the-op"]], "2. Codegen the op and inspect the generated file": [[3, "codegen-the-op-and-inspect-the-generated-file"]], "LazyIr.h": [[3, "lazyir-h"]], "3. Implement the missing IR function": [[3, "implement-the-missing-ir-function"]], "torch_xla/csrc/ops/ops_xla_shape_fn.h": [[3, "torch-xla-csrc-ops-ops-xla-shape-fn-h"]], "torch_xla/csrc/ops/ops_xla_shape_fn.cpp": [[3, "torch-xla-csrc-ops-ops-xla-shape-fn-cpp"]], "4. Implement the lowering function": [[3, "implement-the-lowering-function"]], "torch_xla/csrc/ops/ops_lower_fn.cpp": [[3, "torch-xla-csrc-ops-ops-lower-fn-cpp"]], "5. Cleanup": [[3, "cleanup"]], "Run the test and verify the result": [[3, "run-the-test-and-verify-the-result"]], "Sample PRs": [[3, "sample-prs"]], "Configure a development environment": [[4, "configure-a-development-environment"]], "Visual Studio Code": [[4, "visual-studio-code"]], "Creating and connecting to your TPU": [[4, "creating-and-connecting-to-your-tpu"]], "Setting up a Visual Studio Code workspace with PyTorch/XLA": [[4, "setting-up-a-visual-studio-code-workspace-with-pytorch-xla"]], "Next steps": [[4, "next-steps"]], "OP Lowering Guide": [[5, "op-lowering-guide"]], "Understanding the operation": [[5, "understanding-the-operation"]], "Unit Test": [[5, "unit-test"]], "Tips": [[5, "tips"]], "Custom Hardware Plugins": [[6, "custom-hardware-plugins"]], "Implementing a PJRT Plugin": [[6, "implementing-a-pjrt-plugin"]], "PJRT C API Implementation": [[6, "pjrt-c-api-implementation"]], "PyTorch/XLA Plugin Package": [[6, "pytorch-xla-plugin-package"]], "Support of Torch Distributed API in PyTorch/XLA": [[7, "support-of-torch-distributed-api-in-pytorch-xla"]], "Collective ops lowering": [[7, "collective-ops-lowering"]], "Collective ops lowering stack": [[7, "collective-ops-lowering-stack"]], "non-Dynamo case": [[7, "non-dynamo-case"]], "Dynamo case": [[7, "dynamo-case"]], "API description": [[7, "api-description"]], "Custom Kernels via Pallas": [[8, "custom-kernels-via-pallas"]], "Adopt the above kernel to be compatible with PyTorch/XLA": [[8, "adopt-the-above-kernel-to-be-compatible-with-pytorch-xla"]], "Use built-in kernels": [[8, "use-built-in-kernels"]], "FlashAttention": [[8, "id1"]], "Example usage": [[8, "example-usage"], [8, "id3"]], "Integration Example": [[8, "integration-example"], [8, "id4"]], "PagedAttention": [[8, "id2"]], "Dependencies": [[8, "dependencies"], [10, "dependencies"]], "Torch Export to StableHLO": [[9, "torch-export-to-stablehlo"]], "Saving StableHLO bytecodes to disk": [[9, "saving-stablehlo-bytecodes-to-disk"]], "Convert saved StableHLO for serving": [[9, "convert-saved-stablehlo-for-serving"]], "Common wrappers": [[9, "common-wrappers"]], "I want to save directly tf.saved_model format without needing to run an separate command.": [[9, "i-want-to-save-directly-tf-saved-model-format-without-needing-to-run-an-separate-command"]], "Other common wrappers": [[9, "other-common-wrappers"]], "Files produced by save_as_stablehlo.": [[9, "files-produced-by-save-as-stablehlo"]], "Preserving High-Level PyTorch Operations in StableHLO by generating stablehlo.composite": [[9, "preserving-high-level-pytorch-operations-in-stablehlo-by-generating-stablehlo-composite"]], "Custom GPU Kernels via Triton": [[10, "custom-gpu-kernels-via-triton"]], "PyTorch/XLA documentation": [[11, "pytorch-xla-documentation"]], "Learn about Pytorch/XLA": [[11, null]], "Learn about accelerators": [[11, null]], "PyTorch/XLA features": [[11, null]], "Improve Pytorch/XLA workload performance": [[11, null]], "Contribute to Pytorch/XLA": [[11, null]], "PyTorch/XLA API": [[12, "pytorch-xla-api"]], "torch_xla": [[12, "module-torch_xla"]], "runtime": [[12, "module-torch_xla.runtime"]], "xla_model": [[12, "module-torch_xla.core.xla_model"]], "distributed": [[12, "module-torch_xla.distributed.parallel_loader"]], "spmd": [[12, "module-torch_xla.distributed.spmd"]], "experimental": [[12, "module-torch_xla.experimental"]], "debug": [[12, "module-torch_xla.debug.metrics"]], "Dynamic shape": [[13, "dynamic-shape"]], "Bounded dynamic shape": [[13, "bounded-dynamic-shape"]], "Eager Mode + Compile API": [[14, "eager-mode-compile-api"]], "Basic Usage": [[14, "basic-usage"]], "Inference": [[14, "inference"], [22, "inference"]], "Training": [[14, "training"], [22, "training"]], "Benchmark": [[14, "benchmark"]], "PJRT Runtime": [[15, "pjrt-runtime"]], "TL;DR": [[15, "tl-dr"]], "Benefits": [[15, "benefits"]], "Quickstart": [[15, "quickstart"]], "CPU": [[15, "cpu"]], "TPU": [[15, "tpu"]], "Pods": [[15, "pods"]], "Docker": [[15, "docker"]], "GPU": [[15, "gpu"]], "Single-node GPU training": [[15, "single-node-gpu-training"]], "Multi-node GPU training": [[15, "multi-node-gpu-training"]], "Differences from XRT": [[15, "differences-from-xrt"]], "Multithreading on TPU v2/v3": [[15, "id3"]], "Changes to xm.rendezvous": [[15, "changes-to-xm-rendezvous"]], "PJRT and torch.distributed": [[15, "pjrt-and-torch-distributed"]], "Performance": [[15, "performance"]], "New TPU runtime": [[15, "new-tpu-runtime"]], "PyTorch on XLA Devices": [[16, "pytorch-on-xla-devices"]], "Creating an XLA Tensor": [[16, "creating-an-xla-tensor"]], "XLA Tensors are PyTorch Tensors": [[16, "xla-tensors-are-pytorch-tensors"]], "Running Models on XLA Devices": [[16, "running-models-on-xla-devices"]], "Running on a Single XLA Device": [[16, "running-on-a-single-xla-device"]], "Running on Multiple XLA Devices with Multi-processing": [[16, "running-on-multiple-xla-devices-with-multi-processing"]], "Running on TPU Pods": [[16, "running-on-tpu-pods"]], "XLA Tensor Deep Dive": [[16, "id3"]], "XLA Tensors are Lazy": [[16, "xla-tensors-are-lazy"]], "Memory Layout": [[16, "memory-layout"]], "Moving XLA Tensors to and from the CPU": [[16, "moving-xla-tensors-to-and-from-the-cpu"]], "Saving and Loading XLA Tensors": [[16, "saving-and-loading-xla-tensors"]], "Compilation Caching": [[16, "compilation-caching"]], "Further Reading": [[16, "further-reading"], [29, "further-reading"]], "Troubleshoot": [[17, "troubleshoot"]], "Sanity Check": [[17, "sanity-check"]], "Check PyTorch/XLA Version": [[17, "check-pytorch-xla-version"]], "Perform A Simple Calculation": [[17, "perform-a-simple-calculation"]], "Run Resnet With Fake Data": [[17, "run-resnet-with-fake-data"]], "Performance Debugging": [[17, "performance-debugging"]], "PyTorch/XLA Debugging Tool": [[17, "pytorch-xla-debugging-tool"]], "Perform A Auto-Metrics Analysis": [[17, "perform-a-auto-metrics-analysis"]], "Compilation & Execution Analysis": [[17, "compilation-execution-analysis"]], "Get A Metrics Report": [[17, "get-a-metrics-report"]], "Understand The Metrics Report": [[17, "understand-the-metrics-report"]], "Clear The Metrics Report": [[17, "clear-the-metrics-report"]], "PyTorch/XLA + Dynamo Debugging Tool": [[17, "pytorch-xla-dynamo-debugging-tool"]], "Performance Profiling": [[17, "performance-profiling"]], "Simple Benchmarking": [[17, "simple-benchmarking"]], "Known Performance Caveats": [[17, "known-performance-caveats"]], "XLA Tensor Quirks": [[17, "xla-tensor-quirks"]], "More Debugging Tools": [[17, "more-debugging-tools"]], "Environment Variables": [[17, "environment-variables"]], "Common Debugging Environment Variables Combinations": [[17, "common-debugging-environment-variables-combinations"]], "Reproducing PyTorch/XLA CI/CD unit test failures.": [[17, "reproducing-pytorch-xla-ci-cd-unit-test-failures"]], "Pytorch/XLA overview": [[18, "pytorch-xla-overview"]], "TPU Setup": [[18, "tpu-setup"]], "Converting code to PyTorch XLA": [[18, "converting-code-to-pytorch-xla"]], "Example 1. Stable Diffusion inference in PyTorch Lightning on a Single TPU Device": [[18, "example-1-stable-diffusion-inference-in-pytorch-lightning-on-a-single-tpu-device"]], "Example 2. HF Stable Diffusion Inference": [[18, "example-2-hf-stable-diffusion-inference"]], "Running on a Single TPU device": [[18, "running-on-a-single-tpu-device"]], "Profiling and performance analysis": [[18, "profiling-and-performance-analysis"]], "Running on Multiple TPU Devices": [[18, "running-on-multiple-tpu-devices"]], "Running on Pods": [[18, "running-on-pods"]], "Source of recompilations in torch_xla": [[19, "source-of-recompilations-in-torch-xla"]], "Let\u2019s first start with some facts/constraints:": [[19, "lets-first-start-with-some-facts-constraints"], [27, "lets-first-start-with-some-facts-constraints"]], "#1. From input dataset.": [[19, "from-input-dataset"], [27, "from-input-dataset"]], "#2. From operator output": [[19, "from-operator-output"], [27, "from-operator-output"]], "2.1 Bounded dynamic shape can fix the case when you use the tensor with dynamic shape as a Tensor, without querying its real dimension.": [[19, "bounded-dynamic-shape-can-fix-the-case-when-you-use-the-tensor-with-dynamic-shape-as-a-tensor-without-querying-its-real-dimension"], [27, "bounded-dynamic-shape-can-fix-the-case-when-you-use-the-tensor-with-dynamic-shape-as-a-tensor-without-querying-its-real-dimension"]], "2.2 what if real dimension is queried on a tensor with dynamic shape?": [[19, "what-if-real-dimension-is-queried-on-a-tensor-with-dynamic-shape"], [27, "what-if-real-dimension-is-queried-on-a-tensor-with-dynamic-shape"]], "#3. From control flow": [[19, "from-control-flow"], [27, "from-control-flow"]], "Conclusion:": [[19, "conclusion"], [27, "conclusion"]], "Appendix:": [[19, "appendix"], [27, "appendix"]], "Automatic Mixed Precision": [[20, "automatic-mixed-precision"]], "AMP for XLA:TPU": [[20, "amp-for-xla-tpu"]], "AMP for XLA:TPU Best Practices": [[20, "amp-for-xla-tpu-best-practices"]], "Supported Operators": [[20, "supported-operators"]], "AMP for XLA:GPU": [[20, "amp-for-xla-gpu"]], "AMP for XLA:GPU Best Practices": [[20, "amp-for-xla-gpu-best-practices"]], "Examples": [[20, "examples"]], "How to do DistributedDataParallel(DDP)": [[21, "how-to-do-distributeddataparallel-ddp"]], "Background / Motivation": [[21, "background-motivation"]], "How to use DistributedDataParallel": [[21, "how-to-use-distributeddataparallel"]], "Benchmarking": [[21, "benchmarking"]], "Resnet50 with fake data": [[21, "resnet50-with-fake-data"]], "MNIST with fake data": [[21, "mnist-with-fake-data"]], "MNIST with real data": [[21, "mnist-with-real-data"]], "Disclaimer": [[21, "disclaimer"]], "TorchDynamo integration in PyTorch XLA": [[22, "torchdynamo-integration-in-pytorch-xla"]], "Integration": [[22, "integration"]], "Feature gaps": [[22, "feature-gaps"]], "Take away": [[22, "take-away"]], "Optimize memory utilization using while_loop": [[23, "optimize-memory-utilization-using-while-loop"]], "while_loop": [[23, "while-loop"]], "Usage:": [[23, "usage"]], "simple example with while_loop:": [[23, "simple-example-with-while-loop"]], "Control group test case": [[23, "control-group-test-case"]], "Control group example with pure python while loop": [[23, "control-group-example-with-pure-python-while-loop"]], "Fully Sharded Data Parallel in PyTorch XLA": [[24, "fully-sharded-data-parallel-in-pytorch-xla"]], "Example training scripts on MNIST and ImageNet": [[24, "example-training-scripts-on-mnist-and-imagenet"]], "Installation": [[24, "installation"]], "Clone PyTorch/XLA repo": [[24, "clone-pytorch-xla-repo"]], "Train MNIST on v3-8 TPU": [[24, "train-mnist-on-v3-8-tpu"]], "Train ImageNet with ResNet-50 on v3-8 TPU": [[24, "train-imagenet-with-resnet-50-on-v3-8-tpu"]], "Example training scripts on TPU pod (with 10 billion parameters)": [[24, "example-training-scripts-on-tpu-pod-with-10-billion-parameters"]], "Fully Sharded Data Parallel using SPMD": [[25, "fully-sharded-data-parallel-using-spmd"]], "Sharding output": [[25, "sharding-output"]], "Gradient checkpointing": [[25, "gradient-checkpointing"]], "HuggingFace Llama 2 Example": [[25, "huggingface-llama-2-example"]], "Quantized Operations for XLA (Experimental feature)": [[26, "quantized-operations-for-xla-experimental-feature"]], "How to use:": [[26, "how-to-use"]], "Call XLA quantized op in model code": [[26, "call-xla-quantized-op-in-model-code"]], "Module Swap": [[26, "module-swap"]], "Supported Quantized Operations:": [[26, "supported-quantized-operations"]], "Matrix Multiply": [[26, "matrix-multiply"]], "Source of recompilations in Pytorch/XLA": [[27, "source-of-recompilations-in-pytorch-xla"]], "PyTorch/XLA SPMD advanced topics": [[28, "pytorch-xla-spmd-advanced-topics"]], "Sharding-Aware Host-to-Device Data Loading": [[28, "sharding-aware-host-to-device-data-loading"]], "Virtual Device Optimization": [[28, "virtual-device-optimization"]], "Hybrid Mesh": [[28, "hybrid-mesh"]], "Running SPMD on TPU Pod": [[28, "running-spmd-on-tpu-pod"]], "XLAShardedTensor": [[28, "xlashardedtensor"]], "DTensor Integration": [[28, "dtensor-integration"]], "Activation Sharding for torch.compile": [[28, "activation-sharding-for-torch-compile"]], "SPMD Debugging Tool": [[28, "spmd-debugging-tool"]], "Auto-Sharding": [[28, "auto-sharding"]], "PyTorch/XLA SPMD User Guide": [[29, "pytorch-xla-spmd-user-guide"]], "What is PyTorch/XLA SPMD?": [[29, "what-is-pytorch-xla-spmd"]], "How to use PyTorch/XLA SPMD?": [[29, "how-to-use-pytorch-xla-spmd"]], "SPMD Mode": [[29, "spmd-mode"]], "Mesh": [[29, "mesh"]], "Partition Spec": [[29, "partition-spec"]], "Distributed Checkpointing": [[30, "distributed-checkpointing"]], "CheckpointManager": [[30, "checkpointmanager"]], "Restoring Optimizer State": [[30, "restoring-optimizer-state"]], "Process Groups": [[30, "process-groups"]], "Running SPMD on GPU": [[31, "running-spmd-on-gpu"]]}, "indexentries": {"hybridmesh (class in torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.HybridMesh"]], "mesh (class in torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.Mesh"]], "mpdeviceloader (class in torch_xla.distributed.parallel_loader)": [[12, "torch_xla.distributed.parallel_loader.MpDeviceLoader"]], "add_step_closure() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.add_step_closure"]], "addressable_device_count() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.addressable_device_count"]], "all_gather() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.all_gather"]], "all_reduce() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.all_reduce"]], "all_to_all() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.all_to_all"]], "clear_sharding() (in module torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.clear_sharding"]], "compile() (in module torch_xla)": [[12, "torch_xla.compile"]], "counter_names() (in module torch_xla.debug.metrics)": [[12, "torch_xla.debug.metrics.counter_names"]], "counter_value() (in module torch_xla.debug.metrics)": [[12, "torch_xla.debug.metrics.counter_value"]], "device() (in module torch_xla)": [[12, "torch_xla.device"]], "device_count() (in module torch_xla)": [[12, "torch_xla.device_count"]], "device_type() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.device_type"]], "devices() (in module torch_xla)": [[12, "torch_xla.devices"]], "eager_mode() (in module torch_xla.experimental)": [[12, "torch_xla.experimental.eager_mode"]], "get_1d_mesh() (in module torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.get_1d_mesh"]], "get_global_mesh() (in module torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.get_global_mesh"]], "get_master_ip() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.get_master_ip"]], "get_memory_info() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.get_memory_info"]], "get_rng_state() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.get_rng_state"]], "get_stablehlo() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.get_stablehlo"]], "get_stablehlo_bytecode() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.get_stablehlo_bytecode"]], "global_device_count() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.global_device_count"]], "global_ordinal() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.global_ordinal"]], "global_runtime_device_count() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.global_runtime_device_count"]], "initialize_cache() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.initialize_cache"]], "is_master_ordinal() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.is_master_ordinal"]], "is_spmd() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.is_spmd"]], "local_device_count() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.local_device_count"]], "local_ordinal() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.local_ordinal"]], "local_process_count() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.local_process_count"]], "manual_seed() (in module torch_xla)": [[12, "torch_xla.manual_seed"]], "mark_sharding() (in module torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.mark_sharding"]], "mesh_reduce() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.mesh_reduce"]], "metric_data() (in module torch_xla.debug.metrics)": [[12, "torch_xla.debug.metrics.metric_data"]], "metric_names() (in module torch_xla.debug.metrics)": [[12, "torch_xla.debug.metrics.metric_names"]], "metrics_report() (in module torch_xla.debug.metrics)": [[12, "torch_xla.debug.metrics.metrics_report"]], "module": [[12, "module-torch_xla"], [12, "module-torch_xla.core.xla_model"], [12, "module-torch_xla.debug.metrics"], [12, "module-torch_xla.distributed.parallel_loader"], [12, "module-torch_xla.distributed.spmd"], [12, "module-torch_xla.distributed.xla_multiprocessing"], [12, "module-torch_xla.experimental"], [12, "module-torch_xla.runtime"]], "optimizer_step() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.optimizer_step"]], "rendezvous() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.rendezvous"]], "save() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.save"]], "set_global_mesh() (in module torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.set_global_mesh"]], "set_rng_state() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.set_rng_state"]], "short_metrics_report() (in module torch_xla.debug.metrics)": [[12, "torch_xla.debug.metrics.short_metrics_report"]], "spawn() (in module torch_xla.distributed.xla_multiprocessing)": [[12, "torch_xla.distributed.xla_multiprocessing.spawn"]], "sync() (in module torch_xla)": [[12, "torch_xla.sync"]], "torch_xla": [[12, "module-torch_xla"]], "torch_xla.core.xla_model": [[12, "module-torch_xla.core.xla_model"]], "torch_xla.debug.metrics": [[12, "module-torch_xla.debug.metrics"]], "torch_xla.distributed.parallel_loader": [[12, "module-torch_xla.distributed.parallel_loader"]], "torch_xla.distributed.spmd": [[12, "module-torch_xla.distributed.spmd"]], "torch_xla.distributed.xla_multiprocessing": [[12, "module-torch_xla.distributed.xla_multiprocessing"]], "torch_xla.experimental": [[12, "module-torch_xla.experimental"]], "torch_xla.runtime": [[12, "module-torch_xla.runtime"]], "use_spmd() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.use_spmd"]], "wait_device_ops() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.wait_device_ops"]], "world_size() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.world_size"]], "xla_device() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.xla_device"]], "xla_device_hw() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.xla_device_hw"]]}}) \ No newline at end of file +Search.setIndex({"docnames": ["accelerators/gpu", "accelerators/tpu", "contribute/bazel", "contribute/codegen_migration", "contribute/configure-environment", "contribute/op_lowering", "contribute/plugins", "features/distop", "features/pallas", "features/stablehlo", "features/triton", "index", "learn/api-guide", "learn/dynamic_shape", "learn/eager", "learn/pjrt", "learn/pytorch-on-xla-devices", "learn/troubleshoot", "learn/xla-overview", "notes/source_of_recompilation", "perf/amp", "perf/ddp", "perf/dynamo", "perf/fori_loop", "perf/fsdp", "perf/fsdpv2", "perf/quantized_ops", "perf/recompilation", "perf/spmd_advanced", "perf/spmd_basic", "perf/spmd_distributed_checkpoint", "perf/spmd_gpu"], "filenames": ["accelerators/gpu.md", "accelerators/tpu.md", "contribute/bazel.md", "contribute/codegen_migration.md", "contribute/configure-environment.md", "contribute/op_lowering.md", "contribute/plugins.md", "features/distop.md", "features/pallas.md", "features/stablehlo.md", "features/triton.md", "index.rst", "learn/api-guide.rst", "learn/dynamic_shape.md", "learn/eager.md", "learn/pjrt.md", "learn/pytorch-on-xla-devices.md", "learn/troubleshoot.md", "learn/xla-overview.md", "notes/source_of_recompilation.md", "perf/amp.md", "perf/ddp.md", "perf/dynamo.md", "perf/fori_loop.md", "perf/fsdp.md", "perf/fsdpv2.md", "perf/quantized_ops.md", "perf/recompilation.md", "perf/spmd_advanced.md", "perf/spmd_basic.md", "perf/spmd_distributed_checkpoint.md", "perf/spmd_gpu.md"], "titles": ["Learn about GPUs", "Learn about TPUs", "Bazel in Pytorch/XLA", "Codegen migration Guide", "Configure a development environment", "OP Lowering Guide", "Custom Hardware Plugins", "Support of Torch Distributed API in PyTorch/XLA", "Custom Kernels via Pallas", "Torch Export to StableHLO", "Custom GPU Kernels via Triton", "PyTorch/XLA documentation", "PyTorch/XLA API", "Dynamic shape", "Eager Mode + Compile API", "PJRT Runtime", "PyTorch on XLA Devices", "Troubleshoot", "Pytorch/XLA overview", "Source of recompilations in torch_xla", "Automatic Mixed Precision", "How to do DistributedDataParallel(DDP)", "TorchDynamo integration in PyTorch XLA", "Optimize memory utilization using while_loop", "Fully Sharded Data Parallel in PyTorch XLA", "Fully Sharded Data Parallel using SPMD", "Quantized Operations for XLA (Experimental feature)", "Source of recompilations in Pytorch/XLA", "PyTorch/XLA SPMD advanced topics", "PyTorch/XLA SPMD User Guide", "Distributed Checkpointing", "Running SPMD on GPU"], "terms": {"For": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 26, 27, 28, 29, 31], "inform": [0, 1, 4, 10, 12, 14, 15, 16, 17, 18, 19, 27, 31], "googl": [0, 1, 8, 15, 16], "cloud": [0, 1, 2, 4, 6, 11, 15, 16, 22, 30], "see": [0, 1, 2, 3, 4, 5, 6, 9, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 24, 27], "machin": [0, 2, 4, 15, 17, 18, 31], "type": [0, 4, 6, 9, 12, 15, 16, 17, 18, 20, 21], "ar": [1, 2, 3, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 24, 26, 27, 28, 29, 30], "custom": [1, 3, 4, 7, 11, 12, 19, 21, 24, 26, 27, 28, 29], "design": [1, 15, 16, 22, 25, 29], "ai": 1, "acceler": [1, 4, 12, 13, 15, 16, 18, 20], "which": [1, 2, 3, 5, 6, 7, 9, 12, 13, 15, 16, 17, 18, 19, 20, 22, 24, 25, 27, 28, 30], "optim": [1, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 24, 25, 27], "train": [1, 8, 12, 13, 16, 17, 18, 20, 28, 30, 31], "infer": [1, 3, 12, 15, 20, 28, 31], "larg": [1, 13, 15, 18, 19, 24, 27, 29], "model": [1, 3, 5, 9, 10, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 24, 25, 27, 28, 29, 30, 31], "thei": [1, 2, 5, 6, 7, 12, 15, 16, 17, 18, 19, 20, 27, 28, 29], "ideal": [1, 2, 3, 19, 22, 27], "varieti": 1, "us": [1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 14, 15, 16, 17, 18, 20, 22, 24, 28, 30, 31], "case": [1, 2, 3, 5, 9, 12, 15, 16, 17, 18, 22, 25, 28], "chatbot": 1, "code": [1, 3, 5, 9, 10, 12, 14, 15, 16, 17, 19, 21, 22, 27, 28], "gener": [1, 5, 12, 14, 15, 16, 17, 18, 19, 27], "media": 1, "content": [1, 12], "synthet": 1, "speech": 1, "vision": [1, 24], "servic": [1, 2, 15], "recommend": [1, 2, 3, 4, 5, 12, 14, 15, 16, 20, 28], "engin": [1, 17], "person": 1, "among": 1, "other": [1, 2, 3, 5, 8, 12, 13, 15, 16, 17, 18, 19, 20, 21, 26, 27, 29], "scale": [1, 9, 12, 15, 20, 22, 29], "cost": [1, 22], "effici": [1, 9, 17, 18, 22], "wide": [1, 5, 19, 27], "rang": [1, 5, 12, 15, 25, 28, 29], "workload": [1, 15, 16, 17, 28, 29], "span": [1, 3], "fine": 1, "tune": [1, 28], "provid": [1, 2, 3, 5, 6, 8, 9, 12, 16, 17, 18, 19, 20, 22, 23, 24, 26, 27, 28, 29, 30], "versatil": 1, "lead": [1, 17, 18], "framework": [1, 9, 11, 14, 19, 26, 27], "includ": [1, 2, 5, 12, 15, 17, 18, 19, 20, 23, 27, 30], "pytorch": [1, 5, 10, 13, 14, 15, 19, 20, 21, 23, 26, 30, 31], "jax": [1, 6, 8, 9, 15], "tensorflow": [1, 2, 6, 9, 12, 15, 17, 19, 27], "seamlessli": 1, "orchestr": 1, "through": [1, 3, 5, 6, 7, 8, 16, 18, 19, 20, 27, 30], "integr": [1, 10, 11, 25, 26, 29], "kubernet": 1, "gke": 1, "leverag": [1, 10, 31], "dynam": [1, 3, 5, 11, 17, 18, 22], "schedul": [1, 18], "improv": [1, 15, 16, 17, 18, 20, 22, 28], "scalabl": 1, "all": [1, 2, 3, 5, 7, 10, 12, 15, 16, 17, 18, 19, 20, 21, 24, 25, 27, 28, 30], "need": [1, 2, 3, 5, 12, 15, 16, 17, 18, 19, 21, 24, 25, 27, 28, 29], "simultan": 1, "look": [1, 3, 5, 16, 17, 18, 28], "simplest": 1, "wai": [1, 2, 5, 7, 8, 12, 15, 16, 18, 19, 21, 22, 26, 27, 28], "develop": [1, 2, 10, 11, 14, 16, 21, 22, 26, 29], "can": [1, 2, 3, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 20, 21, 22, 24, 25, 26, 28, 29, 30, 31], "also": [1, 2, 3, 5, 6, 7, 9, 10, 12, 14, 15, 16, 17, 18, 19, 22, 24, 25, 26, 27, 28, 29], "vertex": 1, "fulli": [1, 11, 14, 15, 17, 29], "manag": [1, 8, 12, 20, 30], "platform": 1, "more": [1, 2, 3, 4, 5, 8, 9, 10, 12, 13, 14, 15, 16, 18, 19, 27, 28, 29, 31], "introduct": [1, 8], "set": [1, 2, 7, 12, 15, 17, 18, 19, 20, 22, 24, 27, 28, 30], "up": [1, 2, 3, 15, 16, 18, 19, 22, 25, 27], "environ": [1, 2, 11, 15, 16, 18, 21, 28, 30], "resourc": [1, 12, 17], "i": [2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 21, 22, 23, 24, 25, 26, 28, 30], "free": [2, 5, 13, 17, 20, 21, 24], "softwar": [2, 17], "tool": [2, 5, 18, 24], "autom": 2, "openxla": [2, 6, 14, 22, 26], "both": [2, 4, 5, 7, 9, 15, 18, 19, 20, 22, 24, 25, 26, 27, 29, 30], "make": [2, 4, 10, 12, 14, 15, 16, 17, 18, 19, 21, 22, 27, 28], "good": [2, 3, 5, 18, 19, 27, 28], "fit": [2, 3, 18, 24], "well": [2, 3, 6, 9, 12, 15, 18, 19, 27, 29], "extern": [2, 4, 8], "seen": [2, 18, 22], "workspac": [2, 17], "file": [2, 4, 12, 15, 17, 18, 20, 21], "http_archiv": 2, "name": [2, 4, 5, 9, 12, 15, 17, 19, 25, 27, 28, 29], "org_tensorflow": 2, "strip_prefix": 2, "f7759359f8420d3ca7b9fd19493f2a01bd47b4ef": 2, "url": 2, "http": [2, 3, 4, 8, 10, 12, 15, 17, 18, 24, 28], "github": [2, 3, 4, 5, 10, 12, 15, 17, 18, 21, 24, 28], "com": [2, 3, 4, 8, 10, 12, 15, 17, 18, 24, 28], "archiv": 2, "tar": 2, "gz": 2, "pin": [2, 12], "updat": [2, 3, 7, 16, 18, 19, 20, 27, 28], "point": [2, 3, 4, 5, 6, 9, 12, 18, 19, 20, 27], "thi": [2, 3, 4, 5, 6, 7, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 24, 25, 26, 27, 28, 29, 30, 31], "repositori": [2, 15], "differ": [2, 6, 12, 16, 17, 18, 19, 21, 23, 24, 27, 28, 29], "revis": 2, "patch": [2, 17], "mai": [2, 3, 6, 15, 16, 17, 18, 19, 20, 27, 28], "ad": [2, 5, 12, 16, 18, 19, 22, 23, 27, 28], "resolv": 2, "prepar": 2, "hermet": 2, "mechan": 2, "deploi": 2, "becaus": [2, 3, 9, 14, 15, 16, 18, 20, 28], "local": [2, 4, 12, 15, 16, 17, 28], "checkout": [2, 17], "ha": [2, 3, 4, 5, 8, 12, 14, 15, 16, 18, 19, 27, 28, 29], "built": [2, 4], "from": [2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 17, 18, 20, 21, 22, 23, 24, 25, 28, 29, 30], "sourc": [2, 3, 5, 6, 9, 11, 12, 17], "instal": [2, 3, 4, 5, 6, 8, 9, 10, 15, 17, 18], "system": [2, 29], "version": [2, 3, 4, 8, 15, 18, 20, 28], "compat": [2, 9, 15, 26, 30], "e": [2, 4, 6, 7, 9, 12, 13, 15, 17, 18, 19, 20, 24, 26, 27, 28], "g": [2, 4, 6, 7, 9, 12, 13, 15, 17, 18, 19, 26, 27, 28, 30], "codegen": [2, 5, 11], "torchgen": [2, 3], "python": [2, 3, 4, 5, 6, 9, 10, 11, 12, 13, 15, 17, 18, 19, 21, 22, 27, 28], "modul": [2, 8, 9, 12, 16, 17, 21, 24, 25, 28], "should": [2, 3, 4, 5, 6, 9, 10, 12, 14, 15, 16, 17, 18, 19, 20, 23, 24, 27, 28, 30], "The": [2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 18, 19, 21, 22, 24, 25, 26, 27, 28, 29, 30, 31], "directori": [2, 3, 5, 9], "either": [2, 5, 7, 12, 15, 17, 19, 20, 27], "bzl": 2, "overriden": 2, "command": [2, 3, 4, 15, 16, 17, 18, 21, 24], "line": [2, 3, 12, 14, 16, 17, 18, 19, 24, 27], "override_repositori": 2, "path": [2, 6, 9, 12, 16, 17, 19, 24, 27], "export": [2, 3, 4, 5, 11, 15, 17, 18], "tf_repo": 2, "torch_repo": 2, "pleas": [2, 3, 5, 7, 9, 12, 15, 16, 17, 18, 20, 24, 25, 26, 28, 31], "sure": [2, 16, 17], "overridden": [2, 3], "appropri": [2, 18], "been": [2, 5, 12, 15, 16, 18, 19, 27, 28], "use_cuda": 2, "0": [2, 3, 4, 6, 9, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31], "setup": [2, 3, 6, 16, 21], "py": [2, 3, 4, 5, 7, 10, 13, 14, 15, 16, 17, 18, 21, 24, 28, 31], "bdist_wheel": 2, "expect": [2, 3, 6, 10, 14, 15, 17, 19, 22, 26, 27], "object": [2, 12, 28], "present": [2, 30], "new_local_repositori": 2, "build_fil": 2, "pytorch_local_dir": 2, "header": 2, "directli": [2, 3, 5, 6, 12, 15, 16, 17, 18, 19, 20, 24, 27, 28, 30], "share": [2, 3, 6, 15, 16, 17, 28], "libtorch": 2, "so": [2, 3, 6, 10, 12, 13, 15, 16, 17, 18, 19, 24, 27, 30], "same": [2, 3, 5, 6, 9, 10, 12, 14, 15, 16, 17, 18, 19, 20, 23, 26, 27, 28, 29, 31], "where": [2, 4, 7, 8, 12, 13, 15, 16, 17, 18, 19, 24, 25, 27], "lib": [2, 6], "contain": [2, 3, 5, 6, 9, 10, 12, 15, 17, 18, 19, 27], "work": [2, 3, 7, 12, 13, 15, 16, 17, 18, 19, 21, 22, 26, 27, 28, 29], "": [2, 4, 5, 6, 7, 8, 9, 12, 14, 15, 16, 17, 18, 20, 21, 22, 26, 28, 29, 30], "requir": [2, 3, 5, 12, 13, 15, 16, 17, 18, 19, 20, 27, 28, 30, 31], "pass": [2, 5, 9, 10, 12, 15, 18, 20, 21, 28], "isystemextern": 2, "compil": [2, 5, 6, 9, 10, 11, 12, 13, 15, 18, 19, 20, 22, 25, 26, 27, 29, 30], "find": [2, 3, 5, 9, 15, 17, 18, 21, 25], "satisfi": [2, 28], "them": [2, 3, 5, 9, 12, 15, 16, 17, 18, 19, 27], "some": [2, 3, 5, 12, 13, 14, 15, 16, 17, 21, 26, 28], "user": [2, 4, 6, 9, 11, 14, 15, 16, 17, 18, 19, 22, 23, 25, 26, 27, 28, 30], "bring": [2, 3, 25], "pybind11": 2, "embed": 2, "link": [2, 3], "against": [2, 21], "libpython": 2, "instead": [2, 7, 12, 14, 15, 16, 17, 18, 19, 21, 22, 24, 27, 28, 30], "These": [2, 3, 5, 8, 15, 18, 26, 30], "pybind11_emb": 2, "option": [2, 3, 4, 6, 9, 12, 15, 17, 18, 26, 28, 30], "transit": [2, 16], "simpl": [2, 3, 8, 15, 18, 20, 24, 29], "torch_xla": [2, 4, 5, 6, 7, 8, 9, 10, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], "csrc": [2, 5], "runtim": [2, 3, 4, 6, 11, 16, 17, 21, 24, 28, 29, 30], "configr": 2, "via": [2, 4, 11, 15, 23, 24, 25, 28, 29], "bazelrc": 2, "take": [2, 3, 9, 10, 12, 16, 17, 18, 19, 27, 28], "flag": [2, 3, 12, 13, 20], "config": [2, 4], "remote_cach": 2, "configur": [2, 3, 5, 11, 12, 15, 17, 18, 30], "gcloud": [2, 4, 15, 16, 18], "usual": [2, 3, 5, 14, 16, 17], "faster": [2, 15, 18, 19, 22, 27], "authent": [2, 15], "easi": [2, 15, 16, 19, 27], "express": [2, 25, 29], "complex": [2, 10, 22], "lot": [2, 16, 17, 18, 19, 27], "gain": [2, 15], "have": [2, 3, 4, 5, 6, 8, 9, 12, 15, 16, 17, 18, 19, 21, 22, 24, 25, 27, 28, 30], "singl": [2, 12, 14, 19, 21, 22, 24, 25, 27, 28, 29, 31], "graph": [2, 9, 10, 12, 14, 15, 16, 17, 18, 19, 21, 22, 27, 28], "everyth": [2, 19, 21, 27], "therefor": [2, 17, 18], "separ": [2, 3, 5, 16, 18, 22, 24, 25], "rest": [2, 15, 17, 19, 27], "plu": [2, 21, 23], "whole": [2, 12, 14, 19, 22, 27], "everythin": 2, "els": [2, 17, 19, 27], "enough": [2, 18, 19, 27], "normal": [2, 3, 15, 19, 25, 27, 28], "achiev": [2, 5, 14, 21], "invok": [2, 3, 22, 28], "standard": [2, 9], "c": [2, 3, 5, 12, 15, 17, 19, 20, 27], "bind": [2, 9], "simpli": [2, 15], "_xlac": [2, 10, 17, 19, 27], "client": [2, 6, 12, 15], "togeth": [2, 14, 15, 16, 21, 24, 28], "when": [2, 3, 5, 7, 10, 12, 13, 14, 15, 16, 17, 18, 20, 22, 24, 28, 29, 30], "chang": [2, 5, 13, 16, 17, 18, 19, 20, 21, 26, 27, 28], "abl": [2, 16, 19, 27, 30], "without": [2, 5, 12, 15, 17, 18, 28, 29, 30], "iter": [2, 12, 13, 16, 17, 18, 22, 28], "cycl": 2, "come": [2, 12, 19, 27], "There": [2, 3, 14, 16, 17, 18, 19, 21, 22, 27, 28], "plenti": 2, "backend": [2, 3, 7, 12, 14, 15, 19, 22, 23, 26, 27, 28, 30], "we": [2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 29, 31], "our": [2, 3, 4, 5, 6, 7, 8, 9, 13, 15, 16, 17, 19, 20, 21, 22, 27, 28], "gc": [2, 30], "storag": [2, 4, 8, 16, 17, 18, 24, 30], "you": [2, 4, 6, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 21, 22, 24, 25, 28, 29, 31], "under": [2, 3, 5, 12, 15, 16, 21], "disabl": [2, 12, 14, 17, 18], "default": [2, 5, 12, 14, 15, 16, 17, 18, 20, 24, 28, 30], "speed": [2, 18, 19, 22, 27], "increment": [2, 3], "huge": [2, 17, 18, 19, 21, 27], "margin": 2, "almost": [2, 29], "alwai": [2, 15, 16, 17, 19, 27, 29], "enabl": [2, 10, 12, 13, 14, 17, 18, 20, 21, 26, 28, 29, 30], "ci": [2, 5], "To": [2, 3, 4, 5, 6, 8, 9, 10, 12, 13, 14, 15, 17, 18, 19, 24, 25, 27, 28, 30, 31], "ensur": [2, 9, 12, 19, 25, 27, 28, 30], "credenti": 2, "auth": [2, 15], "applic": [2, 17, 26, 30], "login": [2, 18], "launch": [2, 12, 15, 16, 18, 21, 22, 24], "browser": 2, "gcp": [2, 4, 15], "variou": [2, 10], "individu": [2, 24, 25, 29], "who": [2, 21], "access": [2, 3, 5, 8, 12, 15, 16, 17, 18, 19, 27, 30], "project": [2, 4, 6, 15, 16, 18], "one": [2, 3, 5, 8, 9, 12, 15, 16, 17, 18, 19, 22, 23, 24, 25, 27, 28, 29, 31], "onli": [2, 3, 5, 7, 9, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23, 27, 29, 30], "specifi": [2, 7, 9, 12, 16, 18, 24, 28], "google_default_credenti": 2, "token": [2, 14, 18, 26], "out": [2, 5, 9, 12, 13, 14, 15, 16, 17, 18, 20, 22, 28], "box": [2, 5, 28], "log": [2, 17, 18], "permiss": 2, "add": [2, 3, 5, 9, 10, 12, 16, 17, 18, 19, 22, 23, 24, 27], "new": [2, 3, 4, 5, 7, 14, 16, 17, 18, 19, 22, 27, 28], "role": 2, "In": [2, 3, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 27, 28, 29, 30], "account": [2, 18], "kei": [2, 4, 6, 15, 17, 18, 30], "google_credenti": 2, "On": [2, 15, 30], "docker": [2, 9], "network": [2, 12, 15, 16, 17, 20, 28], "cloudbuild": 2, "down": [2, 5, 18], "imag": [2, 15, 18, 19, 21, 24, 27], "do": [2, 3, 5, 11, 13, 15, 16, 17, 18, 19, 20, 24, 26, 27, 28], "doe": [2, 3, 12, 13, 15, 16, 17, 18, 19, 20, 27, 28], "read": [2, 4, 5, 12, 15, 28], "write": [2, 5, 10, 12, 16, 29], "silo": 2, "each": [2, 3, 5, 7, 10, 12, 15, 16, 17, 18, 19, 22, 24, 25, 27, 28, 29, 30], "uniqu": [2, 16, 18, 19, 27], "benefit": [2, 18, 25, 26, 30], "consist": [2, 7, 9, 15], "remote_default_exec_properti": 2, "some_silo_kei": 2, "bazel_remote_cach": 2, "1": [2, 4, 6, 7, 8, 9, 12, 14, 15, 16, 17, 20, 21, 22, 23, 24, 25, 28, 29, 31], "silo_nam": 2, "your": [2, 3, 6, 8, 9, 15, 16, 17, 18, 19, 21, 25, 27, 28, 30], "tpuvm_mod": 2, "gcloud_service_key_fil": 2, "application_default_credenti": 2, "json": [2, 9], "might": [2, 5, 12, 16, 17, 18, 19, 27], "help": [2, 17, 18, 19, 27], "too": [2, 17, 19, 27], "cannot": [2, 8, 18, 19, 20, 24, 27], "here": [2, 3, 5, 8, 9, 13, 16, 18, 19, 21, 22, 24, 25, 27, 28, 29, 30], "author": 2, "usernam": 2, "behavior": [2, 3, 5, 15, 16, 17, 20], "function": [2, 5, 6, 7, 8, 9, 10, 12, 14, 16, 17, 18, 22, 23, 25, 26, 30], "intend": 2, "first": [2, 3, 4, 9, 10, 12, 13, 15, 17, 18, 21, 28, 29, 30, 31], "time": [2, 3, 4, 12, 13, 15, 16, 17, 18, 19, 22, 23, 27, 28], "slow": [2, 17, 18], "scratch": [2, 3], "veri": [2, 6, 8, 14, 16, 18, 19, 27], "fast": [2, 19, 27], "onc": [2, 7, 12, 16, 17, 18, 19, 22, 27, 28], "again": [2, 3, 9, 16, 18], "bit": [2, 16, 26], "slower": [2, 17, 18, 21], "per": [2, 9, 12, 15, 16, 17, 20, 21, 22, 26], "until": [2, 12, 16, 18, 30], "next": [2, 12, 17, 18, 19, 26, 27, 28], "quit": 2, "current": [2, 6, 8, 9, 12, 13, 14, 15, 16, 18, 19, 21, 22, 23, 25, 26, 27, 28, 31], "migrat": [2, 11, 15], "futur": [2, 3, 4, 6, 9, 13, 15, 16, 17, 18, 19, 25, 27], "plafrom": 2, "cpp": [2, 5], "main": [2, 4, 7, 9, 10, 14, 15, 28], "Of": 2, "cours": 2, "pjrt": [2, 11, 12, 16, 28], "Not": 2, "environment": 2, "variabl": [2, 4, 13, 15, 18, 19, 27], "miss": [2, 5, 12, 17], "common": [2, 15, 19, 25, 26, 27, 29, 30], "part": [2, 3, 6, 10, 12, 14, 15, 17, 18, 28], "ones": [2, 12, 19, 27], "helper": [2, 3, 9, 12], "script": [2, 3, 4, 8, 15, 16, 17, 18, 20, 21, 31], "run_test": 2, "sh": 2, "r": [2, 18], "xla_client": 2, "pure": [2, 3], "easili": [2, 5, 19, 22, 27], "execut": [2, 10, 12, 14, 15, 16, 18, 19, 20, 21, 22, 27, 28, 29, 31], "parallel": [2, 11, 12, 15, 17, 21, 28, 29], "sinc": [2, 3, 5, 15, 16, 17, 18, 19, 20, 21, 22, 25, 27, 30], "xrt": [2, 12], "port": [2, 15, 31], "gpu": [2, 5, 6, 8, 11, 13, 17, 18, 28], "tpu": [2, 3, 5, 6, 8, 11, 12, 13, 17, 21, 22, 23, 30, 31], "devic": [2, 3, 4, 5, 6, 7, 9, 11, 12, 13, 14, 15, 17, 19, 20, 21, 22, 23, 26, 27, 29, 30], "avail": [2, 12, 15, 16, 17, 18, 19, 24, 27, 31], "reason": [2, 3, 5, 14, 15, 18, 21], "bundl": 2, "target": [2, 9, 14, 15, 16, 18, 19, 20, 22, 27], "sequenti": [2, 12], "calcul": 2, "visual": [2, 28], "lcov": 2, "describ": [2, 3, 4, 9, 12, 16, 18, 20, 21, 29], "document": [2, 3, 4, 5, 6, 9, 15, 16, 20, 21, 26], "editor": 2, "choic": [2, 19, 27], "gutter": 2, "vscode": 2, "power": 2, "like": [2, 3, 4, 5, 8, 12, 15, 16, 17, 18, 19, 20, 24, 27, 28], "clangd": 2, "refer": [2, 3, 5, 7, 8, 9, 10, 13, 15, 16, 18, 24, 26, 28, 31], "autocomplet": 2, "semant": [2, 5, 17, 19, 27], "understand": [2, 18, 19, 27], "underli": [2, 12, 16], "stack": [2, 16, 17, 19, 20, 27, 28], "combin": [2, 5, 12, 19, 27], "studio": 2, "extens": [2, 4, 5, 6], "featur": [2, 8, 13, 15, 17, 21, 25, 28, 29, 30], "assist": 2, "edit": 2, "As": [2, 3, 18, 19, 25, 27], "distutil": 2, "ltc": 3, "lazi": [3, 17, 18, 19, 22, 27, 28], "tensor": [3, 5, 7, 9, 12, 13, 15, 18, 20, 22, 23, 25, 26, 28, 29], "core": [3, 5, 7, 9, 12, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 29], "clean": [3, 17, 22], "exist": [3, 9, 12, 14, 15, 16, 17, 22, 28], "stub": 3, "over": [3, 12, 14, 15, 16, 18, 24, 30], "6": [3, 4, 5, 9, 12, 17, 18, 19, 27], "were": [3, 16, 17, 18, 19, 27], "complet": [3, 12, 16, 17], "process": [3, 5, 6, 7, 10, 12, 14, 15, 17, 18, 21, 24, 26], "found": [3, 15, 18], "ref": [3, 4, 15], "replac": [3, 18, 23], "support": [3, 6, 8, 9, 10, 12, 13, 15, 19, 22, 23, 24, 27, 28, 30, 31], "NOT": 3, "introduc": [3, 7, 8, 14, 15, 17, 18, 21, 28], "ani": [3, 8, 9, 12, 13, 15, 16, 17, 18, 19, 20, 21, 24, 25, 27, 28, 29, 30, 31], "purpos": [3, 5, 26], "follow": [3, 5, 8, 9, 10, 12, 14, 15, 16, 17, 18, 19, 21, 24, 25, 27, 28, 31], "instruct": [3, 5, 18], "depend": [3, 4, 5, 13, 14, 16, 18, 19, 20, 27], "build": [3, 5, 16, 18, 24], "It": [3, 4, 5, 7, 12, 13, 14, 16, 18, 19, 22, 24, 25, 26, 27, 28], "experi": [3, 5, 14, 15, 21, 30], "workstat": [3, 5], "cpu": [3, 5, 7, 9, 12, 17, 18, 19, 24, 26, 27, 28, 30], "pjrt_devic": [3, 5, 6, 13, 15, 16, 17, 21, 23, 31], "re": [3, 12, 14, 15, 17, 18, 19, 20, 23, 25, 27], "familiar": [3, 16, 25], "issu": [3, 5, 12, 14, 15, 16, 17, 18, 20, 21, 25], "3560": 3, "track": [3, 17, 30], "statu": [3, 17], "put": [3, 5, 16, 17, 21], "alia": [3, 7, 12], "avoid": [3, 17, 18, 20], "duplic": 3, "mention": [3, 5, 19, 22, 27], "below": [3, 5, 7, 9, 14, 15, 18, 19, 20, 27, 30, 31], "live": [3, 5, 12, 19, 27], "folder": [3, 4, 5], "except": [3, 5, 18, 28], "xla_native_funct": [3, 5], "yaml": [3, 5], "torch": [3, 4, 8, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 29, 30], "shape_infer": 3, "shape": [3, 5, 8, 9, 10, 11, 12, 17, 18, 23, 28, 29], "defin": [3, 5, 8, 10, 12, 18, 20, 23, 25, 28, 29], "input": [3, 5, 7, 9, 10, 12, 13, 14, 15, 16, 17, 18, 20, 23, 25, 28, 29, 30], "return": [3, 5, 6, 7, 8, 9, 12, 14, 16, 17, 18, 19, 21, 22, 23, 26, 27, 28, 30], "output": [3, 4, 7, 8, 9, 10, 12, 15, 16, 17, 20, 21, 22, 23, 24, 28], "manual": [3, 5, 8, 14, 17, 24], "gen_lazy_tensor": 3, "data": [3, 7, 9, 11, 12, 14, 15, 16, 18, 19, 20, 22, 27, 29, 30], "aten": [3, 5, 17, 19, 27], "specif": [3, 12, 16, 18, 20, 21, 26], "run_gen_lazy_tensor": 3, "dest": 3, "lazy_ir": 3, "class": [3, 6, 7, 9, 12, 21, 24, 26, 30], "genlazyir": 3, "back": [3, 5, 9, 12, 16, 17, 18, 28], "todai": [3, 13], "most": [3, 6, 12, 15, 17, 22], "categori": [3, 25], "goal": [3, 4, 5, 7, 14], "move": [3, 9, 12, 15, 17, 19, 21, 27, 30], "full_codegen": 3, "necessari": [3, 12, 17, 20], "call": [3, 5, 7, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23, 24, 27, 28, 30], "upstream": [3, 7, 13, 22], "api": [3, 5, 11, 15, 16, 19, 21, 22, 24, 26, 27, 28, 29, 30], "xlanativefunct": [3, 5], "column": 3, "declar": [3, 5], "anoth": [3, 9, 13, 16, 17, 18, 19, 27], "wrap": [3, 5, 6, 8, 9, 12, 14, 16, 18, 20, 24, 25, 26, 28], "around": [3, 15, 19, 24, 27], "xlatensor": [3, 5, 12, 28], "construct": [3, 5, 16, 18, 24, 28, 29, 30], "aten_xla_typ": [3, 5], "Will": 3, "method": [3, 9, 12, 15, 20, 25, 28, 30], "map": [3, 5, 7, 12], "node": [3, 5, 7, 10, 17, 19, 27, 31], "remov": [3, 15, 17, 18], "tensor_method": [3, 5], "possibl": [3, 15, 16, 17, 18, 24, 25, 28], "multipl": [3, 7, 9, 12, 14, 19, 22, 26, 27], "few": [3, 16, 17, 18, 19, 21, 27, 30], "simpler": [3, 15], "go": [3, 14, 16, 18, 20, 28], "unari": 3, "binari": [3, 6, 9, 22], "exampl": [3, 4, 5, 6, 7, 9, 12, 13, 14, 15, 16, 17, 19, 21, 22, 26, 27, 28, 29, 30, 31], "characterist": 3, "fallback": [3, 5], "_adaptive_avg_pool3d": 3, "condit": [3, 19, 23, 27], "issupportedadaptivepool": 3, "xlahelp": 3, "i64list": 3, "self": [3, 5, 6, 7, 9, 12, 18, 21, 26, 28], "size": [3, 7, 10, 13, 15, 16, 17, 18, 19, 27, 30], "output_size_list": 3, "pool_dim": 3, "nativ": [3, 5, 14, 15, 17, 20, 21, 28], "call_fallback_fn": 3, "xla_fallback": 3, "aten_op": 3, "output_s": 3, "wip": 3, "evolv": 3, "At": [3, 6, 12], "self_tensor": 3, "static": [3, 13, 19, 27], "bool": [3, 12], "sync_upd": 3, "sys_util": 3, "getenvbool": 3, "xla_tensor_update_sync": 3, "true": [3, 12, 14, 15, 18, 19, 21, 24, 27, 28, 30], "xla_check": 3, "dst_tensor": 3, "updatefromtensor": 3, "sync": [3, 12, 14, 17, 18, 20], "complic": [3, 5, 8], "an": [3, 4, 5, 6, 7, 8, 12, 15, 17, 18, 19, 20, 21, 22, 24, 25, 27, 28, 29, 30], "would": [3, 4, 5, 12, 15, 16, 17, 18, 19, 23, 27], "someth": [3, 18], "ab": [3, 24], "const": [3, 5, 7], "torch_lazy_fn_count": 3, "bridg": [3, 22], "atenfromxlatensor": 3, "getxlatensor": 3, "fail": [3, 12, 16, 17, 30], "explain": [3, 6, 16, 17, 18, 19, 27, 29], "later": [3, 18], "still": [3, 7, 15, 16, 19, 20, 21, 27, 30], "snippet": [3, 16, 28], "auto": [3, 5, 12, 24, 30], "common_devic": 3, "getxladevic": 3, "torch_internal_assert": 3, "xlatensorptr": 3, "lazy_self": 3, "getxlatensororcreateforwrappednumb": 3, "nodeptr": 3, "reusenod": 3, "getirvalu": 3, "makenod": 3, "cachenod": 3, "creat": [3, 9, 10, 12, 15, 17, 18, 20, 21, 28, 30], "std": [3, 7, 21], "get": [3, 5, 12, 13, 14, 15, 18, 19, 21, 24, 26, 27], "check": [3, 4, 5, 12, 16, 26, 29], "reus": [3, 16, 18, 20], "previou": [3, 15, 16, 18, 19, 27], "creation": [3, 12], "If": [3, 4, 5, 9, 12, 15, 16, 17, 18, 19, 26, 27, 28], "correspond": [3, 5, 7, 12, 18, 20, 24, 28, 29], "cach": [3, 8, 12, 13, 18], "newli": [3, 9], "And": [3, 19, 21, 27, 28], "within": [3, 9, 12, 16, 17, 18, 26, 30], "note": [3, 4, 7, 8, 9, 10, 12, 14, 15, 16, 17, 18, 19, 22, 24, 25, 26, 27, 29], "done": [3, 4, 8, 16, 17, 18, 19, 27], "public": [3, 15], "xlanod": 3, "xlavalu": 3, "opkind": [3, 5], "absoutputshap": 3, "num_output": [3, 19, 27], "mhash": 3, "string": [3, 7, 12, 28], "tostr": 3, "overrid": [3, 12, 20], "stringstream": 3, "ss": 3, "str": [3, 6, 12], "xlaopvector": 3, "loweringcontext": 3, "loctx": 3, "A": [3, 4, 6, 12, 15, 16, 18, 19, 20, 25, 26, 27, 28], "coupl": [3, 16, 17], "thing": [3, 17, 18], "keep": [3, 4, 13, 15, 17, 19, 27], "mind": [3, 15, 17], "clone": [3, 15, 17, 18], "even": [3, 12, 15, 16, 17, 19, 21, 27], "everi": [3, 5, 8, 9, 12, 15, 16, 17, 19, 22, 27, 28, 30], "outputshap": 3, "xla_shap": 3, "overli": 3, "simplifi": 3, "buildxxxop": 3, "slightli": [3, 5, 12], "better": [3, 5, 14, 15, 16, 17, 18, 19, 22, 23, 27], "maximumoutputshap": 3, "lower_for_shape_fn": 3, "absl": 3, "xlaop": [3, 5], "operand": 3, "promot": 3, "max": [3, 19, 27, 30], "second": [3, 10, 13, 15, 17, 18, 21, 29, 31], "inferoutputshap": 3, "comput": [3, 4, 12, 15, 16, 17, 18, 19, 20, 27, 28, 29], "logic": [3, 12, 14, 19, 23, 27, 28, 29], "two": [3, 6, 12, 15, 17, 18, 19, 27, 28, 29], "xla_input": 3, "getoutputop": 3, "returnop": 3, "buildab": 3, "origin": [3, 9, 18], "genericop": 3, "modifi": [3, 18, 20, 22, 28], "abov": [3, 5, 6, 9, 13, 14, 15, 16, 17, 18, 19, 21, 22, 27, 29], "delet": 3, "sometim": [3, 18, 19, 27], "being": [3, 12, 16, 18, 21, 29], "tensor_op": 3, "cross": [3, 16, 28], "s1": [3, 28], "sub": 3, "mul": [3, 19, 27], "u2": 3, "v3": [3, 16, 21], "u3": 3, "v2": [3, 4, 16], "irnod": 3, "those": [3, 5, 9, 12, 17, 18, 21], "long": [3, 14, 17, 18, 19, 21, 27], "term": [3, 10, 14, 17, 19, 27], "rid": [3, 19, 27], "composit": [3, 5], "end": [3, 5, 10, 12, 13, 15, 16, 17, 18, 21, 24, 25], "exp": 3, "pow": 3, "norm_exp": 3, "vector": [3, 10], "involv": [3, 19, 27, 28], "don": [3, 5, 13, 14, 15, 16, 17, 19, 24, 27], "t": [3, 5, 9, 12, 13, 14, 15, 16, 17, 19, 20, 24, 25, 27, 28, 29, 30], "build_cpp_test": 3, "skip": [3, 5, 17, 22], "desir": [3, 9, 18, 30], "test_ptxla": 3, "gtest_filt": 3, "atenxlatensortest": 3, "testab": 3, "correct": [3, 19, 27], "counter": [3, 5, 12, 17], "correctli": [3, 17, 25], "gt": [3, 4, 9, 15, 18], "erf": 3, "erfc": 3, "erfinv": 3, "pull": [3, 9, 20, 21, 24], "3659": 3, "binary_cross_entropi": [3, 20], "backward": [3, 5, 9, 14, 15, 16, 20, 21, 22, 24, 25], "3809": 3, "scalar": [3, 5, 17, 19, 27], "addcdiv": 3, "addcmul": 3, "3768": 3, "neg": 3, "index": [3, 4, 6, 12, 15, 16, 17, 18, 31], "amin": 3, "amax": 3, "3771": 3, "special": [3, 9, 10, 18, 28], "partial": [3, 19, 24, 25, 27], "adaptive_avgpool3d": 3, "3790": 3, "guid": [4, 9, 11, 15, 16, 18, 24, 25, 28], "interact": [4, 15], "start": [4, 14, 15, 16, 17, 18], "colab": [4, 17], "kaggl": 4, "preinstal": [4, 15], "ecosystem": [4, 26], "packag": [4, 10, 11, 16, 18, 20, 21], "date": 4, "list": [4, 5, 12, 18, 20, 23, 28], "readm": [4, 17, 18], "prerequisit": 4, "remot": 4, "quota": 4, "about": [4, 14, 15, 16, 18, 19, 27], "request": [4, 5, 12, 17, 18, 19, 20, 21, 27, 28], "offici": [4, 17], "ssh": [4, 15, 16, 18], "regist": [4, 5, 6, 7, 15, 30], "agent": 4, "alreadi": [4, 8, 10, 12, 17, 18, 19, 21, 24, 27, 30], "befor": [4, 7, 8, 9, 12, 15, 16, 17, 18, 19, 20, 21, 22, 25, 27, 28, 30], "begin": [4, 28], "zone": [4, 15, 16, 18], "tpu_typ": 4, "8": [4, 9, 10, 12, 14, 15, 16, 18, 19, 21, 22, 26, 27, 28, 29], "vm": [4, 15, 16, 17, 18, 21], "assum": [4, 6, 8, 12, 16, 19, 21, 25, 27, 28], "id_ed25519": 4, "ubuntu2204": 4, "base": [4, 7, 12, 14, 15, 17, 18, 19, 24, 27, 28, 29], "metadata": [4, 17], "cat": [4, 20], "pub": 4, "ip": [4, 12, 15, 30, 31], "format": [4, 12, 17, 18, 22, 26], "valu": [4, 5, 9, 10, 12, 13, 15, 17, 18, 19, 23, 27, 28, 31], "networkendpoint": 4, "accessconfig": 4, "externalip": 4, "123": 4, "give": [4, 9, 17, 18, 26, 28, 29], "friendli": 4, "easier": [4, 14, 18, 19, 27], "echo": 4, "host": [4, 12, 15, 16, 17, 18, 20, 24, 30, 31], "n": [4, 12, 21, 26], "hostnam": 4, "test": [4, 6, 8, 9, 10, 13, 15, 21, 24, 31], "v": [4, 8, 9, 15, 19, 27], "palett": 4, "select": [4, 12, 15, 30], "visualstudio": 4, "doc": [4, 12, 14, 15, 16, 19, 25, 27, 28], "__": [4, 15], "just": [4, 8, 14, 15, 16, 19, 21, 24, 27, 30], "titl": [4, 15], "open": [4, 5, 6, 9, 15, 17], "window": 4, "termin": [4, 30], "mkdir": 4, "ptxla": 4, "Then": [4, 9, 18], "ui": 4, "venv": 4, "virtual": [4, 12], "latest": [4, 9], "releas": [4, 6, 7, 8, 15, 16, 17, 18, 22, 24, 25, 26, 28], "pip": [4, 8, 9, 10, 18], "numpi": [4, 8, 9, 12, 18, 29], "f": [4, 8, 9, 12, 16, 21, 24, 26, 30], "googleapi": [4, 8, 18], "libtpu": [4, 6, 15], "html": [4, 8, 15, 24], "import": [4, 6, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 28, 29, 30], "set_device_typ": 4, "print": [4, 9, 12, 13, 15, 16, 17, 18, 19, 21, 22, 27, 28, 30], "real_devic": 4, "run": [4, 5, 8, 10, 11, 12, 13, 14, 15, 19, 20, 21, 22, 26, 27, 30], "2": [4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 22, 23, 24, 26, 28, 31], "3": [4, 5, 6, 8, 9, 10, 12, 14, 17, 18, 22, 23, 24, 26, 28], "4": [4, 6, 8, 9, 12, 15, 16, 17, 18, 19, 22, 23, 24, 26, 27, 28, 29], "5": [4, 7, 9, 12, 13, 17, 18, 19, 21, 24, 26, 27], "7": [4, 12, 17, 21, 22], "number": [4, 10, 12, 13, 14, 15, 17, 18, 24, 28, 29], "vari": [4, 15, 19, 25, 27], "That": [4, 19, 27], "now": [4, 7, 9, 10, 14, 15, 16, 18, 19, 27, 28], "realist": 4, "librari": [5, 6, 18, 29, 30], "offer": [5, 9, 25, 26], "implement": [5, 7, 8, 9, 14, 15, 17, 19, 22, 24, 25, 27], "xla": [5, 9, 10, 13, 14, 15, 19, 21, 23, 25, 30, 31], "its": [5, 7, 9, 13, 15, 16, 17, 21, 22, 24, 28, 29], "convert": [5, 12, 16, 21], "higher": [5, 17, 30], "level": [5, 17, 18, 22, 26, 30], "represent": [5, 12, 16, 18, 29], "hlo": [5, 12, 16, 17, 18], "beyond": 5, "scope": 5, "forward": [5, 9, 14, 20, 21, 22, 25, 26], "haven": [5, 19, 27], "yet": [5, 7], "caus": [5, 12, 14, 15, 16, 17, 18, 19, 20, 27], "signific": [5, 17, 18, 22], "slowdown": [5, 17, 21], "must": [5, 6, 7, 12, 15, 16, 17, 25, 30, 31], "best": [5, 8, 22, 26], "perform": [5, 7, 8, 9, 10, 12, 14, 16, 20, 21, 22, 24, 26, 28], "what": [5, 16, 18], "debug": [5, 14, 19, 26, 27], "pt": [5, 15, 16, 17, 18], "profil": [5, 15], "_ctc_loss": [5, 17], "_ctc_loss_backward": [5, 17], "contribut": 5, "definit": [5, 16, 19, 27], "native_funct": 5, "after": [5, 7, 9, 12, 15, 16, 17, 18, 19, 23, 27, 28], "kernel": [5, 9, 11, 19, 26, 27], "aten_fallback": 5, "h": 5, "search": 5, "repo": [5, 16, 17, 18, 21], "sequenc": [5, 12], "explicitli": [5, 16, 17, 18, 19, 20, 27], "compos": 5, "match": [5, 9, 12, 16, 17], "serv": 5, "interfac": [5, 6, 16, 17, 25, 30], "machineri": 5, "registerxla": 5, "registerautogradxla": 5, "entri": [5, 6, 9], "pytorch_xla": 5, "world": [5, 8, 15, 19, 22, 27, 30], "written": [5, 18, 30], "paramet": [5, 12, 15, 16, 17, 20, 21, 25, 28, 30, 31], "result": [5, 7, 12, 13, 15, 16, 17, 18, 21, 23, 28], "dispatch": [5, 30], "wrapper": [5, 16, 21, 24, 25], "inplac": [5, 12, 28], "ir": [5, 9, 12, 17, 18, 19, 27], "insid": [5, 9, 16, 18, 28], "stand": 5, "intermedi": [5, 15, 17, 18], "smaller": [5, 18, 19, 27], "inherit": 5, "dai": 5, "addit": [5, 6, 10, 15, 16, 17, 18, 20, 21], "unless": [5, 17, 19, 27], "want": [5, 12, 14, 15, 16, 17, 18, 19, 22, 27, 28, 31], "verifi": 5, "test_oper": 5, "test_aten_xla_tensor": 5, "yield": [5, 16, 17], "break": [5, 18, 19, 27], "grasp": 5, "capabl": 5, "how": [5, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 27, 28, 31], "similar": [5, 15, 18, 21, 23, 26], "minim": [5, 18], "pr": [5, 17, 24], "vanilla": 5, "lerp": 5, "variant": [5, 12, 19, 20, 27], "lerp_": 5, "scalar_out": 5, "tensor_out": 5, "prototyp": [5, 9, 28], "weight": [5, 9, 12, 17, 25, 26], "lerp_out": 5, "howev": [5, 8, 9, 17, 18, 28], "namespac": [5, 17], "wrapper_scalar_lerp": 5, "No": [5, 13, 15, 19, 26, 27], "deviceguard": 5, "omit": [5, 15, 29, 31], "anonym": 5, "wrapper_scalar_lerp_": 5, "wrapper_scalar_lerp__tmp": 5, "_copy_from": 5, "m": [5, 7, 9, 19, 24, 27], "impl": [5, 7, 9], "torch_fn": 5, "automat": [5, 6, 11, 12, 15, 16, 17, 18, 19, 24, 27, 29, 30], "u": [5, 15, 17, 18, 19, 22, 27], "explicit": [5, 20, 24], "place": [5, 7, 12, 18, 20, 28, 30], "ll": [5, 19, 27], "interned_str": 5, "symbol": [5, 19, 27], "submit": [5, 17, 18, 20], "team": [6, 22], "direclti": 6, "tf": [6, 17, 19, 27], "close": 6, "expos": [6, 15, 16, 18, 28], "deviceplugin": 6, "handl": [6, 14, 17, 19, 24, 25, 27, 28, 29], "short": [6, 17, 19, 27], "pjrtclient": 6, "mirror": 6, "pjrt_api": 6, "straightforward": [6, 12, 18], "detail": [6, 7, 8, 9, 12, 13, 15, 16, 17, 18, 19, 27], "concret": [6, 19, 27], "placehold": 6, "pjrt_library_path": 6, "extra": [6, 21, 25], "multiprocess": [6, 12, 15, 16], "compon": 6, "least": [6, 18], "cpuplugin": 6, "def": [6, 7, 8, 9, 10, 12, 14, 15, 16, 18, 21, 22, 23, 25, 26], "library_path": 6, "o": [6, 9, 15, 21], "join": [6, 12], "dirnam": 6, "__file__": 6, "pjrt_c_api_cpu_plugin": 6, "identifi": [6, 12, 30], "exmapl": 6, "pyproject": 6, "toml": 6, "torch_xla_cpu_plugin": 6, "With": [6, 8, 9, 13, 15, 19, 22, 27], "initi": [6, 7, 9, 12, 15, 16, 18, 21, 23, 30], "experiment": [6, 8, 9, 10, 11, 13, 14, 15, 16, 21, 22, 23, 25, 28, 30], "state": [6, 12, 24], "becom": [6, 8, 9, 15, 17, 18, 19, 27], "stabl": [6, 15, 24], "xla_model": [7, 9, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 29], "adopt": [7, 19, 27], "traceabl": [7, 12], "commun": [7, 8, 12, 15, 16, 18, 22, 29], "reimplement": [7, 15], "_c10d_function": 7, "figur": [7, 13, 29], "show": [7, 9, 15, 16, 17, 21], "all_reduc": [7, 12, 20], "between": [7, 13, 15, 16, 17, 18, 19, 20, 21, 23, 27, 28], "processgroupxla": 7, "deriv": 7, "processgroup": 7, "xla_backend": [7, 15, 21, 30], "_create_xla_process_group": 7, "prefix_stor": 7, "rank": [7, 12, 15, 21, 24, 29, 30], "timeout": 7, "assert": [7, 21], "xr": [7, 12, 15, 16, 20, 21, 24, 25, 28, 29, 30], "is_spmd": [7, 12], "spmd": [7, 11, 16, 18, 30], "group": [7, 12, 15, 21, 28], "_register_xla_backend": 7, "dist": [7, 15, 21, 30], "register_backend": 7, "allreduc": 7, "all_reduce_opt": 7, "allgath": 7, "output_tensors_list": 7, "input_tensor": 7, "opt": [7, 16], "none": [7, 8, 9, 12, 17, 25, 28, 29], "_mp_fn": [7, 15, 16], "init_process_group": [7, 15, 21, 30], "init_method": [7, 15, 30], "progress": [7, 18], "instanc": [7, 8, 12, 24, 30], "blob": [7, 10, 12, 15, 28], "distributed_c10d": 7, "_exception_logg": 7, "all_gath": [7, 12, 15], "tensor_list": 7, "async_op": 7, "fals": [7, 9, 12, 16, 24, 28], "_get_default_group": 7, "certain": [7, 17, 19, 20, 27], "remap": 7, "_functional_collect": 7, "all_reduce_inplac": 7, "eventu": 7, "reach": [7, 12], "rewrit": [7, 18, 19, 27, 28], "reduceop": 7, "group_nam": 7, "torch_library_impl": 7, "four": [7, 18], "oper": [7, 10, 11, 12, 15, 16, 17, 18, 30], "align": [7, 14], "while": [7, 9, 12, 18, 19, 21, 27], "signatur": 7, "remain": [7, 16, 18, 19, 27, 31], "restrict": 7, "appli": [7, 12, 20, 24, 25, 30], "usag": [7, 12, 17, 18, 19, 24, 25, 27, 30], "test_collective_ops_tpu": 7, "demonstr": [7, 18, 20, 25, 30], "scenario": [7, 22], "sum": [7, 12, 20, 24, 25], "reduct": [7, 12], "aggreg": 7, "all_gather_into_tensor": 7, "gather": [7, 12, 28], "reduce_scatter_tensor": 7, "reduc": [7, 12, 13, 14, 15, 16, 17, 18, 24], "across": [7, 12, 15, 16, 17, 24, 29], "all_to_all_singl": 7, "output_split_s": 7, "input_split_s": 7, "although": [7, 15, 19, 27], "accept": [7, 28], "argument": [7, 9, 10, 12, 18, 20, 22, 24], "limit": [7, 12, 15, 16], "reflect": 7, "compromis": 7, "maintain": 7, "constraint": [7, 15, 17], "alltoal": [7, 12], "rise": 8, "openai": [8, 10], "triton": [8, 11], "popular": 8, "order": [8, 12, 16, 17, 18, 28, 29], "pariti": 8, "continu": [8, 15, 22], "push": 8, "let": [8, 15, 16, 17, 18, 22, 29], "custom_kernel": 8, "jax_import_guard": 8, "pl": [8, 15, 16, 28], "jnp": 8, "add_vectors_kernel": 8, "x_ref": 8, "y_ref": 8, "o_ref": 8, "x": [8, 9, 10, 12, 16, 17, 18, 19, 21, 23, 24, 25, 26, 27, 28, 29], "y": [8, 10, 12, 17, 18, 19, 24, 25, 26, 27, 28], "jit": [8, 10, 22], "add_vector": 8, "arrai": [8, 12, 18, 25, 29], "pallas_cal": 8, "out_shap": 8, "shapedtypestruct": 8, "dtype": [8, 9, 10, 15, 19, 20, 26, 27], "otherwis": [8, 12, 17, 18, 19, 25, 27], "program": [8, 9, 10, 12, 17, 18, 19, 22, 27, 28, 29], "hang": 8, "lock": 8, "q": [8, 9], "randn": [8, 9, 12, 14, 15, 16, 21, 22, 26, 28, 29], "128": [8, 9, 15, 24, 26, 31], "k": [8, 9, 17], "make_kernel_from_palla": 8, "pt_kernel": 8, "lambda": [8, 24], "liner": 8, "flash": [8, 10], "attent": [8, 10], "besid": 8, "op": [8, 9, 11, 12, 14, 17, 18, 19, 20, 27, 28, 29], "suppor": 8, "flash_attent": 8, "paged_attent": 8, "queri": [8, 15], "squeez": 8, "dim": [8, 12], "key_cach": 8, "value_cach": 8, "context_len": 8, "block_tabl": 8, "pages_per_compute_block": 8, "megacore_mod": 8, "vllm": 8, "util": [8, 11, 12, 16, 17, 21, 24, 25, 26, 30], "effect": [8, 12], "memori": [8, 11, 12, 13, 17, 18, 19, 24, 27], "kv": 8, "proper": [8, 29], "jax_nightly_releas": 8, "jaxlib_nightly_releas": 8, "exported_program_to_stablehlo": 9, "xm": [9, 12, 14, 16, 17, 18, 20, 21, 22, 23, 24, 25, 26, 29], "torchvis": [9, 14, 22], "xla_devic": [9, 12, 15, 16, 17, 18, 20, 21, 22, 23, 26, 29], "resnet18": [9, 14, 22], "sampl": [9, 12, 15, 17], "tupl": [9, 12, 19, 23, 25, 27, 29], "sample_input": 9, "224": [9, 14], "stablehlo_program": 9, "callabl": [9, 12, 24], "get_stablehlo_text": 9, "get_stablehlo_bytecod": [9, 12], "sample_input_xla": 9, "output2": 9, "allclos": 9, "atol": 9, "1e": [9, 17, 22], "One": [9, 12, 13, 18, 24], "tmp": [9, 16, 17, 24], "stablehlo_dir": 9, "empti": [9, 12], "doesn": [9, 16, 17, 19, 25, 27], "load": [9, 10, 12, 15, 17, 21, 24, 26, 30], "stablehlographmodul": 9, "stablehlo_program2": 9, "output3": 9, "server": [9, 12, 15, 18], "env": [9, 12, 15, 28], "nightli": [9, 17, 18, 24, 28], "resnet_tf": 9, "p": [9, 15, 17, 19, 27], "8500": 9, "mount": [9, 16], "model_nam": 9, "accomplish": 9, "tf_saved_model_integr": 9, "save_torch_module_as_tf_saved_model": 9, "nn": [9, 12, 15, 16, 21, 22, 24, 26, 28], "trace": [9, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 27, 28], "exported_model": 9, "exportedprogram": 9, "pathlik": 9, "stablehloexportopt": 9, "alias": [9, 17, 20], "save_torch_model_as_stablehlo": 9, "torchmodel": 9, "arg": [9, 12, 16, 18, 23, 24], "constant": [9, 17, 18, 28], "ndarrai": [9, 12], "human": 9, "readabl": [9, 18], "mlir": 9, "form": [9, 15, 17, 19, 27, 31], "posit": [9, 12], "meta": 9, "l__fn___layers_15_feed_forward_w2": 9, "l__fn___layers_13_feed_forward_w1": 9, "l__fn___layers_3_attention_wo": 9, "l__fn___layers_12_ffn_norm_weight": 9, "l__fn___layers_25_attention_wo": 9, "serial": [9, 15, 16], "stablehlofunc": 9, "stage": 9, "guarante": [9, 12], "plan": [9, 13, 15], "major": 9, "agre": [9, 18], "scaled_dot_product_attent": 9, "decompos": 9, "low": [9, 13, 17], "dure": [9, 12, 17, 18, 22, 24, 28], "lower": [9, 11, 17, 18, 19, 20, 27], "captur": [9, 12, 17, 18], "downstream": [9, 20], "ml": [9, 29], "crucial": 9, "geneart": 9, "pattern": [9, 17, 19, 22, 27], "bunch": 9, "challeng": 9, "error": [9, 12, 16, 17], "prone": 9, "robust": 9, "outlin": [9, 26], "stablehlocompositebuild": 9, "arbitari": 9, "region": [9, 12, 14, 17, 20, 28], "non": [9, 12, 14, 19, 20, 27, 29], "hardcod": [9, 28], "store": [9, 10, 12, 17], "attribut": 9, "retriev": [9, 12, 16, 19, 22, 27, 28], "pratic": 9, "scaled_product_attent": 9, "mark_pattern_util": 9, "__init__": [9, 21, 26], "super": [9, 21, 22], "q_proj": 9, "linear": [9, 12, 15, 16, 20, 21, 26], "bia": 9, "k_proj": 9, "v_proj": 9, "builder": 9, "b": [9, 12, 15, 18, 19, 20, 22, 27, 29], "sdpa": 9, "25": [9, 13], "other_attr": 9, "val": 9, "mark_input": 9, "attn_out": 9, "mark_output": 9, "input_arg": 9, "10": [9, 12, 15, 16, 17, 18, 19, 21, 22, 23, 26, 27, 30], "stablehlo_gm": 9, "shown": [9, 15, 19, 27], "irtohlo": 9, "56": 9, "mhlo": 9, "cross_program_prefetch": 9, "input_output_alia": 9, "is_dynam": 9, "use_auto_spmd_partit": 9, "func": 9, "arg0": 9, "10x8x128xf32": 9, "arg1": 9, "128x128xf32": 9, "arg2": 9, "arg3": 9, "9": [9, 18, 19, 21, 24, 27], "composite_attribut": 9, "500000e": 9, "01": [9, 10], "f32": 9, "decomposit": 9, "11": [9, 17, 19, 27], "privat": [9, 15], "actual": [9, 14, 18, 19, 21, 27, 28], "encapsul": 9, "propag": [9, 17], "high": [10, 13, 18, 21, 26], "deep": [10, 11, 17], "learn": [10, 15], "languag": 10, "empow": 10, "full": [10, 12, 16, 17, 24], "potenti": [10, 12, 15, 17, 25], "given": [10, 12, 17, 18, 19, 21, 24, 27, 29], "add_kernel": 10, "x_ptr": 10, "pointer": 10, "y_ptr": 10, "output_ptr": 10, "n_element": 10, "block_siz": 10, "tl": 10, "constexpr": 10, "element": [10, 12, 19, 25, 27, 28], "tutori": [10, 17, 18, 21, 28], "l28": 10, "pid": 10, "program_id": 10, "axi": [10, 12, 25], "block_start": 10, "offset": 10, "arang": 10, "mask": [10, 17, 19, 27], "xla_triton": 10, "16": [10, 16, 18, 24, 26, 29], "int64": 10, "empty_lik": 10, "grid": 10, "cdiv": 10, "triton_cal": 10, "itself": [10, 12, 24], "kwarg": [10, 12, 24, 28], "payload": [10, 12, 15], "regard": [10, 16, 22], "buffer": [10, 12], "_xla_gpu_custom_cal": 10, "dep": 10, "connect": [11, 12, 15, 28], "overview": [11, 29], "eager": [11, 12, 19, 21, 26, 27], "mode": [11, 12, 19, 21, 26, 27, 28, 30], "troubleshoot": 11, "palla": 11, "stablehlo": [11, 12], "mix": [11, 12, 29], "precis": 11, "advanc": [11, 29], "topic": [11, 29], "distribut": [11, 16, 17, 21, 24, 25, 28, 29], "checkpoint": [11, 15, 18, 24, 29], "distributeddataparallel": [11, 15], "ddp": [11, 15], "torchdynamo": 11, "while_loop": 11, "shard": [11, 12, 29, 30], "quantiz": 11, "recompil": [11, 13, 14, 16, 17, 18], "hardwar": [11, 12, 17, 18, 20], "plugin": [11, 15], "bazel": 11, "int": [12, 15, 19, 27, 28], "device_count": [12, 28], "address": [12, 15, 28, 31], "wait": [12, 17, 18], "pend": [12, 14], "whether": [12, 16, 20], "block": [12, 18, 24, 28], "finish": [12, 18], "full_graph": 12, "num_different_graphs_allow": 12, "lazytensor": [12, 14, 18], "repres": [12, 15, 19, 27], "happen": [12, 14, 15, 16, 17, 18, 19, 27], "decid": [12, 17, 19, 27], "funciton": 12, "act": [12, 16], "context": [12, 15, 17, 19, 20, 27], "throw": [12, 16], "info": [12, 17, 19, 27, 29], "exit": [12, 17, 20, 21], "pt_xla_debug": 12, "messag": [12, 17], "dump": [12, 17], "allow": [12, 16, 17, 18, 20, 28, 29, 30], "rais": [12, 17], "exceed": 12, "foo": 12, "sin": 12, "co": 12, "foo2": 12, "compiled_foo2": 12, "manual_se": [12, 15], "seed": 12, "random": [12, 14, 15, 18, 26], "integ": [12, 17], "rng": [12, 15], "device_typ": 12, "local_process_count": 12, "local_device_count": 12, "total": [12, 19, 27, 29], "addressable_device_count": 12, "visibl": [12, 19, 27], "global_device_count": 12, "global_runtime_device_count": [12, 25, 28, 29], "especi": [12, 15, 18, 22, 28], "world_siz": [12, 15, 20, 21, 24, 28], "particip": [12, 15], "job": [12, 18, 22], "global_ordin": [12, 15, 16, 21, 24], "global": [12, 15, 16, 28, 30], "ordin": [12, 16], "thread": [12, 15, 16, 17, 30], "predict": 12, "relationship": [12, 16, 17], "worker": [12, 15, 16, 18, 24, 30], "id": [12, 15, 17, 18], "nor": 12, "contigu": [12, 16, 17], "local_ordin": 12, "get_master_ip": 12, "master": [12, 15, 16, 30], "discoveri": 12, "use_spmd": [12, 28, 29, 30], "forc": [12, 15, 17, 19, 23, 27], "mean": [12, 15, 16, 17, 18, 19, 21, 25, 27, 28], "replic": [12, 28, 29], "spmd_advanc": 12, "md": [12, 15], "initialize_cach": [12, 16], "readonli": [12, 16], "persist": [12, 16, 30], "devkind": 12, "cuda": [12, 15, 16, 18, 19, 20, 26, 27, 31], "deprec": 12, "xla_device_hw": 12, "union": 12, "real": [12, 22], "is_master_ordin": 12, "multi": [12, 13, 28, 31], "num_host": 12, "boolean": 12, "indic": [12, 17, 18, 19, 27], "reduce_typ": 12, "float": [12, 19, 20, 27], "pin_layout": 12, "reduce_sum": 12, "reduce_mul": 12, "reduce_and": 12, "reduce_or": 12, "reduce_min": 12, "reduce_max": 12, "replica": [12, 15], "layout": [12, 26], "pine": 12, "prevent": [12, 18, 20, 22, 28], "corrupt": 12, "unpin": 12, "hlomodul": 12, "constrain": [12, 15], "hold": [12, 28, 29], "along": [12, 24], "dimens": [12, 13, 28, 29], "all_to_al": 12, "split_dimens": 12, "concat_dimens": 12, "split_count": 12, "www": 12, "org": [12, 15, 24], "operation_semant": 12, "upon": 12, "split": 12, "concat": 12, "count": [12, 17], "add_step_closur": 12, "closur": 12, "run_async": 12, "step": [12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 24, 25, 27, 28, 30], "mani": [12, 15, 17, 18, 19, 27, 31], "report": 12, "consol": 12, "post": [12, 17], "tensorboard": [12, 18], "etc": [12, 14, 17, 19, 27, 28], "intermediari": 12, "inspect": 12, "typic": 12, "barrier": [12, 15, 16, 18], "materi": [12, 17, 18, 19, 27, 28], "queu": 12, "though": [12, 16, 21], "advis": 12, "throttl": 12, "event": 12, "asynchron": [12, 28, 30], "wait_device_op": 12, "async": [12, 22], "whose": [12, 13], "optimizer_step": [12, 16, 18, 20, 21, 24], "optimizer_arg": 12, "dict": [12, 24], "gradid": 12, "parallelload": [12, 28], "dataparallel": 12, "loader": [12, 17, 18, 22], "dictionari": 12, "gradient": [12, 16, 20, 24, 30], "save": [12, 17, 24, 30], "file_or_path": 12, "textio": 12, "master_onli": [12, 24], "global_mast": 12, "transfer": [12, 15, 17, 18, 28], "care": [12, 16, 17, 19, 27], "taken": [12, 16, 17, 19, 21, 27, 30], "view": [12, 16, 17], "recreat": [12, 16], "destin": [12, 16], "nest": [12, 24], "locat": 12, "control": [12, 13, 16, 17, 28], "obj_to_sav": 12, "path_to_sav": 12, "rendezv": 12, "tag": [12, 15], "byte": 12, "mesh": [12, 15, 25], "xla_rendezv": 12, "sent": [12, 17], "exchang": 12, "mesh_reduc": 12, "reduce_fn": 12, "toxlatensorarena": 12, "receiv": 12, "copi": [12, 15, 16, 17, 18], "np": [12, 25, 29], "accuraci": [12, 21, 24], "test_accuraci": 12, "set_rng_stat": 12, "get_rng_stat": 12, "get_memory_info": 12, "memoryinfo": 12, "bytes_us": 12, "290816": 12, "bytes_limit": 12, "34088157184": 12, "peak_bytes_us": 12, "500816": 12, "get_stablehlo": 12, "var": [12, 28], "xla_hlo_debug": [12, 17], "root": [12, 19, 27], "bytecod": [12, 22], "parallel_load": [12, 15, 16, 17], "mpdeviceload": [12, 16, 18, 28], "dataload": [12, 16, 18, 21, 28, 30], "background": [12, 30], "upload": [12, 18, 28], "per_device_load": [12, 28], "constructor": 12, "train_device_load": 12, "train_load": [12, 16, 28], "xla_multiprocess": 12, "spawn": [12, 15, 16, 18], "fn": 12, "nproc": [12, 15], "daemon": 12, "start_method": 12, "moment": 12, "maximum": [12, 13, 18, 26], "valueerror": 12, "mark_shard": [12, 25, 28, 29], "xlashardedtensor": [12, 30], "partition_spec": [12, 28, 29], "annot": [12, 28, 29], "partit": [12, 28], "spec": [12, 28], "intern": [12, 15, 16, 17, 19, 27, 28, 31], "spmdpartition": [12, 28], "topologi": [12, 16, 28, 29], "device_mesh": [12, 28], "mesh_shap": [12, 25, 28, 29], "ax": [12, 28, 29], "impact": [12, 15, 17, 19, 21, 27], "dynamo_custom_op": 12, "dynamo": [12, 18, 22, 26], "recogniz": 12, "num_devic": [12, 25, 28, 29], "device_id": [12, 25, 28, 29], "32": [12, 17, 18], "clear_shard": 12, "clear": 12, "cast": [12, 20], "t1": [12, 16, 17, 29], "get_1d_mesh": 12, "set_global_mesh": 12, "get_global_mesh": 12, "axis_nam": [12, 28], "v4": [12, 14, 15, 16, 18, 22, 28], "ravel": 12, "reshap": 12, "fill": 12, "assign": [12, 16, 18], "Its": 12, "length": [12, 19, 27], "len": [12, 18], "get_xla_supported_devic": 12, "get_logical_mesh": 12, "ordereddict": [12, 28, 29], "hybridmesh": [12, 28], "ici_mesh_shap": [12, 28], "dcn_mesh_shap": [12, 28], "hybrid": 12, "ici": 12, "dcn": [12, 28], "increas": 12, "intens": 12, "mdl": 12, "inner": [12, 24, 28], "outer": [12, 24, 25, 28], "slice": [12, 18, 28], "fsdp": [12, 24, 25, 28, 29], "eager_mod": [12, 14], "wa": [12, 15, 17, 18, 19, 27, 30], "d": [12, 13, 19, 20, 27], "eagerli": [12, 14, 16, 17, 19, 27], "metric": [12, 21], "metrics_report": [12, 17], "short_metrics_report": [12, 17], "counter_nam": 12, "metric_nam": 12, "activ": [12, 16, 17, 21, 24, 25, 26], "counter_valu": 12, "metric_data": 12, "total_sampl": 12, "accumul": 12, "retain": 12, "circular": 12, "natur": 13, "in_tensor": 13, "randint": [13, 26], "out_tensor": 13, "nonzero": [13, 17, 18, 19, 27], "word": [13, 19, 27], "further": [13, 18, 21], "categor": 13, "unbound": 13, "alloc": 13, "infinit": [13, 25], "phase": 13, "layer": [13, 14, 24, 25, 28], "perceptron": 13, "mlp": 13, "xla_experiment": 13, "masked_select": 13, "masked_scatt": 13, "your_script": [13, 18], "100": [13, 17, 24], "29": [13, 21, 22], "49": [13, 22], "20": [13, 16, 17, 21, 26], "03": 13, "102": 13, "hit": [13, 19, 27], "198": 13, "1953": 13, "motiv": 13, "excess": 13, "half": 13, "drop": [13, 17], "try": [13, 17, 18, 19, 27], "python3": [13, 15, 16, 17, 18, 24], "test_dynamic_shape_model": 13, "testdynamicshapemodel": 13, "test_backward_pass_with_dynamic_input": 13, "expand": [13, 22], "feel": [13, 17, 21], "review": [13, 25], "rfc": [13, 28, 31], "64": [14, 22, 24], "mark_step": [14, 15, 16, 17, 18, 21], "drawback": 14, "approach": [14, 19, 21, 24, 27], "often": [14, 17, 19, 27], "confus": 14, "preprocess": [14, 26], "small": [14, 17, 18, 19, 21, 22, 27], "leak": 14, "expens": [14, 17, 19, 27], "hard": [14, 19, 21, 22, 27], "why": [14, 19, 27], "mitig": 14, "ux": 14, "mark": [14, 16], "compiled_model": 14, "right": [14, 19, 22, 27], "awai": 14, "pretti": [14, 16, 19, 21, 27], "straight": 14, "enter": 14, "reenabl": 14, "perfomr": 14, "compar": [14, 15, 16, 17, 21, 22, 23], "recommen": 14, "overhad": 14, "step_fn": 14, "loss_fn": [14, 15, 16, 20, 21, 22], "zero_grad": [14, 15, 16, 20, 21], "logit": [14, 25], "loss": [14, 15, 16, 20, 22, 24, 25], "ask": [14, 17, 19, 27], "refactor": 14, "decod": 14, "much": [14, 15, 16, 18, 19, 22, 27], "llama2": 14, "fake": [14, 30], "chip": [14, 15], "300": [14, 17], "observ": [14, 15, 21], "147": 14, "65": [14, 17], "45": 14, "train_decoder_only_bas": [14, 17], "perfomran": 14, "tri": [14, 18], "resnet50": [14, 15, 16, 22, 24], "exepct": 14, "loop": [14, 16, 17, 18, 19, 25, 27, 30], "meant": 14, "encount": [15, 17, 18], "bug": [15, 17, 21], "r2": [15, 17, 28], "init": [15, 16, 21, 22, 23], "renam": 15, "torchrun": [15, 16, 31], "xpu": 15, "neuron": 15, "xrt_tpu_config": 15, "30": [15, 24], "thousand": 15, "preview": 15, "safe": 15, "section": [15, 16, 17, 18, 28], "broadcast": 15, "broadcast_master_param": 15, "pjrt_backend": 15, "diff": [15, 18], "42": 15, "gradient_as_bucket_view": [15, 21], "mseloss": [15, 21], "sgd": [15, 16, 20, 21, 22], "lr": [15, 16, 21, 22, 24, 25], "001": [15, 21], "confirm": 15, "__name__": [15, 16, 21], "__main__": [15, 16, 21], "localservic": 15, "localhost": 15, "51011": 15, "master_addr": 15, "master_port": 15, "12355": [15, 31], "Or": [15, 16, 19, 27], "overhead": [15, 21, 22], "grpc": 15, "torchbench": 15, "35": [15, 17], "tpuvm": [15, 16, 18, 28], "2048": 15, "mnist": [15, 16, 17, 20], "test_train_mp_mnist": [15, 21], "fake_data": [15, 17, 21, 31], "alpha": [15, 16], "central2": [15, 18], "git": [15, 17, 18, 24], "depth": [15, 17], "branch": [15, 17, 19, 27], "test_train_mp_imagenet": [15, 17, 21], "batch_siz": [15, 24, 31], "256": 15, "num_epoch": [15, 21, 24], "By": [15, 19, 27], "tpu_process_bound": 15, "tpu_visible_chip": 15, "r1": 15, "13": [15, 16, 21, 23], "docker_imag": 15, "gcr": 15, "io": [15, 24], "sudo": [15, 18], "rm": 15, "privileg": 15, "net": [15, 18, 20], "gpu_num_devic": 15, "nnode": [15, 31], "num_gpu_devic": 15, "pjrt_distribut": 15, "physic": [15, 28, 29], "12": [15, 17, 22, 24], "number_gpu_vm": [15, 31], "node_rank": [15, 31], "current_node_rank": 15, "nproc_per_nod": [15, 31], "number_local_gpu_devic": 15, "rdzv_endpoint": [15, 31], "internal_ip_address": 15, "multinode_train": 15, "endpoint": [15, 31], "machine_0": 15, "machine_1": 15, "machine_0_internal_ip_address": [15, 31], "ident": 15, "page": 15, "mostli": [15, 24], "interchang": 15, "perspect": [15, 16], "subtl": 15, "importantli": 15, "architectur": [15, 24], "thu": [15, 17], "batch": [15, 16, 17, 18, 28], "latenc": 15, "deseri": 15, "send": [15, 16, 18, 28], "direct": [15, 17], "independ": [15, 16, 17], "significantli": [15, 16, 18], "xla_dist": 15, "scp": [15, 16], "sdk": 15, "collect": [15, 21, 22, 29, 30], "enhanc": 15, "stabil": [15, 17, 20], "xmp": [15, 16, 18], "substanti": 15, "practic": [15, 19, 25, 27], "unreli": 15, "due": [15, 17, 18, 31], "inbound": 15, "could": [15, 18, 19, 27, 28], "failur": 15, "entir": [15, 24], "restart": 15, "impos": 15, "middl": [15, 18, 19, 27], "unwant": 15, "permit": 15, "subset": 15, "old": 15, "alter": 15, "synchron": [15, 16, 18, 28, 30], "consid": [15, 18], "all_gather_object": 15, "gloo": [15, 21, 30], "subgroup": 15, "monitor": 15, "_": [15, 22, 23], "altern": [15, 19, 20, 26, 27], "less": [15, 19, 22, 27], "reliabl": 15, "than": [15, 17, 19, 21, 24, 27], "strongli": 15, "_all_gath": 15, "int32": 15, "zeros_lik": 15, "get_world_s": 15, "averag": 15, "task": 15, "175": 15, "chart": 15, "breakdown": 15, "tfrt": 15, "legaci": 15, "streamexecutor": 15, "tpu_legaci": 15, "comparison": [15, 29], "regular": [16, 17, 18, 26], "t0": 16, "matrix": 16, "multipli": [16, 29], "mm": [16, 20], "neural": 16, "l_in": 16, "l_out": 16, "floattensor": 16, "highlight": [16, 18], "nllloss": 16, "momentum": 16, "switch": [16, 17, 19, 21, 27], "acquir": 16, "mp_device_load": 16, "three": 16, "multithread": [16, 17], "own": [16, 24], "onto": 16, "preload": [16, 18], "overlap": [16, 18, 22, 28], "batches_per_execut": 16, "consolid": [16, 24], "all_reduce_gradi": 16, "parent": 16, "talk": 16, "basi": 16, "howto": 16, "focu": [16, 19, 27], "train_mnist_xla": 16, "outsid": 16, "infrastructur": 16, "awar": 16, "fakedata": 16, "But": [16, 17, 19, 27], "immedi": [16, 28], "hand": 16, "record": [16, 17, 18], "defer": 16, "fuse": [16, 18], "invis": 16, "caller": 16, "insert": [16, 18], "paper": 16, "opaqu": [16, 17], "appear": [16, 17, 18], "unlik": [16, 18], "adjust": 16, "preserv": [16, 17], "appreci": 16, "accommod": 16, "previous": 16, "state_dict": [16, 24, 30], "footprint": 16, "xser": 16, "stream": 16, "amount": [16, 17, 18, 19, 27], "restor": 16, "load_state_dict": [16, 30], "unavail": [16, 17], "consum": [16, 19, 27], "disk": 16, "occur": 16, "your_cache_path": 16, "mp_fn": 16, "xla_cache_": 16, "runnabl": [16, 21, 25], "subject": 17, "peculiar": 17, "detial": 17, "__version__": 17, "cu121": 17, "t2": [17, 29], "200": 17, "rx": 17, "conclud": 17, "diagnos": 17, "extrem": 17, "pt_xla_debug_level": 17, "slip": 17, "analyz": [17, 18], "summari": 17, "compiletim": 17, "frequent": 17, "21": 17, "transferfromdevicetim": 17, "23": 17, "hash": 17, "c74c3b91b855b2b123f833b0d5f86943": 17, "107": 17, "frame": 17, "trigger": [17, 18, 19, 27], "dk3": 17, "1055": 17, "44": 17, "__next__": 17, "train_loop_fn": 17, "48": [17, 21], "start_train": 17, "73": 17, "548000": 17, "gb": 17, "922460": 17, "547871": 17, "124478": 17, "028210": 17, "steptrac": 17, "frequenc": 17, "pair": 17, "met": 17, "spent": [17, 18], "destroi": 17, "percentil": 17, "totalsampl": 17, "202": 17, "06m09s401ms746": 17, "001u": 17, "valuer": 17, "778ms572": 17, "062u": 17, "rate": [17, 21], "425201": 17, "001ms32": 17, "778u": 17, "001ms61": 17, "283u": 17, "001ms79": 17, "236u": 17, "001ms110": 17, "973u": 17, "50": [17, 18, 23], "001ms228": 17, "773u": 17, "80": 17, "001ms339": 17, "183u": 17, "90": 17, "001ms434": 17, "305u": 17, "95": 17, "002ms921": 17, "063u": 17, "99": [17, 21], "21s102ms853": 17, "173u": 17, "cachedsynctensor": 17, "395": [17, 21], "area": 17, "rout": 17, "qualifi": 17, "33": [17, 21, 22], "_local_scalar_dens": 17, "epoch": [17, 18, 24], "clear_al": 17, "xla_dynamo_debug": 17, "bottleneck": [17, 18], "notebook": 17, "train_resnet_benchmark": 17, "behav": 17, "evalu": [17, 18, 19, 27], "suggest": 17, "bad": 17, "degrad": [17, 18], "speedup": [17, 22], "indirect": 17, "solut": [17, 19, 26, 27], "variat": 17, "pad": [17, 18, 19, 27], "fix": [17, 18, 22, 25], "translat": 17, "item": [17, 18], "substitut": 17, "flow": 17, "clip_grad_norm": 17, "problemat": 17, "clip_grad_norm_": 17, "dramat": 17, "total_norm": 17, "zero": [17, 24, 30], "param_norm": 17, "grad": 17, "norm": 17, "norm_typ": 17, "add_": 17, "clip_coef": 17, "max_norm": 17, "mul_": 17, "data_parallel": 17, "last": 17, "dataset": [17, 21, 24], "stride": 17, "reconstruct": 17, "shallow": 17, "ty": 17, "made": [17, 18, 19, 27, 28], "_get_xla_tensors_text": [17, 19, 27], "_get_xla_tensors_hlo": 17, "prior": [17, 30], "degre": 17, "xla_ir_debug": 17, "henc": [17, 22], "respons": [17, 18, 22, 30], "xla_save_tensors_fil": 17, "realli": [17, 19, 22, 27], "big": [17, 19, 27], "left": 17, "append": 17, "sheet": 17, "xla_save_tensors_fmt": 17, "text": 17, "dot": 17, "graphviz": 17, "xla_flag": 17, "xla_dump_to": 17, "dir_nam": 17, "unoptim": 17, "optimz": 17, "xla_metrics_fil": 17, "xla_save_hlo_fil": 17, "offend": 17, "xla_sync_wait": 17, "xla_use_eager_debug_mod": 17, "bypass": 17, "overal": [17, 18], "optimizaiton": 17, "tf_cpp_log_thread_id": 17, "tf_cpp_vmodul": 17, "vlog": 17, "tf_cpp_min_log_level": 17, "turn": 17, "warn": 17, "tf_vlog": 17, "xla_dump_hlo_graph": 17, "xla_util": 17, "cc": 17, "save1": 17, "xla_graph_executor": 17, "pjrt_computation_cli": 17, "dir": 17, "pytorch_test_with_slow": 17, "test_torch": 17, "test_put_xla_uint8": 17, "torch_test_devic": 17, "pytorch_test_bas": 17, "brief": 18, "basic": [18, 19, 21, 27], "reader": 18, "modif": 18, "fetch": 18, "discuss": [18, 29], "opcod": 18, "fed": 18, "attach": [18, 28], "callback": 18, "xla_tensor_z": 18, "cut": [18, 19, 27], "transferfromdevic": 18, "tell": [18, 19, 27], "properti": [18, 19, 27], "illustr": [18, 29], "suppos": 18, "tensors_on_devic": 18, "z": [18, 19, 27], "subgraph": [18, 19, 27], "signal": 18, "far": 18, "suitabl": 18, "trade": [18, 19, 27], "off": 18, "spend": 18, "fusion": 18, "worth": [18, 19, 27], "latter": [18, 24], "wheel": [18, 24], "runtime_vers": 18, "project_id": 18, "accelerator_typ": 18, "tpu_nam": 18, "your_tpu_nam": 18, "subnetwork": 18, "tpusubnet": 18, "pip3": 18, "cp38": 18, "linux_x86_64": 18, "whl": 18, "apt": 18, "libopenbla": 18, "dev": [18, 21], "libgl1": 18, "guidelin": 18, "bar": 18, "rememb": 18, "txt2img": 18, "prompt": 18, "photograph": 18, "astronaut": 18, "ride": 18, "hors": 18, "relat": 18, "precision_scop": 18, "addition": [18, 20, 24], "particular": 18, "frozenclipembedd": 18, "simplic": [18, 19, 27], "ddim": 18, "top": 18, "attr": 18, "statement": [18, 19, 27], "stop": 18, "fall": [18, 25], "difficult": 18, "readi": 18, "investig": [18, 21], "cover": [18, 28], "huggingfac": 18, "sd": 18, "xl": 18, "cd": [18, 24], "text_to_imag": 18, "inference_tpu_single_devic": 18, "lora": 18, "model_id": 18, "stabilityai": 18, "pipelin": 18, "dpmsolvermultistepschedul": 18, "txt": 18, "invisible_watermark": 18, "transform": [18, 24, 29], "safetensor": 18, "licens": 18, "card": 18, "cli": 18, "_your_copied_token__": 18, "pipe": 18, "hour": 18, "wherea": 18, "likewis": 18, "gpt": 18, "15": 18, "min": 18, "subsequ": 18, "advantag": 18, "mayb": 18, "notic": 18, "piec": 18, "__call__": 18, "commit": 18, "caveat": 18, "rule": [18, 20], "thumb": 18, "durat": [18, 30], "constantli": 18, "idl": 18, "inference_tpu_": 18, "capture_profil": 18, "gap": 18, "xp": 18, "measur": 18, "portion": 18, "busi": 18, "scroll": 18, "occupi": 18, "displai": 18, "largest": 18, "zoom": 18, "timelin": 18, "period": 18, "examin": 18, "did": 18, "pipe_watermark": 18, "closer": 18, "preced": 18, "proceed": [18, 25], "watermark": 18, "cv2": 18, "pywt": 18, "leav": 18, "broken": 18, "rerun": 18, "scale_model_input": 18, "ran": 18, "my_funct": 18, "preocess": 18, "debug_single_process": 18, "magic": [18, 19, 27], "treat": 18, "xla_no_special_scalar": 18, "hurt": [19, 27], "perf": [19, 27], "pov": [19, 27], "sai": [19, 27], "assur": [19, 27], "gone": [19, 27], "coverag": [19, 27], "aim": [19, 25, 27], "explan": [19, 27], "mainli": [19, 27], "problem": [19, 27], "beginn": [19, 27], "propos": [19, 27], "reli": [19, 27], "impract": [19, 27], "assumpt": [19, 27], "ye": [19, 26, 27], "sentenc": [19, 27], "bucket": [19, 27, 30], "kinda": [19, 27], "anti": [19, 27], "frontend": [19, 27], "matter": [19, 27], "workaround": [19, 27], "okai": [19, 27], "teach": [19, 27], "produc": [19, 20, 21, 27], "theoret": [19, 27], "sort": [19, 27], "obviou": [19, 27], "s64": [19, 27], "inde": [19, 27], "_get_xla_tensor_dimension_s": [19, 27], "commonli": [19, 27], "wrong": [19, 27], "wors": [19, 27], "probabl": [19, 27], "know": [19, 21, 27], "upper": [19, 27], "nit": [19, 27], "rand": [19, 27], "solv": [19, 27], "kept": [19, 27], "earli": [19, 27], "accessor": [19, 27], "2d": [19, 25, 27], "implicitli": [19, 27], "doubl": [19, 27], "overload": [19, 27], "explod": [19, 27], "convers": [19, 27], "cheap": [19, 27], "ve": [19, 27], "hoc": [19, 27], "think": [19, 27], "verison": [19, 27], "bla": [19, 27], "blabla": [19, 27], "interpret": [19, 27], "proce": [19, 27], "uglier": [19, 27], "win": [19, 27], "pars": [19, 27], "torchscript": [19, 27], "somehow": [19, 27], "merg": [19, 27], "lazili": [19, 27, 28, 30], "properli": [19, 27], "thought": [19, 27], "trivial": [19, 27], "effort": [19, 27, 28], "side": [19, 27], "bandwidth": [19, 27], "automag": [19, 27], "gold": [19, 27], "smart": [19, 27], "trick": [19, 27], "tbh": [19, 27], "longer": [19, 27], "unawar": [19, 27], "hope": [19, 27], "smash": [19, 27], "blocker": [19, 27], "ahead": [19, 27], "nnc": [19, 27], "exactli": [19, 27], "transpos": [19, 27], "brian": [19, 27], "hirsh": [19, 27], "bdhirsh": [19, 27], "question": [19, 27], "comment": [19, 27], "stick": [19, 27], "torch_warn": [19, 27], "yea": [19, 27], "hei": [19, 27], "won": [19, 20, 27], "blaze": [19, 27], "isn": [19, 27, 30], "abil": [19, 21, 27], "devirtu": [19, 27], "sound": [19, 27], "great": [19, 27], "carri": [19, 27, 28], "truth": [19, 27], "irvalu": [19, 27], "enforc": [19, 21, 27], "discrep": [19, 27], "followup": [19, 27], "1000": [19, 27], "my": [19, 27, 30], "presenc": [19, 27], "get_dimention_s": [19, 27], "didn": [19, 27], "exponenti": [19, 27], "blowup": [19, 27], "fewer": [19, 27], "opportun": [19, 27], "recogn": [19, 22, 27], "feasibl": [19, 27], "annoi": [19, 27], "wasn": [19, 27], "materiz": [19, 27], "combo": [19, 27], "extend": 20, "float32": 20, "datatyp": 20, "float16": 20, "bfloat16": [20, 26], "syncfre": 20, "autocast": 20, "summar": 20, "elig": 20, "suppli": 20, "addmm": 20, "addmm_": 20, "prefer": 20, "float64": 20, "respect": 20, "unlist": 20, "__matmul__": 20, "addbmm": 20, "addmv": 20, "addr": 20, "baddbmm": 20, "bmm": 20, "conv1d": 20, "conv2d": [20, 24], "conv3d": 20, "conv_transpose1d": 20, "conv_transpose2d": 20, "conv_transpose3d": 20, "matmul": 20, "relu": [20, 21], "prelu": 20, "max_pool2d": 20, "batch_norm": 20, "log_softmax": 20, "binary_cross_entropy_with_logit": 20, "prod": 20, "cdist": 20, "chloeski": 20, "invers": 20, "reflection_pad": 20, "replication_pad": 20, "mse_loss": 20, "cosine_embbeding_loss": 20, "nll_loss": 20, "multilabel_margin_loss": 20, "qr": 20, "svd": 20, "triangular_solv": 20, "linalg_svd": 20, "linalg_inv_ex": 20, "widest": 20, "index_copi": 20, "scaler": [20, 26], "gradscal": 20, "_fetch_gradi": 20, "xla_use_f16": 20, "underflow": 20, "imagenet": 20, "minimum": [21, 24, 25], "nccl": 21, "new_rank": 21, "ddp_model": 21, "final": [21, 28], "launcher": 21, "demo_fn": 21, "touch": [21, 30], "five": 21, "sy": 21, "tempfil": 21, "cleanup": 21, "destroy_process_group": 21, "toymodel": 21, "net1": 21, "1000000": 21, "net2": 21, "demo_bas": 21, "graident_as_bucket_view": 21, "label": 21, "run_demo": 21, "tot": 21, "statist": 21, "unit": 21, "median": 21, "90th": 21, "deviat": 21, "cv": 21, "418": 21, "54": 21, "419": 21, "22": 21, "430": 21, "40": 21, "76": 21, "02": 21, "97": 21, "407": 21, "60": 21, "39": 21, "seem": 21, "17864": 21, "19": [21, 22], "20108": 21, "96": 21, "24351": 21, "74": 21, "5866": 21, "83": 21, "10701": 21, "11770": 21, "00": 21, "14313": 21, "78": 21, "3102": 21, "92": 21, "41": [21, 22], "round": 21, "heavili": [21, 22], "sens": 21, "amort": 21, "logdir": 21, "converg": 21, "caution": 21, "interest": 21, "known": 21, "crash": 21, "unmodifi": 22, "hook": 22, "biggest": [22, 24], "torchfx": 22, "technologi": 22, "fx": 22, "a_xla": 22, "b_xla": 22, "compiled_cod": 22, "eval_model": 22, "xla_resnet18": 22, "eval": 22, "dynamo_resnet18": 22, "no_grad": 22, "resent18": 22, "analysi": 22, "bench": 22, "59": 22, "resnext50_32x4d": 22, "91": 22, "alexnet": 22, "28": 22, "mobilenet_v2": 22, "18": 22, "62": 22, "mnasnet1_0": 22, "68": 22, "vgg16": 22, "bert_pytorch": 22, "squeezenet1_1": 22, "timm_vision_transform": 22, "52": 22, "geomean": 22, "04": 22, "train_model": 22, "crossentropyloss": 22, "pred": 22, "train_model_main": 22, "dynamo_train_model": 22, "xla_optim": 22, "weight_decai": 22, "extract": 22, "07": 22, "43": 22, "81": 22, "87": 22, "fwd": 22, "bwd": 22, "e2": 22, "hide": 22, "larger": [22, 24], "wit": 22, "promis": 22, "tradit": 22, "excit": 22, "upcom": [22, 28], "invest": 22, "matur": 22, "stori": 22, "_higher_order_op": 23, "fori_loop": 23, "cond_fn": 23, "body_fn": 23, "bodi": 23, "iteri": 23, "init_v": 23, "functionaltensor": 23, "lvl": 23, "cumul": 23, "ten": 23, "51": 23, "xlafullyshardeddataparallel": 24, "my_modul": [24, 25], "adam": [24, 25], "0001": [24, 25], "leftov": [24, 25], "arxiv": 24, "1910": 24, "02054": 24, "reshard_after_forward": 24, "test_train_mp_mnist_fsdp_with_ckpt": 24, "test_train_mp_imagenet_fsdp": 24, "interleav": 24, "submodul": 24, "fsdpvitmodel": 24, "checkpoint_modul": [24, 25], "3524": 24, "auto_wrap_polici": [24, 25], "size_based_auto_wrap_polici": 24, "polici": [24, 28], "100m": 24, "transformer_auto_wrap_polici": [24, 25], "transformer_layer_cl": [24, 25], "auto_wrapper_cal": 24, "remateri": 24, "resum": 24, "get_shard_metadata": 24, "consolidate_sharded_model_checkpoint": 24, "stitch": 24, "ckpt": 24, "shard_metadata": 24, "ckpt_path": 24, "pth": 24, "consolidate_sharded_ckpt": 24, "ckpt_prefix": 24, "your_sharded_checkpoint_fil": 24, "ckpt_suffix": 24, "_rank": 24, "inspir": 24, "structur": [24, 28], "fairscal": 24, "fullyshardeddataparallel": 24, "readthedoc": 24, "en": 24, "resort": 24, "train_resnet_fsdp_auto_wrap": 24, "newer": 24, "recurs": [24, 25], "98": 24, "drop_last": 24, "use_nested_fsdp": 24, "use_gradient_checkpoint": 24, "final_ckpt": 24, "75": 24, "download": 24, "1k": 24, "datadir": 24, "test_set_batch_s": 24, "eval_interv": 24, "num_warmup_epoch": 24, "lr_scheduler_divide_every_n_epoch": 24, "lr_scheduler_divisor": 24, "residu": 24, "algorithm": [24, 25], "ronghanghu": 24, "vit_10b_fsdp_exampl": 24, "vit": 24, "fsdpv2": 25, "famou": 25, "enjoi": 25, "tabl": 25, "spmd_fully_sharded_data_parallel": 25, "spmdfullyshardeddataparallel": 25, "autowrap": 25, "decoderlay": 25, "functool": 25, "decoder_only_model": 25, "shard_output": 25, "0th": 25, "children": 25, "fork": 25, "hf": 25, "abstract": [26, 28], "blockwis": 26, "int4": 26, "analog": 26, "classifi": 26, "flexibl": 26, "choos": [26, 30], "docstr": 26, "xla_quantized_matmul": 26, "n_input_featur": 26, "n_output_featur": 26, "w_int": 26, "127": 26, "int8": 26, "matmul_output": 26, "quantized_matmul": 26, "x_xla": 26, "w_int_xla": 26, "scaler_xla": 26, "matmul_output_xla": 26, "w": 26, "f_dynamo": 26, "dynamo_out_xla": 26, "myqlinearforxlabackend": 26, "load_weight": 26, "processed_w": 26, "processed_scal": 26, "stuff": 26, "orig_model": 26, "mymodel": 26, "q_weight": 26, "q_weights_for_xla": 26, "process_for_xla": 26, "q_linear": 26, "xlaquantizedlinear": 26, "in_featur": 26, "out_featur": 26, "load_quantized_weight": 26, "channel": 26, "sym": 26, "asym": 26, "w8a16": 26, "w8a8": 26, "w4a8": 26, "gspmd": [28, 29], "proced": 28, "src": [28, 30], "_input_sharding_": 28, "4d": 28, "input_shard": 28, "shardingspec": 28, "input_mesh": 28, "s2": 28, "s3": 28, "s4": 28, "_after": 28, "_the": 28, "unnecessari": 28, "forth": 28, "techniqu": 28, "decis": 28, "nice": 28, "arrang": 28, "center": 28, "multislic": 28, "denot": 28, "delai": 28, "subclass": 28, "__torch_dispatch__": 28, "global_tensor": 28, "strictli": 28, "local_shard": 28, "xlashard": 28, "4e8e5511555073ce8b6d1a436bf808c9333dcac6": 28, "xla_sharded_tensor": 28, "l12": 28, "ongo": 28, "distributedtensor": 28, "proof": 28, "concept": [28, 29], "distribute_tensor": 28, "devicemesh": 28, "big_tensor": 28, "100000": 28, "88": 28, "my_dtensor": 28, "stai": 28, "dynamo_mark_shard": 28, "placement": 28, "visualize_tensor_shard": 28, "visualize_shard": 28, "rich": 28, "2x2": 28, "generated_t": 28, "use_color": 28, "style": 28, "tile": 28, "partial_repl": 28, "envvar": 28, "xla_auto_spmd": 28, "_tensor": 28, "distribute_modul": 28, "auto_polici": 28, "mymodul": 28, "sharded_model": 28, "behvaior": 28, "xla_auto_use_group_shard": 28, "reshard": 28, "xla_auto_spmd_mesh": 28, "unset": 28, "hint": 29, "strategi": 29, "th": 29, "cluster": 29, "interconnect": 29, "encourag": 29, "fist": 29, "paral": 29, "dedic": 30, "planner": 30, "spmdsaveplann": 30, "spmdloadplann": 30, "dist_cp": 30, "distributed_checkpoint": 30, "xc": 30, "storage_writ": 30, "filesystemwrit": 30, "checkpoint_dir": 30, "storage_read": 30, "filesystemread": 30, "all_step": 30, "save_async": 30, "unblock": 30, "preemption": 30, "detect": 30, "provis": 30, "queuedresourc": 30, "autocheckpoint": 30, "chkpt_on_preempt": 30, "fsspec": 30, "filesystem": 30, "prime_optim": 30, "chkpt_mgr": 30, "tracked_step": 30, "highest": 30, "best_step": 30, "prime": 30, "enumer": 30, "attempt": 30, "unprim": 30, "destruct": 30, "discov": 30, "nvidia": 31, "resnet": 31, "num_gpu_machin": 31, "rank_of_current_machin": 31, "machine_0_ip_address": 31, "training_or_inference_script_using_spmd": 31, "xla_use_spmd": 31, "test_train_spmd_imagenet": 31}, "objects": {"": [[12, 0, 0, "-", "torch_xla"]], "torch_xla": [[12, 1, 1, "", "compile"], [12, 1, 1, "", "device"], [12, 1, 1, "", "device_count"], [12, 1, 1, "", "devices"], [12, 0, 0, "-", "experimental"], [12, 1, 1, "", "manual_seed"], [12, 0, 0, "-", "runtime"], [12, 1, 1, "", "sync"]], "torch_xla.core": [[12, 0, 0, "-", "xla_model"]], "torch_xla.core.xla_model": [[12, 1, 1, "", "add_step_closure"], [12, 1, 1, "", "all_gather"], [12, 1, 1, "", "all_reduce"], [12, 1, 1, "", "all_to_all"], [12, 1, 1, "", "get_memory_info"], [12, 1, 1, "", "get_rng_state"], [12, 1, 1, "", "get_stablehlo"], [12, 1, 1, "", "get_stablehlo_bytecode"], [12, 1, 1, "", "is_master_ordinal"], [12, 1, 1, "", "mesh_reduce"], [12, 1, 1, "", "optimizer_step"], [12, 1, 1, "", "rendezvous"], [12, 1, 1, "", "save"], [12, 1, 1, "", "set_rng_state"], [12, 1, 1, "", "wait_device_ops"], [12, 1, 1, "", "xla_device"], [12, 1, 1, "", "xla_device_hw"]], "torch_xla.debug": [[12, 0, 0, "-", "metrics"]], "torch_xla.debug.metrics": [[12, 1, 1, "", "counter_names"], [12, 1, 1, "", "counter_value"], [12, 1, 1, "", "metric_data"], [12, 1, 1, "", "metric_names"], [12, 1, 1, "", "metrics_report"], [12, 1, 1, "", "short_metrics_report"]], "torch_xla.distributed": [[12, 0, 0, "-", "parallel_loader"], [12, 0, 0, "-", "spmd"], [12, 0, 0, "-", "xla_multiprocessing"]], "torch_xla.distributed.parallel_loader": [[12, 2, 1, "", "MpDeviceLoader"]], "torch_xla.distributed.spmd": [[12, 2, 1, "", "HybridMesh"], [12, 2, 1, "", "Mesh"], [12, 1, 1, "", "clear_sharding"], [12, 1, 1, "", "get_1d_mesh"], [12, 1, 1, "", "get_global_mesh"], [12, 1, 1, "", "mark_sharding"], [12, 1, 1, "", "set_global_mesh"]], "torch_xla.distributed.xla_multiprocessing": [[12, 1, 1, "", "spawn"]], "torch_xla.experimental": [[12, 1, 1, "", "eager_mode"]], "torch_xla.runtime": [[12, 1, 1, "", "addressable_device_count"], [12, 1, 1, "", "device_type"], [12, 1, 1, "", "get_master_ip"], [12, 1, 1, "", "global_device_count"], [12, 1, 1, "", "global_ordinal"], [12, 1, 1, "", "global_runtime_device_count"], [12, 1, 1, "", "initialize_cache"], [12, 1, 1, "", "is_spmd"], [12, 1, 1, "", "local_device_count"], [12, 1, 1, "", "local_ordinal"], [12, 1, 1, "", "local_process_count"], [12, 1, 1, "", "use_spmd"], [12, 1, 1, "", "world_size"]]}, "objtypes": {"0": "py:module", "1": "py:function", "2": "py:class"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "function", "Python function"], "2": ["py", "class", "Python class"]}, "titleterms": {"learn": [0, 1, 11], "about": [0, 1, 11], "gpu": [0, 10, 15, 20, 31], "tpu": [1, 4, 15, 16, 18, 20, 24, 28], "bazel": 2, "pytorch": [2, 3, 4, 6, 7, 8, 9, 11, 12, 16, 17, 18, 22, 24, 27, 28, 29], "xla": [2, 3, 4, 6, 7, 8, 11, 12, 16, 17, 18, 20, 22, 24, 26, 27, 28, 29], "depend": [2, 8, 10], "how": [2, 21, 26, 29], "build": 2, "librari": 2, "torch": [2, 7, 9, 15, 28], "plugin": [2, 6], "remot": 2, "cach": [2, 16], "run": [2, 3, 9, 16, 17, 18, 28, 31], "test": [2, 3, 5, 17, 23], "code": [2, 4, 18, 26], "coverag": 2, "languag": 2, "server": 2, "codegen": 3, "migrat": 3, "guid": [3, 5, 29], "befor": [3, 5], "you": [3, 5, 19, 27], "start": [3, 5, 19, 27], "file": [3, 5, 9], "structur": [3, 5], "old": 3, "op": [3, 5, 7, 26], "lower": [3, 5, 7], "step": [3, 4], "1": [3, 18, 19, 27], "identifi": 3, "2": [3, 18, 19, 25, 27], "inspect": 3, "gener": [3, 9], "lazyir": 3, "h": 3, "3": [3, 19, 27], "implement": [3, 6], "miss": 3, "ir": 3, "function": 3, "torch_xla": [3, 12, 19], "csrc": 3, "ops_xla_shape_fn": 3, "cpp": 3, "4": 3, "ops_lower_fn": 3, "5": 3, "cleanup": 3, "verifi": 3, "result": 3, "sampl": 3, "pr": 3, "configur": 4, "develop": 4, "environ": [4, 17], "visual": 4, "studio": 4, "creat": [4, 16], "connect": 4, "your": 4, "set": 4, "up": 4, "workspac": 4, "next": 4, "understand": [5, 17], "oper": [5, 9, 19, 20, 26, 27], "unit": [5, 17], "tip": 5, "custom": [6, 8, 10], "hardwar": 6, "pjrt": [6, 15], "c": 6, "api": [6, 7, 12, 14], "packag": 6, "support": [7, 20, 26], "distribut": [7, 12, 15, 30], "collect": 7, "stack": 7, "non": 7, "dynamo": [7, 17], "case": [7, 19, 23, 27], "descript": 7, "kernel": [8, 10], "via": [8, 10], "palla": 8, "adopt": 8, "abov": 8, "compat": 8, "us": [8, 19, 21, 23, 25, 26, 27, 29], "built": 8, "flashattent": 8, "exampl": [8, 18, 20, 23, 24, 25], "usag": [8, 14, 23], "integr": [8, 22, 28], "pagedattent": 8, "export": 9, "stablehlo": 9, "save": [9, 16], "bytecod": 9, "disk": 9, "convert": [9, 18], "serv": 9, "common": [9, 17], "wrapper": 9, "i": [9, 19, 27, 29], "want": 9, "directli": 9, "tf": 9, "saved_model": 9, "format": 9, "without": [9, 19, 27], "need": 9, "an": [9, 16], "separ": 9, "command": 9, "other": 9, "produc": 9, "save_as_stablehlo": 9, "preserv": 9, "high": 9, "level": 9, "composit": 9, "triton": 10, "document": 11, "acceler": 11, "featur": [11, 22, 26], "improv": 11, "workload": 11, "perform": [11, 15, 17, 18], "contribut": 11, "runtim": [12, 15], "xla_model": 12, "spmd": [12, 25, 28, 29, 31], "experiment": [12, 26], "debug": [12, 17, 28], "dynam": [13, 19, 27], "shape": [13, 19, 27], "bound": [13, 19, 27], "eager": 14, "mode": [14, 29], "compil": [14, 16, 17, 28], "basic": 14, "infer": [14, 18, 22], "train": [14, 15, 22, 24], "benchmark": [14, 17, 21], "tl": 15, "dr": 15, "benefit": 15, "quickstart": 15, "cpu": [15, 16], "pod": [15, 16, 18, 24, 28], "docker": 15, "singl": [15, 16, 18], "node": 15, "multi": [15, 16], "differ": 15, "from": [15, 16, 19, 27], "xrt": 15, "multithread": 15, "v2": 15, "v3": [15, 24], "chang": 15, "xm": 15, "rendezv": 15, "new": 15, "devic": [16, 18, 28], "tensor": [16, 17, 19, 27], "ar": 16, "model": [16, 26], "multipl": [16, 18], "process": [16, 30], "deep": 16, "dive": 16, "lazi": 16, "memori": [16, 23], "layout": 16, "move": 16, "load": [16, 28], "further": [16, 29], "read": [16, 29], "troubleshoot": 17, "saniti": 17, "check": 17, "version": 17, "A": 17, "simpl": [17, 23], "calcul": 17, "resnet": [17, 24], "With": 17, "fake": [17, 21], "data": [17, 21, 24, 25, 28], "tool": [17, 28], "auto": [17, 28], "metric": 17, "analysi": [17, 18], "execut": 17, "get": 17, "report": 17, "The": 17, "clear": 17, "profil": [17, 18], "known": 17, "caveat": 17, "quirk": 17, "more": 17, "variabl": 17, "combin": 17, "reproduc": 17, "ci": 17, "cd": 17, "failur": 17, "overview": 18, "setup": 18, "stabl": 18, "diffus": 18, "lightn": 18, "hf": 18, "sourc": [19, 27], "recompil": [19, 27], "let": [19, 27], "": [19, 27], "first": [19, 27], "some": [19, 27], "fact": [19, 27], "constraint": [19, 27], "input": [19, 27], "dataset": [19, 27], "output": [19, 25, 27], "can": [19, 27], "fix": [19, 27], "when": [19, 27], "queri": [19, 27], "its": [19, 27], "real": [19, 21, 27], "dimens": [19, 27], "what": [19, 27, 29], "control": [19, 23, 27], "flow": [19, 27], "conclus": [19, 27], "appendix": [19, 27], "automat": 20, "mix": 20, "precis": 20, "amp": 20, "best": 20, "practic": 20, "do": 21, "distributeddataparallel": 21, "ddp": 21, "background": 21, "motiv": 21, "resnet50": 21, "mnist": [21, 24], "disclaim": 21, "torchdynamo": 22, "gap": 22, "take": 22, "awai": 22, "optim": [23, 28, 30], "util": 23, "while_loop": 23, "group": [23, 30], "pure": 23, "python": 23, "while": 23, "loop": 23, "fulli": [24, 25], "shard": [24, 25, 28], "parallel": [24, 25], "script": 24, "imagenet": 24, "instal": 24, "clone": 24, "repo": 24, "8": 24, "50": 24, "10": 24, "billion": 24, "paramet": 24, "gradient": 25, "checkpoint": [25, 30], "huggingfac": 25, "llama": 25, "quantiz": 26, "call": 26, "modul": 26, "swap": 26, "matrix": 26, "multipli": 26, "advanc": 28, "topic": 28, "awar": 28, "host": 28, "virtual": 28, "hybrid": 28, "mesh": [28, 29], "xlashardedtensor": 28, "dtensor": 28, "activ": 28, "user": 29, "partit": 29, "spec": 29, "checkpointmanag": 30, "restor": 30, "state": 30}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 8, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx.ext.todo": 2, "sphinx.ext.viewcode": 1, "sphinx": 57}, "alltitles": {"Learn about GPUs": [[0, "learn-about-gpus"]], "Learn about TPUs": [[1, "learn-about-tpus"]], "Bazel in Pytorch/XLA": [[2, "bazel-in-pytorch-xla"]], "Bazel dependencies": [[2, "bazel-dependencies"]], "How to build XLA libraries": [[2, "how-to-build-xla-libraries"]], "How to build the Torch/XLA plugin": [[2, "how-to-build-the-torch-xla-plugin"]], "Remote caching": [[2, "remote-caching"]], "Running tests": [[2, "running-tests"]], "Code coverage": [[2, "code-coverage"]], "Language Server": [[2, "language-server"]], "Building PyTorch/XLA": [[2, "building-pytorch-xla"]], "Codegen migration Guide": [[3, "codegen-migration-guide"]], "Before you start": [[3, "before-you-start"], [5, "before-you-start"]], "File structure": [[3, "file-structure"], [5, "file-structure"]], "PyTorch Codegen files": [[3, "pytorch-codegen-files"]], "PyTorch/XLA Codegen files": [[3, "pytorch-xla-codegen-files"]], "PyTorch/XLA Old Op Lowering files": [[3, "pytorch-xla-old-op-lowering-files"]], "Codegen step by step": [[3, "codegen-step-by-step"]], "1. Identify the op": [[3, "identify-the-op"]], "2. Codegen the op and inspect the generated file": [[3, "codegen-the-op-and-inspect-the-generated-file"]], "LazyIr.h": [[3, "lazyir-h"]], "3. Implement the missing IR function": [[3, "implement-the-missing-ir-function"]], "torch_xla/csrc/ops/ops_xla_shape_fn.h": [[3, "torch-xla-csrc-ops-ops-xla-shape-fn-h"]], "torch_xla/csrc/ops/ops_xla_shape_fn.cpp": [[3, "torch-xla-csrc-ops-ops-xla-shape-fn-cpp"]], "4. Implement the lowering function": [[3, "implement-the-lowering-function"]], "torch_xla/csrc/ops/ops_lower_fn.cpp": [[3, "torch-xla-csrc-ops-ops-lower-fn-cpp"]], "5. Cleanup": [[3, "cleanup"]], "Run the test and verify the result": [[3, "run-the-test-and-verify-the-result"]], "Sample PRs": [[3, "sample-prs"]], "Configure a development environment": [[4, "configure-a-development-environment"]], "Visual Studio Code": [[4, "visual-studio-code"]], "Creating and connecting to your TPU": [[4, "creating-and-connecting-to-your-tpu"]], "Setting up a Visual Studio Code workspace with PyTorch/XLA": [[4, "setting-up-a-visual-studio-code-workspace-with-pytorch-xla"]], "Next steps": [[4, "next-steps"]], "OP Lowering Guide": [[5, "op-lowering-guide"]], "Understanding the operation": [[5, "understanding-the-operation"]], "Unit Test": [[5, "unit-test"]], "Tips": [[5, "tips"]], "Custom Hardware Plugins": [[6, "custom-hardware-plugins"]], "Implementing a PJRT Plugin": [[6, "implementing-a-pjrt-plugin"]], "PJRT C API Implementation": [[6, "pjrt-c-api-implementation"]], "PyTorch/XLA Plugin Package": [[6, "pytorch-xla-plugin-package"]], "Support of Torch Distributed API in PyTorch/XLA": [[7, "support-of-torch-distributed-api-in-pytorch-xla"]], "Collective ops lowering": [[7, "collective-ops-lowering"]], "Collective ops lowering stack": [[7, "collective-ops-lowering-stack"]], "non-Dynamo case": [[7, "non-dynamo-case"]], "Dynamo case": [[7, "dynamo-case"]], "API description": [[7, "api-description"]], "Custom Kernels via Pallas": [[8, "custom-kernels-via-pallas"]], "Adopt the above kernel to be compatible with PyTorch/XLA": [[8, "adopt-the-above-kernel-to-be-compatible-with-pytorch-xla"]], "Use built-in kernels": [[8, "use-built-in-kernels"]], "FlashAttention": [[8, "id1"]], "Example usage": [[8, "example-usage"], [8, "id3"]], "Integration Example": [[8, "integration-example"], [8, "id4"]], "PagedAttention": [[8, "id2"]], "Dependencies": [[8, "dependencies"], [10, "dependencies"]], "Torch Export to StableHLO": [[9, "torch-export-to-stablehlo"]], "Saving StableHLO bytecodes to disk": [[9, "saving-stablehlo-bytecodes-to-disk"]], "Convert saved StableHLO for serving": [[9, "convert-saved-stablehlo-for-serving"]], "Common wrappers": [[9, "common-wrappers"]], "I want to save directly tf.saved_model format without needing to run an separate command.": [[9, "i-want-to-save-directly-tf-saved-model-format-without-needing-to-run-an-separate-command"]], "Other common wrappers": [[9, "other-common-wrappers"]], "Files produced by save_as_stablehlo.": [[9, "files-produced-by-save-as-stablehlo"]], "Preserving High-Level PyTorch Operations in StableHLO by generating stablehlo.composite": [[9, "preserving-high-level-pytorch-operations-in-stablehlo-by-generating-stablehlo-composite"]], "Custom GPU Kernels via Triton": [[10, "custom-gpu-kernels-via-triton"]], "PyTorch/XLA documentation": [[11, "pytorch-xla-documentation"]], "Learn about Pytorch/XLA": [[11, null]], "Learn about accelerators": [[11, null]], "PyTorch/XLA features": [[11, null]], "Improve Pytorch/XLA workload performance": [[11, null]], "Contribute to Pytorch/XLA": [[11, null]], "PyTorch/XLA API": [[12, "pytorch-xla-api"]], "torch_xla": [[12, "module-torch_xla"]], "runtime": [[12, "module-torch_xla.runtime"]], "xla_model": [[12, "module-torch_xla.core.xla_model"]], "distributed": [[12, "module-torch_xla.distributed.parallel_loader"]], "spmd": [[12, "module-torch_xla.distributed.spmd"]], "experimental": [[12, "module-torch_xla.experimental"]], "debug": [[12, "module-torch_xla.debug.metrics"]], "Dynamic shape": [[13, "dynamic-shape"]], "Bounded dynamic shape": [[13, "bounded-dynamic-shape"]], "Eager Mode + Compile API": [[14, "eager-mode-compile-api"]], "Basic Usage": [[14, "basic-usage"]], "Inference": [[14, "inference"], [22, "inference"]], "Training": [[14, "training"], [22, "training"]], "Benchmark": [[14, "benchmark"]], "PJRT Runtime": [[15, "pjrt-runtime"]], "TL;DR": [[15, "tl-dr"]], "Benefits": [[15, "benefits"]], "Quickstart": [[15, "quickstart"]], "CPU": [[15, "cpu"]], "TPU": [[15, "tpu"]], "Pods": [[15, "pods"]], "Docker": [[15, "docker"]], "GPU": [[15, "gpu"]], "Single-node GPU training": [[15, "single-node-gpu-training"]], "Multi-node GPU training": [[15, "multi-node-gpu-training"]], "Differences from XRT": [[15, "differences-from-xrt"]], "Multithreading on TPU v2/v3": [[15, "id3"]], "Changes to xm.rendezvous": [[15, "changes-to-xm-rendezvous"]], "PJRT and torch.distributed": [[15, "pjrt-and-torch-distributed"]], "Performance": [[15, "performance"]], "New TPU runtime": [[15, "new-tpu-runtime"]], "PyTorch on XLA Devices": [[16, "pytorch-on-xla-devices"]], "Creating an XLA Tensor": [[16, "creating-an-xla-tensor"]], "XLA Tensors are PyTorch Tensors": [[16, "xla-tensors-are-pytorch-tensors"]], "Running Models on XLA Devices": [[16, "running-models-on-xla-devices"]], "Running on a Single XLA Device": [[16, "running-on-a-single-xla-device"]], "Running on Multiple XLA Devices with Multi-processing": [[16, "running-on-multiple-xla-devices-with-multi-processing"]], "Running on TPU Pods": [[16, "running-on-tpu-pods"]], "XLA Tensor Deep Dive": [[16, "id3"]], "XLA Tensors are Lazy": [[16, "xla-tensors-are-lazy"]], "Memory Layout": [[16, "memory-layout"]], "Moving XLA Tensors to and from the CPU": [[16, "moving-xla-tensors-to-and-from-the-cpu"]], "Saving and Loading XLA Tensors": [[16, "saving-and-loading-xla-tensors"]], "Compilation Caching": [[16, "compilation-caching"]], "Further Reading": [[16, "further-reading"], [29, "further-reading"]], "Troubleshoot": [[17, "troubleshoot"]], "Sanity Check": [[17, "sanity-check"]], "Check PyTorch/XLA Version": [[17, "check-pytorch-xla-version"]], "Perform A Simple Calculation": [[17, "perform-a-simple-calculation"]], "Run Resnet With Fake Data": [[17, "run-resnet-with-fake-data"]], "Performance Debugging": [[17, "performance-debugging"]], "PyTorch/XLA Debugging Tool": [[17, "pytorch-xla-debugging-tool"]], "Perform A Auto-Metrics Analysis": [[17, "perform-a-auto-metrics-analysis"]], "Compilation & Execution Analysis": [[17, "compilation-execution-analysis"]], "Get A Metrics Report": [[17, "get-a-metrics-report"]], "Understand The Metrics Report": [[17, "understand-the-metrics-report"]], "Clear The Metrics Report": [[17, "clear-the-metrics-report"]], "PyTorch/XLA + Dynamo Debugging Tool": [[17, "pytorch-xla-dynamo-debugging-tool"]], "Performance Profiling": [[17, "performance-profiling"]], "Simple Benchmarking": [[17, "simple-benchmarking"]], "Known Performance Caveats": [[17, "known-performance-caveats"]], "XLA Tensor Quirks": [[17, "xla-tensor-quirks"]], "More Debugging Tools": [[17, "more-debugging-tools"]], "Environment Variables": [[17, "environment-variables"]], "Common Debugging Environment Variables Combinations": [[17, "common-debugging-environment-variables-combinations"]], "Reproducing PyTorch/XLA CI/CD unit test failures.": [[17, "reproducing-pytorch-xla-ci-cd-unit-test-failures"]], "Pytorch/XLA overview": [[18, "pytorch-xla-overview"]], "TPU Setup": [[18, "tpu-setup"]], "Converting code to PyTorch XLA": [[18, "converting-code-to-pytorch-xla"]], "Example 1. Stable Diffusion inference in PyTorch Lightning on a Single TPU Device": [[18, "example-1-stable-diffusion-inference-in-pytorch-lightning-on-a-single-tpu-device"]], "Example 2. HF Stable Diffusion Inference": [[18, "example-2-hf-stable-diffusion-inference"]], "Running on a Single TPU device": [[18, "running-on-a-single-tpu-device"]], "Profiling and performance analysis": [[18, "profiling-and-performance-analysis"]], "Running on Multiple TPU Devices": [[18, "running-on-multiple-tpu-devices"]], "Running on Pods": [[18, "running-on-pods"]], "Source of recompilations in torch_xla": [[19, "source-of-recompilations-in-torch-xla"]], "Let\u2019s first start with some facts/constraints:": [[19, "lets-first-start-with-some-facts-constraints"], [27, "lets-first-start-with-some-facts-constraints"]], "#1. From input dataset.": [[19, "from-input-dataset"], [27, "from-input-dataset"]], "#2. From operator output": [[19, "from-operator-output"], [27, "from-operator-output"]], "2.1 Bounded dynamic shape can fix the case when you use the tensor with dynamic shape as a Tensor, without querying its real dimension.": [[19, "bounded-dynamic-shape-can-fix-the-case-when-you-use-the-tensor-with-dynamic-shape-as-a-tensor-without-querying-its-real-dimension"], [27, "bounded-dynamic-shape-can-fix-the-case-when-you-use-the-tensor-with-dynamic-shape-as-a-tensor-without-querying-its-real-dimension"]], "2.2 what if real dimension is queried on a tensor with dynamic shape?": [[19, "what-if-real-dimension-is-queried-on-a-tensor-with-dynamic-shape"], [27, "what-if-real-dimension-is-queried-on-a-tensor-with-dynamic-shape"]], "#3. From control flow": [[19, "from-control-flow"], [27, "from-control-flow"]], "Conclusion:": [[19, "conclusion"], [27, "conclusion"]], "Appendix:": [[19, "appendix"], [27, "appendix"]], "Automatic Mixed Precision": [[20, "automatic-mixed-precision"]], "AMP for XLA:TPU": [[20, "amp-for-xla-tpu"]], "AMP for XLA:TPU Best Practices": [[20, "amp-for-xla-tpu-best-practices"]], "Supported Operators": [[20, "supported-operators"]], "AMP for XLA:GPU": [[20, "amp-for-xla-gpu"]], "AMP for XLA:GPU Best Practices": [[20, "amp-for-xla-gpu-best-practices"]], "Examples": [[20, "examples"]], "How to do DistributedDataParallel(DDP)": [[21, "how-to-do-distributeddataparallel-ddp"]], "Background / Motivation": [[21, "background-motivation"]], "How to use DistributedDataParallel": [[21, "how-to-use-distributeddataparallel"]], "Benchmarking": [[21, "benchmarking"]], "Resnet50 with fake data": [[21, "resnet50-with-fake-data"]], "MNIST with fake data": [[21, "mnist-with-fake-data"]], "MNIST with real data": [[21, "mnist-with-real-data"]], "Disclaimer": [[21, "disclaimer"]], "TorchDynamo integration in PyTorch XLA": [[22, "torchdynamo-integration-in-pytorch-xla"]], "Integration": [[22, "integration"]], "Feature gaps": [[22, "feature-gaps"]], "Take away": [[22, "take-away"]], "Optimize memory utilization using while_loop": [[23, "optimize-memory-utilization-using-while-loop"]], "while_loop": [[23, "while-loop"]], "Usage:": [[23, "usage"]], "simple example with while_loop:": [[23, "simple-example-with-while-loop"]], "Control group test case": [[23, "control-group-test-case"]], "Control group example with pure python while loop": [[23, "control-group-example-with-pure-python-while-loop"]], "Fully Sharded Data Parallel in PyTorch XLA": [[24, "fully-sharded-data-parallel-in-pytorch-xla"]], "Example training scripts on MNIST and ImageNet": [[24, "example-training-scripts-on-mnist-and-imagenet"]], "Installation": [[24, "installation"]], "Clone PyTorch/XLA repo": [[24, "clone-pytorch-xla-repo"]], "Train MNIST on v3-8 TPU": [[24, "train-mnist-on-v3-8-tpu"]], "Train ImageNet with ResNet-50 on v3-8 TPU": [[24, "train-imagenet-with-resnet-50-on-v3-8-tpu"]], "Example training scripts on TPU pod (with 10 billion parameters)": [[24, "example-training-scripts-on-tpu-pod-with-10-billion-parameters"]], "Fully Sharded Data Parallel using SPMD": [[25, "fully-sharded-data-parallel-using-spmd"]], "Sharding output": [[25, "sharding-output"]], "Gradient checkpointing": [[25, "gradient-checkpointing"]], "HuggingFace Llama 2 Example": [[25, "huggingface-llama-2-example"]], "Quantized Operations for XLA (Experimental feature)": [[26, "quantized-operations-for-xla-experimental-feature"]], "How to use:": [[26, "how-to-use"]], "Call XLA quantized op in model code": [[26, "call-xla-quantized-op-in-model-code"]], "Module Swap": [[26, "module-swap"]], "Supported Quantized Operations:": [[26, "supported-quantized-operations"]], "Matrix Multiply": [[26, "matrix-multiply"]], "Source of recompilations in Pytorch/XLA": [[27, "source-of-recompilations-in-pytorch-xla"]], "PyTorch/XLA SPMD advanced topics": [[28, "pytorch-xla-spmd-advanced-topics"]], "Sharding-Aware Host-to-Device Data Loading": [[28, "sharding-aware-host-to-device-data-loading"]], "Virtual Device Optimization": [[28, "virtual-device-optimization"]], "Hybrid Mesh": [[28, "hybrid-mesh"]], "Running SPMD on TPU Pod": [[28, "running-spmd-on-tpu-pod"]], "XLAShardedTensor": [[28, "xlashardedtensor"]], "DTensor Integration": [[28, "dtensor-integration"]], "Activation Sharding for torch.compile": [[28, "activation-sharding-for-torch-compile"]], "SPMD Debugging Tool": [[28, "spmd-debugging-tool"]], "Auto-Sharding": [[28, "auto-sharding"]], "PyTorch/XLA SPMD User Guide": [[29, "pytorch-xla-spmd-user-guide"]], "What is PyTorch/XLA SPMD?": [[29, "what-is-pytorch-xla-spmd"]], "How to use PyTorch/XLA SPMD?": [[29, "how-to-use-pytorch-xla-spmd"]], "SPMD Mode": [[29, "spmd-mode"]], "Mesh": [[29, "mesh"]], "Partition Spec": [[29, "partition-spec"]], "Distributed Checkpointing": [[30, "distributed-checkpointing"]], "CheckpointManager": [[30, "checkpointmanager"]], "Restoring Optimizer State": [[30, "restoring-optimizer-state"]], "Process Groups": [[30, "process-groups"]], "Running SPMD on GPU": [[31, "running-spmd-on-gpu"]]}, "indexentries": {"hybridmesh (class in torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.HybridMesh"]], "mesh (class in torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.Mesh"]], "mpdeviceloader (class in torch_xla.distributed.parallel_loader)": [[12, "torch_xla.distributed.parallel_loader.MpDeviceLoader"]], "add_step_closure() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.add_step_closure"]], "addressable_device_count() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.addressable_device_count"]], "all_gather() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.all_gather"]], "all_reduce() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.all_reduce"]], "all_to_all() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.all_to_all"]], "clear_sharding() (in module torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.clear_sharding"]], "compile() (in module torch_xla)": [[12, "torch_xla.compile"]], "counter_names() (in module torch_xla.debug.metrics)": [[12, "torch_xla.debug.metrics.counter_names"]], "counter_value() (in module torch_xla.debug.metrics)": [[12, "torch_xla.debug.metrics.counter_value"]], "device() (in module torch_xla)": [[12, "torch_xla.device"]], "device_count() (in module torch_xla)": [[12, "torch_xla.device_count"]], "device_type() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.device_type"]], "devices() (in module torch_xla)": [[12, "torch_xla.devices"]], "eager_mode() (in module torch_xla.experimental)": [[12, "torch_xla.experimental.eager_mode"]], "get_1d_mesh() (in module torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.get_1d_mesh"]], "get_global_mesh() (in module torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.get_global_mesh"]], "get_master_ip() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.get_master_ip"]], "get_memory_info() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.get_memory_info"]], "get_rng_state() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.get_rng_state"]], "get_stablehlo() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.get_stablehlo"]], "get_stablehlo_bytecode() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.get_stablehlo_bytecode"]], "global_device_count() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.global_device_count"]], "global_ordinal() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.global_ordinal"]], "global_runtime_device_count() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.global_runtime_device_count"]], "initialize_cache() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.initialize_cache"]], "is_master_ordinal() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.is_master_ordinal"]], "is_spmd() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.is_spmd"]], "local_device_count() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.local_device_count"]], "local_ordinal() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.local_ordinal"]], "local_process_count() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.local_process_count"]], "manual_seed() (in module torch_xla)": [[12, "torch_xla.manual_seed"]], "mark_sharding() (in module torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.mark_sharding"]], "mesh_reduce() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.mesh_reduce"]], "metric_data() (in module torch_xla.debug.metrics)": [[12, "torch_xla.debug.metrics.metric_data"]], "metric_names() (in module torch_xla.debug.metrics)": [[12, "torch_xla.debug.metrics.metric_names"]], "metrics_report() (in module torch_xla.debug.metrics)": [[12, "torch_xla.debug.metrics.metrics_report"]], "module": [[12, "module-torch_xla"], [12, "module-torch_xla.core.xla_model"], [12, "module-torch_xla.debug.metrics"], [12, "module-torch_xla.distributed.parallel_loader"], [12, "module-torch_xla.distributed.spmd"], [12, "module-torch_xla.distributed.xla_multiprocessing"], [12, "module-torch_xla.experimental"], [12, "module-torch_xla.runtime"]], "optimizer_step() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.optimizer_step"]], "rendezvous() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.rendezvous"]], "save() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.save"]], "set_global_mesh() (in module torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.set_global_mesh"]], "set_rng_state() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.set_rng_state"]], "short_metrics_report() (in module torch_xla.debug.metrics)": [[12, "torch_xla.debug.metrics.short_metrics_report"]], "spawn() (in module torch_xla.distributed.xla_multiprocessing)": [[12, "torch_xla.distributed.xla_multiprocessing.spawn"]], "sync() (in module torch_xla)": [[12, "torch_xla.sync"]], "torch_xla": [[12, "module-torch_xla"]], "torch_xla.core.xla_model": [[12, "module-torch_xla.core.xla_model"]], "torch_xla.debug.metrics": [[12, "module-torch_xla.debug.metrics"]], "torch_xla.distributed.parallel_loader": [[12, "module-torch_xla.distributed.parallel_loader"]], "torch_xla.distributed.spmd": [[12, "module-torch_xla.distributed.spmd"]], "torch_xla.distributed.xla_multiprocessing": [[12, "module-torch_xla.distributed.xla_multiprocessing"]], "torch_xla.experimental": [[12, "module-torch_xla.experimental"]], "torch_xla.runtime": [[12, "module-torch_xla.runtime"]], "use_spmd() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.use_spmd"]], "wait_device_ops() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.wait_device_ops"]], "world_size() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.world_size"]], "xla_device() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.xla_device"]], "xla_device_hw() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.xla_device_hw"]]}}) \ No newline at end of file