From 4989c849a624b8e2fccbe93759e32688f633e555 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Fri, 2 Feb 2024 22:31:08 +0000 Subject: [PATCH 1/4] wrap constant arg in XLAExportInterpreter --- test/stablehlo/test_xla_export_interpreter.py | 33 +++++++++++++++++++ torch_xla/stablehlo.py | 16 +++++++++ 2 files changed, 49 insertions(+) create mode 100644 test/stablehlo/test_xla_export_interpreter.py diff --git a/test/stablehlo/test_xla_export_interpreter.py b/test/stablehlo/test_xla_export_interpreter.py new file mode 100644 index 00000000000..88d8af410f2 --- /dev/null +++ b/test/stablehlo/test_xla_export_interpreter.py @@ -0,0 +1,33 @@ +import re +import sys +import unittest + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +from torch_xla.stablehlo import exported_program_to_stablehlo + +device = xm.xla_device() + + +class XLAExportInterpreterTest(unittest.TestCase): + + def test_constant_wrapping(self): + + class M(torch.nn.Module): + + def forward(self, x): + return 1.0 - x + + ep = torch.export.export(M(), (torch.rand(2, 3),)) + ep = ep.run_decompositions() + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search(r'stablehlo.constant.* : tensor<2x3xf32>', shlo_text) + is not None) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index 09e21514b8f..861e587ae15 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -233,6 +233,22 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: new_kwargs = dict(kwargs) if 'device' in kwargs: new_kwargs['device'] = self._device + # Note: Use `_disable_current_modes` to alwasys create constant tensor. + # Under `fake_tensor_mode` a fake tensor will be created. This is not a + # use case for XLAExportInterpreter right now, adding to be future-proof. + with torch.utils._python_dispatch._disable_current_modes(): + # If the op spec expects a `Tensor` input, we wrap the python primitive + # type to a torch.tensor. The dtype for float respects + # `torch.get_default_dtype`. Without this wrapping, the python float + # will be wrapped before it enters dispatcher, and it doesn't respect + # the global default dtype. + args_and_specs = tuple(zip(args, target._schema.arguments)) + args = tuple( + map( + lambda arg_spec: torch.tensor(arg_spec[0]) + if isinstance(arg_spec[0], (float, int, bool)) and type(arg_spec[ + 1].type) == torch.TensorType else arg_spec[0], + args_and_specss)) return super().call_function(target, args, new_kwargs) def run_node(self, n) -> Any: From 1d53747404dc311262abd8bb9643e043e64f7f03 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Fri, 2 Feb 2024 23:09:37 +0000 Subject: [PATCH 2/4] do not wrap bool --- torch_xla/stablehlo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index 861e587ae15..68e48820d14 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -237,8 +237,8 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: # Under `fake_tensor_mode` a fake tensor will be created. This is not a # use case for XLAExportInterpreter right now, adding to be future-proof. with torch.utils._python_dispatch._disable_current_modes(): - # If the op spec expects a `Tensor` input, we wrap the python primitive - # type to a torch.tensor. The dtype for float respects + # If the op spec expects a `Tensor` input, we wrap the python float/int + # to a torch.tensor. The dtype for float respects # `torch.get_default_dtype`. Without this wrapping, the python float # will be wrapped before it enters dispatcher, and it doesn't respect # the global default dtype. @@ -246,7 +246,7 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: args = tuple( map( lambda arg_spec: torch.tensor(arg_spec[0]) - if isinstance(arg_spec[0], (float, int, bool)) and type(arg_spec[ + if isinstance(arg_spec[0], (float, int)) and type(arg_spec[ 1].type) == torch.TensorType else arg_spec[0], args_and_specss)) return super().call_function(target, args, new_kwargs) From a2b8e8c24f46f1f775927254ad00c0738fdefa3b Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 5 Feb 2024 19:00:18 +0000 Subject: [PATCH 3/4] check '_schema' attr for callsite --- test/stablehlo/test_xla_export_interpreter.py | 14 ++++++++ torch_xla/stablehlo.py | 35 ++++++++++--------- 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/test/stablehlo/test_xla_export_interpreter.py b/test/stablehlo/test_xla_export_interpreter.py index 88d8af410f2..85a4607d3a8 100644 --- a/test/stablehlo/test_xla_export_interpreter.py +++ b/test/stablehlo/test_xla_export_interpreter.py @@ -27,6 +27,20 @@ def forward(self, x): re.search(r'stablehlo.constant.* : tensor<2x3xf32>', shlo_text) is not None) + def test_constant_wrapping_scalar_variant(self): + + class M(torch.nn.Module): + + def forward(self, x): + return torch.ops.aten.rsub(x, 1.0) + + ep = torch.export.export(M(), (torch.rand(2, 3),)) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search(r'stablehlo.constant.* : tensor<2x3xf32>', shlo_text) + is not None) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index 68e48820d14..c28c7b239e6 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -233,23 +233,24 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: new_kwargs = dict(kwargs) if 'device' in kwargs: new_kwargs['device'] = self._device - # Note: Use `_disable_current_modes` to alwasys create constant tensor. - # Under `fake_tensor_mode` a fake tensor will be created. This is not a - # use case for XLAExportInterpreter right now, adding to be future-proof. - with torch.utils._python_dispatch._disable_current_modes(): - # If the op spec expects a `Tensor` input, we wrap the python float/int - # to a torch.tensor. The dtype for float respects - # `torch.get_default_dtype`. Without this wrapping, the python float - # will be wrapped before it enters dispatcher, and it doesn't respect - # the global default dtype. - args_and_specs = tuple(zip(args, target._schema.arguments)) - args = tuple( - map( - lambda arg_spec: torch.tensor(arg_spec[0]) - if isinstance(arg_spec[0], (float, int)) and type(arg_spec[ - 1].type) == torch.TensorType else arg_spec[0], - args_and_specss)) - return super().call_function(target, args, new_kwargs) + # If the op spec expects a `Tensor` input, we wrap the python float/int + # to a torch.tensor. The dtype for float respects + # `torch.get_default_dtype`. Without this wrapping, the python float + # will be wrapped before it enters dispatcher, and it doesn't respect + # the global default dtype. + if hasattr(target, '_schema'): + # Note: Use `_disable_current_modes` to alwasys create constant tensor. + # Under `fake_tensor_mode` a fake tensor will be created. This is not a + # use case for XLAExportInterpreter right now, adding to be future-proof. + with torch.utils._python_dispatch._disable_current_modes(): + args_and_specs = tuple(zip(args, target._schema.arguments)) + args = tuple( + map( + lambda arg_spec: torch.tensor(arg_spec[0]) + if isinstance(arg_spec[0], (float, int)) and type(arg_spec[ + 1].type) == torch.TensorType else arg_spec[0], + args_and_specs)) + return super().call_function(target, args, new_kwargs) def run_node(self, n) -> Any: if n.op == 'placeholder': From d8723fac11bc9b79fe0fc09d14bfcbc87126f131 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 5 Feb 2024 19:38:54 +0000 Subject: [PATCH 4/4] fix typo --- torch_xla/stablehlo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index c28c7b239e6..a8ff5df7667 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -250,7 +250,7 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: if isinstance(arg_spec[0], (float, int)) and type(arg_spec[ 1].type) == torch.TensorType else arg_spec[0], args_and_specs)) - return super().call_function(target, args, new_kwargs) + return super().call_function(target, args, new_kwargs) def run_node(self, n) -> Any: if n.op == 'placeholder':