Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient bucketing using a pre-defined bucket size cap #6417

Closed
wants to merge 10 commits into from

Conversation

amithrm
Copy link
Collaborator

@amithrm amithrm commented Jan 30, 2024

No description provided.

@JackCaoG JackCaoG requested a review from alanwaketan January 30, 2024 18:25
@alanwaketan
Copy link
Collaborator

Do you mind adding a test case?

@amithrm
Copy link
Collaborator Author

amithrm commented Mar 4, 2024

Added the test case and rebased @JackCaoG @alanwaketan

grad_bytes = grad.numel() * grad.element_size()

# Gradient is larger than bucket_cap, don't bucketize
if grad_bytes > bucket_cap:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why you want to specialize this case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the grad_bytes (already in the tensor) is larger than bucket cap, we send it straight away as a single tensor instead of bucketing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I understood the logic. But why? Combining it with the bucket introduce some problems?

Copy link
Collaborator

@jeffhataws jeffhataws Mar 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, looks like you can get rid of this if statement (until continue), and the "if total > bucket_cap" should take care of this condition when bucket is empty.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue with combining this with the rest is that the "buffer" allocated in the underlying runtime may not have enough space to fit this large tensor. The idea is to have a large buffer that can fit all the tensors. It can happen that total_bytes is just below the max allowed and this tensor if added to the bucket spills the maximum. Hence should go "alone" without bucketizing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See your concerns now !! Fixed the code flow

@@ -990,14 +1042,13 @@ def reduce_gradients(optimizer, groups=None, pin_layout=True):
"""
count = xrt_world_size()
if count > 1:
gradients = _fetch_gradients(optimizer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep the original behavior? And maybe use a flag to turn this feature on?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK..let me work on that

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should introduce an argument "bucket_cap_mb" that turns this on, instead of environmental variable? bucket_cap_mb=0 turns off bucketing and is the default?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# Bucketize till the total spills over
total += grad_bytes
if total > bucket_cap:
all_reduce(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check "if len(tensor_bucket):" because tensor_bucket can be empty at the start, when grad_bytes > bucket_cap.

@@ -974,6 +976,56 @@ def wait_device_ops(devices=[]):
torch_xla._XLAC._xla_wait_device_ops(devices=devices)


def bucketed_allreduce(gradients):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe name it similar to the original function all_reduce? How about all_reduce_bucketized?

Also, do you need to pass "groups" and "pin_layout" also?

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please address other comments as well.

@amithrm amithrm force-pushed the bucket_allreduce branch from 777b97f to 31dd451 Compare May 28, 2024 20:20
@jeffhataws
Copy link
Collaborator

@JackCaoG do you know why the build failed with "ERROR: Error initializing RemoteModule"?

@JackCaoG
Copy link
Collaborator

JackCaoG commented May 28, 2024

It is on a fork hence can't use remote cache but there was a bug that it still try to query the credintical. I think we fixed this issue error today, it should start building without cache. If you rebase the CI should start running.

@amithrm amithrm force-pushed the bucket_allreduce branch from 31dd451 to 05e2367 Compare May 29, 2024 02:56
@jeffhataws
Copy link
Collaborator

@JackCaoG looks like build is still failing for some reason after rebasing. Maybe another rebase is needed?

@JackCaoG
Copy link
Collaborator

The error still seems to be related with the fork. Let me grant both of you the write access, then you can open pr directly.

@JackCaoG
Copy link
Collaborator

OK I gave @amithrm write access

@jeffhataws
Copy link
Collaborator

Replaced by #7216 to avoid the build issues in CI testing.

@jeffhataws jeffhataws closed this Jun 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants