-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4677eda
commit d3bebbb
Showing
8 changed files
with
420 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
43 changes: 43 additions & 0 deletions
43
tt_eager/tt_dnn/op_library/bcast/kernels/compute/bcast_h_sharded_optimised.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <cstdint> | ||
#include "compute_kernel_api/bcast.h" | ||
|
||
|
||
namespace NAMESPACE { | ||
void MAIN { | ||
constexpr uint32_t onetile = 1; | ||
uint32_t NC = get_arg_val<uint32_t>(0); | ||
uint32_t Ht = get_arg_val<uint32_t>(1); | ||
uint32_t Wt = get_arg_val<uint32_t>(2); | ||
uint32_t h_blk = get_arg_val<uint32_t>(3); | ||
uint32_t batch_b = get_arg_val<uint32_t>(4); | ||
uint32_t Ht_per_batch_b = get_arg_val<uint32_t>(5); | ||
|
||
init_bcast<BCAST_LLKOP, BCAST_DIM>(tt::CB::c_in0, tt::CB::c_in1, tt::CB::c_out0); | ||
|
||
cb_wait_front(tt::CB::c_in0, Wt*Ht); | ||
cb_reserve_back(tt::CB::c_out0, Wt*Ht); | ||
uint32_t b_offset = 0; | ||
for (uint32_t bn = 0; bn < batch_b; bn++) { | ||
for (uint32_t wt = 0; wt < Wt; wt++) { | ||
cb_wait_front(tt::CB::c_in1, onetile); | ||
for (uint32_t ht = 0; ht < Ht_per_batch_b; ht+=h_blk) { | ||
acquire_dst(tt::DstMode::Half); | ||
for (uint32_t htr = 0; htr<h_blk; htr++) { | ||
uint32_t current_index = b_offset + (ht + htr) * Wt + wt; | ||
BCAST_OP<BroadcastType::ROW>(tt::CB::c_in0, tt::CB::c_in1, current_index, 0, htr); | ||
pack_tile<true>(htr, tt::CB::c_out0, current_index); | ||
} | ||
release_dst(tt::DstMode::Half); | ||
} | ||
cb_pop_front(tt::CB::c_in1, onetile); | ||
} | ||
b_offset += Ht_per_batch_b * Wt; | ||
} | ||
cb_pop_front(tt::CB::c_in0, Wt*Ht); | ||
cb_push_back(tt::CB::c_out0, Wt*Ht); | ||
} | ||
} // NAMESPACE |
53 changes: 53 additions & 0 deletions
53
tt_eager/tt_dnn/op_library/bcast/kernels/dataflow/reader_bcast_h_sharded_optimised.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <stdint.h> | ||
#include "dataflow_api.h" | ||
|
||
void kernel_main() { | ||
uint32_t src1_addr = get_arg_val<uint32_t>(0); | ||
uint32_t Ht = get_arg_val<uint32_t>(1); | ||
uint32_t Wt = get_arg_val<uint32_t>(2); | ||
uint32_t offset = get_arg_val<uint32_t>(3); | ||
uint32_t batch_offset = get_arg_val<uint32_t>(4); //if weight has multiple batches | ||
uint32_t w_blk = get_arg_val<uint32_t>(5); | ||
uint32_t batch_b = get_arg_val<uint32_t>(6); | ||
|
||
//constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1; | ||
constexpr bool src1_is_dram = get_compile_time_arg_val(1) == 1; | ||
constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(0); | ||
|
||
//constexpr uint32_t cb_id_in0 = 0; | ||
constexpr uint32_t cb_id_in1 = 1; | ||
constexpr uint32_t onetile = 1; | ||
|
||
// single-tile ublocks | ||
const uint32_t tile_bytes = get_tile_size(cb_id_in1); | ||
const DataFormat data_format = get_dataformat(cb_id_in1); | ||
|
||
const InterleavedAddrGenFast<src1_is_dram> s1 = { | ||
.bank_base_address = src1_addr, | ||
.page_size = tile_bytes, | ||
.data_format = data_format | ||
}; | ||
|
||
|
||
uint32_t l1_write_addr_in0; | ||
uint32_t l1_write_addr_in1; | ||
|
||
cb_push_back(cb_id_in0, Ht * Wt); | ||
for (uint32_t b = 0; b < batch_b; b ++) { | ||
for (uint32_t wt = 0; wt < Wt; wt += w_blk) { | ||
cb_reserve_back(cb_id_in1, w_blk); | ||
l1_write_addr_in1 = get_write_ptr(cb_id_in1); | ||
for (uint32_t r = 0; r<w_blk; r++) { | ||
noc_async_read_tile(offset + wt + r, s1, l1_write_addr_in1); | ||
l1_write_addr_in1 += tile_bytes; | ||
} | ||
noc_async_read_barrier(); | ||
cb_push_back(cb_id_in1, w_blk); | ||
} | ||
offset += batch_offset; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.