Skip to content

Commit

Permalink
#8632: Use generate_reduce_scaler function
Browse files Browse the repository at this point in the history
  • Loading branch information
dongjin-na committed Jun 1, 2024
1 parent 43b1132 commit 80965ba
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// SPDX-License-Identifier: Apache-2.0

#include "tt_eager/tt_dnn/kernels/dataflow/moreh_common.hpp"
#include "tt_eager/tt_dnn/kernels/dataflow/generate_reduce_scaler.hpp"

void kernel_main() {
uint32_t src_addr = get_arg_val<uint32_t>(0);
Expand All @@ -26,29 +27,7 @@ void kernel_main() {
#ifdef REDUCE_SCALER
constexpr uint32_t cb_id_in2 = 2;
constexpr uint32_t scaler = get_compile_time_arg_val(4);
cb_reserve_back(cb_id_in2, 1);
constexpr uint32_t num_zeros_reads = 2048 / MEM_ZEROS_SIZE;
uint64_t zeros_noc_addr = get_noc_addr(MEM_ZEROS_BASE);
uint32_t write_addr = get_write_ptr(cb_id_in2);
// Fill tile with zeros
for (uint32_t i = 0; i < num_zeros_reads; ++i) {
noc_async_read(zeros_noc_addr, write_addr, MEM_ZEROS_SIZE);
write_addr += MEM_ZEROS_SIZE;
}
noc_async_read_barrier();
if constexpr (scaler != 0) {
volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(get_write_ptr(cb_id_in2));
uint32_t idx = 0;
for (uint32_t k = 0; k < 4; ++k) {
uint32_t curr_idx = idx;
for (uint32_t j = 0; j < 8; ++j) {
ptr[curr_idx] = scaler;
curr_idx++;
}
idx += 128;
}
}
cb_push_back(cb_id_in2, 1);
generate_reduce_scaler(cb_id_in2, scaler);
#endif

constexpr uint32_t cb_id_mask_h = 3;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// SPDX-License-Identifier: Apache-2.0

#include "tt_eager/tt_dnn/kernels/dataflow/moreh_common.hpp"
#include "tt_eager/tt_dnn/kernels/dataflow/generate_reduce_scaler.hpp"

void kernel_main() {
uint32_t src_addr = get_arg_val<uint32_t>(0);
Expand All @@ -13,29 +14,7 @@ void kernel_main() {
constexpr uint32_t scaler = get_compile_time_arg_val(1);

constexpr uint32_t cb_id_in2 = 2;
cb_reserve_back(cb_id_in2, 1);
constexpr uint32_t num_zeros_reads = 2048 / MEM_ZEROS_SIZE;
uint64_t zeros_noc_addr = get_noc_addr(MEM_ZEROS_BASE);
uint32_t write_addr = get_write_ptr(cb_id_in2);
// Fill tile with zeros
for (uint32_t i = 0; i < num_zeros_reads; ++i) {
noc_async_read(zeros_noc_addr, write_addr, MEM_ZEROS_SIZE);
write_addr += MEM_ZEROS_SIZE;
}
noc_async_read_barrier();
if constexpr (scaler != 0) {
volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(get_write_ptr(cb_id_in2));
uint32_t idx = 0;
for (uint32_t k = 0; k < 4; ++k) {
uint32_t curr_idx = idx;
for (uint32_t j = 0; j < 8; ++j) {
ptr[curr_idx] = scaler;
curr_idx++;
}
idx += 128;
}
}
cb_push_back(cb_id_in2, 1);
generate_reduce_scaler(cb_id_in2, scaler);

constexpr uint32_t cb_id_mask_w = 3;
#ifdef DO_MASK_W
Expand Down

0 comments on commit 80965ba

Please sign in to comment.