Skip to content

Commit

Permalink
reset xla_sharding file
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Nov 22, 2024
1 parent c0c10f1 commit 195a6e1
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,8 @@ def disable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor],

def mark_sharding(
t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
partition_spec: Tuple[Union[Tuple, int, str, None]]) -> XLAShardedTensor:
partition_spec: Tuple[Union[Tuple, int, str, None],
...]) -> XLAShardedTensor:
"""
Annotates the tensor provided with XLA partition spec. Internally,
it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass.
Expand Down

0 comments on commit 195a6e1

Please sign in to comment.