Skip to content

Commit

Permalink
Wrap constant arg in XLAExportInterpreter (#6460)
Browse files Browse the repository at this point in the history
  • Loading branch information
lsy323 authored and bhavya01 committed Apr 22, 2024
1 parent df67eaf commit b987f24
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
47 changes: 47 additions & 0 deletions test/stablehlo/test_xla_export_interpreter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import re
import sys
import unittest

import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.stablehlo import exported_program_to_stablehlo

device = xm.xla_device()


class XLAExportInterpreterTest(unittest.TestCase):

def test_constant_wrapping(self):

class M(torch.nn.Module):

def forward(self, x):
return 1.0 - x

ep = torch.export.export(M(), (torch.rand(2, 3),))
ep = ep.run_decompositions()
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(r'stablehlo.constant.* : tensor<2x3xf32>', shlo_text)
is not None)

def test_constant_wrapping_scalar_variant(self):

class M(torch.nn.Module):

def forward(self, x):
return torch.ops.aten.rsub(x, 1.0)

ep = torch.export.export(M(), (torch.rand(2, 3),))
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(r'stablehlo.constant.* : tensor<2x3xf32>', shlo_text)
is not None)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
17 changes: 17 additions & 0 deletions torch_xla/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,23 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any:
new_kwargs = dict(kwargs)
if 'device' in kwargs:
new_kwargs['device'] = self._device
# If the op spec expects a `Tensor` input, we wrap the python float/int
# to a torch.tensor. The dtype for float respects
# `torch.get_default_dtype`. Without this wrapping, the python float
# will be wrapped before it enters dispatcher, and it doesn't respect
# the global default dtype.
if hasattr(target, '_schema'):
# Note: Use `_disable_current_modes` to alwasys create constant tensor.
# Under `fake_tensor_mode` a fake tensor will be created. This is not a
# use case for XLAExportInterpreter right now, adding to be future-proof.
with torch.utils._python_dispatch._disable_current_modes():
args_and_specs = tuple(zip(args, target._schema.arguments))
args = tuple(
map(
lambda arg_spec: torch.tensor(arg_spec[0])
if isinstance(arg_spec[0], (float, int)) and type(arg_spec[
1].type) == torch.TensorType else arg_spec[0],
args_and_specs))
return super().call_function(target, args, new_kwargs)

def run_node(self, n) -> Any:
Expand Down

0 comments on commit b987f24

Please sign in to comment.