diff --git a/test/stablehlo/test_xla_export_interpreter.py b/test/stablehlo/test_xla_export_interpreter.py index 88d8af410f2..85a4607d3a8 100644 --- a/test/stablehlo/test_xla_export_interpreter.py +++ b/test/stablehlo/test_xla_export_interpreter.py @@ -27,6 +27,20 @@ def forward(self, x): re.search(r'stablehlo.constant.* : tensor<2x3xf32>', shlo_text) is not None) + def test_constant_wrapping_scalar_variant(self): + + class M(torch.nn.Module): + + def forward(self, x): + return torch.ops.aten.rsub(x, 1.0) + + ep = torch.export.export(M(), (torch.rand(2, 3),)) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search(r'stablehlo.constant.* : tensor<2x3xf32>', shlo_text) + is not None) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index 68e48820d14..c28c7b239e6 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -233,23 +233,24 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: new_kwargs = dict(kwargs) if 'device' in kwargs: new_kwargs['device'] = self._device - # Note: Use `_disable_current_modes` to alwasys create constant tensor. - # Under `fake_tensor_mode` a fake tensor will be created. This is not a - # use case for XLAExportInterpreter right now, adding to be future-proof. - with torch.utils._python_dispatch._disable_current_modes(): - # If the op spec expects a `Tensor` input, we wrap the python float/int - # to a torch.tensor. The dtype for float respects - # `torch.get_default_dtype`. Without this wrapping, the python float - # will be wrapped before it enters dispatcher, and it doesn't respect - # the global default dtype. - args_and_specs = tuple(zip(args, target._schema.arguments)) - args = tuple( - map( - lambda arg_spec: torch.tensor(arg_spec[0]) - if isinstance(arg_spec[0], (float, int)) and type(arg_spec[ - 1].type) == torch.TensorType else arg_spec[0], - args_and_specss)) - return super().call_function(target, args, new_kwargs) + # If the op spec expects a `Tensor` input, we wrap the python float/int + # to a torch.tensor. The dtype for float respects + # `torch.get_default_dtype`. Without this wrapping, the python float + # will be wrapped before it enters dispatcher, and it doesn't respect + # the global default dtype. + if hasattr(target, '_schema'): + # Note: Use `_disable_current_modes` to alwasys create constant tensor. + # Under `fake_tensor_mode` a fake tensor will be created. This is not a + # use case for XLAExportInterpreter right now, adding to be future-proof. + with torch.utils._python_dispatch._disable_current_modes(): + args_and_specs = tuple(zip(args, target._schema.arguments)) + args = tuple( + map( + lambda arg_spec: torch.tensor(arg_spec[0]) + if isinstance(arg_spec[0], (float, int)) and type(arg_spec[ + 1].type) == torch.TensorType else arg_spec[0], + args_and_specs)) + return super().call_function(target, args, new_kwargs) def run_node(self, n) -> Any: if n.op == 'placeholder':