Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add autocast support for XlaPatchedLinear #8421

Merged
merged 9 commits into from
Dec 10, 2024
39 changes: 39 additions & 0 deletions test/test_patchedlinear_autocast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
aws-nm9 marked this conversation as resolved.
Show resolved Hide resolved
import re
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import unittest

import torch_xla.distributed.spmd.xla_sharding as xs

device = xm.xla_device()

class TestAutocastXla(unittest.TestCase):
def test_patchedlinear_autocast(self):
hidden_size = 10
intermediate_size = 15
input_tensor = torch.randn(5, hidden_size, requires_grad=True) # batch_size=5 as example
linear = torch.nn.Linear(hidden_size, intermediate_size, bias=True)

with torch.autocast("xla", enabled=True):
linear = linear.to(device)
input_tensor = input_tensor.to(device)
output = xs.xla_patched_nn_linear_forward(linear, input_tensor)
sum = (output**output).mean()

sum.backward()
hlo = torch_xla._XLAC._get_xla_tensors_hlo([linear.weight.grad])

self.assertTrue(re.search(r".*dot.*bf16", hlo) is not None)
aws-nm9 marked this conversation as resolved.
Show resolved Hide resolved
aws-nm9 marked this conversation as resolved.
Show resolved Hide resolved

self.assertTrue(re.search(r".*dot.*f32", hlo) is None)
aws-nm9 marked this conversation as resolved.
Show resolved Hide resolved

bf16_to_f32 = len(re.findall(r".*convert.*f32.*convert.*bf16", hlo))
aws-nm9 marked this conversation as resolved.
Show resolved Hide resolved
f32_to_bf16 = len(re.findall(r".*convert.*bf16.*convert.*f32", hlo))
self.assertTrue(bf16_to_f32 > 0 and f32_to_bf16 > 0)
self.assertTrue(bf16_to_f32 > f32_to_bf16) # cast for grad in backward, so bf16_to_f32 should be higher


if __name__ == "__main__":
unittest.main()
3 changes: 3 additions & 0 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Tuple, Union, List, Sequence, Any, Optional, Set
from enum import IntEnum

from torch.amp import custom_fwd, custom_bwd

class Mesh:
"""Describe the logical XLA device topology mesh and the underlying resources.
Expand Down Expand Up @@ -738,6 +739,7 @@ class XLAPatchedLinear(torch.autograd.Function):
"""

@staticmethod
@custom_fwd(device_type='xla', cast_inputs=torch.get_autocast_dtype('xla'))
aws-nm9 marked this conversation as resolved.
Show resolved Hide resolved
def forward(ctx, input, weight, bias=None):
# bias is an optional argument
ctx.save_for_backward(input, weight, bias)
Expand All @@ -748,6 +750,7 @@ def forward(ctx, input, weight, bias=None):
return product + bias

@staticmethod
@custom_bwd(device_type='xla')
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
Expand Down