-
Notifications
You must be signed in to change notification settings - Fork 487
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
214 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
from collections.abc import Sequence | ||
import functools | ||
import string | ||
import sys | ||
from typing import Any, Callable, Optional, Union | ||
import weakref | ||
|
||
import numpy as np | ||
import torch | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.runtime as xr | ||
import torch_xla.experimental.xla_sharding as xs | ||
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor | ||
from torch_xla.experimental.xla_sharding import Mesh | ||
|
||
# pytype: disable=import-error | ||
try: | ||
import rich | ||
import rich.align | ||
import rich.box | ||
import rich.console | ||
import rich.padding | ||
import rich.style | ||
import rich.table | ||
RICH_ENABLED = True | ||
except: | ||
RICH_ENABLED = False | ||
|
||
# Sharding visualization | ||
sharding_callbacks = weakref.WeakValueDictionary() # type: ignore | ||
_INSPECT_SHARDING_CALL_NAME = "InspectSharding" | ||
|
||
class ShardingCallbackInfo: | ||
def __init__(self, callback, module_context): | ||
self.callback = callback | ||
self.module_context = module_context | ||
|
||
Color = Union[tuple[float, float, float], str] | ||
ColorMap = Callable[[float], tuple[float, float, float, float]] | ||
|
||
def _canonicalize_color(color: Color) -> str: | ||
if isinstance(color, str): | ||
return color | ||
r, g, b = (int(a * 255) for a in color) | ||
return f"#{r:02X}{g:02X}{b:02X}" | ||
|
||
def _get_text_color(color: str) -> str: | ||
r, g, b = torch.map(lambda x: int(x, 16), (color[1:3], color[3:5], color[5:7])) | ||
if (r * 0.299 + g * 0.587 + b * 0.114) > 186: | ||
return "#000000" | ||
return "#ffffff" | ||
|
||
def make_color_iter(color_map, num_rows, num_cols): | ||
num_colors = num_rows * num_cols | ||
color_values = np.linspace(0, 1, num_colors) | ||
idx = 0 | ||
for _ in range(num_colors): | ||
yield color_map(color_values[idx]) | ||
idx = (idx + num_colors // 2 + bool(num_colors % 2 == 0)) % num_colors | ||
|
||
# 把sharding画出来 | ||
def visualize_sharding(shape: torch.Size, sharding: str, | ||
use_color: bool = True, scale: float = 1., | ||
min_width: int = 9, max_width: int = 80, | ||
color_map: Optional[ColorMap] = None): | ||
"""Visualizes a ``Sharding`` using ``rich``.""" | ||
if not RICH_ENABLED: | ||
raise ValueError("`visualize_sharding` requires `rich` to be installed.") | ||
|
||
if len(shape) > 2 or len(shape) < 1: | ||
raise ValueError( | ||
"`visualize_sharding` only works for shapes with 1 and 2 dimensions.") | ||
|
||
# sharding[sharding.index(']')+1:-1]# sharding.devices_indices_map(tuple(shape)) | ||
slices: dict[tuple[int, ...], set[int]] = {} | ||
heights: dict[tuple[int, ...], Optional[float]] = {} | ||
widths: dict[tuple[int, ...], float] = {} | ||
|
||
if len(sharding)>0: | ||
# sharding is longer than 0 | ||
# eg: '{devices=[2,2]0,1,2,3}' # 13 | ||
# eg: '{replicated}' | ||
# eg: '{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}' # 15 | ||
if sharding == '{replicated}': | ||
# eg: '{replicated}' | ||
heights = 1 | ||
widths = 1 | ||
num_devices = xr.global_runtime_device_count() | ||
device_ids = list(range(num_devices)) | ||
slices.setdefault((0, 0), device_ids) | ||
else: | ||
# `device_indices_map`: [0, 1, 2, 3] | ||
# `sharding_spac`: [2, 2] | ||
sharding_spac = sharding[sharding.index('['):sharding.index(']')+1] | ||
print('sharding_spac: ', sharding_spac) | ||
if len(sharding) >= 25 and sharding[-24:-1] == 'last_tile_dim_replicate': | ||
device_list = list(sharding[sharding.index(']')+1:-24]) | ||
print("device_list") | ||
print(device_list) | ||
device_indices_map = [int(i) for i in device_list[:-1] if i != ','] | ||
heights = int(sharding_spac[1]) | ||
widths = int(sharding_spac[3]) | ||
last_dim_depth = int(sharding_spac[5]) | ||
devices_len = len(device_indices_map) | ||
len_after_dim_down = devices_len//last_dim_depth | ||
for i in range(len_after_dim_down): | ||
slices.setdefault((i//widths, i%widths), device_indices_map[i:i+last_dim_depth]) | ||
elif sharding[-1] == "}": | ||
# eg: '{devices=[2,2]0,1,2,3}' # 13 | ||
device_list = list(sharding[sharding.index(']')+1:-1]) | ||
# print('device_list: ', device_list) | ||
device_indices_map = [int(i) for i in device_list if i != ','] | ||
# print('device_indices_map: ', device_indices_map) | ||
heights = int(sharding_spac[1]) | ||
# print('heights: ', heights) | ||
widths = int(sharding_spac[3]) | ||
# print('widths: ', widths) | ||
devices_len = len(device_indices_map) | ||
# print('devices_len: ', devices_len) | ||
for i in range(devices_len): | ||
slices.setdefault((i//widths, i%widths), device_indices_map[i]) | ||
else: | ||
raise ValueError("sharding is not organized as expected") | ||
else: | ||
raise ValueError("sharding has no value") | ||
|
||
num_rows = heights | ||
num_cols = widths | ||
print('slices', slices) | ||
|
||
# # eg: '{replicated}' | ||
# if sharding = '{replicated}': | ||
# # print it code here !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! | ||
# heights = 1 | ||
# widths = 1 | ||
# num_devices = xr.global_runtime_device_count() | ||
# device_ids = list(range(num_devices)) | ||
# slices.setdefault((0, 0), device_ids) | ||
|
||
# # eg: '{devices=[2,2]0,1,2,3}' # 13 | ||
# # eg: '{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}' # 15 | ||
# if len(shape) > 2 or len(shape) < 1: | ||
# raise ValueError( | ||
# "`visualize_sharding` only works for shapes with 1 and 2 dimensions.") | ||
|
||
console = rich.console.Console(width=max_width) | ||
use_color = use_color and console.color_system is not None | ||
if use_color and not color_map: | ||
try: | ||
import matplotlib as mpl # pytype: disable=import-error | ||
color_map = mpl.colormaps["tab20b"] | ||
except ModuleNotFoundError: | ||
use_color = False | ||
|
||
base_height = int(10 * scale) | ||
aspect_ratio = (shape[1] if len(shape) == 2 else 1) / shape[0] | ||
base_width = int(base_height * aspect_ratio) | ||
height_to_width_ratio = 2.5 | ||
|
||
# eg: '{devices=[2,2]0,1,2,3}' # 13 | ||
# eg: '{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}' # 15 | ||
|
||
# slcs is the data we saved on this slice | ||
# `device_indices_map`: [0, 1, 2, 3] | ||
# `sharding_spac`: [2, 2] | ||
|
||
# set the device kind to TPU as default since `sharding` here is `str`, TODO(@manfei): get device kind from commands for TPU/GPU/CPU | ||
device_kind = 'TPU' # next(iter(sharding.device_set)).platform.upper() | ||
|
||
color_iter = make_color_iter(color_map, num_rows, num_cols) | ||
table = rich.table.Table(show_header=False, show_lines=not use_color, | ||
padding=0, | ||
highlight=not use_color, pad_edge=False, | ||
box=rich.box.SQUARE if not use_color else None) | ||
for i in range(num_rows): | ||
col = [] | ||
for j in range(num_cols): | ||
entry = f"{device_kind} "+ str(slices[i, j])# "entry"# .join([str(s) for s in sorted(slices[i, j])]) | ||
width, maybe_height = widths, heights# widths[i, j], heights[i, j] | ||
width = int(width * base_width * height_to_width_ratio) | ||
if maybe_height is None: | ||
height = 1 | ||
else: | ||
height = int(maybe_height * base_height) | ||
width = min(max(width, min_width), max_width) | ||
left_padding, remainder = divmod(width - len(entry) - 2, 2) | ||
right_padding = left_padding + remainder | ||
top_padding, remainder = divmod(height - 2, 2) | ||
bottom_padding = top_padding + remainder | ||
if use_color: | ||
color = _canonicalize_color(next(color_iter)[:3]) | ||
text_color = _get_text_color(color) | ||
top_padding += 1 | ||
bottom_padding += 1 | ||
left_padding += 1 | ||
right_padding += 1 | ||
else: | ||
color = None | ||
text_color = None | ||
padding = (top_padding, right_padding, bottom_padding, left_padding) | ||
padding = tuple(max(x, 0) for x in padding) # type: ignore | ||
col.append( | ||
rich.padding.Padding( | ||
rich.align.Align(entry, "center", vertical="middle"), padding, | ||
style=rich.style.Style(bgcolor=color, | ||
color=text_color))) | ||
table.add_row(*col) | ||
console.print(table, end='\n\n') | ||
|
||
def visualize_tensor_sharding(ter, **kwargs): | ||
"""Visualizes an array's sharding.""" | ||
import torch_xla | ||
sharding = torch_xla._XLAC._get_xla_sharding_spec(ter) | ||
return visualize_sharding(ter.shape, sharding, **kwargs) |