Skip to content

Commit

Permalink
Refactor ExecuteReplicated to operate on sharded data directly (pytor…
Browse files Browse the repository at this point in the history
…ch#5737)

* Refactor ExecuteReplicated to operate on sharded data directly

* Remove old handlers

* formatting

* Improve naming and logging

* update docstring

* Remove obsolete unit tests

* improve comment

* Remove slow calls to get output shapes.

* fix implicit sharding

* remove declarations of input/output handlers

* formatting

* give everything a manual placeholder sharding

* see if CI passes

* formatting

* Shard parameter and output handling

* Use absl::BlockingCounter

* formatting

* fix merge

* Assign valid output shardings

* tune and document costs

* formatting

* implicitly replicate output to match outputhandler

* clarify ReplicateShardedData

* fix merge
  • Loading branch information
will-cromar authored and ManfeiBai committed Dec 1, 2023
1 parent cd87fc8 commit 075cead
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 318 deletions.
94 changes: 0 additions & 94 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,100 +348,6 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
}
}

TEST_F(XLAShardingTest, InputHandler) {
if ((torch_xla::runtime::sys_util::GetEnvString(
torch_xla::runtime::env::kEnvPjRtDevice, "") == "") ||
(torch_xla::runtime::GetComputationClient()->GetLocalDevices().size() <
2)) {
GTEST_SKIP()
<< "`PJRT_DEVICE` is not set, with more than 2 local devices, ("
<< torch_xla::runtime::GetComputationClient()->GetLocalDevices().size()
<< " local devices detected).";
}

std::vector<at::Tensor> tensors(2);
auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
std::fill_n(tensors.begin(), tensors.size(), tensor);
std::vector<std::string> devices = {"TPU:0", "TPU:1"};
std::vector<XLATensor::ShardingSpecPtr> shardings = {
nullptr, std::make_shared<XLATensor::ShardingSpec>(
xla::HloSharding::Replicate().ToProto(), tensor_shape)};
std::vector<torch::lazy::BackendDataPtr> tensors_data =
CreateTensorsData(tensors, shardings, devices);

std::vector<torch_xla::runtime::ComputationClient::DataPtr> arguments =
UnwrapXlaData(tensors_data);
auto arguments_by_device = ShardingUtil::InputHandler(arguments, devices);

auto arg0_dev0 = arguments_by_device[0][0];
auto arg0_dev1 = arguments_by_device[1][0];
EXPECT_TRUE(XlaDataValuesEqual(arg0_dev0, arg0_dev1, at::kFloat));

auto arg1_dev0 = arguments_by_device[0][1];
auto arg1_dev1 = arguments_by_device[1][1];
EXPECT_TRUE(XlaDataValuesEqual(arg1_dev0, arg1_dev1, at::kFloat));
}

TEST_F(XLAShardingTest, OutputHandler) {
if ((torch_xla::runtime::sys_util::GetEnvString(
torch_xla::runtime::env::kEnvPjRtDevice, "") == "") ||
(torch_xla::runtime::GetComputationClient()->GetLocalDevices().size() <
2)) {
GTEST_SKIP()
<< "`PJRT_DEVICE` is not set, with more than 2 local devices, ("
<< torch_xla::runtime::GetComputationClient()->GetLocalDevices().size()
<< " local devices detected).";
}

std::vector<std::string> devices =
torch_xla::runtime::GetComputationClient()->GetLocalDevices();

// Prepare an input vecotr `outputs` with 2 arguments per device.
std::vector<std::vector<torch_xla::runtime::ComputationClient::DataPtr>>
outputs;
outputs.reserve(devices.size());
at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat));
for (auto device : devices) {
outputs.push_back(
UnwrapXlaData(CreateTensorsData({tensor, tensor}, {device, device})));
}

xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
xla::HloSharding::Tile1D(
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()),
devices.size())
.ToProto(),
tensor_shape);
std::vector<XLATensor::ShardingSpecPtr> sharding_specs{sharding_spec,
sharding_spec};

// Shard a PjRtData into a PjRtShardedData.
std::vector<torch_xla::runtime::ComputationClient::DataPtr> sharded_outputs =
ShardingUtil::OutputHandler(outputs, sharding_specs,
/*replicated_output=*/true);
EXPECT_EQ(sharded_outputs.size(), 2);
auto shards = torch_xla::runtime::GetComputationClient()->GetDataShards(
sharded_outputs[0]);
EXPECT_EQ(shards.size(), devices.size());
EXPECT_FALSE(
xla::Shape::Equal().IgnoreLayout()(shards[0]->shape(), tensor_shape));

