Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ExecuteReplicated to operate on sharded data directly #5737

Merged
merged 24 commits into from
Nov 30, 2023

Conversation

will-cromar
Copy link
Collaborator

@will-cromar will-cromar commented Oct 25, 2023

  • Refactor ExecuteReplicated to consume sharded data directly now that we don't have to keep consistency with XRT.
    • Remove InputHandler and OutputHandler.
  • Change input type of ExecuteReplicated to absl::Span<const DataPtr> to match ExecuteComputation
    • Useful side-effect: ReplicateShardedData can pass const DataPtr& handle directly to ExecuteReplicated now
  • Make ReplicateShardedData private and return a PjRtData explicitly now.
  • Parallelize ExecuteReplicated data handling over num_outputs instead of num_devices. Since each shard runs much more quickly now, use tsl::TheadPool::ParallelFor to balance thread latency with parallelism.

This addresses a couple of pain points I found when prototyping IFRT in #5677. ComputationClient better hides the details of how sharded data is handled, since it's not required to be broken up before passing into ExecuteReplicated.

This PR appears to have a small positive impact on performance on llama2 7B on v4-8. Before:

Totally decoded 1007 tokens in 7.06039 seconds

After:

Totally decoded 1007 tokens in 7.01075 seconds

virtual std::vector<std::vector<DataPtr>> ExecuteReplicated(
const Computation& computation,
const std::vector<std::vector<DataPtr>>& arguments,
virtual std::vector<DataPtr> ExecuteReplicated(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I am very reluctant to change the return type here -- device x arguments (vector of vector) is more general as the original ExecuteReplicated was meant to be implemented for non-SPMD as well. Also, I see that you are removing the input handler and the output handler... those abstractions were introduced to avoid any assumption about the computation client's return type, other than it's a device Data. The actual sharidng related processing happens in the sharding util, only after realizing that it's a sharded data.

Any reason you need to refactor the SPMD related code path, under pjrt_computation_client and xla_sharding_util? I assume this is for IFRT migration, and you want to skip having a separate ifrt client?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's a reason to keep the "vector of vector" structure here. We may have supported some way of running Replicated execution without SPMD before with XRT, but it doesn't really exist in PJRT or IFRT. You either execute on one device (PJRT's ExecuteSharded) or all local devices (IFRT's Execute). There was actually no usage of ExecuteReplicated at all within PT/XLA before PJRT and SPMD. IIRC we only kept this interface for this long to keep compatibility with XRT (even though we hadn't actually used that code path in XRT for years, if ever). We explicitly only call ExecuteReplicated when async->cached_computation->is_sharded.

It doesn't look like InputHandler would have handled unsharded data (what would GetDataShard do to unsharded data?), nor does it look like OutputHandler would have handled sharded data (what would WrapDataShards do to sharded data?). I think it makes much more sense to just construct the sharded data within the ComputationClient, applying the xla::OpShardings from the actual PjRtExecutable. It's a relatively minor refactor in PJRT, and it makes IFRT much easier to implement.

Copy link
Contributor

@yeounoh yeounoh Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, InputHandler and OuptutHandler works with both unsharded and sharded, and it's a no-op for unsharded data. We were able to switch around the data handling logic without touchting the data representation, and also the vice versa. Having that abstraction layer was the point I was trying to make, not the functional perspective ---- it works either way.

So for me, the question really comes down to if we are creating a separate IFRT computation client, and then later make it the main/default client for all use cases. Or, you are trying to rewrite the PJRT computation client directly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also ExecuteReplicate really calls Execute from the runtime, which has nested buffer... https://source.corp.google.com/piper///depot/google3/third_party/tensorflow/compiler/xla/pjrt/pjrt_client.h;l=1373?q=Execute

Now for us, we have a sharded data (or ifrt array) that wraps a vector of buffer, so we can make it either way -- but, assuming that the computation data could just be a wrapper of the device buffer -- the former is not strange at all?

Copy link
Contributor

@yeounoh yeounoh Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having said that (my reasonings for the hesitancy), if you feel that this would make the IFRT adoption easier, then let's make this refactor functional and land.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fundamental difference between PJRT and IFRT in this case is that it's trivial to "unshard" data in PJRT, since the actual runtime construct (PjRtBuffer) doesn't know anything about sharding. All sharding information is effectively in the framework (PjRtShardedData) and in the compiled PjRtExecutable.

In IFRT, the runtime data construct (ifrt::Array) carries its own sharding information. Disassembling it actually creates new ifrt::Arrays with new shardings, then we'd have to reassemble the array with the original sharding before passing it into the ifrt::Executable. The executable will then return a sharded array, which we'll again have to disassemble, then reassemble in OutputHandler with the sharding that matches what the executable returned. This creates a lot of space for errors to hide, particularly around the ownership of the underlying shards of data, which will be aliased by multiple ifrt::Arrays (compared to PJRT, where a single buffer of device data is always owned by one PjRtBuffer).

To avoid this, we either need to 1) implement the hack in my IFRT prototype PR where I skip InputHandler and OutputHandler when using IFRT or 2) abstract more implementation details under ComputationClient, since the data representations need to be handled differently.

Since DataPtr already represents sharded data, it's actually very easy for us to approximate IFRT's behavior with PJRT. PjRtShardedData fills the same role as ifrt::Array: references to the local device data plus the sharding of that data. It would be much harder for us to approximate PjRt's behavior from IFRT.

This PR is definitely inspired by IFRT, but I also think this is the right way to evolve this abstraction post-XRT. It doesn't really make the runtime client any thicker, and overall it pushes functionality down in the stack (ie out of PyTorch/XLA) now that we have more capable runtimes.

Copy link
Contributor

@yeounoh yeounoh Nov 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hoped that the SPMD sharding logics/functions to be runtime agnostic, working with the PjRtShardedData & DataPtr abstraction. Question, ifrt::Array tracks references or the actual copies? If it's the former like the PjRtBuffer, then assemble, disassemble shouldn't be a problem? One thing for sure, we don't want to materialize those data to CPU in any time. The input handler & output handler in SPMD util just reorganizes references.

Now come to think of it, yea -- adopting ifrt::Array is the right way, and it doesn't seem to be a simple drop-in replacement. Let's make it work, and at the same time -- I wonder if we could see to make leave InputHandler and OutputHandler intact so it can work with both PjRt and IfRt computaiton clients? But that sounds like they need to be a dummy call with most of logics plumbed down to the runtime layer...

I am also going to expand the InputHandler to take care of resharding the inputs after the auto-sharding pass. At the high-level, InputHandler is supposed to prepare the inputs in PyTorch/XLA before the execution, OutputHandler post-processes the outputs of the execution. And they were never used with xrt, only for pjrt & spmd.

EDIT:

This addresses a couple of pain points I found when prototyping IFRT in #5677. ComputationClient better hides the details of how sharded data is handled, since it's not required to be broken up before passing into ExecuteReplicated.

+1 hide the details of how sharded data is handled. Currently that's done by the DataPtr abstraction. At the same time, the computation client is a thin layer, it wraps and provides APIs to the underlying runtime client. The input for SPMD needs to be sharded and wrapped in a DataPtr (again wraps PjRtShardedData), and the entire array of device params are prepared for the execute API. I think we are looking at different levels of abstraction, for me the data abstraction is enough -- and the sharding & sharded data logic stays outside the execution API, which stays in the PyTorch/XLA layer. Let's discuss more when you resume the work.

Copy link
Collaborator Author

@will-cromar will-cromar Nov 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each operation in IFRT that produces a new Array (copying, disassembling, assembling, resharding, etc) allows you to copy, donate, or reuse the input. I don't believe reusing the input actually means the new buffer shares ownership. This is not relevant in PJRT, since it's functionally impossible (by design) to alias a PJRT buffer. One piece of device data is owned by corresponds to one PJRT buffer, and we share ownership with a shared_ptr to that buffer. That's no longer true in IFRT: particularly when we're resharding, we may end up with multiple arrays that correspond to the same device data, and one array owns that data.

The IFRT client will also need to handle one more unique case: for single-controller, Arrays will actually represent both the addressable and non-addressable shards of a global tensor. The representations of sharded data are different enough that the ComputationClient implementations will have to treat the sharded data containers differently.

The current function of InputHandler (ie converting a vector of objects that each contain a vector of shards, into a vector of vectors of shards) should move down into PjRtComputationClient, since that is specific to the API of PjRtClient::Execute. Likewise, the inverse operation in OutputHandler belongs in PjRtComputationClient for the same reason. OutputHandler does one more thing: assign shardings to the outputs.

Intuitively, I would say it makes the most sense to assign the global shapes and shardings to each sharded data based on the executable that produced them. I'm pretty sure the ShardingSpecs in OutputHandler originate from the compiled Executable's GetHloModules anyway, and PJRT added APIs for specifically this purpose since we started SPMD. But these calls are far too slow in practice.

When I went to add back the sharding assignment in OutputHandler, I noticed that we don't even actually use the shardings there! We assign all of the outputs to placeholders:

{
tsl::profiler::TraceMe activity("update_placeholder",
tsl::profiler::TraceMeLevel::kInfo);
for (size_t i = 0; i < results.size(); ++i) {
XLA_CHECK(async->tensors_data[i] != nullptr);
async->tensors_data[i]->Assign(*results[i]);
}
}

for (size_t i = 0; i < results.size(); ++i) {
if (async->tensors_data[i] != nullptr) {
async->tensors_data[i]->Assign(*results[i]);
} else {

Assign completely ignores the sharding and shape on the right side:

void Assign(const torch::lazy::BackendData& data) override {
const PjRtShardedData& pjrt_sharded_data =
dynamic_cast<const PjRtShardedData&>(data);
if (&pjrt_sharded_data != this) {
shards = std::move(pjrt_sharded_data.shards);
}
}

In commit 5c45ca6, I just gave all of the outputs a placeholder "manual" sharding, since the placeholders have the real sharding anyway.

edit: I might be wrong about the last part. Added a check and running CI again... edit edit: Yeah, we only use the shardings on the placeholders.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, since we should extract the output (propagated, non-propagated) after the compilation, and they are available when we prepare the placeholders. One thing to note (synced offline) is that the UNKNOWN sharidng type carry a special meaning for the upcoming auto-sharding change. For this PR, we can just focus on avoiding any regressions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, I gave the outputs here "real" shardings based on the compiled executable, since UNKNOWN will have a real semantic meaning soon. The extra call to GetOutputShardings works here since I only call it once per compilation, rather than once per execution.

@yeounoh yeounoh dismissed their stale review October 26, 2023 22:28

Unblocking the PR and will review again when it's ready.

@will-cromar
Copy link
Collaborator Author

will-cromar commented Oct 27, 2023

I'm going to leave this as a draft, because this PR introduces a substantial performance regression. For even a basic ResNet run on TPU v4, the throughput drops from 7300 ex/sec to 2500 ex/sec. The IFRT prototype is similarly bad in terms of performance, potentially because it implements a version of this change.

My first thought is that I'm spawning many more threads in this PR to prepare the input PjRtBuffer*s (proportional to the number of arguments to an executable, compared to the number of devices), and the latency of starting a thread could add up. I would then expect ExecuteReplicatedTime to shoot up, which is.... not what is happening

Baseline built from master, after one epoch of ResNet50 + fake data on TPU v4:

Metric: ExecuteReplicatedTime
  TotalSamples: 586
  Accumulator: 15m54s068ms128.743us
  ValueRate: 03s474ms899.642us / second
  Rate: 2.2769 / second
  Percentiles: 1%=483ms438.931us; 5%=498ms772.820us; 10%=677ms741.942us; 20%=01s080ms232.120us; 50%=02s619ms632.615us; 80%=02s957ms133.071us; 90%=02s094ms352.027us; 95%=02s155ms739.991us; 99%=02s320ms891.565us

With this PR:

Metric: ExecuteReplicatedTime
  TotalSamples: 586
  Accumulator: 08m34s017ms471.972us
  ValueRate: 777ms521.088us / second
  Rate: 1.00226 / second
  Percentiles: 1%=691ms238.872us; 5%=707ms665.990us; 10%=717ms541.769us; 20%=724ms875.148us; 50%=753ms426.035us; 80%=778ms197.963us; 90%=911ms861.770us; 95%=980ms888.872us; 99%=01s145ms639.555us

So the total latency from ExecuteReplicated alone dropped by ~half, but the overall run is abysmally slow. I don't see anything sticking out in the metrics reports (attached), so I may need to dig in with the profiler.

edit: the DeviceLockWait skyrockets from 23s to 15m

Full metrics:

pjrt_metrics.txt
refactored_pjrt_metrics.txt

@will-cromar
Copy link
Collaborator Author

will-cromar commented Oct 30, 2023

Digging in with the profiler, the additional cost is overwhelmingly coming from processing the output (ExecuteReplicated_result_handle). In particular, getting the output shapes, element types, and shardings seems to be very slow. These three calls take 400 to 500 ms. This is longer than entire synchronous portion of ExecuteReplicated used to be, which usually seemed to take less than 10 ms in my profiles.

@will-cromar
Copy link
Collaborator Author

will-cromar commented Oct 31, 2023

Here's the problem:

  • The underlying implementation of GetHloModules is very expensive, on the order of 100s of ms for our relatively small ResNet50 benchmark.
  • The default implementations of GetOutputElementTypes, GetOutputDimensions, and GetOutputSharding all call GetHloModules() or GetOutputShapes, which calls GetHloModules. The wrapped TPU client under the C API does not override these default implementations.
  • Direct calls to GetHloModules are even worse, since we have an expensive call, and then it's expensive to move the result across the PJRT C API. This can take up to 1 second in my profiles.

The solution to restore performance is to build our own xla::HloModule (since we have a copy of the xla::Computation) and extract the output shapes and shardings from that. Throughput increases back to ~7330 ex/sec, and there's still room for improvement in the synchronous part of ExecuteReplicated. The total DeviceLockWait is still around 4 minutes. But, the overall performance is similar with the commit I am about to push.

The best solution here is to add lighter-weight implementations of GetOutputDimensions etc in the underlying runtime, since these were designed to be cheaper to call across the PJRT C API.

std::make_shared<PjRtData>(devices[i], std::move(buffer));
datas.push_back(data);

xla::HloModuleConfig hlo_config(computation.program_shape());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like we are pushing down the execution results post-processing logics from the PyTorch/XLA (sharding util) down to the runtime client.. which I am still not sure of.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of the substantive logic is moving down. The input and output shardings are still entirely determined by the user's sharding annotations and the compiler. The output shardings here even come from the same place they were coming from before: the compiled executable

if (computation_proto.has_spmd_output_sharding()) {
if (computation_proto.spmd_output_sharding().tuple_shardings().size() > 0) {
auto tuple_shardings =
computation_proto.spmd_output_sharding().tuple_shardings();
output_shardings = std::vector<xla::OpSharding>(tuple_shardings.begin(),
tuple_shardings.end());
} else {
output_shardings = std::vector<xla::OpSharding>{
computation_proto.spmd_output_sharding()};
}
}

This is just shifting the plumbing around so that ExecuteReplicated can consume and emit sharded data. The only reason I'm building an HloModule in this draft is because calling GetOutputShardings directly on the executable is just too slow in practice. This is otherwise just the XlaComputation that we extract from the executable in Compile.

It turns out that the shapes/shardings that we set here and in OutputHandler are not used, because we only use the shape/sharding on the placeholder anyway. See my other comment: #5737 (comment)

sharding_specs);
runtime::GetComputationClient()->ExecuteReplicated(
*async->cached_computation->computation,
UnwrapXlaData(async->parameters_data), devices,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, the reorganizing of the params data to device_arguments happens inside the ExecuteReplicated and to call Execute.

@will-cromar
Copy link
Collaborator Author

Wow, this PR became a rabbit hole.

Basically, since this PR spawns potentially many more threads than before (proportional to the number of tensors, instead of proportional to the number of devices), some known issues about concurrent overhead become way more severe. I ended up making several changes here separate from the main goal of updating the ExecuteReplicated API, which I will split into new PRs. The main performance updates in the past few commits:

  • Use tsl::thread::ThreadPool: TSL provides a more efficient threadpool implementation that re-uses threads to reduce overhead. As a nice side-benefit, all of the ExecuteReplicated calls are dispatched to the same thread, making profiles more readable. It also provides a mechanism for batching low-latency operations (like argument and result plumbing) into the same dispatch.
  • Use absl::BlockingCounter: it's a known issue that we end up waiting a long time on MultiWait::Done calls because incrementing the counter requires locking a shared mutex. Finished tasks are forced to queue up waiting for the lock, preventing new tasks from being started by the ThreadPool. absl::BlockingCounter implements the same idea, but uses an atomic counter for lock-less decrements. This significantly reduces lock contention.

I finally have a good level of performance. The DeviceLockWait (ie the amount of time waiting for the synchronous portion of ExecuteReplicated) drops to 16s from 23s on the ResNet50 benchmark I used above.

More importantly, llama2 7B inference with dynamo + SPMD is comparable or marginally faster with this PR compared to HEAD: Totally decoded 2047 tokens in 14.31419 seconds vs Totally decoded 2047 tokens in 14.46967 seconds. I'll include experiment details at the bottom. The benefits may be larger for models with more parameters, and/or with better tuning of the cost estimates used with ThreadPool::ParallelFor.

Still consider the details of this PR a draft. It needs to be split up, and we should align on the API organization before a final review. Also, a test just failed which means I probably broke something else.

For my llama experiment, I used the optimize_spmd_sharding branch of our fork on v4-8 with the following command: XLA_USE_SPMD=1 PJRT_DEVICE=TPU python3 llama/example_text_completion.py llama/7B/ spiece.model --max_seq_len 2048 --max_gen_len 10000 --max_batch_size 1 --mp False --dynamo True --spmd True

@will-cromar will-cromar force-pushed the wcromar/refactor-execute-replicated branch from a5f5ef6 to f51ae47 Compare November 27, 2023 18:53
@will-cromar
Copy link
Collaborator Author

Finally getting back to this PR after merging supporting changes.

This actually does seem to be making a material impact on performance on LLama2 7B. At the current master, the benchmark hovers around 7.06 seconds for 1000 tokens, for example

Totally decoded 1007 tokens in 7.06039 seconds

It's consistently lower now, for example

Totally decoded 1007 tokens in 7.02751 seconds
...
Totally decoded 1007 tokens in 7.00845 seconds

@will-cromar
Copy link
Collaborator Author

The CI failures look similar to master currently. Marking this PR ready for review.

@will-cromar will-cromar marked this pull request as ready for review November 29, 2023 18:46
@will-cromar will-cromar requested a review from yeounoh November 29, 2023 18:46
Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks Will! I'll let @yeounoh send the approval.

Comment on lines 445 to 447
auto pjrt_data = std::dynamic_pointer_cast<PjRtData>(handle);
XLA_CHECK(pjrt_data) << "Data must be PjRtData or PjRtShardedData.";
return pjrt_data;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this an else if (PjRtData* pjrt_data = dynamic_cast<PjRtData*>(handle.get())), and add the XLA_CHECK as the fallthrough case? The error message threw me off for a sec

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, this is obscure. Fixed.


// Time in nanoseconds that it takes to prepare an argument. Used to tune
// number of threads spawned by ParallelFor. Measured on 2023/11/28.
static constexpr int64_t argument_handle_cost_ns = 10000;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How important is the accuracy of this value? I assume this value depends on the host machine's specs, so it would be different across TPU generations and GPU VMs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just shooting for the right order of magnitude here -- I didn't see any noticeable difference between this and a 30000 ns estimate from a previous commit. Supposedly we can set this to the number of CPU cycles (which would be more consistent), but I don't know how to measure that accurately.

Since this is really obscure, and it doesn't seem to be that sensitive, I just put a nice round number in the correct range here rather than adding another configuration knob.

PjRtComputationClient::ExecuteReplicated(
const ComputationClient::Computation& computation,
const std::vector<std::vector<ComputationClient::DataPtr>>& arguments,
absl::Span<const ComputationClient::DataPtr> arguments,
absl::Span<const std::string> devices,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we still need devices as input here - should the caller need to be aware of which devices the shards live on?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

devices doesn't really make sense since we dropped XRT, since PJRT only lets you execute on one device or all devices. I added a TODO to remove this in my IFRT prototype PR. In practice, the caller does seem to be setting the "correct" value here, though.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@@ -231,10 +231,16 @@ class PjRtComputationClient : public ComputationClient {
std::vector<std::string> devices,
std::unique_ptr<xla::PjRtLoadedExecutable> executable)
: Computation(std::move(computation), std::move(devices)),
executable(std::move(executable)) {}
executable(std::move(executable)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's already being done as part of data placeholder prep, then this is redundant. It looks like we have to extract the output shardings as part of the pjrt_computation iniailzation, I think with AOT compilation and auto-sharding, it makes sense to do it strictly around the compilation (we will have to do it anyway there), and avoid doing it part of the execution call. Preferably, keep the pjrt_copmputation just a wrapper without much processing/overhead.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is redundant. I only added this here because I wanted to avoid returning UNKNOWN out of ExecuteReplicated, and GetOutputSharding is far too slow to call per execution. This does run once per compilation and is cached. PjRtComputation is still a wrapper; it's just caching a really slow call we know we'll need down the line. We were effectively doing the same thing by saving ProgramShape in Computation already.

Depending on how auto-sharding works, and if we add the output shardings to the top-level Computation, the placeholders shardings could easily instead use the compiled executable's output shardings from here rather than digging into the HLO proto manually:

const torch::lazy::BackendDevice& device) {
const auto& computation_proto = computation->computation().proto();
uint64_t num_outputs = output_shapes->size();
std::vector<xla::OpSharding> output_shardings;
std::vector<XLATensor::ShardingSpecPtr> sharding_specs(num_outputs);
if (computation_proto.has_spmd_output_sharding()) {
if (computation_proto.spmd_output_sharding().tuple_shardings().size() > 0) {
auto tuple_shardings =
computation_proto.spmd_output_sharding().tuple_shardings();
output_shardings = std::vector<xla::OpSharding>(tuple_shardings.begin(),
tuple_shardings.end());
} else {
output_shardings = std::vector<xla::OpSharding>{
computation_proto.spmd_output_sharding()};
}
}

Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the fact that it's making the runtime and the replicated execution more performance, but it's less modular than before. Meanwhile, I left some comments around output sharding extraction being part of the execution -- as a heads-up/ FYI.

Overall ,this is great -- thanks @will-cromar

@will-cromar will-cromar force-pushed the wcromar/refactor-execute-replicated branch from deed402 to 9bc594e Compare November 30, 2023 17:49
@will-cromar will-cromar merged commit 7b92a94 into master Nov 30, 2023
18 of 19 checks passed
ManfeiBai pushed a commit to ManfeiBai/PyTorchXLA that referenced this pull request Dec 1, 2023
…ch#5737)

* Refactor ExecuteReplicated to operate on sharded data directly

* Remove old handlers

* formatting

* Improve naming and logging

* update docstring

* Remove obsolete unit tests

* improve comment

* Remove slow calls to get output shapes.

* fix implicit sharding

* remove declarations of input/output handlers

* formatting

* give everything a manual placeholder sharding

* see if CI passes

* formatting

* Shard parameter and output handling

* Use absl::BlockingCounter

* formatting

* fix merge

* Assign valid output shardings

* tune and document costs

* formatting

* implicitly replicate output to match outputhandler

* clarify ReplicateShardedData

* fix merge
ManfeiBai pushed a commit to ManfeiBai/PyTorchXLA that referenced this pull request Dec 1, 2023
…ch#5737)

* Refactor ExecuteReplicated to operate on sharded data directly

* Remove old handlers

* formatting

* Improve naming and logging

* update docstring

* Remove obsolete unit tests

* improve comment

* Remove slow calls to get output shapes.

* fix implicit sharding

* remove declarations of input/output handlers

* formatting

* give everything a manual placeholder sharding

* see if CI passes

* formatting

* Shard parameter and output handling

* Use absl::BlockingCounter

* formatting

* fix merge

* Assign valid output shardings

* tune and document costs

* formatting

* implicitly replicate output to match outputhandler

* clarify ReplicateShardedData

* fix merge
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
…ch#5737)

* Refactor ExecuteReplicated to operate on sharded data directly

* Remove old handlers

* formatting

* Improve naming and logging

* update docstring

* Remove obsolete unit tests

* improve comment

* Remove slow calls to get output shapes.

* fix implicit sharding

* remove declarations of input/output handlers

* formatting

* give everything a manual placeholder sharding

* see if CI passes

* formatting

* Shard parameter and output handling

* Use absl::BlockingCounter

* formatting

* fix merge

* Assign valid output shardings

* tune and document costs

* formatting

* implicitly replicate output to match outputhandler

* clarify ReplicateShardedData

* fix merge
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
* Refactor ExecuteReplicated to operate on sharded data directly

* Remove old handlers

* formatting

* Improve naming and logging

* update docstring

* Remove obsolete unit tests

* improve comment

* Remove slow calls to get output shapes.

* fix implicit sharding

* remove declarations of input/output handlers

* formatting

* give everything a manual placeholder sharding

* see if CI passes

* formatting

* Shard parameter and output handling

* Use absl::BlockingCounter

* formatting

* fix merge

* Assign valid output shardings

* tune and document costs

* formatting

* implicitly replicate output to match outputhandler

* clarify ReplicateShardedData

* fix merge
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
* Refactor ExecuteReplicated to operate on sharded data directly

* Remove old handlers

* formatting

* Improve naming and logging

* update docstring

* Remove obsolete unit tests

* improve comment

* Remove slow calls to get output shapes.

* fix implicit sharding

* remove declarations of input/output handlers

* formatting

* give everything a manual placeholder sharding

* see if CI passes

* formatting

* Shard parameter and output handling

* Use absl::BlockingCounter

* formatting

* fix merge

* Assign valid output shardings

* tune and document costs

* formatting

* implicitly replicate output to match outputhandler

* clarify ReplicateShardedData

* fix merge
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants