Skip to content

Commit

Permalink
Add mark_tensor variant taking torch.Tensor id (#6132)
Browse files Browse the repository at this point in the history
Co-authored-by: Siyuan Liu <[email protected]>
  • Loading branch information
2 people authored and bhavya01 committed Apr 22, 2024
1 parent 9760761 commit 0a2725d
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 46 deletions.
110 changes: 78 additions & 32 deletions test/stablehlo/test_mark_pattern.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import os
import sys
import tempfile
import unittest

import torch
import torch.nn.functional as F
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_marker
from torch.utils import _pytree as pytree
from torch_xla import stablehlo
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model


class XlaMarkPatternTest(unittest.TestCase):
Expand All @@ -24,13 +28,20 @@ def run_func_get_stablehlo(self, f, input_args):
stablehlo = xm.get_stablehlo(out)
return stablehlo

def export_func(self, f, args, saved_model_path=None):
exported = torch.export.export(f, args)
stablehlo_gm = stablehlo.exported_program_to_stablehlo(exported)
if saved_model_path is not None:
save_torch_module_as_tf_saved_model(f, args, saved_model_path)
return stablehlo_gm

def test_basic(self):

def f(x):
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True)
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, 0, True)
x = x + 2
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", False)
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, 0, False)
return x

input_args = (torch.randn(5),)
Expand All @@ -49,29 +60,24 @@ def __init__(self):
def forward(self, x, y):
q, k, v = x.split(128, dim=-2)
q = torch.ops.xla_pattern_marking.mark_tensor(
q, "sdpa", pos=0, id="0", is_input=True)
q, "sdpa", pos=0, id=0, is_input=True)
k = torch.ops.xla_pattern_marking.mark_tensor(
k, "sdpa", pos=1, id="0", is_input=True)
k, "sdpa", pos=1, id=0, is_input=True)
v = torch.ops.xla_pattern_marking.mark_tensor(
v, "sdpa", pos=2, id="0", is_input=True)
v, "sdpa", pos=2, id=0, is_input=True)
attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25)
attn_out = torch.ops.xla_pattern_marking.mark_tensor(
attn_out,
"sdpa",
pos=0,
id="0",
is_input=False,
attr={"scale": 0.25})
attn_out, "sdpa", pos=0, id=0, is_input=False, attr={"scale": 0.25})
q, k, v = y.split(128, dim=-2)
q = torch.ops.xla_pattern_marking.mark_tensor(
q, "sdpa", pos=0, id="1", is_input=True)
q, "sdpa", pos=0, id=1, is_input=True)
k = torch.ops.xla_pattern_marking.mark_tensor(
k, "sdpa", pos=1, id="1", is_input=True)
k, "sdpa", pos=1, id=1, is_input=True)
v = torch.ops.xla_pattern_marking.mark_tensor(
v, "sdpa", pos=2, id="1", is_input=True)
v, "sdpa", pos=2, id=1, is_input=True)
attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4)
attn_out2 = torch.ops.xla_pattern_marking.mark_tensor(
attn_out2, "sdpa", pos=0, id="1", is_input=False, attr={"scale": 2})
attn_out2, "sdpa", pos=0, id=1, is_input=False, attr={"scale": 2})
return attn_out, attn_out2

input_args = (torch.randn((32, 8, 384, 64)), torch.randn((32, 8, 384, 64)))
Expand Down Expand Up @@ -113,14 +119,56 @@ def forward(self, x, y):
self.assertTrue(
'{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo)

def test_uuid_ser_des(self):
import uuid
from torch_xla.experimental.xla_marker import _get_uuid_tensor_internal, decode_uuid_tensor
id = uuid.uuid4()
id_hex = id.hex
id_tensor = _get_uuid_tensor_internal(id)
decoded = decode_uuid_tensor(id_tensor)
self.assertTrue(decoded, id_hex)

def test_composite_builder_export_sdpa_pattern(self):

class M(torch.nn.Module):

def __init__(self):
super().__init__()
self.b = StableHLOCompositeBuilder("sdpa", {"scale": 0.25})
self.b2 = StableHLOCompositeBuilder("sdpa", {"scale": 2})

def forward(self, x, y):
q, k, v = x.split(128, dim=-2)
q, k, v = self.b.mark_inputs(q, k, v)
attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25)
attn_out = self.b.mark_outputs(attn_out)

q, k, v = y.split(128, dim=-2)
q, k, v = self.b2.mark_inputs(q, k, v)
attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4)
attn_out2 = self.b2.mark_outputs(attn_out2)
return attn_out, attn_out2

