From c4f2771849e0827badf934500d29cedc83d73b8d Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Mon, 11 Nov 2024 16:55:17 -0800 Subject: [PATCH] Improve some typing annotations (#8369) --- torch_xla/core/xla_model.py | 4 ++-- torch_xla/distributed/spmd/xla_sharding.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 607f1cb9c57..6e1936c258a 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -135,7 +135,7 @@ def is_master_ordinal(local: bool = True) -> bool: return ordinal == 0 -def master_print(*args: Tuple[Any, ...], +def master_print(*args: Any, fd: TextIO = sys.stdout, local: bool = False, flush: bool = False): @@ -984,7 +984,7 @@ def _reduce_scatter_coalesced( def add_step_closure(closure: Callable[..., Any], - args: Tuple[Any] = (), + args: Tuple[Any, ...] = (), run_async: bool = False): """Adds a closure to the list of the ones to be run at the end of the step. diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index a7b7ec758bd..7085325513e 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -575,7 +575,8 @@ def disable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor], def mark_sharding( t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, - partition_spec: Tuple[Union[Tuple, int, str, None]]) -> XLAShardedTensor: + partition_spec: Tuple[Union[Tuple, int, str, None], + ...]) -> XLAShardedTensor: """ Annotates the tensor provided with XLA partition spec. Internally, it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass.