-
Notifications
You must be signed in to change notification settings - Fork 489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor ExecuteReplicated to operate on sharded data directly #5737
Changes from all commits
7eaa91b
0cc6dcd
e9be48f
d3e3976
4dbac97
5bc1a23
be7b3da
5d48684
6e1e0d3
8b6fff3
64ef125
20be9ca
ac61413
806a7bd
787b568
109b5de
6e1c28a
464c6c3
11bf3ff
dba6d1c
9dcd75e
c2b8812
9bc594e
674f2c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
#include "torch_xla/csrc/runtime/pjrt_computation_client.h" | ||
|
||
#include <algorithm> | ||
#include <future> | ||
#include <unordered_set> | ||
#include <vector> | ||
|
||
|
@@ -386,13 +387,16 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice( | |
std::move(status_or.value())); | ||
} | ||
|
||
ComputationClient::DataPtr PjRtComputationClient::ReplicateShardedData( | ||
std::shared_ptr<PjRtComputationClient::PjRtData> | ||
PjRtComputationClient::ReplicateShardedData( | ||
const ComputationClient::DataPtr& handle) { | ||
if (PjRtShardedData* sharded_data = | ||
dynamic_cast<PjRtShardedData*>(handle.get())) { | ||
if (auto unsharded_data = std::dynamic_pointer_cast<PjRtData>(handle)) { | ||
return unsharded_data; | ||
} else if (auto sharded_data = | ||
std::dynamic_pointer_cast<PjRtShardedData>(handle)) { | ||
XLA_COUNTER("ReplicateShardedData", 1); | ||
TF_VLOG(1) << "ReplicateShardedData (handle=" << handle->GetHandle() | ||
<< ", shape=" << handle->shape() << ")"; | ||
TF_VLOG(1) << "ReplicateShardedData (handle=" << sharded_data->GetHandle() | ||
<< ", shape=" << sharded_data->shape() << ")"; | ||
if (sharded_data->GetSharding().type() == xla::OpSharding::REPLICATED) { | ||
// Data is replicated, return the first shard | ||
return sharded_data->shards[0]; | ||
|
@@ -428,35 +432,22 @@ ComputationClient::DataPtr PjRtComputationClient::ReplicateShardedData( | |
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>> | ||
computations = Compile(std::move(instances)); | ||
|
||
auto shards = sharded_data->shards; | ||
XLA_CHECK_EQ(shards.size(), GetLocalDevices().size()); | ||
auto device_index = build_index_map(GetLocalDevices()); | ||
|
||
std::vector<std::vector<ComputationClient::DataPtr>> arguments_by_device( | ||
GetLocalDevices().size(), std::vector<ComputationClient::DataPtr>(1)); | ||
for (auto shard : shards) { | ||
std::vector<std::string> device_spec = | ||
absl::StrSplit(shard->device(), ':'); | ||
XLA_CHECK_EQ(device_spec.size(), 2) | ||
<< "Invalid device specification: " << shard->device(); | ||
int device_i = device_index[std::stoi(device_spec[1])]; | ||
TF_VLOG(3) << shard->device() << " is mapped to local device index " | ||
<< device_i; | ||
arguments_by_device[device_i][0] = shard; | ||
} | ||
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions | ||
execute_options; | ||
auto sharded_results = | ||
ExecuteReplicated(*computations.front(), arguments_by_device, | ||
ExecuteReplicated(*computations.front(), {sharded_data}, | ||
GetLocalDevices(), execute_options); | ||
XLA_CHECK(sharded_results.size() > 0) | ||
<< "empty ExecuteReplicated results returned."; | ||
XLA_CHECK(sharded_results[0].size() == 1) | ||
XLA_CHECK(sharded_results.size() == 1) | ||
<< "Wrong number of outputs, expected: 1, actual: " | ||
<< sharded_results[0].size(); | ||
return sharded_results[0][0]; | ||
<< sharded_results.size(); | ||
return std::dynamic_pointer_cast<PjRtShardedData>(sharded_results[0]) | ||
->shards[0]; | ||
} | ||
return handle; | ||
|
||
XLA_ERROR() << "Data must be PjRtData or PjRtShardedData, got " | ||
<< handle->ToString(); | ||
} | ||
|
||
std::vector<xla::Literal> PjRtComputationClient::TransferFromServer( | ||
|
@@ -472,12 +463,12 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromServer( | |
for (auto handle : handles) { | ||
// Use XLA replication to reassemble the sharded data. If input handle | ||
// is not sharded, then it is a no-op. | ||
auto new_handle = ReplicateShardedData(handle); | ||
const PjRtData& pjrt_data = dynamic_cast<const PjRtData&>(*new_handle); | ||
std::shared_ptr<PjRtData> pjrt_data = ReplicateShardedData(handle); | ||
XLA_CHECK(pjrt_data); | ||
|
||
xla::Literal& literal = | ||
literals.emplace_back(host_output_shape(pjrt_data.buffer.get())); | ||
futures.push_back(pjrt_data.buffer->ToLiteral(&literal)); | ||
literals.emplace_back(host_output_shape(pjrt_data->buffer.get())); | ||
futures.push_back(pjrt_data->buffer->ToLiteral(&literal)); | ||
|
||
total_size += literal.size_bytes(); | ||
} | ||
|
@@ -642,10 +633,10 @@ PjRtComputationClient::ExecuteComputation( | |
return datas; | ||
} | ||
|
||
std::vector<std::vector<ComputationClient::DataPtr>> | ||
std::vector<ComputationClient::DataPtr> | ||
PjRtComputationClient::ExecuteReplicated( | ||
const ComputationClient::Computation& computation, | ||
const std::vector<std::vector<ComputationClient::DataPtr>>& arguments, | ||
absl::Span<const ComputationClient::DataPtr> arguments, | ||
absl::Span<const std::string> devices, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we still need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
const ExecuteReplicatedOptions& options) { | ||
// Shared ownership of the timed section ensures that it will only get logged | ||
|
@@ -657,34 +648,41 @@ PjRtComputationClient::ExecuteReplicated( | |
tsl::profiler::TraceMeLevel::kInfo); | ||
const PjRtComputation& pjrt_computation = | ||
dynamic_cast<const PjRtComputation&>(computation); | ||
XLA_CHECK(devices.size() == arguments.size()) | ||
<< "ExecuteReplicated over " << devices.size() << " devices, but " | ||
<< arguments.size() << " arguments devices."; | ||
absl::BlockingCounter counter(devices.size()); | ||
std::vector<std::vector<xla::PjRtBuffer*>> argument_handles(devices.size()); | ||
|
||
std::vector<std::vector<xla::PjRtBuffer*>> argument_handles( | ||
devices.size(), std::vector<xla::PjRtBuffer*>(arguments.size())); | ||
{ | ||
tsl::profiler::TraceMe activity( | ||
"PjRtComputationClient::ExecuteReplicated_argument_handle", | ||
tsl::profiler::TraceMeLevel::kInfo); | ||
for (int32_t i = 0; i < devices.size(); ++i) { | ||
auto buffer_converter = [&, i]() { | ||
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]); | ||
XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); | ||
|
||
std::vector<xla::PjRtBuffer*> buffers; | ||
for (auto& argument : arguments[i]) { | ||
const PjRtData* pjrt_data = dynamic_cast<PjRtData*>(argument.get()); | ||
|
||
XLA_CHECK(pjrt_device == pjrt_data->buffer->device()) | ||
<< pjrt_device->DebugString() << " vs " | ||
<< pjrt_data->buffer->device()->DebugString(); | ||
buffers.push_back(pjrt_data->buffer.get()); | ||
} | ||
argument_handles[i] = std::move(buffers); | ||
counter.DecrementCount(); | ||
}; | ||
thread::Schedule(std::move(buffer_converter)); | ||
} | ||
|
||
absl::BlockingCounter counter(arguments.size()); | ||
|
||
// Time in nanoseconds that it takes to prepare an argument. Used to tune | ||
// number of threads spawned by ParallelFor. Measured on 2023/11/28. | ||
static constexpr int64_t argument_handle_cost_ns = 10000; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How important is the accuracy of this value? I assume this value depends on the host machine's specs, so it would be different across TPU generations and GPU VMs There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm just shooting for the right order of magnitude here -- I didn't see any noticeable difference between this and a Since this is really obscure, and it doesn't seem to be that sensitive, I just put a nice round number in the correct range here rather than adding another configuration knob. |
||
pool_.ParallelFor( | ||
arguments.size(), argument_handle_cost_ns, | ||
[&](int64_t start, int64_t end) { | ||
for (int32_t i = start; i < end; ++i) { | ||
auto pjrt_data = | ||
std::dynamic_pointer_cast<PjRtShardedData>(arguments[i]); | ||
XLA_CHECK_EQ(pjrt_data->shards.size(), devices.size()) | ||
<< "Expected one shard per device"; | ||
|
||
for (int32_t d = 0; d < devices.size(); d++) { | ||
std::shared_ptr<PjRtData> shard = pjrt_data->shards[d]; | ||
|
||
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[d]); | ||
XLA_CHECK_EQ(shard->buffer->device(), pjrt_device); | ||
XLA_CHECK(pjrt_device->IsAddressable()) | ||
<< pjrt_device->DebugString(); | ||
|
||
argument_handles[d][i] = shard->buffer.get(); | ||
} | ||
counter.DecrementCount(); | ||
}; | ||
}); | ||
counter.Wait(); | ||
} | ||
|
||
|
@@ -726,38 +724,58 @@ PjRtComputationClient::ExecuteReplicated( | |
})); | ||
} | ||
|
||
std::vector<std::vector<ComputationClient::DataPtr>> data_handles; | ||
data_handles.reserve(results.size()); | ||
std::vector<size_t> dims(results.size()); | ||
size_t num_outputs = results[0].size(); | ||
std::vector<ComputationClient::DataPtr> data_handles(num_outputs); | ||
|
||
{ | ||
tsl::profiler::TraceMe activity( | ||
"PjRtComputationClient::ExecuteReplicated_result_handle", | ||
tsl::profiler::TraceMeLevel::kInfo); | ||
for (int32_t i = 0; i < results.size(); ++i) { | ||
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]); | ||
XLA_CHECK(pjrt_device->IsAddressable()) | ||
<< pjrt_device->DebugString() << " is not addressable."; | ||
|
||
std::vector<ComputationClient::DataPtr> datas; | ||
datas.reserve(results[i].size()); | ||
dims[i] = results[i].size(); | ||
for (int32_t j = 0; j < results[i].size(); ++j) { | ||
std::unique_ptr<xla::PjRtBuffer> buffer = std::move(results[i][j]); | ||
XLA_CHECK(pjrt_device == buffer->device()) | ||
<< "Exepcted device: " << pjrt_device->DebugString() | ||
<< " vs. actual device: " << buffer->device()->DebugString(); | ||
|
||
std::shared_ptr<PjRtData> data = | ||
std::make_shared<PjRtData>(devices[i], std::move(buffer)); | ||
datas.push_back(data); | ||
} | ||
data_handles.push_back(datas); | ||
} | ||
|
||
const xla::Shape& result_shape = computation.program_shape().result(); | ||
TF_VLOG(3) << "Processing output with shape " << result_shape.ToString(); | ||
const std::vector<xla::Shape>& output_shapes = | ||
result_shape.IsTuple() ? result_shape.tuple_shapes() | ||
: std::vector<xla::Shape>({result_shape}); | ||
XLA_CHECK_EQ(output_shapes.size(), num_outputs); | ||
|
||
const std::vector<xla::OpSharding>& output_shardings = | ||
pjrt_computation.output_shardings_ | ||
? *pjrt_computation.output_shardings_ | ||
: | ||
// Without an explicit sharding annotation, the output is implicitly | ||
// replicated | ||
std::vector(num_outputs, xla::HloSharding::Replicate().ToProto()); | ||
XLA_CHECK_EQ(output_shardings.size(), num_outputs); | ||
|
||
absl::BlockingCounter counter(num_outputs); | ||
|
||
// Time in nanoseconds that it takes to process a result buffer. | ||
// Measured on 2023/11/28. | ||
static constexpr int64_t result_handle_cost_ns = 10000; | ||
pool_.ParallelFor( | ||
num_outputs, result_handle_cost_ns, [&](int64_t start, int64_t end) { | ||
for (int32_t i = start; i < end; ++i) { | ||
std::vector<std::shared_ptr<PjRtData>> shards(devices.size()); | ||
for (int32_t d = 0; d < devices.size(); d++) { | ||
std::unique_ptr<xla::PjRtBuffer> buffer = | ||
std::move(results[d][i]); | ||
shards[d] = | ||
std::make_shared<PjRtData>(devices[d], std::move(buffer)); | ||
} | ||
|
||
data_handles[i] = std::make_shared<PjRtShardedData>( | ||
spmd_device_str, output_shapes[i], std::move(shards), | ||
output_shardings[i]); | ||
TF_VLOG(5) << "Created sharded data with shape " | ||
<< data_handles[i]->shape().ToString(); | ||
counter.DecrementCount(); | ||
} | ||
}); | ||
counter.Wait(); | ||
} | ||
|
||
TF_VLOG(1) << "Returning " << data_handles.size() << " sets of results " | ||
<< "with dimensions [" << absl::StrJoin(dims, ",") << "]."; | ||
TF_VLOG(1) << "Returning " << data_handles.size() << " sharded outputs."; | ||
return data_handles; | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I am very reluctant to change the return type here -- device x arguments (vector of vector) is more general as the original ExecuteReplicated was meant to be implemented for non-SPMD as well. Also, I see that you are removing the input handler and the output handler... those abstractions were introduced to avoid any assumption about the computation client's return type, other than it's a device Data. The actual sharidng related processing happens in the sharding util, only after realizing that it's a sharded data.
Any reason you need to refactor the SPMD related code path, under pjrt_computation_client and xla_sharding_util? I assume this is for IFRT migration, and you want to skip having a separate ifrt client?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there's a reason to keep the "vector of vector" structure here. We may have supported some way of running Replicated execution without SPMD before with XRT, but it doesn't really exist in PJRT or IFRT. You either execute on one device (PJRT's ExecuteSharded) or all local devices (IFRT's Execute). There was actually no usage of ExecuteReplicated at all within PT/XLA before PJRT and SPMD. IIRC we only kept this interface for this long to keep compatibility with XRT (even though we hadn't actually used that code path in XRT for years, if ever). We explicitly only call
ExecuteReplicated
whenasync->cached_computation->is_sharded
.It doesn't look like
InputHandler
would have handled unsharded data (what wouldGetDataShard
do to unsharded data?), nor does it look likeOutputHandler
would have handled sharded data (what wouldWrapDataShards
do to sharded data?). I think it makes much more sense to just construct the sharded data within theComputationClient
, applying thexla::OpSharding
s from the actualPjRtExecutable
. It's a relatively minor refactor in PJRT, and it makes IFRT much easier to implement.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, InputHandler and OuptutHandler works with both unsharded and sharded, and it's a no-op for unsharded data. We were able to switch around the data handling logic without touchting the data representation, and also the vice versa. Having that abstraction layer was the point I was trying to make, not the functional perspective ---- it works either way.
So for me, the question really comes down to if we are creating a separate IFRT computation client, and then later make it the main/default client for all use cases. Or, you are trying to rewrite the PJRT computation client directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also ExecuteReplicate really calls Execute from the runtime, which has nested buffer... https://source.corp.google.com/piper///depot/google3/third_party/tensorflow/compiler/xla/pjrt/pjrt_client.h;l=1373?q=Execute
Now for us, we have a sharded data (or ifrt array) that wraps a vector of buffer, so we can make it either way -- but, assuming that the computation data could just be a wrapper of the device buffer -- the former is not strange at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having said that (my reasonings for the hesitancy), if you feel that this would make the IFRT adoption easier, then let's make this refactor functional and land.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fundamental difference between PJRT and IFRT in this case is that it's trivial to "unshard" data in PJRT, since the actual runtime construct (
PjRtBuffer
) doesn't know anything about sharding. All sharding information is effectively in the framework (PjRtShardedData
) and in the compiledPjRtExecutable
.In IFRT, the runtime data construct (
ifrt::Array
) carries its own sharding information. Disassembling it actually creates newifrt::Array
s with new shardings, then we'd have to reassemble the array with the original sharding before passing it into theifrt::Executable
. The executable will then return a sharded array, which we'll again have to disassemble, then reassemble inOutputHandler
with the sharding that matches what the executable returned. This creates a lot of space for errors to hide, particularly around the ownership of the underlying shards of data, which will be aliased by multipleifrt::Array
s (compared to PJRT, where a single buffer of device data is always owned by onePjRtBuffer
).To avoid this, we either need to 1) implement the hack in my IFRT prototype PR where I skip
InputHandler
andOutputHandler
when using IFRT or 2) abstract more implementation details underComputationClient
, since the data representations need to be handled differently.Since
DataPtr
already represents sharded data, it's actually very easy for us to approximate IFRT's behavior with PJRT.PjRtShardedData
fills the same role asifrt::Array
: references to the local device data plus the sharding of that data. It would be much harder for us to approximate PjRt's behavior from IFRT.This PR is definitely inspired by IFRT, but I also think this is the right way to evolve this abstraction post-XRT. It doesn't really make the runtime client any thicker, and overall it pushes functionality down in the stack (ie out of PyTorch/XLA) now that we have more capable runtimes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hoped that the SPMD sharding logics/functions to be runtime agnostic, working with the PjRtShardedData & DataPtr abstraction. Question, ifrt::Array tracks references or the actual copies? If it's the former like the PjRtBuffer, then assemble, disassemble shouldn't be a problem? One thing for sure, we don't want to materialize those data to CPU in any time. The input handler & output handler in SPMD util just reorganizes references.
Now come to think of it, yea -- adopting ifrt::Array is the right way, and it doesn't seem to be a simple drop-in replacement. Let's make it work, and at the same time -- I wonder if we could see to make leave
InputHandler
andOutputHandler
intact so it can work with both PjRt and IfRt computaiton clients? But that sounds like they need to be a dummy call with most of logics plumbed down to the runtime layer...I am also going to expand the InputHandler to take care of resharding the inputs after the auto-sharding pass. At the high-level,
InputHandler
is supposed to prepare the inputs in PyTorch/XLA before the execution,OutputHandler
post-processes the outputs of the execution. And they were never used withxrt
, only forpjrt
&spmd
.EDIT:
+1 hide the details of how sharded data is handled. Currently that's done by the DataPtr abstraction. At the same time, the computation client is a thin layer, it wraps and provides APIs to the underlying runtime client. The input for SPMD needs to be sharded and wrapped in a DataPtr (again wraps PjRtShardedData), and the entire array of device params are prepared for the execute API. I think we are looking at different levels of abstraction, for me the data abstraction is enough -- and the sharding & sharded data logic stays outside the execution API, which stays in the PyTorch/XLA layer. Let's discuss more when you resume the work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each operation in IFRT that produces a new Array (copying, disassembling, assembling, resharding, etc) allows you to copy, donate, or reuse the input. I don't believe reusing the input actually means the new buffer shares ownership. This is not relevant in PJRT, since it's functionally impossible (by design) to alias a PJRT buffer. One piece of device data is owned by corresponds to one PJRT buffer, and we share ownership with a shared_ptr to that buffer. That's no longer true in IFRT: particularly when we're resharding, we may end up with multiple arrays that correspond to the same device data, and one array owns that data.
The IFRT client will also need to handle one more unique case: for single-controller, Arrays will actually represent both the addressable and non-addressable shards of a global tensor. The representations of sharded data are different enough that the
ComputationClient
implementations will have to treat the sharded data containers differently.The current function of
InputHandler
(ie converting a vector of objects that each contain a vector of shards, into a vector of vectors of shards) should move down intoPjRtComputationClient
, since that is specific to the API ofPjRtClient::Execute
. Likewise, the inverse operation inOutputHandler
belongs inPjRtComputationClient
for the same reason.OutputHandler
does one more thing: assign shardings to the outputs.Intuitively, I would say it makes the most sense to assign the global shapes and shardings to each sharded data based on the executable that produced them. I'm pretty sure the
ShardingSpec
s inOutputHandler
originate from the compiledExecutable
'sGetHloModules
anyway, and PJRT added APIs for specifically this purpose since we started SPMD. But these calls are far too slow in practice.When I went to add back the sharding assignment in
OutputHandler
, I noticed that we don't even actually use the shardings there! We assign all of the outputs to placeholders:xla/torch_xla/csrc/xla_graph_executor.cpp
Lines 749 to 756 in b20a082
xla/torch_xla/csrc/xla_graph_executor.cpp
Lines 1022 to 1025 in b20a082
Assign
completely ignores the sharding and shape on the right side:xla/torch_xla/csrc/runtime/pjrt_computation_client.h
Lines 192 to 198 in b20a082
In commit 5c45ca6, I just gave all of the outputs a placeholder "manual" sharding, since the placeholders have the real sharding anyway.
edit: I might be wrong about the last part. Added a check and running CI again... edit edit: Yeah, we only use the shardings on the placeholders.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, since we should extract the output (propagated, non-propagated) after the compilation, and they are available when we prepare the placeholders. One thing to note (synced offline) is that the UNKNOWN sharidng type carry a special meaning for the upcoming auto-sharding change. For this PR, we can just focus on avoiding any regressions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, I gave the outputs here "real" shardings based on the compiled executable, since
UNKNOWN
will have a real semantic meaning soon. The extra call toGetOutputShardings
works here since I only call it once per compilation, rather than once per execution.