Skip to content

Commit

Permalink
#9749: move repeat (#11215)
Browse files Browse the repository at this point in the history
  • Loading branch information
ntarafdar authored Aug 11, 2024
1 parent fcf94ac commit 5161b53
Show file tree
Hide file tree
Showing 38 changed files with 298 additions and 249 deletions.
2 changes: 0 additions & 2 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -405,5 +405,3 @@ Other Operations
.. autofunction:: tt_lib.tensor.mean_hw

.. autofunction:: tt_lib.tensor.lamb_optimizer

.. autofunction:: tt_lib.tensor.repeat
6 changes: 3 additions & 3 deletions models/demos/falcon7b_common/tt/falcon_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token
)
# Repeat attn masks for all heads
for i in range(self.num_devices):
tt_attention_mask[i] = ttnn.experimental.tensor.repeat(
tt_attention_mask[i] = ttnn.repeat(
tt_attention_mask[i],
[1, self.config.num_attention_heads, 1, 1],
output_mem_config=self.model_config["ATTN_MASK_MEMCFG"],
ttnn.Shape([1, self.config.num_attention_heads, 1, 1]),
memory_config=self.model_config["ATTN_MASK_MEMCFG"],
)
# Tilize attn masks
for i in range(self.num_devices):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def scale_mask_softmax_decomposed(self, attn, scale, attn_mask):
attn = ttnn.multiply(attn, scale)

## Need to figure out how to broadcast in t dim
# attn_mask = tt_lib.tensor.repeat(attn_mask, [1, attn.shape()[1], 1, 1]) # this causes memory error as the broadcast result is too big
# attn_mask = ttnn.repeat(attn_mask, [1, attn.shape()[1], 1, 1]) # this causes memory error as the broadcast result is too big
# attn_mask = tt2torch_tensor(attn_mask)
# attn_mask = attn_mask.repeat(1, attn.shape()[1], 1, 1)
# attn_mask = torch2tt_tensor(attn_mask, self.device)
Expand Down
4 changes: 2 additions & 2 deletions models/experimental/llama2_70b/tests/test_llama_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos):
attn_masks = ttnn.to_device(attn_masks, llama_attention_model.device_mesh)