// Wrap sharded data into a PjRtShardedData with `devices.size()` shards.
std::vector<torch_xla::runtime::ComputationClient::DataPtr> wrapped_outputs =
ShardingUtil::OutputHandler(outputs, sharding_specs,
/*replicated_output=*/false);
EXPECT_EQ(wrapped_outputs.size(), 2);
shards = torch_xla::runtime::GetComputationClient()->GetDataShards(
wrapped_outputs[0]);
EXPECT_EQ(shards.size(), devices.size());
EXPECT_TRUE(
xla::Shape::Equal().IgnoreLayout()(shards[0]->shape(), tensor_shape));
}

TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {4, 4});
int64_t n_devices =
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ cc_library(
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@tsl//tsl/profiler/lib:traceme",
"@tsl//tsl/platform/cloud:gcs_file_system",
"@tsl//tsl/platform:env",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
Expand Down
21 changes: 7 additions & 14 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,20 +291,13 @@ class ComputationClient {
const ExecuteComputationOptions& options =
ExecuteComputationOptions{}) = 0;

// Executes the computation in replicated mode.
// The size of the arguments vector is the number of replicas to execute,
// and it must match the size of the computation.devices() as well as the
// devices passed as argument. The destination devices for each replicated
// computation come from the devices the Data objects are stored into, which
// must match the devices argument. Within arguments[i], every Data
// object must be coming from the same device. Returns a vector (of the same
// size of the arguments vector) with the results of the parallel execution.
// The result[i], a vector itself, will be the result of the computation fed
// with arguments[i]. If options.explode_tuple is true, the output tuples will
// be decomposed into their single elements.
virtual std::vector<std::vector<DataPtr>> ExecuteReplicated(
const Computation& computation,
const std::vector<std::vector<DataPtr>>& arguments,
// Executes the computation on multiple local devices in parallel.
// Each argument to the executable is expected to be sharded in the same order
// as `devices`. If options.explode_tuple is true, the output tuples will be
// decomposed into their single elements. Returns a vector of outputs, each
// of which is sharded in the same order as `devices`.
virtual std::vector<DataPtr> ExecuteReplicated(
const Computation& computation, absl::Span<const DataPtr> arguments,
absl::Span<const std::string> devices,
const ExecuteReplicatedOptions& options) = 0;

Expand Down
180 changes: 99 additions & 81 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"

#include <algorithm>
#include <future>
#include <unordered_set>
#include <vector>

Expand Down Expand Up @@ -386,13 +387,16 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice(
std::move(status_or.value()));
}

ComputationClient::DataPtr PjRtComputationClient::ReplicateShardedData(
std::shared_ptr<PjRtComputationClient::PjRtData>
PjRtComputationClient::ReplicateShardedData(
const ComputationClient::DataPtr& handle) {
if (PjRtShardedData* sharded_data =
dynamic_cast<PjRtShardedData*>(handle.get())) {
if (auto unsharded_data = std::dynamic_pointer_cast<PjRtData>(handle)) {
return unsharded_data;
} else if (auto sharded_data =
std::dynamic_pointer_cast<PjRtShardedData>(handle)) {
XLA_COUNTER("ReplicateShardedData", 1);
TF_VLOG(1) << "ReplicateShardedData (handle=" << handle->GetHandle()
<< ", shape=" << handle->shape() << ")";
TF_VLOG(1) << "ReplicateShardedData (handle=" << sharded_data->GetHandle()
<< ", shape=" << sharded_data->shape() << ")";
if (sharded_data->GetSharding().type() == xla::OpSharding::REPLICATED) {
// Data is replicated, return the first shard
return sharded_data->shards[0];
Expand Down Expand Up @@ -428,35 +432,22 @@ ComputationClient::DataPtr PjRtComputationClient::ReplicateShardedData(
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
computations = Compile(std::move(instances));

auto shards = sharded_data->shards;
XLA_CHECK_EQ(shards.size(), GetLocalDevices().size());
auto device_index = build_index_map(GetLocalDevices());

std::vector<std::vector<ComputationClient::DataPtr>> arguments_by_device(
GetLocalDevices().size(), std::vector<ComputationClient::DataPtr>(1));
for (auto shard : shards) {
std::vector<std::string> device_spec =
absl::StrSplit(shard->device(), ':');
XLA_CHECK_EQ(device_spec.size(), 2)
<< "Invalid device specification: " << shard->device();
int device_i = device_index[std::stoi(device_spec[1])];
TF_VLOG(3) << shard->device() << " is mapped to local device index "
<< device_i;
arguments_by_device[device_i][0] = shard;
}
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
execute_options;
auto sharded_results =
ExecuteReplicated(*computations.front(), arguments_by_device,
ExecuteReplicated(*computations.front(), {sharded_data},
GetLocalDevices(), execute_options);
XLA_CHECK(sharded_results.size() > 0)
<< "empty ExecuteReplicated results returned.";
XLA_CHECK(sharded_results[0].size() == 1)
XLA_CHECK(sharded_results.size() == 1)
<< "Wrong number of outputs, expected: 1, actual: "
<< sharded_results[0].size();
return sharded_results[0][0];
<< sharded_results.size();
return std::dynamic_pointer_cast<PjRtShardedData>(sharded_results[0])
->shards[0];
}
return handle;

XLA_ERROR() << "Data must be PjRtData or PjRtShardedData, got "
<< handle->ToString();
}

std::vector<xla::Literal> PjRtComputationClient::TransferFromServer(
Expand All @@ -472,12 +463,12 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromServer(
for (auto handle : handles) {
// Use XLA replication to reassemble the sharded data. If input handle
// is not sharded, then it is a no-op.
auto new_handle = ReplicateShardedData(handle);
const PjRtData& pjrt_data = dynamic_cast<const PjRtData&>(*new_handle);
std::shared_ptr<PjRtData> pjrt_data = ReplicateShardedData(handle);
XLA_CHECK(pjrt_data);

xla::Literal& literal =
literals.emplace_back(host_output_shape(pjrt_data.buffer.get()));
futures.push_back(pjrt_data.buffer->ToLiteral(&literal));
literals.emplace_back(host_output_shape(pjrt_data->buffer.get()));
futures.push_back(pjrt_data->buffer->ToLiteral(&literal));

total_size += literal.size_bytes();
}
Expand Down Expand Up @@ -642,10 +633,10 @@ PjRtComputationClient::ExecuteComputation(
return datas;
}

std::vector<std::vector<ComputationClient::DataPtr>>
std::vector<ComputationClient::DataPtr>
PjRtComputationClient::ExecuteReplicated(
const ComputationClient::Computation& computation,
const std::vector<std::vector<ComputationClient::DataPtr>>& arguments,
absl::Span<const ComputationClient::DataPtr> arguments,
absl::Span<const std::string> devices,
const ExecuteReplicatedOptions& options) {
// Shared ownership of the timed section ensures that it will only get logged
Expand All @@ -657,34 +648,41 @@ PjRtComputationClient::ExecuteReplicated(
tsl::profiler::TraceMeLevel::kInfo);
const PjRtComputation& pjrt_computation =
dynamic_cast<const PjRtComputation&>(computation);
XLA_CHECK(devices.size() == arguments.size())
<< "ExecuteReplicated over " << devices.size() << " devices, but "
<< arguments.size() << " arguments devices.";
absl::BlockingCounter counter(devices.size());
std::vector<std::vector<xla::PjRtBuffer*>> argument_handles(devices.size());

std::vector<std::vector<xla::PjRtBuffer*>> argument_handles(
devices.size(), std::vector<xla::PjRtBuffer*>(arguments.size()));
{
tsl::profiler::TraceMe activity(
"PjRtComputationClient::ExecuteReplicated_argument_handle",
tsl::profiler::TraceMeLevel::kInfo);
for (int32_t i = 0; i < devices.size(); ++i) {
auto buffer_converter = [&, i]() {
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]);
XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString();

std::vector<xla::PjRtBuffer*> buffers;
for (auto& argument : arguments[i]) {
const PjRtData* pjrt_data = dynamic_cast<PjRtData*>(argument.get());

XLA_CHECK(pjrt_device == pjrt_data->buffer->device())
<< pjrt_device->DebugString() << " vs "
<< pjrt_data->buffer->device()->DebugString();
buffers.push_back(pjrt_data->buffer.get());
}
argument_handles[i] = std::move(buffers);
counter.DecrementCount();
};
thread::Schedule(std::move(buffer_converter));
}

absl::BlockingCounter counter(arguments.size());

// Time in nanoseconds that it takes to prepare an argument. Used to tune
// number of threads spawned by ParallelFor. Measured on 2023/11/28.
static constexpr int64_t argument_handle_cost_ns = 10000;
pool_.ParallelFor(
arguments.size(), argument_handle_cost_ns,
[&](int64_t start, int64_t end) {
for (int32_t i = start; i < end; ++i) {
auto pjrt_data =
std::dynamic_pointer_cast<PjRtShardedData>(arguments[i]);
XLA_CHECK_EQ(pjrt_data->shards.size(), devices.size())
<< "Expected one shard per device";

for (int32_t d = 0; d < devices.size(); d++) {
std::shared_ptr<PjRtData> shard = pjrt_data->shards[d];

xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[d]);
XLA_CHECK_EQ(shard->buffer->device(), pjrt_device);
XLA_CHECK(pjrt_device->IsAddressable())
<< pjrt_device->DebugString();

argument_handles[d][i] = shard->buffer.get();
}
counter.DecrementCount();
};
});
counter.Wait();
}

Expand Down Expand Up @@ -726,38 +724,58 @@ PjRtComputationClient::ExecuteReplicated(
}));
}

std::vector<std::vector<ComputationClient::DataPtr>> data_handles;
data_handles.reserve(results.size());
std::vector<size_t> dims(results.size());
size_t num_outputs = results[0].size();
std::vector<ComputationClient::DataPtr> data_handles(num_outputs);

{
tsl::profiler::TraceMe activity(
"PjRtComputationClient::ExecuteReplicated_result_handle",
tsl::profiler::TraceMeLevel::kInfo);
for (int32_t i = 0; i < results.size(); ++i) {
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]);
XLA_CHECK(pjrt_device->IsAddressable())
<< pjrt_device->DebugString() << " is not addressable.";

std::vector<ComputationClient::DataPtr> datas;
datas.reserve(results[i].size());
dims[i] = results[i].size();
for (int32_t j = 0; j < results[i].size(); ++j) {
std::unique_ptr<xla::PjRtBuffer> buffer = std::move(results[i][j]);
XLA_CHECK(pjrt_device == buffer->device())
<< "Exepcted device: " << pjrt_device->DebugString()
<< " vs. actual device: " << buffer->device()->DebugString();

std::shared_ptr<PjRtData> data =
std::make_shared<PjRtData>(devices[i], std::move(buffer));
datas.push_back(data);
}
data_handles.push_back(datas);
}

const xla::Shape& result_shape = computation.program_shape().result();
TF_VLOG(3) << "Processing output with shape " << result_shape.ToString();
const std::vector<xla::Shape>& output_shapes =
result_shape.IsTuple() ? result_shape.tuple_shapes()
: std::vector<xla::Shape>({result_shape});
XLA_CHECK_EQ(output_shapes.size(), num_outputs);

const std::vector<xla::OpSharding>& output_shardings =
pjrt_computation.output_shardings_
? *pjrt_computation.output_shardings_
:
// Without an explicit sharding annotation, the output is implicitly
// replicated
std::vector(num_outputs, xla::HloSharding::Replicate().ToProto());
XLA_CHECK_EQ(output_shardings.size(), num_outputs);

absl::BlockingCounter counter(num_outputs);

// Time in nanoseconds that it takes to process a result buffer.
// Measured on 2023/11/28.
static constexpr int64_t result_handle_cost_ns = 10000;
pool_.ParallelFor(
num_outputs, result_handle_cost_ns, [&](int64_t start, int64_t end) {
for (int32_t i = start; i < end; ++i) {
std::vector<std::shared_ptr<PjRtData>> shards(devices.size());
for (int32_t d = 0; d < devices.size(); d++) {
std::unique_ptr<xla::PjRtBuffer> buffer =
std::move(results[d][i]);
shards[d] =
std::make_shared<PjRtData>(devices[d], std::move(buffer));
}

data_handles[i] = std::make_shared<PjRtShardedData>(
spmd_device_str, output_shapes[i], std::move(shards),
output_shardings[i]);
TF_VLOG(5) << "Created sharded data with shape "
<< data_handles[i]->shape().ToString();
counter.DecrementCount();
}
});
counter.Wait();
}

TF_VLOG(1) << "Returning " << data_handles.size() << " sets of results "
<< "with dimensions [" << absl::StrJoin(dims, ",") << "].";
TF_VLOG(1) << "Returning " << data_handles.size() << " sharded outputs.";
return data_handles;
}

Expand Down
Loading

0 comments on commit 075cead

Please sign in to comment.