-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#0: Shard and Pad programming examples (#12974)
- Loading branch information
Showing
6 changed files
with
427 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 55 additions & 0 deletions
55
tt_metal/programming_examples/pad/kernels/pad_reader_dims_rm_interleaved.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <stdint.h> | ||
#include "dataflow_api.h" | ||
|
||
void kernel_main() { | ||
|
||
const uint32_t src_addr = get_arg_val<uint32_t>(0); | ||
const uint32_t pad_addr = get_arg_val<uint32_t>(1); | ||
const uint32_t start_src_stick_id = get_arg_val<uint32_t>(2); | ||
const uint32_t row_size_diff = get_arg_val<uint32_t>(3); | ||
const uint32_t dst_N = get_arg_val<uint32_t>(4); | ||
const uint32_t data_size_bytes = get_arg_val<uint32_t>(5); | ||
const uint32_t num_rows_per_core = get_arg_val<uint32_t>(6); | ||
|
||
constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; | ||
constexpr bool pad_is_dram = get_compile_time_arg_val(1) == 1; | ||
constexpr uint32_t cb_id = tt::CB::c_in0; | ||
|
||
const InterleavedAddrGen<src_is_dram> s0 = { | ||
.bank_base_address = src_addr, | ||
.page_size = data_size_bytes | ||
}; | ||
const InterleavedAddrGen<pad_is_dram> s1 = { | ||
.bank_base_address = pad_addr, | ||
.page_size = data_size_bytes | ||
}; | ||
|
||
// pad based on page | ||
uint32_t src_stick_id = start_src_stick_id; | ||
uint32_t src_start_col_idx = row_size_diff / 2; | ||
uint32_t src_end_col_idx = dst_N - src_start_col_idx; | ||
for (uint32_t i = 0; i < num_rows_per_core; i++) { | ||
for (uint32_t dst_col_idx = 0; dst_col_idx < dst_N; dst_col_idx++) { | ||
cb_reserve_back(cb_id, 1); | ||
uint32_t l1_addr = get_write_ptr(cb_id); | ||
if (dst_col_idx < src_start_col_idx || dst_col_idx >= src_end_col_idx) { | ||
// add pad value to cb | ||
uint64_t pad_noc_addr = get_noc_addr(0, s1); | ||
noc_async_read(pad_noc_addr, l1_addr, data_size_bytes); | ||
} | ||
else { | ||
// add original src data to cb | ||
uint64_t src_noc_addr = get_noc_addr(src_stick_id, s0); | ||
noc_async_read(src_noc_addr, l1_addr, data_size_bytes); | ||
src_stick_id++; | ||
} | ||
noc_async_read_barrier(); | ||
cb_push_back(cb_id, 1); | ||
l1_addr += data_size_bytes; | ||
} | ||
} | ||
} |
38 changes: 38 additions & 0 deletions
38
tt_metal/programming_examples/pad/kernels/pad_writer_dims_rm_interleaved.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <stdint.h> | ||
#include "dataflow_api.h" | ||
|
||
|
||
void kernel_main() { | ||
|
||
const uint32_t dst_addr = get_arg_val<uint32_t>(0); | ||
const uint32_t start_dst_stick_id = get_arg_val<uint32_t>(1); | ||
const uint32_t dst_N = get_arg_val<uint32_t>(2); | ||
const uint32_t data_size_bytes = get_arg_val<uint32_t>(3); | ||
const uint32_t num_rows_per_core = get_arg_val<uint32_t>(4); | ||
|
||
constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; | ||
constexpr uint32_t cb_id = tt::CB::c_in0; | ||
|
||
const InterleavedAddrGen<dst_is_dram> s0 = { | ||
.bank_base_address = dst_addr, | ||
.page_size = data_size_bytes | ||
}; | ||
|
||
uint32_t dst_stick_id = start_dst_stick_id; | ||
for (uint32_t row_idx = 0; row_idx < num_rows_per_core; row_idx++) { | ||
for (uint32_t dst_col_idx = 0; dst_col_idx < dst_N; dst_col_idx++) { | ||
cb_wait_front(cb_id, 1); | ||
uint32_t l1_addr = get_read_ptr(cb_id); | ||
uint64_t dst_noc_addr = get_noc_addr(dst_stick_id, s0); | ||
noc_async_write(l1_addr, dst_noc_addr, data_size_bytes); | ||
noc_async_write_barrier(); | ||
dst_stick_id++; | ||
cb_pop_front(cb_id, 1); | ||
} | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "tt_metal/host_api.hpp" | ||
#include "tt_metal/common/constants.hpp" | ||
#include "tt_metal/detail/util.hpp" | ||
#include "tt_metal/common/bfloat16.hpp" | ||
#include "tt_metal/impl/dispatch/command_queue.hpp" | ||
#include "tt_metal/detail/tt_metal.hpp" | ||
#include "tt_metal/impl/device/device.hpp" | ||
|
||
using namespace tt; | ||
using namespace tt::tt_metal; | ||
|
||
int main(int argc, char **argv) { | ||
// get program/device | ||
int device_id = 0; | ||
Device *device = CreateDevice(device_id); | ||
CommandQueue& cq = device->command_queue(); | ||
Program program = CreateProgram(); | ||
|
||
// initialize source data | ||
constexpr uint32_t src_M = 8; | ||
constexpr uint32_t src_N = 4; | ||
constexpr uint32_t packed_data_size = sizeof(uint32_t); | ||
constexpr uint32_t unpacked_data_size = sizeof(bfloat16); | ||
constexpr uint32_t packing_ratio = packed_data_size / unpacked_data_size; | ||
uint32_t src_num_values_unpacked = src_M * src_N; | ||
uint32_t src_num_values_packed = src_num_values_unpacked / packing_ratio; | ||
std::vector<uint32_t> src_vec(src_num_values_packed, 0); | ||
// source vector = {1, 2, 3, ... , 30, 31, 32} | ||
for (uint32_t i = 0; i < src_vec.size(); i++) { | ||
bfloat16 bfloat_val1 = bfloat16(2 * i + 1); | ||
bfloat16 bfloat_val2 = bfloat16(2 * i + 2); | ||
src_vec[i] = pack_two_bfloat16_into_uint32(std::pair<bfloat16, bfloat16>(bfloat_val1, bfloat_val2)); | ||
} | ||
|
||
// create pad vector | ||
bfloat16 pad_value = bfloat16(2); | ||
std::vector<uint32_t> pad_vec(1, pack_two_bfloat16_into_uint32(std::pair<bfloat16, bfloat16>(pad_value, pad_value))); | ||
|
||
// create destination vector | ||
constexpr uint32_t dst_M = 8; | ||
constexpr uint32_t dst_N = 8; | ||
uint32_t dst_num_values_unpacked = dst_M * dst_N; | ||
uint32_t dst_num_values_packed = dst_num_values_unpacked / packing_ratio; | ||
std::vector<uint32_t> dst_vec(dst_num_values_packed, 0); | ||
|
||
// designate cores and core specs | ||
CoreCoord start_core = {0, 0}; | ||
CoreCoord end_core = {0, 3}; | ||
CoreRange cores(start_core, end_core); | ||
uint32_t num_cores = cores.size(); | ||
|
||
// configure and create DRAM buffers for input, pad, output | ||
uint32_t src_buffer_size = packed_data_size * src_num_values_packed; | ||
tt_metal::InterleavedBufferConfig input_dram_config { | ||
.device = device, | ||
.size = src_buffer_size, | ||
.page_size = packed_data_size, | ||
.buffer_type = tt_metal::BufferType::DRAM | ||
}; | ||
std::shared_ptr<tt::tt_metal::Buffer> src_buffer = CreateBuffer(input_dram_config); | ||
uint32_t src_addr = src_buffer->address(); | ||
|
||
uint32_t pad_buffer_size = packed_data_size * pad_vec.size(); | ||
tt_metal::InterleavedBufferConfig pad_dram_config { | ||
.device = device, | ||
.size = pad_buffer_size, | ||
.page_size = packed_data_size, | ||
.buffer_type = tt_metal::BufferType::DRAM | ||
}; | ||
std::shared_ptr<tt::tt_metal::Buffer> pad_buffer = CreateBuffer(pad_dram_config); | ||
uint32_t pad_addr = pad_buffer->address(); | ||
|
||
uint32_t dst_buffer_size = packed_data_size * dst_num_values_packed; | ||
tt_metal::InterleavedBufferConfig output_dram_config { | ||
.device = device, | ||
.size = dst_buffer_size, | ||
.page_size = packed_data_size, | ||
.buffer_type = tt_metal::BufferType::DRAM | ||
}; | ||
std::shared_ptr<tt::tt_metal::Buffer> dst_buffer = CreateBuffer(output_dram_config); | ||
uint32_t dst_addr = dst_buffer->address(); | ||
|
||
// configure and create circular buffer | ||
uint32_t cb_id = CB::c_in0; | ||
tt::DataFormat cb_data_format = tt::DataFormat::UInt32; | ||
CircularBufferConfig cb_config = tt::tt_metal::CircularBufferConfig(dst_N * packed_data_size * 2, {{cb_id, cb_data_format}}) | ||
.set_page_size(cb_id, packed_data_size); | ||
auto cb_src = tt::tt_metal::CreateCircularBuffer(program, cores, cb_config); | ||
|
||
// specify compile time args | ||
bool src_is_dram = src_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; | ||
bool pad_is_dram = pad_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; | ||
bool dst_is_dram = dst_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; | ||
std::vector<uint32_t> reader_compile_time_args = {(uint32_t) src_is_dram, | ||
(uint32_t) pad_is_dram}; | ||
std::vector<uint32_t> writer_compile_time_args = {(uint32_t) dst_is_dram}; | ||
|
||
// create kernels | ||
KernelHandle reader_id = CreateKernel(program, | ||
"tt_metal/programming_examples/pad/kernels/pad_reader_dims_rm_interleaved.cpp", | ||
cores, | ||
tt_metal::DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default, .compile_args = reader_compile_time_args}); | ||
KernelHandle writer_id = CreateKernel(program, | ||
"tt_metal/programming_examples/pad/kernels/pad_writer_dims_rm_interleaved.cpp", | ||
cores, | ||
tt_metal::DataMovementConfig{.processor = DataMovementProcessor::RISCV_1, .noc = NOC::RISCV_1_default, .compile_args = writer_compile_time_args}); | ||
|
||
// set kernel runtime arguments | ||
uint32_t start_src_idx = 0; | ||
uint32_t start_dst_idx = 0; | ||
uint32_t num_rows_per_core = src_M / num_cores; | ||
uint32_t row_size_diff = dst_N - src_N; | ||
uint32_t num_packed_row_src = src_N / packing_ratio; | ||
uint32_t num_packed_row_dst = dst_N / packing_ratio; | ||
uint32_t num_src_sticks_per_core = num_packed_row_src * num_rows_per_core; | ||
for (uint32_t core_idx = 0; core_idx < num_cores; core_idx++) { | ||
CoreCoord core = {0, core_idx}; | ||
tt_metal::SetRuntimeArgs( | ||
program, | ||
reader_id, | ||
core, | ||
{src_addr, | ||
pad_addr, | ||
start_src_idx, | ||
row_size_diff / packing_ratio, | ||
num_packed_row_dst, | ||
packed_data_size, | ||
num_rows_per_core | ||
} | ||
); | ||
tt_metal::SetRuntimeArgs( | ||
program, | ||
writer_id, | ||
core, | ||
{dst_addr, | ||
start_dst_idx, | ||
num_packed_row_dst, | ||
packed_data_size, | ||
num_rows_per_core | ||
} | ||
); | ||
start_src_idx += num_src_sticks_per_core; | ||
start_dst_idx += num_packed_row_dst * num_rows_per_core; | ||
} | ||
|
||
printf("Padding tensor of shape (%d, %d) to shape (%d, %d) with pad value: %d\n", src_M, src_N, dst_M, dst_N, pad_value.to_uint16()); | ||
printf("Original tensor with shape (%d, %d):\n", src_M, src_N); | ||
for (uint32_t m = 0; m < src_M; m++) { | ||
for (uint32_t n = 0; n < num_packed_row_src; n++) { | ||
printf("%d ", (uint16_t)src_vec[m * num_packed_row_src + n]); | ||
printf("%d ", (uint16_t)(src_vec[m * num_packed_row_src + n] >> 16)); | ||
} | ||
printf("\n"); | ||
} | ||
printf("\n"); | ||
|
||
// dispatch program to device for execution | ||
EnqueueWriteBuffer(cq, src_buffer, src_vec.data(), false); | ||
EnqueueWriteBuffer(cq, pad_buffer, pad_vec.data(), false); | ||
EnqueueProgram(cq, program, false); | ||
EnqueueReadBuffer(cq, dst_buffer, dst_vec.data(), false); | ||
Finish(cq); | ||
|
||
printf("Padded tensor with shape (%d, %d):\n", dst_M, dst_N); | ||
for (uint32_t m = 0; m < dst_M; m++) { | ||
for (uint32_t n = 0; n < num_packed_row_dst; n++) { | ||
printf("%d ", (uint16_t)dst_vec[m * num_packed_row_dst + n]); | ||
printf("%d ", (uint16_t)(dst_vec[m * num_packed_row_dst + n] >> 16)); | ||
} | ||
printf("\n"); | ||
} | ||
|
||
CloseDevice(device); | ||
} |
42 changes: 42 additions & 0 deletions
42
tt_metal/programming_examples/sharding/kernels/reader_sharded_rm.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <stdint.h> | ||
#include "dataflow_api.h" | ||
|
||
// export TT_METAL_DPRINT_CORES='(0,0)-(0,3)' in order to see DPRINT messages | ||
|
||
void kernel_main() { | ||
const uint32_t src_addr = get_arg_val<uint32_t>(0); | ||
const uint32_t stick_size = get_arg_val<uint32_t>(1); | ||
const uint32_t shard_height = get_arg_val<uint32_t>(2); | ||
const uint32_t shard_width_bytes = get_arg_val<uint32_t>(3); | ||
const uint32_t padded_offset_bytes = get_arg_val<uint32_t>(4); | ||
const uint32_t start_id = get_arg_val<uint32_t>(5); | ||
const uint32_t current_core = get_arg_val<uint32_t>(6); | ||
|
||
constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(0); | ||
constexpr bool src_is_dram = get_compile_time_arg_val(1) == 1; | ||
const InterleavedAddrGen<src_is_dram> s0 = { | ||
.bank_base_address = src_addr, | ||
.page_size = stick_size | ||
}; | ||
uint32_t stick_id = start_id; | ||
cb_reserve_back(cb_id_in0, shard_height); | ||
uint32_t l1_write_addr = get_write_ptr(cb_id_in0); | ||
DPRINT_DATA0(DPRINT << "Core (0," << current_core << "): "); | ||
for (uint32_t h = 0; h < shard_height; ++h) { | ||
uint64_t src_noc_addr = get_noc_addr(stick_id, s0); | ||
noc_async_read(src_noc_addr, l1_write_addr, stick_size); | ||
// print both BFloat16 values that are packed into the page | ||
uint32_t* read_ptr = (uint32_t*)l1_write_addr; | ||
DPRINT_DATA0(DPRINT << (uint16_t)*read_ptr << " "); | ||
DPRINT_DATA0(DPRINT << (uint16_t)(*read_ptr >> 16) << " "); | ||
stick_id++; | ||
l1_write_addr += padded_offset_bytes; | ||
} | ||
DPRINT_DATA0(DPRINT << ENDL()); | ||
noc_async_read_barrier(); | ||
cb_push_back(cb_id_in0, shard_height); | ||
} |
Oops, something went wrong.