Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add autocast support for XlaPatchedLinear #8421

Merged
merged 9 commits into from
Dec 10, 2024
59 changes: 59 additions & 0 deletions test/test_autocast_xla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import re
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import unittest

import torch_xla.distributed.spmd.xla_sharding as xs

device = xm.xla_device()


class TestAutocastXla(unittest.TestCase):

def test_cross_entropy_loss(self):
data = torch.randn(16, 10).to(torch.bfloat16).to(device)
target = torch.randn(16, 10).to(torch.bfloat16).to(device)

with torch.autocast("xla"):
loss = torch.nn.CrossEntropyLoss()(data, target)
hlo = torch_xla._XLAC._get_xla_tensors_hlo([loss])
self.assertRegex(hlo, r".*convert.*f32.*convert.*bf16")
self.assertRegex(hlo, r".*exponential.*f32.*exponential.*f32")
self.assertRegex(hlo, r".*log.*f32.*log.*f32")

def test_patchedlinear_autocast(self):
hidden_size = 10
intermediate_size = 15
input_tensor = torch.randn(
5, hidden_size, requires_grad=True) # batch_size=5 as example
linear = torch.nn.Linear(hidden_size, intermediate_size, bias=True)

with torch.autocast("xla", enabled=True):
linear = linear.to(device)
input_tensor = input_tensor.to(device)
output = xs.xla_patched_nn_linear_forward(linear, input_tensor)
result = (output**output).mean()
# a simple sum would suffice, but then hlo for linear.weight.grad would be only the backward pass
# since grad of sum is constant. Hence this is done, to get whole execution graph in HLO.

result.backward()
hlo = torch_xla._XLAC._get_xla_tensors_hlo([linear.weight.grad])
# Verify that matrix multiplication is performed in bfloat16 precision and not f32.
# XLAPatchedLinear uses einsum instead of matmul, hence this is checking if einsum op in this layer is being autocast.
self.assertRegex(hlo, r".*dot.*bf16")
self.assertNotRegex(hlo, r".*dot.*f32")

bf16_to_f32 = len(re.findall(r".*convert.*f32.*convert.*bf16", hlo))
f32_to_bf16 = len(re.findall(r".*convert.*bf16.*convert.*f32", hlo))

# Verify that precision conversions are happening
self.assertTrue(bf16_to_f32 > 0 and f32_to_bf16 > 0)
# Verify more bf16->f32 conversions than f32->bf16, since this is expected during backward pass for grad computation
self.assertTrue(bf16_to_f32 == 11)
self.assertTrue(f32_to_bf16 == 8)
self.assertTrue(bf16_to_f32 > f32_to_bf16) #redundant given the above two, but this is what we actually want to verify


if __name__ == "__main__":
unittest.main()
29 changes: 0 additions & 29 deletions test/test_bf16_autocast.py

This file was deleted.

3 changes: 3 additions & 0 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Tuple, Union, List, Sequence, Any, Optional, Set
from enum import IntEnum

from torch.amp import custom_fwd, custom_bwd

class Mesh:
"""Describe the logical XLA device topology mesh and the underlying resources.
Expand Down Expand Up @@ -738,6 +739,7 @@ class XLAPatchedLinear(torch.autograd.Function):
"""

@staticmethod
@custom_fwd(device_type='xla', cast_inputs=torch.get_autocast_dtype('xla'))
aws-nm9 marked this conversation as resolved.
Show resolved Hide resolved
def forward(ctx, input, weight, bias=None):
# bias is an optional argument
ctx.save_for_backward(input, weight, bias)
Expand All @@ -748,6 +750,7 @@ def forward(ctx, input, weight, bias=None):
return product + bias

@staticmethod
@custom_bwd(device_type='xla')
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
Expand Down
Loading