Skip to content

Commit

Permalink
Rebase on main.
Browse files Browse the repository at this point in the history
It hangs in activation reader kernel since CreateSemaphore api return type is changed

Signed-off-by: Nilaykumar K Patel <[email protected]>
  • Loading branch information
nkpatel-tt committed Aug 15, 2024
1 parent ff3a449 commit 2f7ba35
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 52 deletions.
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def run_conv(
# for k in range(0, filter_height):
# for l in range(0, filter_width):
# torch_weight_tensor[i, j, k, l] = 1
torch_bias_tensor = torch.randn(conv_bias_shape, dtype=torch.bfloat16).float() if has_bias else None
torch_bias_tensor = torch.zeros(conv_bias_shape, dtype=torch.bfloat16).float() if has_bias else None
torch_out_golden_tensor = torch.nn.functional.conv2d(
torch_input_tensor_nchw,
torch_weight_tensor,
Expand Down
15 changes: 8 additions & 7 deletions ttnn/cpp/ttnn/operations/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ std::tuple<ttnn::Tensor, ParallelConfig, bool> shard_or_reshard_tensor_if_requir
}
}
auto shape = input_tensor.shard_spec()->shape;
cout << "shard shape = " << shape[0] << ", " << shape[1] << endl;
std::cout << "shard shape = " << shape[0] << ", " << shape[1] << std::endl;;
return {input_tensor, parallel_config, needs_shard_or_reshard};
}

