Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Dec 3, 2023
1 parent 603fcbf commit 58a7763
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions torch_xla/distributed/spmd/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def visualize_sharding(sharding: str,
if len(device_list_original) == 2 and device_list_original[1] == '}':
try:
device_list_original_first = device_list_original[0]
device_list = device_list_original_first[device_list_original_first.index(']') + 1:]
device_list = device_list_original_first[device_list_original_first.
index(']') + 1:]
device_indices_map = [int(s) for s in device_list.split(',')]
heights = int(sharding_spac[1])
widths = int(sharding_spac[3])
Expand All @@ -81,24 +82,28 @@ def visualize_sharding(sharding: str,
(i // widths, i % widths),
device_indices_map[i * last_dim_depth:(i + 1) * last_dim_depth])
except:
raise ValueError("sharding ", sharding, " is not organized as expected")
raise ValueError("sharding ", sharding,
" is not organized as expected")
else:
# eg: '{devices=[2,2]0,1,2,3}'
try:
assert device_list_original[0][-1] == '}'
except:
raise ValueError("sharding ", sharding, " is not organized as expected")
raise ValueError("sharding ", sharding,
" is not organized as expected")
try:
device_list_original_first = device_list_original[0]
device_list = device_list_original_first[device_list_original_first.index(']') + 1:-1]
device_list = device_list_original_first[device_list_original_first.
index(']') + 1:-1]
device_indices_map = [int(i) for i in device_list.split(',')]
heights = int(sharding_spac[1])
widths = int(sharding_spac[3])
devices_len = len(device_indices_map)
for i in range(devices_len):
slices.setdefault((i // widths, i % widths), device_indices_map[i])
except:
raise ValueError("sharding ", sharding, " is not organized as expected")
raise ValueError("sharding ", sharding,
" is not organized as expected")
else:
raise ValueError("sharding length should >= 0")

Expand Down Expand Up @@ -152,8 +157,10 @@ def visualize_sharding(sharding: str,

def visualize_tensor_sharding(t, **kwargs):
"""Visualizes an array's sharding."""

# XLAShardedTensor is-a torch.Tensor
def maybe_unwrap(t: torch.Tensor) -> torch.Tensor:
return t.global_tensor if isinstance(t, XLAShardedTensor) else t

sharding = torch_xla._XLAC._get_xla_sharding_spec(maybe_unwrap(t))
return visualize_sharding(sharding, **kwargs)

0 comments on commit 58a7763

Please sign in to comment.