Skip to content

Commit

Permalink
Bharane/unary composite op removal (#10817)
Browse files Browse the repository at this point in the history
* #10778: Update Tril op with ttnn support

* #10778: Update triu op with ttnn support

* #10778: Update round op with ttnn support

* #10778: Update polygamma with ttnn support

* #10778: Update Argmin with ttnn support

* #10778: Update argmax op with ttnn support
  • Loading branch information
bharane-ab authored Jul 31, 2024
1 parent 5feb7bd commit d8ba4eb
Show file tree
Hide file tree
Showing 12 changed files with 35 additions and 400 deletions.
1 change: 1 addition & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ Pointwise Unary

ttnn/abs
ttnn/acos
ttnn/logical_not_
ttnn/acosh
ttnn/asin
ttnn/asinh
Expand Down
13 changes: 0 additions & 13 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,8 @@ Tensor elementwise operations

.. autofunction:: tt_lib.tensor.logical_ori

.. autofunction:: tt_lib.tensor.polygamma

.. autofunction:: tt_lib.tensor.frac

.. autofunction:: tt_lib.tensor.round

.. autofunction:: tt_lib.tensor.floor_div

.. autofunction:: tt_lib.tensor.rfloor_div
Expand Down Expand Up @@ -352,10 +348,6 @@ Tensor creation operations

.. autofunction:: tt_lib.tensor.empty

.. autofunction:: tt_lib.tensor.tril

.. autofunction:: tt_lib.tensor.triu

Broadcast and Reduce
====================

Expand Down Expand Up @@ -537,11 +529,6 @@ Other Operations

.. autofunction:: tt_lib.tensor.repeat

.. autofunction:: tt_lib.tensor.argmax

.. autofunction:: tt_lib.tensor.argmin


Loss Functions
==============

Expand Down
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/logical_not_.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.logical_not_:

ttnn.logical_not_
###################

.. autofunction:: ttnn.logical_not_
4 changes: 2 additions & 2 deletions models/experimental/mistral/tt/mistral_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ def forward(
)
diagonal = 0

mask = tt_lib.tensor.tril(tensor, diagonal)
mask = ttnn.tril(tensor, diagonal)
tensor.deallocate()
# make the mask banded to account for sliding window
diagonal = -self.args.sliding_window
mask = tt_lib.tensor.triu(mask, diagonal)
mask = ttnn.triu(mask, diagonal)
mask = ttnn.log(mask)
mask = format_tensor(mask, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config, pad_value=-10000)

Expand Down
2 changes: 1 addition & 1 deletion models/experimental/nanogpt/tt/nanogpt_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, config, base_address, device, tt_cache_path, dtype):
self.n_head = self.config.n_head
self.n_embd = self.config.n_embd

