-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow scalar all gather #5797
Allow scalar all gather #5797
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I will merge it once CI passed. (If I forgot feel free to give me a ping)
Ping on this PR, the failing tests seem unrelated to this change. |
sorry my bad.. should of merge it earlier. Do you mind resolve the merge conflict and rebase? Do you need this to backport to 2.2 release? |
Resolved the conflict and no need to backport. |
sg, I will merge it once all ci passed. |
Thanks! |
This change allows all gather to work for scalars. This is supported by the normal pytorch distributed API, but if you attempt to all gather scalars under xla, a low level error is emitted as the xla gather axis, 0, is not less than the rank, also 0, of the inputs.
This is fixed by reshaping the scalar tensors to shape (1,) and then reshaping the sliced outputs back to scalars.
This was tested manually on Cloud TPU.