-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient bucketing using a pre-defined bucket size cap #6417
Conversation
Do you mind adding a test case? |
11f466d
to
fdb0f9e
Compare
Added the test case and rebased @JackCaoG @alanwaketan |
torch_xla/core/xla_model.py
Outdated
grad_bytes = grad.numel() * grad.element_size() | ||
|
||
# Gradient is larger than bucket_cap, don't bucketize | ||
if grad_bytes > bucket_cap: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why you want to specialize this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the grad_bytes (already in the tensor) is larger than bucket cap, we send it straight away as a single tensor instead of bucketing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I understood the logic. But why? Combining it with the bucket introduce some problems?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, looks like you can get rid of this if statement (until continue), and the "if total > bucket_cap" should take care of this condition when bucket is empty.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue with combining this with the rest is that the "buffer" allocated in the underlying runtime may not have enough space to fit this large tensor. The idea is to have a large buffer that can fit all the tensors. It can happen that total_bytes is just below the max allowed and this tensor if added to the bucket spills the maximum. Hence should go "alone" without bucketizing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See your concerns now !! Fixed the code flow
@@ -990,14 +1042,13 @@ def reduce_gradients(optimizer, groups=None, pin_layout=True): | |||
""" | |||
count = xrt_world_size() | |||
if count > 1: | |||
gradients = _fetch_gradients(optimizer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we keep the original behavior? And maybe use a flag to turn this feature on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK..let me work on that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should introduce an argument "bucket_cap_mb" that turns this on, instead of environmental variable? bucket_cap_mb=0 turns off bucketing and is the default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
torch_xla/core/xla_model.py
Outdated
# Bucketize till the total spills over | ||
total += grad_bytes | ||
if total > bucket_cap: | ||
all_reduce( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to check "if len(tensor_bucket):" because tensor_bucket can be empty at the start, when grad_bytes > bucket_cap.
@@ -974,6 +976,56 @@ def wait_device_ops(devices=[]): | |||
torch_xla._XLAC._xla_wait_device_ops(devices=devices) | |||
|
|||
|
|||
def bucketed_allreduce(gradients): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe name it similar to the original function all_reduce? How about all_reduce_bucketized?
Also, do you need to pass "groups" and "pin_layout" also?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Please address other comments as well.
@JackCaoG do you know why the build failed with "ERROR: Error initializing RemoteModule"? |
It is on a fork hence can't use remote cache but there was a bug that it still try to query the credintical. I think we fixed this issue error today, it should start building without cache. If you rebase the CI should start running. |
Summary: This pull request tries to unify all TORCH_LIBRARY definitions across torch_xla into one xla library. Test Plan: CI
@JackCaoG looks like build is still failing for some reason after rebasing. Maybe another rebase is needed? |
The error still seems to be related with the fork. Let me grant both of you the write access, then you can open pr directly. |
OK I gave @amithrm write access |
Replaced by #7216 to avoid the build issues in CI testing. |
No description provided.