Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Nov 7, 2023
1 parent e75b6cf commit 978c5e5
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions torch_xla/distributed/spmd/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _canonicalize_color(color: Color) -> str:


def _get_text_color(color: str) -> str:
r, g, b = torch.map(lambda x: int(x, 16),
r, g, b = torch.map(lambda x: int(x, 16),
(color[1:3], color[3:5], color[5:7]))
if (r * 0.299 + g * 0.587 + b * 0.114) > 186:
return "#000000"
Expand All @@ -71,7 +71,7 @@ def visualize_sharding(shape: torch.Size,
use_color: bool = True,
scale: float = 1.,
min_width: int = 9,
max_width: int = 80,
max_width: int = 80,
color_map: Optional[ColorMap] = None):
"""Visualizes a ``Sharding`` using ``rich``."""
if not RICH_ENABLED:
Expand Down Expand Up @@ -115,7 +115,7 @@ def visualize_sharding(shape: torch.Size,
len_after_dim_down = devices_len // last_dim_depth
for i in range(len_after_dim_down):
slices.setdefault((i // widths, i % widths),
device_indices_map[i:i+last_dim_depth])
device_indices_map[i:i + last_dim_depth])
elif sharding[-1] == "}":
# eg: '{devices=[2,2]0,1,2,3}' # 13
device_list = list(sharding[sharding.index(']') + 1:-1])
Expand Down Expand Up @@ -176,8 +176,8 @@ def visualize_sharding(shape: torch.Size,
for j in range(num_cols):
entry = f"{device_kind} " + str(
slices[i,
j])# "entry"# .join([str(s) for s in sorted(slices[i, j])])
width, maybe_height = widths, heights # widths[i, j], heights[i, j]
j]) # "entry"# .join([str(s) for s in sorted(slices[i, j])])
width, maybe_height = widths, heights # widths[i, j], heights[i, j]
width = int(width * base_width * height_to_width_ratio)
if maybe_height is None:
height = 1
Expand All @@ -202,12 +202,13 @@ def visualize_sharding(shape: torch.Size,
padding = tuple(max(x, 0) for x in padding) # type: ignore
col.append(
rich.padding.Padding(
rich.align.Align(entry, "center", vertical="middle"), padding,
style=rich.style.Style(bgcolor=color,
color=text_color)))
rich.align.Align(entry, "center", vertical="middle"),
padding,
style=rich.style.Style(bgcolor=color, color=text_color)))
table.add_row(*col)
console.print(table, end='\n\n')


def visualize_tensor_sharding(ter, **kwargs):
"""Visualizes an array's sharding."""
import torch_xla
Expand Down

0 comments on commit 978c5e5

Please sign in to comment.