Expand Down Expand Up @@ -533,6 +533,7 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases
}
weight_tensor_ = ttnn::operations::core::to_device(weight_tensor_, device, std::nullopt);
if (bias_tensor.has_value()) {
std::cout << "bias had values " << std::endl;
if (parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) {
bias_tensor_ = bias_tensor.value();
auto bias_shape = bias_tensor_.get_shape();
Expand All @@ -545,7 +546,7 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases
if (bias_tensor_.get_dtype() != weights_bias_dtype) {
bias_tensor_ = ttnn::to_dtype(bias_tensor_, weights_bias_dtype);
}
bias_tensor_ = ttnn::operations::core::to_device(bias_tensor_, device, nullopt);
bias_tensor_ = ttnn::operations::core::to_device(bias_tensor_, device, std::nullopt);
} else {
bias_tensor_ = convert_conv_bias_tensor_to_tiled_layout_block_sharded(
bias_tensor.value(), num_cores_c, weights_bias_dtype);
Expand All @@ -557,7 +558,7 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases
}

ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_conv_op_config(
tt::tt_metal::OptimizedConvParallelizationConfig conv_parallelization_config,
tt::tt_metal::OptimizedConvParallelizationConfigNew conv_parallelization_config,
tt::tt_metal::OptimizedConvBlockConfig conv_blocking_config,
bool height_sharded,
string activation,
Expand All @@ -569,8 +570,8 @@ ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_co
.in0_block_w = conv_blocking_config.act_block_w_ntiles,
.out_subblock_h = conv_blocking_config.out_subblock_h_ntiles,
.out_subblock_w = conv_blocking_config.out_subblock_w_ntiles,
.per_core_M = conv_parallelization_config.per_core_out_matrix_height * TILE_HEIGHT,
.per_core_N = conv_parallelization_config.per_core_out_matrix_width * TILE_WIDTH,
.per_core_M = conv_parallelization_config.per_core_out_matrix_height,
.per_core_N = conv_parallelization_config.per_core_out_matrix_width,
.fuse_batch = true,
.mcast_in0 = false};
if (activation != "") {
Expand All @@ -584,8 +585,8 @@ ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_co
.in0_block_w = conv_blocking_config.act_block_w_ntiles / grid_size_along_c,
.out_subblock_h = conv_blocking_config.out_subblock_h_ntiles,
.out_subblock_w = conv_blocking_config.out_subblock_w_ntiles,
.per_core_M = conv_parallelization_config.per_core_out_matrix_height * TILE_HEIGHT,
.per_core_N = conv_parallelization_config.per_core_out_matrix_width * TILE_WIDTH,
.per_core_M = conv_parallelization_config.per_core_out_matrix_height,
.per_core_N = conv_parallelization_config.per_core_out_matrix_width,
.transpose_mcast = transpose_mcast};
if (activation != "") {
matmul_config.fused_activation = ttnn::operations::unary::utils::string_to_unary_with_param(activation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uin
}
#endif

int temp = 0;
FORCE_INLINE
void read_channels(uint32_t& l1_write_addr_act, const uint32_t act_l1_read_addr, const uint32_t reader_channel_idx,
const uint32_t conv_act_c_read_bytes, const uint32_t coalesced_read_bytes, const uint32_t stride_h_bytes) {
Expand All @@ -32,12 +31,6 @@ void read_channels(uint32_t& l1_write_addr_act, const uint32_t act_l1_read_addr,
#pragma GCC unroll unroll_factor
for (uint32_t inner = 0; inner < WINDOW_INNER; inner++) {
noc_async_read_one_packet_with_state<true>(act_l1_read_addr_plus_offset, l1_write_addr_act);
if(temp < 10) {
/*DPRINT << "stride h bytes = " << stride_h_bytes << ENDL();*/
/*DPRINT << "window inner = " << WINDOW_INNER << ENDL();*/
/*print_pages(act_l1_read_addr_plus_offset,coalesced_read_bytes/2 ,1);*/
temp++;
}
l1_write_addr_act += coalesced_read_bytes;
// +2 is hard-coded, TODO: generalize
act_l1_read_addr_plus_offset += stride_h_bytes;
Expand Down Expand Up @@ -95,7 +88,7 @@ void kernel_main() {
// Set up local VALID value, to be mcasted to destinations flag address after the data has been mcasted
volatile tt_l1_ptr uint32_t* act_mcast_sender_semaphore_valid_addr_ptr = &l1_array[0];
act_mcast_sender_semaphore_valid_addr_ptr[0] = 1; // Load const 1 to be used as semaphore valid value sent from sender to receivers
uint32_t act_mcast_sender_semaphore_valid_addr = reinterpret_cast<uint32_t>(&l1_array[0]);
uint32_t act_mcast_sender_semaphore_valid_addr = get_semaphore(reinterpret_cast<uint32_t>(&l1_array[0]));

// Set up remote VALID value
volatile tt_l1_ptr uint32_t* act_mcast_receiver_semaphore_addr_ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(act_mcast_receiver_semaphore_addr);
Expand All @@ -117,11 +110,7 @@ void kernel_main() {

// TODO: need to make the read coalescing optimization cleaner
// currently works for the case of num_coalesced_reads == weight_size_w since these reads are contiguous on both src/dst side
/*DPRINT << "weight_size_w = " << weight_size_w << ENDL();*/
/*DPRINT << "conv_act_c_read_bytes = " << conv_act_c_read_bytes << ENDL();*/
/*DPRINT << "conv_act_size_w = " << conv_act_size_w << ENDL();*/
constexpr uint32_t coalesced_read_bytes = weight_size_w * conv_act_c_read_bytes;
/*DPRINT << "coalesced_read_bytes = " << coalesced_read_bytes << ENDL();*/


// Fully create act matrix and tilize it before mcast
Expand All @@ -131,8 +120,6 @@ void kernel_main() {

// Reset reader_idx to finish act_block_h_datums
uint32_t reader_idx = 0;
/*DPRINT << "act_num_blocks_h = " << act_num_blocks_h << ENDL();*/
/*DPRINT << "act_w_num_outer = " << act_w_num_outer << ENDL();*/
for (uint32_t nbh = 0; nbh < act_num_blocks_h; nbh++) {
cb_reserve_back(cb_id_act_row_major_bfloat16, act_block_num_tiles);
uint32_t l1_write_addr_act = get_write_ptr(cb_id_act_row_major_bfloat16);
Expand All @@ -152,9 +139,6 @@ void kernel_main() {
// incrementing num issued in one shot is actually slower
// noc_async_read_inc_num_issued(num_issued_reads_per_block); // "false" on read
noc_async_read_barrier();
/*DPRINT << "Read activations " << ENDL();*/
/*print_pages(get_write_ptr(cb_id_act_row_major_bfloat16), 32*9, act_mcast_sender_size_bytes/2/32/9);*/
/*print_pages(get_read_ptr(cb_id_sharded_act), 16, 64);*/
cb_push_back(cb_id_act_row_major_bfloat16, act_block_num_tiles);

// Round robin self-mcast and receive tilized act matrix in cb_id_act
Expand Down Expand Up @@ -190,6 +174,11 @@ void kernel_main() {
} else {
// MCAST RECEIVER: receive entire tilized input from sender core
// Set act semaphore value to INVALID
DPRINT << "in reader ekrnel" << ENDL();
DPRINT << "in reader ekrnel" << ENDL();
DPRINT << "in reader ekrnel" << ENDL();
DPRINT << "in reader ekrnel" << ENDL();
DPRINT << act_w_num_outer << ENDL();
noc_semaphore_set(act_mcast_receiver_semaphore_addr_ptr, INVALID);

// Atomic increment source core counter
Expand All @@ -200,9 +189,17 @@ void kernel_main() {
act_mcast_sender_semaphore_noc_addr = get_noc_addr(act_mcast_sender_noc_y[act_w_outer_i], act_mcast_sender_noc_x, act_mcast_sender_semaphore_addr);
}
noc_semaphore_inc(act_mcast_sender_semaphore_noc_addr, 1);
DPRINT << "in reader ekrnel" << ENDL();
DPRINT << "in reader ekrnel" << ENDL();
DPRINT << "in reader ekrnel" << ENDL();
DPRINT << act_w_num_outer << ENDL();

// wait on act semaphore value to become VALID (set by mcast sender after it multicasts data)
noc_semaphore_wait(act_mcast_receiver_semaphore_addr_ptr, VALID);
DPRINT << "in reader ekrnel" << ENDL();
DPRINT << "in reader ekrnel" << ENDL();
DPRINT << "in reader ekrnel" << ENDL();
DPRINT << act_w_num_outer << ENDL();
}
cb_push_back(cb_id_act, act_block_num_tiles);
} // act_w_num_outer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ void kernel_main() {
// mcast args
uint32_t weights_mcast_sender_noc_x = get_arg_val<uint32_t>(i); i+=1;
uint32_t weights_mcast_sender_noc_y = get_arg_val<uint32_t>(i); i+=1;
uint32_t weights_mcast_sender_semaphore_addr = get_arg_val<uint32_t>(i); i+=1;
uint32_t weights_mcast_receiver_semaphore_addr = get_arg_val<uint32_t>(i); i+=1;
DPRINT << "in receiver kernel sender semaphore id = " << i << ENDL();
uint32_t weights_mcast_sender_semaphore_addr = get_semaphore(get_arg_val<uint32_t>(i)); i+=1;
uint32_t weights_mcast_receiver_semaphore_addr = get_semaphore(get_arg_val<uint32_t>(i)); i+=1;
uint32_t out_aligned_page_size = get_arg_val<uint32_t>(i); i+=1;

volatile tt_l1_ptr uint32_t* weights_mcast_receiver_semaphore_addr_ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(weights_mcast_receiver_semaphore_addr);
Expand All @@ -108,17 +109,13 @@ void kernel_main() {
uint32_t out_block_w_start_tile_id = out_start_tile_id;
uint32_t out_block_w_start_tile_id_w = out_start_tile_id_w;
uint32_t weight_start_tile_id = out_start_tile_id_w;
/*DPRINT << "weight receiver out_num_blocks_w = " << out_num_blocks_w << ENDL();*/
/*DPRINT << "weight receiver out_num_blocks_h = " << out_num_blocks_h << ENDL();*/
/*DPRINT << "weight receiver out_num_blocks_h = " << out_num_blocks_h << ENDL();*/
/*DPRINT << "weight receiver out_num_blocks_h = " << out_num_blocks_h << ENDL();*/
/*DPRINT << "weight receiver out_num_blocks_h = " << out_num_blocks_h << ENDL();*/
/*DPRINT << "weight receiver weight_block_height_num_outer = " << weight_block_height_num_outer << ENDL();*/
DPRINT << "weight receiver out_num_blocks_h = " << out_num_blocks_h << ENDL();
DPRINT << "weight receiver weight_block_height_num_outer = " << weight_block_height_num_outer << ENDL();
int temp = 0;
for (uint32_t bw = 0; bw < out_num_blocks_w; bw++) {
uint32_t out_block_h_start_tile_id = out_block_w_start_tile_id;
uint32_t out_block_h_start_tile_id_h = out_start_tile_id_h;
/*DPRINT << "weight receiver weight_block_height_num_outer = " << weight_block_height_num_outer << ENDL();*/
DPRINT << "weight receiver weight_block_height_num_outer = " << weight_block_height_num_outer << ENDL();
for(uint32_t bh = 0; bh < out_num_blocks_h; bh++) {
// MCAST RECEIVE WEIGHTS
// read weight blocks inner dim
Expand All @@ -131,7 +128,9 @@ void kernel_main() {
noc_semaphore_set(weights_mcast_receiver_semaphore_addr_ptr, INVALID);

// Atomic increment source core counter
DPRINT << "weights_mcast_sender_semaphore_addr = " << weights_mcast_sender_semaphore_addr << ENDL();
noc_semaphore_inc(weights_mcast_sender_semaphore_noc_addr, 1);
DPRINT << "semaphore is se" << ENDL();

// wait on weights semaphore value to become VALID (set by mcast sender after it multicasts data)
noc_semaphore_wait(weights_mcast_receiver_semaphore_addr_ptr, VALID);
Expand All @@ -140,6 +139,7 @@ void kernel_main() {
/*print_pages(weight_write_l1_addr, 32*32, weight_block_num_tiles);*/
temp++;
}
DPRINT << "pushing back" << ENDL();
cb_push_back(cb_id_weight, weight_block_num_tiles);
} // for weight_block_height_num_outer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ void kernel_main() {
uint32_t weights_mcast_dest_noc_end_y = get_arg_val<uint32_t>(i); i+=1;
uint32_t weights_mcast_num_dests = get_arg_val<uint32_t>(i); i+=1;
uint32_t weights_mcast_num_cores = get_arg_val<uint32_t>(i); i+=1;
uint32_t weights_mcast_sender_semaphore_addr = get_arg_val<uint32_t>(i); i+=1;
uint32_t weights_mcast_receiver_semaphore_addr = get_arg_val<uint32_t>(i); i+=1;
DPRINT << "weights_mcast_sender_semaphore_addr = " << i << ENDL();
uint32_t weights_mcast_sender_semaphore_addr = get_semaphore(get_arg_val<uint32_t>(i)); i+=1;
uint32_t weights_mcast_receiver_semaphore_addr = get_semaphore(get_arg_val<uint32_t>(i)); i+=1;
uint32_t out_aligned_page_size = get_arg_val<uint32_t>(i); i+=1;

/*DPRINT <<" weights mcast num dests "<< weights_mcast_num_dests << ENDL();*/
Expand Down Expand Up @@ -174,6 +175,14 @@ void kernel_main() {
/*DPRINT << "weight tile nbytes = " << weight_tile_nbytes << ENDL();*/
/*DPRINT << " num_blocks_weight_h = " << num_blocks_weight_h << ENDL();*/
/*DPRINT << "weight_block_height_ntiles = " << weight_block_height_ntiles << ENDL();*/
DPRINT << "waiting " << ENDL();
DPRINT << "waiting " << ENDL();
DPRINT << "waiting " << ENDL();
DPRINT << "waiting " << ENDL();
DPRINT << "waiting " << ENDL();
DPRINT << "waiting " << ENDL();
DPRINT << "waiting " << ENDL();
DPRINT << "waiting " << ENDL();
cb_reserve_back(cb_id_weight, weight_block_num_tiles);
//DPRINT << "Reserved " << weight_block_num_tiles << " tiles in cb_id_weight" << ENDL();
uint32_t weight_write_l1_addr = get_write_ptr(cb_id_weight);
Expand Down Expand Up @@ -202,11 +211,17 @@ void kernel_main() {
}
DPRINT << "Initiated read for weights" << ENDL();
DPRINT << "weight tile nbytes = " << weight_tile_nbytes << ENDL();

noc_async_read_barrier();
DPRINT << "weight tile nbytes = " << weight_tile_nbytes << ENDL();
DPRINT << "weight tile nbytes = " << weight_tile_nbytes << ENDL();
DPRINT << "weight tile nbytes = " << weight_tile_nbytes << ENDL();
DPRINT << "weight tile nbytes = " << weight_tile_nbytes << ENDL();
DPRINT << "weight tile nbytes = " << weight_tile_nbytes << ENDL();
/*while(temp < 14*(int)weight_block_width_ntiles) {*/
while(weight_tile_h_outer_i == 2 && temp < 1) {
uint32_t addr = weights_start_address;
/*print_pages(addr, weight_tile_nbytes/2, 5, 0);*/
print_pages(addr, weight_tile_nbytes/2, 5, 0);
weights_start_address += weight_tile_nbytes;
DPRINT << ENDL();
temp++;
Expand All @@ -215,7 +230,15 @@ void kernel_main() {
#ifndef SKIP_MCAST
// wait until all weights mcast destinations have atomically incremented the weights semaphore_addr (i.e. its value should be weights_mcast_num_dests), then reset
// the semaphore_addr value back to zero for the next block
DPRINT << "weight_mcast_num_dests = " << weights_mcast_num_dests << ENDL();
DPRINT << "weights_mcast_sender_semaphore_addr = " << weights_mcast_sender_semaphore_addr << ENDL();
noc_semaphore_wait(weights_mcast_sender_semaphore_addr_ptr, weights_mcast_num_dests);
DPRINT << "weight_mcast_num_dests = " << weights_mcast_num_dests << ENDL();
DPRINT << "weight_mcast_num_dests = " << weights_mcast_num_dests << ENDL();
DPRINT << "weight_mcast_num_dests = " << weights_mcast_num_dests << ENDL();
DPRINT << "weight_mcast_num_dests = " << weights_mcast_num_dests << ENDL();
DPRINT << "weight_mcast_num_dests = " << weights_mcast_num_dests << ENDL();
DPRINT << "weight_mcast_num_dests = " << weights_mcast_num_dests << ENDL();
noc_semaphore_set(weights_mcast_sender_semaphore_addr_ptr, 0);

// Now we have the block in the CB address, we can mcast to dests!
Expand All @@ -238,7 +261,7 @@ void kernel_main() {
cb_push_back(cb_id_weight, weight_block_num_tiles);
DPRINT << "Pushed back " << weight_block_num_tiles << " tiles in cb_id_weight" << ENDL();
} // for weight_block_height_num_outer
//DPRINT << "Done with weights" << ENDL();
DPRINT << "Done with weights" << ENDL();
DPRINT << "FUSE BIAS: " << FUSE_BIAS << ENDL();
//DPRINT << "load bias: " << load_bias << ENDL();*/
#ifdef FUSE_BIAS
Expand Down
Loading

0 comments on commit 2f7ba35

Please sign in to comment.