-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add autocast support for XlaPatchedLinear #8421
Conversation
Need linter fix. |
@jeffhataws @rpsilva-aws Thanks for comments, made the changes and did YAPF formatting (was using RUFF earlier), linter should be fixed. Also consolidated all tests in a single file. |
Attaching the HLO for PatchedLinear test here:
For reference, the code is:
Also adding exact count assertions in new commit below per @rpsilva-aws's and @avizon-aws's suggestion. |
Waiting for lint fixes and tests to pass before merging |
This PR adds autocast support for XlaPatchedLinear. This layer currently is ignored by autocast because of custom forward and backward functions. Adding decorators as per https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops brings this within autocast context (when autocast is enabled). A test for the same is also added.
Related to #8420 PR for einsum autocast support, and an extension of that for XlaPatchedLinear, since it uses einsum.
Relevant issue: #8405