Skip to content

Commit

Permalink
Revert "#13944: Redesign memory packing API (#15980)"
Browse files Browse the repository at this point in the history
This reverts commit 434bd8e.

> Conflicts:
>	tests/tt_metal/tt_metal/test_compile_sets_kernel_binaries.cpp
  • Loading branch information
SeanNijjar committed Dec 17, 2024
1 parent c4f318b commit 47cb524
Show file tree
Hide file tree
Showing 11 changed files with 249 additions and 335 deletions.
8 changes: 4 additions & 4 deletions tests/tt_metal/tt_metal/eth/test_erisc_app_direct_send.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,10 @@ bool send_over_eth(

// TODO: this should be updated to use kernel api
uint32_t active_eth_index = hal.get_programmable_core_type_index(HalProgrammableCoreType::ACTIVE_ETH);
ll_api::memory const& binary_mem_send =
llrt::get_risc_binary(sender_device->build_firmware_target_path(active_eth_index, 0, 0));
ll_api::memory const& binary_mem_receive =
llrt::get_risc_binary(receiver_device->build_firmware_target_path(active_eth_index, 0, 0));
const ll_api::memory& binary_mem_send = llrt::get_risc_binary(
sender_device->build_firmware_target_path(active_eth_index, 0, 0), active_eth_index, 0, 0);
const ll_api::memory& binary_mem_receive = llrt::get_risc_binary(
receiver_device->build_firmware_target_path(active_eth_index, 0, 0), active_eth_index, 0, 0);

for (const auto& eth_core : eth_cores) {
llrt::write_hex_vec_to_core(
Expand Down
164 changes: 0 additions & 164 deletions tests/tt_metal/tt_metal/test_compile_sets_kernel_binaries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,169 +101,5 @@ void construct_program(Program& program, Device* device, CoreCoord& core) {

int main(int argc, char** argv) {
bool pass = true;

try {
////////////////////////////////////////////////////////////////////////////
// Device Setup
////////////////////////////////////////////////////////////////////////////
CoreCoord core = {0, 0};
int num_devices = tt::tt_metal::GetNumAvailableDevices();
std::vector<int> ids;
for (unsigned int id = 0; id < num_devices; id++) {
ids.push_back(id);
}
tt::DevicePool::initialize(ids, 1, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, DispatchCoreConfig{});
auto devices = tt::DevicePool::instance().get_all_active_devices();
std::vector<Program> programs;
// kernel->binaries() returns 32B aligned binaries
std::map<uint32_t, std::vector<ll_api::memory const*>> compute_binaries;
std::map<uint32_t, std::vector<ll_api::memory const*>> brisc_binaries;
std::map<uint32_t, std::vector<ll_api::memory const*>> ncrisc_binaries;

for (int i = 0; i < num_devices; i++) {
auto device = devices[i];

////////////////////////////////////////////////////////////////////////////
// Application Setup
////////////////////////////////////////////////////////////////////////////
programs.push_back(Program());
Program& program = programs.back();

construct_program(program, device, core);

////////////////////////////////////////////////////////////////////////////
// Compile Application
////////////////////////////////////////////////////////////////////////////
// Check that binary memory objects in the kernel match the ones obtained from the persistent cache
uint32_t programmable_core_index = hal.get_programmable_core_type_index(HalProgrammableCoreType::TENSIX);
const KernelGroup* kernel_group = program.kernels_on_core(core, programmable_core_index);
TT_FATAL(
kernel_group != nullptr && kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_COMPUTE].has_value() and
kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_DM0].has_value() and
kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_DM1].has_value(),
"Error");
auto compute_kernel =
tt_metal::detail::GetKernel(program, kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_COMPUTE].value());
auto riscv0_kernel =
tt_metal::detail::GetKernel(program, kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_DM0].value());
auto riscv1_kernel =
tt_metal::detail::GetKernel(program, kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_DM1].value());

// Run iteration to get golden
uint32_t mask = device->build_key();
tt_metal::detail::CompileProgram(device, program);
compute_binaries.insert({mask, compute_kernel->binaries(mask)});
TT_FATAL(compute_binaries.at(mask).size() == 3, "Expected 3 Compute binaries!");
brisc_binaries.insert({mask, riscv0_kernel->binaries(mask)});
TT_FATAL(brisc_binaries.at(mask).size() == 1, "Expected 1 BRISC binary!");
ncrisc_binaries.insert({mask, riscv1_kernel->binaries(mask)});
TT_FATAL(ncrisc_binaries.at(mask).size() == 1, "Expected 1 NCRISC binary!");
}

int num_compiles = 3;
for (int i = 0; i < 3; i++) {
std::vector<string> kernel_names = {"reader_unary_push_4", "writer_unary", "eltwise_copy_3m"};
for (int i = 0; i < num_devices; i++) {
for (const auto& kernel_name : kernel_names) {
std::filesystem::remove_all(devices[i]->build_env().get_out_kernel_root_path() + kernel_name);
}
}
tt_metal::detail::ClearKernelCache();
std::vector<Program> new_programs;
for (int i = 0; i < num_devices; i++) {
auto& device = devices[i];
new_programs.push_back(Program());
Program& program = new_programs.back();
construct_program(program, device, core);
}

std::vector<std::thread> ths;
ths.reserve(num_devices);
uint32_t dm_class_idx = magic_enum::enum_integer(HalProcessorClassType::DM);
uint32_t compute_class_idx = magic_enum::enum_integer(HalProcessorClassType::COMPUTE);
for (int i = 0; i < num_devices; i++) {
auto& device = devices[i];
auto& program = new_programs[i];
ths.emplace_back([&] {
for (int j = 0; j < num_compiles; j++) {
uint32_t mask = device->build_key();
tt_metal::detail::CompileProgram(device, program);
uint32_t programmable_core_index =
hal.get_programmable_core_type_index(HalProgrammableCoreType::TENSIX);
const KernelGroup* kernel_group = program.kernels_on_core(core, programmable_core_index);
auto compute_kernel = tt_metal::detail::GetKernel(
program, kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_COMPUTE].value());
auto riscv0_kernel = tt_metal::detail::GetKernel(
program, kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_DM0].value());
auto riscv1_kernel = tt_metal::detail::GetKernel(
program, kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_DM1].value());
TT_FATAL(compute_kernel->binaries(mask) == compute_binaries.at(mask), "Error");
TT_FATAL(riscv0_kernel->binaries(mask) == brisc_binaries.at(mask), "Error");
TT_FATAL(riscv1_kernel->binaries(mask) == ncrisc_binaries.at(mask), "Error");

std::string brisc_hex_path = device->build_kernel_target_path(
programmable_core_index,
dm_class_idx,
0,
get_latest_kernel_binary_path(device->build_env().get_out_kernel_root_path(), riscv0_kernel));
ll_api::memory const& brisc_binary =
llrt::get_risc_binary(brisc_hex_path, ll_api::memory::Loading::CONTIGUOUS_XIP);
TT_FATAL(
brisc_binary == *brisc_binaries.at(mask).at(0),
"Expected saved BRISC binary to be the same as binary in persistent cache");
std::string ncrisc_hex_path = device->build_kernel_target_path(
programmable_core_index,
dm_class_idx,
1,
get_latest_kernel_binary_path(device->build_env().get_out_kernel_root_path(), riscv1_kernel));
auto load_type =
(device->arch() == tt::ARCH::GRAYSKULL || device->arch() == tt::ARCH::WORMHOLE_B0)
? ll_api::memory::Loading::CONTIGUOUS
: ll_api::memory::Loading::CONTIGUOUS_XIP;
ll_api::memory const& ncrisc_binary = llrt::get_risc_binary(ncrisc_hex_path, load_type);
TT_FATAL(
ncrisc_binary == *ncrisc_binaries.at(mask).at(0),
"Expected saved NCRISC binary to be the same as binary in persistent cache");
for (int trisc_id = 0; trisc_id <= 2; trisc_id++) {
std::string trisc_id_str = std::to_string(trisc_id);
std::string trisc_hex_path = device->build_kernel_target_path(
programmable_core_index,
compute_class_idx,
trisc_id,
get_latest_kernel_binary_path(device->build_env().get_out_kernel_root_path(), compute_kernel));
ll_api::memory const& trisc_binary =
llrt::get_risc_binary(trisc_hex_path, ll_api::memory::Loading::CONTIGUOUS_XIP);
TT_FATAL(
trisc_binary == *compute_binaries.at(mask).at(trisc_id),
"Expected saved TRISC binary for {} to be the same as binary in persistent cache",
trisc_id_str);
}
}
});
}
for (auto& th : ths) {
th.join();
}
}
for (auto dev : devices) {
pass &= tt_metal::CloseDevice(dev);
}

} catch (const std::exception& e) {
pass = false;
// Capture the exception error message
log_error(LogTest, "{}", e.what());
// Capture system call errors that may have returned from driver/kernel
log_error(LogTest, "System error message: {}", std::strerror(errno));
}

if (pass) {
log_info(LogTest, "Test Passed");
} else {
TT_THROW("Test Failed");
}

TT_FATAL(pass, "Error");

return 0;
}
22 changes: 7 additions & 15 deletions tt_metal/hw/toolchain/sections.ld
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ SECTIONS
*(.text .stub .text.* .gnu.linkonce.t.*)
/* .gnu.warning sections are handled specially by elf32.em. */
*(.gnu.warning)
. = ALIGN(4);
} > REGION_CODE :text
.init.fini :
{
Expand All @@ -74,30 +73,23 @@ SECTIONS
ASSERT(SIZEOF(.init.fini) == 0, ".init/.fini sections have contents");
} > REGION_CODE :text

