-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
IFRT prototype #5677
IFRT prototype #5677
Conversation
At this point, I am able to run our SPMD ResNet50 example successfully, but it's extremely slow. Any case where I resharded an array, I chose "copy" semantics for safety. I'll have to take another pass and carefully think through the ownership of the underlying data, since the excessive copies are likely contributing to poor performance. Sharded execution does not appear to be implemented at all within the IFRT/PJRT wrapper, so I marked that method unimplemented for now. Likewise, I Dynamic shape is currently unsupported by Until we have feature parity, I will keep IFRT as a separate |
f58cba4
to
40ab61d
Compare
bbde87f
to
f411a37
Compare
deed402
to
9bc594e
Compare
Coming back to this PR (finally) after merging supporting changes in separate PRs. Performance is significantly better after rebasing -- it only lags PJRT by ~10% on ResNet50 now, compared to 80% in my first draft. There's still room for optimization, particularly around reducing the number of copies used when transforming IFRT arrays. IFRT is still highly experimental in this state. Known outstanding issues other than performance:
I'll clean up this PR and send it for review as an optional/experimental setting. |
e4856a3
to
1855086
Compare
Performance on LLama 7B is not bad! It's somewhere between PJRT now and PJRT before I started working on some optimizations this month:
|
If you don't intend to merge this for 2.2 release, I will hold on the review until the branch cut. |
Merging after the cut sounds good to me. This won't be useful in the 2.2 release. |
I will take a look today |
|
||
// Builds a map from the device's global ordinal to its index in the `devices` | ||
// array. | ||
std::unordered_map<int, int> build_index_map( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for these utils, can we share between pjrt and ifrt? so they are actually different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was intentional. I wanted to minimize changes to common code, and ideally we would be able to remove the PJRT computation client at some point. Would you prefer I try to factor out all common functionality in this PR?
} | ||
|
||
IfrtComputationClient::IfrtComputationClient() { | ||
std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, ""); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's do a grep of pjrt
in this file and replace them with ifrt
. Through I am curious did you intend to query the EnvPjrtDevice here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, see #5677 (comment)
const std::vector<DataPtr>& shards, std::string device, xla::Shape shape, | ||
xla::OpSharding sharding) { | ||
// TODO: implement CreateDataPlaceholder for sharded data | ||
if (shards.size() == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a legit use case of shards.zie() == 0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is how sharded data placeholders get created right now:
xla/torch_xla/csrc/xla_sharding_util.cpp
Lines 590 to 593 in a2f80e4
auto sharded_data_placeholder = | |
runtime::GetComputationClient()->WrapDataShards( | |
{}, GetVirtualDevice().toString(), sharding_specs[i]->shape, | |
sharding_specs[i]->sharding); |
I'd rather update CreateDataPlaceholder
to take a sharding and make it more explicit
This reverts commit 7d52f67.
3cec447
to
9ebc616
Compare
ComputationClient
with IFRT, which currently just wraps PJRT.XLA_USE_IFRT=1
initialize_pjrt.cc/h
since IFRT wraps the samePjRtClient
.pjrt_computation_client
:spmd_device_str
const
PjRtClient
to aunique_ptr
. Only SE:TPU required us to useshared_ptr
, which is now removed.There's still some opportunities to refactor common functionality up to
ComputationClient
, but I'm trying to minimize changes to the high-level API for this PR.IFRT is still highly experimental. Use at your own risk. See my comments below for caveats and limitations.