Skip to content

Commit

Permalink
#12632: Refactor moreh_layer_norm
Browse files Browse the repository at this point in the history
  • Loading branch information
ngohoang34 committed Oct 2, 2024
1 parent 59d8749 commit fbe2897
Show file tree
Hide file tree
Showing 18 changed files with 262 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp"


ALWI bool need_to_do_mask_h(uint32_t w_idx, uint32_t origin_num_h_tiles, uint32_t origin_num_w_tiles) {
return ((w_idx / origin_num_w_tiles) + 1) % origin_num_h_tiles == 0;
}
Expand Down Expand Up @@ -69,7 +68,6 @@ void MAIN {
constexpr uint32_t origin_Wt = (origin_W + TILE_WIDTH - 1) / TILE_WIDTH;

for (uint32_t outer_idx = 0; outer_idx < num_rows_per_core; outer_idx++) {

/*
* Sum[x]
* cb_xsum
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ void MAIN {
tile_regs_release();
}
} // block_size loop
} // num_inner loop
} // num_inner loop
// We don't pop cb_x until we compute xmm.

/*
Expand Down Expand Up @@ -442,7 +442,7 @@ void MAIN {
cb_pop_front(cb_beta, block_size);
cb_push_back(cb_out, block_size);
} // if (beta_has_value)
} // num_inner loop
} // num_inner loop
cb_pop_front(cb_recip_std, onetile);
cb_pop_front(cb_xmm, num_inner);
} // num_rows_per_core loop
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,17 @@
#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp"

template <typename T>
void write_mean_rstd(uint32_t cb_id, uint32_t tile_offset, uint32_t num_inner, uint32_t normalized_dims, uint32_t outer_idx, uint32_t output_height, uint32_t output_width, uint32_t Ht, uint32_t Wt, T addrg)
{
void write_mean_rstd(
uint32_t cb_id,
uint32_t tile_offset,
uint32_t num_inner,
uint32_t normalized_dims,
uint32_t outer_idx,
uint32_t output_height,
uint32_t output_width,
uint32_t Ht,
uint32_t Wt,
T addrg) {
constexpr uint32_t onetile = 1;

const uint32_t cb_tile_bytes = get_tile_size(cb_id);
Expand Down Expand Up @@ -126,11 +135,31 @@ void kernel_main() {

for (uint32_t outer_idx = 0; outer_idx < num_rows_per_core; outer_idx++) {
if (mean_has_value) {
write_mean_rstd(cb_id_mean, tile_offset, num_inner, normalized_dims, outer_idx, mean_rstd_height, mean_rstd_width, Ht, Wt, mean_addrg);
write_mean_rstd(
cb_id_mean,
tile_offset,
num_inner,
normalized_dims,
outer_idx,
mean_rstd_height,
mean_rstd_width,
Ht,
Wt,
mean_addrg);
}

if (rstd_has_value) {
write_mean_rstd(cb_id_rstd, tile_offset, num_inner, normalized_dims, outer_idx, mean_rstd_height, mean_rstd_width, Ht, Wt, rstd_addrg);
write_mean_rstd(
cb_id_rstd,
tile_offset,
num_inner,
normalized_dims,
outer_idx,
mean_rstd_height,
mean_rstd_width,
Ht,
Wt,
rstd_addrg);
}

// output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

#include "moreh_layer_norm_device_operation.hpp"

#include <cstdint>

#include "tt_dnn/op_library/moreh_helper_functions.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/types.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <variant>
#include <vector>

#include "ttnn/decorators.hpp"
#include "ttnn/device_operation.hpp"
#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,40 @@ std::vector<std::optional<Tensor>> MorehLayerNorm::invoke(
const Tensor& input,
const uint32_t normalized_dims,
const float eps,
const std::optional<const Tensor> gamma,
const std::optional<const Tensor> beta,
const std::optional<const Tensor> output,
const std::optional<const Tensor> mean,
const std::optional<const Tensor> rstd,
const std::optional<const Tensor>& gamma,
const std::optional<const Tensor>& beta,
const std::optional<const Tensor>& output,
const std::optional<const Tensor>& mean,
const std::optional<const Tensor>& rstd,
const std::optional<MemoryConfig>& memory_config,
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config) {
return ttnn::prim::moreh_layer_norm(
input,
normalized_dims,
eps,
gamma,
beta,
output,
mean,
rstd,
memory_config,
compute_kernel_config);
input, normalized_dims, eps, gamma, beta, output, mean, rstd, memory_config, compute_kernel_config);
}

std::vector<Tensor> MorehLayerNorm::create_async_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_inputs) {
const auto& input = input_tensors.at(0);
return {
Tensor(operation::get_workers_for_op_output({input})),
Tensor(operation::get_workers_for_op_output({input})),
Tensor(operation::get_workers_for_op_output({input}))};
}

std::vector<bool> MorehLayerNorm::create_async_return_flag(
const Tensor& input,
const uint32_t normalized_dims,
const float eps,
const std::optional<const Tensor>& gamma,
const std::optional<const Tensor>& beta,
const std::optional<const Tensor>& output,
const std::optional<const Tensor>& mean,
const std::optional<const Tensor>& rstd,
const std::optional<MemoryConfig>& memory_config,
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config) {
const auto return_mean = mean.has_value();
const auto return_rstd = rstd.has_value();

return {true, return_mean, return_rstd};
}
} // namespace ttnn::operations::moreh::moreh_layer_norm
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,34 @@ struct MorehLayerNorm {
const Tensor& input,
const uint32_t normalized_dims,
const float eps,
const std::optional<const Tensor> gamma,
const std::optional<const Tensor> beta,
const std::optional<const Tensor> output,
const std::optional<const Tensor> mean,
const std::optional<const Tensor> rstd,
const std::optional<const Tensor>& gamma,
const std::optional<const Tensor>& beta,
const std::optional<const Tensor>& output,
const std::optional<const Tensor>& mean,
const std::optional<const Tensor>& rstd,
const std::optional<MemoryConfig>& memory_config,
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config);

static std::vector<Tensor> create_async_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_inputs);

// The parameters of this function must be identical to those of invoke.
static std::vector<bool> create_async_return_flag(
const Tensor& input,
const uint32_t normalized_dims,
const float eps,
const std::optional<const Tensor>& gamma,
const std::optional<const Tensor>& beta,
const std::optional<const Tensor>& output,
const std::optional<const Tensor>& mean,
const std::optional<const Tensor>& rstd,
const std::optional<MemoryConfig>& memory_config,
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config);
};
} // namespace ttnn::operations::moreh::moreh_layer_norm

namespace ttnn {
constexpr auto moreh_layer_norm =
ttnn::register_operation<"ttnn::moreh_layer_norm", ttnn::operations::moreh::moreh_layer_norm::MorehLayerNorm>();
constexpr auto moreh_layer_norm = ttnn::register_operation_with_auto_launch_op<
"ttnn::moreh_layer_norm",
ttnn::operations::moreh::moreh_layer_norm::MorehLayerNorm>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp"


namespace NAMESPACE {
void MAIN {
constexpr uint32_t num_cols_per_core = get_compile_time_arg_val(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,17 @@

#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp"


template <typename T>
void read_mean_rstd(uint32_t cb_id, uint32_t tile_offset, uint32_t normalized_dims, uint32_t outer_idx, uint32_t height, uint32_t width, uint32_t Ht, uint32_t Wt, T addrg) {
void read_mean_rstd(
uint32_t cb_id,
uint32_t tile_offset,
uint32_t normalized_dims,
uint32_t outer_idx,
uint32_t height,
uint32_t width,
uint32_t Ht,
uint32_t Wt,
T addrg) {
constexpr uint32_t onetile = 1;

const uint32_t cb_tile_bytes = get_tile_size(cb_id);
Expand Down Expand Up @@ -45,12 +53,11 @@ void read_mean_rstd(uint32_t cb_id, uint32_t tile_offset, uint32_t normalized_di
}

// rotate data
for (uint32_t i = 0 ; i < 16; i++ ) {
for (uint32_t i = 0; i < 16; i++) {
l1_ptr[i * FACE_WIDTH] = l1_ptr[i];
l1_ptr[i * FACE_WIDTH + 256 * 2] = l1_ptr[i + 256];
}
}
else {
} else {
auto idx = tile_offset + outer_idx;

auto w = idx % width;
Expand All @@ -67,9 +74,7 @@ void read_mean_rstd(uint32_t cb_id, uint32_t tile_offset, uint32_t normalized_di

auto dst_noc_addr = get_noc_addr(noc_id, addrg);
noc_async_read(
dst_noc_addr + tilized_idx * cb_dtype_bytes,
l1_write_addr + tilized_idx * cb_dtype_bytes,
cb_dtype_bytes);
dst_noc_addr + tilized_idx * cb_dtype_bytes, l1_write_addr + tilized_idx * cb_dtype_bytes, cb_dtype_bytes);

noc_async_read_barrier();
if (idx != 0) {
Expand Down Expand Up @@ -177,12 +182,30 @@ void kernel_main() {
uint32_t mean_rstd_tile_offset = tile_offset / num_inner;

// mean
read_mean_rstd(cb_id_mean, mean_rstd_tile_offset, normalized_dims, outer_idx, mean_rstd_height, mean_rstd_width, mean_rstd_Ht, mean_rstd_Wt, mean_addrg);
read_mean_rstd(
cb_id_mean,
mean_rstd_tile_offset,
normalized_dims,
outer_idx,
mean_rstd_height,
mean_rstd_width,
mean_rstd_Ht,
mean_rstd_Wt,
mean_addrg);

// rstd
read_mean_rstd(cb_id_rstd, mean_rstd_tile_offset, normalized_dims, outer_idx, mean_rstd_height, mean_rstd_width, mean_rstd_Ht, mean_rstd_Wt, rstd_addrg);
read_mean_rstd(
cb_id_rstd,
mean_rstd_tile_offset,
normalized_dims,
outer_idx,
mean_rstd_height,
mean_rstd_width,
mean_rstd_Ht,
mean_rstd_Wt,
rstd_addrg);
} // gamma_grad_has_value

} // num_rows_per_core loop
} // num_inner loop
} // num_inner loop
} // void kernel_main()
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,17 @@

#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp"


template <typename T>
void read_mean_rstd(uint32_t cb_id, uint32_t tile_offset, uint32_t normalized_dims, uint32_t outer_idx, uint32_t height, uint32_t width, uint32_t Ht, uint32_t Wt, T addrg) {
void read_mean_rstd(
uint32_t cb_id,
uint32_t tile_offset,
uint32_t normalized_dims,
uint32_t outer_idx,
uint32_t height,
uint32_t width,
uint32_t Ht,
uint32_t Wt,
T addrg) {
constexpr uint32_t onetile = 1;

const uint32_t cb_tile_bytes = get_tile_size(cb_id);
Expand Down Expand Up @@ -45,12 +53,11 @@ void read_mean_rstd(uint32_t cb_id, uint32_t tile_offset, uint32_t normalized_di
}

// rotate data
for (uint32_t i = 0 ; i < 16; i++ ) {
for (uint32_t i = 0; i < 16; i++) {
l1_ptr[i * FACE_WIDTH] = l1_ptr[i];
l1_ptr[i * FACE_WIDTH + 256 * 2] = l1_ptr[i + 256];
}
}
else {
} else {
auto idx = tile_offset + outer_idx;

auto w = idx % width;
Expand All @@ -67,9 +74,7 @@ void read_mean_rstd(uint32_t cb_id, uint32_t tile_offset, uint32_t normalized_di

auto dst_noc_addr = get_noc_addr(noc_id, addrg);
noc_async_read(
dst_noc_addr + tilized_idx * cb_dtype_bytes,
l1_write_addr + tilized_idx * cb_dtype_bytes,
cb_dtype_bytes);
dst_noc_addr + tilized_idx * cb_dtype_bytes, l1_write_addr + tilized_idx * cb_dtype_bytes, cb_dtype_bytes);

noc_async_read_barrier();
if (idx != 0) {
Expand Down Expand Up @@ -170,10 +175,28 @@ void kernel_main() {
uint32_t mean_rstd_tile_offset = tile_offset / num_inner;

// mean
read_mean_rstd(cb_id_mean, mean_rstd_tile_offset, normalized_dims, outer_idx, mean_rstd_height, mean_rstd_width, mean_rstd_Ht, mean_rstd_Wt, mean_addrg);
read_mean_rstd(
cb_id_mean,
mean_rstd_tile_offset,
normalized_dims,
outer_idx,
mean_rstd_height,
mean_rstd_width,
mean_rstd_Ht,
mean_rstd_Wt,
mean_addrg);

// rstd
read_mean_rstd(cb_id_rstd, mean_rstd_tile_offset, normalized_dims, outer_idx, mean_rstd_height, mean_rstd_width, mean_rstd_Ht, mean_rstd_Wt, rstd_addrg);
read_mean_rstd(
cb_id_rstd,
mean_rstd_tile_offset,
normalized_dims,
outer_idx,
mean_rstd_height,
mean_rstd_width,
mean_rstd_Ht,
mean_rstd_Wt,
rstd_addrg);

// For Sum[dy] and Sum[y * dy]
for (uint32_t inner_idx = 0; inner_idx < num_inner; inner_idx++) {
Expand All @@ -199,7 +222,7 @@ void kernel_main() {
noc_async_read_barrier();
cb_push_back(cb_id_gamma, onetile);
} // gamma_has_value
} // num_inner loop
} // num_inner loop

// For ((n * dy - Sum[dy]) - (y * Sum[y * dy])) * (rstd / n)
for (uint32_t inner_idx = 0; inner_idx < num_inner; inner_idx++) {
Expand Down
Loading

0 comments on commit fbe2897

Please sign in to comment.