Skip to content

Commit

Permalink
Relax Max Pool Requirement For C To Be Power Of 2 (#15022)
Browse files Browse the repository at this point in the history
  • Loading branch information
wransom-TT authored Nov 16, 2024
1 parent d1454aa commit 8dee2c2
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 39 deletions.
5 changes: 5 additions & 0 deletions tests/ttnn/unit_tests/operations/test_maxpool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ def run_max_pool(
[1, 512, 14, 14],
# wide yolo kernel
[1, 512, 10, 10],
[1, 96, 112, 112],
[1, 192, 132, 20],
)
),
)
Expand Down Expand Up @@ -335,6 +337,7 @@ def test_run_max_pool(
[8, 4096, 10, 16],
# wide yolo kernel
[1, 32768, 10, 10],
[1, 6144, 6, 6],
)
),
)
Expand Down Expand Up @@ -429,6 +432,8 @@ def test_run_max_pool_width_shard(
[16, 16, 528, 80],
# wide yolo kernel
[1, 4096, 10, 10],
[1, 768, 56, 56],
[1, 1280, 8, 6],
)
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,24 @@ void kernel_main() {

const int32_t pad_w = get_compile_time_arg_val(3);

// channel size in bytes, multiple of 32
// channel size in bytes
const uint32_t in_nbytes_c = get_compile_time_arg_val(4);

// input tensor height / width / channels
const int32_t in_w = get_compile_time_arg_val(6);
const uint32_t in_cb_nsticks = get_compile_time_arg_val(7);
const int32_t in_w = get_compile_time_arg_val(5);
const uint32_t in_cb_nsticks = get_compile_time_arg_val(6);

const uint32_t in_c = get_compile_time_arg_val(8);
const uint32_t in_c = get_compile_time_arg_val(7);

const uint32_t split_reader = get_compile_time_arg_val(10);
const uint32_t reader_id = get_compile_time_arg_val(11);
const uint32_t split_reader = get_compile_time_arg_val(9);
const uint32_t reader_id = get_compile_time_arg_val(10);

// compile time args
// value of 1 in bf16 in a uin32_t
constexpr uint32_t bf16_one_u32 = get_compile_time_arg_val(12);
constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(13);
constexpr uint32_t in_cb_sz = get_compile_time_arg_val(14);
constexpr uint32_t max_rows_for_reduction = get_compile_time_arg_val(15);
constexpr uint32_t bf16_one_u32 = get_compile_time_arg_val(11);
constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(12);
constexpr uint32_t in_cb_sz = get_compile_time_arg_val(13);
constexpr uint32_t max_rows_for_reduction = get_compile_time_arg_val(14);

constexpr uint32_t TILE_SIZE = 32 * 32;
constexpr uint32_t MAX_TILES_PER_REDUCTION = 8;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,22 @@ void kernel_main() {

const int32_t pad_w = get_compile_time_arg_val(3);

// channel size in bytes, multiple of 32
// channel size in bytes
const uint32_t in_nbytes_c = get_compile_time_arg_val(4);
const uint32_t in_nbytes_c_log2 = get_compile_time_arg_val(5);

// input tensor height / width / channels
const int32_t in_w = get_compile_time_arg_val(6);
const uint32_t in_cb_nsticks = get_compile_time_arg_val(7);
const int32_t in_w = get_compile_time_arg_val(5);
const uint32_t in_cb_nsticks = get_compile_time_arg_val(6);

const uint32_t in_c = get_compile_time_arg_val(8);
const uint32_t in_c = get_compile_time_arg_val(7);

const uint32_t split_reader = get_compile_time_arg_val(10);
const uint32_t reader_id = get_compile_time_arg_val(11);
const uint32_t split_reader = get_compile_time_arg_val(9);
const uint32_t reader_id = get_compile_time_arg_val(10);

// value of 1 in bf16 in a uin32_t
constexpr uint32_t bf16_one_u32 = get_compile_time_arg_val(12);
constexpr uint32_t bf16_one_u32 = get_compile_time_arg_val(11);

constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(13);
constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(12);

constexpr uint32_t TILE_WIDTH = 32;

Expand Down Expand Up @@ -99,7 +98,7 @@ void kernel_main() {
uint32_t h_multiples = 0;
for (uint32_t h = 0; h < window_h; ++ h, h_multiples += in_w_padded) {
uint32_t stick_offset = top_left_local_index + h_multiples;
uint32_t read_offset = in_l1_read_base_addr + (stick_offset << in_nbytes_c_log2);
uint32_t read_offset = in_l1_read_base_addr + (stick_offset * in_nbytes_c);
noc_async_read_one_packet(get_noc_addr(read_offset), out_l1_write_addr, in_nbytes_c * window_w);
out_l1_write_addr += in_nbytes_c * window_w;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,20 @@ void kernel_main() {

// channel size in bytes, multiple of 32
const uint32_t in_nbytes_c = get_compile_time_arg_val(4);
const uint32_t in_nbytes_c_log2 = get_compile_time_arg_val(5);

// input tensor height / width / channels
const int32_t in_w = get_compile_time_arg_val(6);
const uint32_t in_cb_nsticks = get_compile_time_arg_val(7);
const int32_t in_w = get_compile_time_arg_val(5);
const uint32_t in_cb_nsticks = get_compile_time_arg_val(6);

const uint32_t in_c = get_compile_time_arg_val(8);
const uint32_t in_c = get_compile_time_arg_val(7);

const uint32_t split_reader = get_compile_time_arg_val(10);
const uint32_t reader_id = get_compile_time_arg_val(11);
const uint32_t split_reader = get_compile_time_arg_val(9);
const uint32_t reader_id = get_compile_time_arg_val(10);

// value of 1 in bf16 in a uin32_t
constexpr uint32_t bf16_one_u32 = get_compile_time_arg_val(12);
constexpr uint32_t bf16_one_u32 = get_compile_time_arg_val(11);

constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(13);
constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(12);

// static_assert(0 == reader_nindices%2, "reader_nindices must be multiple of 2");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ void validate_maxpool(const Tensor& input, const sliding_window::SlidingWindowCo
TT_FATAL(input.get_dtype() == DataType::BFLOAT16, "Only BFLOAT16 supported for now");
TT_FATAL(input.get_layout() == Layout::ROW_MAJOR, "Only ROW_MAJOR supported for now");

// NOTE: This is not a hard requirement. If need to support non-power-of-2, simply change the address generator in reader to generic one.
uint32_t in_nbytes_c = (input.get_legacy_shape()[3]) * (input.get_dtype() == DataType::BFLOAT16 ? 2 : 1);
bool is_pow2 = (in_nbytes_c & (in_nbytes_c - 1)) == 0;
TT_FATAL(is_pow2, "Row size (nchannels * bytes = {}) should be power of 2 ({}).", in_nbytes_c, is_pow2);

TT_FATAL(input.memory_config().is_sharded(), "Input needs to be sharded");
TT_FATAL(out_mem_config.is_sharded(), "Output memory config needs to be sharded");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_

uint32_t in_nbytes_c = input_shape[3] / num_shards_c * in_nbytes; // row of input (channels)
uint32_t out_nbytes_c = output_shape[3] / num_shards_c * out_nbytes; // row of output (channels)
TT_ASSERT((in_nbytes_c & (in_nbytes_c - 1)) == 0, "in_nbytes_c should be power of 2"); // in_nbytes_c is power of 2
TT_ASSERT(
(out_nbytes_c & (out_nbytes_c - 1)) == 0, "out_nbytes_c should be power of 2"); // out_nbytes_c is power of 2

tt::DataFormat indices_df = tt::DataFormat::RawUInt16; // datatype_to_dataformat_converter(reader_indices.get_dtype());
uint32_t indices_nbytes = datum_size(indices_df);
Expand Down Expand Up @@ -282,14 +279,12 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_
*/
float one = 1.;
uint32_t bf16_one_u32 = *reinterpret_cast<uint32_t*>(&one);
uint32_t in_nbytes_c_log2 = (uint32_t)std::log2((float)in_nbytes_c);
std::vector<uint32_t> reader0_ct_args = {
out_nhw_per_core,
kernel_size_h,
kernel_size_w,
pad_w,
in_nbytes_c,
in_nbytes_c_log2,
in_w,
in_cb_page_padded * in_cb_npages / tile_w,
input_shape[3] / num_shards_c,
Expand All @@ -307,7 +302,6 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_
kernel_size_w,
pad_w,
in_nbytes_c,
in_nbytes_c_log2,
in_w,
in_cb_page_padded * in_cb_npages / tile_w,
input_shape[3] / num_shards_c,
Expand Down

0 comments on commit 8dee2c2

Please sign in to comment.