Skip to content

Commit

Permalink
Replace all TT Lib Permute Uses with TTNN and remove old bindings (#1…
Browse files Browse the repository at this point in the history
…0312)

#9746: remove all ttl permute bindings and uses and replace with ttnn
  • Loading branch information
ntarafdar authored Jul 20, 2024
1 parent f52ec95 commit 484c441
Show file tree
Hide file tree
Showing 25 changed files with 126 additions and 235 deletions.
2 changes: 0 additions & 2 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,6 @@ but in general retaining the data.

.. autofunction:: tt_lib.tensor.transpose

.. autofunction:: tt_lib.tensor.permute

.. autofunction:: tt_lib.tensor.untilize

.. autofunction:: tt_lib.tensor.untilize_with_unpadding
Expand Down
2 changes: 1 addition & 1 deletion models/experimental/bert_tiny/tt/bert_self_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def make_attetntion_heads(x: ttnn.Tensor):
)
# Permute expects input to be in TILE layout
# Input shape: [1, 128, 2, 64]
transposed = tt_lib.tensor.permute(reshape_unt, [0, 2, 1, 3])
transposed = ttnn.permute(reshape_unt, [0, 2, 1, 3])

transposed = tt_to_torch_tensor(transposed)
transposed = ttnn.from_torch(transposed, dtype=ttnn.bfloat16)
Expand Down
2 changes: 1 addition & 1 deletion models/experimental/bloom/tt/bloom_merge_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def tt_merge_heads(
)

# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
tt_permuted = tt_lib.tensor.permute(tt_reshaped, (0, 2, 1, 3))
tt_permuted = ttnn.permute(tt_reshaped, (0, 2, 1, 3))

