-
Notifications
You must be signed in to change notification settings - Fork 989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add allgather check for xpu #2199
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall I think this is a good idea/design. I think we can refactor this out into something like the following:
def assert_tensor_devices(a:tensor,b:tensor, operation:str):
if a.device.type != b.device.type:
raise RuntimeError()
And we can check for the device types/etc with tracking the operation name via the string, similar to verify_operation
. (Tbh we may even just be able to modify verify_operation?)
Makes sense, we can directly modify verify_operation by adding additional common check for any device ? |
Yep. Just I'd put it outside the |
@muellerzr could you please re-trigger CI and review? As of yet I placed it inside verify_operation as the wrapper is already called in gather method. Thanks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking much better to me! Nice job
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One note for our failures here
Related to #2180 . There can be an alternate design as (cc @muellerzr Let me know your thoughts.) :