Skip to content

Commit

Permalink
do not wrap bool
Browse files Browse the repository at this point in the history
  • Loading branch information
lsy323 committed Feb 2, 2024
1 parent 4989c84 commit 1d53747
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torch_xla/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,16 +237,16 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any:
# Under `fake_tensor_mode` a fake tensor will be created. This is not a
# use case for XLAExportInterpreter right now, adding to be future-proof.
with torch.utils._python_dispatch._disable_current_modes():
# If the op spec expects a `Tensor` input, we wrap the python primitive
# type to a torch.tensor. The dtype for float respects
# If the op spec expects a `Tensor` input, we wrap the python float/int
# to a torch.tensor. The dtype for float respects
# `torch.get_default_dtype`. Without this wrapping, the python float
# will be wrapped before it enters dispatcher, and it doesn't respect
# the global default dtype.
args_and_specs = tuple(zip(args, target._schema.arguments))
args = tuple(
map(
lambda arg_spec: torch.tensor(arg_spec[0])
if isinstance(arg_spec[0], (float, int, bool)) and type(arg_spec[
if isinstance(arg_spec[0], (float, int)) and type(arg_spec[
1].type) == torch.TensorType else arg_spec[0],
args_and_specss))
return super().call_function(target, args, new_kwargs)
Expand Down

0 comments on commit 1d53747

Please sign in to comment.