# reshape - fallback
reshaped_2 = fallback_ops.reshape(
Expand Down
4 changes: 2 additions & 2 deletions models/experimental/deit/tt/deit_self_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def transpose_for_scores(self, x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor:
self.attention_head_size,
]
x = fallback_ops.reshape(x, *new_x_shape)
x = tt_lib.tensor.permute(x, (0, 2, 1, 3))
x = ttnn.permute(x, (0, 2, 1, 3))
return x

def forward(
Expand Down Expand Up @@ -81,7 +81,7 @@ def forward(
attention_probs = attention_probs * head_mask

context_layer = ttnn.matmul(attention_probs, value_layer)
context_layer = tt_lib.tensor.permute(context_layer, (0, 2, 1, 3))
context_layer = ttnn.permute(context_layer, (0, 2, 1, 3))
new_context_layer_shape = (1,) + tuple(context_layer.get_legacy_shape())[:-2] + (self.all_head_size,)
context_layer = fallback_ops.reshape(context_layer, *new_context_layer_shape)

Expand Down
2 changes: 1 addition & 1 deletion models/experimental/hrnet/tt/hrnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def forward(self, x: tt_lib.tensor.Tensor):
y = module(y)

y = self.avg_pool2d(y)
y = tt_lib.tensor.permute(y, (0, 3, 2, 1))
y = ttnn.permute(y, (0, 3, 2, 1))
y = self.classifier(y)

return y
Expand Down
16 changes: 8 additions & 8 deletions models/experimental/mistral/mistral_helper_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,15 @@ def get_freqs_cis(freqs_cis: torch.Tensor, query_shape, key_shape, device=None,
BCMUL = tt_lib.tensor.BcastOpMath.MUL

t_one_xq = tt_lib.tensor.ones(query_shape, output_mem_config=mem_config)
t_one_xq = tt_lib.tensor.permute(t_one_xq, (3, 1, 2, 0), output_mem_config=mem_config)
t_one_xq = ttnn.permute(t_one_xq, (3, 1, 2, 0), memory_config=mem_config)

freqs_real = tt_lib.tensor.permute(freqs_cis.real, (3, 1, 2, 0), output_mem_config=mem_config)
freqs_imag = tt_lib.tensor.permute(freqs_cis.imag, (3, 1, 2, 0), output_mem_config=mem_config)
freqs_real = ttnn.permute(freqs_cis.real, (3, 1, 2, 0), memory_config=mem_config)
freqs_imag = ttnn.permute(freqs_cis.imag, (3, 1, 2, 0), memory_config=mem_config)

bcast_freq_re_xq = tt_lib.tensor.bcast(t_one_xq, freqs_real, BCMUL, BCH, output_mem_config=mem_config)
bcast_freq_im_xq = tt_lib.tensor.bcast(t_one_xq, freqs_imag, BCMUL, BCH, output_mem_config=mem_config)
bcast_freq_re_xq = tt_lib.tensor.permute(bcast_freq_re_xq, (3, 1, 2, 0), output_mem_config=mem_config)
bcast_freq_im_xq = tt_lib.tensor.permute(bcast_freq_im_xq, (3, 1, 2, 0), output_mem_config=mem_config)
bcast_freq_re_xq = ttnn.permute(bcast_freq_re_xq, (3, 1, 2, 0), memory_config=mem_config)
bcast_freq_im_xq = ttnn.permute(bcast_freq_im_xq, (3, 1, 2, 0), memory_config=mem_config)
t_one_xq.deallocate()

bcast_freq_xq = tt_lib.tensor.complex_tensor(bcast_freq_re_xq, bcast_freq_im_xq)
Expand All @@ -131,12 +131,12 @@ def get_freqs_cis(freqs_cis: torch.Tensor, query_shape, key_shape, device=None,
bcast_freq_im_xq.deallocate()

t_one_xk = tt_lib.tensor.ones(key_shape, output_mem_config=mem_config)
t_one_xk = tt_lib.tensor.permute(t_one_xk, (3, 1, 2, 0), output_mem_config=mem_config)
t_one_xk = ttnn.permute(t_one_xk, (3, 1, 2, 0), memory_config=mem_config)

bcast_freq_re_xk = tt_lib.tensor.bcast(t_one_xk, freqs_real, BCMUL, BCH, output_mem_config=mem_config)
bcast_freq_im_xk = tt_lib.tensor.bcast(t_one_xk, freqs_imag, BCMUL, BCH, output_mem_config=mem_config)
bcast_freq_re_xk = tt_lib.tensor.permute(bcast_freq_re_xk, (3, 1, 2, 0), output_mem_config=mem_config)
bcast_freq_im_xk = tt_lib.tensor.permute(bcast_freq_im_xk, (3, 1, 2, 0), output_mem_config=mem_config)
bcast_freq_re_xk = ttnn.permute(bcast_freq_re_xk, (3, 1, 2, 0), memory_config=mem_config)
bcast_freq_im_xk = ttnn.permute(bcast_freq_im_xk, (3, 1, 2, 0), memory_config=mem_config)

bcast_freq_xk = tt_lib.tensor.complex_tensor(bcast_freq_re_xk, bcast_freq_im_xk)

Expand Down
8 changes: 4 additions & 4 deletions models/experimental/mistral/tt/mistral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def forward(
xq.deallocate()

key = format_tensor(key, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config)
key = tt_lib.tensor.permute(key, [0, 2, 3, 1])
key = ttnn.permute(key, [0, 2, 3, 1])
key = format_tensor(key, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config)

value = format_tensor(value, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config)
Expand All @@ -203,8 +203,8 @@ def forward(
scores = ttnn.multiply(scores, self.scale, memory_config=self.args.out_mem_config)

if mask is not None:
mask = tt_lib.tensor.permute(mask, [2, 3, 0, 1])
scores = tt_lib.tensor.permute(scores, [2, 3, 0, 1])
mask = ttnn.permute(mask, [2, 3, 0, 1])
scores = ttnn.permute(scores, [2, 3, 0, 1])

scores = tt_lib.tensor.bcast(
scores,
Expand All @@ -213,7 +213,7 @@ def forward(
tt_lib.tensor.BcastOpDim.HW,
output_mem_config=self.output_mem_config,
)
scores = tt_lib.tensor.permute(scores, [2, 3, 0, 1])
scores = ttnn.permute(scores, [2, 3, 0, 1])
desired_output_shape = [bsz, 32, seqlen, seqlen]
desired_output_shape[-1] = value.get_legacy_shape()[-1]

Expand Down
6 changes: 3 additions & 3 deletions models/experimental/nanogpt/tt/nanogpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ def forward(self, idx: torch.Tensor) -> tt_lib.tensor.Tensor:
tt_tok_emb = torch_to_tt_tensor_rm(tok_emb, self.device)
tt_pos_emb = torch_to_tt_tensor_rm(pos_emb, self.device)

tt_tok_emb = tt_lib.tensor.permute(tt_tok_emb, (0, 2, 1, 3))
tt_pos_emb = tt_lib.tensor.permute(tt_pos_emb, (0, 2, 1, 3))
tt_tok_emb = ttnn.permute(tt_tok_emb, (0, 2, 1, 3))
tt_pos_emb = ttnn.permute(tt_pos_emb, (0, 2, 1, 3))

tt_x = tt_lib.tensor.bcast(tt_tok_emb, tt_pos_emb, tt_lib.tensor.BcastOpMath.ADD, tt_lib.tensor.BcastOpDim.H)
tt_tok_emb.deallocate()
tt_pos_emb.deallocate()
tt_x = tt_lib.tensor.permute(tt_x, (0, 2, 1, 3))
tt_x = ttnn.permute(tt_x, (0, 2, 1, 3))

for block in self.h:
tt_x = block.forward(tt_x)
Expand Down
4 changes: 2 additions & 2 deletions models/experimental/roberta/tt/roberta_self_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def transpose_for_scores(self, x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor:
self.attention_head_size,
]
x = fallback_ops.reshape(x, *new_x_shape)
x = tt_lib.tensor.permute(x, (0, 2, 1, 3))
x = ttnn.permute(x, (0, 2, 1, 3))
return x

def linear(self, x, weight, bias):
Expand Down Expand Up @@ -223,7 +223,7 @@ def forward(
attention_probs = ttnn.mul(attention_probs, head_mask, memory_config=self.mem_config)

context_layer = ttnn.matmul(attention_probs, value_layer, memory_config=self.mem_config)
context_layer = tt_lib.tensor.permute(context_layer, (0, 2, 1, 3))
context_layer = ttnn.permute(context_layer, (0, 2, 1, 3))

# TODO left here. Finish porting and re-test everything. See other TODO s
# context_layer = context_layer.permute(0, 2, 1, 3). contiguous() TODO: CHECK contiguous
Expand Down
4 changes: 2 additions & 2 deletions models/experimental/stable_diffusion/tt/cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,15 @@ def batch_to_head_dim(self, tensor: ttl.tensor.Tensor) -> ttl.tensor.Tensor:
head_size = self.heads
_, batch_size, seq_len, dim = tensor.get_legacy_shape()
tensor = fallback_ops.reshape(tensor, batch_size // head_size, head_size, seq_len, dim)
tensor = ttl.tensor.permute(tensor, (0, 2, 1, 3))
tensor = ttnn.permute(tensor, (0, 2, 1, 3))
tensor = fallback_ops.reshape(tensor, 1, batch_size // head_size, seq_len, dim * head_size)
return tensor

def head_to_batch_dim(self, tensor: ttl.tensor.Tensor) -> ttl.tensor.Tensor:
head_size = self.heads
_, batch_size, seq_len, dim = tensor.get_legacy_shape()
tensor = fallback_ops.reshape(tensor, batch_size, seq_len, head_size, dim // head_size)
tensor = ttl.tensor.permute(tensor, (0, 2, 1, 3))
tensor = ttnn.permute(tensor, (0, 2, 1, 3))
tensor = fallback_ops.reshape(tensor, 1, batch_size * head_size, seq_len, dim // head_size)
return tensor

Expand Down
8 changes: 4 additions & 4 deletions models/experimental/stable_diffusion/tt/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,11 +483,11 @@ def forward(

inner_dim = hidden_states.get_legacy_shape()[1]

hidden_states = ttl.tensor.permute(hidden_states, (0, 2, 3, 1))
hidden_states = ttnn.permute(hidden_states, (0, 2, 3, 1))
hidden_states = fallback_ops.reshape(hidden_states, 1, batch, height * width, inner_dim)
else:
inner_dim = hidden_states.get_legacy_shape()[1]
hidden_states = ttl.tensor.permute(hidden_states, (0, 2, 3, 1))
hidden_states = ttnn.permute(hidden_states, (0, 2, 3, 1))
hidden_states = fallback_ops.reshape(hidden_states, 1, batch, height * width, inner_dim)

hidden_states = self.proj_in(hidden_states)
Expand All @@ -506,13 +506,13 @@ def forward(
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = fallback_ops.reshape(hidden_states, batch, height, width, inner_dim)
hidden_states = ttl.tensor.permute(hidden_states, (0, 3, 1, 2))
hidden_states = ttnn.permute(hidden_states, (0, 3, 1, 2))

hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = fallback_ops.reshape(hidden_states, batch, height, width, inner_dim)
hidden_states = ttl.tensor.permute(hidden_states, (0, 3, 1, 2))
hidden_states = ttnn.permute(hidden_states, (0, 3, 1, 2))

output = ttnn.add(
hidden_states,
Expand Down
7 changes: 4 additions & 3 deletions models/experimental/swin/tt/swin_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from models.experimental.swin.tt.swin_stage import TtSwinStage
from models.experimental.swin.tt.swin_patch_merging import TtSwinPatchMerging

import ttnn
import tt_lib
from tt_lib.fallback_ops import fallback_ops

Expand Down Expand Up @@ -82,7 +83,7 @@ def forward(
_, batch_size, _, hidden_size = hidden_states.get_legacy_shape()

reshaped_hidden_state = fallback_ops.reshape(hidden_states, batch_size, *input_dimensions, hidden_size)
reshaped_hidden_state = tt_lib.tensor.permute(reshaped_hidden_state, (0, 3, 1, 2))
reshaped_hidden_state = ttnn.permute(reshaped_hidden_state, (0, 3, 1, 2))
all_hidden_states += (hidden_states,)
all_reshaped_hidden_states += (reshaped_hidden_state,)

Expand Down Expand Up @@ -134,7 +135,7 @@ def custom_forward(*inputs):
*(output_dimensions[0], output_dimensions[1]),
hidden_size,
)
reshaped_hidden_state = tt_lib.tensor.permute(reshaped_hidden_state, (0, 3, 1, 2))
reshaped_hidden_state = ttnn.permute(reshaped_hidden_state, (0, 3, 1, 2))
all_hidden_states += (hidden_states_before_downsampling,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
elif output_hidden_states and not output_hidden_states_before_downsampling:
Expand All @@ -143,7 +144,7 @@ def custom_forward(*inputs):
reshaped_hidden_state = fallback_ops.reshape(
reshaped_hidden_state, batch_size, *input_dimensions, hidden_size
)
reshaped_hidden_state = tt_lib.tensor.permute(reshaped_hidden_state, (0, 3, 1, 2))
reshaped_hidden_state = ttnn.permute(reshaped_hidden_state, (0, 3, 1, 2))
all_hidden_states += (hidden_states,)
all_reshaped_hidden_states += (reshaped_hidden_state,)

Expand Down
8 changes: 4 additions & 4 deletions models/experimental/swin/tt/swin_self_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def transpose_for_scores(self, x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor:
self.attention_head_size,
]
x = fallback_ops.reshape(x, *new_x_shape)
x = tt_lib.tensor.permute(x, (0, 2, 1, 3))
x = ttnn.permute(x, (0, 2, 1, 3))
return x

def forward(
Expand Down Expand Up @@ -120,15 +120,15 @@ def forward(
self.window_size[0] * self.window_size[1],
1,
)
attention_scores = tt_lib.tensor.permute(attention_scores, (1, 2, 3, 0))
attention_scores = ttnn.permute(attention_scores, (1, 2, 3, 0))
attention_scores = tt_lib.tensor.bcast(
attention_scores,
relative_position_bias,
tt_lib.tensor.BcastOpMath.ADD,
tt_lib.tensor.BcastOpDim.W,
)

attention_scores = tt_lib.tensor.permute(attention_scores, (3, 0, 1, 2))
attention_scores = ttnn.permute(attention_scores, (3, 0, 1, 2))

attention_scores = tt_to_torch_tensor(attention_scores)

Expand Down Expand Up @@ -160,7 +160,7 @@ def forward(
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = ttnn.matmul(attention_probs, value_layer)
context_layer = tt_lib.tensor.permute(context_layer, (0, 2, 1, 3))
context_layer = ttnn.permute(context_layer, (0, 2, 1, 3))

new_context_layer_shape = tuple(context_layer.get_legacy_shape())[:-2] + (self.all_head_size,)
context_layer = fallback_ops.reshape(
Expand Down
4 changes: 2 additions & 2 deletions models/experimental/vit/tt/modeling_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def transpose_for_scores(self, x: tt_tensor) -> tt_tensor:
self.attention_head_size,
)
x = tt_lib.fallback_ops.reshape(x, *new_x_shape)
return tt_lib.tensor.permute(x, (0, 2, 1, 3))
return ttnn.permute(x, (0, 2, 1, 3))

def forward(
self,
Expand Down Expand Up @@ -159,7 +159,7 @@ def forward(

context_layer = ttnn.matmul(attention_probs, value_layer)

context_layer = tt_lib.tensor.permute(context_layer, (0, 2, 1, 3))
context_layer = ttnn.permute(context_layer, (0, 2, 1, 3))
new_context_layer_shape = (1,) + tuple(context_layer.get_legacy_shape())[:-2] + (self.all_head_size,)
context_layer = fallback_ops.reshape(context_layer, *new_context_layer_shape)

Expand Down
23 changes: 15 additions & 8 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from models.helper_funcs import Linear as tt_Linear
from models.utility_functions import torch2tt_tensor, tt2torch_tensor, ttl_complex_2_torch_complex
from models.demos.metal_BERT_large_11.tt import custom_matmuls
from tests.ttnn.utils_for_testing import assert_with_pcc


def setup_tt_tensor(x, device, layout, input_mem_config, dtype):
Expand Down Expand Up @@ -1366,8 +1367,10 @@ def repeat_interleave(
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.repeat_interleave(t0, repeat, dim, output_mem_config=output_mem_config)

return tt2torch_tensor(t1)
output_tensor = ttnn.from_device(t1)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.to_torch(output_tensor)
return output_tensor


@setup_host_and_device
Expand Down Expand Up @@ -1639,11 +1642,13 @@ def prod(
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.prod(t0, all_dimensions, dim, output_mem_config=output_mem_config)
output = tt2torch_tensor(t1)
output_tensor = ttnn.from_device(t1)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.to_torch(output_tensor)
if all_dimensions:
return output[:1, :1, :1, :1]
return output_tensor[:1, :1, :1, :1]
else:
return output
return output_tensor


@setup_host_and_device
Expand Down Expand Up @@ -2194,9 +2199,11 @@ def permute(
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.permute(t0, permute_dims, output_mem_config=output_mem_config)

return tt2torch_tensor(t1)
t1 = ttnn.permute(t0, permute_dims, memory_config=output_mem_config)
output_tensor = ttnn.from_device(t1)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.to_torch(output_tensor)
return output_tensor


@setup_host_and_device
Expand Down
1 change: 1 addition & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ set(TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding/embedding/device/embeddings_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/permute/permute_impl.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/softmax/device/multi_core/softmax_op_multi_core.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/pad/device/pad_op.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ set(TT_DNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/moreh_sgd/moreh_sgd_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_sgd/moreh_sgd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reshape/reshape_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/permute/permute_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/composite/composite_ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backward/backward_ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/optimizer/optimizer_ops.cpp
Expand Down
Loading

0 comments on commit 484c441

Please sign in to comment.