-
Notifications
You must be signed in to change notification settings - Fork 98
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#5725: Adding bilinear support in upsample
- Loading branch information
1 parent
eedf9af
commit 1673fd9
Showing
14 changed files
with
639 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
64 changes: 64 additions & 0 deletions
64
ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/compute/bilinear.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <cstdint> | ||
|
||
#include "compute_kernel_api/tilize.h" | ||
#include "compute_kernel_api/reduce.h" | ||
#include "compute_kernel_api/pack_untilize.h" | ||
|
||
template<uint32_t in_ntiles_hw, uint32_t in_ntiles_c, uint32_t out_ntiles_c, uint32_t unpA_face_r_dim> | ||
inline void reduce_h_fused( | ||
const uint32_t in_cb_id, | ||
const uint32_t in_scalar_cb_id, | ||
const uint32_t in_ntiles_hwc, | ||
const uint32_t in_stick_index, | ||
const uint32_t out_cb_id) { | ||
|
||
cb_reserve_back(out_cb_id, 1); | ||
tile_regs_acquire(); | ||
cb_wait_front(in_cb_id, 4); | ||
unpack_tilizeA_B_block(in_cb_id, in_scalar_cb_id, in_ntiles_hwc, 0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/, 2 /* unpack 1 or 2 faces ) */, unpA_face_r_dim); | ||
for (uint32_t c_i = 0; c_i < in_ntiles_c; ++c_i) { | ||
reduce_tile_math(c_i, 2 /* reduce 1 or 2 faces */); | ||
} | ||
cb_pop_front(in_cb_id, 4); | ||
|
||
tile_regs_wait(); | ||
tile_regs_commit(); | ||
pack_untilize_dst<out_ntiles_c>(out_cb_id, 1, 0, 1, 2); /* pack 1 row (1x16 or 1x32) */ | ||
tile_regs_release(); | ||
|
||
cb_push_back(out_cb_id, 1); | ||
} | ||
|
||
namespace NAMESPACE{ | ||
void MAIN{ | ||
constexpr uint32_t out_cb_id = tt::CB::c_out0; | ||
constexpr uint32_t in1_cb_id = tt::CB::c_in1; | ||
constexpr uint32_t bias_cb_id = tt::CB::c_in2; | ||
constexpr uint32_t in_scalar_cb_id = tt::CB::c_in4; | ||
constexpr uint32_t in2_cb_id = tt::CB::c_intermed0; | ||
|
||
constexpr uint32_t in_ntiles_hw = get_compile_time_arg_val(0); | ||
constexpr uint32_t in_ntiles_c = get_compile_time_arg_val(1); | ||
constexpr uint32_t in_ntiles_hwc = get_compile_time_arg_val(2); | ||
constexpr uint32_t window_size_hw = get_compile_time_arg_val(3); | ||
constexpr uint32_t out_h = get_compile_time_arg_val(4); | ||
constexpr uint32_t out_w = get_compile_time_arg_val(5); | ||
constexpr uint32_t out_ntiles_c = get_compile_time_arg_val(7); | ||
|
||
constexpr uint32_t nsticks_per_core_by_nblocks = get_compile_time_arg_val(8); | ||
constexpr uint32_t num_output_tiles = out_ntiles_c; //* nblocks; | ||
|
||
tilizeA_B_reduce_init<false, true>(in1_cb_id, in_scalar_cb_id, in_ntiles_hwc, out_cb_id, 2, 4); | ||
pack_untilize_dst_init_short<num_output_tiles>(out_cb_id, 1, 2); /* pack 1 row (1x16 or 1x32) */ | ||
for(uint32_t i = 0; i < nsticks_per_core_by_nblocks; i++){ | ||
cb_wait_front(in_scalar_cb_id, 1); | ||
reduce_h_fused<in_ntiles_hw, in_ntiles_c, out_ntiles_c, window_size_hw>(in1_cb_id, | ||
in_scalar_cb_id, in_ntiles_hwc, i, out_cb_id); | ||
cb_pop_front(in_scalar_cb_id, 1); | ||
} | ||
} // MAIN | ||
} //NAMESPACE |
108 changes: 108 additions & 0 deletions
108
...n/operations/pool/upsample/device/kernels/dataflow/reader_bilinear_multi_core_sharded.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "dataflow_api.h" | ||
#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" | ||
|
||
#define ALWI inline __attribute__((always_inline)) | ||
|
||
// Fill given four values into the memory starting at the given address. | ||
// WARNING: Use with caution as there's no memory protection. Make sure size is within limits | ||
ALWI bool fill_four_val(uint32_t begin_addr, uint16_t val, uint16_t val1, uint16_t val2, uint16_t val3) { | ||
volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(begin_addr); | ||
|
||
ptr[0] = (val | (val1 << 16)); | ||
ptr[1] = (val2 | (val3 << 16)); | ||
return true; | ||
} | ||
|
||
|
||
void kernel_main() { | ||
|
||
uint32_t stick_nbytes = get_arg_val<uint32_t>(0); | ||
uint32_t in_image_rows_per_core = get_arg_val<uint32_t>(1); | ||
uint32_t scale_h = get_arg_val<uint32_t>(2); | ||
uint32_t scale_w = get_arg_val<uint32_t>(3); | ||
uint32_t in_w = get_arg_val<uint32_t>(4); | ||
uint32_t out_w = get_arg_val<uint32_t>(5); | ||
uint32_t src1_addr = get_arg_val<uint32_t>(6); | ||
uint32_t read_offset = get_arg_val<uint32_t>(8); | ||
uint32_t is_last_row = get_arg_val<uint32_t>(9); | ||
uint32_t in_h = 1; | ||
constexpr bool src1_is_dram = false; | ||
|
||
constexpr uint32_t in_cb_id = get_compile_time_arg_val(0); | ||
constexpr uint32_t out_cb_id = tt::CB::c_in1; | ||
constexpr uint32_t is_reader = get_compile_time_arg_val(2); | ||
|
||
uint32_t in_image_row_nbytes = in_w * stick_nbytes; | ||
uint32_t out_image_row_nbytes = out_w * stick_nbytes; | ||
uint32_t reader_image_rows_per_core = (in_image_rows_per_core + is_reader) / 2; | ||
uint32_t writer_image_rows_per_core = in_image_rows_per_core / 2; | ||
uint32_t image_row_begin = is_reader ? 0 : reader_image_rows_per_core; | ||
uint32_t image_row_end = is_reader ? reader_image_rows_per_core : in_image_rows_per_core; | ||
uint32_t l1_read_addr = get_read_ptr(in_cb_id); //+ image_row_begin * in_image_row_nbytes; | ||
constexpr uint32_t in_scalar_cb_id = tt::CB::c_in4; | ||
|
||
// assuming shard begins with a new row. TODO: generalize? | ||
float scale_h_inv = 1.0f / scale_h; | ||
float scale_w_inv = 1.0f / scale_w; | ||
float x, y, x_index, y_index, dx, dy; | ||
y_index = (float)(0.5f) * (float)scale_h_inv + 0.5f; | ||
for (uint32_t image_row = 0 ; image_row < in_image_rows_per_core * scale_h; ++image_row){ | ||
x_index = (float)(0.5f) * (float)scale_w_inv -0.5f; | ||
for(uint32_t j=0; j < in_w * scale_w; j++){ | ||
cb_reserve_back(out_cb_id, 4); | ||
cb_reserve_back(in_scalar_cb_id, 1); | ||
|
||
x = x_index < 0 ? 0 : x_index; | ||
y = y_index < read_offset ? read_offset : y_index; | ||
dx = x - int(x); | ||
dy = y - int(y); | ||
|
||
uint32_t x1 = int(x); | ||
uint32_t y1 = int(y); | ||
uint32_t x2 = min(x1 + 1, in_w-1); | ||
uint32_t y2 = y1 + 1; //, in_image_rows_per_core - 1); | ||
if(is_last_row){ | ||
y2 = min(y2, in_image_rows_per_core); //if last row, y2 should be in_image_rows_per_core | ||
} | ||
|
||
fill_four_val(get_write_ptr(in_scalar_cb_id), float_to_bfloat16((1-dx) * (1-dy)), | ||
float_to_bfloat16(dx * (1 - dy)), float_to_bfloat16((1 - dx) * dy), float_to_bfloat16(dx * dy)); | ||
|
||
uint32_t l1_write_addr = get_write_ptr(out_cb_id); | ||
uint32_t l1_read_addr_temp = l1_read_addr + x1 * stick_nbytes + y1 * in_w * stick_nbytes; | ||
//1st tile | ||
uint64_t src_noc_addr = get_noc_addr(l1_read_addr_temp); | ||
noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes); | ||
l1_write_addr += stick_nbytes; | ||
|
||
//2nd tile | ||
l1_read_addr_temp = l1_read_addr + y1 * in_w * stick_nbytes + x2 * stick_nbytes; | ||
src_noc_addr = get_noc_addr(l1_read_addr_temp); | ||
noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes); | ||
l1_write_addr += stick_nbytes; | ||
|
||
//3rd tile | ||
l1_read_addr_temp = l1_read_addr + y2 * in_w * stick_nbytes + x1 * stick_nbytes; | ||
src_noc_addr = get_noc_addr(l1_read_addr_temp); | ||
noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes); | ||
l1_write_addr += stick_nbytes; | ||
|
||
//4th tile | ||
l1_read_addr_temp = l1_read_addr + y2 * in_w * stick_nbytes + x2 * stick_nbytes; | ||
src_noc_addr = get_noc_addr(l1_read_addr_temp); | ||
noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes); | ||
l1_write_addr += stick_nbytes; | ||
|
||
//push scaler and data into cb. | ||
noc_async_read_barrier(); | ||
cb_push_back(out_cb_id, 4); | ||
cb_push_back(in_scalar_cb_id, 1); | ||
x_index += scale_w_inv; | ||
} | ||
y_index += scale_h_inv; | ||
} | ||
} |
Oops, something went wrong.