Skip to content

Commit

Permalink
#9336: Refactoring moreh layernorm (#9636)
Browse files Browse the repository at this point in the history
The current `moreh_layernorm` only accepts 4D tensors as input. To meet this restriction, the mean and rstd tensors are stored inefficiently, with only one valid data point per tile.
The current implementation does not support optional outputs and the fp32_dest_acc_en feature.

Refactoring moreh_layernorm
- add fp32_dest_acc_en support
- support non 4d tensor
- support optional output tensor
  • Loading branch information
hschoi4448 authored Jul 5, 2024
1 parent 5b395d6 commit c52e153
Show file tree
Hide file tree
Showing 23 changed files with 2,036 additions and 1,290 deletions.

Large diffs are not rendered by default.

75 changes: 64 additions & 11 deletions tt_eager/tt_dnn/kernels/compute/moreh_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ ALWI void pack_tile_with_dt(uint32_t ifrom_dst, uint32_t icb)
pack_tile(ifrom_dst, icb);
}

ALWI void copy_tile_init_with_dt(uint32_t icb)
ALWI void copy_tile_init_with_dt(uint32_t icb, uint32_t transpose = 0)
{
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format_srca(icb);
#endif
copy_tile_to_dst_init_short(icb);
copy_tile_to_dst_init_short(icb, transpose);
}

ALWI void add_tiles_init_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) {
Expand All @@ -57,27 +57,86 @@ ALWI void add_tiles_init_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) {
add_tiles_init(icb0, icb1);
}

ALWI void add_bcast_rows_init_short_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) {
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
add_bcast_rows_init_short(icb0, icb1);
}

ALWI void add_bcast_cols_init_short_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) {
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
add_bcast_cols_init_short(icb0, icb1);
}

ALWI void add_bcast_scalar_init_short_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) {
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
add_bcast_scalar_init_short(icb0, icb1);
}

ALWI void sub_tiles_init_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) {
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
sub_tiles_init(icb0, icb1);
}

ALWI void sub_bcast_cols_init_short_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) {
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
sub_bcast_cols_init_short(icb0, icb1);
}

ALWI void sub_tiles_bcast_scalar_init_short_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) {
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
sub_tiles_bcast_scalar_init_short(icb0, icb1);
}

ALWI void mul_tiles_init_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) {
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
mul_tiles_init(icb0, icb1);
}

ALWI void mul_bcast_rows_init_short_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) {
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
mul_bcast_rows_init_short(icb0, icb1);
}

ALWI void mul_bcast_cols_init_short_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) {
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
mul_bcast_cols_init_short(icb0, icb1);
}

ALWI void mul_tiles_bcast_scalar_init_short_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) {
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
mul_tiles_bcast_scalar_init_short(icb0, icb1);
}

template<bool at_start, PoolType reduce_type = REDUCE_OP, ReduceDim reduce_dim = REDUCE_DIM>
ALWI void reduce_init_delta_with_dt(PoolType reduce_op, ReduceDim dim, uint32_t ocb = 16, uint32_t icb0 = 0, uint32_t icb1 = 1)
{
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
reduce_init_delta<at_start>(reduce_type, reduce_dim, ocb, icb0, icb1);
}


class ArgFetcher {
private:
int arg_idx = 0;
Expand Down Expand Up @@ -489,17 +548,14 @@ ALWI void reduce_tile_to_cb(
tile_regs_acquire();
cb_wait_front(icb1, onetile);

#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
reduce_init_delta<false>(reduce_op, dim);
reduce_init_delta_with_dt<false>(reduce_op, dim, ocb, icb0, icb1);
for (uint32_t x = 0; x < size; ++x) {
cb_wait_front(icb0, x + 1); // must be a cumulative wait for correctness

constexpr uint32_t bcast_scaler0 = 0; // 0th index from bcast_scaler CB
reduce_tile(icb0, icb1, x, bcast_scaler0, dst0);
}
reduce_revert_delta();
reduce_revert_delta(ocb);
tile_regs_commit();

if (pop0)
Expand Down Expand Up @@ -800,10 +856,7 @@ ALWI void reduce_tile_and_recip_tile_to_cb(
cb_wait_front(icb1, onetile);

tile_regs_acquire();
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
reduce_init_delta<false>(reduce_op, dim);
reduce_init_delta_with_dt<false>(reduce_op, dim, ocb, icb0, icb1);
for (uint32_t x = 0; x < size; ++x) {
cb_wait_front(icb0, x + 1); // must be a cumulative wait for correctness

Expand Down
42 changes: 42 additions & 0 deletions tt_eager/tt_dnn/op_library/moreh_helper_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "tt_eager/tt_dnn/op_library/work_split.hpp"
#include "tt_metal/detail/util.hpp"
#include "common/constants.hpp"

#include "third_party/magic_enum/magic_enum.hpp"

Expand Down Expand Up @@ -250,6 +251,47 @@ void check_tensor(
check_tensor(tensor.value(), op_name, tensor_name, data_types, layout, check_dtype, check_layout);
}

bool is_hw_dim(uint32_t dim, uint32_t rank) {
if (rank == 1 || rank == 2) {
return true;
}
if (rank >= 3) {
if (dim >= rank - 2) {
return true;
}
}
return false;
}

uint32_t compute_inner(Shape shape, uint32_t dim) {
uint32_t num_inner = 1;
auto rank = shape.rank();

for (uint32_t i = rank - dim; i < rank; i++) {
auto size = shape[i];
if (is_hw_dim(i, rank)) {
size = tt::div_up(size, constants::TILE_WIDTH);
}
num_inner *= size;
}

return num_inner;
}

uint32_t compute_outer(Shape shape, uint32_t dim) {
uint32_t num_outer = 1;
auto rank = shape.rank();

for (uint32_t i = 0; i < rank - dim; i++) {
auto size = shape[i];
if (is_hw_dim(i, rank)) {
size = tt::div_up(size, constants::TILE_WIDTH);
}
num_outer *= size;
}
return num_outer;
}

} // namespace primary
} // namespace operations
} // namespace tt
6 changes: 6 additions & 0 deletions tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,12 @@ auto create_override_addresses_callback(
}


bool is_hw_dim(uint32_t dim, uint32_t rank);

uint32_t compute_inner(Shape shape, uint32_t dim);

uint32_t compute_outer(Shape shape, uint32_t dim);

} // namespace primary
} // namespace operations
} // namespace tt
Loading

0 comments on commit c52e153

Please sign in to comment.