Skip to content

Commit

Permalink
#0: Fixed Bias in WS
Browse files Browse the repository at this point in the history
  • Loading branch information
sankarmanoj-tt committed Dec 11, 2024
1 parent ce3b592 commit 3f52fd5
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 5 deletions.
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def test_conv_ws(
debug = False
groups = 1

torch.manual_seed(0)
# torch.manual_seed()
conv_input_shape = [batch_size, input_channels, input_height, input_width]
conv_weight_shape = [output_channels, input_channels // groups, filter_height, filter_width]
conv_bias_shape = [1, 1, 1, output_channels]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,6 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_width_sharded_v2_impl(
uint32_t num_blocks_act_h_per_core =
(per_core_out_matrix_height_ntiles + act_block_h_ntiles - 1) / act_block_h_ntiles;
uint32_t num_blocks_weight_w_per_core = per_core_out_matrix_width_ntiles / weight_block_w_ntiles;
uint32_t bias_ntiles_per_core = bias_ntiles / num_weight_slices_width;

std::map<string, string> writer_defines;
std::map<string, string> writer_mcast_sender_defines;
Expand Down Expand Up @@ -669,7 +668,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_width_sharded_v2_impl(
tilize_in0, // tilize_in0
untilize_out, // untilize_out

bias_ntiles_per_core,
bias_ntiles,

out0_cb,
num_output_tiles,
Expand All @@ -678,7 +677,6 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_width_sharded_v2_impl(
input_num_cores, // in0_nblocks_w_tilize. Repeat tilize after all cores have done one round of MCAST.
};


uint32_t act_tile_size = tt_metal::detail::TileSize(act_df);
uint32_t tilized_act_tile_size = tt_metal::detail::TileSize(tilized_act_df);
uint32_t weight_tile_size = tt_metal::detail::TileSize(weight_df);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void kernel_main() {
const InterleavedAddrGenFast<true> s_weight = {
.bank_base_address = weight_addr_dram_base, .page_size = weight_tile_nbytes, .data_format = weight_df};
#ifdef FUSE_BIAS

cb_reserve_back(bias_cb_id, weight_block_width_ntiles);
const uint32_t bias_pagesize = get_tile_size(bias_cb_id);
const DataFormat bias_df = get_dataformat(bias_cb_id);
const InterleavedAddrGenFast<bias_in_dram> s_bias = {
Expand Down

0 comments on commit 3f52fd5

Please sign in to comment.