input_args = (torch.randn((32, 8, 384, 64)), torch.randn((32, 8, 384, 64)))
tmp_path = tempfile.mkdtemp()
stablehlo_gm = self.export_func(M(), input_args, tmp_path)
stablehlo = stablehlo_gm.get_stablehlo_text()
self.assertEqual(stablehlo.count("@stablehlo.composite"), 2)
self.assertTrue(
'{attributes = {scale = 2.500000e-01 : f32}, name = "sdpa"}}' in
stablehlo)
self.assertTrue(
'{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo)
self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb')))

def test_multiple_input(self):

def f(x, y):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True)
y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True)
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, 0, True)
y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, 0, True)
out = x + y
out = out * x * y
out = torch.ops.xla_pattern_marking.mark_tensor(out, "p", 0, "0", False)
out = torch.ops.xla_pattern_marking.mark_tensor(out, "p", 0, 0, False)
return out

input_args = (torch.ones(5), torch.ones(5))
Expand All @@ -132,12 +180,12 @@ def f(x, y):
def test_multiple_output(self):

def f(x, y):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True)
y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True)
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, 0, True)
y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, 0, True)
out1 = x + y
out2 = x * y
out1 = torch.ops.xla_pattern_marking.mark_tensor(out1, "p", 0, "0", False)
out2 = torch.ops.xla_pattern_marking.mark_tensor(out2, "p", 1, "0", False)
out1 = torch.ops.xla_pattern_marking.mark_tensor(out1, "p", 0, 0, False)
out2 = torch.ops.xla_pattern_marking.mark_tensor(out2, "p", 1, 0, False)
return out1, out2

input_args = (torch.ones(5), torch.ones(5))
Expand All @@ -147,14 +195,13 @@ def f(x, y):
def test_nested_pattern(self):

def f(x):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True)
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, 0, True)
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True)
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, 0, True)
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False)
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, 0, False)
x = x * 2
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0",
False)
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, 0, False)
return x

input_args = (torch.ones(5),)
Expand All @@ -164,14 +211,13 @@ def f(x):
def test_tangent_output(self):
# Special case of nested pattern, outputs don't have dependencies.
def f(x):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True)
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, 0, True)
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True)
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, 0, True)
x = x + 1
y = x - 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False)
y = torch.ops.xla_pattern_marking.mark_tensor(y, "p_outter", 0, "0",
False)
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, 0, False)
y = torch.ops.xla_pattern_marking.mark_tensor(y, "p_outter", 0, 0, False)
return x, y

input_args = (torch.ones(5),)
Expand Down
7 changes: 4 additions & 3 deletions test/stablehlo/test_pt2e_qdq.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import tempfile
import unittest
from typing import Callable, Dict, List

Expand Down Expand Up @@ -107,9 +108,9 @@ def test_resnet18(self):
stablehlo_txt.count("stablehlo.uniform_dequantize"),
fx_node_cnt["dequantize"])
# Save as tf.saved_model
save_path = '/tmp/tf_saved_model/tmp1'
save_torch_module_as_tf_saved_model(m, args, save_path)
self.assertTrue(os.path.exists(os.path.join(save_path, 'saved_model.pb')))
tmp_path = tempfile.mkdtemp()
save_torch_module_as_tf_saved_model(m, args, tmp_path)
self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb')))


if __name__ == '__main__':
Expand Down
6 changes: 3 additions & 3 deletions torch_xla/experimental/mark_pattern_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import uuid
from typing import Dict, Union

import torch
from torch_xla.experimental import xla_marker
import torch_xla.experimental.xla_marker
from torch_xla.experimental.xla_marker import get_uuid_tensor


class StableHLOCompositeBuilder:
Expand All @@ -21,7 +21,7 @@ def __init__(self, name: str, attr: Dict[str, Union[int, float, str]] = None):

self.attr = attr
self.name = name
self.id = uuid.uuid4().hex
self.id = get_uuid_tensor()
self._inputs = []
self._outputs = []

Expand Down
89 changes: 81 additions & 8 deletions torch_xla/experimental/xla_marker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import dataclasses
import json
from dataclasses import dataclass
from typing import Dict
from typing import Dict, Union
import uuid

import torch
import torch_xla
Expand All @@ -10,15 +11,44 @@
xla_pattern_marking_lib = Library("xla_pattern_marking", "DEF")

xla_pattern_marking_lib.define(
"mark_tensor(Tensor x, str name, int pos, str id, bool is_input, Any? attr=None) -> Tensor"
"mark_tensor(Tensor x, str name, int pos, int id, bool is_input, Any? attr=None) -> Tensor"
)