#if defined(TYPE_KERNEL)
__kernel_data_lma = .;
#endif
. = ALIGN(ABSOLUTE(.) + MEM_PAD, MEM_ALIGN);
. = ALIGN(. + MEM_PAD, MEM_ALIGN);

#if defined(TYPE_FIRMWARE)
__fw_export_end_text = ABSOLUTE(.);
__fw_export_end_text = .;
#if defined(TARGET_NCRISC)
PROVIDE (KERNEL_ENTRY_SYMBOL = ABSOLUTE(__fw_export_end_text));
PROVIDE (KERNEL_ENTRY_SYMBOL = __fw_export_end_text);
#endif
#endif

#if defined(TYPE_KERNEL)
__kernel_init_local_l1_base = ABSOLUTE(.);
__kernel_init_local_l1_base = .;
#endif

#if defined(TYPE_FIRMWARE)
PROVIDE(__global_pointer$ = ORIGIN(REGION_DATA) + 0x7f0);
#endif
.data DATA_START :
#if defined (TYPE_KERNEL)
AT(__kernel_data_lma)
#endif
ALIGN(4)
.data DATA_START : ALIGN(4)
{
. = .; /* Force section emission. */
__ldm_data_start = .;
Expand Down Expand Up @@ -149,8 +141,8 @@ SECTIONS
} > REGION_DATA :data

#ifdef TYPE_FIRMWARE
. = ALIGN(ABSOLUTE(.), MEM_ALIGN);
__fw_export_ldm_end = ABSOLUTE(.);
. = ALIGN(MEM_ALIGN);
__fw_export_ldm_end = .;
#endif

#ifdef TYPE_FIRMWARE
Expand Down
10 changes: 8 additions & 2 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,10 @@ void Device::initialize_firmware(const HalProgrammableCoreType &core_type, CoreC
auto [build_idx, num_build_states] = this->build_processor_type_to_index(core_type_idx, processor_class);
for (uint32_t riscv_id = build_idx; riscv_id < (build_idx + num_build_states); riscv_id++) {
ll_api::memory const& binary_mem = llrt::get_risc_binary(
firmware_build_states_[riscv_id]->get_target_out_path(""));
firmware_build_states_[riscv_id]->get_target_out_path(""),
core_type_idx,
processor_class,
(riscv_id - build_idx));
uint32_t fw_size = binary_mem.get_text_size();
if (riscv_id == 1) { // TODO: clean up how brisc/ncrisc are handled
// In this context, ncrisc_kernel_size16 is the size of the fw
Expand Down Expand Up @@ -482,7 +485,10 @@ void Device::initialize_firmware(const HalProgrammableCoreType &core_type, CoreC
auto [build_idx, num_build_states] = this->build_processor_type_to_index(core_type_idx, processor_class);
for (uint32_t eriscv_id = build_idx; eriscv_id < (build_idx + num_build_states); eriscv_id++) {
ll_api::memory const& binary_mem = llrt::get_risc_binary(
firmware_build_states_[eriscv_id]->get_target_out_path(""));
firmware_build_states_[eriscv_id]->get_target_out_path(""),
core_type_idx,
processor_class,
(eriscv_id - build_idx));
uint32_t fw_size = binary_mem.get_text_size();
log_debug(LogDevice, "ERISC fw binary size: {} in bytes", fw_size);
llrt::test_load_write_read_risc_binary(binary_mem, this->id(), virtual_core, core_type_idx, processor_class, (eriscv_id - build_idx));
Expand Down
27 changes: 20 additions & 7 deletions tt_metal/impl/kernels/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,12 +378,17 @@ void DataMovementKernel::read_binaries(Device *device) {
int riscv_id = static_cast<std::underlying_type<DataMovementProcessor>::type>(this->config_.processor);
const JitBuildState &build_state = device->build_kernel_state(tensix_core_type, dm_class_idx, riscv_id);
// TODO: from HAL
auto load_type =
ll_api::memory::Relocate relo_type =
(riscv_id == 1 && (device->arch() == tt::ARCH::GRAYSKULL || device->arch() == tt::ARCH::WORMHOLE_B0)) ?
ll_api::memory::Loading::CONTIGUOUS : ll_api::memory::Loading::CONTIGUOUS_XIP;
ll_api::memory::Relocate::NONE : ll_api::memory::Relocate::XIP;
ll_api::memory const& binary_mem = llrt::get_risc_binary(
build_state.get_target_out_path(this->kernel_full_name_),
load_type);
// processor class is BRISC/NCRISC and each have one data movement processor type
tensix_core_type,
riscv_id,
dm_class_idx,
ll_api::memory::PackSpans::PACK,
relo_type);
binaries.push_back(&binary_mem);
uint32_t binary_size = binary_mem.get_packed_size();
log_debug(LogLoader, "RISC {} kernel binary size: {} in bytes", riscv_id, binary_size);
Expand All @@ -400,11 +405,15 @@ void EthernetKernel::read_binaries(Device *device) {
const JitBuildState &build_state = device->build_kernel_state(erisc_core_type, dm_class_idx, erisc_id);
int risc_id = erisc_id + (this->config_.eth_mode == Eth::IDLE ? 6 : 5); // TODO (abhullar): clean this up when llrt helpers use HAL
// TODO: fix when active eth supports relo
auto load_type = (this->config_.eth_mode == Eth::IDLE) ?
ll_api::memory::Loading::CONTIGUOUS_XIP : ll_api::memory::Loading::DISCRETE;
ll_api::memory::Relocate relo_type = (this->config_.eth_mode == Eth::IDLE) ?
ll_api::memory::Relocate::XIP : ll_api::memory::Relocate::NONE;
ll_api::memory const& binary_mem = llrt::get_risc_binary(
build_state.get_target_out_path(this->kernel_full_name_),
load_type);
erisc_core_type,
erisc_id,
dm_class_idx,
ll_api::memory::PackSpans::PACK,
relo_type);
binaries.push_back(&binary_mem);
uint32_t binary_size = binary_mem.get_packed_size();
log_debug(LogLoader, "ERISC {} kernel binary size: {} in bytes", erisc_id, binary_size);
Expand All @@ -420,7 +429,11 @@ void ComputeKernel::read_binaries(Device *device) {
const JitBuildState &build_state = device->build_kernel_state(tensix_core_type, compute_class_idx, trisc_id);
ll_api::memory const& binary_mem = llrt::get_risc_binary(
build_state.get_target_out_path(this->kernel_full_name_),
ll_api::memory::Loading::CONTIGUOUS_XIP);
tensix_core_type,
compute_class_idx,
trisc_id,
ll_api::memory::PackSpans::PACK,
ll_api::memory::Relocate::XIP);
binaries.push_back(&binary_mem);
uint32_t binary_size = binary_mem.get_packed_size();
log_debug(LogLoader, "RISC {} kernel binary size: {} in bytes", trisc_id + 2, binary_size);
Expand Down
33 changes: 21 additions & 12 deletions tt_metal/llrt/llrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,37 +47,46 @@ using std::uint64_t;

ll_api::memory const& get_risc_binary(
string const& path,
ll_api::memory::Loading loading) {
uint32_t core_type_idx,
uint32_t processor_class_idx,
uint32_t processor_type_idx,
ll_api::memory::PackSpans span_type,
ll_api::memory::Relocate relo_type) {
static struct {
std::unordered_map<std::string, std::unique_ptr<ll_api::memory const>> map;
std::unordered_map<std::string, std::unique_ptr<ll_api::memory>> map;
std::mutex mutex;
std::condition_variable cvar;
} cache;

std::unique_lock lock(cache.mutex);
auto [slot, inserted] = cache.map.try_emplace(path);
ll_api::memory const* ptr = nullptr;
if (inserted) {
// We're the first with PATH. Create and insert.
lock.unlock();
ptr = new ll_api::memory(path, loading);
auto *ptr = new ll_api::memory(path, relo_type);

// TODO: pass pack_spans into reader, generate text/data sizes
// from segment sizes and pack there
if (span_type == ll_api::memory::PackSpans::PACK) {
uint64_t data_start = tt::tt_metal::hal.get_dev_addr(tt::tt_metal::HalProgrammableCoreType::TENSIX, tt::tt_metal::HalL1MemAddrType::LOCAL);
uint64_t text_start = (relo_type == ll_api::memory::Relocate::XIP) ?
0 :
tt::tt_metal::hal.get_base_firmware_addr(core_type_idx, processor_class_idx, processor_type_idx);
ptr->pack_data_into_text(text_start, data_start);
}

lock.lock();
// maps have iterator stability, so SLOT is still valid.
slot->second = decltype(slot->second)(ptr);
// We can't wake just those waiting on this slot, so wake them
// all. Should be a rare event anyway.
cache.cvar.notify_all();
} else {
if (!slot->second) {
// Someone else is creating the initial entry, wait for them.
cache.cvar.wait(lock, [=] { return bool(slot->second); });
}
ptr = slot->second.get();
TT_ASSERT(ptr->get_loading() == loading);
} else if (!slot->second) {
// Someone else is creating the initial entry, wait for them.
cache.cvar.wait(lock, [=] { return bool(slot->second); });
}

return *ptr;
return *slot->second.get();
}

// CoreCoord core --> NOC coordinates ("functional workers" from the SOC descriptor)
Expand Down
Loading

0 comments on commit 47cb524

Please sign in to comment.