Skip to content

Commit

Permalink
check '_schema' attr for callsite
Browse files Browse the repository at this point in the history
  • Loading branch information
lsy323 committed Feb 5, 2024
1 parent 1d53747 commit a2b8e8c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 17 deletions.
14 changes: 14 additions & 0 deletions test/stablehlo/test_xla_export_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ def forward(self, x):
re.search(r'stablehlo.constant.* : tensor<2x3xf32>', shlo_text)
is not None)

def test_constant_wrapping_scalar_variant(self):

class M(torch.nn.Module):

def forward(self, x):
return torch.ops.aten.rsub(x, 1.0)

ep = torch.export.export(M(), (torch.rand(2, 3),))
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(r'stablehlo.constant.* : tensor<2x3xf32>', shlo_text)
is not None)


if __name__ == '__main__':
test = unittest.main()
Expand Down
35 changes: 18 additions & 17 deletions torch_xla/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,23 +233,24 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any:
new_kwargs = dict(kwargs)
if 'device' in kwargs:
new_kwargs['device'] = self._device
# Note: Use `_disable_current_modes` to alwasys create constant tensor.
# Under `fake_tensor_mode` a fake tensor will be created. This is not a
# use case for XLAExportInterpreter right now, adding to be future-proof.
with torch.utils._python_dispatch._disable_current_modes():
# If the op spec expects a `Tensor` input, we wrap the python float/int
# to a torch.tensor. The dtype for float respects
# `torch.get_default_dtype`. Without this wrapping, the python float
# will be wrapped before it enters dispatcher, and it doesn't respect
# the global default dtype.
args_and_specs = tuple(zip(args, target._schema.arguments))
args = tuple(
map(
lambda arg_spec: torch.tensor(arg_spec[0])
if isinstance(arg_spec[0], (float, int)) and type(arg_spec[
1].type) == torch.TensorType else arg_spec[0],
args_and_specss))
return super().call_function(target, args, new_kwargs)
# If the op spec expects a `Tensor` input, we wrap the python float/int
# to a torch.tensor. The dtype for float respects
# `torch.get_default_dtype`. Without this wrapping, the python float
# will be wrapped before it enters dispatcher, and it doesn't respect
# the global default dtype.
if hasattr(target, '_schema'):
# Note: Use `_disable_current_modes` to alwasys create constant tensor.
# Under `fake_tensor_mode` a fake tensor will be created. This is not a
# use case for XLAExportInterpreter right now, adding to be future-proof.
with torch.utils._python_dispatch._disable_current_modes():
args_and_specs = tuple(zip(args, target._schema.arguments))
args = tuple(
map(
lambda arg_spec: torch.tensor(arg_spec[0])
if isinstance(arg_spec[0], (float, int)) and type(arg_spec[
1].type) == torch.TensorType else arg_spec[0],
args_and_specs))
return super().call_function(target, args, new_kwargs)

def run_node(self, n) -> Any:
if n.op == 'placeholder':
Expand Down

0 comments on commit a2b8e8c

Please sign in to comment.