Skip to content

Commit

Permalink
Transfer data directly to the device (#5772)
Browse files Browse the repository at this point in the history
* Transfer data directly to the device (#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
  • Loading branch information
will-cromar authored and ManfeiBai committed Nov 29, 2023
1 parent 62138b1 commit e753e61
Show file tree
Hide file tree
Showing 14 changed files with 391 additions and 268 deletions.
168 changes: 143 additions & 25 deletions test/spmd/test_spmd_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,35 @@
import math
import numpy as np
import os
import io
import rich

import torch
import torch_xla
import torch_xla.runtime as xr
import torch_xla.utils.utils as xu
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding


import test_xla_sharding_base


class DebuggingSpmdTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()# os.environ["XLA_USE_SPMD"] = "1"
xr.use_spmd()
super().setUpClass()

@unittest.skipIf(xr.device_type() == 'CPU', "skipped on CPU before enable")
@unittest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'),
"TODO(manfei): enable it.")
@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
def test_debugging_spmd_single_host_tiled(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = xr.global_runtime_device_count()
mesh_shape = (2, num_devices // 2)
Expand All @@ -38,36 +43,127 @@ def test_debugging_spmd_single_host_tiled(self):
partition_spec = (0, 1)
xs.mark_sharding(t, mesh, partition_spec)
sharding = torch_xla._XLAC._get_xla_sharding_spec(t)
print("sharding is:")
print(sharding)
print("then print:")
visualize_tensor_sharding(t)
generatedtable = visualize_tensor_sharding(t)
console = rich.console.Console(file=io.StringIO(), width=120)
console.print(generatedtable)
output = console.file.getvalue()

fake_console = rich.console.Console(file=io.StringIO(), width=120)
color = None
text_color = None
fask_table = rich.table.Table(
show_header=False,
show_lines=True,
padding=0,
highlight=True,
pad_edge=False,
box=rich.box.SQUARE)
col = []
col.append(
rich.padding.Padding(
rich.align.Align('TPU 0', "center", vertical="middle"),
(2,1,2,1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 1', "center", vertical="middle"),
(2,1,2,1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 2', "center", vertical="middle"),
(2,1,2,1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 3', "center", vertical="middle"),
(2,1,2,1),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
col = []
col.append(
rich.padding.Padding(
rich.align.Align('TPU 4', "center", vertical="middle"),
(2,1,2,1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 5', "center", vertical="middle"),
(2,1,2,1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 6', "center", vertical="middle"),
(2,1,2,1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 7', "center", vertical="middle"),
(2,1,2,1),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
fake_console.print(fask_table)
fake_output = fake_console.file.getvalue()
assert output == fake_output

@unittest.skipIf(xr.device_type() == 'CPU', "skipped on CPU before enable")
@unittest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'),
"TODO(manfei): enable it.")

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
def test_single_host_partial_replication(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = xr.global_runtime_device_count()
mesh_shape = (2, num_devices // 2)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

partition_spec = (0, None)
t = torch.randn(8, 32, device=device)
t = torch.randn(8, 32, device=device)
xs.mark_sharding(t, mesh, (0, None))
sharding = torch_xla._XLAC._get_xla_sharding_spec(t)
print("sharding is: ")
print(sharding)
print("then print: ")
visualize_tensor_sharding(t)
generatedtable = visualize_tensor_sharding(t)
console = rich.console.Console(file=io.StringIO(), width=120)
console.print(generatedtable)
output = console.file.getvalue()

color = None
text_color = None
fask_table = rich.table.Table(
show_header=False,
show_lines=True,
padding=0,
highlight=True,
pad_edge=False,
box=rich.box.SQUARE)
col = []
col.append(
rich.padding.Padding(
rich.align.Align('TPU [0, 1, 2, 3]', "center", vertical="middle"),
(2,0,2,0),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
col = []
col.append(
rich.padding.Padding(
rich.align.Align('TPU [4, 5, 6, 7]', "center", vertical="middle"),
(2,0,2,0),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
console.print(fask_table)
fake_console = rich.console.Console(file=io.StringIO(), width=120)
fake_console.print(fask_table)
fake_output = fake_console.file.getvalue()
assert output == fake_output

@unittest.skipIf(xr.device_type() == 'CPU', "skipped on CPU before enable")
@unittest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'),
"TODO(manfei): enable it.")

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
def test_single_host_replicated(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = xr.global_runtime_device_count()
mesh_shape = (2, num_devices // 2)
Expand All @@ -78,10 +174,32 @@ def test_single_host_replicated(self):
t = torch.randn(8, 32, device=device)
xs.mark_sharding(t, mesh, partition_spec_replicated)
sharding = torch_xla._XLAC._get_xla_sharding_spec(t)
print("sharding is: ")
print(sharding)
print("then print: ")
visualize_tensor_sharding(t)
generatedtable = visualize_tensor_sharding(t)
console = rich.console.Console(file=io.StringIO(), width=120)
console.print(generatedtable)
output = console.file.getvalue()

color = None
text_color = None
fask_table = rich.table.Table(
show_header=False,
show_lines=True,
padding=0,
highlight=True,
pad_edge=False,
box=rich.box.SQUARE)
col = []
col.append(
rich.padding.Padding(
rich.align.Align('TPU [0, 1, 2, 3, 4, 5, 6, 7]', "center", vertical="middle"),
(0,0,1,0),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
fake_console = rich.console.Console(file=io.StringIO(), width=120)
fake_console.print(fask_table)
fake_output = fake_console.file.getvalue()
assert output == fake_output


if __name__ == '__main__':
test = unittest.main()
Expand Down
64 changes: 42 additions & 22 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
load(
"//bazel:rules_def.bzl",
"ptxla_cc_test",
)

load(
"@tsl//tsl/platform/default:cuda_build_defs.bzl",
"if_cuda_is_configured",
Expand Down Expand Up @@ -46,6 +51,7 @@ cc_library(
":metrics_reader",
":metrics",
":sys_util",
":tensor_source",
":types",
":util",
":xla_coordinator",
Expand Down Expand Up @@ -78,6 +84,7 @@ cc_library(
":env_vars",
":multi_wait",
":stablehlo_helper",
":tensor_source",
":tf_logging",
":thread_pool",
":xla_coordinator",
Expand Down Expand Up @@ -264,6 +271,17 @@ cc_library(
],
)

cc_library(
name = "tensor_source",
hdrs = ["tensor_source.h"],
deps = [
":debug_macros",
"@xla//xla:literal",
"@xla//xla:shape_util",
"@torch//:headers",
]
)

cc_library(
name = "types",
hdrs = ["types.h"],
Expand Down Expand Up @@ -339,25 +357,27 @@ ptxla_cc_test(
],
)

# TODO(goranpetrovic): reenable when `xla_cc_test` is fixed upstream.
# xla_cc_test(
# name = "pjrt_computation_client_test",
# srcs = ["pjrt_computation_client_test.cc"],
# deps = [
# ":computation_client",
# "@xla//xla:literal",
# "@xla//xla:literal_util",
# "@xla//xla:shape_util",
# "@xla//xla:status",
# "@xla//xla:statusor",
# "@xla//xla/client:xla_builder",
# "@xla//xla/client:xla_computation",
# "@xla//xla/tests:literal_test_util",
# "@xla//xla/tools:hlo_module_loader",
# "@org_tensorflow//tensorflow/core/platform:logging",
# "@tsl//tsl/lib/core:status_test_util",
# "@tsl//tsl/platform:env",
# "@tsl//tsl/platform:test",
# "@tsl//tsl/platform:test_main",
# ],
# )
ptxla_cc_test(
name = "pjrt_computation_client_test",
srcs = ["pjrt_computation_client_test.cc"],
deps = [
":computation_client",
":pjrt_computation_client",
":tensor_source",
"@xla//xla:literal",
"@xla//xla:literal_util",
"@xla//xla:shape_util",
"@xla//xla:status",
"@xla//xla:statusor",
"@xla//xla/client:xla_builder",
"@xla//xla/client:xla_computation",
"@xla//xla/tests:literal_test_util",
"@xla//xla/tools:hlo_module_loader",
"@tsl//tsl/lib/core:status_test_util",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_main",
],
)
30 changes: 8 additions & 22 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef XLA_CLIENT_COMPUTATION_CLIENT_H_
#define XLA_CLIENT_COMPUTATION_CLIENT_H_

#include <ATen/Tensor.h>
#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/hash.h>
Expand All @@ -20,6 +21,7 @@
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/metrics.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
#include "torch_xla/csrc/runtime/types.h"
#include "torch_xla/csrc/runtime/util.h"
#include "xla/client/xla_computation.h"
Expand Down Expand Up @@ -192,25 +194,6 @@ class ComputationClient {

using ComputationPtr = std::shared_ptr<Computation>;

// The TensorSource provides a way for a client to populate a buffer allocated
// by the core computation client code.
struct TensorSource {
// The PopulateFn accepts a dense buffer is standard array layout
// (dim0-major) and deposits the source tensor data directly over the
// provided buffer.
using PopulateFn = std::function<void(const TensorSource&, void*, size_t)>;

TensorSource() = default;
TensorSource(xla::Shape shape, std::string device, PopulateFn populate_fn)
: shape(std::move(shape)),
device(std::move(device)),
populate_fn(std::move(populate_fn)) {}

xla::Shape shape;
std::string device;
PopulateFn populate_fn;
};

// TODO(wcromar): Should CompileInstance still exist? Should it be a subclass
// of torch::lazy::Computation?
struct CompileInstance {
Expand Down Expand Up @@ -275,19 +258,22 @@ class ComputationClient {

// Transfers local tensor values to the TPU devices and fetches the handles.
virtual std::vector<DataPtr> TransferToServer(
absl::Span<const TensorSource> tensors) = 0;
absl::Span<const std::shared_ptr<const TensorSource>> tensors) = 0;

// Transfers local sharded tensor values to the TPU devices and returns a
// `PjRtShardedData`.
virtual DataPtr TransferShardsToServer(
absl::Span<const TensorSource> tensor_shards, std::string device,
xla::Shape shape, xla::OpSharding sharding) = 0;
absl::Span<const std::shared_ptr<const TensorSource>> tensor_shards,
std::string device, xla::Shape shape, xla::OpSharding sharding) = 0;

// Copies `data->buffer` to `dst` device buffer.
virtual DataPtr CopyToDevice(DataPtr data, std::string dst) = 0;

// Reads the tensor literal values stored at TPU server sites, behind the
// supplied handles.
// Note: `TransferFromServer` call will block until the `DataPtrs` are ready
// if they were created by `TransferToServer` or `Execute*`. Calling this from
// python while holding the GIL can cause deadlocks!
virtual std::vector<xla::Literal> TransferFromServer(
absl::Span<const DataPtr> handles) = 0;

Expand Down
Loading

0 comments on commit e753e61

Please sign in to comment.