From 5161b53e5e384d031ef0be77958ee1bbec5a67d0 Mon Sep 17 00:00:00 2001 From: Naif Tarafdar <135640067+tarafdarTT@users.noreply.github.com> Date: Sun, 11 Aug 2024 12:39:54 -0400 Subject: [PATCH] #9749: move repeat (#11215) --- docs/source/ttnn/ttnn/dependencies/tt_lib.rst | 2 - .../demos/falcon7b_common/tt/falcon_model.py | 6 +- .../tests/unit_tests/test_attn_sdpa.py | 2 +- .../llama2_70b/tests/test_llama_attention.py | 4 +- .../llama2_70b/tests/test_llama_decoder.py | 4 +- .../tests/unit_tests/test_attn_sdpa.py | 2 +- .../llama2_70b/tt/llama_attention_galaxy.py | 4 +- .../llama2_70b/tt/llama_decoder_galaxy.py | 4 +- .../llama2_70b/tt/llama_model_optimized.py | 4 +- .../sweep_tests/tt_lib_ops.py | 6 +- .../unit_testing/misc/test_repeat.py | 2 +- tests/ttnn/profiling/ops_for_profiling.py | 4 +- tests/ttnn/profiling/reference.txt | 2 +- .../grayskull/test_repeat.py | 2 +- ttnn/CMakeLists.txt | 4 + .../tt_dnn/op_library/CMakeLists.txt | 2 - .../tt_dnn/op_library/repeat/repeat_op.cpp | 90 ------------------- .../csrc/tt_lib_bindings_tensor_dm_ops.cpp | 15 ---- ttnn/cpp/ttnn/operations/data_movement.hpp | 37 -------- .../data_movement/data_movement_pybind.hpp | 34 +------ .../reader_repeat_interleaved_start_id.cpp | 0 ...peat_stick_layout_interleaved_start_id.cpp | 0 .../data_movement/repeat/device/repeat_op.cpp | 53 +++++++++++ .../repeat/device}/repeat_op.hpp | 18 +--- .../repeat/device/repeat_program_factory.cpp} | 44 +++++---- .../repeat/device/repeat_program_factory.hpp | 13 +++ .../data_movement/repeat/repeat.cpp | 58 ++++++++++++ .../data_movement/repeat/repeat.hpp | 34 +++++++ .../data_movement/repeat/repeat_pybind.cpp | 74 +++++++++++++++ .../data_movement/repeat/repeat_pybind.hpp | 14 +++ .../repeat_interleave/repeat_interleave.hpp | 1 - .../ttnn/operations/eltwise/binary/binary.cpp | 2 +- .../binary_backward/binary_backward.hpp | 1 - .../eltwise/complex_binary/complex_binary.hpp | 1 - .../eltwise/complex_unary/complex_unary.hpp | 1 - .../complex_unary_backward.hpp | 1 - .../ternary_backward/ternary_backward.hpp | 1 - .../eltwise/unary_backward/unary_backward.hpp | 1 - 38 files changed, 298 insertions(+), 249 deletions(-) delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/repeat/repeat_op.cpp delete mode 100644 ttnn/cpp/ttnn/operations/data_movement.hpp rename ttnn/cpp/ttnn/{deprecated/tt_dnn/op_library/repeat => operations/data_movement/repeat/device}/kernels/dataflow/reader_repeat_interleaved_start_id.cpp (100%) rename ttnn/cpp/ttnn/{deprecated/tt_dnn/op_library/repeat => operations/data_movement/repeat/device}/kernels/dataflow/reader_repeat_stick_layout_interleaved_start_id.cpp (100%) create mode 100644 ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_op.cpp rename ttnn/cpp/ttnn/{deprecated/tt_dnn/op_library/repeat => operations/data_movement/repeat/device}/repeat_op.hpp (50%) rename ttnn/cpp/ttnn/{deprecated/tt_dnn/op_library/repeat/multi_core/repeat_op_multi_core.cpp => operations/data_movement/repeat/device/repeat_program_factory.cpp} (80%) create mode 100644 ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_program_factory.hpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.hpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/repeat/repeat_pybind.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/repeat/repeat_pybind.hpp diff --git a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst index c0141575436..ca9ba9aac8f 100644 --- a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst @@ -405,5 +405,3 @@ Other Operations .. autofunction:: tt_lib.tensor.mean_hw .. autofunction:: tt_lib.tensor.lamb_optimizer - -.. autofunction:: tt_lib.tensor.repeat diff --git a/models/demos/falcon7b_common/tt/falcon_model.py b/models/demos/falcon7b_common/tt/falcon_model.py index 23a2d62dcb0..9a6215b1c27 100644 --- a/models/demos/falcon7b_common/tt/falcon_model.py +++ b/models/demos/falcon7b_common/tt/falcon_model.py @@ -159,10 +159,10 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token ) # Repeat attn masks for all heads for i in range(self.num_devices): - tt_attention_mask[i] = ttnn.experimental.tensor.repeat( + tt_attention_mask[i] = ttnn.repeat( tt_attention_mask[i], - [1, self.config.num_attention_heads, 1, 1], - output_mem_config=self.model_config["ATTN_MASK_MEMCFG"], + ttnn.Shape([1, self.config.num_attention_heads, 1, 1]), + memory_config=self.model_config["ATTN_MASK_MEMCFG"], ) # Tilize attn masks for i in range(self.num_devices): diff --git a/models/demos/t3000/llama2_70b/tests/unit_tests/test_attn_sdpa.py b/models/demos/t3000/llama2_70b/tests/unit_tests/test_attn_sdpa.py index 208f6f448e9..0acd34ac57b 100644 --- a/models/demos/t3000/llama2_70b/tests/unit_tests/test_attn_sdpa.py +++ b/models/demos/t3000/llama2_70b/tests/unit_tests/test_attn_sdpa.py @@ -68,7 +68,7 @@ def scale_mask_softmax_decomposed(self, attn, scale, attn_mask): attn = ttnn.multiply(attn, scale) ## Need to figure out how to broadcast in t dim - # attn_mask = tt_lib.tensor.repeat(attn_mask, [1, attn.shape()[1], 1, 1]) # this causes memory error as the broadcast result is too big + # attn_mask = ttnn.repeat(attn_mask, [1, attn.shape()[1], 1, 1]) # this causes memory error as the broadcast result is too big # attn_mask = tt2torch_tensor(attn_mask) # attn_mask = attn_mask.repeat(1, attn.shape()[1], 1, 1) # attn_mask = torch2tt_tensor(attn_mask, self.device) diff --git a/models/experimental/llama2_70b/tests/test_llama_attention.py b/models/experimental/llama2_70b/tests/test_llama_attention.py index 5ca80a91275..561a24f2024 100644 --- a/models/experimental/llama2_70b/tests/test_llama_attention.py +++ b/models/experimental/llama2_70b/tests/test_llama_attention.py @@ -217,8 +217,8 @@ def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos): attn_masks = ttnn.to_device(attn_masks, llama_attention_model.device_mesh) repeat_shape = (1, batch, 1, 1) - attn_masks = tt_lib.tensor.repeat( - attn_masks, repeat_shape, output_mem_config=llama_attention_model.model_config["DRAM_MEMCFG"] + attn_masks = ttnn.repeat( + attn_masks, ttnn.Shape(repeat_shape), memory_config=llama_attention_model.model_config["DRAM_MEMCFG"] ) return ( xs, diff --git a/models/experimental/llama2_70b/tests/test_llama_decoder.py b/models/experimental/llama2_70b/tests/test_llama_decoder.py index 4f428904fd9..a1e211e0b3b 100644 --- a/models/experimental/llama2_70b/tests/test_llama_decoder.py +++ b/models/experimental/llama2_70b/tests/test_llama_decoder.py @@ -219,8 +219,8 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos): attn_masks = ttnn.to_device(attn_masks, llama_decoder_model.device_mesh) repeat_shape = (1, batch, 1, 1) - attn_masks = tt_lib.tensor.repeat( - attn_masks, repeat_shape, output_mem_config=llama_decoder_model.model_config["DRAM_MEMCFG"] + attn_masks = ttnn.repeat( + attn_masks, ttnn.Shape(repeat_shape), memory_config=llama_decoder_model.model_config["DRAM_MEMCFG"] ) return ( xs, diff --git a/models/experimental/llama2_70b/tests/unit_tests/test_attn_sdpa.py b/models/experimental/llama2_70b/tests/unit_tests/test_attn_sdpa.py index 80093dbc4b8..a00d5e83f1d 100644 --- a/models/experimental/llama2_70b/tests/unit_tests/test_attn_sdpa.py +++ b/models/experimental/llama2_70b/tests/unit_tests/test_attn_sdpa.py @@ -68,7 +68,7 @@ def scale_mask_softmax_decomposed(self, attn, scale, attn_mask): attn = ttnn.multiply(attn, scale) ## Need to figure out how to broadcast in t dim - # attn_mask = tt_lib.tensor.repeat(attn_mask, [1, attn.shape()[1], 1, 1]) # this causes memory error as the broadcast result is too big + # attn_mask = ttnn.repeat(attn_mask, [1, attn.shape()[1], 1, 1]) # this causes memory error as the broadcast result is too big # attn_mask = tt2torch_tensor(attn_mask) # attn_mask = attn_mask.repeat(1, attn.shape()[1], 1, 1) # attn_mask = torch2tt_tensor(attn_mask, self.device) diff --git a/models/experimental/llama2_70b/tt/llama_attention_galaxy.py b/models/experimental/llama2_70b/tt/llama_attention_galaxy.py index d8a3d9dd905..b929bfa8ba4 100644 --- a/models/experimental/llama2_70b/tt/llama_attention_galaxy.py +++ b/models/experimental/llama2_70b/tt/llama_attention_galaxy.py @@ -138,8 +138,8 @@ def prepare_inputs(self, x, start_pos): repeat_shape = (attn_batch, 1, 1, 1) for i in range(self.num_devices): - attn_masks[i] = tt_lib.tensor.repeat( - attn_masks[i], repeat_shape, output_mem_config=self.model_config["DRAM_MEMCFG"] + attn_masks[i] = ttnn.repeat( + attn_masks[i], ttnn.Shape(repeat_shape), memory_config=self.model_config["DRAM_MEMCFG"] ) # Put attn_mask on the device with the sharded config attention_mask_memconfig = self.model_config["ATTN_MASK_MEMCFG"] diff --git a/models/experimental/llama2_70b/tt/llama_decoder_galaxy.py b/models/experimental/llama2_70b/tt/llama_decoder_galaxy.py index 50fcfe974a1..32aafac7407 100644 --- a/models/experimental/llama2_70b/tt/llama_decoder_galaxy.py +++ b/models/experimental/llama2_70b/tt/llama_decoder_galaxy.py @@ -213,8 +213,8 @@ def prepare_inputs(self, x, start_pos): repeat_shape = (1, self.n_local_heads, 1, 1) for i in range(self.num_devices): - attn_masks[i] = tt_lib.tensor.repeat( - attn_masks[i], repeat_shape, output_mem_config=self.model_config["DRAM_MEMCFG"] + attn_masks[i] = ttnn.repeat( + attn_masks[i], ttnn.Shape(repeat_shape), memory_config=self.model_config["DRAM_MEMCFG"] ) # Put attn_mask on the device with the sharded config attention_mask_memconfig = self.model_config["ATTN_MASK_MEMCFG"] diff --git a/models/experimental/llama2_70b/tt/llama_model_optimized.py b/models/experimental/llama2_70b/tt/llama_model_optimized.py index 36d9f4e0d22..a7396d89e41 100644 --- a/models/experimental/llama2_70b/tt/llama_model_optimized.py +++ b/models/experimental/llama2_70b/tt/llama_model_optimized.py @@ -278,8 +278,8 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None): attn_masks = ttnn.to_device(attn_masks, self.device_mesh) repeat_shape = (1, batch, 1, 1) - attn_masks = tt_lib.tensor.repeat( - attn_masks, repeat_shape, output_mem_config=self.model_config["DRAM_MEMCFG"] + attn_masks = ttnn.repeat( + attn_masks, ttnn.Shape(repeat_shape), memory_config=self.model_config["DRAM_MEMCFG"] ) # Put attn_mask on the device with the sharded config attention_mask_memconfig = self.model_config["ATTN_MASK_MEMCFG"] diff --git a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py index d0243950311..9de0ff7ccc3 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py @@ -1115,9 +1115,9 @@ def repeat( **kwargs, ): t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) - t1 = ttl.tensor.repeat(t0, repeat, output_mem_config=output_mem_config) - - return tt2torch_tensor(t1) + t1 = ttnn.repeat(t0, ttnn.Shape(repeat), memory_config=output_mem_config) + output_tensor = ttnn.to_torch(t1) + return output_tensor @setup_host_and_device diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_repeat.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_repeat.py index cc40d6e39a9..c67c2d9c44a 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_repeat.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_repeat.py @@ -38,7 +38,7 @@ def run_repeat(input_shape, repeats, device, layout, dtype, input_mem_config, ou tt_cpu = input.repeat(torch.Size(repeats)) - tt = ttl.tensor.repeat(tt_input, ttl.tensor.Shape(repeats), output_mem_config) + tt = ttnn.repeat(tt_input, ttnn.Shape(repeats), memory_config=output_mem_config) tt_dev = tt.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch().to(torch.bfloat16) diff --git a/tests/ttnn/profiling/ops_for_profiling.py b/tests/ttnn/profiling/ops_for_profiling.py index 743bca84b48..efcdf4900b2 100644 --- a/tests/ttnn/profiling/ops_for_profiling.py +++ b/tests/ttnn/profiling/ops_for_profiling.py @@ -1520,7 +1520,7 @@ def swiglu_2(x): def repeat(x): - tt_lib.tensor.repeat(x, (1, 1, 1, 4)) + ttnn.repeat(x, ttnn.Shape((1, 1, 1, 4))) def repeat_interleave_0(x): @@ -2252,7 +2252,7 @@ def clone(x): }, { "op": repeat, - "name": "tt_lib.tensor.repeat", + "name": "ttnn.repeat", }, { "op": repeat_interleave_0, diff --git a/tests/ttnn/profiling/reference.txt b/tests/ttnn/profiling/reference.txt index 0bd16793f35..2327b4b72cf 100644 --- a/tests/ttnn/profiling/reference.txt +++ b/tests/ttnn/profiling/reference.txt @@ -96,7 +96,7 @@ tt_lib.tensor.real,200,0.027,0.029,0.06,0.012 ttnn.real_bw,200,0.821,0.827,0.847 ttnn.reglu_dim_2,200,0.102,0.107,0.245,0.045 ttnn.reglu_dim_3,200,0.105,0.111,0.244,0.045 -tt_lib.tensor.repeat,200,0.025,0.027,0.368,0.009 +ttnn.repeat,200,0.025,0.027,0.368,0.009 ttnn.repeat_interleave_dim_0,200,0.039,0.043,0.375,0.01 ttnn.repeat_interleave_dim_1,80,0.42,0.429,323.298,0.219 ttnn.repeat_interleave_dim_2,80,0.152,0.154,150.628,0.076 diff --git a/tests/ttnn/python_api_testing/non_working_unit_tests/grayskull/test_repeat.py b/tests/ttnn/python_api_testing/non_working_unit_tests/grayskull/test_repeat.py index 188696063cd..6fd5ce74d7b 100644 --- a/tests/ttnn/python_api_testing/non_working_unit_tests/grayskull/test_repeat.py +++ b/tests/ttnn/python_api_testing/non_working_unit_tests/grayskull/test_repeat.py @@ -32,7 +32,7 @@ def run_repeat_tests( x = ttnn_ops.setup_ttnn_tensor(x, device, dlayout[0], in_mem_config[0], dtype[0]) - tt_result = ttnn.repeat(x, ttnn.Shape(shape)) + tt_result = ttnn.repeat(x, shape) tt_result = ttnn_ops.ttnn_tensor_to_torch(tt_result, output_mem_config) diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 9b91a30c141..89b3765865d 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -141,6 +141,10 @@ set(TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/split/split.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/split/device/split_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/split/device/split_program_factory.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat/repeat.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat/device/repeat_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat/device/repeat_program_factory.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat/repeat_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/non_zero_indices/non_zero_indices.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/non_zero_indices/non_zero_indices_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/non_zero_indices/device/non_zero_indices_op.cpp diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt index 31fcf0c6894..7d6b3fa58af 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt @@ -120,8 +120,6 @@ set(TT_DNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_attn_matmul/multi_core_attn_matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_group_attn_matmul/multi_core_group_attn_matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_ssm_1d_sum_reduce/multi_core_ssm_1d_sum_reduce.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/repeat/multi_core/repeat_op_multi_core.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/repeat/repeat_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nlp_tms/nlp_tms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nlp_tms/nlp_create_qkv_heads_falcon7b.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nlp_tms/nlp_create_qkv_heads_decode.cpp diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/repeat/repeat_op.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/repeat/repeat_op.cpp deleted file mode 100644 index d7f04919923..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/repeat/repeat_op.cpp +++ /dev/null @@ -1,90 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/deprecated/tt_dnn/op_library/repeat/repeat_op.hpp" - -#include "ttnn/tensor/tensor_utils.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/auto_format.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/copy/copy_op.hpp" - -using namespace tt::constants; - -namespace tt { - -namespace tt_metal { - -RepeatOpParallelizationStrategy Repeat::get_parallelization_strategy(const std::vector &input_tensors) const { - return RepeatOpParallelizationStrategy::MULTI_CORE; -} - -void Repeat::validate(const std::vector &input_tensors) const { - const auto &input_tensor = input_tensors[0]; - tt::tt_metal::Shape input_shape = input_tensor.get_legacy_shape(); - TT_FATAL(this->repeat_dim < input_shape.rank(), "Repeat dim specified is larger than input tensor rank."); - if (input_tensor.get_layout() == Layout::ROW_MAJOR && this->repeat_dim == input_shape.rank() - 1) { - TT_FATAL( - (input_shape[this->repeat_dim] * input_tensor.element_size()) % input_tensor.buffer()->alignment() == 0, - "Current repeat implementation requires aligned last dim when repeating on last dim"); - } - TT_FATAL(this->num_repeats > 0, "Number of repeats should be greater than 0"); - TT_FATAL(input_tensor.buffer(), "Operand to repeat needs to be allocated in a buffer on device."); - TT_FATAL(input_tensor.device(), "Operand to repeat needs to be on device."); - TT_FATAL( - input_tensor.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, - "Input to repeat must be interleaved."); - TT_FATAL( - this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, - "Output of repeat must be interleaved."); -} - -std::vector Repeat::compute_output_shapes(const std::vector &input_tensors) const { - tt::tt_metal::Shape shape_out = input_tensors[0].get_legacy_shape(); - shape_out[this->repeat_dim] *= this->num_repeats; - return {shape_out}; -} - -std::vector Repeat::create_output_tensors(const std::vector &input_tensors) const { - const Tensor &ref_in_tensor = input_tensors[0]; - - return operation::generic_create_output_tensors( - *this, input_tensors, ref_in_tensor.get_dtype(), ref_in_tensor.get_layout(), this->output_mem_config); -} - -operation::ProgramWithCallbacks Repeat::create_program( - const std::vector &input_tensors, std::vector &output_tensors) const { - switch (this->get_parallelization_strategy(input_tensors)) { - case RepeatOpParallelizationStrategy::MULTI_CORE: - default: - return repeat_multi_core(input_tensors[0], this->repeat_dim, this->num_repeats, output_tensors[0]); - }; -} - -Tensor repeat(const Tensor &input_tensor, const Shape &shape, const MemoryConfig &output_mem_config) { - std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; - operation::launch_op( - [shape, output_mem_config] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) -> std::vector { - auto& input_tensor = input_tensors.at(0); - uint32_t input_rank = input_tensor.get_legacy_shape().rank(); - TT_FATAL(shape.rank() == input_rank, "Number of repeat dims must be equal to number of tensor dims"); - Tensor output = input_tensor; - for (uint32_t dim = 0; dim < shape.rank(); ++dim) { - if (shape[dim] == 1) { - continue; - } - TT_FATAL(shape[dim] > 0, "Number of repetitions along a dim must be greater than 0"); - if (input_tensor.get_layout() == Layout::ROW_MAJOR && dim == input_rank - 1) { - TT_FATAL( - (input_tensor.get_legacy_shape()[dim] * input_tensor.element_size()) % input_tensor.buffer()->alignment() == 0, - "Current repeat implementation requires aligned last dim when repeating on last dim"); - } - output = operation::run_without_autoformat(Repeat{dim, shape[dim], output_mem_config}, {output}).at(0); - } - return {output}; - }, {input_tensor}, output_tensors); - return output_tensors.at(0); -} - -} // namespace tt_metal - -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp b/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp index 53f03cd68e7..dd929622737 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp @@ -7,7 +7,6 @@ #include "ttnn/deprecated/tt_dnn/op_library/move/move_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/reshape/reshape_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/fold/fold_op.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/repeat/repeat_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/bcast/bcast_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/reduce/reduce_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/copy/copy_op.hpp" @@ -63,20 +62,6 @@ namespace tt::tt_metal::detail{ )doc" ); - m_tensor.def("repeat", &tt::tt_metal::repeat, - py::arg("input"), py::arg("size"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( - Returns a new tensor filled with repetition of input ``input`` tensor according to number of times specified in ``size``. The rank of ``size`` should be less than or equal to the rank of tensor ``input_a``. - - Output tensor will have same data type as input. - - .. csv-table:: - :header: "Argument", "Description", "Data type", "Valid range", "Required" - - "input", "Input tensor for which repetition is computed", "Tensor", "Tensor of any shape", "Yes" - "size", "The number of times to repeat this tensor along each dimension", "List[Int]", "Positive repetition values", "Yes" - "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" - )doc"); - m_tensor.def("assign", [](const Tensor& input_a, const Tensor& input_b, uint8_t queue_id){ return assign(queue_id, input_a, input_b); }, diff --git a/ttnn/cpp/ttnn/operations/data_movement.hpp b/ttnn/cpp/ttnn/operations/data_movement.hpp deleted file mode 100644 index aa2341ccf0a..00000000000 --- a/ttnn/cpp/ttnn/operations/data_movement.hpp +++ /dev/null @@ -1,37 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include "ttnn/tensor/types.hpp" -#include "ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/repeat/repeat_op.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/composite/composite_ops.hpp" -#include "ttnn/operations/core/core.hpp" - -#include - -namespace ttnn { -namespace operations { -namespace data_movement { - - -struct Repeat { - static ttnn::Tensor operator()( - const ttnn::Tensor& input_tensor, - const ttnn::Shape& shape, - std::optional output_mem_config = std::nullopt) { - MemoryConfig mem_config = output_mem_config.value_or(input_tensor.memory_config()); - auto output_tensor = tt::tt_metal::repeat(input_tensor, shape.value, mem_config); - return output_tensor; - } -}; - - -} // namespace data_movement -} // namespace operations - -constexpr auto repeat = ttnn::register_operation_with_auto_launch_op<"ttnn::repeat", ttnn::operations::data_movement::Repeat>(); - -} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/data_movement/data_movement_pybind.hpp b/ttnn/cpp/ttnn/operations/data_movement/data_movement_pybind.hpp index b2f6574dc54..195b62e9eca 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/data_movement_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/data_movement_pybind.hpp @@ -9,7 +9,6 @@ #include #include "ttnn/cpp/pybind11/decorators.hpp" -#include "ttnn/operations/data_movement.hpp" #include "ttnn/operations/data_movement/concat/concat_pybind.hpp" #include "ttnn/operations/data_movement/pad/pad_pybind.hpp" #include "ttnn/operations/data_movement/permute/permute_pybind.hpp" @@ -24,6 +23,7 @@ #include "ttnn/operations/data_movement/untilize_with_halo_v2/untilize_with_halo_v2_pybind.hpp" #include "ttnn/operations/data_movement/non_zero_indices/non_zero_indices_pybind.hpp" #include "ttnn/operations/data_movement/fill_rm/fill_rm_pybind.hpp" +#include "ttnn/operations/data_movement/repeat/repeat_pybind.hpp" namespace py = pybind11; @@ -32,43 +32,12 @@ namespace ttnn { namespace operations { namespace data_movement { -void bind_repeat(py::module& module) { - auto doc = R"doc( -repeat(input_tensor: ttnn.Tensor, shape : ttnn.Shape) -> ttnn.Tensor - -Returns a new tensor filled with repetition of input :attr:`input_tensor` according to number of times specified in :attr:`shape`. - -Args: - * :attr:`input_tensor`: the input_tensor to apply the repeate operation. - * :attr:`shape`: The number of repetitions for each element. - -Keyword Args: - * :attr:`memory_config`: the memory configuration to use for the operation - -Example: - - >>> tensor = ttnn.repeat(ttnn.from_torch(torch.tensor([[1, 2], [3, 4]]), 2,)), device) - >>> print(tensor) - tensor([[1, 2], - [1, 2], - [3, 4], - [3, 4]]) - )doc"; - - ttnn::bind_registered_operation( - module, - ttnn::repeat, - doc, - ttnn::pybind_arguments_t{ - py::arg("input_tensor"), py::arg("shape"), py::kw_only(), py::arg("memory_config") = std::nullopt}); -} void py_module(py::module& module) { detail::bind_permute(module); detail::bind_concat(module); detail::bind_pad(module); detail::bind_slice(module); - bind_repeat(module); detail::bind_repeat_interleave(module); detail::bind_tilize(module); detail::bind_tilize_with_val_padding(module); @@ -80,6 +49,7 @@ void py_module(py::module& module) { detail::bind_untilize_with_halo_v2(module); bind_non_zero_indices(module); bind_fill_rm(module); + py_bind_repeat(module); } } // namespace data_movement diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/repeat/kernels/dataflow/reader_repeat_interleaved_start_id.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/kernels/dataflow/reader_repeat_interleaved_start_id.cpp similarity index 100% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/repeat/kernels/dataflow/reader_repeat_interleaved_start_id.cpp rename to ttnn/cpp/ttnn/operations/data_movement/repeat/device/kernels/dataflow/reader_repeat_interleaved_start_id.cpp diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/repeat/kernels/dataflow/reader_repeat_stick_layout_interleaved_start_id.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/kernels/dataflow/reader_repeat_stick_layout_interleaved_start_id.cpp similarity index 100% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/repeat/kernels/dataflow/reader_repeat_stick_layout_interleaved_start_id.cpp rename to ttnn/cpp/ttnn/operations/data_movement/repeat/device/kernels/dataflow/reader_repeat_stick_layout_interleaved_start_id.cpp diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_op.cpp new file mode 100644 index 00000000000..ba3508c786b --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_op.cpp @@ -0,0 +1,53 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "repeat_op.hpp" +#include "repeat_program_factory.hpp" +#include "ttnn/tensor/tensor_utils.hpp" + +using namespace tt::constants; + + +namespace ttnn::operations::data_movement { + + +void RepeatDeviceOperation::validate(const std::vector &input_tensors) const { + const auto &input_tensor = input_tensors[0]; + tt::tt_metal::Shape input_shape = input_tensor.get_legacy_shape(); + TT_FATAL(this->repeat_dim < input_shape.rank(), "Repeat dim specified is larger than input tensor rank."); + if (input_tensor.get_layout() == Layout::ROW_MAJOR && this->repeat_dim == input_shape.rank() - 1) { + TT_FATAL( + (input_shape[this->repeat_dim] * input_tensor.element_size()) % input_tensor.buffer()->alignment() == 0, + "The last dim of tensor being repeated must be 32 byte aligned for DRAM Tensor and 16 byte aligned for L1 tensor"); + } + TT_FATAL(this->num_repeats > 0, "Number of repeats should be greater than 0"); + TT_FATAL(input_tensor.buffer(), "Operand to repeat needs to be allocated in a buffer on device."); + TT_FATAL(input_tensor.device(), "Operand to repeat needs to be on device."); + TT_FATAL( + input_tensor.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, + "Input to repeat must be interleaved."); + TT_FATAL( + this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, + "Output of repeat must be interleaved."); +} + +std::vector RepeatDeviceOperation::compute_output_shapes(const std::vector &input_tensors) const { + tt::tt_metal::Shape shape_out = input_tensors[0].get_legacy_shape(); + shape_out[this->repeat_dim] *= this->num_repeats; + return {shape_out}; +} + +std::vector RepeatDeviceOperation::create_output_tensors(const std::vector &input_tensors) const { + const Tensor &ref_in_tensor = input_tensors[0]; + + return operation::generic_create_output_tensors( + *this, input_tensors, ref_in_tensor.get_dtype(), ref_in_tensor.get_layout(), this->output_mem_config); +} + +operation::ProgramWithCallbacks RepeatDeviceOperation::create_program( + const std::vector &input_tensors, std::vector &output_tensors) const { + return detail::repeat_multi_core(input_tensors[0], this->repeat_dim, this->num_repeats, output_tensors[0]); +} + +} // namespace ttnn::operations::data_movement diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/repeat/repeat_op.hpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_op.hpp similarity index 50% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/repeat/repeat_op.hpp rename to ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_op.hpp index e7c24eb8059..4597d85843c 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/repeat/repeat_op.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_op.hpp @@ -7,13 +7,11 @@ #include "ttnn/tensor/tensor.hpp" #include "ttnn/run_operation.hpp" -namespace tt { -namespace tt_metal { +namespace ttnn::operations::data_movement { -enum class RepeatOpParallelizationStrategy { MULTI_CORE }; -struct Repeat { +struct RepeatDeviceOperation { const uint32_t repeat_dim; const uint32_t num_repeats; const MemoryConfig output_mem_config; @@ -23,20 +21,8 @@ struct Repeat { std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, std::vector &output_tensors) const; - RepeatOpParallelizationStrategy get_parallelization_strategy(const std::vector &input_tensors) const; }; -operation::ProgramWithCallbacks repeat_multi_core( - const Tensor &input_tensor, const uint32_t repeat_dim, const uint32_t num_repeats, const Tensor &output); -operation::ProgramWithCallbacks repeat_single_core( - const Tensor &input_tensor, const uint32_t repeat_dim, const uint32_t num_repeats, const Tensor &output); - -Tensor repeat( - const Tensor &input_tensor, - const Shape &shape, - const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - } // namespace tt_metal -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/repeat/multi_core/repeat_op_multi_core.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_program_factory.cpp similarity index 80% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/repeat/multi_core/repeat_op_multi_core.cpp rename to ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_program_factory.cpp index a35b2a0626a..1a88755388d 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/repeat/multi_core/repeat_op_multi_core.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_program_factory.cpp @@ -2,24 +2,22 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ttnn/deprecated/tt_dnn/op_library/repeat/repeat_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/work_split.hpp" #include "tt_metal/detail/util.hpp" #include "tt_metal/host_api.hpp" using namespace tt::constants; -namespace tt { -namespace tt_metal { +namespace ttnn::operations::data_movement::detail { operation::ProgramWithCallbacks repeat_multi_core( const Tensor &input_tensor, const uint32_t repeat_dim, const uint32_t num_repeats, const Tensor &output) { - tt_metal::Program program = tt_metal::CreateProgram(); + tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); - tt_metal::Device *device = output.device(); + tt::tt_metal::Device *device = output.device(); - const tt::DataFormat cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + const tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); const bool rm_layout = output.get_layout() == Layout::ROW_MAJOR; @@ -32,7 +30,7 @@ operation::ProgramWithCallbacks repeat_multi_core( single_page_size = align(output.element_size() * output.get_legacy_shape()[-1], output.buffer()->alignment()); } else { num_output_pages = output.volume() / TILE_HW; - single_page_size = tt_metal::detail::TileSize(cb_data_format); + single_page_size = tt::tt_metal::detail::TileSize(cb_data_format); } auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); @@ -41,21 +39,21 @@ operation::ProgramWithCallbacks repeat_multi_core( auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = split_work_to_cores(compute_with_storage_grid_size, num_output_pages, rm_orientation); - tt_metal::Buffer *dst_buffer = output.buffer(); + tt::tt_metal::Buffer *dst_buffer = output.buffer(); TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); uint32_t src0_cb_index = 0; uint32_t num_input_pages = 2; - tt_metal::CircularBufferConfig cb_src0_config = - tt_metal::CircularBufferConfig(num_input_pages * single_page_size, {{src0_cb_index, cb_data_format}}) + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig(num_input_pages * single_page_size, {{src0_cb_index, cb_data_format}}) .set_page_size(src0_cb_index, single_page_size); - auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); uint32_t num_dims = output.get_legacy_shape().rank(); auto input_buffer = input_tensor.buffer(); uint32_t src_addr = input_buffer->address(); - uint32_t src_is_dram = input_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + uint32_t src_is_dram = input_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; uint32_t src_page_size = input_buffer->page_size(); uint32_t num_pages_per_block; @@ -105,7 +103,7 @@ operation::ProgramWithCallbacks repeat_multi_core( // Reader compile-time args // Data is 32 byte aligned - bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; std::vector reader_compile_time_args = {// interleaved accessor args (std::uint32_t)src0_cb_index, (std::uint32_t)src_is_dram, @@ -121,20 +119,20 @@ operation::ProgramWithCallbacks repeat_multi_core( (std::uint32_t)src0_cb_index, (std::uint32_t)dst_is_dram}; - tt_metal::KernelHandle unary_reader_kernel_id = tt_metal::CreateKernel( + tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( program, rm_layout - ? "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/repeat/kernels/dataflow/reader_repeat_stick_layout_interleaved_start_id.cpp" - : "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/repeat/kernels/dataflow/reader_repeat_interleaved_start_id.cpp", + ? "ttnn/cpp/ttnn/operations/data_movement/repeat/device/kernels/dataflow/reader_repeat_stick_layout_interleaved_start_id.cpp" + : "ttnn/cpp/ttnn/operations/data_movement/repeat/device/kernels/dataflow/reader_repeat_interleaved_start_id.cpp", all_cores, - tt_metal::ReaderDataMovementConfig(reader_compile_time_args, repeat_defines)); + tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args, repeat_defines)); - tt_metal::KernelHandle unary_writer_kernel_id = tt_metal::CreateKernel( + tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( program, rm_layout ? "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/writer_unary_stick_layout_interleaved_start_id.cpp" : "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp", all_cores, - tt_metal::WriterDataMovementConfig(writer_compile_time_args)); + tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); const auto cores = grid_to_cores(num_cores, num_cores_x, num_cores_y, rm_orientation); uint32_t g1_num_cores = core_group_1.num_cores(); @@ -172,9 +170,9 @@ operation::ProgramWithCallbacks repeat_multi_core( } else { writer_kernel_args = {dst_buffer->address(), num_pages_per_core, num_pages_written}; } - tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_kernel_args); + tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_kernel_args); - tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_kernel_args); + tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_kernel_args); num_pages_written += num_pages_per_core; } @@ -202,6 +200,4 @@ operation::ProgramWithCallbacks repeat_multi_core( return {std::move(program), override_runtime_args_callback}; } -} // namespace tt_metal - -} // namespace tt +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_program_factory.hpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_program_factory.hpp new file mode 100644 index 00000000000..a625cf05a9d --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_program_factory.hpp @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include "tt_metal/host_api.hpp" + +namespace ttnn::operations::data_movement::detail { + +operation::ProgramWithCallbacks repeat_multi_core( + const Tensor &input_tensor, const uint32_t repeat_dim, const uint32_t num_repeats, const Tensor &output); + +} // namespace ttnn::operations::data_movement::detail diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp new file mode 100644 index 00000000000..bf04f1f5656 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + + +#include "ttnn/common/constants.hpp" +#include "ttnn/run_operation.hpp" +#include "ttnn/decorators.hpp" +#include "ttnn/operations/data_movement/repeat/repeat.hpp" +#include "device/repeat_op.hpp" + +namespace ttnn::operations::data_movement { + + +ttnn::Tensor RepeatOperation::operator()( + uint8_t queue_id, + const ttnn::Tensor& input_tensor, + const Shape & repeat_dims, + const std::optional& memory_config_arg) { + + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; + operation::launch_op( + [repeat_dims, memory_config_arg] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) -> std::vector { + auto& input_tensor = input_tensors.at(0); + auto memory_config = memory_config_arg.value_or(input_tensor.memory_config()); + uint32_t input_rank = input_tensor.get_legacy_shape().rank(); + TT_FATAL(repeat_dims.rank() == input_rank, "Number of repeat dims must be equal to number of tensor dims"); + Tensor output = input_tensor; + for (uint32_t dim = 0; dim < repeat_dims.size(); ++dim) { + if (repeat_dims[dim] == 1) { + continue; + } + TT_FATAL(repeat_dims[dim] > 0, "Number of repetitions along a dim must be greater than 0"); + if (input_tensor.get_layout() == Layout::ROW_MAJOR && dim == input_rank - 1) { + TT_FATAL( + (input_tensor.get_legacy_shape()[dim] * input_tensor.element_size()) % input_tensor.buffer()->alignment() == 0, + "Current repeat implementation requires aligned last dim when repeating on last dim"); + } + output = operation::run_without_autoformat(RepeatDeviceOperation{dim, repeat_dims[dim], memory_config}, {output}).at(0); + } + return {output}; + }, {input_tensor}, output_tensors); + return output_tensors.at(0); + +} + +ttnn::Tensor RepeatOperation::operator()( + const ttnn::Tensor& input_tensor, + const Shape & repeat_dims, + const std::optional& memory_config) { + return operator()(DefaultQueueId, input_tensor, repeat_dims, memory_config); +} + +ttnn::Tensor RepeatOperation::operator()(const ttnn::Tensor& input_tensor, const Shape & repeat_dims) { + return operator()(DefaultQueueId, input_tensor, repeat_dims, std::nullopt); +} + +} // ttnn::operations::data_movement namespace diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.hpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.hpp new file mode 100644 index 00000000000..53b5bea9237 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.hpp @@ -0,0 +1,34 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/run_operation.hpp" +#include "ttnn/decorators.hpp" + + +namespace ttnn { +namespace operations::data_movement { + +struct RepeatOperation { + static ttnn::Tensor operator()( + uint8_t queue_id, + const ttnn::Tensor& input_tensor, + const Shape & repeat_dims, + const std::optional& memory_config_arg); + + static ttnn::Tensor operator()( + const ttnn::Tensor& input_tensor, + const Shape & repeat_dims, + const std::optional& memory_config); + + static ttnn::Tensor operator()(const ttnn::Tensor& input_tensor, const Shape & repeat_dims); +}; + + +} // namespace operations::data_movement + +constexpr auto repeat = ttnn::register_operation_with_auto_launch_op<"ttnn::repeat", ttnn::operations::data_movement::RepeatOperation>(); + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat_pybind.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat_pybind.cpp new file mode 100644 index 00000000000..de73bd9c03c --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat_pybind.cpp @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + + +#include +#include + +#include "ttnn/cpp/pybind11/decorators.hpp" + +#include "repeat.hpp" + +namespace ttnn::operations::data_movement { +namespace py = pybind11; + +namespace detail { + template + void bind_repeat(py::module& module, const data_movement_operation_t& operation, const char *doc) { + ttnn::bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [] (const data_movement_operation_t& self, + const ttnn::Tensor& input_tensor, + const Shape & repeat_dims, + const std::optional& memory_config, + uint8_t queue_id) { + return self(queue_id, input_tensor, repeat_dims, memory_config); + }, + py::arg("input_tensor"), + py::arg("repeat_dims"), + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("queue_id") = 0, + } + ); + } + +} // namespace detail + + +void py_bind_repeat(py::module& module) { + auto doc = R"doc( + repeat(input_tensor: ttnn.Tensor, repeat_dims : ttnn.Shape) -> ttnn.Tensor + + Returns a new tensor filled with repetition of input :attr:`input_tensor` according to number of times specified in :attr:`shape`. + + Args: + * :attr:`input_tensor`: the input_tensor to apply the repeate operation. + * :attr:`repeat_dims`: The number of repetitions for each element. + Keyword Args: + * :attr:`memory_config`: the memory configuration to use for the operation + + Example: + + >>> tensor = ttnn.repeat(ttnn.from_torch(torch.tensor([[1, 2], [3, 4]]), 2,)), device) + >>> print(tensor) + tensor([[1, 2], + [1, 2], + [3, 4], + [3, 4]]) + )doc"; + + detail::bind_repeat( + module, + ttnn::repeat, + doc + ); +} + + + +} // namespace ttnn::operations::data_movement diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat_pybind.hpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat_pybind.hpp new file mode 100644 index 00000000000..0a26f83ab6d --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat_pybind.hpp @@ -0,0 +1,14 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + + +#include "pybind11/pybind_fwd.hpp" + +namespace ttnn::operations::data_movement { + +void py_bind_repeat(pybind11::module& module); + +} // namespace ttnn::operations::data_movement::detail diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat_interleave/repeat_interleave.hpp b/ttnn/cpp/ttnn/operations/data_movement/repeat_interleave/repeat_interleave.hpp index c07ad7f55f6..a0279485667 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/repeat_interleave/repeat_interleave.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat_interleave/repeat_interleave.hpp @@ -6,7 +6,6 @@ #include "ttnn/tensor/types.hpp" #include "ttnn/cpp/ttnn/operations/data_movement/concat/concat.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/repeat/repeat_op.hpp" #include "ttnn/operations/core/core.hpp" #include "ttnn/operations/data_movement/permute/permute.hpp" diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp index 3e45d16d6fe..b212102c27b 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp @@ -6,7 +6,7 @@ #include "binary.hpp" #include "device/binary_device_operation.hpp" #include "ttnn/device_operation.hpp" -#include "ttnn/operations/data_movement.hpp" +#include "ttnn/operations/data_movement/repeat/repeat.hpp" #include "ttnn/operations/eltwise/unary/unary.hpp" namespace ttnn::operations::binary { diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp index 203608675dc..c4043dbcb9a 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp @@ -8,7 +8,6 @@ #include "ttnn/common/constants.hpp" #include "device/binary_backward_op.hpp" #include "ttnn/device_operation.hpp" -#include "ttnn/operations/data_movement.hpp" #include "ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp" #include "ttnn/operations/eltwise/complex/complex.hpp" diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_binary/complex_binary.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex_binary/complex_binary.hpp index fc47106d5bf..bda47fe8a9e 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_binary/complex_binary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_binary/complex_binary.hpp @@ -7,7 +7,6 @@ #include "device/complex_binary_op.hpp" #include "ttnn/device_operation.hpp" -#include "ttnn/operations/data_movement.hpp" namespace ttnn { diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_unary/complex_unary.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex_unary/complex_unary.hpp index 52150a11216..79a1cd9338c 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_unary/complex_unary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_unary/complex_unary.hpp @@ -7,7 +7,6 @@ #include "device/complex_unary_op.hpp" #include "ttnn/device_operation.hpp" -#include "ttnn/operations/data_movement.hpp" namespace ttnn { diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/complex_unary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/complex_unary_backward.hpp index 9aecbae57ca..d7a8bde2016 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/complex_unary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/complex_unary_backward.hpp @@ -7,7 +7,6 @@ #include "device/complex_unary_backward_op.hpp" #include "ttnn/device_operation.hpp" -#include "ttnn/operations/data_movement.hpp" namespace ttnn { diff --git a/ttnn/cpp/ttnn/operations/eltwise/ternary_backward/ternary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/ternary_backward/ternary_backward.hpp index b4dc379666f..89fcaff3413 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/ternary_backward/ternary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/ternary_backward/ternary_backward.hpp @@ -7,7 +7,6 @@ #include "device/ternary_backward_op.hpp" #include "ttnn/device_operation.hpp" -#include "ttnn/operations/data_movement.hpp" namespace ttnn { diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp index 460d048bbca..d27953ef20a 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp @@ -7,7 +7,6 @@ #include "device/unary_backward_op.hpp" #include "ttnn/device_operation.hpp" -#include "ttnn/operations/data_movement.hpp" #include "ttnn/operations/eltwise/complex/complex.hpp" namespace ttnn {