From 5dbbb284d57f8804a8dc0086974343dd1a7ca3f9 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 5 Feb 2024 16:13:46 -0800 Subject: [PATCH] Wrap constant arg in XLAExportInterpreter (#6460) --- test/stablehlo/test_xla_export_interpreter.py | 47 +++++++++++++++++++ torch_xla/stablehlo.py | 17 +++++++ 2 files changed, 64 insertions(+) create mode 100644 test/stablehlo/test_xla_export_interpreter.py diff --git a/test/stablehlo/test_xla_export_interpreter.py b/test/stablehlo/test_xla_export_interpreter.py new file mode 100644 index 00000000000..85a4607d3a8 --- /dev/null +++ b/test/stablehlo/test_xla_export_interpreter.py @@ -0,0 +1,47 @@ +import re +import sys +import unittest + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +from torch_xla.stablehlo import exported_program_to_stablehlo + +device = xm.xla_device() + + +class XLAExportInterpreterTest(unittest.TestCase): + + def test_constant_wrapping(self): + + class M(torch.nn.Module): + + def forward(self, x): + return 1.0 - x + + ep = torch.export.export(M(), (torch.rand(2, 3),)) + ep = ep.run_decompositions() + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search(r'stablehlo.constant.* : tensor<2x3xf32>', shlo_text) + is not None) + + def test_constant_wrapping_scalar_variant(self): + + class M(torch.nn.Module): + + def forward(self, x): + return torch.ops.aten.rsub(x, 1.0) + + ep = torch.export.export(M(), (torch.rand(2, 3),)) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search(r'stablehlo.constant.* : tensor<2x3xf32>', shlo_text) + is not None) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index 09e21514b8f..a8ff5df7667 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -233,6 +233,23 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: new_kwargs = dict(kwargs) if 'device' in kwargs: new_kwargs['device'] = self._device + # If the op spec expects a `Tensor` input, we wrap the python float/int + # to a torch.tensor. The dtype for float respects + # `torch.get_default_dtype`. Without this wrapping, the python float + # will be wrapped before it enters dispatcher, and it doesn't respect + # the global default dtype. + if hasattr(target, '_schema'): + # Note: Use `_disable_current_modes` to alwasys create constant tensor. + # Under `fake_tensor_mode` a fake tensor will be created. This is not a + # use case for XLAExportInterpreter right now, adding to be future-proof. + with torch.utils._python_dispatch._disable_current_modes(): + args_and_specs = tuple(zip(args, target._schema.arguments)) + args = tuple( + map( + lambda arg_spec: torch.tensor(arg_spec[0]) + if isinstance(arg_spec[0], (float, int)) and type(arg_spec[ + 1].type) == torch.TensorType else arg_spec[0], + args_and_specs)) return super().call_function(target, args, new_kwargs) def run_node(self, n) -> Any: