diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/moreh_layer_norm_large_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/moreh_layer_norm_large_kernel.cpp index c8da0247d5c..bfa667c8898 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/moreh_layer_norm_large_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/moreh_layer_norm_large_kernel.cpp @@ -4,7 +4,6 @@ #include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp" - ALWI bool need_to_do_mask_h(uint32_t w_idx, uint32_t origin_num_h_tiles, uint32_t origin_num_w_tiles) { return ((w_idx / origin_num_w_tiles) + 1) % origin_num_h_tiles == 0; } @@ -69,7 +68,6 @@ void MAIN { constexpr uint32_t origin_Wt = (origin_W + TILE_WIDTH - 1) / TILE_WIDTH; for (uint32_t outer_idx = 0; outer_idx < num_rows_per_core; outer_idx++) { - /* * Sum[x] * cb_xsum diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/moreh_layer_norm_small_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/moreh_layer_norm_small_kernel.cpp index f78d3af08cd..c49fc5e641d 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/moreh_layer_norm_small_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/moreh_layer_norm_small_kernel.cpp @@ -158,7 +158,7 @@ void MAIN { tile_regs_release(); } } // block_size loop - } // num_inner loop + } // num_inner loop // We don't pop cb_x until we compute xmm. /* @@ -442,7 +442,7 @@ void MAIN { cb_pop_front(cb_beta, block_size); cb_push_back(cb_out, block_size); } // if (beta_has_value) - } // num_inner loop + } // num_inner loop cb_pop_front(cb_recip_std, onetile); cb_pop_front(cb_xmm, num_inner); } // num_rows_per_core loop diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/writer_moreh_layer_norm.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/writer_moreh_layer_norm.cpp index 720fced5a60..60f7d8cd004 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/writer_moreh_layer_norm.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/writer_moreh_layer_norm.cpp @@ -6,8 +6,17 @@ #include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" template -void write_mean_rstd(uint32_t cb_id, uint32_t tile_offset, uint32_t num_inner, uint32_t normalized_dims, uint32_t outer_idx, uint32_t output_height, uint32_t output_width, uint32_t Ht, uint32_t Wt, T addrg) -{ +void write_mean_rstd( + uint32_t cb_id, + uint32_t tile_offset, + uint32_t num_inner, + uint32_t normalized_dims, + uint32_t outer_idx, + uint32_t output_height, + uint32_t output_width, + uint32_t Ht, + uint32_t Wt, + T addrg) { constexpr uint32_t onetile = 1; const uint32_t cb_tile_bytes = get_tile_size(cb_id); @@ -126,11 +135,31 @@ void kernel_main() { for (uint32_t outer_idx = 0; outer_idx < num_rows_per_core; outer_idx++) { if (mean_has_value) { - write_mean_rstd(cb_id_mean, tile_offset, num_inner, normalized_dims, outer_idx, mean_rstd_height, mean_rstd_width, Ht, Wt, mean_addrg); + write_mean_rstd( + cb_id_mean, + tile_offset, + num_inner, + normalized_dims, + outer_idx, + mean_rstd_height, + mean_rstd_width, + Ht, + Wt, + mean_addrg); } if (rstd_has_value) { - write_mean_rstd(cb_id_rstd, tile_offset, num_inner, normalized_dims, outer_idx, mean_rstd_height, mean_rstd_width, Ht, Wt, rstd_addrg); + write_mean_rstd( + cb_id_rstd, + tile_offset, + num_inner, + normalized_dims, + outer_idx, + mean_rstd_height, + mean_rstd_width, + Ht, + Wt, + rstd_addrg); } // output diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/moreh_layer_norm_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/moreh_layer_norm_device_operation.cpp index 9d3dce058dd..35315cd2594 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/moreh_layer_norm_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/moreh_layer_norm_device_operation.cpp @@ -4,8 +4,6 @@ #include "moreh_layer_norm_device_operation.hpp" -#include - #include "tt_dnn/op_library/moreh_helper_functions.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/types.hpp" diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/moreh_layer_norm_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/moreh_layer_norm_device_operation.hpp index d42653beebf..9e1af4db658 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/moreh_layer_norm_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/moreh_layer_norm_device_operation.hpp @@ -2,9 +2,6 @@ // // SPDX-License-Identifier: Apache-2.0 -#include -#include - #include "ttnn/decorators.hpp" #include "ttnn/device_operation.hpp" #include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/moreh_layer_norm.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/moreh_layer_norm.cpp index d02626a7607..d8268c7dbbb 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/moreh_layer_norm.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/moreh_layer_norm.cpp @@ -11,23 +11,40 @@ std::vector> MorehLayerNorm::invoke( const Tensor& input, const uint32_t normalized_dims, const float eps, - const std::optional gamma, - const std::optional beta, - const std::optional output, - const std::optional mean, - const std::optional rstd, + const std::optional& gamma, + const std::optional& beta, + const std::optional& output, + const std::optional& mean, + const std::optional& rstd, const std::optional& memory_config, const std::optional& compute_kernel_config) { return ttnn::prim::moreh_layer_norm( - input, - normalized_dims, - eps, - gamma, - beta, - output, - mean, - rstd, - memory_config, - compute_kernel_config); + input, normalized_dims, eps, gamma, beta, output, mean, rstd, memory_config, compute_kernel_config); +} + +std::vector MorehLayerNorm::create_async_output_tensors( + const std::vector& input_tensors, const std::vector>& optional_inputs) { + const auto& input = input_tensors.at(0); + return { + Tensor(operation::get_workers_for_op_output({input})), + Tensor(operation::get_workers_for_op_output({input})), + Tensor(operation::get_workers_for_op_output({input}))}; +} + +std::vector MorehLayerNorm::create_async_return_flag( + const Tensor& input, + const uint32_t normalized_dims, + const float eps, + const std::optional& gamma, + const std::optional& beta, + const std::optional& output, + const std::optional& mean, + const std::optional& rstd, + const std::optional& memory_config, + const std::optional& compute_kernel_config) { + const auto return_mean = mean.has_value(); + const auto return_rstd = rstd.has_value(); + + return {true, return_mean, return_rstd}; } } // namespace ttnn::operations::moreh::moreh_layer_norm diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/moreh_layer_norm.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/moreh_layer_norm.hpp index 5d3f51aaaa5..238634eb638 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/moreh_layer_norm.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/moreh_layer_norm.hpp @@ -13,17 +13,34 @@ struct MorehLayerNorm { const Tensor& input, const uint32_t normalized_dims, const float eps, - const std::optional gamma, - const std::optional beta, - const std::optional output, - const std::optional mean, - const std::optional rstd, + const std::optional& gamma, + const std::optional& beta, + const std::optional& output, + const std::optional& mean, + const std::optional& rstd, + const std::optional& memory_config, + const std::optional& compute_kernel_config); + + static std::vector create_async_output_tensors( + const std::vector& input_tensors, const std::vector>& optional_inputs); + + // The parameters of this function must be identical to those of invoke. + static std::vector create_async_return_flag( + const Tensor& input, + const uint32_t normalized_dims, + const float eps, + const std::optional& gamma, + const std::optional& beta, + const std::optional& output, + const std::optional& mean, + const std::optional& rstd, const std::optional& memory_config, const std::optional& compute_kernel_config); }; } // namespace ttnn::operations::moreh::moreh_layer_norm namespace ttnn { -constexpr auto moreh_layer_norm = - ttnn::register_operation<"ttnn::moreh_layer_norm", ttnn::operations::moreh::moreh_layer_norm::MorehLayerNorm>(); +constexpr auto moreh_layer_norm = ttnn::register_operation_with_auto_launch_op< + "ttnn::moreh_layer_norm", + ttnn::operations::moreh::moreh_layer_norm::MorehLayerNorm>(); } diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/moreh_layer_norm_backward_gamma_beta_grad_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/moreh_layer_norm_backward_gamma_beta_grad_kernel.cpp index 2153cc1a08a..d028f9c26a3 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/moreh_layer_norm_backward_gamma_beta_grad_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/moreh_layer_norm_backward_gamma_beta_grad_kernel.cpp @@ -4,7 +4,6 @@ #include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp" - namespace NAMESPACE { void MAIN { constexpr uint32_t num_cols_per_core = get_compile_time_arg_val(0); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/reader_moreh_layer_norm_backward_gamma_beta_grad.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/reader_moreh_layer_norm_backward_gamma_beta_grad.cpp index 19cff5b290a..90e43e1488f 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/reader_moreh_layer_norm_backward_gamma_beta_grad.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/reader_moreh_layer_norm_backward_gamma_beta_grad.cpp @@ -4,9 +4,17 @@ #include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" - template -void read_mean_rstd(uint32_t cb_id, uint32_t tile_offset, uint32_t normalized_dims, uint32_t outer_idx, uint32_t height, uint32_t width, uint32_t Ht, uint32_t Wt, T addrg) { +void read_mean_rstd( + uint32_t cb_id, + uint32_t tile_offset, + uint32_t normalized_dims, + uint32_t outer_idx, + uint32_t height, + uint32_t width, + uint32_t Ht, + uint32_t Wt, + T addrg) { constexpr uint32_t onetile = 1; const uint32_t cb_tile_bytes = get_tile_size(cb_id); @@ -45,12 +53,11 @@ void read_mean_rstd(uint32_t cb_id, uint32_t tile_offset, uint32_t normalized_di } // rotate data - for (uint32_t i = 0 ; i < 16; i++ ) { + for (uint32_t i = 0; i < 16; i++) { l1_ptr[i * FACE_WIDTH] = l1_ptr[i]; l1_ptr[i * FACE_WIDTH + 256 * 2] = l1_ptr[i + 256]; } - } - else { + } else { auto idx = tile_offset + outer_idx; auto w = idx % width; @@ -67,9 +74,7 @@ void read_mean_rstd(uint32_t cb_id, uint32_t tile_offset, uint32_t normalized_di auto dst_noc_addr = get_noc_addr(noc_id, addrg); noc_async_read( - dst_noc_addr + tilized_idx * cb_dtype_bytes, - l1_write_addr + tilized_idx * cb_dtype_bytes, - cb_dtype_bytes); + dst_noc_addr + tilized_idx * cb_dtype_bytes, l1_write_addr + tilized_idx * cb_dtype_bytes, cb_dtype_bytes); noc_async_read_barrier(); if (idx != 0) { @@ -177,12 +182,30 @@ void kernel_main() { uint32_t mean_rstd_tile_offset = tile_offset / num_inner; // mean - read_mean_rstd(cb_id_mean, mean_rstd_tile_offset, normalized_dims, outer_idx, mean_rstd_height, mean_rstd_width, mean_rstd_Ht, mean_rstd_Wt, mean_addrg); + read_mean_rstd( + cb_id_mean, + mean_rstd_tile_offset, + normalized_dims, + outer_idx, + mean_rstd_height, + mean_rstd_width, + mean_rstd_Ht, + mean_rstd_Wt, + mean_addrg); // rstd - read_mean_rstd(cb_id_rstd, mean_rstd_tile_offset, normalized_dims, outer_idx, mean_rstd_height, mean_rstd_width, mean_rstd_Ht, mean_rstd_Wt, rstd_addrg); + read_mean_rstd( + cb_id_rstd, + mean_rstd_tile_offset, + normalized_dims, + outer_idx, + mean_rstd_height, + mean_rstd_width, + mean_rstd_Ht, + mean_rstd_Wt, + rstd_addrg); } // gamma_grad_has_value } // num_rows_per_core loop - } // num_inner loop + } // num_inner loop } // void kernel_main() diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/reader_moreh_layer_norm_backward_input_grad_large.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/reader_moreh_layer_norm_backward_input_grad_large.cpp index c64b4e06160..d7da30cce52 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/reader_moreh_layer_norm_backward_input_grad_large.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/reader_moreh_layer_norm_backward_input_grad_large.cpp @@ -4,9 +4,17 @@ #include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" - template -void read_mean_rstd(uint32_t cb_id, uint32_t tile_offset, uint32_t normalized_dims, uint32_t outer_idx, uint32_t height, uint32_t width, uint32_t Ht, uint32_t Wt, T addrg) { +void read_mean_rstd( + uint32_t cb_id, + uint32_t tile_offset, + uint32_t normalized_dims, + uint32_t outer_idx, + uint32_t height, + uint32_t width, + uint32_t Ht, + uint32_t Wt, + T addrg) { constexpr uint32_t onetile = 1; const uint32_t cb_tile_bytes = get_tile_size(cb_id); @@ -45,12 +53,11 @@ void read_mean_rstd(uint32_t cb_id, uint32_t tile_offset, uint32_t normalized_di } // rotate data - for (uint32_t i = 0 ; i < 16; i++ ) { + for (uint32_t i = 0; i < 16; i++) { l1_ptr[i * FACE_WIDTH] = l1_ptr[i]; l1_ptr[i * FACE_WIDTH + 256 * 2] = l1_ptr[i + 256]; } - } - else { + } else { auto idx = tile_offset + outer_idx; auto w = idx % width; @@ -67,9 +74,7 @@ void read_mean_rstd(uint32_t cb_id, uint32_t tile_offset, uint32_t normalized_di auto dst_noc_addr = get_noc_addr(noc_id, addrg); noc_async_read( - dst_noc_addr + tilized_idx * cb_dtype_bytes, - l1_write_addr + tilized_idx * cb_dtype_bytes, - cb_dtype_bytes); + dst_noc_addr + tilized_idx * cb_dtype_bytes, l1_write_addr + tilized_idx * cb_dtype_bytes, cb_dtype_bytes); noc_async_read_barrier(); if (idx != 0) { @@ -170,10 +175,28 @@ void kernel_main() { uint32_t mean_rstd_tile_offset = tile_offset / num_inner; // mean - read_mean_rstd(cb_id_mean, mean_rstd_tile_offset, normalized_dims, outer_idx, mean_rstd_height, mean_rstd_width, mean_rstd_Ht, mean_rstd_Wt, mean_addrg); + read_mean_rstd( + cb_id_mean, + mean_rstd_tile_offset, + normalized_dims, + outer_idx, + mean_rstd_height, + mean_rstd_width, + mean_rstd_Ht, + mean_rstd_Wt, + mean_addrg); // rstd - read_mean_rstd(cb_id_rstd, mean_rstd_tile_offset, normalized_dims, outer_idx, mean_rstd_height, mean_rstd_width, mean_rstd_Ht, mean_rstd_Wt, rstd_addrg); + read_mean_rstd( + cb_id_rstd, + mean_rstd_tile_offset, + normalized_dims, + outer_idx, + mean_rstd_height, + mean_rstd_width, + mean_rstd_Ht, + mean_rstd_Wt, + rstd_addrg); // For Sum[dy] and Sum[y * dy] for (uint32_t inner_idx = 0; inner_idx < num_inner; inner_idx++) { @@ -199,7 +222,7 @@ void kernel_main() { noc_async_read_barrier(); cb_push_back(cb_id_gamma, onetile); } // gamma_has_value - } // num_inner loop + } // num_inner loop // For ((n * dy - Sum[dy]) - (y * Sum[y * dy])) * (rstd / n) for (uint32_t inner_idx = 0; inner_idx < num_inner; inner_idx++) { diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/reader_moreh_layer_norm_backward_input_grad_small.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/reader_moreh_layer_norm_backward_input_grad_small.cpp index 8de8067b7db..f3c01bc575e 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/reader_moreh_layer_norm_backward_input_grad_small.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/reader_moreh_layer_norm_backward_input_grad_small.cpp @@ -5,7 +5,16 @@ #include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" template -void read_mean_rstd(uint32_t cb_id, uint32_t tile_offset, uint32_t normalized_dims, uint32_t outer_idx, uint32_t height, uint32_t width, uint32_t Ht, uint32_t Wt, T addrg) { +void read_mean_rstd( + uint32_t cb_id, + uint32_t tile_offset, + uint32_t normalized_dims, + uint32_t outer_idx, + uint32_t height, + uint32_t width, + uint32_t Ht, + uint32_t Wt, + T addrg) { constexpr uint32_t onetile = 1; const uint32_t cb_tile_bytes = get_tile_size(cb_id); @@ -44,12 +53,11 @@ void read_mean_rstd(uint32_t cb_id, uint32_t tile_offset, uint32_t normalized_di } // rotate data - for (uint32_t i = 0 ; i < 16; i++ ) { + for (uint32_t i = 0; i < 16; i++) { l1_ptr[i * FACE_WIDTH] = l1_ptr[i]; l1_ptr[i * FACE_WIDTH + 256 * 2] = l1_ptr[i + 256]; } - } - else { + } else { auto idx = tile_offset + outer_idx; auto w = idx % width; @@ -66,9 +74,7 @@ void read_mean_rstd(uint32_t cb_id, uint32_t tile_offset, uint32_t normalized_di auto dst_noc_addr = get_noc_addr(noc_id, addrg); noc_async_read( - dst_noc_addr + tilized_idx * cb_dtype_bytes, - l1_write_addr + tilized_idx * cb_dtype_bytes, - cb_dtype_bytes); + dst_noc_addr + tilized_idx * cb_dtype_bytes, l1_write_addr + tilized_idx * cb_dtype_bytes, cb_dtype_bytes); noc_async_read_barrier(); if (idx != 0) { @@ -168,10 +174,28 @@ void kernel_main() { uint32_t mean_rstd_tile_offset = tile_offset / num_inner; // mean - read_mean_rstd(cb_id_mean, mean_rstd_tile_offset, normalized_dims, outer_idx, mean_rstd_height, mean_rstd_width, mean_rstd_Ht, mean_rstd_Wt, mean_addrg); + read_mean_rstd( + cb_id_mean, + mean_rstd_tile_offset, + normalized_dims, + outer_idx, + mean_rstd_height, + mean_rstd_width, + mean_rstd_Ht, + mean_rstd_Wt, + mean_addrg); // rstd - read_mean_rstd(cb_id_rstd, mean_rstd_tile_offset, normalized_dims, outer_idx, mean_rstd_height, mean_rstd_width, mean_rstd_Ht, mean_rstd_Wt, rstd_addrg); + read_mean_rstd( + cb_id_rstd, + mean_rstd_tile_offset, + normalized_dims, + outer_idx, + mean_rstd_height, + mean_rstd_width, + mean_rstd_Ht, + mean_rstd_Wt, + rstd_addrg); // input (N, C, H, W) const uint32_t input_l1_write_ptr = get_write_ptr(cb_id_input); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_gamma_beta_grad_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_gamma_beta_grad_device_operation.cpp index 1cc21f696dd..15ce6553233 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_gamma_beta_grad_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_gamma_beta_grad_device_operation.cpp @@ -4,8 +4,6 @@ #include "moreh_layer_norm_backward_gamma_beta_grad_device_operation.hpp" -#include - #include "ttnn/tensor/tensor.hpp" namespace ttnn::operations::moreh::moreh_layer_norm_backward_gamma_beta_grad { diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_gamma_beta_grad_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_gamma_beta_grad_device_operation.hpp index af870776184..3238da257d7 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_gamma_beta_grad_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_gamma_beta_grad_device_operation.hpp @@ -2,9 +2,6 @@ // // SPDX-License-Identifier: Apache-2.0 -#include -#include - #include "ttnn/decorators.hpp" #include "ttnn/device_operation.hpp" #include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" @@ -70,7 +67,7 @@ struct MorehLayerNormBackwardGammaBetaGradOperation { const std::optional &gamma_grad, const std::optional &beta_grad, const std::optional &memory_config, - const std::optional& compute_kernel_config); + const std::optional &compute_kernel_config); }; } // namespace ttnn::operations::moreh::moreh_layer_norm_backward_gamma_beta_grad diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_input_grad_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_input_grad_device_operation.hpp index 03c892023ee..300a0438fc0 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_input_grad_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_input_grad_device_operation.hpp @@ -2,9 +2,6 @@ // // SPDX-License-Identifier: Apache-2.0 -#include -#include - #include "ttnn/decorators.hpp" #include "ttnn/device_operation.hpp" #include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_input_grad_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_input_grad_program_factory.cpp index e9f8d4b511e..0ccd10edee3 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_input_grad_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_input_grad_program_factory.cpp @@ -2,8 +2,6 @@ // // SPDX-License-Identifier: Apache-2.0 -#include - #include "moreh_layer_norm_backward_input_grad_device_operation.hpp" #include "tt_metal/common/work_split.hpp" #include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward.cpp index 6d744fc08e6..05deee20a09 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward.cpp @@ -4,8 +4,8 @@ #include "moreh_layer_norm_backward.hpp" -#include "ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_gamma_beta_grad_device_operation.hpp" -#include "ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_input_grad_device_operation.hpp" +#include "device/moreh_layer_norm_backward_gamma_beta_grad_device_operation.hpp" +#include "device/moreh_layer_norm_backward_input_grad_device_operation.hpp" namespace ttnn::operations::moreh::moreh_layer_norm_backward { std::vector> moreh_layernorm_backward_gamma_beta_grad( @@ -28,7 +28,15 @@ std::vector> moreh_layernorm_backward_gamma_beta_grad( } const auto& ret = ttnn::prim::moreh_layer_norm_backward_gamma_beta_grad( - output_grad, input, mean, rstd, normalized_dims, gamma_grad, beta_grad, memory_config, compute_kernel_config_val); + output_grad, + input, + mean, + rstd, + normalized_dims, + gamma_grad, + beta_grad, + memory_config, + compute_kernel_config_val); return ret; } @@ -49,15 +57,7 @@ std::vector> MorehLayerNormBackward::invoke( if (input_grad.has_value()) { outputs.push_back(ttnn::prim::moreh_layer_norm_backward_input_grad( - output_grad, - input, - mean, - rstd, - normalized_dims, - input_grad, - gamma, - memory_config, - compute_kernel_config)); + output_grad, input, mean, rstd, normalized_dims, input_grad, gamma, memory_config, compute_kernel_config)); } else { outputs.push_back(std::nullopt); } @@ -69,4 +69,31 @@ std::vector> MorehLayerNormBackward::invoke( return outputs; } + +std::vector MorehLayerNormBackward::create_async_output_tensors( + const std::vector& input_tensors, const std::vector>& optional_inputs) { + const auto& output_grad = input_tensors.at(0); + return { + Tensor(operation::get_workers_for_op_output({output_grad})), + Tensor(operation::get_workers_for_op_output({output_grad})), + Tensor(operation::get_workers_for_op_output({output_grad}))}; +} + +std::vector MorehLayerNormBackward::create_async_return_flag( + const Tensor& output_grad, + const Tensor& input, + const Tensor& mean, + const Tensor& rstd, + uint32_t normalized_dims, + const std::optional& gamma, + const std::optional& input_grad, + const std::optional& gamma_grad, + const std::optional& beta_grad, + const std::optional& memory_config, + const std::optional& compute_kernel_config) { + const auto return_input_grad = input_grad.has_value(); + const auto return_gamma_grad = gamma_grad.has_value(); + const auto return_beta_grad = beta_grad.has_value(); + return {return_input_grad, return_gamma_grad, return_beta_grad}; +} } // namespace ttnn::operations::moreh::moreh_layer_norm_backward diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward.hpp index 74d03f685c0..75d93b64946 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward.hpp @@ -21,11 +21,28 @@ struct MorehLayerNormBackward { const std::optional& beta_grad, const std::optional& memory_config, const std::optional& compute_kernel_config); + + static std::vector create_async_output_tensors( + const std::vector& input_tensors, const std::vector>& optional_inputs); + + // The parameters of this function must be identical to those of invoke. + static std::vector create_async_return_flag( + const Tensor& output_grad, + const Tensor& input, + const Tensor& mean, + const Tensor& rstd, + uint32_t normalized_dims, + const std::optional& gamma, + const std::optional& input_grad, + const std::optional& gamma_grad, + const std::optional& beta_grad, + const std::optional& memory_config, + const std::optional& compute_kernel_config); }; } // namespace ttnn::operations::moreh::moreh_layer_norm_backward namespace ttnn { -constexpr auto moreh_layer_norm_backward = ttnn::register_operation< +constexpr auto moreh_layer_norm_backward = ttnn::register_operation_with_auto_launch_op< "ttnn::moreh_layer_norm_backward", ttnn::operations::moreh::moreh_layer_norm_backward::MorehLayerNormBackward>(); } diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward_pybind.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward_pybind.cpp index fcccaf83f25..23c1aac83ea 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/moreh_layer_norm_backward_pybind.cpp @@ -15,17 +15,17 @@ void bind_moreh_layer_norm_backward_operation(py::module& module) { ttnn::moreh_layer_norm_backward, "Moreh Layer Norm Backward Operation", ttnn::pybind_arguments_t{ - py::arg("output_grad"), - py::arg("input"), - py::arg("mean"), - py::arg("rstd"), - py::arg("normalized_dims"), - py::kw_only(), - py::arg("gamma") = std::nullopt, - py::arg("input_grad") = std::nullopt, - py::arg("gamma_grad") = std::nullopt, - py::arg("beta_grad") = std::nullopt, - py::arg("memory_config") = std::nullopt, - py::arg("compute_kernel_config") = std::nullopt}); + py::arg("output_grad"), + py::arg("input"), + py::arg("mean"), + py::arg("rstd"), + py::arg("normalized_dims"), + py::kw_only(), + py::arg("gamma") = std::nullopt, + py::arg("input_grad") = std::nullopt, + py::arg("gamma_grad") = std::nullopt, + py::arg("beta_grad") = std::nullopt, + py::arg("memory_config") = std::nullopt, + py::arg("compute_kernel_config") = std::nullopt}); } } // namespace ttnn::operations::moreh::moreh_layer_norm_backward