Skip to content

Commit

Permalink
Added decorator description to docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
aws-nm9 committed Dec 9, 2024
1 parent 4ed2466 commit c93a683
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,14 @@ class XLAPatchedLinear(torch.autograd.Function):
dimensions. The torch.matmul default behavior makes it very hard for XLA compiler
to propagate the sharding annotation.
Autocast decorators @custom_fwd and @custom_bwd used as per autocast docs [1] to bring this class/layer within
autocast context, when autocast is enabled.
torch.get_autocast_dtype() fetches datatype for ops run in autocast [2], with the specified device (here, 'xla').
References:
[1] https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops
[2] https://github.com/pytorch/pytorch/blob/2cc01cc6d3ad2aff47e8460667ba654b2e4c9f21/torch/amp/autocast_mode.py#L500
TODO (alanwaketan): Let's patch it on the dispatcher level.
"""

Expand Down

0 comments on commit c93a683

Please sign in to comment.