-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor ExecuteReplicated to operate on sharded data directly #5737
Conversation
virtual std::vector<std::vector<DataPtr>> ExecuteReplicated( | ||
const Computation& computation, | ||
const std::vector<std::vector<DataPtr>>& arguments, | ||
virtual std::vector<DataPtr> ExecuteReplicated( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I am very reluctant to change the return type here -- device x arguments (vector of vector) is more general as the original ExecuteReplicated was meant to be implemented for non-SPMD as well. Also, I see that you are removing the input handler and the output handler... those abstractions were introduced to avoid any assumption about the computation client's return type, other than it's a device Data. The actual sharidng related processing happens in the sharding util, only after realizing that it's a sharded data.
Any reason you need to refactor the SPMD related code path, under pjrt_computation_client and xla_sharding_util? I assume this is for IFRT migration, and you want to skip having a separate ifrt client?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there's a reason to keep the "vector of vector" structure here. We may have supported some way of running Replicated execution without SPMD before with XRT, but it doesn't really exist in PJRT or IFRT. You either execute on one device (PJRT's ExecuteSharded) or all local devices (IFRT's Execute). There was actually no usage of ExecuteReplicated at all within PT/XLA before PJRT and SPMD. IIRC we only kept this interface for this long to keep compatibility with XRT (even though we hadn't actually used that code path in XRT for years, if ever). We explicitly only call ExecuteReplicated
when async->cached_computation->is_sharded
.
It doesn't look like InputHandler
would have handled unsharded data (what would GetDataShard
do to unsharded data?), nor does it look like OutputHandler
would have handled sharded data (what would WrapDataShards
do to sharded data?). I think it makes much more sense to just construct the sharded data within the ComputationClient
, applying the xla::OpSharding
s from the actual PjRtExecutable
. It's a relatively minor refactor in PJRT, and it makes IFRT much easier to implement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, InputHandler and OuptutHandler works with both unsharded and sharded, and it's a no-op for unsharded data. We were able to switch around the data handling logic without touchting the data representation, and also the vice versa. Having that abstraction layer was the point I was trying to make, not the functional perspective ---- it works either way.
So for me, the question really comes down to if we are creating a separate IFRT computation client, and then later make it the main/default client for all use cases. Or, you are trying to rewrite the PJRT computation client directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also ExecuteReplicate really calls Execute from the runtime, which has nested buffer... https://source.corp.google.com/piper///depot/google3/third_party/tensorflow/compiler/xla/pjrt/pjrt_client.h;l=1373?q=Execute
Now for us, we have a sharded data (or ifrt array) that wraps a vector of buffer, so we can make it either way -- but, assuming that the computation data could just be a wrapper of the device buffer -- the former is not strange at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having said that (my reasonings for the hesitancy), if you feel that this would make the IFRT adoption easier, then let's make this refactor functional and land.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fundamental difference between PJRT and IFRT in this case is that it's trivial to "unshard" data in PJRT, since the actual runtime construct (PjRtBuffer
) doesn't know anything about sharding. All sharding information is effectively in the framework (PjRtShardedData
) and in the compiled PjRtExecutable
.
In IFRT, the runtime data construct (ifrt::Array
) carries its own sharding information. Disassembling it actually creates new ifrt::Array
s with new shardings, then we'd have to reassemble the array with the original sharding before passing it into the ifrt::Executable
. The executable will then return a sharded array, which we'll again have to disassemble, then reassemble in OutputHandler
with the sharding that matches what the executable returned. This creates a lot of space for errors to hide, particularly around the ownership of the underlying shards of data, which will be aliased by multiple ifrt::Array
s (compared to PJRT, where a single buffer of device data is always owned by one PjRtBuffer
).
To avoid this, we either need to 1) implement the hack in my IFRT prototype PR where I skip InputHandler
and OutputHandler
when using IFRT or 2) abstract more implementation details under ComputationClient
, since the data representations need to be handled differently.
Since DataPtr
already represents sharded data, it's actually very easy for us to approximate IFRT's behavior with PJRT. PjRtShardedData
fills the same role as ifrt::Array
: references to the local device data plus the sharding of that data. It would be much harder for us to approximate PjRt's behavior from IFRT.
This PR is definitely inspired by IFRT, but I also think this is the right way to evolve this abstraction post-XRT. It doesn't really make the runtime client any thicker, and overall it pushes functionality down in the stack (ie out of PyTorch/XLA) now that we have more capable runtimes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hoped that the SPMD sharding logics/functions to be runtime agnostic, working with the PjRtShardedData & DataPtr abstraction. Question, ifrt::Array tracks references or the actual copies? If it's the former like the PjRtBuffer, then assemble, disassemble shouldn't be a problem? One thing for sure, we don't want to materialize those data to CPU in any time. The input handler & output handler in SPMD util just reorganizes references.
Now come to think of it, yea -- adopting ifrt::Array is the right way, and it doesn't seem to be a simple drop-in replacement. Let's make it work, and at the same time -- I wonder if we could see to make leave InputHandler
and OutputHandler
intact so it can work with both PjRt and IfRt computaiton clients? But that sounds like they need to be a dummy call with most of logics plumbed down to the runtime layer...
I am also going to expand the InputHandler to take care of resharding the inputs after the auto-sharding pass. At the high-level, InputHandler
is supposed to prepare the inputs in PyTorch/XLA before the execution, OutputHandler
post-processes the outputs of the execution. And they were never used with xrt
, only for pjrt
& spmd
.
EDIT:
This addresses a couple of pain points I found when prototyping IFRT in #5677. ComputationClient better hides the details of how sharded data is handled, since it's not required to be broken up before passing into ExecuteReplicated.
+1 hide the details of how sharded data is handled. Currently that's done by the DataPtr abstraction. At the same time, the computation client is a thin layer, it wraps and provides APIs to the underlying runtime client. The input for SPMD needs to be sharded and wrapped in a DataPtr (again wraps PjRtShardedData), and the entire array of device params are prepared for the execute API. I think we are looking at different levels of abstraction, for me the data abstraction is enough -- and the sharding & sharded data logic stays outside the execution API, which stays in the PyTorch/XLA layer. Let's discuss more when you resume the work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each operation in IFRT that produces a new Array (copying, disassembling, assembling, resharding, etc) allows you to copy, donate, or reuse the input. I don't believe reusing the input actually means the new buffer shares ownership. This is not relevant in PJRT, since it's functionally impossible (by design) to alias a PJRT buffer. One piece of device data is owned by corresponds to one PJRT buffer, and we share ownership with a shared_ptr to that buffer. That's no longer true in IFRT: particularly when we're resharding, we may end up with multiple arrays that correspond to the same device data, and one array owns that data.
The IFRT client will also need to handle one more unique case: for single-controller, Arrays will actually represent both the addressable and non-addressable shards of a global tensor. The representations of sharded data are different enough that the ComputationClient
implementations will have to treat the sharded data containers differently.
The current function of InputHandler
(ie converting a vector of objects that each contain a vector of shards, into a vector of vectors of shards) should move down into PjRtComputationClient
, since that is specific to the API of PjRtClient::Execute
. Likewise, the inverse operation in OutputHandler
belongs in PjRtComputationClient
for the same reason. OutputHandler
does one more thing: assign shardings to the outputs.
Intuitively, I would say it makes the most sense to assign the global shapes and shardings to each sharded data based on the executable that produced them. I'm pretty sure the ShardingSpec
s in OutputHandler
originate from the compiled Executable
's GetHloModules
anyway, and PJRT added APIs for specifically this purpose since we started SPMD. But these calls are far too slow in practice.
When I went to add back the sharding assignment in OutputHandler
, I noticed that we don't even actually use the shardings there! We assign all of the outputs to placeholders:
xla/torch_xla/csrc/xla_graph_executor.cpp
Lines 749 to 756 in b20a082
{ | |
tsl::profiler::TraceMe activity("update_placeholder", | |
tsl::profiler::TraceMeLevel::kInfo); | |
for (size_t i = 0; i < results.size(); ++i) { | |
XLA_CHECK(async->tensors_data[i] != nullptr); | |
async->tensors_data[i]->Assign(*results[i]); | |
} | |
} |
xla/torch_xla/csrc/xla_graph_executor.cpp
Lines 1022 to 1025 in b20a082
for (size_t i = 0; i < results.size(); ++i) { | |
if (async->tensors_data[i] != nullptr) { | |
async->tensors_data[i]->Assign(*results[i]); | |
} else { |
Assign
completely ignores the sharding and shape on the right side:
xla/torch_xla/csrc/runtime/pjrt_computation_client.h
Lines 192 to 198 in b20a082
void Assign(const torch::lazy::BackendData& data) override { | |
const PjRtShardedData& pjrt_sharded_data = | |
dynamic_cast<const PjRtShardedData&>(data); | |
if (&pjrt_sharded_data != this) { | |
shards = std::move(pjrt_sharded_data.shards); | |
} | |
} |
In commit 5c45ca6, I just gave all of the outputs a placeholder "manual" sharding, since the placeholders have the real sharding anyway.
edit: I might be wrong about the last part. Added a check and running CI again... edit edit: Yeah, we only use the shardings on the placeholders.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, since we should extract the output (propagated, non-propagated) after the compilation, and they are available when we prepare the placeholders. One thing to note (synced offline) is that the UNKNOWN sharidng type carry a special meaning for the upcoming auto-sharding change. For this PR, we can just focus on avoiding any regressions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, I gave the outputs here "real" shardings based on the compiled executable, since UNKNOWN
will have a real semantic meaning soon. The extra call to GetOutputShardings
works here since I only call it once per compilation, rather than once per execution.
Unblocking the PR and will review again when it's ready.
I'm going to leave this as a draft, because this PR introduces a substantial performance regression. For even a basic ResNet run on TPU v4, the throughput drops from My first thought is that I'm spawning many more threads in this PR to prepare the input Baseline built from
With this PR:
So the total latency from edit: the Full metrics: |
Digging in with the profiler, the additional cost is overwhelmingly coming from processing the output ( |
Here's the problem:
The solution to restore performance is to build our own The best solution here is to add lighter-weight implementations of |
std::make_shared<PjRtData>(devices[i], std::move(buffer)); | ||
datas.push_back(data); | ||
|
||
xla::HloModuleConfig hlo_config(computation.program_shape()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like we are pushing down the execution results post-processing logics from the PyTorch/XLA (sharding util) down to the runtime client.. which I am still not sure of.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None of the substantive logic is moving down. The input and output shardings are still entirely determined by the user's sharding annotations and the compiler. The output shardings here even come from the same place they were coming from before: the compiled executable
xla/torch_xla/csrc/xla_sharding_util.cpp
Lines 660 to 670 in b20a082
if (computation_proto.has_spmd_output_sharding()) { | |
if (computation_proto.spmd_output_sharding().tuple_shardings().size() > 0) { | |
auto tuple_shardings = | |
computation_proto.spmd_output_sharding().tuple_shardings(); | |
output_shardings = std::vector<xla::OpSharding>(tuple_shardings.begin(), | |
tuple_shardings.end()); | |
} else { | |
output_shardings = std::vector<xla::OpSharding>{ | |
computation_proto.spmd_output_sharding()}; | |
} | |
} |
This is just shifting the plumbing around so that ExecuteReplicated
can consume and emit sharded data. The only reason I'm building an HloModule in this draft is because calling GetOutputShardings
directly on the executable is just too slow in practice. This is otherwise just the XlaComputation
that we extract from the executable in Compile
.
It turns out that the shapes/shardings that we set here and in OutputHandler
are not used, because we only use the shape/sharding on the placeholder anyway. See my other comment: #5737 (comment)
sharding_specs); | ||
runtime::GetComputationClient()->ExecuteReplicated( | ||
*async->cached_computation->computation, | ||
UnwrapXlaData(async->parameters_data), devices, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, the reorganizing of the params data to device_arguments happens inside the ExecuteReplicated and to call Execute.
Wow, this PR became a rabbit hole. Basically, since this PR spawns potentially many more threads than before (proportional to the number of tensors, instead of proportional to the number of devices), some known issues about concurrent overhead become way more severe. I ended up making several changes here separate from the main goal of updating the
I finally have a good level of performance. The More importantly, llama2 7B inference with dynamo + SPMD is comparable or marginally faster with this PR compared to HEAD: Still consider the details of this PR a draft. It needs to be split up, and we should align on the API organization before a final review. Also, a test just failed which means I probably broke something else. For my llama experiment, I used the |
a5f5ef6
to
f51ae47
Compare
Finally getting back to this PR after merging supporting changes. This actually does seem to be making a material impact on performance on LLama2 7B. At the current master, the benchmark hovers around
It's consistently lower now, for example
|
The CI failures look similar to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks Will! I'll let @yeounoh send the approval.
auto pjrt_data = std::dynamic_pointer_cast<PjRtData>(handle); | ||
XLA_CHECK(pjrt_data) << "Data must be PjRtData or PjRtShardedData."; | ||
return pjrt_data; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this an else if (PjRtData* pjrt_data = dynamic_cast<PjRtData*>(handle.get()))
, and add the XLA_CHECK as the fallthrough case? The error message threw me off for a sec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, this is obscure. Fixed.
|
||
// Time in nanoseconds that it takes to prepare an argument. Used to tune | ||
// number of threads spawned by ParallelFor. Measured on 2023/11/28. | ||
static constexpr int64_t argument_handle_cost_ns = 10000; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How important is the accuracy of this value? I assume this value depends on the host machine's specs, so it would be different across TPU generations and GPU VMs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm just shooting for the right order of magnitude here -- I didn't see any noticeable difference between this and a 30000 ns
estimate from a previous commit. Supposedly we can set this to the number of CPU cycles (which would be more consistent), but I don't know how to measure that accurately.
Since this is really obscure, and it doesn't seem to be that sensitive, I just put a nice round number in the correct range here rather than adding another configuration knob.
PjRtComputationClient::ExecuteReplicated( | ||
const ComputationClient::Computation& computation, | ||
const std::vector<std::vector<ComputationClient::DataPtr>>& arguments, | ||
absl::Span<const ComputationClient::DataPtr> arguments, | ||
absl::Span<const std::string> devices, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we still need devices
as input here - should the caller need to be aware of which devices the shards live on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
devices
doesn't really make sense since we dropped XRT, since PJRT only lets you execute on one device or all devices. I added a TODO to remove this in my IFRT prototype PR. In practice, the caller does seem to be setting the "correct" value here, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
@@ -231,10 +231,16 @@ class PjRtComputationClient : public ComputationClient { | |||
std::vector<std::string> devices, | |||
std::unique_ptr<xla::PjRtLoadedExecutable> executable) | |||
: Computation(std::move(computation), std::move(devices)), | |||
executable(std::move(executable)) {} | |||
executable(std::move(executable)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's already being done as part of data placeholder prep, then this is redundant. It looks like we have to extract the output shardings as part of the pjrt_computation
iniailzation, I think with AOT compilation and auto-sharding, it makes sense to do it strictly around the compilation (we will have to do it anyway there), and avoid doing it part of the execution call. Preferably, keep the pjrt_copmputation
just a wrapper without much processing/overhead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is redundant. I only added this here because I wanted to avoid returning UNKNOWN
out of ExecuteReplicated
, and GetOutputSharding
is far too slow to call per execution. This does run once per compilation and is cached. PjRtComputation
is still a wrapper; it's just caching a really slow call we know we'll need down the line. We were effectively doing the same thing by saving ProgramShape
in Computation
already.
Depending on how auto-sharding works, and if we add the output shardings to the top-level Computation
, the placeholders shardings could easily instead use the compiled executable's output shardings from here rather than digging into the HLO proto manually:
xla/torch_xla/csrc/xla_sharding_util.cpp
Lines 618 to 633 in b9475d9
const torch::lazy::BackendDevice& device) { | |
const auto& computation_proto = computation->computation().proto(); | |
uint64_t num_outputs = output_shapes->size(); | |
std::vector<xla::OpSharding> output_shardings; | |
std::vector<XLATensor::ShardingSpecPtr> sharding_specs(num_outputs); | |
if (computation_proto.has_spmd_output_sharding()) { | |
if (computation_proto.spmd_output_sharding().tuple_shardings().size() > 0) { | |
auto tuple_shardings = | |
computation_proto.spmd_output_sharding().tuple_shardings(); | |
output_shardings = std::vector<xla::OpSharding>(tuple_shardings.begin(), | |
tuple_shardings.end()); | |
} else { | |
output_shardings = std::vector<xla::OpSharding>{ | |
computation_proto.spmd_output_sharding()}; | |
} | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the fact that it's making the runtime and the replicated execution more performance, but it's less modular than before. Meanwhile, I left some comments around output sharding extraction being part of the execution -- as a heads-up/ FYI.
Overall ,this is great -- thanks @will-cromar
deed402
to
9bc594e
Compare
…ch#5737) * Refactor ExecuteReplicated to operate on sharded data directly * Remove old handlers * formatting * Improve naming and logging * update docstring * Remove obsolete unit tests * improve comment * Remove slow calls to get output shapes. * fix implicit sharding * remove declarations of input/output handlers * formatting * give everything a manual placeholder sharding * see if CI passes * formatting * Shard parameter and output handling * Use absl::BlockingCounter * formatting * fix merge * Assign valid output shardings * tune and document costs * formatting * implicitly replicate output to match outputhandler * clarify ReplicateShardedData * fix merge
…ch#5737) * Refactor ExecuteReplicated to operate on sharded data directly * Remove old handlers * formatting * Improve naming and logging * update docstring * Remove obsolete unit tests * improve comment * Remove slow calls to get output shapes. * fix implicit sharding * remove declarations of input/output handlers * formatting * give everything a manual placeholder sharding * see if CI passes * formatting * Shard parameter and output handling * Use absl::BlockingCounter * formatting * fix merge * Assign valid output shardings * tune and document costs * formatting * implicitly replicate output to match outputhandler * clarify ReplicateShardedData * fix merge
…ch#5737) * Refactor ExecuteReplicated to operate on sharded data directly * Remove old handlers * formatting * Improve naming and logging * update docstring * Remove obsolete unit tests * improve comment * Remove slow calls to get output shapes. * fix implicit sharding * remove declarations of input/output handlers * formatting * give everything a manual placeholder sharding * see if CI passes * formatting * Shard parameter and output handling * Use absl::BlockingCounter * formatting * fix merge * Assign valid output shardings * tune and document costs * formatting * implicitly replicate output to match outputhandler * clarify ReplicateShardedData * fix merge
* Refactor ExecuteReplicated to operate on sharded data directly * Remove old handlers * formatting * Improve naming and logging * update docstring * Remove obsolete unit tests * improve comment * Remove slow calls to get output shapes. * fix implicit sharding * remove declarations of input/output handlers * formatting * give everything a manual placeholder sharding * see if CI passes * formatting * Shard parameter and output handling * Use absl::BlockingCounter * formatting * fix merge * Assign valid output shardings * tune and document costs * formatting * implicitly replicate output to match outputhandler * clarify ReplicateShardedData * fix merge
* Refactor ExecuteReplicated to operate on sharded data directly * Remove old handlers * formatting * Improve naming and logging * update docstring * Remove obsolete unit tests * improve comment * Remove slow calls to get output shapes. * fix implicit sharding * remove declarations of input/output handlers * formatting * give everything a manual placeholder sharding * see if CI passes * formatting * Shard parameter and output handling * Use absl::BlockingCounter * formatting * fix merge * Assign valid output shardings * tune and document costs * formatting * implicitly replicate output to match outputhandler * clarify ReplicateShardedData * fix merge
ExecuteReplicated
to consume sharded data directly now that we don't have to keep consistency with XRT.InputHandler
andOutputHandler
.ExecuteReplicated
toabsl::Span<const DataPtr>
to matchExecuteComputation
ReplicateShardedData
can passconst DataPtr& handle
directly toExecuteReplicated
nowReplicateShardedData
private and return aPjRtData
explicitly now.ExecuteReplicated
data handling overnum_outputs
instead ofnum_devices
. Since each shard runs much more quickly now, usetsl::TheadPool::ParallelFor
to balance thread latency with parallelism.This addresses a couple of pain points I found when prototyping IFRT in #5677.
ComputationClient
better hides the details of how sharded data is handled, since it's not required to be broken up before passing intoExecuteReplicated
.This PR appears to have a small positive impact on performance on llama2 7B on v4-8. Before:
After: