Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ExecuteReplicated to operate on sharded data directly #5737

Merged
merged 24 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 0 additions & 94 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,100 +348,6 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
}
}

TEST_F(XLAShardingTest, InputHandler) {
if ((torch_xla::runtime::sys_util::GetEnvString(
torch_xla::runtime::env::kEnvPjRtDevice, "") == "") ||
(torch_xla::runtime::GetComputationClient()->GetLocalDevices().size() <
2)) {
GTEST_SKIP()
<< "`PJRT_DEVICE` is not set, with more than 2 local devices, ("
<< torch_xla::runtime::GetComputationClient()->GetLocalDevices().size()
<< " local devices detected).";
}

std::vector<at::Tensor> tensors(2);
auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
std::fill_n(tensors.begin(), tensors.size(), tensor);
std::vector<std::string> devices = {"TPU:0", "TPU:1"};
std::vector<XLATensor::ShardingSpecPtr> shardings = {
nullptr, std::make_shared<XLATensor::ShardingSpec>(
xla::HloSharding::Replicate().ToProto(), tensor_shape)};
std::vector<torch::lazy::BackendDataPtr> tensors_data =
CreateTensorsData(tensors, shardings, devices);

std::vector<torch_xla::runtime::ComputationClient::DataPtr> arguments =
UnwrapXlaData(tensors_data);
auto arguments_by_device = ShardingUtil::InputHandler(arguments, devices);

auto arg0_dev0 = arguments_by_device[0][0];
auto arg0_dev1 = arguments_by_device[1][0];
EXPECT_TRUE(XlaDataValuesEqual(arg0_dev0, arg0_dev1, at::kFloat));

auto arg1_dev0 = arguments_by_device[0][1];
auto arg1_dev1 = arguments_by_device[1][1];
EXPECT_TRUE(XlaDataValuesEqual(arg1_dev0, arg1_dev1, at::kFloat));
}

TEST_F(XLAShardingTest, OutputHandler) {
if ((torch_xla::runtime::sys_util::GetEnvString(
torch_xla::runtime::env::kEnvPjRtDevice, "") == "") ||
(torch_xla::runtime::GetComputationClient()->GetLocalDevices().size() <
2)) {
GTEST_SKIP()
<< "`PJRT_DEVICE` is not set, with more than 2 local devices, ("
<< torch_xla::runtime::GetComputationClient()->GetLocalDevices().size()
<< " local devices detected).";
}

std::vector<std::string> devices =
torch_xla::runtime::GetComputationClient()->GetLocalDevices();

// Prepare an input vecotr `outputs` with 2 arguments per device.
std::vector<std::vector<torch_xla::runtime::ComputationClient::DataPtr>>
outputs;
outputs.reserve(devices.size());
at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat));
for (auto device : devices) {
outputs.push_back(
UnwrapXlaData(CreateTensorsData({tensor, tensor}, {device, device})));
}

xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
xla::HloSharding::Tile1D(
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()),
devices.size())
.ToProto(),
tensor_shape);
std::vector<XLATensor::ShardingSpecPtr> sharding_specs{sharding_spec,
sharding_spec};

// Shard a PjRtData into a PjRtShardedData.
std::vector<torch_xla::runtime::ComputationClient::DataPtr> sharded_outputs =
ShardingUtil::OutputHandler(outputs, sharding_specs,
/*replicated_output=*/true);
EXPECT_EQ(sharded_outputs.size(), 2);
auto shards = torch_xla::runtime::GetComputationClient()->GetDataShards(
sharded_outputs[0]);
EXPECT_EQ(shards.size(), devices.size());
EXPECT_FALSE(
xla::Shape::Equal().IgnoreLayout()(shards[0]->shape(), tensor_shape));

// Wrap sharded data into a PjRtShardedData with `devices.size()` shards.
std::vector<torch_xla::runtime::ComputationClient::DataPtr> wrapped_outputs =
ShardingUtil::OutputHandler(outputs, sharding_specs,
/*replicated_output=*/false);
EXPECT_EQ(wrapped_outputs.size(), 2);
shards = torch_xla::runtime::GetComputationClient()->GetDataShards(
wrapped_outputs[0]);
EXPECT_EQ(shards.size(), devices.size());
EXPECT_TRUE(
xla::Shape::Equal().IgnoreLayout()(shards[0]->shape(), tensor_shape));
}

TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {4, 4});
int64_t n_devices =
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ cc_library(
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@tsl//tsl/profiler/lib:traceme",
"@tsl//tsl/platform/cloud:gcs_file_system",
"@tsl//tsl/platform:env",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
Expand Down
21 changes: 7 additions & 14 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,20 +291,13 @@ class ComputationClient {
const ExecuteComputationOptions& options =
ExecuteComputationOptions{}) = 0;

// Executes the computation in replicated mode.
// The size of the arguments vector is the number of replicas to execute,
// and it must match the size of the computation.devices() as well as the
// devices passed as argument. The destination devices for each replicated
// computation come from the devices the Data objects are stored into, which
// must match the devices argument. Within arguments[i], every Data
// object must be coming from the same device. Returns a vector (of the same
// size of the arguments vector) with the results of the parallel execution.
// The result[i], a vector itself, will be the result of the computation fed
// with arguments[i]. If options.explode_tuple is true, the output tuples will
// be decomposed into their single elements.
virtual std::vector<std::vector<DataPtr>> ExecuteReplicated(
const Computation& computation,
const std::vector<std::vector<DataPtr>>& arguments,
// Executes the computation on multiple local devices in parallel.
// Each argument to the executable is expected to be sharded in the same order
// as `devices`. If options.explode_tuple is true, the output tuples will be
// decomposed into their single elements. Returns a vector of outputs, each
// of which is sharded in the same order as `devices`.
virtual std::vector<DataPtr> ExecuteReplicated(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I am very reluctant to change the return type here -- device x arguments (vector of vector) is more general as the original ExecuteReplicated was meant to be implemented for non-SPMD as well. Also, I see that you are removing the input handler and the output handler... those abstractions were introduced to avoid any assumption about the computation client's return type, other than it's a device Data. The actual sharidng related processing happens in the sharding util, only after realizing that it's a sharded data.

Any reason you need to refactor the SPMD related code path, under pjrt_computation_client and xla_sharding_util? I assume this is for IFRT migration, and you want to skip having a separate ifrt client?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's a reason to keep the "vector of vector" structure here. We may have supported some way of running Replicated execution without SPMD before with XRT, but it doesn't really exist in PJRT or IFRT. You either execute on one device (PJRT's ExecuteSharded) or all local devices (IFRT's Execute). There was actually no usage of ExecuteReplicated at all within PT/XLA before PJRT and SPMD. IIRC we only kept this interface for this long to keep compatibility with XRT (even though we hadn't actually used that code path in XRT for years, if ever). We explicitly only call ExecuteReplicated when async->cached_computation->is_sharded.

It doesn't look like InputHandler would have handled unsharded data (what would GetDataShard do to unsharded data?), nor does it look like OutputHandler would have handled sharded data (what would WrapDataShards do to sharded data?). I think it makes much more sense to just construct the sharded data within the ComputationClient, applying the xla::OpShardings from the actual PjRtExecutable. It's a relatively minor refactor in PJRT, and it makes IFRT much easier to implement.

Copy link
Contributor

@yeounoh yeounoh Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, InputHandler and OuptutHandler works with both unsharded and sharded, and it's a no-op for unsharded data. We were able to switch around the data handling logic without touchting the data representation, and also the vice versa. Having that abstraction layer was the point I was trying to make, not the functional perspective ---- it works either way.

So for me, the question really comes down to if we are creating a separate IFRT computation client, and then later make it the main/default client for all use cases. Or, you are trying to rewrite the PJRT computation client directly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also ExecuteReplicate really calls Execute from the runtime, which has nested buffer... https://source.corp.google.com/piper///depot/google3/third_party/tensorflow/compiler/xla/pjrt/pjrt_client.h;l=1373?q=Execute

Now for us, we have a sharded data (or ifrt array) that wraps a vector of buffer, so we can make it either way -- but, assuming that the computation data could just be a wrapper of the device buffer -- the former is not strange at all?

Copy link
Contributor

@yeounoh yeounoh Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having said that (my reasonings for the hesitancy), if you feel that this would make the IFRT adoption easier, then let's make this refactor functional and land.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fundamental difference between PJRT and IFRT in this case is that it's trivial to "unshard" data in PJRT, since the actual runtime construct (PjRtBuffer) doesn't know anything about sharding. All sharding information is effectively in the framework (PjRtShardedData) and in the compiled PjRtExecutable.

In IFRT, the runtime data construct (ifrt::Array) carries its own sharding information. Disassembling it actually creates new ifrt::Arrays with new shardings, then we'd have to reassemble the array with the original sharding before passing it into the ifrt::Executable. The executable will then return a sharded array, which we'll again have to disassemble, then reassemble in OutputHandler with the sharding that matches what the executable returned. This creates a lot of space for errors to hide, particularly around the ownership of the underlying shards of data, which will be aliased by multiple ifrt::Arrays (compared to PJRT, where a single buffer of device data is always owned by one PjRtBuffer).

To avoid this, we either need to 1) implement the hack in my IFRT prototype PR where I skip InputHandler and OutputHandler when using IFRT or 2) abstract more implementation details under ComputationClient, since the data representations need to be handled differently.

Since DataPtr already represents sharded data, it's actually very easy for us to approximate IFRT's behavior with PJRT. PjRtShardedData fills the same role as ifrt::Array: references to the local device data plus the sharding of that data. It would be much harder for us to approximate PjRt's behavior from IFRT.

This PR is definitely inspired by IFRT, but I also think this is the right way to evolve this abstraction post-XRT. It doesn't really make the runtime client any thicker, and overall it pushes functionality down in the stack (ie out of PyTorch/XLA) now that we have more capable runtimes.

Copy link
Contributor

@yeounoh yeounoh Nov 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hoped that the SPMD sharding logics/functions to be runtime agnostic, working with the PjRtShardedData & DataPtr abstraction. Question, ifrt::Array tracks references or the actual copies? If it's the former like the PjRtBuffer, then assemble, disassemble shouldn't be a problem? One thing for sure, we don't want to materialize those data to CPU in any time. The input handler & output handler in SPMD util just reorganizes references.

Now come to think of it, yea -- adopting ifrt::Array is the right way, and it doesn't seem to be a simple drop-in replacement. Let's make it work, and at the same time -- I wonder if we could see to make leave InputHandler and OutputHandler intact so it can work with both PjRt and IfRt computaiton clients? But that sounds like they need to be a dummy call with most of logics plumbed down to the runtime layer...

I am also going to expand the InputHandler to take care of resharding the inputs after the auto-sharding pass. At the high-level, InputHandler is supposed to prepare the inputs in PyTorch/XLA before the execution, OutputHandler post-processes the outputs of the execution. And they were never used with xrt, only for pjrt & spmd.

EDIT:

This addresses a couple of pain points I found when prototyping IFRT in #5677. ComputationClient better hides the details of how sharded data is handled, since it's not required to be broken up before passing into ExecuteReplicated.

+1 hide the details of how sharded data is handled. Currently that's done by the DataPtr abstraction. At the same time, the computation client is a thin layer, it wraps and provides APIs to the underlying runtime client. The input for SPMD needs to be sharded and wrapped in a DataPtr (again wraps PjRtShardedData), and the entire array of device params are prepared for the execute API. I think we are looking at different levels of abstraction, for me the data abstraction is enough -- and the sharding & sharded data logic stays outside the execution API, which stays in the PyTorch/XLA layer. Let's discuss more when you resume the work.

Copy link
Collaborator Author

@will-cromar will-cromar Nov 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each operation in IFRT that produces a new Array (copying, disassembling, assembling, resharding, etc) allows you to copy, donate, or reuse the input. I don't believe reusing the input actually means the new buffer shares ownership. This is not relevant in PJRT, since it's functionally impossible (by design) to alias a PJRT buffer. One piece of device data is owned by corresponds to one PJRT buffer, and we share ownership with a shared_ptr to that buffer. That's no longer true in IFRT: particularly when we're resharding, we may end up with multiple arrays that correspond to the same device data, and one array owns that data.

The IFRT client will also need to handle one more unique case: for single-controller, Arrays will actually represent both the addressable and non-addressable shards of a global tensor. The representations of sharded data are different enough that the ComputationClient implementations will have to treat the sharded data containers differently.

The current function of InputHandler (ie converting a vector of objects that each contain a vector of shards, into a vector of vectors of shards) should move down into PjRtComputationClient, since that is specific to the API of PjRtClient::Execute. Likewise, the inverse operation in OutputHandler belongs in PjRtComputationClient for the same reason. OutputHandler does one more thing: assign shardings to the outputs.

Intuitively, I would say it makes the most sense to assign the global shapes and shardings to each sharded data based on the executable that produced them. I'm pretty sure the ShardingSpecs in OutputHandler originate from the compiled Executable's GetHloModules anyway, and PJRT added APIs for specifically this purpose since we started SPMD. But these calls are far too slow in practice.

When I went to add back the sharding assignment in OutputHandler, I noticed that we don't even actually use the shardings there! We assign all of the outputs to placeholders:

{
tsl::profiler::TraceMe activity("update_placeholder",
tsl::profiler::TraceMeLevel::kInfo);
for (size_t i = 0; i < results.size(); ++i) {
XLA_CHECK(async->tensors_data[i] != nullptr);
async->tensors_data[i]->Assign(*results[i]);
}
}

for (size_t i = 0; i < results.size(); ++i) {
if (async->tensors_data[i] != nullptr) {
async->tensors_data[i]->Assign(*results[i]);
} else {

Assign completely ignores the sharding and shape on the right side:

void Assign(const torch::lazy::BackendData& data) override {
const PjRtShardedData& pjrt_sharded_data =
dynamic_cast<const PjRtShardedData&>(data);
if (&pjrt_sharded_data != this) {
shards = std::move(pjrt_sharded_data.shards);
}
}

In commit 5c45ca6, I just gave all of the outputs a placeholder "manual" sharding, since the placeholders have the real sharding anyway.

edit: I might be wrong about the last part. Added a check and running CI again... edit edit: Yeah, we only use the shardings on the placeholders.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, since we should extract the output (propagated, non-propagated) after the compilation, and they are available when we prepare the placeholders. One thing to note (synced offline) is that the UNKNOWN sharidng type carry a special meaning for the upcoming auto-sharding change. For this PR, we can just focus on avoiding any regressions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, I gave the outputs here "real" shardings based on the compiled executable, since UNKNOWN will have a real semantic meaning soon. The extra call to GetOutputShardings works here since I only call it once per compilation, rather than once per execution.

const Computation& computation, absl::Span<const DataPtr> arguments,
absl::Span<const std::string> devices,
const ExecuteReplicatedOptions& options) = 0;

Expand Down
180 changes: 99 additions & 81 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"

#include <algorithm>
#include <future>
#include <unordered_set>
#include <vector>

Expand Down Expand Up @@ -386,13 +387,16 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice(
std::move(status_or.value()));
}

ComputationClient::DataPtr PjRtComputationClient::ReplicateShardedData(
std::shared_ptr<PjRtComputationClient::PjRtData>
PjRtComputationClient::ReplicateShardedData(
const ComputationClient::DataPtr& handle) {
if (PjRtShardedData* sharded_data =
dynamic_cast<PjRtShardedData*>(handle.get())) {
if (auto unsharded_data = std::dynamic_pointer_cast<PjRtData>(handle)) {
return unsharded_data;
} else if (auto sharded_data =
std::dynamic_pointer_cast<PjRtShardedData>(handle)) {
XLA_COUNTER("ReplicateShardedData", 1);
TF_VLOG(1) << "ReplicateShardedData (handle=" << handle->GetHandle()
<< ", shape=" << handle->shape() << ")";
TF_VLOG(1) << "ReplicateShardedData (handle=" << sharded_data->GetHandle()
<< ", shape=" << sharded_data->shape() << ")";
if (sharded_data->GetSharding().type() == xla::OpSharding::REPLICATED) {
// Data is replicated, return the first shard
return sharded_data->shards[0];
Expand Down Expand Up @@ -428,35 +432,22 @@ ComputationClient::DataPtr PjRtComputationClient::ReplicateShardedData(
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
computations = Compile(std::move(instances));

auto shards = sharded_data->shards;
XLA_CHECK_EQ(shards.size(), GetLocalDevices().size());
auto device_index = build_index_map(GetLocalDevices());

std::vector<std::vector<ComputationClient::DataPtr>> arguments_by_device(
GetLocalDevices().size(), std::vector<ComputationClient::DataPtr>(1));
for (auto shard : shards) {
std::vector<std::string> device_spec =
absl::StrSplit(shard->device(), ':');
XLA_CHECK_EQ(device_spec.size(), 2)
<< "Invalid device specification: " << shard->device();
int device_i = device_index[std::stoi(device_spec[1])];
TF_VLOG(3) << shard->device() << " is mapped to local device index "
<< device_i;
arguments_by_device[device_i][0] = shard;
}
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
execute_options;
auto sharded_results =
ExecuteReplicated(*computations.front(), arguments_by_device,
ExecuteReplicated(*computations.front(), {sharded_data},
GetLocalDevices(), execute_options);
XLA_CHECK(sharded_results.size() > 0)
<< "empty ExecuteReplicated results returned.";
XLA_CHECK(sharded_results[0].size() == 1)
XLA_CHECK(sharded_results.size() == 1)
<< "Wrong number of outputs, expected: 1, actual: "
<< sharded_results[0].size();
return sharded_results[0][0];
<< sharded_results.size();
return std::dynamic_pointer_cast<PjRtShardedData>(sharded_results[0])
->shards[0];
}
return handle;

XLA_ERROR() << "Data must be PjRtData or PjRtShardedData, got "
<< handle->ToString();
}

std::vector<xla::Literal> PjRtComputationClient::TransferFromServer(
Expand All @@ -472,12 +463,12 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromServer(
for (auto handle : handles) {
// Use XLA replication to reassemble the sharded data. If input handle
// is not sharded, then it is a no-op.
auto new_handle = ReplicateShardedData(handle);
const PjRtData& pjrt_data = dynamic_cast<const PjRtData&>(*new_handle);
std::shared_ptr<PjRtData> pjrt_data = ReplicateShardedData(handle);
XLA_CHECK(pjrt_data);

xla::Literal& literal =
literals.emplace_back(host_output_shape(pjrt_data.buffer.get()));
futures.push_back(pjrt_data.buffer->ToLiteral(&literal));
literals.emplace_back(host_output_shape(pjrt_data->buffer.get()));
futures.push_back(pjrt_data->buffer->ToLiteral(&literal));

total_size += literal.size_bytes();
}
Expand Down Expand Up @@ -642,10 +633,10 @@ PjRtComputationClient::ExecuteComputation(
return datas;
}

std::vector<std::vector<ComputationClient::DataPtr>>
std::vector<ComputationClient::DataPtr>
PjRtComputationClient::ExecuteReplicated(
const ComputationClient::Computation& computation,
const std::vector<std::vector<ComputationClient::DataPtr>>& arguments,
absl::Span<const ComputationClient::DataPtr> arguments,
absl::Span<const std::string> devices,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we still need devices as input here - should the caller need to be aware of which devices the shards live on?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

devices doesn't really make sense since we dropped XRT, since PJRT only lets you execute on one device or all devices. I added a TODO to remove this in my IFRT prototype PR. In practice, the caller does seem to be setting the "correct" value here, though.

const ExecuteReplicatedOptions& options) {
// Shared ownership of the timed section ensures that it will only get logged
Expand All @@ -657,34 +648,41 @@ PjRtComputationClient::ExecuteReplicated(
tsl::profiler::TraceMeLevel::kInfo);
const PjRtComputation& pjrt_computation =
dynamic_cast<const PjRtComputation&>(computation);
XLA_CHECK(devices.size() == arguments.size())
<< "ExecuteReplicated over " << devices.size() << " devices, but "
<< arguments.size() << " arguments devices.";
absl::BlockingCounter counter(devices.size());
std::vector<std::vector<xla::PjRtBuffer*>> argument_handles(devices.size());

std::vector<std::vector<xla::PjRtBuffer*>> argument_handles(
devices.size(), std::vector<xla::PjRtBuffer*>(arguments.size()));
{
tsl::profiler::TraceMe activity(
"PjRtComputationClient::ExecuteReplicated_argument_handle",
tsl::profiler::TraceMeLevel::kInfo);
for (int32_t i = 0; i < devices.size(); ++i) {
auto buffer_converter = [&, i]() {
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]);
XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString();

std::vector<xla::PjRtBuffer*> buffers;
for (auto& argument : arguments[i]) {
const PjRtData* pjrt_data = dynamic_cast<PjRtData*>(argument.get());

XLA_CHECK(pjrt_device == pjrt_data->buffer->device())
<< pjrt_device->DebugString() << " vs "
<< pjrt_data->buffer->device()->DebugString();
buffers.push_back(pjrt_data->buffer.get());
}
argument_handles[i] = std::move(buffers);
counter.DecrementCount();
};
thread::Schedule(std::move(buffer_converter));
}

absl::BlockingCounter counter(arguments.size());

// Time in nanoseconds that it takes to prepare an argument. Used to tune
// number of threads spawned by ParallelFor. Measured on 2023/11/28.
static constexpr int64_t argument_handle_cost_ns = 10000;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How important is the accuracy of this value? I assume this value depends on the host machine's specs, so it would be different across TPU generations and GPU VMs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just shooting for the right order of magnitude here -- I didn't see any noticeable difference between this and a 30000 ns estimate from a previous commit. Supposedly we can set this to the number of CPU cycles (which would be more consistent), but I don't know how to measure that accurately.

Since this is really obscure, and it doesn't seem to be that sensitive, I just put a nice round number in the correct range here rather than adding another configuration knob.

pool_.ParallelFor(
arguments.size(), argument_handle_cost_ns,
[&](int64_t start, int64_t end) {
for (int32_t i = start; i < end; ++i) {
auto pjrt_data =
std::dynamic_pointer_cast<PjRtShardedData>(arguments[i]);
XLA_CHECK_EQ(pjrt_data->shards.size(), devices.size())
<< "Expected one shard per device";

for (int32_t d = 0; d < devices.size(); d++) {
std::shared_ptr<PjRtData> shard = pjrt_data->shards[d];

xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[d]);
XLA_CHECK_EQ(shard->buffer->device(), pjrt_device);
XLA_CHECK(pjrt_device->IsAddressable())
<< pjrt_device->DebugString();

argument_handles[d][i] = shard->buffer.get();
}
counter.DecrementCount();
};
});
counter.Wait();
}

Expand Down Expand Up @@ -726,38 +724,58 @@ PjRtComputationClient::ExecuteReplicated(
}));
}

std::vector<std::vector<ComputationClient::DataPtr>> data_handles;
data_handles.reserve(results.size());
std::vector<size_t> dims(results.size());
size_t num_outputs = results[0].size();
std::vector<ComputationClient::DataPtr> data_handles(num_outputs);

{
tsl::profiler::TraceMe activity(
"PjRtComputationClient::ExecuteReplicated_result_handle",
tsl::profiler::TraceMeLevel::kInfo);
for (int32_t i = 0; i < results.size(); ++i) {
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]);
XLA_CHECK(pjrt_device->IsAddressable())
<< pjrt_device->DebugString() << " is not addressable.";

std::vector<ComputationClient::DataPtr> datas;
datas.reserve(results[i].size());
dims[i] = results[i].size();
for (int32_t j = 0; j < results[i].size(); ++j) {
std::unique_ptr<xla::PjRtBuffer> buffer = std::move(results[i][j]);
XLA_CHECK(pjrt_device == buffer->device())
<< "Exepcted device: " << pjrt_device->DebugString()
<< " vs. actual device: " << buffer->device()->DebugString();

std::shared_ptr<PjRtData> data =
std::make_shared<PjRtData>(devices[i], std::move(buffer));
datas.push_back(data);
}
data_handles.push_back(datas);
}

const xla::Shape& result_shape = computation.program_shape().result();
TF_VLOG(3) << "Processing output with shape " << result_shape.ToString();
const std::vector<xla::Shape>& output_shapes =
result_shape.IsTuple() ? result_shape.tuple_shapes()
: std::vector<xla::Shape>({result_shape});
XLA_CHECK_EQ(output_shapes.size(), num_outputs);

const std::vector<xla::OpSharding>& output_shardings =
pjrt_computation.output_shardings_
? *pjrt_computation.output_shardings_
:
// Without an explicit sharding annotation, the output is implicitly
// replicated
std::vector(num_outputs, xla::HloSharding::Replicate().ToProto());
XLA_CHECK_EQ(output_shardings.size(), num_outputs);

absl::BlockingCounter counter(num_outputs);

// Time in nanoseconds that it takes to process a result buffer.
// Measured on 2023/11/28.
static constexpr int64_t result_handle_cost_ns = 10000;
pool_.ParallelFor(
num_outputs, result_handle_cost_ns, [&](int64_t start, int64_t end) {
for (int32_t i = start; i < end; ++i) {
std::vector<std::shared_ptr<PjRtData>> shards(devices.size());
for (int32_t d = 0; d < devices.size(); d++) {
std::unique_ptr<xla::PjRtBuffer> buffer =
std::move(results[d][i]);
shards[d] =
std::make_shared<PjRtData>(devices[d], std::move(buffer));
}

data_handles[i] = std::make_shared<PjRtShardedData>(
spmd_device_str, output_shapes[i], std::move(shards),
output_shardings[i]);
TF_VLOG(5) << "Created sharded data with shape "
<< data_handles[i]->shape().ToString();
counter.DecrementCount();
}
});
counter.Wait();
}

TF_VLOG(1) << "Returning " << data_handles.size() << " sets of results "
<< "with dimensions [" << absl::StrJoin(dims, ",") << "].";
TF_VLOG(1) << "Returning " << data_handles.size() << " sharded outputs.";
return data_handles;
}

Expand Down
Loading