diff --git a/torch_xla/distributed/spmd/debugging.py b/torch_xla/distributed/spmd/debugging.py index 6c397b4f8dd..44562d3e263 100644 --- a/torch_xla/distributed/spmd/debugging.py +++ b/torch_xla/distributed/spmd/debugging.py @@ -36,32 +36,32 @@ def __init__(self, callback, module_context): self.module_context = module_context -Color = Union[tuple[float, float, float], str] -ColorMap = Callable[[float], tuple[float, float, float, float]] +# Color = Union[tuple[float, float, float], str] +# ColorMap = Callable[[float], tuple[float, float, float, float]] -def _canonicalize_color(color: Color) -> str: - if isinstance(color, str): - return color - r, g, b = (int(a * 255) for a in color) - return f"#{r:02X}{g:02X}{b:02X}" +# def _canonicalize_color(color: Color) -> str: +# if isinstance(color, str): +# return color +# r, g, b = (int(a * 255) for a in color) +# return f"#{r:02X}{g:02X}{b:02X}" -def _get_text_color(color: str) -> str: - r, g, b = torch.map(lambda x: int(x, 16), - (color[1:3], color[3:5], color[5:7])) - if (r * 0.299 + g * 0.587 + b * 0.114) > 186: - return "#000000" - return "#ffffff" +# def _get_text_color(color: str) -> str: +# r, g, b = torch.map(lambda x: int(x, 16), +# (color[1:3], color[3:5], color[5:7])) +# if (r * 0.299 + g * 0.587 + b * 0.114) > 186: +# return "#000000" +# return "#ffffff" -def make_color_iter(color_map, num_rows, num_cols): - num_colors = num_rows * num_cols - color_values = np.linspace(0, 1, num_colors) - idx = 0 - for _ in range(num_colors): - yield color_map(color_values[idx]) - idx = (idx + num_colors // 2 + bool(num_colors % 2 == 0)) % num_colors +# def make_color_iter(color_map, num_rows, num_cols): +# num_colors = num_rows * num_cols +# color_values = np.linspace(0, 1, num_colors) +# idx = 0 +# for _ in range(num_colors): +# yield color_map(color_values[idx]) +# idx = (idx + num_colors // 2 + bool(num_colors % 2 == 0)) % num_colors def visualize_sharding(shape: torch.Size, @@ -69,8 +69,7 @@ def visualize_sharding(shape: torch.Size, use_color: bool = True, scale: float = 1., min_width: int = 9, - max_width: int = 80, - color_map: Optional[ColorMap] = None): + max_width: int = 80): # , color_map: Optional[ColorMap] = None): """Visualizes a ``Sharding`` using ``rich``.""" if not RICH_ENABLED: raise ValueError("`visualize_sharding` requires `rich` to be installed.") @@ -130,12 +129,12 @@ def visualize_sharding(shape: torch.Size, console = rich.console.Console(width=max_width) use_color = use_color and console.color_system is not None - if use_color and not color_map: - try: - import matplotlib as mpl - color_map = mpl.colormaps["tab20b"] - except ModuleNotFoundError: - use_color = False + # if use_color and not color_map: + # try: + # import matplotlib as mpl + # color_map = mpl.colormaps["tab20b"] + # except ModuleNotFoundError: + # use_color = False base_height = int(3 * scale) aspect_ratio = (shape[1] if len(shape) == 2 else 1) / shape[0] @@ -152,7 +151,7 @@ def visualize_sharding(shape: torch.Size, # set the device kind to TPU as default since `sharding` here is `str`, TODO(@manfei): get device kind from commands for TPU/GPU/CPU device_kind = 'TPU' # next(iter(sharding.device_set)).platform.upper() - color_iter = make_color_iter(color_map, num_rows, num_cols) + # color_iter = make_color_iter(color_map, num_rows, num_cols) table = rich.table.Table( show_header=False, show_lines=not use_color, @@ -177,12 +176,12 @@ def visualize_sharding(shape: torch.Size, bottom_padding = top_padding + remainder if use_color: - color = _canonicalize_color(next(color_iter)[:3]) - text_color = _get_text_color(color) - top_padding += 1 - bottom_padding += 1 - left_padding += 1 - right_padding += 1 + # color = _canonicalize_color(next(color_iter)[:3]) + # text_color = _get_text_color(color) + # top_padding += 1 + # bottom_padding += 1 + # left_padding += 1 + # right_padding += 1 else: color = None text_color = None