Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alanwaketan committed Feb 2, 2024
1 parent cef020f commit 5e7ce0f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 31 deletions.
50 changes: 25 additions & 25 deletions test/stablehlo/test_mark_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def test_basic(self):

def f(x):
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True)
x = x + 2
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", False)
x = torch.ops.xla.mark_tensor(x, "p", 0, "0", False)
return x

input_args = (torch.randn(5),)
Expand All @@ -59,29 +59,29 @@ def __init__(self):

def forward(self, x, y):
q, k, v = x.split(128, dim=-2)
q = torch.ops.xla_pattern_marking.mark_tensor(
q = torch.ops.xla.mark_tensor(
q, "sdpa", pos=0, id="0", is_input=True)
k = torch.ops.xla_pattern_marking.mark_tensor(
k = torch.ops.xla.mark_tensor(
k, "sdpa", pos=1, id="0", is_input=True)
v = torch.ops.xla_pattern_marking.mark_tensor(
v = torch.ops.xla.mark_tensor(
v, "sdpa", pos=2, id="0", is_input=True)
attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25)
attn_out = torch.ops.xla_pattern_marking.mark_tensor(
attn_out = torch.ops.xla.mark_tensor(
attn_out,
"sdpa",
pos=0,
id="0",
is_input=False,
attr={"scale": 0.25})
q, k, v = y.split(128, dim=-2)
q = torch.ops.xla_pattern_marking.mark_tensor(
q = torch.ops.xla.mark_tensor(
q, "sdpa", pos=0, id="1", is_input=True)
k = torch.ops.xla_pattern_marking.mark_tensor(
k = torch.ops.xla.mark_tensor(
k, "sdpa", pos=1, id="1", is_input=True)
v = torch.ops.xla_pattern_marking.mark_tensor(
v = torch.ops.xla.mark_tensor(
v, "sdpa", pos=2, id="1", is_input=True)
attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4)
attn_out2 = torch.ops.xla_pattern_marking.mark_tensor(
attn_out2 = torch.ops.xla.mark_tensor(
attn_out2, "sdpa", pos=0, id="1", is_input=False, attr={"scale": 2})
return attn_out, attn_out2

Expand Down Expand Up @@ -193,11 +193,11 @@ def forward(self, x, y):
def test_multiple_input(self):

def f(x, y):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True)
y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True)
x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True)
y = torch.ops.xla.mark_tensor(y, "p", 1, "0", True)
out = x + y
out = out * x * y
out = torch.ops.xla_pattern_marking.mark_tensor(out, "p", 0, "0", False)
out = torch.ops.xla.mark_tensor(out, "p", 0, "0", False)
return out

input_args = (torch.ones(5), torch.ones(5))
Expand All @@ -209,12 +209,12 @@ def f(x, y):
def test_multiple_output(self):

def f(x, y):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True)
y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True)
x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True)
y = torch.ops.xla.mark_tensor(y, "p", 1, "0", True)
out1 = x + y
out2 = x * y
out1 = torch.ops.xla_pattern_marking.mark_tensor(out1, "p", 0, "0", False)
out2 = torch.ops.xla_pattern_marking.mark_tensor(out2, "p", 1, "0", False)
out1 = torch.ops.xla.mark_tensor(out1, "p", 0, "0", False)
out2 = torch.ops.xla.mark_tensor(out2, "p", 1, "0", False)
return out1, out2

input_args = (torch.ones(5), torch.ones(5))
Expand All @@ -224,13 +224,13 @@ def f(x, y):
def test_nested_pattern(self):

def f(x):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0", True)
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", True)
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False)
x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", False)
x = x * 2
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0",
x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0",
False)

input_args = (torch.ones(5),)
Expand All @@ -240,13 +240,13 @@ def f(x):
def test_tangent_output(self):
# Special case of nested pattern, outputs don't have dependencies.
def f(x):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0", True)
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", True)
x = x + 1
y = x - 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False)
y = torch.ops.xla_pattern_marking.mark_tensor(y, "p_outter", 0, "0",
x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", False)
y = torch.ops.xla.mark_tensor(y, "p_outter", 0, "0",
False)

input_args = (torch.ones(5),)
Expand Down
12 changes: 6 additions & 6 deletions torch_xla/experimental/mark_pattern_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class StableHLOCompositeBuilder:
"""
Helper for building a StableHLO Composite by marking input and output tensors. It
should be used with the StableHLO converters from `torch_xla.stablehlo`.
Args:
name (str):
The name of the built StableHLO Composite op.
Expand All @@ -37,7 +37,7 @@ def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool):
if not isinstance(tensor, torch.Tensor):
raise ValueError(f"input must be a torch tensor. Got {type(tensor)}.")
marked_tensors.append(
torch.ops.xla_pattern_marking.mark_tensor(
torch.ops.xla.mark_tensor(
tensor,
name=self.name,
pos=pos,
Expand All @@ -52,9 +52,9 @@ def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool):

def mark_inputs(self, *tensors: torch.Tensor):
"""
Mark the input tensors of the StableHLO Composite. This method must only be
Mark the input tensors of the StableHLO Composite. This method must only be
called once per builder.
Args:
*tensors (torch.Tensor):
Torch tensors to mark.
Expand All @@ -68,9 +68,9 @@ def mark_inputs(self, *tensors: torch.Tensor):

def mark_outputs(self, *tensors: torch.Tensor):
"""
Mark the output tensors of the StableHLO Composite. This method must only be
Mark the output tensors of the StableHLO Composite. This method must only be
called once per builder.
Args:
*tensors (torch.Tensor):
Torch tensors to mark.
Expand Down

0 comments on commit 5e7ce0f

Please sign in to comment.