Skip to content

Commit

Permalink
Update debugging.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Nov 26, 2023
1 parent 8f3fdfc commit eda1da0
Showing 1 changed file with 34 additions and 35 deletions.
69 changes: 34 additions & 35 deletions torch_xla/distributed/spmd/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,41 +36,40 @@ def __init__(self, callback, module_context):
self.module_context = module_context


Color = Union[tuple[float, float, float], str]
ColorMap = Callable[[float], tuple[float, float, float, float]]
# Color = Union[tuple[float, float, float], str]
# ColorMap = Callable[[float], tuple[float, float, float, float]]


def _canonicalize_color(color: Color) -> str:
if isinstance(color, str):
return color
r, g, b = (int(a * 255) for a in color)
return f"#{r:02X}{g:02X}{b:02X}"
# def _canonicalize_color(color: Color) -> str:
# if isinstance(color, str):
# return color
# r, g, b = (int(a * 255) for a in color)
# return f"#{r:02X}{g:02X}{b:02X}"


def _get_text_color(color: str) -> str:
r, g, b = torch.map(lambda x: int(x, 16),
(color[1:3], color[3:5], color[5:7]))
if (r * 0.299 + g * 0.587 + b * 0.114) > 186:
return "#000000"
return "#ffffff"
# def _get_text_color(color: str) -> str:
# r, g, b = torch.map(lambda x: int(x, 16),
# (color[1:3], color[3:5], color[5:7]))
# if (r * 0.299 + g * 0.587 + b * 0.114) > 186:
# return "#000000"
# return "#ffffff"


def make_color_iter(color_map, num_rows, num_cols):
num_colors = num_rows * num_cols
color_values = np.linspace(0, 1, num_colors)
idx = 0
for _ in range(num_colors):
yield color_map(color_values[idx])
idx = (idx + num_colors // 2 + bool(num_colors % 2 == 0)) % num_colors
# def make_color_iter(color_map, num_rows, num_cols):
# num_colors = num_rows * num_cols
# color_values = np.linspace(0, 1, num_colors)
# idx = 0
# for _ in range(num_colors):
# yield color_map(color_values[idx])
# idx = (idx + num_colors // 2 + bool(num_colors % 2 == 0)) % num_colors


def visualize_sharding(shape: torch.Size,
sharding: str,
use_color: bool = True,
scale: float = 1.,
min_width: int = 9,
max_width: int = 80,
color_map: Optional[ColorMap] = None):
max_width: int = 80): # , color_map: Optional[ColorMap] = None):
"""Visualizes a ``Sharding`` using ``rich``."""
if not RICH_ENABLED:
raise ValueError("`visualize_sharding` requires `rich` to be installed.")
Expand Down Expand Up @@ -130,12 +129,12 @@ def visualize_sharding(shape: torch.Size,

console = rich.console.Console(width=max_width)
use_color = use_color and console.color_system is not None
if use_color and not color_map:
try:
import matplotlib as mpl
color_map = mpl.colormaps["tab20b"]
except ModuleNotFoundError:
use_color = False
# if use_color and not color_map:
# try:
# import matplotlib as mpl
# color_map = mpl.colormaps["tab20b"]
# except ModuleNotFoundError:
# use_color = False

base_height = int(3 * scale)
aspect_ratio = (shape[1] if len(shape) == 2 else 1) / shape[0]
Expand All @@ -152,7 +151,7 @@ def visualize_sharding(shape: torch.Size,
# set the device kind to TPU as default since `sharding` here is `str`, TODO(@manfei): get device kind from commands for TPU/GPU/CPU
device_kind = 'TPU' # next(iter(sharding.device_set)).platform.upper()

color_iter = make_color_iter(color_map, num_rows, num_cols)
# color_iter = make_color_iter(color_map, num_rows, num_cols)
table = rich.table.Table(
show_header=False,
show_lines=not use_color,
Expand All @@ -177,12 +176,12 @@ def visualize_sharding(shape: torch.Size,
bottom_padding = top_padding + remainder

if use_color:
color = _canonicalize_color(next(color_iter)[:3])
text_color = _get_text_color(color)
top_padding += 1
bottom_padding += 1
left_padding += 1
right_padding += 1
# color = _canonicalize_color(next(color_iter)[:3])
# text_color = _get_text_color(color)
# top_padding += 1
# bottom_padding += 1
# left_padding += 1
# right_padding += 1
else:
color = None
text_color = None
Expand Down

0 comments on commit eda1da0

Please sign in to comment.