Skip to content

Commit

Permalink
Make pipeline work. Clean up needed.
Browse files Browse the repository at this point in the history
Signed-off-by: Nilaykumar K Patel <[email protected]>
  • Loading branch information
nkpatel-tt committed Nov 21, 2024
1 parent 82c96ee commit d84a302
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 27 deletions.
16 changes: 11 additions & 5 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ ParallelConfig determine_parallel_config(
uint32_t output_channels,
const CoreCoord& compute_grid_size,
ShardOrientation block_shard_orientation,
bool is_out_tiled) {
bool is_out_tiled,
bool is_non_tile_mul_width) {

uint32_t effective_tile_height = is_out_tiled ? tt::constants::TILE_HEIGHT : 1;
uint32_t effective_tile_width = is_out_tiled ? tt::constants::TILE_WIDTH : 1;
Expand All @@ -104,7 +105,11 @@ ParallelConfig determine_parallel_config(
uint32_t start_divisor =
block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.x : compute_grid_size.y;
num_cores_nhw = find_closest_largest_divisor_with_num_padding(out_nhw_ntiles, start_divisor);
uint32_t num_cores_c = find_closest_common_largest_divisor(out_c_ntiles, std::ceil((float)input_channels / effective_tile_width), block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.y : compute_grid_size.x);
auto channels_per_core = std::ceil((float)input_channels / effective_tile_width);
if(is_non_tile_mul_width)
channels_per_core = input_channels / effective_tile_width;

uint32_t num_cores_c = find_closest_common_largest_divisor(out_c_ntiles, channels_per_core, block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.y : compute_grid_size.x);
uint32_t cores_x = block_shard_orientation == ShardOrientation::COL_MAJOR ? num_cores_nhw : num_cores_c;
uint32_t cores_y = block_shard_orientation == ShardOrientation::COL_MAJOR ? num_cores_c : num_cores_nhw;
CoreRange core_range = CoreRange(CoreCoord({0, 0}), CoreCoord({cores_x - 1, cores_y - 1}));
Expand Down Expand Up @@ -443,12 +448,14 @@ std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool, bool> get_conv_padded_input_sh
bool use_non_tile_height = (shard_layout == TensorMemoryLayout::HEIGHT_SHARDED) && out_channels <= 256 && conv_config.act_block_h_override == 0 &&
(conv_config.dtype == DataType::BFLOAT16 || conv_config.dtype == DataType::FLOAT32) && conv_config.output_layout == Layout::ROW_MAJOR && conv_config.input_channels_alignment != 16; //shalow conv varient

/*bool is_non_tile_mul_width = (shard_layout == TensorMemoryLayout::BLOCK_SHARDED) && out_channels <= 256 && conv_config.act_block_h_override == 0 &&*/
/* (conv_config.dtype == DataType::BFLOAT16 || conv_config.dtype == DataType::FLOAT32) && conv_config.output_layout == Layout::ROW_MAJOR && conv_config.input_channels_alignment != 16; //shalow conv varient*/
ParallelConfig parallel_config = input_tensor_parallel_config;
if (conv_config.reshard_if_not_optimal || needs_shard_or_reshard) {
auto block_shard_orientation =
conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR;
ParallelConfig optimal_parallel_config = determine_parallel_config(
shard_layout, batch_size, in_channels, height, width, out_channels, device->compute_with_storage_grid_size(), block_shard_orientation, !use_non_tile_height);
shard_layout, batch_size, in_channels, height, width, out_channels, device->compute_with_storage_grid_size(), block_shard_orientation, !use_non_tile_height, false);

if (conv_config.override_sharding_config) {
TT_FATAL(conv_config.core_grid.has_value(), "Error");
Expand Down Expand Up @@ -752,10 +759,9 @@ ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_co
}
return matmul_config;
} else {
TT_ASSERT(conv_blocking_config.act_block_w_ntiles % grid_size_along_c == 0);
ttnn::operations::matmul::MatmulMultiCoreReuseMultiCastProgramConfig matmul_config = {
.compute_with_storage_grid_size = conv_parallelization_config.grid_size,
.in0_block_w = conv_blocking_config.act_block_w_ntiles / grid_size_along_c,
.in0_block_w = conv_blocking_config.act_block_w_ntiles,
.out_subblock_h = conv_blocking_config.out_subblock_h_ntiles,
.out_subblock_w = conv_blocking_config.out_subblock_w_ntiles,
.out_block_h = div_up(conv_parallelization_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT),
Expand Down
3 changes: 2 additions & 1 deletion ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ sliding_window::ParallelConfig determine_parallel_config(
uint32_t output_channels,
const CoreCoord& compute_grid_size,
ShardOrientation block_shard_orientation,
bool is_out_tiled=true);
bool is_out_tiled=true,
bool is_non_tile_mul_width=false);

uint32_t get_num_cores_nhw_from_parallel_config(const sliding_window::ParallelConfig& pconfig);

Expand Down
8 changes: 5 additions & 3 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,10 @@ void py_bind_conv2d(py::module& module) {
uint32_t output_channels,
const CoreCoord& compute_grid_size,
ShardOrientation block_shard_orientation,
bool is_out_tiled) -> ttnn::operations::sliding_window::ParallelConfig {
bool is_out_tiled,
bool is_non_tile_mul_width) -> ttnn::operations::sliding_window::ParallelConfig {
return ttnn::operations::conv::conv2d::determine_parallel_config(
shard_layout, batch_size, input_channels, output_height, output_width, output_channels, compute_grid_size, block_shard_orientation, is_out_tiled);
shard_layout, batch_size, input_channels, output_height, output_width, output_channels, compute_grid_size, block_shard_orientation, is_out_tiled, is_non_tile_mul_width);
},
py::arg("shard_layout"),
py::arg("batch_size"),
Expand All @@ -225,7 +226,8 @@ void py_bind_conv2d(py::module& module) {
py::arg("output_channels"),
py::arg("compute_grid_size"),
py::arg("block_shard_orientation"),
py::arg("is_out_tiled") = true);
py::arg("is_out_tiled") = true,
py::arg("is_non_tile_mul_width") = false);

