PyTorch/XLA 2.5 Release
Cloud TPUs now support the Pytorch 2.5 release, via PyTorch/XLA integration. On top of the underlying improvements and bug fixes in the PyTorch 2.5 release, this release introduces several features, and PyTorch/XLA specific bug fixes.
Highlights
We are excited to announce the release of PyTorch XLA 2.5! PyTorch 2.5 supports torch_xla.compile
function which improves the debugging experience for developers during the development process, and aligns distributed APIs with upstream PyTorch with the traceable collective support for both Dynamo and non-Dynamo cases. Start from PyTorch/XLA 2.5, proposed a clarified vision for deprecation of the older torch_xla API in favor of moving towards the existing PyTorch API, providing for a simplified developer experience.
If you’ve used vLLM for serving models on GPUs, you’ll now be able to seamlessly switch to its TPU backend. vLLM is a widely adopted inference framework that also serves as an excellent way to drive accelerator interoperability. With vLLM on TPU, users will retain the same vLLM interface we’ve grown to love, with direct integration with Hugging Face Models to make model experimentation easy.
STABLE FEATURES
Eager
- Increase max in flight operation to accommodate eager mode [#7263]
- Unify the logics to check eager mode [#7709]
- Update
eager.md
[#7710] - Optimize execution for ops that have multiple output in eager mode [#7680]
Quantization / Low Precision
- Asymmetric quantized
matmul
support [#7626] - Add blockwise quantized dot support [#7605]
- Support
int4
weight in quantized matmul / linear [#7235] - Support
fp8e5m2 dtype
[#7740] - Add
fp8e4m3fn
support [#7842] - Support dynamic activation quant for per-channel quantized matmul [#7867]
- Enable cross entropy loss for xla autocast with FP32 precision [#8094]
Pallas Kernels
- Support ab for
flash_attention
[#7840], actual kernel is implemented in JAX - Support
logits_soft_cap
parameter inpaged_attention
[#7704], actual kernel is implemented in JAX - Support
gmm
andtgmm trace_pallas
caching [#7921] - Cache flash attention tracing [#8026]
- Improve the user guide [#7625]
- Update pallas doc with
paged_attention
[#7591]
StableHLO
- Add user guide for stablehlo composite op [#7826]
gSPMD
- Handle the parameter wrapping for SPMD [#7604]
- Add helper function to get 1d mesh [#7577]
- Support manual
all-reduce
[#7576] - Expose
apply_backward_optimization_barrier
[#7477] - Support reduce-scatter in manual sharding [#7231]
- Allow
MpDeviceLoader
to shard dictionaries of tensor [#8202]
Dynamo
- Optimize dynamo dynamic shape caching [#7726]
- Add support for dynamic shape in dynamo [#7676]
- In dynamo optim_mode avoid unnecessary set_attr [#7915]
- Fix the crash with copy op in dynamo [#7902]
- Optimize
_split_xla_args_tensor_sym_constant
[#7900] - DYNAMO RNG seed update optimization [#7884]
- Support
mark_dynamic
[#7812] - Support gmm as a custom op for dynamo [#7672]
- Fix dynamo inplace copy [#7933]
- CPU time optimization for
GraphInputMatcher
[#7895]
PJRT
- Improve device auto-detection [#7787]
- Move _xla_register_custom_call_target implementation into PjRtComputationClient [#7801]
- Handle SPMD case inside of ComputationClient::WaitDeviceOps [#7796]
GKE
Functionalization
- Add 1-layer gradient accumulation test to check aliasing [#7692]
AMP
- Fix norm data-type when using AMP [#7878]
BETA FEATURES
Op Lowering
- Lower
aten::_linalg_eigh
[#7674] - Fallback
_embedding_bag_backward
and forcesparse=false
[#7584] - Support trilinear by using upstream decomp [#7586]
Higher order ops
- [Fori_loop] Update randint max range to Support bool dtype [#7632]
TorchBench Integration
- [benchmarks] API alignment with PyTorch profiler events [#7930]
- [benchmarks] Add IR dump option when run torchbench [#7927]
- [benchmarks] Use same
matmul
precision between PyTorch and PyTorch/XLA[#7748] - [benchmarks] Introduce verifier to verify the model output correctness against native pytorch [#7724, #7777]
- [benchmarks] Fix moco model issue on XLA [#7257, #7598]
- Type annotation for
benchmarks/
[#7289] - Default with
CUDAGraphs
on for inductor [#7749]
GPU
- Deprecate
XRT
forXLA:CUDA
[#8006]
EXPERIMENTAL FEATURES
Backward Compatibility & APIs that will be removed in 2.7 release:
- Deprecate APIs (deprecated → new):
Deprecated New PRs xla_model.xrt_world_size()
runtime.world_size()
[#7679][#7743] xla_model.get_ordinal()
runtime.global_ordinal()
[#7679] xla_model.get_local_ordinal()
runtime.global_ordinal()
[#7679] - Internalize APIs
xla_model.parse_xla_device()
[#7675]
- Improvement
- Automatic PJRT device detection when importing
torch_xla
[#7787]
- Automatic PJRT device detection when importing
- Add deprecated decorator [#7703]
Distributed
Distributed API
We have aligned our distributed APIs with upstream PyTorch. Previously, we implemented custom distributed APIs, such as torch_xla.xla_model.all_reduce. With the traceable collective support, we now enable torch.distributed.all_reduce
and similar functions for both Dynamo and non-Dynamo cases in torch_xla
.
- Support of upstream distributed APIs (torch.distributed.*) like
all_reduce
,all_gather
,reduce_scatter_tensor
,all_to_all
. Previously we used xla specific distributed APIs in xla_model [#7860, #7950, #8064]. - Introduce
torch_xla.launch()
to launch the multiprocess in order to unify torchrun andtorch_xla.distributed.xla_multiprocessing.spawn()
[#7764, #7648, #7695]. torch.distributed.reduce_scatter_tensor()
: [#7950]- Register sdp lower precision autocast [#7299]
- Add Python binding for xla::DotGeneral [#7863]
- Fix input output alias for custom inplace ops [#7822]
torch_xla.compile
- Support
full_graph
which will error out if there will be more than one graph being executed in the compiled region. [#7776][#7789] - Support the dynamic shape detection which will print a useful error message when the number of different graphs being executed across different executions exceeds the predefined limits. [#7918]
- Support naming each compiled program which will make debug messages more informative. [#7802]
Usability & Debuggability
- Wheel name change to support pip>=24.1: [issue#7697]
- Add
tpu-info
as a dependency oftorch_xla[tpu]
and test: [#7938][#7337] - Support
torch_xla.manual_seed
: [#7340] - Support callback on tensor when async execution is finished [#7984]
- Implement
torch.ops._c10d_functional.broadcast
: [#7770] - Flags
XLA_USE_BF16
,XLA_DOWNCAST_BF16
will be removed in 2.6 release [#7582][#7945]