temp_bias = tt_lib.tensor.tril(tt_lib.tensor.ones([1, 1, self.block_size, self.block_size]))
temp_bias = ttnn.tril(tt_lib.tensor.ones([1, 1, self.block_size, self.block_size]))
temp_bias = tt_to_torch_tensor(temp_bias)
self.register_buffer(
"bias",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_argmax(self, input_shapes, dim, all, device):
.to(tt_lib.tensor.Layout.TILE)
.to(device)
)
tt_output_tensor_on_device = tt_lib.tensor.argmax(input_tensor, dim=dim, all=all)
tt_output_tensor_on_device = ttnn.experimental.argmax(input_tensor, dim=dim, all=all)
tt_out_tensor = tt_output_tensor_on_device.cpu().to(tt_lib.tensor.Layout.ROW_MAJOR).to_torch()
if all:
golden_tensor = torch.argmax(input_data)
Expand Down
8 changes: 4 additions & 4 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def eltwise_logit(x, *args, eps, device, dtype, layout, input_mem_config, output
@setup_host_and_device
def eltwise_polygamma(x, *args, k, device, dtype, layout, input_mem_config, output_mem_config, **kwargs):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.polygamma(t0, k, output_mem_config=output_mem_config)
t1 = ttnn.polygamma(t0, k=k, memory_config=output_mem_config)

return tt2torch_tensor(t1)

Expand Down Expand Up @@ -790,7 +790,7 @@ def eltwise_round(
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.round(t0, decimals, output_mem_config=output_mem_config)
t1 = ttnn.round(t0, decimals=decimals, memory_config=output_mem_config)

return tt2torch_tensor(t1)

Expand Down Expand Up @@ -1170,15 +1170,15 @@ def eltwise_unary_lt(
def triu(x, *args, device, dtype, layout, input_mem_config, output_mem_config, **kwargs):
tx = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
diag = kwargs.get("diag", 0)
t1 = ttl.tensor.triu(tx, diag, output_mem_config)
t1 = ttnn.triu(tx, diagonal=diag, memory_config=output_mem_config)
return tt2torch_tensor(t1)


@setup_host_and_device
def tril(x, *args, device, dtype, layout, input_mem_config, output_mem_config, **kwargs):
tx = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
diag = kwargs.get("diag", 0)
t1 = ttl.tensor.tril(tx, diag, output_mem_config)
t1 = ttnn.tril(tx, diagonal=diag, memory_config=output_mem_config)
return tt2torch_tensor(t1)


Expand Down
36 changes: 18 additions & 18 deletions tests/ttnn/profiling/ops_for_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,43 +1516,43 @@ def pow_float(x):


def argmax_1(x):
tt_lib.tensor.argmax(x, dim=-1)
ttnn.argmax(x, dim=-1)


def argmax_2(x):
tt_lib.tensor.argmax(x, dim=-2)
ttnn.argmax(x, dim=-2)


def argmax_3(x):
tt_lib.tensor.argmax(x, dim=-3)
ttnn.argmax(x, dim=-3)


def argmax_4(x):
tt_lib.tensor.argmax(x, dim=-4)
ttnn.argmax(x, dim=-4)


def argmax_all(x):
tt_lib.tensor.argmax(x, dim=-1, all=True)
ttnn.argmax(x, dim=-1, all=True)


def argmin_1(x):
tt_lib.tensor.argmin(x, dim=-1)
ttnn.argmin(x, dim=-1)


def argmin_2(x):
tt_lib.tensor.argmin(x, dim=-2)
ttnn.argmin(x, dim=-2)


def argmin_3(x):
tt_lib.tensor.argmin(x, dim=-3)
ttnn.argmin(x, dim=-3)


def argmin_4(x):
tt_lib.tensor.argmin(x, dim=-4)
ttnn.argmin(x, dim=-4)


def argmin_all(x):
tt_lib.tensor.argmin(x, dim=-1, all=True)
ttnn.argmin(x, dim=-1, all=True)


def primary_moreh_softmax_0(x):
Expand Down Expand Up @@ -2264,42 +2264,42 @@ def clone(x):
},
{
"op": argmax_1,
"name": "tt_lib.tensor.argmax_dim_3",
"name": "ttnn.argmax_dim_3",
"num_repeats": 2,
},
{
"op": argmax_2,
"name": "tt_lib.tensor.argmax_dim_2",
"name": "ttnn.argmax_dim_2",
"num_repeats": 2,
},
{
"op": argmax_3,
"name": "tt_lib.tensor.argmax_dim_1",
"name": "ttnn.argmax_dim_1",
"num_repeats": 2,
},
{
"op": argmax_all,
"name": "tt_lib.tensor.argmax_all",
"name": "ttnn.argmax_all",
"num_repeats": 2,
},
{
"op": argmin_1,
"name": "tt_lib.tensor.argmin_dim_3",
"name": "ttnn.argmin_dim_3",
"num_repeats": 2,
},
{
"op": argmin_2,
"name": "tt_lib.tensor.argmin_dim_2",
"name": "ttnn.argmin_dim_2",
"num_repeats": 2,
},
{
"op": argmin_3,
"name": "tt_lib.tensor.argmin_dim_1",
"name": "ttnn.argmin_dim_1",
"num_repeats": 2,
},
{
"op": argmin_all,
"name": "tt_lib.tensor.argmin_all",
"name": "ttnn.argmin_all",
"num_repeats": 2,
},
{
Expand Down
Loading

0 comments on commit d8ba4eb

Please sign in to comment.