Skip to content

Commit

Permalink
#8602: Make kernel group/launch msg fields indexable
Browse files Browse the repository at this point in the history
create a dispatch processor class enum, change named fields into arrays
  • Loading branch information
pgkeller committed Jun 21, 2024
1 parent 3c6631b commit b5553d3
Show file tree
Hide file tree
Showing 14 changed files with 149 additions and 185 deletions.
16 changes: 8 additions & 8 deletions tests/tt_metal/tt_metal/test_compile_sets_kernel_binaries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,11 @@ int main(int argc, char **argv) {
// Check that binary memory objects in the kernel match the ones obtained from the persistent cache
const KernelGroup* kernel_group = program.kernels_on_core(core, CoreType::WORKER);
TT_FATAL(
kernel_group != nullptr && kernel_group->compute_id.has_value() and
kernel_group->riscv0_id.has_value() and kernel_group->riscv1_id.has_value());
auto compute_kernel = tt_metal::detail::GetKernel(program, kernel_group->compute_id.value());
auto riscv0_kernel = tt_metal::detail::GetKernel(program, kernel_group->riscv0_id.value());
auto riscv1_kernel = tt_metal::detail::GetKernel(program, kernel_group->riscv1_id.value());
kernel_group != nullptr && kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_COMPUTE].has_value() and
kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_DM0].has_value() and kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_DM1].has_value());
auto compute_kernel = tt_metal::detail::GetKernel(program, kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_COMPUTE].value());
auto riscv0_kernel = tt_metal::detail::GetKernel(program, kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_DM0].value());
auto riscv1_kernel = tt_metal::detail::GetKernel(program, kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_DM1].value());

// Run iteration to get golden
uint32_t mask = device->build_key();
Expand Down Expand Up @@ -175,9 +175,9 @@ int main(int argc, char **argv) {
uint32_t mask = device->build_key();
tt_metal::detail::CompileProgram(device, program);
const KernelGroup* kernel_group = program.kernels_on_core(core, CoreType::WORKER);
auto compute_kernel = tt_metal::detail::GetKernel(program, kernel_group->compute_id.value());
auto riscv0_kernel = tt_metal::detail::GetKernel(program, kernel_group->riscv0_id.value());
auto riscv1_kernel = tt_metal::detail::GetKernel(program, kernel_group->riscv1_id.value());
auto compute_kernel = tt_metal::detail::GetKernel(program, kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_COMPUTE].value());
auto riscv0_kernel = tt_metal::detail::GetKernel(program, kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_DM0].value());
auto riscv1_kernel = tt_metal::detail::GetKernel(program, kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_DM1].value());
TT_FATAL(compute_kernel->binaries(mask) == compute_binaries.at(mask));
TT_FATAL(riscv0_kernel->binaries(mask) == brisc_binaries.at(mask));
TT_FATAL(riscv1_kernel->binaries(mask) == ncrisc_binaries.at(mask));
Expand Down
11 changes: 6 additions & 5 deletions tt_metal/detail/tt_metal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,15 +445,16 @@ namespace tt::tt_metal{
const KernelGroup * kernel_group = program.kernels_on_core(CoreCoord(x, y), CoreType::WORKER);
if (kernel_group != nullptr) {
bool local_noc0_in_use = false; bool local_noc1_in_use = false;
if (kernel_group->riscv0_id.has_value()) {
if (kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_DM0].has_value()) {
riscv0_in_use = true;
set_global_and_local_noc_usage(kernel_group->riscv0_id.value(), local_noc0_in_use, local_noc1_in_use);
set_global_and_local_noc_usage(kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_DM0].value(), local_noc0_in_use, local_noc1_in_use);
}
if (kernel_group->riscv1_id.has_value()) {
if (kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_DM1].has_value()) {
riscv1_in_use = true;
set_global_and_local_noc_usage(kernel_group->riscv1_id.value(), local_noc0_in_use, local_noc1_in_use);
set_global_and_local_noc_usage(kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_DM1].value(), local_noc0_in_use, local_noc1_in_use);
}
if (kernel_group->riscv0_id.has_value() and kernel_group->riscv1_id.has_value()) {
if (kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_DM0].has_value() and
kernel_group->kernel_ids[DISPATCH_CLASS_TENSIX_DM1].has_value()) {
TT_FATAL(local_noc0_in_use and local_noc1_in_use, "Illegal NOC usage: data movement kernels on logical core {} cannot use the same NOC, doing so results in hangs!", CoreCoord(x, y).str());
}
}
Expand Down
8 changes: 4 additions & 4 deletions tt_metal/hw/firmware/src/brisc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,13 @@ inline void set_ncrisc_kernel_resume_deassert_address() {
}

