From 7884949b2d6fe6a8206e44c775e045997e9fe085 Mon Sep 17 00:00:00 2001 From: Nachiket Mokashi Date: Wed, 27 Nov 2024 02:56:15 +0000 Subject: [PATCH 1/8] Add autocast support for XlaPatchedLinear --- test/test_patchedlinear_autocast.py | 39 ++++++++++++++++++++++ torch_xla/distributed/spmd/xla_sharding.py | 3 ++ 2 files changed, 42 insertions(+) create mode 100644 test/test_patchedlinear_autocast.py diff --git a/test/test_patchedlinear_autocast.py b/test/test_patchedlinear_autocast.py new file mode 100644 index 00000000000..91635852647 --- /dev/null +++ b/test/test_patchedlinear_autocast.py @@ -0,0 +1,39 @@ +import os +import re +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import unittest + +import torch_xla.distributed.spmd.xla_sharding as xs + +device = xm.xla_device() + +class TestAutocastXla(unittest.TestCase): + def test_patchedlinear_autocast(self): + hidden_size = 10 + intermediate_size = 15 + input_tensor = torch.randn(5, hidden_size, requires_grad=True) # batch_size=5 as example + linear = torch.nn.Linear(hidden_size, intermediate_size, bias=True) + + with torch.autocast("xla", enabled=True): + linear = linear.to(device) + input_tensor = input_tensor.to(device) + output = xs.xla_patched_nn_linear_forward(linear, input_tensor) + sum = (output**output).mean() + + sum.backward() + hlo = torch_xla._XLAC._get_xla_tensors_hlo([linear.weight.grad]) + + self.assertTrue(re.search(r".*dot.*bf16", hlo) is not None) + + self.assertTrue(re.search(r".*dot.*f32", hlo) is None) + + bf16_to_f32 = len(re.findall(r".*convert.*f32.*convert.*bf16", hlo)) + f32_to_bf16 = len(re.findall(r".*convert.*bf16.*convert.*f32", hlo)) + self.assertTrue(bf16_to_f32 > 0 and f32_to_bf16 > 0) + self.assertTrue(bf16_to_f32 > f32_to_bf16) # cast for grad in backward, so bf16_to_f32 should be higher + + +if __name__ == "__main__": + unittest.main() diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 7085325513e..1decbcf8402 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -14,6 +14,7 @@ from typing import Tuple, Union, List, Sequence, Any, Optional, Set from enum import IntEnum +from torch.amp import custom_fwd, custom_bwd class Mesh: """Describe the logical XLA device topology mesh and the underlying resources. @@ -738,6 +739,7 @@ class XLAPatchedLinear(torch.autograd.Function): """ @staticmethod + @custom_fwd(device_type='xla', cast_inputs=torch.get_autocast_dtype('xla')) def forward(ctx, input, weight, bias=None): # bias is an optional argument ctx.save_for_backward(input, weight, bias) @@ -748,6 +750,7 @@ def forward(ctx, input, weight, bias=None): return product + bias @staticmethod + @custom_bwd(device_type='xla') def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors grad_input = grad_weight = grad_bias = None From 59c0b8f6b5715e2392a71ce41e5669015a3586f3 Mon Sep 17 00:00:00 2001 From: Nachiket Mokashi Date: Wed, 27 Nov 2024 19:11:08 +0000 Subject: [PATCH 2/8] Added and addressed comments --- test/test_patchedlinear_autocast.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/test/test_patchedlinear_autocast.py b/test/test_patchedlinear_autocast.py index 91635852647..f9988280ee6 100644 --- a/test/test_patchedlinear_autocast.py +++ b/test/test_patchedlinear_autocast.py @@ -1,4 +1,3 @@ -import os import re import torch import torch_xla @@ -9,7 +8,7 @@ device = xm.xla_device() -class TestAutocastXla(unittest.TestCase): +class TestPatchedLinearAutocastXla(unittest.TestCase): def test_patchedlinear_autocast(self): hidden_size = 10 intermediate_size = 15 @@ -20,19 +19,23 @@ def test_patchedlinear_autocast(self): linear = linear.to(device) input_tensor = input_tensor.to(device) output = xs.xla_patched_nn_linear_forward(linear, input_tensor) - sum = (output**output).mean() + result = (output**output).mean() + # a simple sum would suffice, but then hlo for linear.weight.grad would be only the backward pass + # since grad of sum is constant. Hence this is done, to get whole execution graph in HLO. - sum.backward() + result.backward() hlo = torch_xla._XLAC._get_xla_tensors_hlo([linear.weight.grad]) - + # Verify that matrix multiplication is performed in bfloat16 precision and not f32. + # XLAPatchedLinear uses einsum instead of matmul, hence this is basically checking if einsum op in this layer is being autocast. self.assertTrue(re.search(r".*dot.*bf16", hlo) is not None) - self.assertTrue(re.search(r".*dot.*f32", hlo) is None) bf16_to_f32 = len(re.findall(r".*convert.*f32.*convert.*bf16", hlo)) f32_to_bf16 = len(re.findall(r".*convert.*bf16.*convert.*f32", hlo)) + # Verify that precision conversions are happening self.assertTrue(bf16_to_f32 > 0 and f32_to_bf16 > 0) - self.assertTrue(bf16_to_f32 > f32_to_bf16) # cast for grad in backward, so bf16_to_f32 should be higher + # Verify more bf16->f32 conversions than f32->bf16, since this is expected during backward pass for grad computation + self.assertTrue(bf16_to_f32 > f32_to_bf16) if __name__ == "__main__": From be1f1aff631183ca7406ef6cfb54fbea779c16e2 Mon Sep 17 00:00:00 2001 From: Nachiket Mokashi Date: Fri, 6 Dec 2024 06:15:25 +0000 Subject: [PATCH 3/8] YAPF formatting, consolidated tests in single file, addressed comments --- test/test_autocast_xla.py | 68 ++++++++++++++++++++++++++++++++++++++ test/test_bf16_autocast.py | 29 ---------------- 2 files changed, 68 insertions(+), 29 deletions(-) create mode 100644 test/test_autocast_xla.py delete mode 100644 test/test_bf16_autocast.py diff --git a/test/test_autocast_xla.py b/test/test_autocast_xla.py new file mode 100644 index 00000000000..81ee59e9e5d --- /dev/null +++ b/test/test_autocast_xla.py @@ -0,0 +1,68 @@ +import re +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import unittest + +import torch_xla.distributed.spmd.xla_sharding as xs + +device = xm.xla_device() + + +class TestAutocastXla(unittest.TestCase): + + def test_cross_entropy_loss(self): + data = torch.randn(16, 10).to(torch.bfloat16).to(device) + target = torch.randn(16, 10).to(torch.bfloat16).to(device) + + with torch.autocast("xla"): + loss = torch.nn.CrossEntropyLoss()(data, target) + hlo = torch_xla._XLAC._get_xla_tensors_hlo([loss]) + self.assertRegex(hlo, r".*convert.*f32.*convert.*bf16") + self.assertRegex(hlo, r".*exponential.*f32.*exponential.*f32") + self.assertRegex(hlo, r".*log.*f32.*log.*f32") + + def test_einsum(self): + data = torch.randn(16, 10).to(torch.bfloat16).to(device) + target = torch.randn(5, 10).to(torch.bfloat16).to(device) + + with torch.autocast("xla"): + product = torch.einsum("...n,mn->...m", data, target) + # test the HLO to see if autocast works for einsum op, which would show up as a dot op in the HLO + hlo = torch_xla._XLAC._get_xla_tensors_hlo([product]) + # Verify that dot op has bf16 output and not f32, i.e. the computation is performed in bfloat16 precision by autocast + self.assertRegex(hlo, r".*dot.*bf16") + self.assertNotRegex(hlo, r".*dot.*f32") + + def test_patchedlinear_autocast(self): + hidden_size = 10 + intermediate_size = 15 + input_tensor = torch.randn( + 5, hidden_size, requires_grad=True) # batch_size=5 as example + linear = torch.nn.Linear(hidden_size, intermediate_size, bias=True) + + with torch.autocast("xla", enabled=True): + linear = linear.to(device) + input_tensor = input_tensor.to(device) + output = xs.xla_patched_nn_linear_forward(linear, input_tensor) + result = (output**output).mean() + # a simple sum would suffice, but then hlo for linear.weight.grad would be only the backward pass + # since grad of sum is constant. Hence this is done, to get whole execution graph in HLO. + + result.backward() + hlo = torch_xla._XLAC._get_xla_tensors_hlo([linear.weight.grad]) + # Verify that matrix multiplication is performed in bfloat16 precision and not f32. + # XLAPatchedLinear uses einsum instead of matmul, hence this is checking if einsum op in this layer is being autocast. + self.assertRegex(hlo, r".*dot.*bf16") + self.assertNotRegex(hlo, r".*dot.*f32") + + bf16_to_f32 = len(re.findall(r".*convert.*f32.*convert.*bf16", hlo)) + f32_to_bf16 = len(re.findall(r".*convert.*bf16.*convert.*f32", hlo)) + # Verify that precision conversions are happening + self.assertTrue(bf16_to_f32 > 0 and f32_to_bf16 > 0) + # Verify more bf16->f32 conversions than f32->bf16, since this is expected during backward pass for grad computation + self.assertTrue(bf16_to_f32 > f32_to_bf16) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_bf16_autocast.py b/test/test_bf16_autocast.py deleted file mode 100644 index d5facd802cd..00000000000 --- a/test/test_bf16_autocast.py +++ /dev/null @@ -1,29 +0,0 @@ -import os -import re -import torch -import torch_xla -import torch_xla.core.xla_model as xm -import unittest - -device = xm.xla_device() - - -class TestAutocastXla(unittest.TestCase): - - def test_cross_entropy_loss(self): - data = torch.randn(16, 10).to(torch.bfloat16).to(device) - target = torch.randn(16, 10).to(torch.bfloat16).to(device) - with torch.autocast("xla"): - loss = torch.nn.CrossEntropyLoss()(data, target) - hlo = torch_xla._XLAC._get_xla_tensors_hlo([loss]) - self.assertTrue( - re.search(rf".*convert.*f32.*convert.*bf16", hlo) is not None) - - self.assertTrue( - re.search(rf".*exponential.*f32.*exponential.*f32", hlo) is not None) - - self.assertTrue(re.search(rf".*log.*f32.*log.*f32", hlo) is not None) - - -if __name__ == "__main__": - unittest.main() From 4ed246671a802110d0cf21ccb9b8bdc912a811d0 Mon Sep 17 00:00:00 2001 From: Nachiket Mokashi Date: Fri, 6 Dec 2024 08:02:09 +0000 Subject: [PATCH 4/8] Changed assertions, re-lint with different spaces --- test/test_autocast_xla.py | 95 +++++++++++++---------------- test/test_patchedlinear_autocast.py | 42 ------------- 2 files changed, 43 insertions(+), 94 deletions(-) delete mode 100644 test/test_patchedlinear_autocast.py diff --git a/test/test_autocast_xla.py b/test/test_autocast_xla.py index 81ee59e9e5d..4de48f89bf6 100644 --- a/test/test_autocast_xla.py +++ b/test/test_autocast_xla.py @@ -11,58 +11,49 @@ class TestAutocastXla(unittest.TestCase): - def test_cross_entropy_loss(self): - data = torch.randn(16, 10).to(torch.bfloat16).to(device) - target = torch.randn(16, 10).to(torch.bfloat16).to(device) - - with torch.autocast("xla"): - loss = torch.nn.CrossEntropyLoss()(data, target) - hlo = torch_xla._XLAC._get_xla_tensors_hlo([loss]) - self.assertRegex(hlo, r".*convert.*f32.*convert.*bf16") - self.assertRegex(hlo, r".*exponential.*f32.*exponential.*f32") - self.assertRegex(hlo, r".*log.*f32.*log.*f32") - - def test_einsum(self): - data = torch.randn(16, 10).to(torch.bfloat16).to(device) - target = torch.randn(5, 10).to(torch.bfloat16).to(device) - - with torch.autocast("xla"): - product = torch.einsum("...n,mn->...m", data, target) - # test the HLO to see if autocast works for einsum op, which would show up as a dot op in the HLO - hlo = torch_xla._XLAC._get_xla_tensors_hlo([product]) - # Verify that dot op has bf16 output and not f32, i.e. the computation is performed in bfloat16 precision by autocast - self.assertRegex(hlo, r".*dot.*bf16") - self.assertNotRegex(hlo, r".*dot.*f32") - - def test_patchedlinear_autocast(self): - hidden_size = 10 - intermediate_size = 15 - input_tensor = torch.randn( - 5, hidden_size, requires_grad=True) # batch_size=5 as example - linear = torch.nn.Linear(hidden_size, intermediate_size, bias=True) - - with torch.autocast("xla", enabled=True): - linear = linear.to(device) - input_tensor = input_tensor.to(device) - output = xs.xla_patched_nn_linear_forward(linear, input_tensor) - result = (output**output).mean() - # a simple sum would suffice, but then hlo for linear.weight.grad would be only the backward pass - # since grad of sum is constant. Hence this is done, to get whole execution graph in HLO. - - result.backward() - hlo = torch_xla._XLAC._get_xla_tensors_hlo([linear.weight.grad]) - # Verify that matrix multiplication is performed in bfloat16 precision and not f32. - # XLAPatchedLinear uses einsum instead of matmul, hence this is checking if einsum op in this layer is being autocast. - self.assertRegex(hlo, r".*dot.*bf16") - self.assertNotRegex(hlo, r".*dot.*f32") - - bf16_to_f32 = len(re.findall(r".*convert.*f32.*convert.*bf16", hlo)) - f32_to_bf16 = len(re.findall(r".*convert.*bf16.*convert.*f32", hlo)) - # Verify that precision conversions are happening - self.assertTrue(bf16_to_f32 > 0 and f32_to_bf16 > 0) - # Verify more bf16->f32 conversions than f32->bf16, since this is expected during backward pass for grad computation - self.assertTrue(bf16_to_f32 > f32_to_bf16) + def test_cross_entropy_loss(self): + data = torch.randn(16, 10).to(torch.bfloat16).to(device) + target = torch.randn(16, 10).to(torch.bfloat16).to(device) + + with torch.autocast("xla"): + loss = torch.nn.CrossEntropyLoss()(data, target) + hlo = torch_xla._XLAC._get_xla_tensors_hlo([loss]) + self.assertRegex(hlo, r".*convert.*f32.*convert.*bf16") + self.assertRegex(hlo, r".*exponential.*f32.*exponential.*f32") + self.assertRegex(hlo, r".*log.*f32.*log.*f32") + + def test_patchedlinear_autocast(self): + hidden_size = 10 + intermediate_size = 15 + input_tensor = torch.randn( + 5, hidden_size, requires_grad=True) # batch_size=5 as example + linear = torch.nn.Linear(hidden_size, intermediate_size, bias=True) + + with torch.autocast("xla", enabled=True): + linear = linear.to(device) + input_tensor = input_tensor.to(device) + output = xs.xla_patched_nn_linear_forward(linear, input_tensor) + result = (output**output).mean() + # a simple sum would suffice, but then hlo for linear.weight.grad would be only the backward pass + # since grad of sum is constant. Hence this is done, to get whole execution graph in HLO. + + result.backward() + hlo = torch_xla._XLAC._get_xla_tensors_hlo([linear.weight.grad]) + # Verify that matrix multiplication is performed in bfloat16 precision and not f32. + # XLAPatchedLinear uses einsum instead of matmul, hence this is checking if einsum op in this layer is being autocast. + self.assertRegex(hlo, r".*dot.*bf16") + self.assertNotRegex(hlo, r".*dot.*f32") + + bf16_to_f32 = len(re.findall(r".*convert.*f32.*convert.*bf16", hlo)) + f32_to_bf16 = len(re.findall(r".*convert.*bf16.*convert.*f32", hlo)) + + # Verify that precision conversions are happening + self.assertTrue(bf16_to_f32 > 0 and f32_to_bf16 > 0) + # Verify more bf16->f32 conversions than f32->bf16, since this is expected during backward pass for grad computation + self.assertTrue(bf16_to_f32 == 11) + self.assertTrue(f32_to_bf16 == 8) + self.assertTrue(bf16_to_f32 > f32_to_bf16) #redundant given the above two, but this is what we actually want to verify if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/test/test_patchedlinear_autocast.py b/test/test_patchedlinear_autocast.py deleted file mode 100644 index f9988280ee6..00000000000 --- a/test/test_patchedlinear_autocast.py +++ /dev/null @@ -1,42 +0,0 @@ -import re -import torch -import torch_xla -import torch_xla.core.xla_model as xm -import unittest - -import torch_xla.distributed.spmd.xla_sharding as xs - -device = xm.xla_device() - -class TestPatchedLinearAutocastXla(unittest.TestCase): - def test_patchedlinear_autocast(self): - hidden_size = 10 - intermediate_size = 15 - input_tensor = torch.randn(5, hidden_size, requires_grad=True) # batch_size=5 as example - linear = torch.nn.Linear(hidden_size, intermediate_size, bias=True) - - with torch.autocast("xla", enabled=True): - linear = linear.to(device) - input_tensor = input_tensor.to(device) - output = xs.xla_patched_nn_linear_forward(linear, input_tensor) - result = (output**output).mean() - # a simple sum would suffice, but then hlo for linear.weight.grad would be only the backward pass - # since grad of sum is constant. Hence this is done, to get whole execution graph in HLO. - - result.backward() - hlo = torch_xla._XLAC._get_xla_tensors_hlo([linear.weight.grad]) - # Verify that matrix multiplication is performed in bfloat16 precision and not f32. - # XLAPatchedLinear uses einsum instead of matmul, hence this is basically checking if einsum op in this layer is being autocast. - self.assertTrue(re.search(r".*dot.*bf16", hlo) is not None) - self.assertTrue(re.search(r".*dot.*f32", hlo) is None) - - bf16_to_f32 = len(re.findall(r".*convert.*f32.*convert.*bf16", hlo)) - f32_to_bf16 = len(re.findall(r".*convert.*bf16.*convert.*f32", hlo)) - # Verify that precision conversions are happening - self.assertTrue(bf16_to_f32 > 0 and f32_to_bf16 > 0) - # Verify more bf16->f32 conversions than f32->bf16, since this is expected during backward pass for grad computation - self.assertTrue(bf16_to_f32 > f32_to_bf16) - - -if __name__ == "__main__": - unittest.main() From c93a6832b3b0daa524139595b16ce79305057ba0 Mon Sep 17 00:00:00 2001 From: Nachiket Mokashi Date: Mon, 9 Dec 2024 05:26:57 +0000 Subject: [PATCH 5/8] Added decorator description to docstring --- torch_xla/distributed/spmd/xla_sharding.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 1decbcf8402..0d20c3027b1 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -735,6 +735,14 @@ class XLAPatchedLinear(torch.autograd.Function): dimensions. The torch.matmul default behavior makes it very hard for XLA compiler to propagate the sharding annotation. + Autocast decorators @custom_fwd and @custom_bwd used as per autocast docs [1] to bring this class/layer within + autocast context, when autocast is enabled. + torch.get_autocast_dtype() fetches datatype for ops run in autocast [2], with the specified device (here, 'xla'). + + References: + [1] https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops + [2] https://github.com/pytorch/pytorch/blob/2cc01cc6d3ad2aff47e8460667ba654b2e4c9f21/torch/amp/autocast_mode.py#L500 + TODO (alanwaketan): Let's patch it on the dispatcher level. """ From 39ec866e5515e65c4d19148d02d29fcf5b09d6f1 Mon Sep 17 00:00:00 2001 From: Nachiket Mokashi Date: Mon, 9 Dec 2024 17:21:00 +0000 Subject: [PATCH 6/8] Linting fixes --- test/test_autocast_xla.py | 5 +++-- torch_xla/distributed/spmd/xla_sharding.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_autocast_xla.py b/test/test_autocast_xla.py index 4de48f89bf6..352cc7bed1e 100644 --- a/test/test_autocast_xla.py +++ b/test/test_autocast_xla.py @@ -10,7 +10,6 @@ class TestAutocastXla(unittest.TestCase): - def test_cross_entropy_loss(self): data = torch.randn(16, 10).to(torch.bfloat16).to(device) target = torch.randn(16, 10).to(torch.bfloat16).to(device) @@ -52,7 +51,9 @@ def test_patchedlinear_autocast(self): # Verify more bf16->f32 conversions than f32->bf16, since this is expected during backward pass for grad computation self.assertTrue(bf16_to_f32 == 11) self.assertTrue(f32_to_bf16 == 8) - self.assertTrue(bf16_to_f32 > f32_to_bf16) #redundant given the above two, but this is what we actually want to verify + self.assertTrue( + bf16_to_f32 > f32_to_bf16 + ) #redundant given the above two, but this is what we actually want to verify if __name__ == "__main__": diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 0d20c3027b1..5ea1343bfcd 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -16,6 +16,7 @@ from torch.amp import custom_fwd, custom_bwd + class Mesh: """Describe the logical XLA device topology mesh and the underlying resources. From e17bc2ce80d6e6c7bc3826d1d8067967a7e1d747 Mon Sep 17 00:00:00 2001 From: Nachiket Mokashi Date: Mon, 9 Dec 2024 18:06:02 +0000 Subject: [PATCH 7/8] Lint fix - blank line after classname --- test/test_autocast_xla.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_autocast_xla.py b/test/test_autocast_xla.py index 352cc7bed1e..657603b9df7 100644 --- a/test/test_autocast_xla.py +++ b/test/test_autocast_xla.py @@ -10,6 +10,7 @@ class TestAutocastXla(unittest.TestCase): + def test_cross_entropy_loss(self): data = torch.randn(16, 10).to(torch.bfloat16).to(device) target = torch.randn(16, 10).to(torch.bfloat16).to(device) From 992a38a405e6216b2828eb1bf5099176439945d3 Mon Sep 17 00:00:00 2001 From: Nachiket Mokashi Date: Mon, 9 Dec 2024 18:45:33 +0000 Subject: [PATCH 8/8] Linting fix - blank space --- test/test_autocast_xla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_autocast_xla.py b/test/test_autocast_xla.py index 657603b9df7..962fa52a59c 100644 --- a/test/test_autocast_xla.py +++ b/test/test_autocast_xla.py @@ -10,7 +10,7 @@ class TestAutocastXla(unittest.TestCase): - + def test_cross_entropy_loss(self): data = torch.randn(16, 10).to(torch.bfloat16).to(device) target = torch.randn(16, 10).to(torch.bfloat16).to(device)