diff --git a/test/stablehlo/test_mark_pattern.py b/test/stablehlo/test_mark_pattern.py index 1980e236f75..66148898e1d 100644 --- a/test/stablehlo/test_mark_pattern.py +++ b/test/stablehlo/test_mark_pattern.py @@ -1,4 +1,6 @@ +import os import sys +import tempfile import unittest import torch @@ -6,7 +8,9 @@ import torch_xla.core.xla_model as xm import torch_xla.experimental.xla_marker from torch.utils import _pytree as pytree +from torch_xla import stablehlo from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder +from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model class XlaMarkPatternTest(unittest.TestCase): @@ -24,13 +28,20 @@ def run_func_get_stablehlo(self, f, input_args): stablehlo = xm.get_stablehlo(out) return stablehlo + def export_func(self, f, args, saved_model_path=None): + exported = torch.export.export(f, args) + stablehlo_gm = stablehlo.exported_program_to_stablehlo(exported) + if saved_model_path is not None: + save_torch_module_as_tf_saved_model(f, args, saved_model_path) + return stablehlo_gm + def test_basic(self): def f(x): x = x + 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, 0, True) x = x + 2 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", False) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, 0, False) return x input_args = (torch.randn(5),) @@ -49,29 +60,24 @@ def __init__(self): def forward(self, x, y): q, k, v = x.split(128, dim=-2) q = torch.ops.xla_pattern_marking.mark_tensor( - q, "sdpa", pos=0, id="0", is_input=True) + q, "sdpa", pos=0, id=0, is_input=True) k = torch.ops.xla_pattern_marking.mark_tensor( - k, "sdpa", pos=1, id="0", is_input=True) + k, "sdpa", pos=1, id=0, is_input=True) v = torch.ops.xla_pattern_marking.mark_tensor( - v, "sdpa", pos=2, id="0", is_input=True) + v, "sdpa", pos=2, id=0, is_input=True) attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25) attn_out = torch.ops.xla_pattern_marking.mark_tensor( - attn_out, - "sdpa", - pos=0, - id="0", - is_input=False, - attr={"scale": 0.25}) + attn_out, "sdpa", pos=0, id=0, is_input=False, attr={"scale": 0.25}) q, k, v = y.split(128, dim=-2) q = torch.ops.xla_pattern_marking.mark_tensor( - q, "sdpa", pos=0, id="1", is_input=True) + q, "sdpa", pos=0, id=1, is_input=True) k = torch.ops.xla_pattern_marking.mark_tensor( - k, "sdpa", pos=1, id="1", is_input=True) + k, "sdpa", pos=1, id=1, is_input=True) v = torch.ops.xla_pattern_marking.mark_tensor( - v, "sdpa", pos=2, id="1", is_input=True) + v, "sdpa", pos=2, id=1, is_input=True) attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4) attn_out2 = torch.ops.xla_pattern_marking.mark_tensor( - attn_out2, "sdpa", pos=0, id="1", is_input=False, attr={"scale": 2}) + attn_out2, "sdpa", pos=0, id=1, is_input=False, attr={"scale": 2}) return attn_out, attn_out2 input_args = (torch.randn((32, 8, 384, 64)), torch.randn((32, 8, 384, 64))) @@ -113,14 +119,56 @@ def forward(self, x, y): self.assertTrue( '{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo) + def test_uuid_ser_des(self): + import uuid + from torch_xla.experimental.xla_marker import _get_uuid_tensor_internal, decode_uuid_tensor + id = uuid.uuid4() + id_hex = id.hex + id_tensor = _get_uuid_tensor_internal(id) + decoded = decode_uuid_tensor(id_tensor) + self.assertTrue(decoded, id_hex) + + def test_composite_builder_export_sdpa_pattern(self): + + class M(torch.nn.Module): + + def __init__(self): + super().__init__() + self.b = StableHLOCompositeBuilder("sdpa", {"scale": 0.25}) + self.b2 = StableHLOCompositeBuilder("sdpa", {"scale": 2}) + + def forward(self, x, y): + q, k, v = x.split(128, dim=-2) + q, k, v = self.b.mark_inputs(q, k, v) + attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25) + attn_out = self.b.mark_outputs(attn_out) + + q, k, v = y.split(128, dim=-2) + q, k, v = self.b2.mark_inputs(q, k, v) + attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4) + attn_out2 = self.b2.mark_outputs(attn_out2) + return attn_out, attn_out2 + + input_args = (torch.randn((32, 8, 384, 64)), torch.randn((32, 8, 384, 64))) + tmp_path = tempfile.mkdtemp() + stablehlo_gm = self.export_func(M(), input_args, tmp_path) + stablehlo = stablehlo_gm.get_stablehlo_text() + self.assertEqual(stablehlo.count("@stablehlo.composite"), 2) + self.assertTrue( + '{attributes = {scale = 2.500000e-01 : f32}, name = "sdpa"}}' in + stablehlo) + self.assertTrue( + '{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo) + self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb'))) + def test_multiple_input(self): def f(x, y): - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True) - y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, 0, True) + y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, 0, True) out = x + y out = out * x * y - out = torch.ops.xla_pattern_marking.mark_tensor(out, "p", 0, "0", False) + out = torch.ops.xla_pattern_marking.mark_tensor(out, "p", 0, 0, False) return out input_args = (torch.ones(5), torch.ones(5)) @@ -132,12 +180,12 @@ def f(x, y): def test_multiple_output(self): def f(x, y): - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True) - y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, 0, True) + y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, 0, True) out1 = x + y out2 = x * y - out1 = torch.ops.xla_pattern_marking.mark_tensor(out1, "p", 0, "0", False) - out2 = torch.ops.xla_pattern_marking.mark_tensor(out2, "p", 1, "0", False) + out1 = torch.ops.xla_pattern_marking.mark_tensor(out1, "p", 0, 0, False) + out2 = torch.ops.xla_pattern_marking.mark_tensor(out2, "p", 1, 0, False) return out1, out2 input_args = (torch.ones(5), torch.ones(5)) @@ -147,14 +195,13 @@ def f(x, y): def test_nested_pattern(self): def f(x): - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, 0, True) x = x + 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, 0, True) x = x + 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, 0, False) x = x * 2 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", - False) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, 0, False) return x input_args = (torch.ones(5),) @@ -164,14 +211,13 @@ def f(x): def test_tangent_output(self): # Special case of nested pattern, outputs don't have dependencies. def f(x): - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, 0, True) x = x + 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, 0, True) x = x + 1 y = x - 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False) - y = torch.ops.xla_pattern_marking.mark_tensor(y, "p_outter", 0, "0", - False) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, 0, False) + y = torch.ops.xla_pattern_marking.mark_tensor(y, "p_outter", 0, 0, False) return x, y input_args = (torch.ones(5),) diff --git a/test/stablehlo/test_pt2e_qdq.py b/test/stablehlo/test_pt2e_qdq.py index 39e0442e2f4..c901f3684b1 100644 --- a/test/stablehlo/test_pt2e_qdq.py +++ b/test/stablehlo/test_pt2e_qdq.py @@ -1,4 +1,5 @@ import os +import tempfile import unittest from typing import Callable, Dict, List @@ -107,9 +108,9 @@ def test_resnet18(self): stablehlo_txt.count("stablehlo.uniform_dequantize"), fx_node_cnt["dequantize"]) # Save as tf.saved_model - save_path = '/tmp/tf_saved_model/tmp1' - save_torch_module_as_tf_saved_model(m, args, save_path) - self.assertTrue(os.path.exists(os.path.join(save_path, 'saved_model.pb'))) + tmp_path = tempfile.mkdtemp() + save_torch_module_as_tf_saved_model(m, args, tmp_path) + self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb'))) if __name__ == '__main__': diff --git a/torch_xla/experimental/mark_pattern_utils.py b/torch_xla/experimental/mark_pattern_utils.py index b02a702f9ec..40330a5b063 100644 --- a/torch_xla/experimental/mark_pattern_utils.py +++ b/torch_xla/experimental/mark_pattern_utils.py @@ -1,8 +1,8 @@ -import uuid from typing import Dict, Union import torch -from torch_xla.experimental import xla_marker +import torch_xla.experimental.xla_marker +from torch_xla.experimental.xla_marker import get_uuid_tensor class StableHLOCompositeBuilder: @@ -21,7 +21,7 @@ def __init__(self, name: str, attr: Dict[str, Union[int, float, str]] = None): self.attr = attr self.name = name - self.id = uuid.uuid4().hex + self.id = get_uuid_tensor() self._inputs = [] self._outputs = [] diff --git a/torch_xla/experimental/xla_marker.py b/torch_xla/experimental/xla_marker.py index f967fb9b787..ab600bb3e2f 100644 --- a/torch_xla/experimental/xla_marker.py +++ b/torch_xla/experimental/xla_marker.py @@ -1,7 +1,8 @@ import dataclasses import json from dataclasses import dataclass -from typing import Dict +from typing import Dict, Union +import uuid import torch import torch_xla @@ -10,15 +11,44 @@ xla_pattern_marking_lib = Library("xla_pattern_marking", "DEF") xla_pattern_marking_lib.define( - "mark_tensor(Tensor x, str name, int pos, str id, bool is_input, Any? attr=None) -> Tensor" + "mark_tensor(Tensor x, str name, int pos, int id, bool is_input, Any? attr=None) -> Tensor" ) +xla_pattern_marking_lib.define( + "mark_tensor.tensor(Tensor x, str name, int pos, Tensor id, bool is_input, Any? attr=None) -> Tensor" +) + + +def _get_uuid_tensor_internal(id: uuid.UUID): + int_arr = [] + for i in range(4): + int_arr.append(int(id.int >> (128 - 32 * (i + 1)) & 0xFFFFFFFF)) + # Need to use int64 here to avoid an overflow issue in torch. + return torch.tensor(int_arr, dtype=torch.int64) + + +def get_uuid_tensor(): + id = uuid.uuid4() + return _get_uuid_tensor_internal(id) + + +def decode_uuid_tensor(x): + assert len( + x.shape + ) == 1, f"The uuid tensor is expected to be a 1D tensor. Getting shape : {x.shape}." + assert x.numel( + ) == 4, f"The uuid tensor is expected to have 4 elements. Tensor has {x.numel()} elements." + uuid_int = 0 + for i in range(4): + uuid_int += x.cpu()[i] << (32 * i) + return hex(uuid_int) + @dataclass class BoundaryMetadata: name: str # Name of the Patttern. pos: int # Arg/return position. - id: str # Patten instance id. + id: Union[int, torch.Tensor] # Patten instance id. is_input: bool = True # If the marked tensor is input/output. attr: dict = None # Attribute of the pattern, expected to be attached to output. @@ -27,8 +57,14 @@ class BoundaryMetadataSerializer(json.JSONEncoder): def default(self, obj): if dataclasses.is_dataclass(obj): + if isinstance(obj, BoundaryMetadata): + if isinstance(obj.id, torch.Tensor): + obj.id = decode_uuid_tensor(obj.id) + else: + obj.id = str(obj.id) return dataclasses.asdict(obj) - return super().default(obj) + else: + return super().default(obj) def _assert_valid_composite_attr(attr): @@ -49,7 +85,7 @@ def _assert_valid_composite_attr(attr): def mark_tensor_xla(x: torch.Tensor, name: str, pos: int, - id: str, + id: int, is_input: bool, attr: Dict = None): """Attach pattern boundary metadata to a XLA Tensor. @@ -58,7 +94,7 @@ def mark_tensor_xla(x: torch.Tensor, x: torch.Tensor (On XLA device) - the marked tensor. name: str - The name of the pattern, it will be the name of the stablehlo composite op. pos: int - Input/output Position of the annotated tensor in the pattern. - id: str - Unique identifier of the pattern instance. + id: int - Unique identifier of the pattern instance. is_input: bool - If the annotated tensor is the input to the pattern. attr: dict - Attribute of the pattern, it will be passed down to the attribute field in the stablehlo composite. @@ -73,7 +109,7 @@ def mark_tensor_xla(x: torch.Tensor, def mark_tensor(x: torch.Tensor, name: str, pos: int, - id: str, + id: int, is_input: bool, attr: Dict = None): # Do nothing for non-xla tensor. @@ -84,7 +120,44 @@ def mark_tensor(x: torch.Tensor, def mark_tensor_meta(x: torch.Tensor, name: str, pos: int, - id: str, + id: int, + is_input: bool, + attr: Dict = None): + return torch.empty_like(x) + + +@impl(xla_pattern_marking_lib, "mark_tensor.tensor", "XLA") +def mark_tensor_xla(x: torch.Tensor, + name: str, + pos: int, + id: torch.Tensor, + is_input: bool, + attr: Dict = None): + """Variant: `id` is a torch.Tensor, which is generated from `get_uuid_tensor`. + """ + _assert_valid_composite_attr(attr) + pattern_info = BoundaryMetadata(name, pos, id, is_input, attr) + return torch_xla._XLAC._xla_mark_tensor( + x, json.dumps(pattern_info, cls=BoundaryMetadataSerializer)) + + +@impl(xla_pattern_marking_lib, "mark_tensor.tensor", + "CompositeExplicitAutograd") +def mark_tensor(x: torch.Tensor, + name: str, + pos: int, + id: torch.Tensor, + is_input: bool, + attr: Dict = None): + # Do nothing for non-xla tensor. + return x + + +@impl(xla_pattern_marking_lib, "mark_tensor.tensor", "Meta") +def mark_tensor_meta(x: torch.Tensor, + name: str, + pos: int, + id: torch.Tensor, is_input: bool, attr: Dict = None): return torch.empty_like(x)