inline void run_triscs() {
if (mailboxes->launch.enable_triscs) {
if (mailboxes->launch.enables[DISPATCH_CLASS_TENSIX_COMPUTE]) {
mailboxes->slave_sync.all = RUN_SYNC_MSG_ALL_TRISCS_GO;
}
}

inline void finish_ncrisc_copy_and_run() {
if (mailboxes->launch.enable_ncrisc) {
if (mailboxes->launch.enables[DISPATCH_CLASS_TENSIX_DM1]) {
mailboxes->slave_sync.ncrisc = RUN_SYNC_MSG_GO;

l1_to_ncrisc_iram_copy_wait();
Expand Down Expand Up @@ -383,8 +383,8 @@ int main() {
// Run the BRISC kernel
DEBUG_STATUS("R");
uint32_t kernel_config_base = mailboxes->launch.kernel_config_base;
l1_arg_base = (uint32_t tt_l1_ptr *)(kernel_config_base + mailboxes->launch.rta_offset_brisc);
if (mailboxes->launch.enable_brisc) {
l1_arg_base = (uint32_t tt_l1_ptr *)(kernel_config_base + mailboxes->launch.rta_offsets[DISPATCH_CLASS_TENSIX_DM0]);
if (mailboxes->launch.enables[DISPATCH_CLASS_TENSIX_DM0]) {
setup_cb_read_write_interfaces(num_cbs_to_early_init, mailboxes->launch.max_cb_index, true, true);
kernel_init();
} else {
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/hw/firmware/src/erisc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void __attribute__((section("erisc_l1_code.1"), noinline)) Application(void) {
DeviceZoneScopedMainN("ERISC-FW");
DEBUG_STATUS("R");
uint32_t kernel_config_base = mailboxes->launch.kernel_config_base;
l1_arg_base = (uint32_t tt_l1_ptr *)(kernel_config_base + mailboxes->launch.rta_offset_brisc); // overloaded
l1_arg_base = (uint32_t tt_l1_ptr *)(kernel_config_base + mailboxes->launch.rta_offsets[DISPATCH_CLASS_ETH_DM0]);
kernel_init();
} else {
internal_::risc_context_switch();
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/hw/firmware/src/idle_erisc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ int main() {
//UC FIXME: do i need this?
setup_cb_read_write_interfaces(num_cbs_to_early_init, mailboxes->launch.max_cb_index, true, true);
uint32_t kernel_config_base = mailboxes->launch.kernel_config_base;
l1_arg_base = (uint32_t tt_l1_ptr *)(kernel_config_base + mailboxes->launch.rta_offset_brisc); // overloaded
l1_arg_base = (uint32_t tt_l1_ptr *)(kernel_config_base + mailboxes->launch.rta_offsets[DISPATCH_CLASS_ETH_DM0]);
kernel_init();
//} else {
// This was not initialized in kernel_init
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/hw/firmware/src/ncrisc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ int main(int argc, char *argv[]) {
setup_cb_read_write_interfaces(0, mailboxes->launch.max_cb_index, true, true);

uint32_t kernel_config_base = mailboxes->launch.kernel_config_base;
l1_arg_base = (uint32_t tt_l1_ptr *)(kernel_config_base + mailboxes->launch.rta_offset_ncrisc);
l1_arg_base = (uint32_t tt_l1_ptr *)(kernel_config_base + mailboxes->launch.rta_offsets[DISPATCH_CLASS_TENSIX_DM1]);

DEBUG_STATUS("R");
kernel_init();
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/hw/firmware/src/trisc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ int main(int argc, char *argv[]) {
#endif

uint32_t kernel_config_base = mailboxes->launch.kernel_config_base;
l1_arg_base = (uint32_t tt_l1_ptr *)(kernel_config_base + mailboxes->launch.rta_offset_trisc);
l1_arg_base = (uint32_t tt_l1_ptr *)(kernel_config_base + mailboxes->launch.rta_offsets[DISPATCH_CLASS_TENSIX_COMPUTE]);

DEBUG_STATUS("R");
kernel_init();
Expand Down
25 changes: 15 additions & 10 deletions tt_metal/hw/inc/dev_msgs.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,30 @@ enum dispatch_mode {
DISPATCH_MODE_HOST,
};

enum dispatch_core_processor_classes {
// Tensix processor classes
DISPATCH_CLASS_TENSIX_DM0 = 0,
DISPATCH_CLASS_TENSIX_DM1 = 1,
DISPATCH_CLASS_TENSIX_COMPUTE = 2,

// Ethernet processor classes
DISPATCH_CLASS_ETH_DM0 = 0,

DISPATCH_CLASS_MAX_PROC = 3,
};

struct launch_msg_t { // must be cacheline aligned
volatile uint16_t brisc_watcher_kernel_id;
volatile uint16_t ncrisc_watcher_kernel_id;
volatile uint16_t triscs_watcher_kernel_id;
volatile uint16_t watcher_kernel_ids[DISPATCH_CLASS_MAX_PROC];
volatile uint16_t ncrisc_kernel_size16; // size in 16 byte units

// Ring buffer of kernel configuration data
volatile uint32_t kernel_config_base;
volatile uint16_t rta_offset_brisc;
volatile uint16_t rta_offset_ncrisc;
volatile uint16_t rta_offset_trisc;
volatile uint16_t rta_offsets[DISPATCH_CLASS_MAX_PROC];

volatile uint8_t mode; // dispatch mode host/dev
volatile uint8_t brisc_noc_id;
volatile uint8_t enable_brisc;
volatile uint8_t enable_ncrisc;
volatile uint8_t enable_triscs;
volatile uint8_t enables[DISPATCH_CLASS_MAX_PROC];
volatile uint8_t max_cb_index;
volatile uint8_t enable_erisc;
volatile uint8_t run; // must be in last cacheline of this msg
};

Expand Down
58 changes: 29 additions & 29 deletions tt_metal/impl/debug/watcher_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ void create_kernel_file() {

static void log_running_kernels(const launch_msg_t *launch_msg) {
log_info("While running kernels:");
log_info(" brisc : {}", kernel_names[launch_msg->brisc_watcher_kernel_id]);
log_info(" ncrisc: {}", kernel_names[launch_msg->ncrisc_watcher_kernel_id]);
log_info(" triscs: {}", kernel_names[launch_msg->triscs_watcher_kernel_id]);
log_info(" brisc : {}", kernel_names[launch_msg->watcher_kernel_ids[DISPATCH_CLASS_TENSIX_DM0]]);
log_info(" ncrisc: {}", kernel_names[launch_msg->watcher_kernel_ids[DISPATCH_CLASS_TENSIX_DM1]]);
log_info(" triscs: {}", kernel_names[launch_msg->watcher_kernel_ids[DISPATCH_CLASS_TENSIX_COMPUTE]]);
}

static void dump_l1_status(FILE *f, Device *device, CoreCoord core, const launch_msg_t *launch_msg) {
Expand Down Expand Up @@ -158,13 +158,13 @@ static const char *get_riscv_name(CoreCoord core, uint32_t type) {

static string get_kernel_name(CoreCoord core, const launch_msg_t *launch_msg, uint32_t type) {
switch (type) {
case DebugBrisc:
case DebugBrisc: return kernel_names[launch_msg->watcher_kernel_ids[DISPATCH_CLASS_TENSIX_DM0]];
case DebugErisc:
case DebugIErisc: return kernel_names[launch_msg->brisc_watcher_kernel_id];
case DebugNCrisc: return kernel_names[launch_msg->ncrisc_watcher_kernel_id];
case DebugIErisc: return kernel_names[launch_msg->watcher_kernel_ids[DISPATCH_CLASS_ETH_DM0]];
case DebugNCrisc: return kernel_names[launch_msg->watcher_kernel_ids[DISPATCH_CLASS_TENSIX_DM1]];
case DebugTrisc0:
case DebugTrisc1:
case DebugTrisc2: return kernel_names[launch_msg->triscs_watcher_kernel_id];
case DebugTrisc2: return kernel_names[launch_msg->watcher_kernel_ids[DISPATCH_CLASS_TENSIX_COMPUTE]];
default:
log_running_kernels(launch_msg);
TT_THROW("Watcher data corrupted, unexpected riscv type on core {}: {}", core.str(), type);
Expand Down Expand Up @@ -488,40 +488,40 @@ static void dump_run_mailboxes(

fprintf(f, "|");

if (launch_msg->enable_brisc == 1) {
if (launch_msg->enables[DISPATCH_CLASS_TENSIX_DM0] == 1) {
fprintf(f, "B");
} else if (launch_msg->enable_brisc == 0) {
} else if (launch_msg->enables[DISPATCH_CLASS_TENSIX_DM0] == 0) {
fprintf(f, "b");
} else {
log_running_kernels(launch_msg);
TT_THROW(
"Watcher data corruption, unexpected brisc enable on core {}: {} (expected 0 or 1)",
core.str(),
launch_msg->enable_brisc);
launch_msg->enables[DISPATCH_CLASS_TENSIX_DM0]);
}

if (launch_msg->enable_ncrisc == 1) {
if (launch_msg->enables[DISPATCH_CLASS_TENSIX_DM1] == 1) {
fprintf(f, "N");
} else if (launch_msg->enable_ncrisc == 0) {
} else if (launch_msg->enables[DISPATCH_CLASS_TENSIX_DM1] == 0) {
fprintf(f, "n");
} else {
log_running_kernels(launch_msg);
TT_THROW(
"Watcher data corruption, unexpected ncrisc enable on core {}: {} (expected 0 or 1)",
core.str(),
launch_msg->enable_ncrisc);
launch_msg->enables[DISPATCH_CLASS_TENSIX_DM1]);
}

if (launch_msg->enable_triscs == 1) {
if (launch_msg->enables[DISPATCH_CLASS_TENSIX_COMPUTE] == 1) {
fprintf(f, "T");
} else if (launch_msg->enable_triscs == 0) {
} else if (launch_msg->enables[DISPATCH_CLASS_TENSIX_COMPUTE] == 0) {
fprintf(f, "t");
} else {
log_running_kernels(launch_msg);
TT_THROW(
"Watcher data corruption, unexpected trisc enable on core {}: {} (expected 0 or 1)",
core.str(),
launch_msg->enable_triscs);
launch_msg->enables[DISPATCH_CLASS_TENSIX_COMPUTE]);
}

fprintf(f, " ");
Expand Down Expand Up @@ -565,35 +565,35 @@ static void dump_sync_regs(FILE *f, Device *device, CoreCoord core) {

static void validate_kernel_ids(
FILE *f, std::map<int, bool> &used_kernel_names, chip_id_t device_id, CoreCoord core, const launch_msg_t *launch) {
if (launch->brisc_watcher_kernel_id >= kernel_names.size()) {
if (launch->watcher_kernel_ids[DISPATCH_CLASS_TENSIX_DM0] >= kernel_names.size()) {
TT_THROW(
"Watcher data corruption, unexpected brisc kernel id on Device {} core {}: {} (last valid {})",
device_id,
core.str(),
launch->brisc_watcher_kernel_id,
launch->watcher_kernel_ids[DISPATCH_CLASS_TENSIX_DM0],
kernel_names.size());
}
used_kernel_names[launch->brisc_watcher_kernel_id] = true;
used_kernel_names[launch->watcher_kernel_ids[DISPATCH_CLASS_TENSIX_DM0]] = true;

if (launch->ncrisc_watcher_kernel_id >= kernel_names.size()) {
if (launch->watcher_kernel_ids[DISPATCH_CLASS_TENSIX_DM1] >= kernel_names.size()) {
TT_THROW(
"Watcher data corruption, unexpected ncrisc kernel id on Device {} core {}: {} (last valid {})",
device_id,
core.str(),
launch->ncrisc_watcher_kernel_id,
launch->watcher_kernel_ids[DISPATCH_CLASS_TENSIX_DM1],
kernel_names.size());
}
used_kernel_names[launch->ncrisc_watcher_kernel_id] = true;
used_kernel_names[launch->watcher_kernel_ids[DISPATCH_CLASS_TENSIX_DM1]] = true;

if (launch->triscs_watcher_kernel_id >= kernel_names.size()) {
if (launch->watcher_kernel_ids[DISPATCH_CLASS_TENSIX_COMPUTE] >= kernel_names.size()) {
TT_THROW(
"Watcher data corruption, unexpected trisc kernel id on Device {} core {}: {} (last valid {})",
device_id,
core.str(),
launch->triscs_watcher_kernel_id,
launch->watcher_kernel_ids[DISPATCH_CLASS_TENSIX_COMPUTE],
kernel_names.size());
}
used_kernel_names[launch->triscs_watcher_kernel_id] = true;
used_kernel_names[launch->watcher_kernel_ids[DISPATCH_CLASS_TENSIX_COMPUTE]] = true;
}

static void dump_core(
Expand Down Expand Up @@ -679,14 +679,14 @@ static void dump_core(

// Eth core only reports erisc kernel id, uses the brisc field
if (is_eth_core) {
fprintf(f, "k_id:%d", mbox_data->launch.brisc_watcher_kernel_id);
fprintf(f, "k_id:%d", mbox_data->launch.watcher_kernel_ids[DISPATCH_CLASS_ETH_DM0]);
} else {
fprintf(
f,
"k_ids:%d|%d|%d",
mbox_data->launch.brisc_watcher_kernel_id,
mbox_data->launch.ncrisc_watcher_kernel_id,
mbox_data->launch.triscs_watcher_kernel_id);
mbox_data->launch.watcher_kernel_ids[DISPATCH_CLASS_TENSIX_DM0],
mbox_data->launch.watcher_kernel_ids[DISPATCH_CLASS_TENSIX_DM1],
mbox_data->launch.watcher_kernel_ids[DISPATCH_CLASS_TENSIX_COMPUTE]);
}

// Ring buffer at the end because it can print a bunch of data
Expand Down
17 changes: 4 additions & 13 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,19 +201,10 @@ void Device::initialize_firmware(CoreCoord phys_core, launch_msg_t *launch_msg)
void Device::initialize_and_launch_firmware() {
ZoneScoped;

launch_msg_t launch_msg = {
.brisc_watcher_kernel_id = 0,
.ncrisc_watcher_kernel_id = 0,
.triscs_watcher_kernel_id = 0,
.ncrisc_kernel_size16 = 0,
.mode = DISPATCH_MODE_HOST,
.brisc_noc_id = 0,
.enable_brisc = 0,
.enable_ncrisc = 0,
.enable_triscs = 0,
.enable_erisc = 0,
.run = RUN_MSG_INIT,
};
launch_msg_t launch_msg;
std::memset(&launch_msg, 0, sizeof(launch_msg_t));
launch_msg.mode = DISPATCH_MODE_HOST,
launch_msg.run = RUN_MSG_INIT,

// Download to worker cores
log_debug("Initializing firmware");
Expand Down
8 changes: 5 additions & 3 deletions tt_metal/impl/kernels/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class Kernel : public JitBuildSettings {
std::map<std::string, std::string> defines() const { return defines_; }

virtual RISCV processor() const = 0;
dispatch_core_processor_classes dispatch_class() { return this->dispatch_class_; }

virtual bool configure(Device *device, const CoreCoord &logical_core) const = 0;

Expand Down Expand Up @@ -131,6 +132,7 @@ class Kernel : public JitBuildSettings {
// TODO: break this dependency by https://github.com/tenstorrent/tt-metal/issues/3381
std::unordered_map<chip_id_t, std::vector<ll_api::memory>> binaries_;
uint16_t binary_size16_;
dispatch_core_processor_classes dispatch_class_;
std::vector<uint32_t> compile_time_args_;
std::vector< std::vector< std::vector<uint32_t>> > core_to_runtime_args_;
std::vector< std::vector< RuntimeArgsData> > core_to_runtime_args_data_;
Expand All @@ -149,7 +151,7 @@ class Kernel : public JitBuildSettings {

class DataMovementKernel : public Kernel {
public:
DataMovementKernel(const std::string &kernel_path, const CoreRangeSet &cr_set, const DataMovementConfig &config) : Kernel(kernel_path, cr_set, config.compile_args, config.defines), config_(config) {}
DataMovementKernel(const std::string &kernel_path, const CoreRangeSet &cr_set, const DataMovementConfig &config) : Kernel(kernel_path, cr_set, config.compile_args, config.defines), config_(config) { this->dispatch_class_ = (config.processor == DataMovementProcessor::RISCV_0) ? DISPATCH_CLASS_TENSIX_DM0 : DISPATCH_CLASS_TENSIX_DM1; }

~DataMovementKernel() {}

Expand All @@ -176,7 +178,7 @@ class DataMovementKernel : public Kernel {
class EthernetKernel : public Kernel {
public:
EthernetKernel(const std::string &kernel_path, const CoreRangeSet &cr_set, const EthernetConfig &config) :
Kernel(kernel_path, cr_set, config.compile_args, config.defines), config_(config) {}
Kernel(kernel_path, cr_set, config.compile_args, config.defines), config_(config) { this->dispatch_class_ = DISPATCH_CLASS_ETH_DM0; }

~EthernetKernel() {}

Expand All @@ -202,7 +204,7 @@ class EthernetKernel : public Kernel {

class ComputeKernel : public Kernel {
public:
ComputeKernel(const std::string &kernel_path, const CoreRangeSet &cr_set, const ComputeConfig &config) : Kernel(kernel_path, cr_set, config.compile_args, config.defines), config_(config) {}
ComputeKernel(const std::string &kernel_path, const CoreRangeSet &cr_set, const ComputeConfig &config) : Kernel(kernel_path, cr_set, config.compile_args, config.defines), config_(config) { this->dispatch_class_ = DISPATCH_CLASS_TENSIX_COMPUTE; }

~ComputeKernel() {}

Expand Down
Loading

0 comments on commit b5553d3

Please sign in to comment.