xla_pattern_marking_lib.define(
"mark_tensor.tensor(Tensor x, str name, int pos, Tensor id, bool is_input, Any? attr=None) -> Tensor"
)


def _get_uuid_tensor_internal(id: uuid.UUID):
int_arr = []
for i in range(4):
int_arr.append(int(id.int >> (128 - 32 * (i + 1)) & 0xFFFFFFFF))
# Need to use int64 here to avoid an overflow issue in torch.
return torch.tensor(int_arr, dtype=torch.int64)


def get_uuid_tensor():
id = uuid.uuid4()
return _get_uuid_tensor_internal(id)


def decode_uuid_tensor(x):
assert len(
x.shape
) == 1, f"The uuid tensor is expected to be a 1D tensor. Getting shape : {x.shape}."
assert x.numel(
) == 4, f"The uuid tensor is expected to have 4 elements. Tensor has {x.numel()} elements."
uuid_int = 0
for i in range(4):
uuid_int += x.cpu()[i] << (32 * i)
return hex(uuid_int)


@dataclass
class BoundaryMetadata:
name: str # Name of the Patttern.
pos: int # Arg/return position.
id: str # Patten instance id.
id: Union[int, torch.Tensor] # Patten instance id.
is_input: bool = True # If the marked tensor is input/output.
attr: dict = None # Attribute of the pattern, expected to be attached to output.

Expand All @@ -27,8 +57,14 @@ class BoundaryMetadataSerializer(json.JSONEncoder):

def default(self, obj):
if dataclasses.is_dataclass(obj):
if isinstance(obj, BoundaryMetadata):
if isinstance(obj.id, torch.Tensor):
obj.id = decode_uuid_tensor(obj.id)
else:
obj.id = str(obj.id)
return dataclasses.asdict(obj)
return super().default(obj)
else:
return super().default(obj)


def _assert_valid_composite_attr(attr):
Expand All @@ -49,7 +85,7 @@ def _assert_valid_composite_attr(attr):
def mark_tensor_xla(x: torch.Tensor,
name: str,
pos: int,
id: str,
id: int,
is_input: bool,
attr: Dict = None):
"""Attach pattern boundary metadata to a XLA Tensor.
Expand All @@ -58,7 +94,7 @@ def mark_tensor_xla(x: torch.Tensor,
x: torch.Tensor (On XLA device) - the marked tensor.
name: str - The name of the pattern, it will be the name of the stablehlo composite op.
pos: int - Input/output Position of the annotated tensor in the pattern.
id: str - Unique identifier of the pattern instance.
id: int - Unique identifier of the pattern instance.
is_input: bool - If the annotated tensor is the input to the pattern.
attr: dict - Attribute of the pattern, it will be passed down to the attribute field
in the stablehlo composite.
Expand All @@ -73,7 +109,7 @@ def mark_tensor_xla(x: torch.Tensor,
def mark_tensor(x: torch.Tensor,
name: str,
pos: int,
id: str,
id: int,
is_input: bool,
attr: Dict = None):
# Do nothing for non-xla tensor.
Expand All @@ -84,7 +120,44 @@ def mark_tensor(x: torch.Tensor,
def mark_tensor_meta(x: torch.Tensor,
name: str,
pos: int,
id: str,
id: int,
is_input: bool,
attr: Dict = None):
return torch.empty_like(x)


@impl(xla_pattern_marking_lib, "mark_tensor.tensor", "XLA")
def mark_tensor_xla(x: torch.Tensor,
name: str,
pos: int,
id: torch.Tensor,
is_input: bool,
attr: Dict = None):
"""Variant: `id` is a torch.Tensor, which is generated from `get_uuid_tensor`.
"""
_assert_valid_composite_attr(attr)
pattern_info = BoundaryMetadata(name, pos, id, is_input, attr)
return torch_xla._XLAC._xla_mark_tensor(
x, json.dumps(pattern_info, cls=BoundaryMetadataSerializer))


@impl(xla_pattern_marking_lib, "mark_tensor.tensor",
"CompositeExplicitAutograd")
def mark_tensor(x: torch.Tensor,
name: str,
pos: int,
id: torch.Tensor,
is_input: bool,
attr: Dict = None):
# Do nothing for non-xla tensor.
return x


@impl(xla_pattern_marking_lib, "mark_tensor.tensor", "Meta")
def mark_tensor_meta(x: torch.Tensor,
name: str,
pos: int,
id: torch.Tensor,
is_input: bool,
attr: Dict = None):
return torch.empty_like(x)

0 comments on commit 0a2725d

Please sign in to comment.