diff --git a/test/stablehlo/test_xla_export_interpreter.py b/test/stablehlo/test_xla_export_interpreter.py new file mode 100644 index 00000000000..88d8af410f2 --- /dev/null +++ b/test/stablehlo/test_xla_export_interpreter.py @@ -0,0 +1,33 @@ +import re +import sys +import unittest + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +from torch_xla.stablehlo import exported_program_to_stablehlo + +device = xm.xla_device() + + +class XLAExportInterpreterTest(unittest.TestCase): + + def test_constant_wrapping(self): + + class M(torch.nn.Module): + + def forward(self, x): + return 1.0 - x + + ep = torch.export.export(M(), (torch.rand(2, 3),)) + ep = ep.run_decompositions() + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search(r'stablehlo.constant.* : tensor<2x3xf32>', shlo_text) + is not None) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index 09e21514b8f..861e587ae15 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -233,6 +233,22 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: new_kwargs = dict(kwargs) if 'device' in kwargs: new_kwargs['device'] = self._device + # Note: Use `_disable_current_modes` to alwasys create constant tensor. + # Under `fake_tensor_mode` a fake tensor will be created. This is not a + # use case for XLAExportInterpreter right now, adding to be future-proof. + with torch.utils._python_dispatch._disable_current_modes(): + # If the op spec expects a `Tensor` input, we wrap the python primitive + # type to a torch.tensor. The dtype for float respects + # `torch.get_default_dtype`. Without this wrapping, the python float + # will be wrapped before it enters dispatcher, and it doesn't respect + # the global default dtype. + args_and_specs = tuple(zip(args, target._schema.arguments)) + args = tuple( + map( + lambda arg_spec: torch.tensor(arg_spec[0]) + if isinstance(arg_spec[0], (float, int, bool)) and type(arg_spec[ + 1].type) == torch.TensorType else arg_spec[0], + args_and_specss)) return super().call_function(target, args, new_kwargs) def run_node(self, n) -> Any: