Skip to content

Commit

Permalink
#13944: Redesign memory packing API
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-TT committed Dec 12, 2024
1 parent ee62a86 commit e035930
Show file tree
Hide file tree
Showing 11 changed files with 179 additions and 264 deletions.
8 changes: 4 additions & 4 deletions tests/tt_metal/tt_metal/eth/test_erisc_app_direct_send.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,10 @@ bool send_over_eth(

// TODO: this should be updated to use kernel api
uint32_t active_eth_index = hal.get_programmable_core_type_index(HalProgrammableCoreType::ACTIVE_ETH);
ll_api::memory const& binary_mem_send = llrt::get_risc_binary(
sender_device->build_firmware_target_path(active_eth_index, 0, 0), active_eth_index, 0, 0);
ll_api::memory const& binary_mem_receive = llrt::get_risc_binary(
receiver_device->build_firmware_target_path(active_eth_index, 0, 0), active_eth_index, 0, 0);
ll_api::memory const& binary_mem_send =
llrt::get_risc_binary(sender_device->build_firmware_target_path(active_eth_index, 0, 0));
ll_api::memory const& binary_mem_receive =
llrt::get_risc_binary(receiver_device->build_firmware_target_path(active_eth_index, 0, 0));

for (const auto& eth_core : eth_cores) {
llrt::write_hex_vec_to_core(
Expand Down
23 changes: 8 additions & 15 deletions tests/tt_metal/tt_metal/test_compile_sets_kernel_binaries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ int main(int argc, char** argv) {
dm_class_idx,
0,
get_latest_kernel_binary_path(mask, riscv0_kernel));
ll_api::memory const& brisc_binary = llrt::get_risc_binary(
brisc_hex_path, 0, 0, 0, ll_api::memory::PackSpans::PACK, ll_api::memory::Relocate::XIP);
ll_api::memory const& brisc_binary =
llrt::get_risc_binary(brisc_hex_path, ll_api::memory::Loading::CONTIGUOUS_XIP);
TT_FATAL(
brisc_binary == *brisc_binaries.at(mask).at(0),
"Expected saved BRISC binary to be the same as binary in persistent cache");
Expand All @@ -219,13 +219,11 @@ int main(int argc, char** argv) {
dm_class_idx,
1,
get_latest_kernel_binary_path(mask, riscv1_kernel));
ll_api::memory::Relocate relo_type =
auto load_type =
(device->arch() == tt::ARCH::GRAYSKULL || device->arch() == tt::ARCH::WORMHOLE_B0)
? ll_api::memory::Relocate::NONE
: ll_api::memory::Relocate::XIP;

ll_api::memory const& ncrisc_binary =
llrt::get_risc_binary(ncrisc_hex_path, 0, 1, 0, ll_api::memory::PackSpans::PACK, relo_type);
? ll_api::memory::Loading::CONTIGUOUS
: ll_api::memory::Loading::CONTIGUOUS_XIP;
ll_api::memory const& ncrisc_binary = llrt::get_risc_binary(ncrisc_hex_path, load_type);
TT_FATAL(
ncrisc_binary == *ncrisc_binaries.at(mask).at(0),
"Expected saved NCRISC binary to be the same as binary in persistent cache");
Expand All @@ -236,13 +234,8 @@ int main(int argc, char** argv) {
compute_class_idx,
trisc_id,
get_latest_kernel_binary_path(mask, compute_kernel));
ll_api::memory const& trisc_binary = llrt::get_risc_binary(
trisc_hex_path,
0,
2,
trisc_id,
ll_api::memory::PackSpans::PACK,
ll_api::memory::Relocate::XIP);
ll_api::memory const& trisc_binary =
llrt::get_risc_binary(trisc_hex_path, ll_api::memory::Loading::CONTIGUOUS_XIP);
TT_FATAL(
trisc_binary == *compute_binaries.at(mask).at(trisc_id),
"Expected saved TRISC binary for {} to be the same as binary in persistent cache",
Expand Down
22 changes: 15 additions & 7 deletions tt_metal/hw/toolchain/sections.ld
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ SECTIONS
*(.text .stub .text.* .gnu.linkonce.t.*)
/* .gnu.warning sections are handled specially by elf32.em. */
*(.gnu.warning)
. = ALIGN(4);
} > REGION_CODE :text
.init.fini :
{
Expand All @@ -73,23 +74,30 @@ SECTIONS
ASSERT(SIZEOF(.init.fini) == 0, ".init/.fini sections have contents");
} > REGION_CODE :text

. = ALIGN(. + MEM_PAD, MEM_ALIGN);
#if defined(TYPE_KERNEL)
__kernel_data_lma = .;
#endif
. = ALIGN(ABSOLUTE(.) + MEM_PAD, MEM_ALIGN);

#if defined(TYPE_FIRMWARE)
__fw_export_end_text = .;
__fw_export_end_text = ABSOLUTE(.);
#if defined(TARGET_NCRISC)
PROVIDE (KERNEL_ENTRY_SYMBOL = __fw_export_end_text);
PROVIDE (KERNEL_ENTRY_SYMBOL = ABSOLUTE(__fw_export_end_text));
#endif
#endif

#if defined(TYPE_KERNEL)
__kernel_init_local_l1_base = .;
__kernel_init_local_l1_base = ABSOLUTE(.);
#endif

#if defined(TYPE_FIRMWARE)
PROVIDE(__global_pointer$ = ORIGIN(REGION_DATA) + 0x7f0);
#endif
.data DATA_START : ALIGN(4)
.data DATA_START :
#if defined (TYPE_KERNEL)
AT(__kernel_data_lma)
#endif
ALIGN(4)
{
. = .; /* Force section emission. */
__ldm_data_start = .;
Expand Down Expand Up @@ -141,8 +149,8 @@ SECTIONS
} > REGION_DATA :data

#ifdef TYPE_FIRMWARE
. = ALIGN(MEM_ALIGN);
__fw_export_ldm_end = .;
. = ALIGN(ABSOLUTE(.), MEM_ALIGN);
__fw_export_ldm_end = ABSOLUTE(.);
#endif

#ifdef TYPE_FIRMWARE
Expand Down
10 changes: 2 additions & 8 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,10 +438,7 @@ void Device::initialize_firmware(const HalProgrammableCoreType &core_type, CoreC
auto [build_idx, num_build_states] = this->build_processor_type_to_index(core_type_idx, processor_class);
for (uint32_t riscv_id = build_idx; riscv_id < (build_idx + num_build_states); riscv_id++) {
ll_api::memory const& binary_mem = llrt::get_risc_binary(
firmware_build_states_[riscv_id]->get_target_out_path(""),
core_type_idx,
processor_class,
(riscv_id - build_idx));
firmware_build_states_[riscv_id]->get_target_out_path(""));
uint32_t fw_size = binary_mem.get_text_size();
if (riscv_id == 1) { // TODO: clean up how brisc/ncrisc are handled
// In this context, ncrisc_kernel_size16 is the size of the fw
Expand Down Expand Up @@ -485,10 +482,7 @@ void Device::initialize_firmware(const HalProgrammableCoreType &core_type, CoreC
auto [build_idx, num_build_states] = this->build_processor_type_to_index(core_type_idx, processor_class);
for (uint32_t eriscv_id = build_idx; eriscv_id < (build_idx + num_build_states); eriscv_id++) {
ll_api::memory const& binary_mem = llrt::get_risc_binary(
firmware_build_states_[eriscv_id]->get_target_out_path(""),
core_type_idx,
processor_class,
(eriscv_id - build_idx));
firmware_build_states_[eriscv_id]->get_target_out_path(""));
uint32_t fw_size = binary_mem.get_text_size();
log_debug(LogDevice, "ERISC fw binary size: {} in bytes", fw_size);
llrt::test_load_write_read_risc_binary(binary_mem, this->id(), virtual_core, core_type_idx, processor_class, (eriscv_id - build_idx));
Expand Down
27 changes: 7 additions & 20 deletions tt_metal/impl/kernels/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,17 +378,12 @@ void DataMovementKernel::read_binaries(Device *device) {
int riscv_id = static_cast<std::underlying_type<DataMovementProcessor>::type>(this->config_.processor);
const JitBuildState &build_state = device->build_kernel_state(tensix_core_type, dm_class_idx, riscv_id);
// TODO: from HAL
ll_api::memory::Relocate relo_type =
auto load_type =
(riscv_id == 1 && (device->arch() == tt::ARCH::GRAYSKULL || device->arch() == tt::ARCH::WORMHOLE_B0)) ?
ll_api::memory::Relocate::NONE : ll_api::memory::Relocate::XIP;
ll_api::memory::Loading::CONTIGUOUS : ll_api::memory::Loading::CONTIGUOUS_XIP;
ll_api::memory const& binary_mem = llrt::get_risc_binary(
build_state.get_target_out_path(this->kernel_full_name_),
// processor class is BRISC/NCRISC and each have one data movement processor type
tensix_core_type,
riscv_id,
dm_class_idx,
ll_api::memory::PackSpans::PACK,
relo_type);
load_type);
binaries.push_back(&binary_mem);
uint32_t binary_size = binary_mem.get_packed_size();
log_debug(LogLoader, "RISC {} kernel binary size: {} in bytes", riscv_id, binary_size);
Expand All @@ -405,15 +400,11 @@ void EthernetKernel::read_binaries(Device *device) {
const JitBuildState &build_state = device->build_kernel_state(erisc_core_type, dm_class_idx, erisc_id);
int risc_id = erisc_id + (this->config_.eth_mode == Eth::IDLE ? 6 : 5); // TODO (abhullar): clean this up when llrt helpers use HAL
// TODO: fix when active eth supports relo
ll_api::memory::Relocate relo_type = (this->config_.eth_mode == Eth::IDLE) ?
ll_api::memory::Relocate::XIP : ll_api::memory::Relocate::NONE;
auto load_type = (this->config_.eth_mode == Eth::IDLE) ?
ll_api::memory::Loading::CONTIGUOUS_XIP : ll_api::memory::Loading::DISCRETE;
ll_api::memory const& binary_mem = llrt::get_risc_binary(
build_state.get_target_out_path(this->kernel_full_name_),
erisc_core_type,
erisc_id,
dm_class_idx,
ll_api::memory::PackSpans::PACK,
relo_type);
load_type);
binaries.push_back(&binary_mem);
uint32_t binary_size = binary_mem.get_packed_size();
log_debug(LogLoader, "ERISC {} kernel binary size: {} in bytes", erisc_id, binary_size);
Expand All @@ -429,11 +420,7 @@ void ComputeKernel::read_binaries(Device *device) {
const JitBuildState &build_state = device->build_kernel_state(tensix_core_type, compute_class_idx, trisc_id);
ll_api::memory const& binary_mem = llrt::get_risc_binary(
build_state.get_target_out_path(this->kernel_full_name_),
tensix_core_type,
compute_class_idx,
trisc_id,
ll_api::memory::PackSpans::PACK,
ll_api::memory::Relocate::XIP);
ll_api::memory::Loading::CONTIGUOUS_XIP);
binaries.push_back(&binary_mem);
uint32_t binary_size = binary_mem.get_packed_size();
log_debug(LogLoader, "RISC {} kernel binary size: {} in bytes", trisc_id + 2, binary_size);
Expand Down
33 changes: 12 additions & 21 deletions tt_metal/llrt/llrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,46 +47,37 @@ using std::uint64_t;

ll_api::memory const& get_risc_binary(
string const& path,
uint32_t core_type_idx,
uint32_t processor_class_idx,
uint32_t processor_type_idx,
ll_api::memory::PackSpans span_type,
ll_api::memory::Relocate relo_type) {
ll_api::memory::Loading loading) {
static struct {
std::unordered_map<std::string, std::unique_ptr<ll_api::memory>> map;
std::unordered_map<std::string, std::unique_ptr<ll_api::memory const>> map;
std::mutex mutex;
std::condition_variable cvar;
} cache;

std::unique_lock lock(cache.mutex);
auto [slot, inserted] = cache.map.try_emplace(path);
ll_api::memory const* ptr = nullptr;
if (inserted) {
// We're the first with PATH. Create and insert.
lock.unlock();
auto *ptr = new ll_api::memory(path, relo_type);

// TODO: pass pack_spans into reader, generate text/data sizes
// from segment sizes and pack there
if (span_type == ll_api::memory::PackSpans::PACK) {
uint64_t data_start = tt::tt_metal::hal.get_dev_addr(tt::tt_metal::HalProgrammableCoreType::TENSIX, tt::tt_metal::HalL1MemAddrType::LOCAL);
uint64_t text_start = (relo_type == ll_api::memory::Relocate::XIP) ?
0 :
tt::tt_metal::hal.get_base_firmware_addr(core_type_idx, processor_class_idx, processor_type_idx);
ptr->pack_data_into_text(text_start, data_start);
}
ptr = new ll_api::memory(path, loading);

lock.lock();
// maps have iterator stability, so SLOT is still valid.
slot->second = decltype(slot->second)(ptr);
// We can't wake just those waiting on this slot, so wake them
// all. Should be a rare event anyway.
cache.cvar.notify_all();
} else if (!slot->second) {
// Someone else is creating the initial entry, wait for them.
cache.cvar.wait(lock, [=] { return bool(slot->second); });
} else {
if (!slot->second) {
// Someone else is creating the initial entry, wait for them.
cache.cvar.wait(lock, [=] { return bool(slot->second); });
}
ptr = slot->second.get();
TT_ASSERT(ptr->get_loading() == loading);
}

return *slot->second.get();
return *ptr;
}

// CoreCoord core --> NOC coordinates ("functional workers" from the SOC descriptor)
Expand Down
11 changes: 2 additions & 9 deletions tt_metal/llrt/llrt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,10 @@ using WorkerCore = tt_cxy_pair;
using WorkerCores = std::vector<WorkerCore>;

// Return a reference to a potentially shared binary image.
// The images are cached by path name, which is never erased.
// TODO: Remove core_type_idx, processor_class_idx,
// processor_type_idx -- the information they provide can be
// obtained directly from the binary image.
// The images are cached by path name.
ll_api::memory const& get_risc_binary(
string const& path,
uint32_t core_type_idx,
uint32_t processor_class_idx,
uint32_t processor_type_idx,
ll_api::memory::PackSpans span_type = ll_api::memory::PackSpans::NO_PACK,
ll_api::memory::Relocate relo_type = ll_api::memory::Relocate::NONE);
ll_api::memory::Loading loading = ll_api::memory::Loading::DISCRETE);

// TODO: try using "stop" method from device instead, it's the proper way of asserting reset

Expand Down
Loading

0 comments on commit e035930

Please sign in to comment.