PyTorch/XLA 2.4 Release
Cloud TPUs now support the Pytorch 2.4 release, via PyTorch/XLA integration. On top of the underlying improvements and bug fixes in the PyTorch 2.4 release, this release introduces several features, and PyTorch/XLA specific bug fixes.
🚀 PyTorch/XLA 2.4 release delivers a 4% speedup boost (Geometric Mean) on torchbench evaluation benchmarks using openxla_eval
dynamo backend on TPUs, compared to the 2.3 release.
Highlights
We are excited to announce the release of PyTorch XLA 2.4! PyTorch 2.4 offers improved support for custom kernels using Pallas, including kernels like FlashAttention and Group Matrix Multiplication that can be used like any other torch operators and inference support for the PagedAttention kernel. We also add experimental support for eager mode that compiles and executes each operator for a better debugging and development experience.
Stable Features
PJRT
- Enable dynamic plugins by default #7270
GSPMD
- Support manual sharding and introduce high level manual sharding APIs #6915, #6931
- Support SPMDFullToShardShape, SPMDShardToFullShape #6922, #6925
Torch Compile
- Add a DynamoSyncInputExecuteTime counter #6813
- Fix runtime error when run dynamo with a profiler scope #6913
Export
- Add fx passes to support unbounded dynamism #6653
- Add dynamism support to conv1d, view, softmax #6653
- Add dynamism support to aten.embedding and aten.split_with_sizes #6781
- Inline all scalars by default in export path #6803
- Run shape propagation for inserted fx nodes #6805
- Add an option to not generate weights #6909
- Support export custom op to stablehlo custom call #7017
- Support array attribute in stablehlo composite #6840
- Add option to export FX Node metadata to StableHLO #7046
Beta Features
Pallas
- Support FlashAttention backward kernels #6870
- Make FlashAttention as torch.autograd.Function #6886
- Remove torch.empty in tracing to avoid allocating extra memory #6897
- Integrate FlashAttention with SPMD #6935
- Support scaling factor for attention weights in FlashAttention #7035
- Support segment ids in FlashAttention #6943
- Enable PagedAttention through Pallas #6912
- Properly support PagedAttention dynamo code path #7022
- Support megacore_mode in PagedAttention #7060
- Add Megablocks’ Group Matrix Multiplication kernel #6940, #7117, #7120, #7119, #7133, #7151
- Support histogram #7115, #7202
- Support tgmm #7137
- Make repeat_with_fixed_output_size not OOM on VMEM #7145
- Introduce GMM torch.autograd.function #7152
CoreAtenOpSet
- Lower embedding_bag_forward_only #6951
- Implement Repeat with fixed output shape #7114
- Add int8 per channel weight-only quantized matmul #7201
FSDP via SPMD
- Support multislice #7044
- Allow sharding on the maximal dimension of the weights #7134
- Apply optimization-barrier to all params and buffers during grad checkpointing #7206
Distributed Checkpoint
- Add optimizer priming for distributed checkpointing #6572
Usability
- Add xla.sync as a better name for mark_step. See #6399. #6914
- Add xla.step context manager to handle exceptions better. See #6751. #7068
- Implement ComputationClient::GetMemoryInfo for getting TPU memory allocation #7086
- Dump HLO HBM usage info #7085
- Add function for retrieving fallback operations #7116
- Deprecate XLA_USE_BF16 and XLA_USE_FP16 #7150
- Add PT_XLA_DEBUG_LEVEL to make it easier to distinguish between execution cause and compilation cause #7149
- Warn when using persistent cache with debug env vars #7175
- Add experimental MLIR debuginfo writer API #6799
GPU CUDA Fallback
- Add dlpack support #7025
- Make from_dlpack handle cuda synchronization implicitly for input tensors that have
__dlpack__
and__dlpack_device__
attributes. #7125
Distributed
- Switch all_reduce to use the new functional collective op #6887
- Allow user to configure distributed runtime service. #7204
- Use dest_offsets directly in LoadPlanner #7243
Experimental Features
Eager Mode
- Enable Eager mode for PyTorch/XLA #7611
- Support eager mode with torch.compile #7649
- Eagerly execute inplace ops in eager mode #7666
- Support eager mode for multi-process training #7668
- Handle random seed for eager mode #7669
- Enable SPMD with eager mode #7673
Triton
While Loop
- Prepare for torch while_loop signature change. #6872
- Implement fori_loop as a wrapper around while_loop #6850
- Complete fori_loop/while_loop and additional test case #7306
Bug Fixes and Improvements
- Fix type promotion for pow. (#6745)
- Fix vector norm lowering #6883
- Manually init absl log to avoid log spam #6890
- Fix pixel_shuffle return empty #6907
- Make nms fallback to CPU implementation by default #6933
- Fix torch.full scalar type #7010
- Handle multiple inplace update input output aliasing #7023
- Fix overflow for div arguments. #7081
- Add data_type promotion to gelu_backward, stack #7090, #7091
- Fix index of 0-element tensor by 0-element tensor #7113
- Fix output data-type for upsample_bilinear #7168
- Fix a data-type related problem for mul operation by converting inputs to result type #7130
- Make clip_grad_norm_ follow input’s dtype #7205