diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index 861e587ae15..68e48820d14 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -237,8 +237,8 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: # Under `fake_tensor_mode` a fake tensor will be created. This is not a # use case for XLAExportInterpreter right now, adding to be future-proof. with torch.utils._python_dispatch._disable_current_modes(): - # If the op spec expects a `Tensor` input, we wrap the python primitive - # type to a torch.tensor. The dtype for float respects + # If the op spec expects a `Tensor` input, we wrap the python float/int + # to a torch.tensor. The dtype for float respects # `torch.get_default_dtype`. Without this wrapping, the python float # will be wrapped before it enters dispatcher, and it doesn't respect # the global default dtype. @@ -246,7 +246,7 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: args = tuple( map( lambda arg_spec: torch.tensor(arg_spec[0]) - if isinstance(arg_spec[0], (float, int, bool)) and type(arg_spec[ + if isinstance(arg_spec[0], (float, int)) and type(arg_spec[ 1].type) == torch.TensorType else arg_spec[0], args_and_specss)) return super().call_function(target, args, new_kwargs)