repeat_shape = (1, batch, 1, 1)
attn_masks = tt_lib.tensor.repeat(
attn_masks, repeat_shape, output_mem_config=llama_attention_model.model_config["DRAM_MEMCFG"]
attn_masks = ttnn.repeat(
attn_masks, ttnn.Shape(repeat_shape), memory_config=llama_attention_model.model_config["DRAM_MEMCFG"]
)
return (
xs,
Expand Down
4 changes: 2 additions & 2 deletions models/experimental/llama2_70b/tests/test_llama_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos):
attn_masks = ttnn.to_device(attn_masks, llama_decoder_model.device_mesh)

repeat_shape = (1, batch, 1, 1)
attn_masks = tt_lib.tensor.repeat(
attn_masks, repeat_shape, output_mem_config=llama_decoder_model.model_config["DRAM_MEMCFG"]
attn_masks = ttnn.repeat(
attn_masks, ttnn.Shape(repeat_shape), memory_config=llama_decoder_model.model_config["DRAM_MEMCFG"]
)
return (
xs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def scale_mask_softmax_decomposed(self, attn, scale, attn_mask):
attn = ttnn.multiply(attn, scale)

## Need to figure out how to broadcast in t dim
# attn_mask = tt_lib.tensor.repeat(attn_mask, [1, attn.shape()[1], 1, 1]) # this causes memory error as the broadcast result is too big
# attn_mask = ttnn.repeat(attn_mask, [1, attn.shape()[1], 1, 1]) # this causes memory error as the broadcast result is too big
# attn_mask = tt2torch_tensor(attn_mask)
# attn_mask = attn_mask.repeat(1, attn.shape()[1], 1, 1)
# attn_mask = torch2tt_tensor(attn_mask, self.device)
Expand Down
4 changes: 2 additions & 2 deletions models/experimental/llama2_70b/tt/llama_attention_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def prepare_inputs(self, x, start_pos):
repeat_shape = (attn_batch, 1, 1, 1)

for i in range(self.num_devices):
attn_masks[i] = tt_lib.tensor.repeat(
attn_masks[i], repeat_shape, output_mem_config=self.model_config["DRAM_MEMCFG"]
attn_masks[i] = ttnn.repeat(
attn_masks[i], ttnn.Shape(repeat_shape), memory_config=self.model_config["DRAM_MEMCFG"]
)
# Put attn_mask on the device with the sharded config
attention_mask_memconfig = self.model_config["ATTN_MASK_MEMCFG"]
Expand Down
4 changes: 2 additions & 2 deletions models/experimental/llama2_70b/tt/llama_decoder_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ def prepare_inputs(self, x, start_pos):
repeat_shape = (1, self.n_local_heads, 1, 1)

for i in range(self.num_devices):
attn_masks[i] = tt_lib.tensor.repeat(
attn_masks[i], repeat_shape, output_mem_config=self.model_config["DRAM_MEMCFG"]
attn_masks[i] = ttnn.repeat(
attn_masks[i], ttnn.Shape(repeat_shape), memory_config=self.model_config["DRAM_MEMCFG"]
)
# Put attn_mask on the device with the sharded config
attention_mask_memconfig = self.model_config["ATTN_MASK_MEMCFG"]
Expand Down
4 changes: 2 additions & 2 deletions models/experimental/llama2_70b/tt/llama_model_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,8 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None):
attn_masks = ttnn.to_device(attn_masks, self.device_mesh)

repeat_shape = (1, batch, 1, 1)
attn_masks = tt_lib.tensor.repeat(
attn_masks, repeat_shape, output_mem_config=self.model_config["DRAM_MEMCFG"]
attn_masks = ttnn.repeat(
attn_masks, ttnn.Shape(repeat_shape), memory_config=self.model_config["DRAM_MEMCFG"]
)
# Put attn_mask on the device with the sharded config
attention_mask_memconfig = self.model_config["ATTN_MASK_MEMCFG"]
Expand Down
6 changes: 3 additions & 3 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,9 +1115,9 @@ def repeat(
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.repeat(t0, repeat, output_mem_config=output_mem_config)

return tt2torch_tensor(t1)
t1 = ttnn.repeat(t0, ttnn.Shape(repeat), memory_config=output_mem_config)
output_tensor = ttnn.to_torch(t1)
return output_tensor


@setup_host_and_device
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def run_repeat(input_shape, repeats, device, layout, dtype, input_mem_config, ou

tt_cpu = input.repeat(torch.Size(repeats))

tt = ttl.tensor.repeat(tt_input, ttl.tensor.Shape(repeats), output_mem_config)
tt = ttnn.repeat(tt_input, ttnn.Shape(repeats), memory_config=output_mem_config)

tt_dev = tt.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch().to(torch.bfloat16)

Expand Down
4 changes: 2 additions & 2 deletions tests/ttnn/profiling/ops_for_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,7 +1520,7 @@ def swiglu_2(x):


def repeat(x):
tt_lib.tensor.repeat(x, (1, 1, 1, 4))
ttnn.repeat(x, ttnn.Shape((1, 1, 1, 4)))


def repeat_interleave_0(x):
Expand Down Expand Up @@ -2252,7 +2252,7 @@ def clone(x):
},
{
"op": repeat,
"name": "tt_lib.tensor.repeat",
"name": "ttnn.repeat",
},
{
"op": repeat_interleave_0,
Expand Down
2 changes: 1 addition & 1 deletion tests/ttnn/profiling/reference.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ tt_lib.tensor.real,200,0.027,0.029,0.06,0.012
ttnn.real_bw,200,0.821,0.827,0.847
ttnn.reglu_dim_2,200,0.102,0.107,0.245,0.045
ttnn.reglu_dim_3,200,0.105,0.111,0.244,0.045
tt_lib.tensor.repeat,200,0.025,0.027,0.368,0.009
ttnn.repeat,200,0.025,0.027,0.368,0.009
ttnn.repeat_interleave_dim_0,200,0.039,0.043,0.375,0.01
ttnn.repeat_interleave_dim_1,80,0.42,0.429,323.298,0.219
ttnn.repeat_interleave_dim_2,80,0.152,0.154,150.628,0.076
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def run_repeat_tests(

x = ttnn_ops.setup_ttnn_tensor(x, device, dlayout[0], in_mem_config[0], dtype[0])

tt_result = ttnn.repeat(x, ttnn.Shape(shape))
tt_result = ttnn.repeat(x, shape)

tt_result = ttnn_ops.ttnn_tensor_to_torch(tt_result, output_mem_config)

Expand Down
4 changes: 4 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ set(TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/split/split.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/split/device/split_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/split/device/split_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat/repeat.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat/device/repeat_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat/device/repeat_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat/repeat_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/non_zero_indices/non_zero_indices.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/non_zero_indices/non_zero_indices_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/non_zero_indices/device/non_zero_indices_op.cpp
Expand Down
2 changes: 0 additions & 2 deletions ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@ set(TT_DNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_attn_matmul/multi_core_attn_matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_group_attn_matmul/multi_core_group_attn_matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_ssm_1d_sum_reduce/multi_core_ssm_1d_sum_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/repeat/multi_core/repeat_op_multi_core.cpp
${CMAKE_CURRENT_SOURCE_DIR}/repeat/repeat_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nlp_tms/nlp_tms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nlp_tms/nlp_create_qkv_heads_falcon7b.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nlp_tms/nlp_create_qkv_heads_decode.cpp
Expand Down
90 changes: 0 additions & 90 deletions ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/repeat/repeat_op.cpp

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "ttnn/deprecated/tt_dnn/op_library/move/move_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/reshape/reshape_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/fold/fold_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/repeat/repeat_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/bcast/bcast_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/reduce/reduce_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/copy/copy_op.hpp"
Expand Down Expand Up @@ -63,20 +62,6 @@ namespace tt::tt_metal::detail{
)doc"
);

m_tensor.def("repeat", &tt::tt_metal::repeat,
py::arg("input"), py::arg("size"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Returns a new tensor filled with repetition of input ``input`` tensor according to number of times specified in ``size``. The rank of ``size`` should be less than or equal to the rank of tensor ``input_a``.
Output tensor will have same data type as input.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"input", "Input tensor for which repetition is computed", "Tensor", "Tensor of any shape", "Yes"
"size", "The number of times to repeat this tensor along each dimension", "List[Int]", "Positive repetition values", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("assign",
[](const Tensor& input_a, const Tensor& input_b, uint8_t queue_id){
return assign(queue_id, input_a, input_b); },
Expand Down
37 changes: 0 additions & 37 deletions ttnn/cpp/ttnn/operations/data_movement.hpp

This file was deleted.

34 changes: 2 additions & 32 deletions ttnn/cpp/ttnn/operations/data_movement/data_movement_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <pybind11/stl.h>

#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/data_movement.hpp"
#include "ttnn/operations/data_movement/concat/concat_pybind.hpp"
#include "ttnn/operations/data_movement/pad/pad_pybind.hpp"
#include "ttnn/operations/data_movement/permute/permute_pybind.hpp"
Expand All @@ -24,6 +23,7 @@
#include "ttnn/operations/data_movement/untilize_with_halo_v2/untilize_with_halo_v2_pybind.hpp"
#include "ttnn/operations/data_movement/non_zero_indices/non_zero_indices_pybind.hpp"
#include "ttnn/operations/data_movement/fill_rm/fill_rm_pybind.hpp"
#include "ttnn/operations/data_movement/repeat/repeat_pybind.hpp"


namespace py = pybind11;
Expand All @@ -32,43 +32,12 @@ namespace ttnn {
namespace operations {
namespace data_movement {

void bind_repeat(py::module& module) {
auto doc = R"doc(
repeat(input_tensor: ttnn.Tensor, shape : ttnn.Shape) -> ttnn.Tensor
Returns a new tensor filled with repetition of input :attr:`input_tensor` according to number of times specified in :attr:`shape`.
Args:
* :attr:`input_tensor`: the input_tensor to apply the repeate operation.
* :attr:`shape`: The number of repetitions for each element.
Keyword Args:
* :attr:`memory_config`: the memory configuration to use for the operation
Example:
>>> tensor = ttnn.repeat(ttnn.from_torch(torch.tensor([[1, 2], [3, 4]]), 2,)), device)
>>> print(tensor)
tensor([[1, 2],
[1, 2],
[3, 4],
[3, 4]])
)doc";

ttnn::bind_registered_operation(
module,
ttnn::repeat,
doc,
ttnn::pybind_arguments_t{
py::arg("input_tensor"), py::arg("shape"), py::kw_only(), py::arg("memory_config") = std::nullopt});
}

void py_module(py::module& module) {
detail::bind_permute(module);
detail::bind_concat(module);
detail::bind_pad(module);
detail::bind_slice(module);
bind_repeat(module);
detail::bind_repeat_interleave(module);
detail::bind_tilize(module);
detail::bind_tilize_with_val_padding(module);
Expand All @@ -80,6 +49,7 @@ void py_module(py::module& module) {
detail::bind_untilize_with_halo_v2(module);
bind_non_zero_indices(module);
bind_fill_rm(module);
py_bind_repeat(module);
}

} // namespace data_movement
Expand Down
Loading

0 comments on commit 5161b53

Please sign in to comment.