Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add allgather check for xpu #2199

Merged
merged 5 commits into from
Dec 5, 2023
Merged

Conversation

abhilash1910
Copy link
Contributor

@abhilash1910 abhilash1910 commented Nov 29, 2023

Related to #2180 . There can be an alternate design as (cc @muellerzr Let me know your thoughts.) :

if state.device.type != tensor.device.type :
     if state.device.type=="cuda:"
           <print error>
     
     elif ......<logic for other devices>
            <print error>

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I think this is a good idea/design. I think we can refactor this out into something like the following:

def assert_tensor_devices(a:tensor,b:tensor, operation:str):
    if a.device.type != b.device.type:
        raise RuntimeError()

And we can check for the device types/etc with tracking the operation name via the string, similar to verify_operation. (Tbh we may even just be able to modify verify_operation?)

@abhilash1910
Copy link
Contributor Author

Makes sense, we can directly modify verify_operation by adding additional common check for any device ?

@muellerzr
Copy link
Collaborator

Yep. Just I'd put it outside the ACCELERATE_DEBUG_MODE check. This is something we can run quickly all the time as it's device checks it doesn't need a communication call

@abhilash1910
Copy link
Contributor Author

@muellerzr could you please re-trigger CI and review? As of yet I placed it inside verify_operation as the wrapper is already called in gather method. Thanks

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking much better to me! Nice job

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One note for our failures here

src/accelerate/utils/operations.py Outdated Show resolved Hide resolved
@muellerzr muellerzr merged commit 47e6c36 into huggingface:main Dec 5, 2023
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants