Skip to content

Commit

Permalink
Convert aten.alias to no-op
Browse files Browse the repository at this point in the history
  • Loading branch information
jerrysky3 committed Sep 19, 2024
1 parent 970339e commit f5abb9a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
35 changes: 35 additions & 0 deletions tests/lowering/misc/test_alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
import torch_ttnn
import pytest
import ttnn


class AliasModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.ops.aten.alias.default(x)


@pytest.mark.parametrize(
"input_shape",
[(4, 4)],
)
def test_alias(device, input_shape):
m = AliasModule()
input = torch.rand(input_shape, dtype=torch.bfloat16)
result_before = m.forward(input)

option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=True)
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(input)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten
nodes = list(option._out_fx_graphs[0].nodes)
# There should be no op
assert [node.op for node in nodes].count("call_function") == 0
# Check inference result
assert torch.allclose(result_before, result_after)
4 changes: 4 additions & 0 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,10 @@ def call_function(self, target, args, kwargs):
return self.call_function_prop_meta(ttnn.squeeze, args, kwargs)
return self.call_function_prop_meta(target, args, kwargs)

if target == torch.ops.aten.alias.default:
# alias is no-op
return args[0]

return self.call_function_prop_meta(target, args, kwargs)


Expand Down

0 comments on commit f5abb9a

Please sign in to comment.