Skip to content

Commit

Permalink
Transfer data directly to the device (#5772)
Browse files Browse the repository at this point in the history
* Transfer data directly to the device (#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
  • Loading branch information
will-cromar authored and ManfeiBai committed Nov 29, 2023
1 parent 5e9edcc commit 96ef6d7
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 119 deletions.
168 changes: 143 additions & 25 deletions test/spmd/test_spmd_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,35 @@
import math
import numpy as np
import os
import io
import rich

import torch
import torch_xla
import torch_xla.runtime as xr
import torch_xla.utils.utils as xu
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding


import test_xla_sharding_base


class DebuggingSpmdTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()# os.environ["XLA_USE_SPMD"] = "1"
xr.use_spmd()
super().setUpClass()

@unittest.skipIf(xr.device_type() == 'CPU', "skipped on CPU before enable")
@unittest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'),
"TODO(manfei): enable it.")
@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
def test_debugging_spmd_single_host_tiled(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = xr.global_runtime_device_count()
mesh_shape = (2, num_devices // 2)
Expand All @@ -38,36 +43,127 @@ def test_debugging_spmd_single_host_tiled(self):
partition_spec = (0, 1)
xs.mark_sharding(t, mesh, partition_spec)
sharding = torch_xla._XLAC._get_xla_sharding_spec(t)
print("sharding is:")
print(sharding)
print("then print:")
visualize_tensor_sharding(t)
generatedtable = visualize_tensor_sharding(t)
console = rich.console.Console(file=io.StringIO(), width=120)
console.print(generatedtable)
output = console.file.getvalue()

fake_console = rich.console.Console(file=io.StringIO(), width=120)
color = None
text_color = None
fask_table = rich.table.Table(
show_header=False,
show_lines=True,
padding=0,
highlight=True,
pad_edge=False,
box=rich.box.SQUARE)
col = []
col.append(
rich.padding.Padding(
rich.align.Align('TPU 0', "center", vertical="middle"),
(2,1,2,1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 1', "center", vertical="middle"),
(2,1,2,1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 2', "center", vertical="middle"),
(2,1,2,1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 3', "center", vertical="middle"),
(2,1,2,1),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
col = []
col.append(
rich.padding.Padding(
rich.align.Align('TPU 4', "center", vertical="middle"),
(2,1,2,1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 5', "center", vertical="middle"),
(2,1,2,1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 6', "center", vertical="middle"),
(2,1,2,1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 7', "center", vertical="middle"),
(2,1,2,1),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
fake_console.print(fask_table)
fake_output = fake_console.file.getvalue()
assert output == fake_output

@unittest.skipIf(xr.device_type() == 'CPU', "skipped on CPU before enable")
@unittest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'),
"TODO(manfei): enable it.")

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
def test_single_host_partial_replication(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = xr.global_runtime_device_count()
mesh_shape = (2, num_devices // 2)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

partition_spec = (0, None)
t = torch.randn(8, 32, device=device)
t = torch.randn(8, 32, device=device)
xs.mark_sharding(t, mesh, (0, None))
sharding = torch_xla._XLAC._get_xla_sharding_spec(t)
print("sharding is: ")
print(sharding)
print("then print: ")
visualize_tensor_sharding(t)
generatedtable = visualize_tensor_sharding(t)
console = rich.console.Console(file=io.StringIO(), width=120)
console.print(generatedtable)
output = console.file.getvalue()

color = None
text_color = None
fask_table = rich.table.Table(
show_header=False,
show_lines=True,
padding=0,
highlight=True,
pad_edge=False,
box=rich.box.SQUARE)
col = []
col.append(
rich.padding.Padding(
rich.align.Align('TPU [0, 1, 2, 3]', "center", vertical="middle"),
(2,0,2,0),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
col = []
col.append(
rich.padding.Padding(
rich.align.Align('TPU [4, 5, 6, 7]', "center", vertical="middle"),
(2,0,2,0),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
console.print(fask_table)
fake_console = rich.console.Console(file=io.StringIO(), width=120)
fake_console.print(fask_table)
fake_output = fake_console.file.getvalue()
assert output == fake_output

@unittest.skipIf(xr.device_type() == 'CPU', "skipped on CPU before enable")
@unittest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'),
"TODO(manfei): enable it.")

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
def test_single_host_replicated(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = xr.global_runtime_device_count()
mesh_shape = (2, num_devices // 2)
Expand All @@ -78,10 +174,32 @@ def test_single_host_replicated(self):
t = torch.randn(8, 32, device=device)
xs.mark_sharding(t, mesh, partition_spec_replicated)
sharding = torch_xla._XLAC._get_xla_sharding_spec(t)
print("sharding is: ")
print(sharding)
print("then print: ")
visualize_tensor_sharding(t)
generatedtable = visualize_tensor_sharding(t)
console = rich.console.Console(file=io.StringIO(), width=120)
console.print(generatedtable)
output = console.file.getvalue()

color = None
text_color = None
fask_table = rich.table.Table(
show_header=False,
show_lines=True,
padding=0,
highlight=True,
pad_edge=False,
box=rich.box.SQUARE)
col = []
col.append(
rich.padding.Padding(
rich.align.Align('TPU [0, 1, 2, 3, 4, 5, 6, 7]', "center", vertical="middle"),
(0,0,1,0),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
fake_console = rich.console.Console(file=io.StringIO(), width=120)
fake_console.print(fask_table)
fake_output = fake_console.file.getvalue()
assert output == fake_output


if __name__ == '__main__':
test = unittest.main()
Expand Down
11 changes: 11 additions & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,17 @@ cc_library(
]
)

cc_library(
name = "tensor_source",
hdrs = ["tensor_source.h"],
deps = [
":debug_macros",
"@xla//xla:literal",
"@xla//xla:shape_util",
"@torch//:headers",
]
)

cc_library(
name = "types",
hdrs = ["types.h"],
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/runtime/tensor_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ class AtenSource : public TensorSource {
at::ScalarType target_torch_type = TorchTypeFromXlaType(primitive_type());
if (target_torch_type != tensor.type().scalarType()) {
TORCH_LAZY_COUNTER("AtenSourceDowncasts", 1);
tensor_ = std::move(tensor.to(target_torch_type).contiguous());
} else {
tensor_ = std::move(tensor.contiguous());
}
tensor_ = std::move(tensor.to(target_torch_type, /*non_blocking=*/false,
/*copy=*/true, at::MemoryFormat::Contiguous));
}

const void* data() const override { return tensor_.const_data_ptr(); }
Expand Down
19 changes: 15 additions & 4 deletions torch_xla/distributed/spmd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,21 @@
XLAPatchedLinear, mark_sharding, clear_sharding,
wrap_if_sharded, xla_patched_nn_linear_forward)
from .api import xla_distribute_tensor, xla_distribute_module
# from .debugging import visualize_tensor_sharding

__all__ = [
"XLAShard", "XLAShardedTensor", "Mesh", "HybridMesh", "ShardingType",
"ShardingSpec", "XLAPatchedLinear", "mark_sharding", "clear_sharding",
"wrap_if_sharded", "xla_distribute_tensor", "xla_distribute_module",
"xla_patched_nn_linear_forward"
"XLAShard",
"XLAShardedTensor",
"Mesh",
"HybridMesh",
"ShardingType",
"ShardingSpec",
"XLAPatchedLinear",
"mark_sharding",
"clear_sharding",
"wrap_if_sharded",
"xla_distribute_tensor",
"xla_distribute_module",
"xla_patched_nn_linear_forward",
"visualize_tensor_sharding",
]
32 changes: 11 additions & 21 deletions torch_xla/distributed/spmd/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
import torch_xla.runtime as xr
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh

# pytype: disable=import-error
try:
import rich
import rich.align
Expand All @@ -27,7 +25,7 @@
RICH_ENABLED = False

# Sharding visualization
sharding_callbacks = weakref.WeakValueDictionary() # type: ignore
sharding_callbacks = weakref.WeakValueDictionary()
_INSPECT_SHARDING_CALL_NAME = "InspectSharding"


Expand Down Expand Up @@ -81,7 +79,6 @@ def visualize_sharding(shape: torch.Size,
raise ValueError(
"`visualize_sharding` only works for shapes with 1 and 2 dimensions.")

# sharding[sharding.index(']')+1:-1]# sharding.devices_indices_map(tuple(shape))
slices: dict[tuple[int, ...], set[int]] = {}
heights: dict[tuple[int, ...], Optional[float]] = {}
widths: dict[tuple[int, ...], float] = {}
Expand All @@ -102,32 +99,25 @@ def visualize_sharding(shape: torch.Size,
# `device_indices_map`: [0, 1, 2, 3]
# `sharding_spac`: [2, 2]
sharding_spac = sharding[sharding.index('['):sharding.index(']') + 1]
print('sharding_spac: ', sharding_spac)
if len(sharding) >= 25 and sharding[-24:-1] == 'last_tile_dim_replicate':
device_list = list(sharding[sharding.index(']') + 1:-24])
print("device_list")
print(device_list)
device_indices_map = [int(i) for i in device_list[:-1] if i != ',']
heights = int(sharding_spac[1])
widths = int(sharding_spac[3])
last_dim_depth = int(sharding_spac[5])
devices_len = len(device_indices_map)
len_after_dim_down = devices_len // last_dim_depth
for i in range(len_after_dim_down):
slices.setdefault((i // widths, i % widths),
device_indices_map[i:i + last_dim_depth])
slices.setdefault(
(i // widths, i % widths),
device_indices_map[i*last_dim_depth:(i + 1)*last_dim_depth])
elif sharding[-1] == "}":
# eg: '{devices=[2,2]0,1,2,3}' # 13
device_list = list(sharding[sharding.index(']') + 1:-1])
# print('device_list: ', device_list)
device_indices_map = [int(i) for i in device_list if i != ',']
# print('device_indices_map: ', device_indices_map)
heights = int(sharding_spac[1])
# print('heights: ', heights)
widths = int(sharding_spac[3])
# print('widths: ', widths)
devices_len = len(device_indices_map)
# print('devices_len: ', devices_len)
for i in range(devices_len):
slices.setdefault((i // widths, i % widths), device_indices_map[i])
else:
Expand All @@ -137,21 +127,20 @@ def visualize_sharding(shape: torch.Size,

num_rows = heights
num_cols = widths
print('slices', slices)

console = rich.console.Console(width=max_width)
use_color = use_color and console.color_system is not None
if use_color and not color_map:
try:
import matplotlib as mpl # pytype: disable=import-error
import matplotlib as mpl
color_map = mpl.colormaps["tab20b"]
except ModuleNotFoundError:
use_color = False

base_height = int(10 * scale)
base_height = int(3 * scale)
aspect_ratio = (shape[1] if len(shape) == 2 else 1) / shape[0]
base_width = int(base_height * aspect_ratio)
height_to_width_ratio = 2.5
height_to_width_ratio = 1.5

# eg: '{devices=[2,2]0,1,2,3}' # 13
# eg: '{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}' # 15
Expand All @@ -174,9 +163,7 @@ def visualize_sharding(shape: torch.Size,
for i in range(num_rows):
col = []
for j in range(num_cols):
entry = f"{device_kind} " + str(
slices[i,
j]) # "entry"# .join([str(s) for s in sorted(slices[i, j])])
entry = f"{device_kind} " + str(slices[i, j])
width, maybe_height = widths, heights # widths[i, j], heights[i, j]
width = int(width * base_width * height_to_width_ratio)
if maybe_height is None:
Expand All @@ -188,6 +175,7 @@ def visualize_sharding(shape: torch.Size,
right_padding = left_padding + remainder
top_padding, remainder = divmod(height - 2, 2)
bottom_padding = top_padding + remainder

if use_color:
color = _canonicalize_color(next(color_iter)[:3])
text_color = _get_text_color(color)
Expand All @@ -198,8 +186,10 @@ def visualize_sharding(shape: torch.Size,
else:
color = None
text_color = None

padding = (top_padding, right_padding, bottom_padding, left_padding)
padding = tuple(max(x, 0) for x in padding) # type: ignore

col.append(
rich.padding.Padding(
rich.align.Align(entry, "center", vertical="middle"),
Expand Down
Loading

0 comments on commit 96ef6d7

Please sign in to comment.