module.def(
"create_sharded_memory_config_from_parallel_config",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ const uint32_t act_cb_second_reader = CB::c_in7;
const uint32_t matmul_partials_cb = CB::c_intermed0;
const uint32_t tilize_mode_tilized_act_cb = CB::c_intermed1;
const uint32_t untilize_mode_reblock_cb = CB::c_intermed2;
const uint32_t untilized_padded_out_cb = CB::c_intermed3;
const uint32_t out0_cb = CB::c_out0;
const uint32_t temp_sum_cb = CB::c_intermed3;
const uint32_t untilized_padded_out_cb = CB::c_intermed4;
}
}

Expand Down Expand Up @@ -182,17 +182,18 @@ std::tuple<CBHandle, CBHandle> create_CBs_for_sharded_input_v2(
bool need_unpad_after_untilize = output_shard_shape[1] * output_shard_shape[0] < num_writer_output_tiles * TILE_HW;
// If only width is non-tile multiple
if (need_unpad_after_untilize && !use_non_tile_height && weight_width_sliced) {
uint32_t num_bytes_for_df = datum_size(out_df);
CircularBufferConfig compute_cb_output_config =
CircularBufferConfig(num_writer_output_tiles * out_tile_size, {{untilized_padded_out_cb, out_df}})
.set_page_size(untilized_padded_out_cb, out_tile_size);
auto compute_cb_output = tt_metal::CreateCircularBuffer(program, core, compute_cb_output_config);
uint32_t num_bytes_for_df = datum_size(out_df);
log_debug(LogOp, "untilized padded out CB(shard widht non-tile multiple): {}, npages: {}, pagesize: {}", untilized_padded_out_cb, num_writer_output_tiles, out_tile_size * num_bytes_for_df);
CircularBufferConfig cb_output_config =
CircularBufferConfig(num_bytes_for_df * output_shard_shape[0] * output_shard_shape[1], {{out0_cb, out_df}})
.set_page_size(out0_cb, output_shard_shape[1] * num_bytes_for_df);
cb_output_config = cb_output_config.set_globally_allocated_address(*output.buffer());
cb_output = tt_metal::CreateCircularBuffer(program, core, cb_output_config);
log_debug(LogOp, "output CB(shard widht non-tile multiple): {}, npages: {}, pagesize: {}", out0_cb, output_shard_shape[0] * output_shard_shape[1], output_shard_shape[1] * num_bytes_for_df);
log_debug(LogOp, "output CB(shard widht non-tile multiple): {}, npages: {}, pagesize: {}", out0_cb, output_shard_shape[0], output_shard_shape[1] * num_bytes_for_df);
} else {
auto shard_shape = output.shard_spec().value().shape;
uint32_t aligned_output_stick_nbytes = use_non_tile_height ? shard_shape[1] * output.element_size() : out_tile_size;
Expand Down Expand Up @@ -1545,6 +1546,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(
writer_rt_args.push_back(num_cores_x - 1); // weights_mcast_num_cores
writer_rt_args.push_back(weights_mcast_sender_semaphore_id);
writer_rt_args.push_back(weights_mcast_receiver_semaphore_id);
writer_rt_args.push_back(output.buffer()->aligned_page_size());

SetRuntimeArgs(program, writer_mcast_sender_id, core, writer_rt_args);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,40 @@
#include "compute_kernel_api/pack_untilize.h"
#include "compute_kernel_api/tile_move_copy.h"
#include "compute_kernel_api/matmul.h"
// #include "debug/dprint.h"
#include "debug/dprint.h"

#ifdef FUSE_BIAS
#include "compute_kernel_api/bcast.h"
#endif

#include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h"

#define DEBUG_PRINT 0
#define DEBUG_PRINT 1

// #include "debug_macros.h"
#define dump_unpack(a) \
do { DPRINT_UNPACK(DPRINT << "UP: "<< #a " = " << a << ENDL()); } while(false)
#define dump_pack(a) \
do { DPRINT_PACK(DPRINT << "P: "<< #a " = " << a << ENDL()); } while(false)
#define dump_math(a) \
do { DPRINT_MATH(DPRINT << "M: "<< #a " = " << a << ENDL()); } while(false)


// SliceRange srt = SliceRange{.h0 = 0, .h1 = 4, .hs = 1, .w0 = 0, .w1 = 8, .ws = 1};
// SliceRange srr = SliceRange{.h0 = 0, .h1 = 1, .hs = 8, .w0 = 0, .w1 = 32, .ws = 1};
// SliceRange srr1 = SliceRange{.h0 = 1, .h1 = 2, .hs = 8, .w0 = 0, .w1 = 32, .ws = 1};
// SliceRange src = SliceRange{.h0 = 0, .h1 = 32, .hs = 1, .w0 = 0, .w1 = 1, .ws = 1};

inline void print_full_tile(uint32_t cb_id, uint32_t tile_id = 0, bool untilize = false) {
DPRINT_UNPACK(DPRINT << "======" << ENDL());
for (uint8_t r = 0; r < 32; ++ r) {
//for (int32_t r = 0; r < 1; ++ r) {
SliceRange sr = SliceRange{.h0 = r, .h1 = (uint8_t)(r+1), .hs = 1, .w0 = 0, .w1 = 32, .ws = 1};
DPRINT_UNPACK(DPRINT << (uint)r << " " << TileSlice(cb_id, tile_id, sr, true, untilize) << ENDL());
}
DPRINT_UNPACK(DPRINT << "++++++" << ENDL());
}

inline void tilize_in(
uint32_t in_cb_id,
uint32_t in_subblock_h,
Expand Down Expand Up @@ -115,6 +132,20 @@ void MAIN {
constexpr uint32_t out_block_w = in1_block_w;
constexpr bool spill = in0_num_blocks_w > 1;

/*dump_unpack(out_subblock_h);*/
/*dump_unpack(out_subblock_w);*/
/*dump_unpack(out_subblock_num_tiles);*/
/*dump_unpack(in0_block_w);*/
/*dump_unpack(in0_num_subblocks);*/
/*dump_unpack(in0_block_num_tiles);*/
/*dump_unpack(in0_subblock_num_tiles);*/
/*dump_unpack(in0_subblock_h);*/
/*dump_unpack(in1_num_subblocks);*/
/*dump_unpack(in1_block_num_tiles);*/
/*dump_unpack(in1_block_w);*/
/*dump_unpack(in0_num_blocks_h);*/
/*dump_unpack(in0_num_blocks_w);*/
/*dump_unpack(in1_num_blocks_w);*/
// CB indices
constexpr uint32_t in0_cb_id = tt::CB::c_in0;
constexpr uint32_t in1_cb_id = tt::CB::c_in1;
Expand Down Expand Up @@ -222,6 +253,8 @@ void MAIN {
pack_reconfig_data_format(curr_matmul_out_cb);
#endif
uint32_t in0_index_subblock_offset = 0;
dump_unpack(in0_num_subblocks);
dump_unpack(in1_num_subblocks);
for (uint32_t in0_subblock_i = 0; in0_subblock_i < in0_num_subblocks; ++in0_subblock_i) {
uint32_t in1_index_subblock_offset = 0;
for (uint32_t in1_subblock_i = 0; in1_subblock_i < in1_num_subblocks; ++in1_subblock_i) {
Expand Down Expand Up @@ -291,6 +324,16 @@ void MAIN {

tile_regs_release();
cb_push_back(curr_matmul_out_cb, out_subblock_num_tiles);
/*if(in1_block_w_i == 0 && in0_block_h_i == 0 && in0_block_w_i == 7) {*/
/* dump_math(curr_matmul_out_cb);*/
/* dump_math(out_subblock_num_tiles);*/
/* dump_math(matmul_partials_cb);*/
/* dump_math(out_cb_id);*/
/* print_full_tile(curr_matmul_out_cb);*/
/* print_full_tile(curr_matmul_out_cb, 1);*/
/* print_full_tile(curr_matmul_out_cb, 2);*/
/* print_full_tile(curr_matmul_out_cb, 3);*/
/*}*/

in1_index_subblock_offset += out_subblock_w;
} // for in1_num_subblocks
Expand Down Expand Up @@ -393,6 +436,7 @@ void MAIN {
}
tile_regs_release();
cb_push_back(untilize_mode_out_cb_id, out_subblock_num_tiles);
print_full_tile(untilize_mode_out_cb_id);

in1_index_subblock_offset += out_subblock_w;
} // for in1_num_subblocks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#if ENABLE_DEBUG
#include "debug/dprint.h"

#define DUMP(a) \
#define dump(a) \
do { DPRINT << "Activations: "<< #a " = " << a << ENDL(); } while(false)

inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uint32_t start = 0) {
Expand Down Expand Up @@ -77,7 +77,7 @@ void kernel_main() {
constexpr uint32_t window_inner = get_compile_time_arg_val(9);
constexpr uint32_t act_block_h_datums = get_compile_time_arg_val(10);
constexpr uint32_t padded_conv_act_size_w = get_compile_time_arg_val(13);
constexpr uint32_t act_block_w_extra_align_bytes = get_compile_time_arg_val(14);
constexpr uint32_t act_block_w_extra_align_bytes = get_compile_time_arg_val(14);
constexpr uint32_t act_num_blocks_h = get_compile_time_arg_val(16);
constexpr uint32_t act_block_num_tiles = get_compile_time_arg_val(17);
constexpr uint32_t act_w_num_outer = get_compile_time_arg_val(18);
Expand Down Expand Up @@ -174,7 +174,7 @@ void kernel_main() {
// noc_async_read_inc_num_issued(num_issued_reads_per_block); // "false" on read
noc_async_read_barrier();
/*DPRINT << "Read activations " << ENDL();*/
/*print_pages(get_write_ptr(cb_id_act_row_major_bfloat16), 32*9, act_mcast_sender_size_bytes/2/32/9);*/
/*print_pages(get_write_ptr(cb_id_act_row_major_bfloat16), 12*32, act_mcast_sender_size_bytes/2/12/32);*/
/*print_pages(get_read_ptr(cb_id_sharded_act), 16, 64);*/
cb_push_back(cb_id_act_row_major_bfloat16, act_block_num_tiles);

Expand Down Expand Up @@ -230,6 +230,7 @@ void kernel_main() {
noc_semaphore_wait(act_mcast_receiver_semaphore_addr_ptr, VALID);
}
cb_push_back(cb_id_act, act_block_num_tiles);
/*print_pages(get_read_ptr(cb_id_act), 40, 9*3);*/
} // act_w_num_outer
cb_pop_front(tilized_in0_cb_id, act_block_num_tiles);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ void kernel_main() {
constexpr uint32_t out_addr = get_compile_time_arg_val(29);

#ifdef UNPAD_UNTILIZE_OUT
constexpr uint32_t out_block_width_ntiles = get_compile_time_arg_val(32);
constexpr uint32_t out_block_width_padded_bytes = get_compile_time_arg_val(33);
constexpr uint32_t out_block_width_bytes = get_compile_time_arg_val(34);
constexpr uint32_t untilized_padded_out_cb = get_compile_time_arg_val(35);
constexpr uint32_t out_block_width_ntiles = get_compile_time_arg_val(33);
constexpr uint32_t out_block_width_padded_bytes = get_compile_time_arg_val(34);
constexpr uint32_t out_block_width_bytes = get_compile_time_arg_val(35);
constexpr uint32_t untilized_padded_out_cb = get_compile_time_arg_val(36);
#endif
uint32_t i = 0;
i+=19;
Expand Down Expand Up @@ -214,6 +214,7 @@ void kernel_main() {
for (uint32_t bh = 0; bh < out_block_height_num_tiles; bh++) {
/*DPRINT << "Waiting for out_block_width_ntiles: " << out_block_width_ntiles << ENDL();*/
cb_wait_front(untilized_padded_out_cb, out_block_width_ntiles);
/*print_pages(get_read_ptr(untilized_padded_out_cb), 32, 32);*/
uint32_t src_cb_addr = get_read_ptr(untilized_padded_out_cb);
/*DPRINT << "src_cb_addr: " << src_cb_addr << ENDL();*/
/*DPRINT << "Done waiting for out_block_width_ntiles: " << out_block_width_ntiles << ENDL();*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#if ENABLE_DEBUG
#include "debug/dprint.h"

#define DUMP(a) \
#define dump(a) \
do { DPRINT << "Sender: "<< #a " = " << a << ENDL(); } while(false)

inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uint32_t start = 0) {
Expand Down Expand Up @@ -64,10 +64,10 @@ void kernel_main() {
constexpr uint32_t out_addr = get_compile_time_arg_val(29);

#ifdef UNPAD_UNTILIZE_OUT
constexpr uint32_t out_block_width_ntiles = get_compile_time_arg_val(32);
constexpr uint32_t out_block_width_padded_bytes = get_compile_time_arg_val(33);
constexpr uint32_t out_block_width_bytes = get_compile_time_arg_val(34);
constexpr uint32_t untilized_padded_out_cb = get_compile_time_arg_val(35);
constexpr uint32_t out_block_width_ntiles = get_compile_time_arg_val(33);
constexpr uint32_t out_block_width_padded_bytes = get_compile_time_arg_val(34);
constexpr uint32_t out_block_width_bytes = get_compile_time_arg_val(35);
constexpr uint32_t untilized_padded_out_cb = get_compile_time_arg_val(36);
#endif
uint32_t i = 0;
i+=1;
Expand Down Expand Up @@ -95,6 +95,7 @@ void kernel_main() {
uint32_t weights_mcast_sender_semaphore_addr = get_semaphore(get_arg_val<uint32_t>(i)); i+=1;
uint32_t weights_mcast_receiver_semaphore_addr = get_semaphore(get_arg_val<uint32_t>(i)); i+=1;
uint32_t out_aligned_page_size = get_arg_val<uint32_t>(i); i+=1;
dump(out_aligned_page_size);

#ifndef SKIP_MCAST
// Set ur local VALID value, to be mcasted to destinations flag address after the data has been mcasted
Expand Down Expand Up @@ -188,6 +189,8 @@ void kernel_main() {
}

noc_async_read_barrier();
dump(weight_tile_nbytes);
/*print_pages(get_write_ptr(cb_id_weight), 32*32, 12);*/

#ifndef SKIP_MCAST
// wait until all weights mcast destinations have atomically incremented the weights semaphore_addr (i.e. its value should be weights_mcast_num_dests), then reset
Expand Down Expand Up @@ -345,7 +348,7 @@ void kernel_main() {
for (uint32_t r = 0; r < 32; r++) {
noc_async_read(get_noc_addr(src_cb_addr), dst_cb_addr, out_block_width_bytes);
noc_async_read_barrier();
/*print_pages(get_noc_addr(src_cb_addr), out_block_width_bytes, 1);*/
/*print_pages(get_noc_addr(src_cb_addr), out_block_width_bytes / 2, 1);*/
src_cb_addr += out_block_width_padded_bytes;
/*dst_cb_addr += out_block_width_bytes;*/

Expand Down

0 comments on commit d84a302

Please sign in to comment.