From 5c50b953f413e13a51e849014d0376f116aadfd5 Mon Sep 17 00:00:00 2001 From: Chunnien Chan <121328115+chunnienc@users.noreply.github.com> Date: Wed, 28 Feb 2024 20:47:11 -0800 Subject: [PATCH] Support building StableHLO composite with different attr value types (#6615) --- test/stablehlo/test_mark_pattern.py | 35 ++++++++++++++++++-- torch_xla/experimental/mark_pattern_utils.py | 7 ++-- torch_xla/experimental/xla_marker.py | 34 ++++++++++++++----- 3 files changed, 64 insertions(+), 12 deletions(-) diff --git a/test/stablehlo/test_mark_pattern.py b/test/stablehlo/test_mark_pattern.py index 41bf3fe76c8f..6306151e0029 100644 --- a/test/stablehlo/test_mark_pattern.py +++ b/test/stablehlo/test_mark_pattern.py @@ -9,6 +9,7 @@ import torch_xla.experimental.xla_marker from torch.utils import _pytree as pytree from torch_xla import stablehlo +from torch_xla.experimental import xla_marker from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder from utils import has_tf_package @@ -108,14 +109,19 @@ def forward(self, x, y): pos=0, id="0", is_input=False, - attr={"scale": 0.25}) + attr=xla_marker.serialize_composite_attr({"scale": 0.25})) q, k, v = y.split(128, dim=-2) q = torch.ops.xla.mark_tensor(q, "sdpa", pos=0, id="1", is_input=True) k = torch.ops.xla.mark_tensor(k, "sdpa", pos=1, id="1", is_input=True) v = torch.ops.xla.mark_tensor(v, "sdpa", pos=2, id="1", is_input=True) attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4) attn_out2 = torch.ops.xla.mark_tensor( - attn_out2, "sdpa", pos=0, id="1", is_input=False, attr={"scale": 2}) + attn_out2, + "sdpa", + pos=0, + id="1", + is_input=False, + attr=xla_marker.serialize_composite_attr({"scale": 2})) return attn_out, attn_out2 input_args = (torch.randn((32, 8, 384, 64)), torch.randn((32, 8, 384, 64))) @@ -236,6 +242,31 @@ def forward(self, x, y): stablehlo = self.run_func_get_stablehlo(M(), input_args) self.assertEqual(stablehlo.count("@stablehlo.composite"), 1) + def test_composite_builder_mix_attr_value_types(self): + + class M(torch.nn.Module): + + def forward(self, x, y): + builder = StableHLOCompositeBuilder( + "sample_composite", { + "int_attr": 1, + "float_attr": 2.3, + "bool_attr": True, + "str_attr": "helloworld", + }) + x, y = builder.mark_inputs(x, y) + z = x + y + z = builder.mark_outputs(z) + return z + + input_args = (torch.randn((5, 5)), torch.randn((5, 5))) + stablehlo = self.run_func_get_stablehlo(M(), input_args) + self.assertEqual(stablehlo.count("@stablehlo.composite"), 1) + self.assertEqual(stablehlo.count('int_attr = 1 : i64'), 1) + self.assertEqual(stablehlo.count('float_attr = 2.300000e+00 : f32'), 1) + self.assertEqual(stablehlo.count('bool_attr = true'), 1) + self.assertEqual(stablehlo.count('str_attr = "helloworld"'), 1) + def test_multiple_inputs(self): def f(x, y): diff --git a/torch_xla/experimental/mark_pattern_utils.py b/torch_xla/experimental/mark_pattern_utils.py index d67d2c6996b0..8e5dfeca9762 100644 --- a/torch_xla/experimental/mark_pattern_utils.py +++ b/torch_xla/experimental/mark_pattern_utils.py @@ -3,7 +3,7 @@ import torch import torch._dynamo as torchdynamo -import torch_xla.experimental.xla_marker +from torch_xla.experimental import xla_marker @torchdynamo.assume_constant_result @@ -33,6 +33,9 @@ def __init__(self, name: str, attr: Dict[str, Union[int, float, str]] = None): def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool): marked_tensors = [] + serialized_attr = xla_marker.serialize_composite_attr( + self.attr) if not is_input else None + for pos, tensor in enumerate(tensors): if not isinstance(tensor, torch.Tensor): raise ValueError(f"input must be a torch tensor. Got {type(tensor)}.") @@ -43,7 +46,7 @@ def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool): pos=pos, id=self.id, is_input=is_input, - attr=self.attr if not is_input else None, + attr=serialized_attr, )) if len(marked_tensors) == 1: diff --git a/torch_xla/experimental/xla_marker.py b/torch_xla/experimental/xla_marker.py index 647b3a30927d..e15552f9ed63 100644 --- a/torch_xla/experimental/xla_marker.py +++ b/torch_xla/experimental/xla_marker.py @@ -4,6 +4,7 @@ from typing import Dict import torch +import torch._dynamo as torchdynamo from torch.library import impl import torch_xla from torch_xla.core.xla_model import XLA_LIB @@ -39,9 +40,25 @@ def _assert_valid_composite_attr(attr): for k, v in attr.items(): if not isinstance(k, str): raise ValueError("Composite attr name must be a Python str.") - if type(k) not in [str, float, int]: + if type(k) not in (str, float, int, bool): raise ValueError( - "Composite attr value must be either Python str, float, or int.") + "Composite attr value must be either Python str, float, int, or bool." + ) + + +@torchdynamo.assume_constant_result +def serialize_composite_attr(attr: Dict): + if attr is None: + return None + _assert_valid_composite_attr(attr) + return tuple(attr.items()) + + +@torchdynamo.assume_constant_result +def deserialize_composite_attr(attr) -> Dict: + if attr is None: + return None + return dict(attr) @impl(XLA_LIB, "mark_tensor", "XLA") @@ -50,7 +67,7 @@ def mark_tensor_xla(x: torch.Tensor, pos: int, id: str, is_input: bool, - attr: Dict = None): + attr=None): """Attach pattern boundary metadata to a XLA Tensor. Args: @@ -59,9 +76,10 @@ def mark_tensor_xla(x: torch.Tensor, pos: int - Input/output Position of the annotated tensor in the pattern. id: str - Unique identifier of the pattern instance. is_input: bool - If the annotated tensor is the input to the pattern. - attr: dict - Attribute of the pattern, it will be passed down to the attribute field - in the stablehlo composite. + attr - Attribute of the pattern. It must be a value generated by serialize_composite_attr + and will be passed down to the attribute field in the stablehlo composite. """ + attr = deserialize_composite_attr(attr) _assert_valid_composite_attr(attr) pattern_info = BoundaryMetadata(name, pos, id, is_input, attr) return torch_xla._XLAC._xla_mark_tensor( @@ -74,7 +92,7 @@ def mark_tensor(x: torch.Tensor, pos: int, id: str, is_input: bool, - attr: Dict = None): + attr=None): # Do nothing for non-xla tensor. return x @@ -85,5 +103,5 @@ def mark_tensor_meta(x: torch.Tensor, pos: int, id: str, is_input: bool, - attr: Dict = None): - return torch.empty_like(x) \ No newline at end of file + attr=None): + return torch.empty_like(x)