Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IFRT prototype #5677

Merged
merged 33 commits into from
Dec 15, 2023
Merged

IFRT prototype #5677

merged 33 commits into from
Dec 15, 2023

Conversation

will-cromar
Copy link
Collaborator

@will-cromar will-cromar commented Oct 5, 2023

  • Implement new ComputationClient with IFRT, which currently just wraps PJRT.
    • Enable IFRT wrapper with XLA_USE_IFRT=1
  • Move PJRT client initialization into initialize_pjrt.cc/h since IFRT wraps the same PjRtClient.
  • Minor fixes to existing pjrt_computation_client:
    • Make spmd_device_str const
    • Change the owned PjRtClient to a unique_ptr. Only SE:TPU required us to use shared_ptr, which is now removed.

There's still some opportunities to refactor common functionality up to ComputationClient, but I'm trying to minimize changes to the high-level API for this PR.

IFRT is still highly experimental. Use at your own risk. See my comments below for caveats and limitations.

@will-cromar
Copy link
Collaborator Author

At this point, I am able to run our SPMD ResNet50 example successfully, but it's extremely slow. Any case where I resharded an array, I chose "copy" semantics for safety. I'll have to take another pass and carefully think through the ownership of the underlying data, since the excessive copies are likely contributing to poor performance.

Sharded execution does not appear to be implemented at all within the IFRT/PJRT wrapper, so I marked that method unimplemented for now. Likewise, I xla::ifrt::HloSharding (a wrapper over xla::HloSharding) is not implemented in the wrapper, so we'll have to carry the xla::OpSharding in our own object for now.

Dynamic shape is currently unsupported by xla::ifrt::Shape.

Until we have feature parity, I will keep IFRT as a separate ComputationClient implementation so we don't break PJRT.

@will-cromar will-cromar force-pushed the wcromar/ifrt-prototype branch from f58cba4 to 40ab61d Compare November 28, 2023 19:25
@will-cromar will-cromar changed the base branch from master to wcromar/refactor-execute-replicated November 28, 2023 19:26
@will-cromar will-cromar force-pushed the wcromar/ifrt-prototype branch from bbde87f to f411a37 Compare November 29, 2023 22:03
@will-cromar will-cromar force-pushed the wcromar/refactor-execute-replicated branch from deed402 to 9bc594e Compare November 30, 2023 17:49
@will-cromar
Copy link
Collaborator Author

will-cromar commented Dec 1, 2023

Coming back to this PR (finally) after merging supporting changes in separate PRs. Performance is significantly better after rebasing -- it only lags PJRT by ~10% on ResNet50 now, compared to 80% in my first draft. There's still room for optimization, particularly around reducing the number of copies used when transforming IFRT arrays.

IFRT is still highly experimental in this state. Known outstanding issues other than performance:

  • The IFRT/PJRT wrapper now natively supports xla::HloSharding, but I'm still getting errors at the current XLA pin. I'd like to try again after another pin update since this support was only added in October.
  • Sharded execution (aka multiprocess) is completely unsupported in IFRT.
  • Dynamic shapes can't be represented by xla::ifrt::Shape.

I'll clean up this PR and send it for review as an optional/experimental setting.

@will-cromar will-cromar force-pushed the wcromar/ifrt-prototype branch from e4856a3 to 1855086 Compare December 1, 2023 00:03
@will-cromar will-cromar changed the base branch from wcromar/refactor-execute-replicated to master December 1, 2023 00:03
@will-cromar will-cromar changed the title [WIP] IFRT prototype IFRT prototype Dec 1, 2023
@will-cromar
Copy link
Collaborator Author

Performance on LLama 7B is not bad! It's somewhere between PJRT now and PJRT before I started working on some optimizations this month:

Totally decoded 1007 tokens in 7.03615 seconds

@will-cromar will-cromar marked this pull request as ready for review December 1, 2023 19:00
@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 4, 2023

If you don't intend to merge this for 2.2 release, I will hold on the review until the branch cut.

@will-cromar
Copy link
Collaborator Author

Merging after the cut sounds good to me. This won't be useful in the 2.2 release.

@JackCaoG
Copy link
Collaborator

I will take a look today


// Builds a map from the device's global ordinal to its index in the `devices`
// array.
std::unordered_map<int, int> build_index_map(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for these utils, can we share between pjrt and ifrt? so they are actually different?

Copy link
Collaborator Author

@will-cromar will-cromar Dec 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was intentional. I wanted to minimize changes to common code, and ideally we would be able to remove the PJRT computation client at some point. Would you prefer I try to factor out all common functionality in this PR?

}

IfrtComputationClient::IfrtComputationClient() {
std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, "");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do a grep of pjrt in this file and replace them with ifrt. Through I am curious did you intend to query the EnvPjrtDevice here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, see #5677 (comment)

const std::vector<DataPtr>& shards, std::string device, xla::Shape shape,
xla::OpSharding sharding) {
// TODO: implement CreateDataPlaceholder for sharded data
if (shards.size() == 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a legit use case of shards.zie() == 0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how sharded data placeholders get created right now:

auto sharded_data_placeholder =
runtime::GetComputationClient()->WrapDataShards(
{}, GetVirtualDevice().toString(), sharding_specs[i]->shape,
sharding_specs[i]->sharding);

I'd rather update CreateDataPlaceholder to take a sharding and make it more explicit

@will-cromar will-cromar force-pushed the wcromar/ifrt-prototype branch from 3cec447 to 9ebc616 Compare December 14, 2023 22:51
@will-cromar will-cromar merged commit 85b3cdc into master Dec 15, 2023
20 checks passed
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants