diff --git a/tt_metal/impl/dispatch/kernel_config/demux_kernel.cpp b/tt_metal/impl/dispatch/kernel_config/demux_kernel.cpp index d89b61cab5cc..463bc7864268 100644 --- a/tt_metal/impl/dispatch/kernel_config/demux_kernel.cpp +++ b/tt_metal/impl/dispatch/kernel_config/demux_kernel.cpp @@ -10,90 +10,90 @@ void DemuxKernel::GenerateStaticConfigs() { uint16_t channel = - tt::Cluster::instance().get_assigned_channel_for_device(this->servicing_device_id); // TODO: this can be mmio - this->logical_core = - dispatch_core_manager::instance().demux_core(this->servicing_device_id, channel, this->placement_cq_id); - this->config.endpoint_id_start_index = 0xD1; - this->config.rx_queue_start_addr_words = + tt::Cluster::instance().get_assigned_channel_for_device(servicing_device_id_); // TODO: this can be mmio + logical_core_ = dispatch_core_manager::instance().demux_core(servicing_device_id_, channel, placement_cq_id_); + static_config_.endpoint_id_start_index = 0xD1; + static_config_.rx_queue_start_addr_words = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::UNRESERVED) >> 4; - this->config.rx_queue_size_words = 0x10000 >> 4; + static_config_.rx_queue_size_words = 0x10000 >> 4; + static_config_.demux_fan_out = downstream_kernels_.size(); - this->config.remote_rx_network_type = DispatchRemoteNetworkType::NOC0; + static_config_.remote_rx_network_type = DispatchRemoteNetworkType::NOC0; - this->config.test_results_buf_addr_arg = 0; - this->config.test_results_buf_size_bytes = 0; - this->config.timeout_cycles = 0; + static_config_.test_results_buf_addr_arg = 0; + static_config_.test_results_buf_size_bytes = 0; + static_config_.timeout_cycles = 0; // TODO: Do we need an upstream sem here? - for (int idx = 0; idx < this->downstream_kernels.size(); idx++) { - FDKernel* k = this->downstream_kernels[idx]; - this->config.remote_tx_queue_id[idx] = 0; - this->config.remote_tx_network_type[idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; - this->config.output_depacketize_cb_log_page_size[idx] = dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE; + for (int idx = 0; idx < downstream_kernels_.size(); idx++) { + FDKernel* k = downstream_kernels_[idx]; + static_config_.remote_tx_queue_id[idx] = 0; + static_config_.remote_tx_network_type[idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; + static_config_.output_depacketize_cb_log_page_size[idx] = dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE; // Only connected dispatchers need a semaphore. TODO: can initialize anyways, but this matches previous // implementation if (dynamic_cast(k)) { - this->config.output_depacketize_local_sem_id[idx] = - tt::tt_metal::CreateSemaphore(*program, this->logical_core, 0, GetCoreType()); + static_config_.output_depacketize_local_sem_id[idx] = + tt::tt_metal::CreateSemaphore(*program_, logical_core_, 0, GetCoreType()); } - this->config.output_depacketize_remove_header[idx] = 1; + static_config_.output_depacketize_remove_header[idx] = 1; } } void DemuxKernel::GenerateDependentConfigs() { auto& my_dispatch_constants = dispatch_constants::get(GetCoreType()); // Upstream, expect EthTunneler or DEMUX - TT_ASSERT(this->upstream_kernels.size() == 1); - if (auto us = dynamic_cast(this->upstream_kernels[0])) { - this->config.remote_rx_x = us->GetVirtualCore().x; - this->config.remote_rx_y = us->GetVirtualCore().y; - this->config.remote_rx_queue_id = us->GetConfig().vc_count.value() * 2 - 1; - } else if (auto us = dynamic_cast(this->upstream_kernels[0])) { - this->config.remote_rx_x = us->GetVirtualCore().x; - this->config.remote_rx_y = us->GetVirtualCore().y; - this->config.remote_rx_queue_id = us->GetDownstreamPort(this) + 1; // TODO: can this be cleaned up? + TT_ASSERT(upstream_kernels_.size() == 1); + if (auto us = dynamic_cast(upstream_kernels_[0])) { + dependent_config_.remote_rx_x = us->GetVirtualCore().x; + dependent_config_.remote_rx_y = us->GetVirtualCore().y; + dependent_config_.remote_rx_queue_id = us->GetStaticConfig().vc_count.value() * 2 - 1; + } else if (auto us = dynamic_cast(upstream_kernels_[0])) { + dependent_config_.remote_rx_x = us->GetVirtualCore().x; + dependent_config_.remote_rx_y = us->GetVirtualCore().y; + dependent_config_.remote_rx_queue_id = us->GetDownstreamPort(this) + 1; // TODO: can this be cleaned up? // TODO: why is just this one different? Just match previous implementation for now if (us->GetDownstreamPort(this) == 1) { - this->config.endpoint_id_start_index = - this->config.endpoint_id_start_index.value() + this->downstream_kernels.size(); + static_config_.endpoint_id_start_index = + static_config_.endpoint_id_start_index.value() + downstream_kernels_.size(); } } else { TT_FATAL(false, "Unexpected kernel type upstream of DEMUX"); } // Downstream, expect DISPATCH_H or DEMUX - TT_ASSERT(this->downstream_kernels.size() <= MAX_SWITCH_FAN_OUT && this->downstream_kernels.size() > 0); - this->config.demux_fan_out = this->downstream_kernels.size(); - this->config.output_depacketize = 0; // Populated per downstream kernel - for (int idx = 0; idx < this->downstream_kernels.size(); idx++) { - FDKernel* k = this->downstream_kernels[idx]; - this->config.remote_tx_x[idx] = k->GetVirtualCore().x; - this->config.remote_tx_y[idx] = k->GetVirtualCore().y; + TT_ASSERT(downstream_kernels_.size() <= MAX_SWITCH_FAN_OUT && downstream_kernels_.size() > 0); + dependent_config_.output_depacketize = 0; // Populated per downstream kernel + for (int idx = 0; idx < downstream_kernels_.size(); idx++) { + FDKernel* k = downstream_kernels_[idx]; + dependent_config_.remote_tx_x[idx] = k->GetVirtualCore().x; + dependent_config_.remote_tx_y[idx] = k->GetVirtualCore().y; // Expect downstream to be either a DISPATCH or another DEMUX if (auto dispatch_kernel = dynamic_cast(k)) { - this->config.remote_tx_queue_start_addr_words[idx] = - dispatch_kernel->GetConfig().dispatch_cb_base.value() >> 4; - this->config.remote_tx_queue_size_words[idx] = - ((1 << dispatch_kernel->GetConfig().dispatch_cb_log_page_size.value()) * - dispatch_kernel->GetConfig().dispatch_cb_pages.value()) >> + dependent_config_.remote_tx_queue_start_addr_words[idx] = + dispatch_kernel->GetStaticConfig().dispatch_cb_base.value() >> 4; + dependent_config_.remote_tx_queue_size_words[idx] = + ((1 << dispatch_kernel->GetStaticConfig().dispatch_cb_log_page_size.value()) * + dispatch_kernel->GetStaticConfig().dispatch_cb_pages.value()) >> 4; - this->config.output_depacketize = - this->config.output_depacketize.value() | (1 << idx); // Only depacketize for dispatch downstream - this->config.output_depacketize_downstream_sem_id[idx] = dispatch_kernel->GetConfig().my_dispatch_cb_sem_id; + dependent_config_.output_depacketize = + dependent_config_.output_depacketize.value() | (1 << idx); // Only depacketize for dispatch downstream + dependent_config_.output_depacketize_downstream_sem_id[idx] = + dispatch_kernel->GetStaticConfig().my_dispatch_cb_sem_id; uint32_t dest_map_array[4] = {0, 1, 2, 3}; // TODO: how to set these generically? Currently just matching // the hard-coded previous implementation uint64_t dest_endpoint_output_map = packet_switch_dest_pack(dest_map_array, 4); - this->config.dest_endpoint_output_map_hi = (uint32_t)(dest_endpoint_output_map >> 32); - this->config.dest_endpoint_output_map_lo = (uint32_t)(dest_endpoint_output_map & 0xFFFFFFFF); + dependent_config_.dest_endpoint_output_map_hi = (uint32_t)(dest_endpoint_output_map >> 32); + dependent_config_.dest_endpoint_output_map_lo = (uint32_t)(dest_endpoint_output_map & 0xFFFFFFFF); } else if (auto demux_kernel = dynamic_cast(k)) { - this->config.remote_tx_queue_start_addr_words[idx] = - demux_kernel->GetConfig().rx_queue_start_addr_words.value(); - this->config.remote_tx_queue_size_words[idx] = 0x1000; // TODO: hard-coded on previous implementation + dependent_config_.remote_tx_queue_start_addr_words[idx] = + demux_kernel->GetStaticConfig().rx_queue_start_addr_words.value(); + dependent_config_.remote_tx_queue_size_words[idx] = 0x1000; // TODO: hard-coded on previous implementation // Match previous implementation where downstream demux has output_depacketize fields zeroed out. TODO: can // remove this later - this->config.output_depacketize_downstream_sem_id[idx] = 0; + dependent_config_.output_depacketize_downstream_sem_id[idx] = 0; uint64_t dest_endpoint_output_map; - if (device->num_hw_cqs() == 1) { + if (device_->num_hw_cqs() == 1) { uint32_t dest_map_array[4] = {0, 0, 1, 1}; // TODO: how to set these generically? Currently just // matching the hard-coded previous implementation dest_endpoint_output_map = packet_switch_dest_pack(dest_map_array, 4); @@ -101,24 +101,24 @@ void DemuxKernel::GenerateDependentConfigs() { uint32_t dest_map_array[8] = {0, 0, 0, 0, 1, 1, 1, 1}; dest_endpoint_output_map = packet_switch_dest_pack(dest_map_array, 8); } - this->config.dest_endpoint_output_map_hi = (uint32_t)(dest_endpoint_output_map >> 32); - this->config.dest_endpoint_output_map_lo = (uint32_t)(dest_endpoint_output_map & 0xFFFFFFFF); + dependent_config_.dest_endpoint_output_map_hi = (uint32_t)(dest_endpoint_output_map >> 32); + dependent_config_.dest_endpoint_output_map_lo = (uint32_t)(dest_endpoint_output_map & 0xFFFFFFFF); } else { TT_FATAL(false, "Unexpected kernel type downstream of DEMUX"); } } // TODO: this is just to match the previous implementation hard-code, remove later if (!tt::Cluster::instance().is_galaxy_cluster()) { - this->config.output_depacketize = 0x3; + dependent_config_.output_depacketize = 0x3; } } void DemuxKernel::CreateKernel() { std::vector compile_args = { - config.endpoint_id_start_index.value(), - config.rx_queue_start_addr_words.value(), - config.rx_queue_size_words.value(), - config.demux_fan_out.value(), + static_config_.endpoint_id_start_index.value(), + static_config_.rx_queue_start_addr_words.value(), + static_config_.rx_queue_size_words.value(), + static_config_.demux_fan_out.value(), 0, 0, 0, @@ -131,66 +131,65 @@ void DemuxKernel::CreateKernel() { 0, 0, 0, // Populate remote_tx_queue_start_addr_words & remote_tx_queue_size_words after - config.remote_rx_x.value(), - config.remote_rx_y.value(), - config.remote_rx_queue_id.value(), - config.remote_rx_network_type.value(), - config.dest_endpoint_output_map_hi.value(), - config.dest_endpoint_output_map_lo.value(), - config.test_results_buf_addr_arg.value(), - config.test_results_buf_size_bytes.value(), - config.timeout_cycles.value(), - config.output_depacketize.value(), + dependent_config_.remote_rx_x.value(), + dependent_config_.remote_rx_y.value(), + dependent_config_.remote_rx_queue_id.value(), + static_config_.remote_rx_network_type.value(), + dependent_config_.dest_endpoint_output_map_hi.value(), + dependent_config_.dest_endpoint_output_map_lo.value(), + static_config_.test_results_buf_addr_arg.value(), + static_config_.test_results_buf_size_bytes.value(), + static_config_.timeout_cycles.value(), + dependent_config_.output_depacketize.value(), 0, 0, 0, 0 // Populate output_depacketize_config after }; for (int idx = 0; idx < MAX_SWITCH_FAN_OUT; idx++) { - if (config.remote_tx_x[idx]) { - compile_args[4 + idx] |= (config.remote_tx_x[idx].value() & 0xFF); - compile_args[4 + idx] |= (config.remote_tx_y[idx].value() & 0xFF) << 8; - compile_args[4 + idx] |= (config.remote_tx_queue_id[idx].value() & 0xFF) << 16; - compile_args[4 + idx] |= (config.remote_tx_network_type[idx].value() & 0xFF) << 24; + if (dependent_config_.remote_tx_x[idx]) { + compile_args[4 + idx] |= (dependent_config_.remote_tx_x[idx].value() & 0xFF); + compile_args[4 + idx] |= (dependent_config_.remote_tx_y[idx].value() & 0xFF) << 8; + compile_args[4 + idx] |= (static_config_.remote_tx_queue_id[idx].value() & 0xFF) << 16; + compile_args[4 + idx] |= (static_config_.remote_tx_network_type[idx].value() & 0xFF) << 24; } - if (config.remote_tx_queue_start_addr_words[idx]) { - compile_args[8 + idx * 2] = config.remote_tx_queue_start_addr_words[idx].value(); - compile_args[9 + idx * 2] = config.remote_tx_queue_size_words[idx].value(); + if (dependent_config_.remote_tx_queue_start_addr_words[idx]) { + compile_args[8 + idx * 2] = dependent_config_.remote_tx_queue_start_addr_words[idx].value(); + compile_args[9 + idx * 2] = dependent_config_.remote_tx_queue_size_words[idx].value(); } - if (config.output_depacketize_cb_log_page_size[idx]) { + if (static_config_.output_depacketize_cb_log_page_size[idx]) { // To match previous implementation, zero these out if output_depacketize is not set. TODO: don't have to do // this - if (config.output_depacketize.value() & (1 << idx)) { - compile_args[26 + idx] |= (config.output_depacketize_cb_log_page_size[idx].value() & 0xFF); - compile_args[26 + idx] |= (config.output_depacketize_downstream_sem_id[idx].value() & 0xFF) << 8; - compile_args[26 + idx] |= (config.output_depacketize_local_sem_id[idx].value() & 0xFF) << 16; - compile_args[26 + idx] |= (config.output_depacketize_remove_header[idx].value() & 0xFF) << 24; + if (dependent_config_.output_depacketize.value() & (1 << idx)) { + compile_args[26 + idx] |= (static_config_.output_depacketize_cb_log_page_size[idx].value() & 0xFF); + compile_args[26 + idx] |= (dependent_config_.output_depacketize_downstream_sem_id[idx].value() & 0xFF) + << 8; + compile_args[26 + idx] |= (static_config_.output_depacketize_local_sem_id[idx].value() & 0xFF) << 16; + compile_args[26 + idx] |= (static_config_.output_depacketize_remove_header[idx].value() & 0xFF) << 24; } } } TT_ASSERT(compile_args.size() == 30); - const auto& grid_size = device->grid_size(); + const auto& grid_size = device_->grid_size(); tt_cxy_pair my_virtual_core = - tt::Cluster::instance().get_virtual_coordinate_from_logical_coordinates(this->logical_core, GetCoreType()); + tt::Cluster::instance().get_virtual_coordinate_from_logical_coordinates(logical_core_, GetCoreType()); std::map defines = { // All of these unused, remove later - {"MY_NOC_X", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.non_dispatch_noc, grid_size.x, 0))}, - {"MY_NOC_Y", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.non_dispatch_noc, grid_size.y, 0))}, - {"UPSTREAM_NOC_INDEX", std::to_string(this->noc_selection.upstream_noc)}, + {"MY_NOC_X", std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.non_dispatch_noc, grid_size.x, 0))}, + {"MY_NOC_Y", std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.non_dispatch_noc, grid_size.y, 0))}, + {"UPSTREAM_NOC_INDEX", std::to_string(noc_selection_.upstream_noc)}, {"UPSTREAM_NOC_X", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.upstream_noc, grid_size.x, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.upstream_noc, grid_size.x, 0))}, {"UPSTREAM_NOC_Y", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.upstream_noc, grid_size.y, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.upstream_noc, grid_size.y, 0))}, {"DOWNSTREAM_NOC_X", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.downstream_noc, grid_size.x, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.downstream_noc, grid_size.x, 0))}, {"DOWNSTREAM_NOC_Y", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.downstream_noc, grid_size.y, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.downstream_noc, grid_size.y, 0))}, {"DOWNSTREAM_SLAVE_NOC_X", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.downstream_noc, grid_size.x, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.downstream_noc, grid_size.x, 0))}, {"DOWNSTREAM_SLAVE_NOC_Y", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.downstream_noc, grid_size.y, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.downstream_noc, grid_size.y, 0))}, {"SKIP_NOC_LOGGING", "1"}}; configure_kernel_variant(dispatch_kernel_file_names[DEMUX], compile_args, defines, false, false, false); } diff --git a/tt_metal/impl/dispatch/kernel_config/demux_kernel.hpp b/tt_metal/impl/dispatch/kernel_config/demux_kernel.hpp index 66f42e815cd6..2fa982d1e662 100644 --- a/tt_metal/impl/dispatch/kernel_config/demux_kernel.hpp +++ b/tt_metal/impl/dispatch/kernel_config/demux_kernel.hpp @@ -4,34 +4,38 @@ #pragma once #include "fd_kernel.hpp" -typedef struct demux_config { +typedef struct demux_static_config { std::optional endpoint_id_start_index; std::optional rx_queue_start_addr_words; std::optional rx_queue_size_words; - std::optional demux_fan_out; // Dependent + std::optional demux_fan_out; - std::array, MAX_SWITCH_FAN_OUT> remote_tx_x; // [4:7], dependent - std::array, MAX_SWITCH_FAN_OUT> remote_tx_y; // [4:7], dependent std::array, MAX_SWITCH_FAN_OUT> remote_tx_queue_id; // [4:7] std::array, MAX_SWITCH_FAN_OUT> remote_tx_network_type; // [4:7] + std::optional remote_rx_network_type; + + std::optional test_results_buf_addr_arg; + std::optional test_results_buf_size_bytes; + std::optional timeout_cycles; + std::array, MAX_SWITCH_FAN_OUT> output_depacketize_cb_log_page_size; // [26:29] + std::array, MAX_SWITCH_FAN_OUT> output_depacketize_local_sem_id; // [26:29] + std::array, MAX_SWITCH_FAN_OUT> output_depacketize_remove_header; // [26:29] +} demux_static_config_t; + +typedef struct demux_dependent_config { + std::array, MAX_SWITCH_FAN_OUT> remote_tx_x; // [4:7], dependent + std::array, MAX_SWITCH_FAN_OUT> remote_tx_y; // [4:7], dependent std::array, MAX_SWITCH_FAN_OUT> remote_tx_queue_start_addr_words; // [8:2:14], dependent std::array, MAX_SWITCH_FAN_OUT> remote_tx_queue_size_words; // [9:2:15], dependent std::optional remote_rx_x; // Dependent std::optional remote_rx_y; // Dependent std::optional remote_rx_queue_id; // Dependent - std::optional remote_rx_network_type; std::optional dest_endpoint_output_map_hi; // Dependent std::optional dest_endpoint_output_map_lo; // Dependent - std::optional test_results_buf_addr_arg; - std::optional test_results_buf_size_bytes; - std::optional timeout_cycles; - std::optional output_depacketize; // Dependent - std::array, MAX_SWITCH_FAN_OUT> output_depacketize_cb_log_page_size; // [26:29] + std::optional output_depacketize; // Dependent std::array, MAX_SWITCH_FAN_OUT> output_depacketize_downstream_sem_id; // [26:29], dependent - std::array, MAX_SWITCH_FAN_OUT> output_depacketize_local_sem_id; // [26:29] - std::array, MAX_SWITCH_FAN_OUT> output_depacketize_remove_header; // [26:29] -} demux_config_t; +} demux_dependent_config_t; class DemuxKernel : public FDKernel { public: @@ -41,10 +45,11 @@ class DemuxKernel : public FDKernel { void CreateKernel() override; void GenerateStaticConfigs() override; void GenerateDependentConfigs() override; - const demux_config_t& GetConfig() { return this->config; } - void SetPlacementCQID(int id) { this->placement_cq_id = id; } + const demux_static_config_t& GetStaticConfig() { return static_config_; } + void SetPlacementCQID(int id) { placement_cq_id_ = id; } private: - demux_config_t config; - int placement_cq_id; // TODO: remove channel hard-coding for dispatch core manager + demux_static_config_t static_config_; + demux_dependent_config_t dependent_config_; + int placement_cq_id_; // TODO: remove channel hard-coding for dispatch core manager }; diff --git a/tt_metal/impl/dispatch/kernel_config/dispatch_kernel.cpp b/tt_metal/impl/dispatch/kernel_config/dispatch_kernel.cpp index caa60607b1ce..b60a09f4aea8 100644 --- a/tt_metal/impl/dispatch/kernel_config/dispatch_kernel.cpp +++ b/tt_metal/impl/dispatch/kernel_config/dispatch_kernel.cpp @@ -11,163 +11,165 @@ #include "tt_metal/detail/tt_metal.hpp" void DispatchKernel::GenerateStaticConfigs() { - uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device->id()); - uint8_t cq_id = this->cq_id; + uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device_->id()); + uint8_t cq_id_ = this->cq_id_; auto& my_dispatch_constants = dispatch_constants::get(GetCoreType()); - if (this->config.is_h_variant.value() && this->config.is_d_variant.value()) { + if (static_config_.is_h_variant.value() && this->static_config_.is_d_variant.value()) { uint32_t cq_start = my_dispatch_constants.get_host_command_queue_addr(CommandQueueHostAddrType::UNRESERVED); - uint32_t cq_size = device->sysmem_manager().get_cq_size(); - uint32_t command_queue_start_addr = get_absolute_cq_offset(channel, cq_id, cq_size); + uint32_t cq_size = device_->sysmem_manager().get_cq_size(); + uint32_t command_queue_start_addr = get_absolute_cq_offset(channel, cq_id_, cq_size); uint32_t issue_queue_start_addr = command_queue_start_addr + cq_start; - uint32_t issue_queue_size = device->sysmem_manager().get_issue_queue_size(cq_id); + uint32_t issue_queue_size = device_->sysmem_manager().get_issue_queue_size(cq_id_); uint32_t completion_queue_start_addr = issue_queue_start_addr + issue_queue_size; - uint32_t completion_queue_size = device->sysmem_manager().get_completion_queue_size(cq_id); - - this->logical_core = dispatch_core_manager::instance().dispatcher_core(device->id(), channel, cq_id); - this->config.dispatch_cb_base = my_dispatch_constants.dispatch_buffer_base(); - this->config.dispatch_cb_log_page_size = dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE; - this->config.dispatch_cb_pages = my_dispatch_constants.dispatch_buffer_pages(); - this->config.my_dispatch_cb_sem_id = - tt::tt_metal::CreateSemaphore(*program, this->logical_core, 0, GetCoreType()); - - this->config.dispatch_cb_blocks = dispatch_constants::DISPATCH_BUFFER_SIZE_BLOCKS; - this->config.command_queue_base_addr = command_queue_start_addr; - this->config.completion_queue_base_addr = completion_queue_start_addr; - this->config.completion_queue_size = completion_queue_size; - - this->config.my_downstream_cb_sem_id = 0; // unused - - this->config.split_dispatch_page_preamble_size = 0; // unused - this->config.split_prefetch = false; // split_prefetcher - this->config.prefetch_h_noc_xy = 0; // unused prefetch noc_xy - this->config.prefetch_h_local_downstream_sem_addr = 0; // unused prefetch_local_downstream_sem_addr - this->config.prefetch_h_max_credits = 0; // unused prefetch_downstream_buffer_pages - - this->config.packed_write_max_unicast_sub_cmds = - device->compute_with_storage_grid_size().x * device->compute_with_storage_grid_size().y; - this->config.dispatch_s_sync_sem_base_addr = + uint32_t completion_queue_size = device_->sysmem_manager().get_completion_queue_size(cq_id_); + + logical_core_ = dispatch_core_manager::instance().dispatcher_core(device_->id(), channel, cq_id_); + static_config_.dispatch_cb_base = my_dispatch_constants.dispatch_buffer_base(); + static_config_.dispatch_cb_log_page_size = dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE; + static_config_.dispatch_cb_pages = my_dispatch_constants.dispatch_buffer_pages(); + static_config_.my_dispatch_cb_sem_id = + tt::tt_metal::CreateSemaphore(*program_, logical_core_, 0, GetCoreType()); + + static_config_.dispatch_cb_blocks = dispatch_constants::DISPATCH_BUFFER_SIZE_BLOCKS; + static_config_.command_queue_base_addr = command_queue_start_addr; + static_config_.completion_queue_base_addr = completion_queue_start_addr; + static_config_.completion_queue_size = completion_queue_size; + + static_config_.my_downstream_cb_sem_id = 0; // unused + + static_config_.split_dispatch_page_preamble_size = 0; // unused + static_config_.split_prefetch = false; // split_prefetcher + dependent_config_.prefetch_h_noc_xy = 0; // unused prefetch noc_xy + dependent_config_.prefetch_h_local_downstream_sem_addr = 0; // unused prefetch_local_downstream_sem_addr + static_config_.prefetch_h_max_credits = 0; // unused prefetch_downstream_buffer_pages + + static_config_.packed_write_max_unicast_sub_cmds = + device_->compute_with_storage_grid_size().x * device_->compute_with_storage_grid_size().y; + static_config_.dispatch_s_sync_sem_base_addr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_S_SYNC_SEM); - this->config.max_num_worker_sems = dispatch_constants::DISPATCH_MESSAGE_ENTRIES; - this->config.max_num_go_signal_noc_data_entries = dispatch_constants::DISPATCH_GO_SIGNAL_NOC_DATA_ENTRIES; - this->config.mcast_go_signal_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::GO_MSG); - this->config.unicast_go_signal_addr = + static_config_.max_num_worker_sems = dispatch_constants::DISPATCH_MESSAGE_ENTRIES; + static_config_.max_num_go_signal_noc_data_entries = dispatch_constants::DISPATCH_GO_SIGNAL_NOC_DATA_ENTRIES; + static_config_.mcast_go_signal_addr = + hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::GO_MSG); + static_config_.unicast_go_signal_addr = (hal.get_programmable_core_type_index(HalProgrammableCoreType::ACTIVE_ETH) != -1) ? hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::GO_MSG) : 0; - this->config.distributed_dispatcher = (GetCoreType() == CoreType::ETH); + static_config_.distributed_dispatcher = (GetCoreType() == CoreType::ETH); - this->config.host_completion_q_wr_ptr = + static_config_.host_completion_q_wr_ptr = my_dispatch_constants.get_host_command_queue_addr(CommandQueueHostAddrType::COMPLETION_Q_WR); - this->config.dev_completion_q_wr_ptr = + static_config_.dev_completion_q_wr_ptr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q_WR); - this->config.dev_completion_q_rd_ptr = + static_config_.dev_completion_q_rd_ptr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q_RD); - } else if (this->config.is_h_variant.value()) { + } else if (static_config_.is_h_variant.value()) { // DISPATCH_H services a remote chip, and so has a different channel - channel = tt::Cluster::instance().get_assigned_channel_for_device(servicing_device_id); + channel = tt::Cluster::instance().get_assigned_channel_for_device(servicing_device_id_); uint32_t cq_start = my_dispatch_constants.get_host_command_queue_addr(CommandQueueHostAddrType::UNRESERVED); - uint32_t cq_size = device->sysmem_manager().get_cq_size(); - uint32_t command_queue_start_addr = get_absolute_cq_offset(channel, cq_id, cq_size); + uint32_t cq_size = device_->sysmem_manager().get_cq_size(); + uint32_t command_queue_start_addr = get_absolute_cq_offset(channel, cq_id_, cq_size); uint32_t issue_queue_start_addr = command_queue_start_addr + cq_start; - uint32_t issue_queue_size = device->sysmem_manager().get_issue_queue_size(cq_id); + uint32_t issue_queue_size = device_->sysmem_manager().get_issue_queue_size(cq_id_); uint32_t completion_queue_start_addr = issue_queue_start_addr + issue_queue_size; - uint32_t completion_queue_size = device->sysmem_manager().get_completion_queue_size(cq_id); + uint32_t completion_queue_size = device_->sysmem_manager().get_completion_queue_size(cq_id_); - this->logical_core = dispatch_core_manager::instance().dispatcher_core(servicing_device_id, channel, cq_id); - this->config.dispatch_cb_base = my_dispatch_constants.dispatch_buffer_base(); - this->config.dispatch_cb_log_page_size = dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE; - this->config.dispatch_cb_pages = my_dispatch_constants.dispatch_buffer_pages(); - this->config.my_dispatch_cb_sem_id = - tt::tt_metal::CreateSemaphore(*program, this->logical_core, 0, GetCoreType()); + logical_core_ = dispatch_core_manager::instance().dispatcher_core(servicing_device_id_, channel, cq_id_); + static_config_.dispatch_cb_base = my_dispatch_constants.dispatch_buffer_base(); + static_config_.dispatch_cb_log_page_size = dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE; + static_config_.dispatch_cb_pages = my_dispatch_constants.dispatch_buffer_pages(); + static_config_.my_dispatch_cb_sem_id = + tt::tt_metal::CreateSemaphore(*program_, logical_core_, 0, GetCoreType()); - this->config.dispatch_cb_blocks = dispatch_constants::DISPATCH_BUFFER_SIZE_BLOCKS; - this->config.command_queue_base_addr = command_queue_start_addr; - this->config.completion_queue_base_addr = completion_queue_start_addr; - this->config.completion_queue_size = completion_queue_size; + static_config_.dispatch_cb_blocks = dispatch_constants::DISPATCH_BUFFER_SIZE_BLOCKS; + static_config_.command_queue_base_addr = command_queue_start_addr; + static_config_.completion_queue_base_addr = completion_queue_start_addr; + static_config_.completion_queue_size = completion_queue_size; - this->config.my_downstream_cb_sem_id = 0; // Unused + static_config_.my_downstream_cb_sem_id = 0; // Unused - this->config.split_dispatch_page_preamble_size = 0; - this->config.split_prefetch = true; + static_config_.split_dispatch_page_preamble_size = 0; + static_config_.split_prefetch = true; // TODO: why is this hard-coded to 1 CQ on Galaxy? if (tt::Cluster::instance().is_galaxy_cluster()) { - this->config.prefetch_h_max_credits = my_dispatch_constants.mux_buffer_pages(1); + static_config_.prefetch_h_max_credits = my_dispatch_constants.mux_buffer_pages(1); } else { - this->config.prefetch_h_max_credits = my_dispatch_constants.mux_buffer_pages(device->num_hw_cqs()); + static_config_.prefetch_h_max_credits = my_dispatch_constants.mux_buffer_pages(device_->num_hw_cqs()); } - this->config.packed_write_max_unicast_sub_cmds = - device->compute_with_storage_grid_size().x * device->compute_with_storage_grid_size().y; - this->config.dispatch_s_sync_sem_base_addr = 0; // Unused - this->config.max_num_worker_sems = 1; // Used for array sizing, set to 1 even if unused - this->config.max_num_go_signal_noc_data_entries = 1; // Used for array sizing, sset to 1 even if unused - this->config.mcast_go_signal_addr = 0; // Unused - this->config.unicast_go_signal_addr = 0; // Unused - this->config.distributed_dispatcher = 0; // Unused + static_config_.packed_write_max_unicast_sub_cmds = + device_->compute_with_storage_grid_size().x * device_->compute_with_storage_grid_size().y; + static_config_.dispatch_s_sync_sem_base_addr = 0; // Unused + static_config_.max_num_worker_sems = 1; // Used for array sizing, set to 1 even if unused + static_config_.max_num_go_signal_noc_data_entries = 1; // Used for array sizing, sset to 1 even if unused + static_config_.mcast_go_signal_addr = 0; // Unused + static_config_.unicast_go_signal_addr = 0; // Unused + static_config_.distributed_dispatcher = 0; // Unused - this->config.host_completion_q_wr_ptr = + static_config_.host_completion_q_wr_ptr = my_dispatch_constants.get_host_command_queue_addr(CommandQueueHostAddrType::COMPLETION_Q_WR); - this->config.dev_completion_q_wr_ptr = + static_config_.dev_completion_q_wr_ptr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q_WR); - this->config.dev_completion_q_rd_ptr = + static_config_.dev_completion_q_rd_ptr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q_RD); - } else if (this->config.is_d_variant.value()) { + } else if (static_config_.is_d_variant.value()) { uint32_t cq_start = my_dispatch_constants.get_host_command_queue_addr(CommandQueueHostAddrType::UNRESERVED); - uint32_t cq_size = device->sysmem_manager().get_cq_size(); - uint32_t command_queue_start_addr = get_absolute_cq_offset(channel, cq_id, cq_size); + uint32_t cq_size = device_->sysmem_manager().get_cq_size(); + uint32_t command_queue_start_addr = get_absolute_cq_offset(channel, cq_id_, cq_size); uint32_t issue_queue_start_addr = command_queue_start_addr + cq_start; - uint32_t issue_queue_size = device->sysmem_manager().get_issue_queue_size(cq_id); + uint32_t issue_queue_size = device_->sysmem_manager().get_issue_queue_size(cq_id_); uint32_t completion_queue_start_addr = issue_queue_start_addr + issue_queue_size; - uint32_t completion_queue_size = device->sysmem_manager().get_completion_queue_size(cq_id); - - this->logical_core = dispatch_core_manager::instance().dispatcher_d_core(device->id(), channel, cq_id); - this->config.dispatch_cb_base = my_dispatch_constants.dispatch_buffer_base(); - this->config.dispatch_cb_log_page_size = dispatch_constants::PREFETCH_D_BUFFER_LOG_PAGE_SIZE; - this->config.dispatch_cb_pages = my_dispatch_constants.dispatch_buffer_pages(); - this->config.my_dispatch_cb_sem_id = - tt::tt_metal::CreateSemaphore(*program, this->logical_core, 0, GetCoreType()); - - this->config.dispatch_cb_blocks = dispatch_constants::DISPATCH_BUFFER_SIZE_BLOCKS; - this->config.command_queue_base_addr = 0; // These are unused for DISPATCH_D - this->config.completion_queue_base_addr = 0; - this->config.completion_queue_size = 0; - - this->config.my_downstream_cb_sem_id = tt::tt_metal::CreateSemaphore( - *program, - this->logical_core, - my_dispatch_constants.mux_buffer_pages(device->num_hw_cqs()), + uint32_t completion_queue_size = device_->sysmem_manager().get_completion_queue_size(cq_id_); + + logical_core_ = dispatch_core_manager::instance().dispatcher_d_core(device_->id(), channel, cq_id_); + static_config_.dispatch_cb_base = my_dispatch_constants.dispatch_buffer_base(); + static_config_.dispatch_cb_log_page_size = dispatch_constants::PREFETCH_D_BUFFER_LOG_PAGE_SIZE; + static_config_.dispatch_cb_pages = my_dispatch_constants.dispatch_buffer_pages(); + static_config_.my_dispatch_cb_sem_id = + tt::tt_metal::CreateSemaphore(*program_, logical_core_, 0, GetCoreType()); + + static_config_.dispatch_cb_blocks = dispatch_constants::DISPATCH_BUFFER_SIZE_BLOCKS; + static_config_.command_queue_base_addr = 0; // These are unused for DISPATCH_D + static_config_.completion_queue_base_addr = 0; + static_config_.completion_queue_size = 0; + + static_config_.my_downstream_cb_sem_id = tt::tt_metal::CreateSemaphore( + *program_, + logical_core_, + my_dispatch_constants.mux_buffer_pages(device_->num_hw_cqs()), GetCoreType()); // Apparently unused - this->config.split_dispatch_page_preamble_size = sizeof(dispatch_packet_header_t); - this->config.split_prefetch = true; - this->config.prefetch_h_noc_xy = 0; - this->config.prefetch_h_local_downstream_sem_addr = 1; - this->config.prefetch_h_max_credits = my_dispatch_constants.mux_buffer_pages(device->num_hw_cqs()); + static_config_.split_dispatch_page_preamble_size = sizeof(dispatch_packet_header_t); + static_config_.split_prefetch = true; + dependent_config_.prefetch_h_noc_xy = 0; + dependent_config_.prefetch_h_local_downstream_sem_addr = 1; + static_config_.prefetch_h_max_credits = my_dispatch_constants.mux_buffer_pages(device_->num_hw_cqs()); // To match with previous implementation, need to use grid size from mmio device. TODO: that doesn't seem // correct though? - auto mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id); + auto mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id_); const auto& dispatch_core_config = dispatch_core_manager::instance().get_dispatch_core_config(mmio_device_id); CoreCoord remote_grid_size = - tt::get_compute_grid_size(mmio_device_id, device->num_hw_cqs(), dispatch_core_config); - this->config.packed_write_max_unicast_sub_cmds = remote_grid_size.x * remote_grid_size.y; - this->config.dispatch_s_sync_sem_base_addr = + tt::get_compute_grid_size(mmio_device_id, device_->num_hw_cqs(), dispatch_core_config); + static_config_.packed_write_max_unicast_sub_cmds = remote_grid_size.x * remote_grid_size.y; + static_config_.dispatch_s_sync_sem_base_addr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_S_SYNC_SEM); - this->config.max_num_worker_sems = dispatch_constants::DISPATCH_MESSAGE_ENTRIES; - this->config.max_num_go_signal_noc_data_entries = dispatch_constants::DISPATCH_GO_SIGNAL_NOC_DATA_ENTRIES; - this->config.mcast_go_signal_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::GO_MSG); - this->config.unicast_go_signal_addr = + static_config_.max_num_worker_sems = dispatch_constants::DISPATCH_MESSAGE_ENTRIES; + static_config_.max_num_go_signal_noc_data_entries = dispatch_constants::DISPATCH_GO_SIGNAL_NOC_DATA_ENTRIES; + static_config_.mcast_go_signal_addr = + hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::GO_MSG); + static_config_.unicast_go_signal_addr = (hal.get_programmable_core_type_index(HalProgrammableCoreType::ACTIVE_ETH) != -1) ? hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::GO_MSG) : 0; - this->config.distributed_dispatcher = (GetCoreType() == CoreType::ETH); + static_config_.distributed_dispatcher = (GetCoreType() == CoreType::ETH); - this->config.host_completion_q_wr_ptr = + static_config_.host_completion_q_wr_ptr = my_dispatch_constants.get_host_command_queue_addr(CommandQueueHostAddrType::COMPLETION_Q_WR); - this->config.dev_completion_q_wr_ptr = + static_config_.dev_completion_q_wr_ptr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q_WR); - this->config.dev_completion_q_rd_ptr = + static_config_.dev_completion_q_rd_ptr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q_RD); } else { TT_FATAL(false, "DispatchKernel must be one of (or both) H and D variants"); @@ -176,90 +178,91 @@ void DispatchKernel::GenerateStaticConfigs() { void DispatchKernel::GenerateDependentConfigs() { auto& my_dispatch_constants = dispatch_constants::get(GetCoreType()); - if (this->config.is_h_variant.value() && this->config.is_d_variant.value()) { + if (static_config_.is_h_variant.value() && this->static_config_.is_d_variant.value()) { // Upstream - TT_ASSERT(this->upstream_kernels.size() == 1); - auto prefetch_kernel = dynamic_cast(this->upstream_kernels[0]); + TT_ASSERT(upstream_kernels_.size() == 1); + auto prefetch_kernel = dynamic_cast(upstream_kernels_[0]); TT_ASSERT(prefetch_kernel); - this->config.upstream_logical_core = prefetch_kernel->GetLogicalCore(); - this->config.upstream_dispatch_cb_sem_id = prefetch_kernel->GetConfig().my_downstream_cb_sem_id; - this->config.upstream_sync_sem = prefetch_kernel->GetConfig().downstream_sync_sem_id; + dependent_config_.upstream_logical_core = prefetch_kernel->GetLogicalCore(); + dependent_config_.upstream_dispatch_cb_sem_id = prefetch_kernel->GetStaticConfig().my_downstream_cb_sem_id; + dependent_config_.upstream_sync_sem = prefetch_kernel->GetStaticConfig().downstream_sync_sem_id; // Downstream - if (device->dispatch_s_enabled()) { - TT_ASSERT(this->downstream_kernels.size() == 1); - auto dispatch_s_kernel = dynamic_cast(this->downstream_kernels[0]); + if (device_->dispatch_s_enabled()) { + TT_ASSERT(downstream_kernels_.size() == 1); + auto dispatch_s_kernel = dynamic_cast(downstream_kernels_[0]); TT_ASSERT(dispatch_s_kernel); - this->config.downstream_s_logical_core = dispatch_s_kernel->GetLogicalCore(); + dependent_config_.downstream_s_logical_core = dispatch_s_kernel->GetLogicalCore(); } else { // If no dispatch_s, no downstream - TT_ASSERT(this->downstream_kernels.size() == 0); - this->config.downstream_s_logical_core = UNUSED_LOGICAL_CORE; + TT_ASSERT(downstream_kernels_.size() == 0); + dependent_config_.downstream_s_logical_core = UNUSED_LOGICAL_CORE; } - this->config.downstream_logical_core = UNUSED_LOGICAL_CORE; - this->config.downstream_cb_base = my_dispatch_constants.dispatch_buffer_base(); - this->config.downstream_cb_size = + dependent_config_.downstream_logical_core = UNUSED_LOGICAL_CORE; + dependent_config_.downstream_cb_base = my_dispatch_constants.dispatch_buffer_base(); + dependent_config_.downstream_cb_size = (1 << dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE) * my_dispatch_constants.dispatch_buffer_pages(); - this->config.downstream_cb_sem_id = UNUSED_SEM_ID; - } else if (this->config.is_h_variant.value()) { + dependent_config_.downstream_cb_sem_id = UNUSED_SEM_ID; + } else if (static_config_.is_h_variant.value()) { // Upstream, expect DEMUX - TT_ASSERT(this->upstream_kernels.size() == 1); - auto demux_kernel = dynamic_cast(this->upstream_kernels[0]); + TT_ASSERT(upstream_kernels_.size() == 1); + auto demux_kernel = dynamic_cast(upstream_kernels_[0]); TT_ASSERT(demux_kernel); - this->config.upstream_logical_core = demux_kernel->GetLogicalCore(); + dependent_config_.upstream_logical_core = demux_kernel->GetLogicalCore(); int demux_idx = demux_kernel->GetDownstreamPort(this); // Need to know which port this kernel connects to upstream - this->config.upstream_dispatch_cb_sem_id = - demux_kernel->GetConfig().output_depacketize_local_sem_id[demux_idx].value(); - this->config.upstream_sync_sem = 0; // Unused + dependent_config_.upstream_dispatch_cb_sem_id = + demux_kernel->GetStaticConfig().output_depacketize_local_sem_id[demux_idx].value(); + dependent_config_.upstream_sync_sem = 0; // Unused // Downstream, no official downstream core but use the field to connect is to the PREFETCH_H that we need to // write to when resuming sending of commands post exec_buf stall. - TT_ASSERT(this->downstream_kernels.size() == 1); - auto prefetch_h_kernel = dynamic_cast(this->downstream_kernels[0]); + TT_ASSERT(downstream_kernels_.size() == 1); + auto prefetch_h_kernel = dynamic_cast(downstream_kernels_[0]); TT_ASSERT(prefetch_h_kernel); - this->config.downstream_logical_core = UNUSED_LOGICAL_CORE_ADJUSTED; - this->config.downstream_s_logical_core = UNUSED_LOGICAL_CORE_ADJUSTED; - this->config.prefetch_h_noc_xy = tt::tt_metal::hal.noc_xy_encoding( + dependent_config_.downstream_logical_core = UNUSED_LOGICAL_CORE_ADJUSTED; + dependent_config_.downstream_s_logical_core = UNUSED_LOGICAL_CORE_ADJUSTED; + dependent_config_.prefetch_h_noc_xy = tt::tt_metal::hal.noc_xy_encoding( prefetch_h_kernel->GetVirtualCore().x, prefetch_h_kernel->GetVirtualCore().y); - this->config.prefetch_h_local_downstream_sem_addr = prefetch_h_kernel->GetConfig().my_downstream_cb_sem_id; - this->config.downstream_cb_base = my_dispatch_constants.dispatch_buffer_base(); // Unused - this->config.downstream_cb_size = (1 << dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE) * - my_dispatch_constants.dispatch_buffer_pages(); // Unused - this->config.downstream_cb_sem_id = 0; // Unused - } else if (this->config.is_d_variant.value()) { + dependent_config_.prefetch_h_local_downstream_sem_addr = + prefetch_h_kernel->GetStaticConfig().my_downstream_cb_sem_id; + dependent_config_.downstream_cb_base = my_dispatch_constants.dispatch_buffer_base(); // Unused + dependent_config_.downstream_cb_size = (1 << dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE) * + my_dispatch_constants.dispatch_buffer_pages(); // Unused + dependent_config_.downstream_cb_sem_id = 0; // Unused + } else if (static_config_.is_d_variant.value()) { // Upstream, expect a PREFETCH_D - TT_ASSERT(this->upstream_kernels.size() == 1); - auto prefetch_kernel = dynamic_cast(this->upstream_kernels[0]); + TT_ASSERT(upstream_kernels_.size() == 1); + auto prefetch_kernel = dynamic_cast(upstream_kernels_[0]); TT_ASSERT(prefetch_kernel); - this->config.upstream_logical_core = prefetch_kernel->GetLogicalCore(); - this->config.upstream_dispatch_cb_sem_id = prefetch_kernel->GetConfig().my_downstream_cb_sem_id; - this->config.upstream_sync_sem = prefetch_kernel->GetConfig().downstream_sync_sem_id; + dependent_config_.upstream_logical_core = prefetch_kernel->GetLogicalCore(); + dependent_config_.upstream_dispatch_cb_sem_id = prefetch_kernel->GetStaticConfig().my_downstream_cb_sem_id; + dependent_config_.upstream_sync_sem = prefetch_kernel->GetStaticConfig().downstream_sync_sem_id; // Downstream, expect a MUX_D and a DISPATCH_S if enabled - auto dispatch_s_kernel = dynamic_cast(this->downstream_kernels[0]); - auto mux_kernel = dynamic_cast(this->downstream_kernels[0]); - if (device->dispatch_s_enabled()) { - TT_ASSERT(this->downstream_kernels.size() == 2); - mux_kernel = dynamic_cast(this->downstream_kernels[1]); + auto dispatch_s_kernel = dynamic_cast(downstream_kernels_[0]); + auto mux_kernel = dynamic_cast(downstream_kernels_[0]); + if (device_->dispatch_s_enabled()) { + TT_ASSERT(downstream_kernels_.size() == 2); + mux_kernel = dynamic_cast(downstream_kernels_[1]); if (!dispatch_s_kernel) { - dispatch_s_kernel = dynamic_cast(this->downstream_kernels[1]); - mux_kernel = dynamic_cast(this->downstream_kernels[0]); + dispatch_s_kernel = dynamic_cast(downstream_kernels_[1]); + mux_kernel = dynamic_cast(downstream_kernels_[0]); } TT_ASSERT(dispatch_s_kernel); - this->config.downstream_s_logical_core = dispatch_s_kernel->GetLogicalCore(); + dependent_config_.downstream_s_logical_core = dispatch_s_kernel->GetLogicalCore(); } else { - TT_ASSERT(this->downstream_kernels.size() == 1); - this->config.downstream_s_logical_core = UNUSED_LOGICAL_CORE; + TT_ASSERT(downstream_kernels_.size() == 1); + dependent_config_.downstream_s_logical_core = UNUSED_LOGICAL_CORE; } TT_ASSERT(mux_kernel); - this->config.downstream_logical_core = mux_kernel->GetLogicalCore(); + dependent_config_.downstream_logical_core = mux_kernel->GetLogicalCore(); // Some configs depend on which port this kernel connects to on the downstream kernel int dispatch_d_idx = mux_kernel->GetUpstreamPort(this); // Need the port that this connects to downstream - this->config.downstream_cb_size = (1 << dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE) * - my_dispatch_constants.mux_buffer_pages(device->num_hw_cqs()); - this->config.downstream_cb_base = - my_dispatch_constants.dispatch_buffer_base() + this->config.downstream_cb_size.value() * dispatch_d_idx; - this->config.downstream_cb_sem_id = dispatch_d_idx; + dependent_config_.downstream_cb_size = (1 << dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE) * + my_dispatch_constants.mux_buffer_pages(device_->num_hw_cqs()); + dependent_config_.downstream_cb_base = my_dispatch_constants.dispatch_buffer_base() + + dependent_config_.downstream_cb_size.value() * dispatch_d_idx; + dependent_config_.downstream_cb_sem_id = dispatch_d_idx; } else { TT_FATAL(false, "DispatchKernel must be one of (or both) H and D variants"); } @@ -267,65 +270,65 @@ void DispatchKernel::GenerateDependentConfigs() { void DispatchKernel::CreateKernel() { std::vector compile_args = { - config.dispatch_cb_base.value(), - config.dispatch_cb_log_page_size.value(), - config.dispatch_cb_pages.value(), - config.my_dispatch_cb_sem_id.value(), - config.upstream_dispatch_cb_sem_id.value(), - - config.dispatch_cb_blocks.value(), - config.upstream_sync_sem.value(), - config.command_queue_base_addr.value(), - config.completion_queue_base_addr.value(), - config.completion_queue_size.value(), - - config.downstream_cb_base.value(), - config.downstream_cb_size.value(), - config.my_downstream_cb_sem_id.value(), - config.downstream_cb_sem_id.value(), - - config.split_dispatch_page_preamble_size.value(), - config.split_prefetch.value(), - config.prefetch_h_noc_xy.value(), - config.prefetch_h_local_downstream_sem_addr.value(), - config.prefetch_h_max_credits.value(), - - config.packed_write_max_unicast_sub_cmds.value(), - config.dispatch_s_sync_sem_base_addr.value(), - config.max_num_worker_sems.value(), - config.max_num_go_signal_noc_data_entries.value(), - config.mcast_go_signal_addr.value(), - config.unicast_go_signal_addr.value(), - config.distributed_dispatcher.value(), - - config.host_completion_q_wr_ptr.value(), - config.dev_completion_q_wr_ptr.value(), - config.dev_completion_q_rd_ptr.value(), - - config.is_d_variant.value(), - config.is_h_variant.value(), + static_config_.dispatch_cb_base.value(), + static_config_.dispatch_cb_log_page_size.value(), + static_config_.dispatch_cb_pages.value(), + static_config_.my_dispatch_cb_sem_id.value(), + dependent_config_.upstream_dispatch_cb_sem_id.value(), + + static_config_.dispatch_cb_blocks.value(), + dependent_config_.upstream_sync_sem.value(), + static_config_.command_queue_base_addr.value(), + static_config_.completion_queue_base_addr.value(), + static_config_.completion_queue_size.value(), + + dependent_config_.downstream_cb_base.value(), + dependent_config_.downstream_cb_size.value(), + static_config_.my_downstream_cb_sem_id.value(), + dependent_config_.downstream_cb_sem_id.value(), + + static_config_.split_dispatch_page_preamble_size.value(), + static_config_.split_prefetch.value(), + dependent_config_.prefetch_h_noc_xy.value(), + dependent_config_.prefetch_h_local_downstream_sem_addr.value(), + static_config_.prefetch_h_max_credits.value(), + + static_config_.packed_write_max_unicast_sub_cmds.value(), + static_config_.dispatch_s_sync_sem_base_addr.value(), + static_config_.max_num_worker_sems.value(), + static_config_.max_num_go_signal_noc_data_entries.value(), + static_config_.mcast_go_signal_addr.value(), + static_config_.unicast_go_signal_addr.value(), + static_config_.distributed_dispatcher.value(), + + static_config_.host_completion_q_wr_ptr.value(), + static_config_.dev_completion_q_wr_ptr.value(), + static_config_.dev_completion_q_rd_ptr.value(), + + static_config_.is_d_variant.value(), + static_config_.is_h_variant.value(), }; TT_ASSERT(compile_args.size() == 31); - auto my_virtual_core = device->virtual_core_from_logical_core(this->logical_core, GetCoreType()); + auto my_virtual_core = device_->virtual_core_from_logical_core(logical_core_, GetCoreType()); auto upstream_virtual_core = - device->virtual_core_from_logical_core(config.upstream_logical_core.value(), GetCoreType()); + device_->virtual_core_from_logical_core(dependent_config_.upstream_logical_core.value(), GetCoreType()); auto downstream_virtual_core = - device->virtual_core_from_logical_core(config.downstream_logical_core.value(), GetCoreType()); + device_->virtual_core_from_logical_core(dependent_config_.downstream_logical_core.value(), GetCoreType()); auto downstream_s_virtual_core = - device->virtual_core_from_logical_core(config.downstream_s_logical_core.value(), GetCoreType()); + device_->virtual_core_from_logical_core(dependent_config_.downstream_s_logical_core.value(), GetCoreType()); - auto my_virtual_noc_coords = device->virtual_noc0_coordinate(noc_selection.non_dispatch_noc, my_virtual_core); + auto my_virtual_noc_coords = device_->virtual_noc0_coordinate(noc_selection_.non_dispatch_noc, my_virtual_core); auto upstream_virtual_noc_coords = - device->virtual_noc0_coordinate(noc_selection.upstream_noc, upstream_virtual_core); + device_->virtual_noc0_coordinate(noc_selection_.upstream_noc, upstream_virtual_core); auto downstream_virtual_noc_coords = - device->virtual_noc0_coordinate(noc_selection.downstream_noc, downstream_virtual_core); + device_->virtual_noc0_coordinate(noc_selection_.downstream_noc, downstream_virtual_core); auto downstream_s_virtual_noc_coords = - device->virtual_noc0_coordinate(noc_selection.downstream_noc, downstream_s_virtual_core); + device_->virtual_noc0_coordinate(noc_selection_.downstream_noc, downstream_s_virtual_core); std::map defines = { {"MY_NOC_X", std::to_string(my_virtual_noc_coords.x)}, {"MY_NOC_Y", std::to_string(my_virtual_noc_coords.y)}, - {"UPSTREAM_NOC_INDEX", std::to_string(this->noc_selection.upstream_noc)}, + {"UPSTREAM_NOC_INDEX", std::to_string(noc_selection_.upstream_noc)}, {"UPSTREAM_NOC_X", std::to_string(upstream_virtual_noc_coords.x)}, {"UPSTREAM_NOC_Y", std::to_string(upstream_virtual_noc_coords.y)}, {"DOWNSTREAM_NOC_X", std::to_string(downstream_virtual_noc_coords.x)}, @@ -344,24 +347,24 @@ void DispatchKernel::ConfigureCore() { my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_S_SYNC_SEM); uint32_t dispatch_message_base_addr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); - tt::log_warning("Configure Dispatch (device {} core {})", device->id(), logical_core.str()); + tt::log_warning("Configure Dispatch (device {} core {})", device_->id(), logical_core_.str()); for (uint32_t i = 0; i < dispatch_constants::DISPATCH_MESSAGE_ENTRIES; i++) { uint32_t dispatch_s_sync_sem_addr = dispatch_s_sync_sem_base_addr + my_dispatch_constants.get_dispatch_message_offset(i); uint32_t dispatch_message_addr = dispatch_message_base_addr + my_dispatch_constants.get_dispatch_message_offset(i); - detail::WriteToDeviceL1(device, this->logical_core, dispatch_s_sync_sem_addr, zero, GetCoreType()); - detail::WriteToDeviceL1(device, this->logical_core, dispatch_message_addr, zero, GetCoreType()); + detail::WriteToDeviceL1(device_, logical_core_, dispatch_s_sync_sem_addr, zero, GetCoreType()); + detail::WriteToDeviceL1(device_, logical_core_, dispatch_message_addr, zero, GetCoreType()); } // For DISPATCH_D, need to clear completion q events - if (!this->config.is_h_variant.value() && this->config.is_d_variant.value()) { - tt::log_warning("Configure Dispatch D Counters (device {} core {})", device->id(), logical_core.str()); + if (!static_config_.is_h_variant.value() && this->static_config_.is_d_variant.value()) { + tt::log_warning("Configure Dispatch D Counters (device {} core {})", device_->id(), logical_core_.str()); uint32_t completion_q0_last_event_ptr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q0_LAST_EVENT); uint32_t completion_q1_last_event_ptr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q1_LAST_EVENT); - detail::WriteToDeviceL1(device, logical_core, completion_q0_last_event_ptr, zero, GetCoreType()); - detail::WriteToDeviceL1(device, logical_core, completion_q1_last_event_ptr, zero, GetCoreType()); + detail::WriteToDeviceL1(device_, logical_core_, completion_q0_last_event_ptr, zero, GetCoreType()); + detail::WriteToDeviceL1(device_, logical_core_, completion_q1_last_event_ptr, zero, GetCoreType()); } } diff --git a/tt_metal/impl/dispatch/kernel_config/dispatch_kernel.hpp b/tt_metal/impl/dispatch/kernel_config/dispatch_kernel.hpp index 5a283156176e..431d54c8605e 100644 --- a/tt_metal/impl/dispatch/kernel_config/dispatch_kernel.hpp +++ b/tt_metal/impl/dispatch/kernel_config/dispatch_kernel.hpp @@ -4,32 +4,21 @@ #pragma once #include "fd_kernel.hpp" -typedef struct dispatch_config { - std::optional upstream_logical_core; // Dependant - std::optional downstream_logical_core; // Dependant - std::optional downstream_s_logical_core; // Dependant - +typedef struct dispatch_static_config { std::optional dispatch_cb_base; // 0 std::optional dispatch_cb_log_page_size; std::optional dispatch_cb_pages; std::optional my_dispatch_cb_sem_id; - std::optional upstream_dispatch_cb_sem_id; // Dependant std::optional dispatch_cb_blocks; // 5 - std::optional upstream_sync_sem; // Dependant std::optional command_queue_base_addr; std::optional completion_queue_base_addr; std::optional completion_queue_size; - std::optional downstream_cb_base; // 10, dependent - std::optional downstream_cb_size; // Dependent std::optional my_downstream_cb_sem_id; - std::optional downstream_cb_sem_id; // Dependant std::optional split_dispatch_page_preamble_size; // 14 std::optional split_prefetch; - std::optional prefetch_h_noc_xy; // Dependent - std::optional prefetch_h_local_downstream_sem_addr; // Dependent std::optional prefetch_h_max_credits; std::optional packed_write_max_unicast_sub_cmds; // 19 @@ -46,7 +35,24 @@ typedef struct dispatch_config { std::optional is_d_variant; std::optional is_h_variant; -} dispatch_config_t; +} dispatch_static_config_t; + +typedef struct dispatch_dependent_config { + std::optional upstream_logical_core; // Dependant + std::optional downstream_logical_core; // Dependant + std::optional downstream_s_logical_core; // Dependant + + std::optional upstream_dispatch_cb_sem_id; // Dependant + + std::optional upstream_sync_sem; // Dependant + + std::optional downstream_cb_base; // 10, dependent + std::optional downstream_cb_size; // Dependent + std::optional downstream_cb_sem_id; // Dependant + + std::optional prefetch_h_noc_xy; // Dependent + std::optional prefetch_h_local_downstream_sem_addr; // Dependent +} dispatch_dependent_config_t; class DispatchKernel : public FDKernel { public: @@ -59,15 +65,16 @@ class DispatchKernel : public FDKernel { bool h_variant, bool d_variant) : FDKernel(node_id, device_id, servicing_device_id, cq_id, noc_selection) { - config.is_h_variant = h_variant; - config.is_d_variant = d_variant; + static_config_.is_h_variant = h_variant; + static_config_.is_d_variant = d_variant; } void CreateKernel() override; void GenerateStaticConfigs() override; void GenerateDependentConfigs() override; void ConfigureCore() override; - const dispatch_config_t& GetConfig() { return this->config; } + const dispatch_static_config_t& GetStaticConfig() { return static_config_; } private: - dispatch_config_t config; + dispatch_static_config_t static_config_; + dispatch_dependent_config_t dependent_config_; }; diff --git a/tt_metal/impl/dispatch/kernel_config/dispatch_s_kernel.cpp b/tt_metal/impl/dispatch/kernel_config/dispatch_s_kernel.cpp index 36090db6a54a..b7b891de08dc 100644 --- a/tt_metal/impl/dispatch/kernel_config/dispatch_s_kernel.cpp +++ b/tt_metal/impl/dispatch/kernel_config/dispatch_s_kernel.cpp @@ -9,12 +9,12 @@ #include "tt_metal/detail/tt_metal.hpp" void DispatchSKernel::GenerateStaticConfigs() { - uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device->id()); - uint8_t cq_id = this->cq_id; + uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device_->id()); + uint8_t cq_id_ = this->cq_id_; auto& my_dispatch_constants = dispatch_constants::get(GetCoreType()); uint32_t dispatch_s_buffer_base = 0xff; - if (device->dispatch_s_enabled()) { + if (device_->dispatch_s_enabled()) { uint32_t dispatch_buffer_base = my_dispatch_constants.dispatch_buffer_base(); if (GetCoreType() == CoreType::WORKER) { // dispatch_s is on the same Tensix core as dispatch_d. Shared resources. Offset CB start idx. @@ -25,78 +25,78 @@ void DispatchSKernel::GenerateStaticConfigs() { dispatch_s_buffer_base = dispatch_buffer_base; } } - this->logical_core = dispatch_core_manager::instance().dispatcher_s_core(device->id(), channel, cq_id); - this->config.cb_base = dispatch_s_buffer_base; - this->config.cb_log_page_size = dispatch_constants::DISPATCH_S_BUFFER_LOG_PAGE_SIZE; - this->config.cb_size = my_dispatch_constants.dispatch_s_buffer_size(); + logical_core_ = dispatch_core_manager::instance().dispatcher_s_core(device_->id(), channel, cq_id_); + static_config_.cb_base = dispatch_s_buffer_base; + static_config_.cb_log_page_size = dispatch_constants::DISPATCH_S_BUFFER_LOG_PAGE_SIZE; + static_config_.cb_size = my_dispatch_constants.dispatch_s_buffer_size(); // used by dispatch_s to sync with prefetch - this->config.my_dispatch_cb_sem_id = tt::tt_metal::CreateSemaphore(*program, this->logical_core, 0, GetCoreType()); - this->config.dispatch_s_sync_sem_base_addr = + static_config_.my_dispatch_cb_sem_id = tt::tt_metal::CreateSemaphore(*program_, logical_core_, 0, GetCoreType()); + static_config_.dispatch_s_sync_sem_base_addr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_S_SYNC_SEM); // used by dispatch_d to signal that dispatch_s can send go signal - this->config.mcast_go_signal_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::GO_MSG); - this->config.unicast_go_signal_addr = + static_config_.mcast_go_signal_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::GO_MSG); + static_config_.unicast_go_signal_addr = (hal.get_programmable_core_type_index(HalProgrammableCoreType::ACTIVE_ETH) != -1) ? hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::GO_MSG) : 0; - this->config.distributed_dispatcher = (GetCoreType() == CoreType::ETH); - this->config.worker_sem_base_addr = + static_config_.distributed_dispatcher = (GetCoreType() == CoreType::ETH); + static_config_.worker_sem_base_addr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); - this->config.max_num_worker_sems = dispatch_constants::DISPATCH_MESSAGE_ENTRIES; - this->config.max_num_go_signal_noc_data_entries = dispatch_constants::DISPATCH_GO_SIGNAL_NOC_DATA_ENTRIES; + static_config_.max_num_worker_sems = dispatch_constants::DISPATCH_MESSAGE_ENTRIES; + static_config_.max_num_go_signal_noc_data_entries = dispatch_constants::DISPATCH_GO_SIGNAL_NOC_DATA_ENTRIES; } void DispatchSKernel::GenerateDependentConfigs() { // Upstream - TT_ASSERT(this->upstream_kernels.size() == 1); - auto prefetch_kernel = dynamic_cast(this->upstream_kernels[0]); + TT_ASSERT(upstream_kernels_.size() == 1); + auto prefetch_kernel = dynamic_cast(upstream_kernels_[0]); TT_ASSERT(prefetch_kernel); - this->config.upstream_logical_core = prefetch_kernel->GetLogicalCore(); - this->config.upstream_dispatch_cb_sem_id = prefetch_kernel->GetConfig().my_dispatch_s_cb_sem_id; + dependent_config_.upstream_logical_core = prefetch_kernel->GetLogicalCore(); + dependent_config_.upstream_dispatch_cb_sem_id = prefetch_kernel->GetStaticConfig().my_dispatch_s_cb_sem_id; // Downstream - TT_ASSERT(this->downstream_kernels.size() == 1); - auto dispatch_kernel = dynamic_cast(this->downstream_kernels[0]); + TT_ASSERT(downstream_kernels_.size() == 1); + auto dispatch_kernel = dynamic_cast(downstream_kernels_[0]); TT_ASSERT(dispatch_kernel); - this->config.downstream_logical_core = dispatch_kernel->GetLogicalCore(); + dependent_config_.downstream_logical_core = dispatch_kernel->GetLogicalCore(); } void DispatchSKernel::CreateKernel() { std::vector compile_args = { - config.cb_base.value(), - config.cb_log_page_size.value(), - config.cb_size.value(), - config.my_dispatch_cb_sem_id.value(), - config.upstream_dispatch_cb_sem_id.value(), - config.dispatch_s_sync_sem_base_addr.value(), - config.mcast_go_signal_addr.value(), - config.unicast_go_signal_addr.value(), - config.distributed_dispatcher.value(), - config.worker_sem_base_addr.value(), - config.max_num_worker_sems.value(), - config.max_num_go_signal_noc_data_entries.value(), + static_config_.cb_base.value(), + static_config_.cb_log_page_size.value(), + static_config_.cb_size.value(), + static_config_.my_dispatch_cb_sem_id.value(), + dependent_config_.upstream_dispatch_cb_sem_id.value(), + static_config_.dispatch_s_sync_sem_base_addr.value(), + static_config_.mcast_go_signal_addr.value(), + static_config_.unicast_go_signal_addr.value(), + static_config_.distributed_dispatcher.value(), + static_config_.worker_sem_base_addr.value(), + static_config_.max_num_worker_sems.value(), + static_config_.max_num_go_signal_noc_data_entries.value(), }; TT_ASSERT(compile_args.size() == 12); - auto my_virtual_core = device->virtual_core_from_logical_core(this->logical_core, GetCoreType()); + auto my_virtual_core = device_->virtual_core_from_logical_core(logical_core_, GetCoreType()); auto upstream_virtual_core = - device->virtual_core_from_logical_core(config.upstream_logical_core.value(), GetCoreType()); + device_->virtual_core_from_logical_core(dependent_config_.upstream_logical_core.value(), GetCoreType()); auto downstream_virtual_core = - device->virtual_core_from_logical_core(config.downstream_logical_core.value(), GetCoreType()); - auto downstream_s_virtual_core = device->virtual_core_from_logical_core(UNUSED_LOGICAL_CORE, GetCoreType()); + device_->virtual_core_from_logical_core(dependent_config_.downstream_logical_core.value(), GetCoreType()); + auto downstream_s_virtual_core = device_->virtual_core_from_logical_core(UNUSED_LOGICAL_CORE, GetCoreType()); - auto my_virtual_noc_coords = device->virtual_noc0_coordinate(noc_selection.non_dispatch_noc, my_virtual_core); + auto my_virtual_noc_coords = device_->virtual_noc0_coordinate(noc_selection_.non_dispatch_noc, my_virtual_core); auto upstream_virtual_noc_coords = - device->virtual_noc0_coordinate(noc_selection.upstream_noc, upstream_virtual_core); + device_->virtual_noc0_coordinate(noc_selection_.upstream_noc, upstream_virtual_core); auto downstream_virtual_noc_coords = - device->virtual_noc0_coordinate(noc_selection.downstream_noc, downstream_virtual_core); + device_->virtual_noc0_coordinate(noc_selection_.downstream_noc, downstream_virtual_core); auto downstream_s_virtual_noc_coords = - device->virtual_noc0_coordinate(noc_selection.downstream_noc, downstream_s_virtual_core); + device_->virtual_noc0_coordinate(noc_selection_.downstream_noc, downstream_s_virtual_core); std::map defines = { {"MY_NOC_X", std::to_string(my_virtual_noc_coords.x)}, {"MY_NOC_Y", std::to_string(my_virtual_noc_coords.y)}, - {"UPSTREAM_NOC_INDEX", std::to_string(this->noc_selection.upstream_noc)}, // Unused, remove later + {"UPSTREAM_NOC_INDEX", std::to_string(noc_selection_.upstream_noc)}, // Unused, remove later {"UPSTREAM_NOC_X", std::to_string(upstream_virtual_noc_coords.x)}, {"UPSTREAM_NOC_Y", std::to_string(upstream_virtual_noc_coords.y)}, {"DOWNSTREAM_NOC_X", std::to_string(downstream_virtual_noc_coords.x)}, @@ -108,11 +108,11 @@ void DispatchSKernel::CreateKernel() { } void DispatchSKernel::ConfigureCore() { - if (!this->device->distributed_dispatcher()) { + if (!device_->distributed_dispatcher()) { return; } // Just need to clear the dispatch message - tt::log_warning("Configure Dispatch S (device {} core {})", device->id(), logical_core.str()); + tt::log_warning("Configure Dispatch S (device {} core {})", device_->id(), logical_core_.str()); std::vector zero = {0x0}; auto& my_dispatch_constants = dispatch_constants::get(GetCoreType()); uint32_t dispatch_s_sync_sem_base_addr = @@ -124,7 +124,7 @@ void DispatchSKernel::ConfigureCore() { dispatch_s_sync_sem_base_addr + my_dispatch_constants.get_dispatch_message_offset(i); uint32_t dispatch_message_addr = dispatch_message_base_addr + my_dispatch_constants.get_dispatch_message_offset(i); - detail::WriteToDeviceL1(device, this->logical_core, dispatch_s_sync_sem_addr, zero, GetCoreType()); - detail::WriteToDeviceL1(device, this->logical_core, dispatch_message_addr, zero, GetCoreType()); + detail::WriteToDeviceL1(device_, logical_core_, dispatch_s_sync_sem_addr, zero, GetCoreType()); + detail::WriteToDeviceL1(device_, logical_core_, dispatch_message_addr, zero, GetCoreType()); } } diff --git a/tt_metal/impl/dispatch/kernel_config/dispatch_s_kernel.hpp b/tt_metal/impl/dispatch/kernel_config/dispatch_s_kernel.hpp index 6af59c5bd09f..a9293e7bd1f1 100644 --- a/tt_metal/impl/dispatch/kernel_config/dispatch_s_kernel.hpp +++ b/tt_metal/impl/dispatch/kernel_config/dispatch_s_kernel.hpp @@ -4,15 +4,11 @@ #pragma once #include "fd_kernel.hpp" -typedef struct dispatch_s_config { - std::optional upstream_logical_core; // Dependant - std::optional downstream_logical_core; // Dependant - +typedef struct dispatch_s_static_config { std::optional cb_base; std::optional cb_log_page_size; std::optional cb_size; std::optional my_dispatch_cb_sem_id; - std::optional upstream_dispatch_cb_sem_id; // Dependent std::optional dispatch_s_sync_sem_base_addr; std::optional mcast_go_signal_addr; @@ -21,7 +17,13 @@ typedef struct dispatch_s_config { std::optional worker_sem_base_addr; std::optional max_num_worker_sems; std::optional max_num_go_signal_noc_data_entries; -} dispatch_s_config_t; +} dispatch_s_static_config_t; + +typedef struct dispatch_s_dependent_config { + std::optional upstream_logical_core; // Dependant + std::optional downstream_logical_core; // Dependant + std::optional upstream_dispatch_cb_sem_id; // Dependent +} dispatch_s_dependent_config_t; class DispatchSKernel : public FDKernel { public: @@ -32,8 +34,9 @@ class DispatchSKernel : public FDKernel { void GenerateStaticConfigs() override; void GenerateDependentConfigs() override; void ConfigureCore() override; - const dispatch_s_config_t& GetConfig() { return this->config; } + const dispatch_s_static_config_t& GetStaticConfig() { return static_config_; } private: - dispatch_s_config_t config; + dispatch_s_static_config_t static_config_; + dispatch_s_dependent_config_t dependent_config_; }; diff --git a/tt_metal/impl/dispatch/kernel_config/eth_router_kernel.cpp b/tt_metal/impl/dispatch/kernel_config/eth_router_kernel.cpp index 6681a3063376..cf522fb7b31c 100644 --- a/tt_metal/impl/dispatch/kernel_config/eth_router_kernel.cpp +++ b/tt_metal/impl/dispatch/kernel_config/eth_router_kernel.cpp @@ -10,134 +10,134 @@ void EthRouterKernel::GenerateStaticConfigs() { auto& my_dispatch_constants = dispatch_constants::get(GetCoreType()); - if (this->as_mux) { + if (as_mux_) { uint16_t channel = - tt::Cluster::instance().get_assigned_channel_for_device(this->servicing_device_id); // TODO: can be mmio - this->logical_core = - dispatch_core_manager::instance().mux_core(this->servicing_device_id, channel, placement_cq_id); - this->config.rx_queue_start_addr_words = my_dispatch_constants.dispatch_buffer_base() >> 4; + tt::Cluster::instance().get_assigned_channel_for_device(servicing_device_id_); // TODO: can be mmio + logical_core_ = dispatch_core_manager::instance().mux_core(servicing_device_id_, channel, placement_cq_id_); + static_config_.rx_queue_start_addr_words = my_dispatch_constants.dispatch_buffer_base() >> 4; // TODO: why is this hard-coded NUM_CQS=1 for galaxy? if (tt::Cluster::instance().is_galaxy_cluster()) { - this->config.rx_queue_size_words = my_dispatch_constants.mux_buffer_size(1) >> 4; + static_config_.rx_queue_size_words = my_dispatch_constants.mux_buffer_size(1) >> 4; } else { - this->config.rx_queue_size_words = my_dispatch_constants.mux_buffer_size(device->num_hw_cqs()) >> 4; + static_config_.rx_queue_size_words = my_dispatch_constants.mux_buffer_size(device_->num_hw_cqs()) >> 4; } - this->config.kernel_status_buf_addr_arg = 0; - this->config.kernel_status_buf_size_bytes = 0; - this->config.timeout_cycles = 0; - this->config.output_depacketize = {0x0}; - this->config.output_depacketize_log_page_size = {0x0}; - this->config.output_depacketize_downstream_sem = {0x0}; - this->config.output_depacketize_local_sem = {0x0}; - this->config.output_depacketize_remove_header = {0x0}; + static_config_.kernel_status_buf_addr_arg = 0; + static_config_.kernel_status_buf_size_bytes = 0; + static_config_.timeout_cycles = 0; + dependent_config_.output_depacketize = {0x0}; + static_config_.output_depacketize_log_page_size = {0x0}; + dependent_config_.output_depacketize_downstream_sem = {0x0}; + static_config_.output_depacketize_local_sem = {0x0}; + static_config_.output_depacketize_remove_header = {0x0}; - for (int idx = 0; idx < this->upstream_kernels.size(); idx++) { - this->config.input_packetize[idx] = 0x1; - this->config.input_packetize_log_page_size[idx] = dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE; - this->config.input_packetize_local_sem[idx] = - tt::tt_metal::CreateSemaphore(*program, this->logical_core, 0, GetCoreType()); - this->config.remote_rx_queue_id[idx] = 1; + for (int idx = 0; idx < upstream_kernels_.size(); idx++) { + static_config_.input_packetize[idx] = 0x1; + static_config_.input_packetize_log_page_size[idx] = dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE; + static_config_.input_packetize_local_sem[idx] = + tt::tt_metal::CreateSemaphore(*program_, logical_core_, 0, GetCoreType()); + dependent_config_.remote_rx_queue_id[idx] = 1; } // Mux fowrads all VCs - this->config.fwd_vc_count = this->config.vc_count; + static_config_.fwd_vc_count = this->static_config_.vc_count; } else { - uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device->id()); - this->logical_core = dispatch_core_manager::instance().demux_d_core(device->id(), channel, placement_cq_id); - this->config.rx_queue_start_addr_words = + uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device_->id()); + logical_core_ = dispatch_core_manager::instance().demux_d_core(device_->id(), channel, placement_cq_id_); + static_config_.rx_queue_start_addr_words = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::UNRESERVED) >> 4; - this->config.rx_queue_size_words = 0x8000 >> 4; + static_config_.rx_queue_size_words = 0x8000 >> 4; - this->config.kernel_status_buf_addr_arg = 0; - this->config.kernel_status_buf_size_bytes = 0; - this->config.timeout_cycles = 0; - this->config.output_depacketize = {0x0}; + static_config_.kernel_status_buf_addr_arg = 0; + static_config_.kernel_status_buf_size_bytes = 0; + static_config_.timeout_cycles = 0; + dependent_config_.output_depacketize = {0x0}; - this->config.input_packetize = {0x0}; - this->config.input_packetize_log_page_size = {0x0}; - this->config.input_packetize_upstream_sem = {0x0}; - this->config.input_packetize_local_sem = {0x0}; - this->config.input_packetize_src_endpoint = {0x0}; - this->config.input_packetize_dst_endpoint = {0x0}; + static_config_.input_packetize = {0x0}; + static_config_.input_packetize_log_page_size = {0x0}; + dependent_config_.input_packetize_upstream_sem = {0x0}; + static_config_.input_packetize_local_sem = {0x0}; + dependent_config_.input_packetize_src_endpoint = {0x0}; + dependent_config_.input_packetize_dst_endpoint = {0x0}; - this->config.fwd_vc_count = this->config.vc_count; + static_config_.fwd_vc_count = this->static_config_.vc_count; uint32_t created_semaphores = 0; - for (int idx = 0; idx < this->downstream_kernels.size(); idx++) { + for (int idx = 0; idx < downstream_kernels_.size(); idx++) { // Forwward VCs are the ones that don't connect to a prefetch - if (auto pk = dynamic_cast(this->downstream_kernels[idx])) { - this->config.fwd_vc_count = this->config.fwd_vc_count.value() - 1; - this->config.output_depacketize_local_sem[idx] = // TODO: to match for now, init one per vc after - tt::tt_metal::CreateSemaphore(*program, this->logical_core, 0, GetCoreType()); + if (auto pk = dynamic_cast(downstream_kernels_[idx])) { + static_config_.fwd_vc_count = this->static_config_.fwd_vc_count.value() - 1; + static_config_.output_depacketize_local_sem[idx] = // TODO: to match for now, init one per vc after + tt::tt_metal::CreateSemaphore(*program_, logical_core_, 0, GetCoreType()); created_semaphores++; } } if (created_semaphores == 0) { // Just to match previous implementation - tt::tt_metal::CreateSemaphore(*program, this->logical_core, 0, GetCoreType()); + tt::tt_metal::CreateSemaphore(*program_, logical_core_, 0, GetCoreType()); } - for (int idx = 0; idx < this->config.vc_count.value(); idx++) { - this->config.output_depacketize_log_page_size[idx] = dispatch_constants::PREFETCH_D_BUFFER_LOG_PAGE_SIZE; - this->config.output_depacketize_remove_header[idx] = 0; + for (int idx = 0; idx < static_config_.vc_count.value(); idx++) { + static_config_.output_depacketize_log_page_size[idx] = dispatch_constants::PREFETCH_D_BUFFER_LOG_PAGE_SIZE; + static_config_.output_depacketize_remove_header[idx] = 0; } } } void EthRouterKernel::GenerateDependentConfigs() { - if (this->as_mux) { + if (as_mux_) { // Upstream, expect PRETETCH_Hs - TT_ASSERT(this->upstream_kernels.size() <= MAX_SWITCH_FAN_IN && this->upstream_kernels.size() > 0); + TT_ASSERT(upstream_kernels_.size() <= MAX_SWITCH_FAN_IN && upstream_kernels_.size() > 0); // Downstream, expect US_TUNNELER_REMOTE - TT_ASSERT(this->downstream_kernels.size() == 1); - auto tunneler_kernel = dynamic_cast(this->downstream_kernels[0]); + TT_ASSERT(downstream_kernels_.size() == 1); + auto tunneler_kernel = dynamic_cast(downstream_kernels_[0]); TT_ASSERT(tunneler_kernel); uint32_t router_id = tunneler_kernel->GetRouterId(this, true); - for (int idx = 0; idx < this->upstream_kernels.size(); idx++) { - auto prefetch_kernel = dynamic_cast(this->upstream_kernels[idx]); + for (int idx = 0; idx < upstream_kernels_.size(); idx++) { + auto prefetch_kernel = dynamic_cast(upstream_kernels_[idx]); TT_ASSERT(prefetch_kernel); - this->config.remote_tx_x[idx] = tunneler_kernel->GetVirtualCore().x; - this->config.remote_tx_y[idx] = tunneler_kernel->GetVirtualCore().y; - this->config.remote_tx_queue_id[idx] = idx + MAX_SWITCH_FAN_IN * router_id; - this->config.remote_tx_network_type[idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; - this->config.remote_tx_queue_start_addr_words[idx] = - tunneler_kernel->GetConfig().in_queue_start_addr_words.value() + - (idx + router_id * MAX_SWITCH_FAN_IN) * tunneler_kernel->GetConfig().in_queue_size_words.value(); - this->config.remote_tx_queue_size_words[idx] = tunneler_kernel->GetConfig().in_queue_size_words.value(); + dependent_config_.remote_tx_x[idx] = tunneler_kernel->GetVirtualCore().x; + dependent_config_.remote_tx_y[idx] = tunneler_kernel->GetVirtualCore().y; + dependent_config_.remote_tx_queue_id[idx] = idx + MAX_SWITCH_FAN_IN * router_id; + dependent_config_.remote_tx_network_type[idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; + dependent_config_.remote_tx_queue_start_addr_words[idx] = + tunneler_kernel->GetStaticConfig().in_queue_start_addr_words.value() + + (idx + router_id * MAX_SWITCH_FAN_IN) * tunneler_kernel->GetStaticConfig().in_queue_size_words.value(); + dependent_config_.remote_tx_queue_size_words[idx] = + tunneler_kernel->GetStaticConfig().in_queue_size_words.value(); - this->config.remote_rx_x[idx] = prefetch_kernel->GetVirtualCore().x; - this->config.remote_rx_y[idx] = prefetch_kernel->GetVirtualCore().y; - this->config.remote_rx_network_type[idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; + dependent_config_.remote_rx_x[idx] = prefetch_kernel->GetVirtualCore().x; + dependent_config_.remote_rx_y[idx] = prefetch_kernel->GetVirtualCore().y; + dependent_config_.remote_rx_network_type[idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; - this->config.input_packetize_upstream_sem[idx] = - prefetch_kernel->GetConfig().my_downstream_cb_sem_id.value(); + dependent_config_.input_packetize_upstream_sem[idx] = + prefetch_kernel->GetStaticConfig().my_downstream_cb_sem_id.value(); } uint32_t src_id_start = 0xA1 + router_id * MAX_SWITCH_FAN_IN; uint32_t dst_id_start = 0xB1 + router_id * MAX_SWITCH_FAN_IN; - this->config.input_packetize_src_endpoint = { + dependent_config_.input_packetize_src_endpoint = { src_id_start, src_id_start + 1, src_id_start + 2, src_id_start + 3}; - this->config.input_packetize_dst_endpoint = { + dependent_config_.input_packetize_dst_endpoint = { dst_id_start, dst_id_start + 1, dst_id_start + 2, dst_id_start + 3}; } else { // Upstream, expect US_TUNNELER_LOCAL - TT_ASSERT(this->upstream_kernels.size() == 1); - auto us_tunneler_kernel = dynamic_cast(this->upstream_kernels[0]); + TT_ASSERT(upstream_kernels_.size() == 1); + auto us_tunneler_kernel = dynamic_cast(upstream_kernels_[0]); TT_ASSERT(us_tunneler_kernel); // Upstream queues connect to the upstream tunneler, as many queues as we have VCs - for (int idx = 0; idx < config.vc_count.value(); idx++) { - this->config.remote_rx_x[idx] = us_tunneler_kernel->GetVirtualCore().x; - this->config.remote_rx_y[idx] = us_tunneler_kernel->GetVirtualCore().y; + for (int idx = 0; idx < static_config_.vc_count.value(); idx++) { + dependent_config_.remote_rx_x[idx] = us_tunneler_kernel->GetVirtualCore().x; + dependent_config_.remote_rx_y[idx] = us_tunneler_kernel->GetVirtualCore().y; // Queue id starts counting after the input VCs - this->config.remote_rx_queue_id[idx] = us_tunneler_kernel->GetRouterQueueIdOffset(this, false) + idx; - this->config.remote_rx_network_type[idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; + dependent_config_.remote_rx_queue_id[idx] = us_tunneler_kernel->GetRouterQueueIdOffset(this, false) + idx; + dependent_config_.remote_rx_network_type[idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; } // Downstream, expect PREFETCH_D/US_TUNNELER_REMOTE - TT_ASSERT(this->downstream_kernels.size() <= MAX_SWITCH_FAN_OUT && this->downstream_kernels.size() > 0); + TT_ASSERT(downstream_kernels_.size() <= MAX_SWITCH_FAN_OUT && downstream_kernels_.size() > 0); std::vector prefetch_kernels; EthTunnelerKernel* ds_tunneler_kernel = nullptr; - for (auto k : this->downstream_kernels) { + for (auto k : downstream_kernels_) { if (auto pk = dynamic_cast(k)) { prefetch_kernels.push_back(pk); } else if (auto tk = dynamic_cast(k)) { @@ -150,37 +150,37 @@ void EthRouterKernel::GenerateDependentConfigs() { // Populate remote_tx_* for prefetch kernels, assume they are connected "first" uint32_t remote_idx = 0; for (auto prefetch_kernel : prefetch_kernels) { - this->config.remote_tx_x[remote_idx] = prefetch_kernel->GetVirtualCore().x; - this->config.remote_tx_y[remote_idx] = prefetch_kernel->GetVirtualCore().y; - this->config.remote_tx_queue_id[remote_idx] = 0; // Prefetch queue id always 0 - this->config.remote_tx_network_type[remote_idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; - this->config.remote_tx_queue_start_addr_words[remote_idx] = - prefetch_kernel->GetConfig().cmddat_q_base.value() >> 4; - this->config.remote_tx_queue_size_words[remote_idx] = - prefetch_kernel->GetConfig().cmddat_q_size.value() >> 4; - this->config.output_depacketize[remote_idx] = 1; - this->config.output_depacketize_downstream_sem[remote_idx] = - prefetch_kernel->GetConfig().my_upstream_cb_sem_id; + dependent_config_.remote_tx_x[remote_idx] = prefetch_kernel->GetVirtualCore().x; + dependent_config_.remote_tx_y[remote_idx] = prefetch_kernel->GetVirtualCore().y; + dependent_config_.remote_tx_queue_id[remote_idx] = 0; // Prefetch queue id always 0 + dependent_config_.remote_tx_network_type[remote_idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; + dependent_config_.remote_tx_queue_start_addr_words[remote_idx] = + prefetch_kernel->GetStaticConfig().cmddat_q_base.value() >> 4; + dependent_config_.remote_tx_queue_size_words[remote_idx] = + prefetch_kernel->GetStaticConfig().cmddat_q_size.value() >> 4; + dependent_config_.output_depacketize[remote_idx] = 1; + dependent_config_.output_depacketize_downstream_sem[remote_idx] = + prefetch_kernel->GetStaticConfig().my_upstream_cb_sem_id; remote_idx++; } // Populate remote_tx_* for the downstream tunneler, as many queues as we have fwd VCs if (ds_tunneler_kernel) { - for (int idx = 0; idx < config.fwd_vc_count.value(); idx++) { - this->config.remote_tx_x[remote_idx] = ds_tunneler_kernel->GetVirtualCore().x; - this->config.remote_tx_y[remote_idx] = ds_tunneler_kernel->GetVirtualCore().y; - this->config.remote_tx_queue_id[remote_idx] = + for (int idx = 0; idx < static_config_.fwd_vc_count.value(); idx++) { + dependent_config_.remote_tx_x[remote_idx] = ds_tunneler_kernel->GetVirtualCore().x; + dependent_config_.remote_tx_y[remote_idx] = ds_tunneler_kernel->GetVirtualCore().y; + dependent_config_.remote_tx_queue_id[remote_idx] = ds_tunneler_kernel->GetRouterQueueIdOffset(this, true) + idx; - this->config.remote_tx_network_type[remote_idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; - this->config.remote_tx_queue_start_addr_words[remote_idx] = - ds_tunneler_kernel->GetConfig().in_queue_start_addr_words.value() + - ds_tunneler_kernel->GetConfig().in_queue_size_words.value() * - (this->config.remote_tx_queue_id[remote_idx].value()); - this->config.remote_tx_queue_size_words[remote_idx] = - ds_tunneler_kernel->GetConfig().in_queue_size_words.value(); + dependent_config_.remote_tx_network_type[remote_idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; + dependent_config_.remote_tx_queue_start_addr_words[remote_idx] = + ds_tunneler_kernel->GetStaticConfig().in_queue_start_addr_words.value() + + ds_tunneler_kernel->GetStaticConfig().in_queue_size_words.value() * + (dependent_config_.remote_tx_queue_id[remote_idx].value()); + dependent_config_.remote_tx_queue_size_words[remote_idx] = + ds_tunneler_kernel->GetStaticConfig().in_queue_size_words.value(); // Don't depacketize when sending to tunneler - this->config.output_depacketize[remote_idx] = 0; - this->config.output_depacketize_downstream_sem[remote_idx] = 0; + dependent_config_.output_depacketize[remote_idx] = 0; + dependent_config_.output_depacketize_downstream_sem[remote_idx] = 0; remote_idx++; } } @@ -190,9 +190,9 @@ void EthRouterKernel::GenerateDependentConfigs() { void EthRouterKernel::CreateKernel() { std::vector compile_args{ 0, // Unused - config.rx_queue_start_addr_words.value(), - config.rx_queue_size_words.value(), - config.vc_count.value(), + static_config_.rx_queue_start_addr_words.value(), + static_config_.rx_queue_size_words.value(), + static_config_.vc_count.value(), 0, 0, 0, @@ -211,9 +211,9 @@ void EthRouterKernel::CreateKernel() { 0, // Populate remote_rx_* after 0, 0, // Unused - config.kernel_status_buf_addr_arg.value(), - config.kernel_status_buf_size_bytes.value(), - config.timeout_cycles.value(), + static_config_.kernel_status_buf_addr_arg.value(), + static_config_.kernel_status_buf_size_bytes.value(), + static_config_.timeout_cycles.value(), 0, // Populate output_depacketize after 0, 0, @@ -227,72 +227,71 @@ void EthRouterKernel::CreateKernel() { 0, // input_packetize_dst_endpoint }; // Some unused values, just hardcode them to match for checking purposes... - if (!this->as_mux) { + if (!as_mux_) { compile_args[0] = 0xB1; // compile_args[21] = 84; } for (int idx = 0; idx < MAX_SWITCH_FAN_OUT; idx++) { - if (config.remote_tx_x[idx]) { - compile_args[4 + idx] |= (config.remote_tx_x[idx].value() & 0xFF); - compile_args[4 + idx] |= (config.remote_tx_y[idx].value() & 0xFF) << 8; - compile_args[4 + idx] |= (config.remote_tx_queue_id[idx].value() & 0xFF) << 16; - compile_args[4 + idx] |= (config.remote_tx_network_type[idx].value() & 0xFF) << 24; + if (dependent_config_.remote_tx_x[idx]) { + compile_args[4 + idx] |= (dependent_config_.remote_tx_x[idx].value() & 0xFF); + compile_args[4 + idx] |= (dependent_config_.remote_tx_y[idx].value() & 0xFF) << 8; + compile_args[4 + idx] |= (dependent_config_.remote_tx_queue_id[idx].value() & 0xFF) << 16; + compile_args[4 + idx] |= (dependent_config_.remote_tx_network_type[idx].value() & 0xFF) << 24; } - if (config.remote_tx_queue_start_addr_words[idx]) { - compile_args[8 + idx * 2] = config.remote_tx_queue_start_addr_words[idx].value(); - compile_args[9 + idx * 2] = config.remote_tx_queue_size_words[idx].value(); + if (dependent_config_.remote_tx_queue_start_addr_words[idx]) { + compile_args[8 + idx * 2] = dependent_config_.remote_tx_queue_start_addr_words[idx].value(); + compile_args[9 + idx * 2] = dependent_config_.remote_tx_queue_size_words[idx].value(); } - if (config.output_depacketize[idx]) { - compile_args[25] |= (config.output_depacketize[idx].value() & 0x1) << idx; - if (config.output_depacketize[idx].value() & 0x1) { // To match previous implementation - compile_args[26 + idx] |= (config.output_depacketize_log_page_size[idx].value() & 0xFF); - compile_args[26 + idx] |= (config.output_depacketize_downstream_sem[idx].value() & 0xFF) << 8; - compile_args[26 + idx] |= (config.output_depacketize_local_sem[idx].value() & 0xFF) << 16; - compile_args[26 + idx] |= (config.output_depacketize_remove_header[idx].value() & 0xFF) << 24; + if (dependent_config_.output_depacketize[idx]) { + compile_args[25] |= (dependent_config_.output_depacketize[idx].value() & 0x1) << idx; + if (dependent_config_.output_depacketize[idx].value() & 0x1) { // To match previous implementation + compile_args[26 + idx] |= (static_config_.output_depacketize_log_page_size[idx].value() & 0xFF); + compile_args[26 + idx] |= (dependent_config_.output_depacketize_downstream_sem[idx].value() & 0xFF) + << 8; + compile_args[26 + idx] |= (static_config_.output_depacketize_local_sem[idx].value() & 0xFF) << 16; + compile_args[26 + idx] |= (static_config_.output_depacketize_remove_header[idx].value() & 0xFF) << 24; } } } for (int idx = 0; idx < MAX_SWITCH_FAN_IN; idx++) { - if (config.remote_rx_x[idx]) { - compile_args[16 + idx] |= (config.remote_rx_x[idx].value() & 0xFF); - compile_args[16 + idx] |= (config.remote_rx_y[idx].value() & 0xFF) << 8; - compile_args[16 + idx] |= (config.remote_rx_queue_id[idx].value() & 0xFF) << 16; - compile_args[16 + idx] |= (config.remote_rx_network_type[idx].value() & 0xFF) << 24; + if (dependent_config_.remote_rx_x[idx]) { + compile_args[16 + idx] |= (dependent_config_.remote_rx_x[idx].value() & 0xFF); + compile_args[16 + idx] |= (dependent_config_.remote_rx_y[idx].value() & 0xFF) << 8; + compile_args[16 + idx] |= (dependent_config_.remote_rx_queue_id[idx].value() & 0xFF) << 16; + compile_args[16 + idx] |= (dependent_config_.remote_rx_network_type[idx].value() & 0xFF) << 24; } - if (config.input_packetize[idx]) { - compile_args[30 + idx] |= (config.input_packetize[idx].value() & 0xFF); - compile_args[30 + idx] |= (config.input_packetize_log_page_size[idx].value() & 0xFF) << 8; - compile_args[30 + idx] |= (config.input_packetize_upstream_sem[idx].value() & 0xFF) << 16; - compile_args[30 + idx] |= (config.input_packetize_local_sem[idx].value() & 0xFF) << 24; + if (static_config_.input_packetize[idx]) { + compile_args[30 + idx] |= (static_config_.input_packetize[idx].value() & 0xFF); + compile_args[30 + idx] |= (static_config_.input_packetize_log_page_size[idx].value() & 0xFF) << 8; + compile_args[30 + idx] |= (dependent_config_.input_packetize_upstream_sem[idx].value() & 0xFF) << 16; + compile_args[30 + idx] |= (static_config_.input_packetize_local_sem[idx].value() & 0xFF) << 24; } - if (config.input_packetize_src_endpoint[idx]) { - compile_args[34] |= (config.input_packetize_src_endpoint[idx].value() & 0xFF) << (8 * idx); + if (dependent_config_.input_packetize_src_endpoint[idx]) { + compile_args[34] |= (dependent_config_.input_packetize_src_endpoint[idx].value() & 0xFF) << (8 * idx); } - if (config.input_packetize_dst_endpoint[idx]) { - compile_args[35] |= (config.input_packetize_dst_endpoint[idx].value() & 0xFF) << (8 * idx); + if (dependent_config_.input_packetize_dst_endpoint[idx]) { + compile_args[35] |= (dependent_config_.input_packetize_dst_endpoint[idx].value() & 0xFF) << (8 * idx); } } TT_ASSERT(compile_args.size() == 36); - const auto& grid_size = device->grid_size(); + const auto& grid_size = device_->grid_size(); std::map defines = { // All of these unused, remove later - {"MY_NOC_X", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.non_dispatch_noc, grid_size.x, 0))}, - {"MY_NOC_Y", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.non_dispatch_noc, grid_size.y, 0))}, - {"UPSTREAM_NOC_INDEX", std::to_string(this->noc_selection.upstream_noc)}, + {"MY_NOC_X", std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.non_dispatch_noc, grid_size.x, 0))}, + {"MY_NOC_Y", std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.non_dispatch_noc, grid_size.y, 0))}, + {"UPSTREAM_NOC_INDEX", std::to_string(noc_selection_.upstream_noc)}, {"UPSTREAM_NOC_X", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.upstream_noc, grid_size.x, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.upstream_noc, grid_size.x, 0))}, {"UPSTREAM_NOC_Y", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.upstream_noc, grid_size.y, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.upstream_noc, grid_size.y, 0))}, {"DOWNSTREAM_NOC_X", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.downstream_noc, grid_size.x, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.downstream_noc, grid_size.x, 0))}, {"DOWNSTREAM_NOC_Y", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.downstream_noc, grid_size.y, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.downstream_noc, grid_size.y, 0))}, {"DOWNSTREAM_SLAVE_NOC_X", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.downstream_noc, grid_size.x, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.downstream_noc, grid_size.x, 0))}, {"DOWNSTREAM_SLAVE_NOC_Y", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.downstream_noc, grid_size.y, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.downstream_noc, grid_size.y, 0))}, {"SKIP_NOC_LOGGING", "1"}}; configure_kernel_variant(dispatch_kernel_file_names[PACKET_ROUTER_MUX], compile_args, defines, false, false, false); } diff --git a/tt_metal/impl/dispatch/kernel_config/eth_router_kernel.hpp b/tt_metal/impl/dispatch/kernel_config/eth_router_kernel.hpp index 752683c25126..8915044f2210 100644 --- a/tt_metal/impl/dispatch/kernel_config/eth_router_kernel.hpp +++ b/tt_metal/impl/dispatch/kernel_config/eth_router_kernel.hpp @@ -4,12 +4,25 @@ #pragma once #include "fd_kernel.hpp" -typedef struct eth_router_config { +typedef struct eth_router_static_config { std::optional vc_count; // Set from arch level std::optional fwd_vc_count; // # of VCs continuing on to the next chip std::optional rx_queue_start_addr_words; // 1 std::optional rx_queue_size_words; + std::optional kernel_status_buf_addr_arg; // 22 + std::optional kernel_status_buf_size_bytes; + std::optional timeout_cycles; + + std::array, MAX_SWITCH_FAN_OUT> output_depacketize_log_page_size; // [26:29] + std::array, MAX_SWITCH_FAN_OUT> output_depacketize_local_sem; // [26:29] + std::array, MAX_SWITCH_FAN_OUT> output_depacketize_remove_header; // [26:29] + std::array, MAX_SWITCH_FAN_IN> input_packetize; // [30:33] + std::array, MAX_SWITCH_FAN_IN> input_packetize_log_page_size; // [30:33] + std::array, MAX_SWITCH_FAN_IN> input_packetize_local_sem; // [30:33] +} eth_router_static_config_t; + +typedef struct eth_router_dependent_config { std::array, MAX_SWITCH_FAN_OUT> remote_tx_x; // [4:7], dependent std::array, MAX_SWITCH_FAN_OUT> remote_tx_y; // [4:7], dependent std::array, MAX_SWITCH_FAN_OUT> remote_tx_queue_id; // [4:7], dependent @@ -21,22 +34,12 @@ typedef struct eth_router_config { std::array, MAX_SWITCH_FAN_IN> remote_rx_queue_id; // [16:19], dependent std::array, MAX_SWITCH_FAN_IN> remote_rx_network_type; // [17:19], dependent - std::optional kernel_status_buf_addr_arg; // 22 - std::optional kernel_status_buf_size_bytes; - std::optional timeout_cycles; - std::array, MAX_SWITCH_FAN_OUT> output_depacketize; // 25, dependent - std::array, MAX_SWITCH_FAN_OUT> output_depacketize_log_page_size; // [26:29] std::array, MAX_SWITCH_FAN_OUT> output_depacketize_downstream_sem; // [26:29], dependent - std::array, MAX_SWITCH_FAN_OUT> output_depacketize_local_sem; // [26:29] - std::array, MAX_SWITCH_FAN_OUT> output_depacketize_remove_header; // [26:29] - std::array, MAX_SWITCH_FAN_IN> input_packetize; // [30:33] - std::array, MAX_SWITCH_FAN_IN> input_packetize_log_page_size; // [30:33] std::array, MAX_SWITCH_FAN_IN> input_packetize_upstream_sem; // [30:33], dependent - std::array, MAX_SWITCH_FAN_IN> input_packetize_local_sem; // [30:33] std::array, MAX_SWITCH_FAN_IN> input_packetize_src_endpoint; // Dependent std::array, MAX_SWITCH_FAN_IN> input_packetize_dst_endpoint; // Dependent -} eth_router_config_t; +} eth_router_dependent_config_t; class EthRouterKernel : public FDKernel { public: @@ -47,16 +50,17 @@ class EthRouterKernel : public FDKernel { uint8_t cq_id, noc_selection_t noc_selection, bool as_mux) : - FDKernel(node_id, device_id, servicing_device_id, cq_id, noc_selection), as_mux(as_mux) {} + FDKernel(node_id, device_id, servicing_device_id, cq_id, noc_selection), as_mux_(as_mux) {} void CreateKernel() override; void GenerateStaticConfigs() override; void GenerateDependentConfigs() override; - const eth_router_config_t& GetConfig() { return this->config; } - void SetVCCount(uint32_t vc_count) { this->config.vc_count = vc_count; } - void SetPlacementCQID(int id) { this->placement_cq_id = id; } + const eth_router_static_config_t& GetStaticConfig() { return static_config_; } + void SetVCCount(uint32_t vc_count) { static_config_.vc_count = vc_count; } + void SetPlacementCQID(int id) { placement_cq_id_ = id; } private: - eth_router_config_t config; - int placement_cq_id; // TODO: remove channel hard-coding for dispatch core manager - bool as_mux; + eth_router_static_config_t static_config_; + eth_router_dependent_config_t dependent_config_; + int placement_cq_id_; // TODO: remove channel hard-coding for dispatch core manager + bool as_mux_; }; diff --git a/tt_metal/impl/dispatch/kernel_config/eth_tunneler_kernel.cpp b/tt_metal/impl/dispatch/kernel_config/eth_tunneler_kernel.cpp index 7a1ac1e6e641..664ff6094dc6 100644 --- a/tt_metal/impl/dispatch/kernel_config/eth_tunneler_kernel.cpp +++ b/tt_metal/impl/dispatch/kernel_config/eth_tunneler_kernel.cpp @@ -10,43 +10,43 @@ #include "tt_metal/detail/tt_metal.hpp" void EthTunnelerKernel::GenerateStaticConfigs() { - chip_id_t downstream_device_id = FDKernel::GetDownstreamDeviceId(device_id); + chip_id_t downstream_device_id = FDKernel::GetDownstreamDeviceId(device_id_); // For MMIO devices, the above function just gets one of the possible downstream devices, we've populated this // specific case with servicing_device_id - if (device->is_mmio_capable()) { - downstream_device_id = servicing_device_id; + if (device_->is_mmio_capable()) { + downstream_device_id = servicing_device_id_; } if (this->IsRemote()) { uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(downstream_device_id); - this->logical_core = - dispatch_core_manager::instance().tunneler_core(device->id(), downstream_device_id, channel, cq_id); + logical_core_ = + dispatch_core_manager::instance().tunneler_core(device_->id(), downstream_device_id, channel, cq_id_); } else { - uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device->id()); - this->logical_core = dispatch_core_manager::instance().us_tunneler_core_local(device->id(), channel, cq_id); + uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device_->id()); + logical_core_ = dispatch_core_manager::instance().us_tunneler_core_local(device_->id(), channel, cq_id_); } - this->config.endpoint_id_start_index = 0xDACADACA; - this->config.in_queue_start_addr_words = 0x19000 >> 4; - this->config.in_queue_size_words = 0x4000 >> 4; - this->config.kernel_status_buf_addr_arg = 0x39000; - this->config.kernel_status_buf_size_bytes = 0x7000; - this->config.timeout_cycles = 0; + static_config_.endpoint_id_start_index = 0xDACADACA; + static_config_.in_queue_start_addr_words = 0x19000 >> 4; + static_config_.in_queue_size_words = 0x4000 >> 4; + static_config_.kernel_status_buf_addr_arg = 0x39000; + static_config_.kernel_status_buf_size_bytes = 0x7000; + static_config_.timeout_cycles = 0; } void EthTunnelerKernel::GenerateDependentConfigs() { if (this->IsRemote()) { // For remote tunneler, we don't actually have the device constructed for the paired tunneler, so can't pull // info from it. Core coord can be computed without the device, and relevant fields match this tunneler. - chip_id_t downstream_device_id = FDKernel::GetDownstreamDeviceId(device_id); + chip_id_t downstream_device_id = FDKernel::GetDownstreamDeviceId(device_id_); uint16_t downstream_channel = tt::Cluster::instance().get_assigned_channel_for_device(downstream_device_id); tt_cxy_pair paired_logical_core = - dispatch_core_manager::instance().us_tunneler_core_local(downstream_device_id, downstream_channel, cq_id); + dispatch_core_manager::instance().us_tunneler_core_local(downstream_device_id, downstream_channel, cq_id_); tt_cxy_pair paired_physical_coord = tt::Cluster::instance().get_virtual_coordinate_from_logical_coordinates(paired_logical_core, CoreType::ETH); // Upstream, we expect a US_TUNNELER_LOCAL and one or more PACKET_ROUTER EthTunnelerKernel* tunneler_kernel = nullptr; std::vector router_kernels; - for (auto k : this->upstream_kernels) { + for (auto k : upstream_kernels_) { if (auto rk = dynamic_cast(k)) { router_kernels.push_back(rk); } else if (auto tk = dynamic_cast(k)) { @@ -60,108 +60,114 @@ void EthTunnelerKernel::GenerateDependentConfigs() { // Remote sender is the upstream packet router, one queue per router output lane. int remote_idx = 0; for (auto router_kernel : router_kernels) { - uint32_t router_vc_count = router_kernel->GetConfig().vc_count.value(); - uint32_t router_fwd_vc_count = router_kernel->GetConfig().fwd_vc_count.value(); + uint32_t router_vc_count = router_kernel->GetStaticConfig().vc_count.value(); + uint32_t router_fwd_vc_count = router_kernel->GetStaticConfig().fwd_vc_count.value(); for (int idx = 0; idx < router_fwd_vc_count; idx++) { - this->config.remote_sender_x[remote_idx] = router_kernel->GetVirtualCore().x; - this->config.remote_sender_y[remote_idx] = router_kernel->GetVirtualCore().y; + dependent_config_.remote_sender_x[remote_idx] = router_kernel->GetVirtualCore().x; + dependent_config_.remote_sender_y[remote_idx] = router_kernel->GetVirtualCore().y; // Router output lane ids start after it's input lane ids, assume after lanes that go to on-device // kernels - this->config.remote_sender_queue_id[remote_idx] = + dependent_config_.remote_sender_queue_id[remote_idx] = router_vc_count + idx + router_vc_count - router_fwd_vc_count; - this->config.remote_sender_network_type[remote_idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; + dependent_config_.remote_sender_network_type[remote_idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; remote_idx++; } } // Last upstream connection is the return path from other tunneler - this->config.remote_sender_x[this->config.vc_count.value() - 1] = paired_physical_coord.x; - this->config.remote_sender_y[this->config.vc_count.value() - 1] = paired_physical_coord.y; - this->config.remote_sender_queue_id[this->config.vc_count.value() - 1] = this->config.vc_count.value() * 2 - 1; - this->config.remote_sender_network_type[this->config.vc_count.value() - 1] = + dependent_config_.remote_sender_x[this->static_config_.vc_count.value() - 1] = paired_physical_coord.x; + dependent_config_.remote_sender_y[this->static_config_.vc_count.value() - 1] = paired_physical_coord.y; + dependent_config_.remote_sender_queue_id[this->static_config_.vc_count.value() - 1] = + this->static_config_.vc_count.value() * 2 - 1; + dependent_config_.remote_sender_network_type[this->static_config_.vc_count.value() - 1] = (uint32_t)DispatchRemoteNetworkType::ETH; - this->config.inner_stop_mux_d_bypass = 0; + dependent_config_.inner_stop_mux_d_bypass = 0; // Downstream, we expect the same US_TUNNELER_LOCAL and a DEMUX (tunnel start)/MUX_D (non-tunnel start) - TT_ASSERT(this->downstream_kernels.size() == 2); - auto ds_tunneler_kernel = dynamic_cast(this->downstream_kernels[0]); - auto other_ds_kernel = this->downstream_kernels[1]; + TT_ASSERT(downstream_kernels_.size() == 2); + auto ds_tunneler_kernel = dynamic_cast(downstream_kernels_[0]); + auto other_ds_kernel = downstream_kernels_[1]; if (!ds_tunneler_kernel) { - ds_tunneler_kernel = dynamic_cast(this->downstream_kernels[1]); - auto other_ds_kernel = this->downstream_kernels[0]; + ds_tunneler_kernel = dynamic_cast(downstream_kernels_[1]); + auto other_ds_kernel = downstream_kernels_[0]; } TT_ASSERT(ds_tunneler_kernel == tunneler_kernel); - for (uint32_t idx = 0; idx < this->config.vc_count.value(); idx++) { - if (idx == this->config.vc_count.value() - 1) { + for (uint32_t idx = 0; idx < static_config_.vc_count.value(); idx++) { + if (idx == static_config_.vc_count.value() - 1) { // Last VC is the return VC, driving a DEMUX or MUX_D - this->config.remote_receiver_x[idx] = other_ds_kernel->GetVirtualCore().x; - this->config.remote_receiver_y[idx] = other_ds_kernel->GetVirtualCore().y; - this->config.remote_receiver_network_type[idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; + dependent_config_.remote_receiver_x[idx] = other_ds_kernel->GetVirtualCore().x; + dependent_config_.remote_receiver_y[idx] = other_ds_kernel->GetVirtualCore().y; + dependent_config_.remote_receiver_network_type[idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; if (auto demux_kernel = dynamic_cast(other_ds_kernel)) { - this->config.remote_receiver_queue_start[idx] = demux_kernel->GetConfig().rx_queue_start_addr_words; - this->config.remote_receiver_queue_size[idx] = demux_kernel->GetConfig().rx_queue_size_words; - this->config.remote_receiver_queue_id[idx] = 0; // DEMUX input queue id always 0 + dependent_config_.remote_receiver_queue_start[idx] = + demux_kernel->GetStaticConfig().rx_queue_start_addr_words; + dependent_config_.remote_receiver_queue_size[idx] = + demux_kernel->GetStaticConfig().rx_queue_size_words; + dependent_config_.remote_receiver_queue_id[idx] = 0; // DEMUX input queue id always 0 } else if (auto mux_kernel = dynamic_cast(other_ds_kernel)) { - this->config.remote_receiver_queue_start[idx] = - mux_kernel->GetConfig().rx_queue_start_addr_words.value() + - mux_kernel->GetConfig().rx_queue_size_words.value() * - (mux_kernel->GetConfig().mux_fan_in.value() - 1); - this->config.remote_receiver_queue_size[idx] = mux_kernel->GetConfig().rx_queue_size_words; + dependent_config_.remote_receiver_queue_start[idx] = + mux_kernel->GetStaticConfig().rx_queue_start_addr_words.value() + + mux_kernel->GetStaticConfig().rx_queue_size_words.value() * + (mux_kernel->GetStaticConfig().mux_fan_in.value() - 1); + dependent_config_.remote_receiver_queue_size[idx] = + mux_kernel->GetStaticConfig().rx_queue_size_words; // MUX input queue id for tunneler is the last one (counting up from 0) - this->config.remote_receiver_queue_id[idx] = mux_kernel->GetConfig().mux_fan_in.value() - 1; + dependent_config_.remote_receiver_queue_id[idx] = + mux_kernel->GetStaticConfig().mux_fan_in.value() - 1; } else { TT_FATAL(false, "Unexpected kernel type downstream of ETH_TUNNELER"); } } else { - this->config.remote_receiver_x[idx] = paired_physical_coord.x; - this->config.remote_receiver_y[idx] = paired_physical_coord.y; + dependent_config_.remote_receiver_x[idx] = paired_physical_coord.x; + dependent_config_.remote_receiver_y[idx] = paired_physical_coord.y; // Tunneler upstream queue ids start counting up from 0 - this->config.remote_receiver_queue_id[idx] = idx; - this->config.remote_receiver_network_type[idx] = (uint32_t)DispatchRemoteNetworkType::ETH; - this->config.remote_receiver_queue_start[idx] = - this->config.in_queue_start_addr_words.value() + idx * this->config.in_queue_size_words.value(); - this->config.remote_receiver_queue_size[idx] = this->config.in_queue_size_words; + dependent_config_.remote_receiver_queue_id[idx] = idx; + dependent_config_.remote_receiver_network_type[idx] = (uint32_t)DispatchRemoteNetworkType::ETH; + dependent_config_.remote_receiver_queue_start[idx] = + static_config_.in_queue_start_addr_words.value() + + idx * this->static_config_.in_queue_size_words.value(); + dependent_config_.remote_receiver_queue_size[idx] = this->static_config_.in_queue_size_words; } } } else { // Upstream, we expect a US_TUNNELER_REMOTE and a MUX_D. Same deal where upstream tunneler may not be populated // yet since its device may not be created yet. - chip_id_t upstream_device_id = FDKernel::GetUpstreamDeviceId(device_id); - uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device_id); + chip_id_t upstream_device_id = FDKernel::GetUpstreamDeviceId(device_id_); + uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device_id_); tt_cxy_pair paired_logical_core = - dispatch_core_manager::instance().tunneler_core(upstream_device_id, device_id, channel, cq_id); + dispatch_core_manager::instance().tunneler_core(upstream_device_id, device_id_, channel, cq_id_); tt_cxy_pair paired_physical_coord = tt::Cluster::instance().get_virtual_coordinate_from_logical_coordinates(paired_logical_core, CoreType::ETH); - TT_ASSERT(this->upstream_kernels.size() == 2); - auto tunneler_kernel = dynamic_cast(this->upstream_kernels[0]); - auto mux_kernel = dynamic_cast(this->upstream_kernels[1]); + TT_ASSERT(upstream_kernels_.size() == 2); + auto tunneler_kernel = dynamic_cast(upstream_kernels_[0]); + auto mux_kernel = dynamic_cast(upstream_kernels_[1]); if (!tunneler_kernel) { - tunneler_kernel = dynamic_cast(this->upstream_kernels[1]); - mux_kernel = dynamic_cast(this->upstream_kernels[0]); + tunneler_kernel = dynamic_cast(upstream_kernels_[1]); + mux_kernel = dynamic_cast(upstream_kernels_[0]); } TT_ASSERT(tunneler_kernel && mux_kernel); TT_ASSERT(tunneler_kernel->IsRemote()); - for (uint32_t idx = 0; idx < this->config.vc_count.value(); idx++) { - if (idx == this->config.vc_count.value() - 1) { + for (uint32_t idx = 0; idx < static_config_.vc_count.value(); idx++) { + if (idx == static_config_.vc_count.value() - 1) { // Last VC is the return VC, driven by the mux - this->config.remote_sender_x[idx] = mux_kernel->GetVirtualCore().x; - this->config.remote_sender_y[idx] = mux_kernel->GetVirtualCore().y; + dependent_config_.remote_sender_x[idx] = mux_kernel->GetVirtualCore().x; + dependent_config_.remote_sender_y[idx] = mux_kernel->GetVirtualCore().y; // MUX output queue id is counted after all of it's inputs - this->config.remote_sender_queue_id[idx] = mux_kernel->GetConfig().mux_fan_in.value(); - this->config.remote_sender_network_type[idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; + dependent_config_.remote_sender_queue_id[idx] = mux_kernel->GetStaticConfig().mux_fan_in.value(); + dependent_config_.remote_sender_network_type[idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; } else { - this->config.remote_sender_x[idx] = paired_physical_coord.x; - this->config.remote_sender_y[idx] = paired_physical_coord.y; + dependent_config_.remote_sender_x[idx] = paired_physical_coord.x; + dependent_config_.remote_sender_y[idx] = paired_physical_coord.y; // Tunneler downstream queue ids start counting after the upstream ones - this->config.remote_sender_queue_id[idx] = this->config.vc_count.value() + idx; - this->config.remote_sender_network_type[idx] = (uint32_t)DispatchRemoteNetworkType::ETH; + dependent_config_.remote_sender_queue_id[idx] = this->static_config_.vc_count.value() + idx; + dependent_config_.remote_sender_network_type[idx] = (uint32_t)DispatchRemoteNetworkType::ETH; } } // Downstream, we expect the same US_TUNNELER_REMOTE and one or more VC_PACKER_ROUTER EthTunnelerKernel* ds_tunneler_kernel = nullptr; std::vector router_kernels; - for (auto k : this->downstream_kernels) { + for (auto k : downstream_kernels_) { if (auto rk = dynamic_cast(k)) { router_kernels.push_back(rk); } else if (auto tk = dynamic_cast(k)) { @@ -175,35 +181,37 @@ void EthTunnelerKernel::GenerateDependentConfigs() { // Remote receiver is the downstream router, one queue per router input lane int remote_idx = 0; for (auto router_kernel : router_kernels) { - for (int idx = 0; idx < router_kernel->GetConfig().vc_count.value(); idx++) { - this->config.remote_receiver_x[remote_idx] = router_kernel->GetVirtualCore().x; - this->config.remote_receiver_y[remote_idx] = router_kernel->GetVirtualCore().y; - this->config.remote_receiver_queue_id[remote_idx] = idx; // Queue ids start counting from 0 at input - this->config.remote_receiver_network_type[remote_idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; - this->config.remote_receiver_queue_start[remote_idx] = - router_kernel->GetConfig().rx_queue_start_addr_words.value() + - idx * router_kernel->GetConfig().rx_queue_size_words.value(); - this->config.remote_receiver_queue_size[remote_idx] = - router_kernel->GetConfig().rx_queue_size_words.value(); + for (int idx = 0; idx < router_kernel->GetStaticConfig().vc_count.value(); idx++) { + dependent_config_.remote_receiver_x[remote_idx] = router_kernel->GetVirtualCore().x; + dependent_config_.remote_receiver_y[remote_idx] = router_kernel->GetVirtualCore().y; + dependent_config_.remote_receiver_queue_id[remote_idx] = + idx; // Queue ids start counting from 0 at input + dependent_config_.remote_receiver_network_type[remote_idx] = (uint32_t)DispatchRemoteNetworkType::NOC0; + dependent_config_.remote_receiver_queue_start[remote_idx] = + router_kernel->GetStaticConfig().rx_queue_start_addr_words.value() + + idx * router_kernel->GetStaticConfig().rx_queue_size_words.value(); + dependent_config_.remote_receiver_queue_size[remote_idx] = + router_kernel->GetStaticConfig().rx_queue_size_words.value(); remote_idx++; } } // Last receiver connection is the return VC, connected to the paired tunneler - uint32_t return_vc_id = this->config.vc_count.value() - 1; - this->config.remote_receiver_x[return_vc_id] = paired_physical_coord.x; - this->config.remote_receiver_y[return_vc_id] = paired_physical_coord.y; - this->config.remote_receiver_queue_id[return_vc_id] = return_vc_id; - this->config.remote_receiver_network_type[return_vc_id] = (uint32_t)DispatchRemoteNetworkType::ETH; - this->config.remote_receiver_queue_start[return_vc_id] = - this->config.in_queue_start_addr_words.value() + (return_vc_id) * this->config.in_queue_size_words.value(); - this->config.remote_receiver_queue_size[return_vc_id] = this->config.in_queue_size_words; - this->config.inner_stop_mux_d_bypass = 0; + uint32_t return_vc_id = static_config_.vc_count.value() - 1; + dependent_config_.remote_receiver_x[return_vc_id] = paired_physical_coord.x; + dependent_config_.remote_receiver_y[return_vc_id] = paired_physical_coord.y; + dependent_config_.remote_receiver_queue_id[return_vc_id] = return_vc_id; + dependent_config_.remote_receiver_network_type[return_vc_id] = (uint32_t)DispatchRemoteNetworkType::ETH; + dependent_config_.remote_receiver_queue_start[return_vc_id] = + static_config_.in_queue_start_addr_words.value() + + (return_vc_id) * this->static_config_.in_queue_size_words.value(); + dependent_config_.remote_receiver_queue_size[return_vc_id] = this->static_config_.in_queue_size_words; + dependent_config_.inner_stop_mux_d_bypass = 0; // For certain chips in a tunnel (between first stop and end of tunnel, not including), we do a bypass - if (this->config.vc_count.value() > (device->num_hw_cqs() + 1) && - this->config.vc_count.value() < (4 * device->num_hw_cqs() + 1)) { - this->config.inner_stop_mux_d_bypass = + if (static_config_.vc_count.value() > (device_->num_hw_cqs() + 1) && + static_config_.vc_count.value() < (4 * device_->num_hw_cqs() + 1)) { + dependent_config_.inner_stop_mux_d_bypass = (return_vc_id << 24) | - (((tunneler_kernel->GetConfig().vc_count.value() - device->num_hw_cqs()) * 2 - 1) << 16) | + (((tunneler_kernel->GetStaticConfig().vc_count.value() - device_->num_hw_cqs()) * 2 - 1) << 16) | (paired_physical_coord.y << 8) | (paired_physical_coord.x); } } @@ -211,10 +219,10 @@ void EthTunnelerKernel::GenerateDependentConfigs() { void EthTunnelerKernel::CreateKernel() { std::vector compile_args = { - config.endpoint_id_start_index.value(), - config.vc_count.value(), // # Tunnel lanes = VC count - config.in_queue_start_addr_words.value(), - config.in_queue_size_words.value(), + static_config_.endpoint_id_start_index.value(), + static_config_.vc_count.value(), // # Tunnel lanes = VC count + static_config_.in_queue_start_addr_words.value(), + static_config_.in_queue_size_words.value(), 0, 0, 0, @@ -255,54 +263,52 @@ void EthTunnelerKernel::CreateKernel() { 0, 0, 0, // Populate remote_sender_* after - config.kernel_status_buf_addr_arg.value(), - config.kernel_status_buf_size_bytes.value(), - config.timeout_cycles.value(), - config.inner_stop_mux_d_bypass.value()}; + static_config_.kernel_status_buf_addr_arg.value(), + static_config_.kernel_status_buf_size_bytes.value(), + static_config_.timeout_cycles.value(), + dependent_config_.inner_stop_mux_d_bypass.value()}; for (int idx = 0; idx < MAX_TUNNEL_LANES; idx++) { - if (config.remote_receiver_x[idx]) { - compile_args[4 + idx] |= (config.remote_receiver_x[idx].value() & 0xFF); - compile_args[4 + idx] |= (config.remote_receiver_y[idx].value() & 0xFF) << 8; - compile_args[4 + idx] |= (config.remote_receiver_queue_id[idx].value() & 0xFF) << 16; - compile_args[4 + idx] |= (config.remote_receiver_network_type[idx].value() & 0xFF) << 24; + if (dependent_config_.remote_receiver_x[idx]) { + compile_args[4 + idx] |= (dependent_config_.remote_receiver_x[idx].value() & 0xFF); + compile_args[4 + idx] |= (dependent_config_.remote_receiver_y[idx].value() & 0xFF) << 8; + compile_args[4 + idx] |= (dependent_config_.remote_receiver_queue_id[idx].value() & 0xFF) << 16; + compile_args[4 + idx] |= (dependent_config_.remote_receiver_network_type[idx].value() & 0xFF) << 24; } - if (config.remote_receiver_queue_start[idx]) { - compile_args[14 + idx * 2] = config.remote_receiver_queue_start[idx].value(); - compile_args[15 + idx * 2] = config.remote_receiver_queue_size[idx].value(); + if (dependent_config_.remote_receiver_queue_start[idx]) { + compile_args[14 + idx * 2] = dependent_config_.remote_receiver_queue_start[idx].value(); + compile_args[15 + idx * 2] = dependent_config_.remote_receiver_queue_size[idx].value(); } else { compile_args[15 + idx * 2] = 2; // Dummy size for unused VCs } - if (config.remote_sender_x[idx]) { - compile_args[34 + idx] |= (config.remote_sender_x[idx].value() & 0xFF); - compile_args[34 + idx] |= (config.remote_sender_y[idx].value() & 0xFF) << 8; - compile_args[34 + idx] |= (config.remote_sender_queue_id[idx].value() & 0xFF) << 16; - compile_args[34 + idx] |= (config.remote_sender_network_type[idx].value() & 0xFF) << 24; + if (dependent_config_.remote_sender_x[idx]) { + compile_args[34 + idx] |= (dependent_config_.remote_sender_x[idx].value() & 0xFF); + compile_args[34 + idx] |= (dependent_config_.remote_sender_y[idx].value() & 0xFF) << 8; + compile_args[34 + idx] |= (dependent_config_.remote_sender_queue_id[idx].value() & 0xFF) << 16; + compile_args[34 + idx] |= (dependent_config_.remote_sender_network_type[idx].value() & 0xFF) << 24; } } TT_ASSERT(compile_args.size() == 48); - const auto& grid_size = device->grid_size(); + const auto& grid_size = device_->grid_size(); std::map defines = { // All of these unused, remove later - {"MY_NOC_X", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.non_dispatch_noc, grid_size.x, 0))}, - {"MY_NOC_Y", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.non_dispatch_noc, grid_size.y, 0))}, - {"UPSTREAM_NOC_INDEX", std::to_string(this->noc_selection.upstream_noc)}, + {"MY_NOC_X", std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.non_dispatch_noc, grid_size.x, 0))}, + {"MY_NOC_Y", std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.non_dispatch_noc, grid_size.y, 0))}, + {"UPSTREAM_NOC_INDEX", std::to_string(noc_selection_.upstream_noc)}, {"UPSTREAM_NOC_X", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.upstream_noc, grid_size.x, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.upstream_noc, grid_size.x, 0))}, {"UPSTREAM_NOC_Y", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.upstream_noc, grid_size.y, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.upstream_noc, grid_size.y, 0))}, {"DOWNSTREAM_NOC_X", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.downstream_noc, grid_size.x, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.downstream_noc, grid_size.x, 0))}, {"DOWNSTREAM_NOC_Y", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.downstream_noc, grid_size.y, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.downstream_noc, grid_size.y, 0))}, {"DOWNSTREAM_SLAVE_NOC_X", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.downstream_noc, grid_size.x, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.downstream_noc, grid_size.x, 0))}, {"DOWNSTREAM_SLAVE_NOC_Y", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.downstream_noc, grid_size.y, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.downstream_noc, grid_size.y, 0))}, {"SKIP_NOC_LOGGING", "1"}}; configure_kernel_variant( - dispatch_kernel_file_names[this->is_remote ? US_TUNNELER_REMOTE : US_TUNNELER_LOCAL], + dispatch_kernel_file_names[is_remote_ ? US_TUNNELER_REMOTE : US_TUNNELER_LOCAL], compile_args, defines, true, @@ -311,22 +317,22 @@ void EthTunnelerKernel::CreateKernel() { } uint32_t EthTunnelerKernel::GetRouterQueueIdOffset(FDKernel* k, bool upstream) { - uint32_t queue_id = (upstream) ? 0 : this->config.vc_count.value(); - std::vector& kernels = (upstream) ? upstream_kernels : downstream_kernels; + uint32_t queue_id = (upstream) ? 0 : static_config_.vc_count.value(); + std::vector& kernels = (upstream) ? upstream_kernels_ : downstream_kernels_; for (auto kernel : kernels) { if (auto router_kernel = dynamic_cast(kernel)) { if (k == kernel) { return queue_id; } - queue_id += (upstream) ? router_kernel->GetConfig().fwd_vc_count.value() - : router_kernel->GetConfig().vc_count.value(); + queue_id += (upstream) ? router_kernel->GetStaticConfig().fwd_vc_count.value() + : router_kernel->GetStaticConfig().vc_count.value(); } } TT_ASSERT(false, "Couldn't find router kernel"); return queue_id; } uint32_t EthTunnelerKernel::GetRouterId(FDKernel* k, bool upstream) { - std::vector& search = (upstream) ? upstream_kernels : downstream_kernels; + std::vector& search = (upstream) ? upstream_kernels_ : downstream_kernels_; uint32_t router_id = 0; for (auto kernel : search) { if (auto router_kernel = dynamic_cast(kernel)) { diff --git a/tt_metal/impl/dispatch/kernel_config/eth_tunneler_kernel.hpp b/tt_metal/impl/dispatch/kernel_config/eth_tunneler_kernel.hpp index ec5ddcb52a13..c21853ffc537 100644 --- a/tt_metal/impl/dispatch/kernel_config/eth_tunneler_kernel.hpp +++ b/tt_metal/impl/dispatch/kernel_config/eth_tunneler_kernel.hpp @@ -4,12 +4,18 @@ #pragma once #include "fd_kernel.hpp" -typedef struct eth_tunneler_config { +typedef struct eth_tunneler_static_config { std::optional endpoint_id_start_index; std::optional vc_count; // Set from arch level std::optional in_queue_start_addr_words; std::optional in_queue_size_words; + std::optional kernel_status_buf_addr_arg; + std::optional kernel_status_buf_size_bytes; + std::optional timeout_cycles; +} eth_tunneler_static_config_t; + +typedef struct eth_tunneler_dependent_config { std::array, MAX_TUNNEL_LANES> remote_receiver_x; // [4:13], dependent std::array, MAX_TUNNEL_LANES> remote_receiver_y; // [4:13], dependent std::array, MAX_TUNNEL_LANES> remote_receiver_queue_id; // [4:13], dependent @@ -21,11 +27,8 @@ typedef struct eth_tunneler_config { std::array, MAX_TUNNEL_LANES> remote_sender_queue_id; // [34:43], dependent std::array, MAX_TUNNEL_LANES> remote_sender_network_type; // [34:43], dependent - std::optional kernel_status_buf_addr_arg; - std::optional kernel_status_buf_size_bytes; - std::optional timeout_cycles; std::optional inner_stop_mux_d_bypass; // Dependent -} eth_tunneler_config_t; +} eth_tunneler_dependent_config_t; class EthTunnelerKernel : public FDKernel { public: @@ -36,7 +39,7 @@ class EthTunnelerKernel : public FDKernel { uint8_t cq_id, noc_selection_t noc_selection, bool is_remote) : - FDKernel(node_id, device_id, servicing_device_id, cq_id, noc_selection), is_remote(is_remote) {} + FDKernel(node_id, device_id, servicing_device_id, cq_id, noc_selection), is_remote_(is_remote) {} void CreateKernel() override; void GenerateStaticConfigs() override; void GenerateDependentConfigs() override; @@ -44,15 +47,14 @@ class EthTunnelerKernel : public FDKernel { // Tunneler kernel is the exception in that it's always on ethernet core even if dispatch is on tensix. return CoreType::ETH; } - const eth_tunneler_config_t& GetConfig() { return this->config; } - bool IsRemote() { return this->is_remote; } - void SetVCCount(uint32_t vc_count) { this->config.vc_count = vc_count; } + const eth_tunneler_static_config_t& GetStaticConfig() { return static_config_; } + bool IsRemote() { return is_remote_; } + void SetVCCount(uint32_t vc_count) { static_config_.vc_count = vc_count; } uint32_t GetRouterQueueIdOffset(FDKernel* k, bool upstream); uint32_t GetRouterId(FDKernel* k, bool upstream); private: - eth_tunneler_config_t config; - bool is_remote; - bool is_tunnel_start = true; - bool is_tunnel_end = true; + eth_tunneler_static_config_t static_config_; + eth_tunneler_dependent_config_t dependent_config_; + bool is_remote_; }; diff --git a/tt_metal/impl/dispatch/kernel_config/fd_kernel.cpp b/tt_metal/impl/dispatch/kernel_config/fd_kernel.cpp index 2430f883cdcf..4a828b7f4aa3 100644 --- a/tt_metal/impl/dispatch/kernel_config/fd_kernel.cpp +++ b/tt_metal/impl/dispatch/kernel_config/fd_kernel.cpp @@ -117,30 +117,30 @@ void FDKernel::configure_kernel_variant( if (tt::llrt::RunTimeOptions::get_instance().watcher_dispatch_disabled()) { defines["FORCE_WATCHER_OFF"] = "1"; } - if (!tt::DPrintServerReadsDispatchCores(this->device)) { + if (!tt::DPrintServerReadsDispatchCores(device_)) { defines["FORCE_DPRINT_OFF"] = "1"; } defines.insert(defines_in.begin(), defines_in.end()); if (GetCoreType() == CoreType::WORKER) { tt::tt_metal::CreateKernel( - *program, + *program_, path, - this->logical_core, + logical_core_, tt::tt_metal::DataMovementConfig{ .processor = send_to_brisc ? tt::tt_metal::DataMovementProcessor::RISCV_0 : tt::tt_metal::DataMovementProcessor::RISCV_1, - .noc = this->noc_selection.non_dispatch_noc, + .noc = noc_selection_.non_dispatch_noc, .compile_args = compile_args, .defines = defines}); } else { tt::tt_metal::CreateKernel( - *program, + *program_, path, - this->logical_core, + logical_core_, tt::tt_metal::EthernetConfig{ .eth_mode = is_active_eth_core ? Eth::SENDER : Eth::IDLE, - .noc = this->noc_selection.non_dispatch_noc, + .noc = noc_selection_.non_dispatch_noc, .compile_args = compile_args, .defines = defines}); } diff --git a/tt_metal/impl/dispatch/kernel_config/fd_kernel.hpp b/tt_metal/impl/dispatch/kernel_config/fd_kernel.hpp index 7250895edda7..38e6594e993f 100644 --- a/tt_metal/impl/dispatch/kernel_config/fd_kernel.hpp +++ b/tt_metal/impl/dispatch/kernel_config/fd_kernel.hpp @@ -7,9 +7,9 @@ #include "impl/program/program.hpp" #include "tt_metal/impl/dispatch/kernels/packet_queue_ctrl.hpp" -#define UNUSED_LOGICAL_CORE tt_cxy_pair(this->device->id(), 0, 0) +#define UNUSED_LOGICAL_CORE tt_cxy_pair(device_->id(), 0, 0) // TODO: Just to make match with previous implementation, remove later -#define UNUSED_LOGICAL_CORE_ADJUSTED tt_cxy_pair(servicing_device_id, 0, 0) +#define UNUSED_LOGICAL_CORE_ADJUSTED tt_cxy_pair(servicing_device_id_, 0, 0) #define UNUSED_SEM_ID 0 typedef struct { @@ -45,11 +45,11 @@ class FDKernel { public: FDKernel( int node_id, chip_id_t device_id, chip_id_t servicing_device_id, uint8_t cq_id, noc_selection_t noc_selection) : - node_id(node_id), - device_id(device_id), - servicing_device_id(servicing_device_id), - cq_id(cq_id), - noc_selection(noc_selection) {}; + node_id_(node_id), + device_id_(device_id), + servicing_device_id_(servicing_device_id), + cq_id_(cq_id), + noc_selection_(noc_selection) {}; virtual ~FDKernel() = default; // Populate the static configs for this kernel (ones that do not depend on configs from other kernels), including @@ -77,22 +77,22 @@ class FDKernel { DispatchWorkerType type); // Register another kernel as upstream/downstream of this one - void AddUpstreamKernel(FDKernel* upstream) { this->upstream_kernels.push_back(upstream); } - void AddDownstreamKernel(FDKernel* downstream) { this->downstream_kernels.push_back(downstream); } + void AddUpstreamKernel(FDKernel* upstream) { upstream_kernels_.push_back(upstream); } + void AddDownstreamKernel(FDKernel* downstream) { downstream_kernels_.push_back(downstream); } - virtual CoreType GetCoreType() { return dispatch_core_manager::instance().get_dispatch_core_type(device->id()); } - tt_cxy_pair GetLogicalCore() { return logical_core; } + virtual CoreType GetCoreType() { return dispatch_core_manager::instance().get_dispatch_core_type(device_->id()); } + tt_cxy_pair GetLogicalCore() { return logical_core_; } tt_cxy_pair GetVirtualCore() { - return tt::Cluster::instance().get_virtual_coordinate_from_logical_coordinates(logical_core, GetCoreType()); + return tt::Cluster::instance().get_virtual_coordinate_from_logical_coordinates(logical_core_, GetCoreType()); } - chip_id_t GetDeviceId() { return this->device_id; } // Since this->device may not exist yet + chip_id_t GetDeviceId() { return device_id_; } // Since this->device may not exist yet // Get the port index for which a given kernel is upstream/downstream of this one - int GetUpstreamPort(FDKernel* other) { return GetPort(other, this->upstream_kernels); } - int GetDownstreamPort(FDKernel* other) { return GetPort(other, this->downstream_kernels); } + int GetUpstreamPort(FDKernel* other) { return GetPort(other, this->upstream_kernels_); } + int GetDownstreamPort(FDKernel* other) { return GetPort(other, this->downstream_kernels_); } void AddDeviceAndProgram(Device* device, Program* program) { - this->device = device; - this->program = program; + device_ = device; + program_ = program; }; protected: @@ -118,15 +118,15 @@ class FDKernel { static chip_id_t GetDownstreamDeviceId(chip_id_t device_id); static uint32_t GetTunnelStop(chip_id_t device_id); - Device* device = nullptr; // Set at configuration time by AddDeviceAndProgram() - Program* program = nullptr; - tt_cxy_pair logical_core; - chip_id_t device_id; - chip_id_t servicing_device_id; // Remote chip that this PREFETCH_H/DISPATCH_H is servicing - int node_id; - uint8_t cq_id; - noc_selection_t noc_selection; - - std::vector upstream_kernels; - std::vector downstream_kernels; + Device* device_ = nullptr; // Set at configuration time by AddDeviceAndProgram() + Program* program_ = nullptr; + tt_cxy_pair logical_core_; + chip_id_t device_id_; + chip_id_t servicing_device_id_; // Remote chip that this PREFETCH_H/DISPATCH_H is servicing + int node_id_; + uint8_t cq_id_; + noc_selection_t noc_selection_; + + std::vector upstream_kernels_; + std::vector downstream_kernels_; }; diff --git a/tt_metal/impl/dispatch/kernel_config/mux_kernel.cpp b/tt_metal/impl/dispatch/kernel_config/mux_kernel.cpp index 9c02ad173682..6629f65afc48 100644 --- a/tt_metal/impl/dispatch/kernel_config/mux_kernel.cpp +++ b/tt_metal/impl/dispatch/kernel_config/mux_kernel.cpp @@ -10,144 +10,145 @@ #include "tt_metal/detail/tt_metal.hpp" void MuxKernel::GenerateStaticConfigs() { - uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(this->device->id()); - this->logical_core = dispatch_core_manager::instance().mux_d_core(this->device->id(), channel, this->cq_id); + uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device_->id()); + logical_core_ = dispatch_core_manager::instance().mux_d_core(device_->id(), channel, this->cq_id_); auto& my_dispatch_constants = dispatch_constants::get(GetCoreType()); - this->config.reserved = 0; - this->config.rx_queue_start_addr_words = my_dispatch_constants.dispatch_buffer_base() >> 4; - this->config.rx_queue_size_words = ((1 << dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE) * - my_dispatch_constants.mux_buffer_pages(device->num_hw_cqs())) >> - 4; - this->config.mux_fan_in = this->upstream_kernels.size(); - for (int idx = 0; idx < this->upstream_kernels.size(); idx++) { - this->config.remote_rx_network_type[idx] = DispatchRemoteNetworkType::NOC0; + static_config_.reserved = 0; + static_config_.rx_queue_start_addr_words = my_dispatch_constants.dispatch_buffer_base() >> 4; + static_config_.rx_queue_size_words = ((1 << dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE) * + my_dispatch_constants.mux_buffer_pages(device_->num_hw_cqs())) >> + 4; + static_config_.mux_fan_in = upstream_kernels_.size(); + for (int idx = 0; idx < upstream_kernels_.size(); idx++) { + static_config_.remote_rx_network_type[idx] = DispatchRemoteNetworkType::NOC0; } - this->config.tx_network_type = (uint32_t)DispatchRemoteNetworkType::NOC0; - this->config.test_results_buf_addr_arg = 0; - this->config.test_results_buf_size_bytes = 0; - this->config.timeout_cycles = 0; - this->config.output_depacketize = 0x0; - this->config.output_depacketize_info = 0x0; + static_config_.tx_network_type = (uint32_t)DispatchRemoteNetworkType::NOC0; + static_config_.test_results_buf_addr_arg = 0; + static_config_.test_results_buf_size_bytes = 0; + static_config_.timeout_cycles = 0; + static_config_.output_depacketize = 0x0; + static_config_.output_depacketize_info = 0x0; - for (int idx = 0; idx < this->upstream_kernels.size(); idx++) { + for (int idx = 0; idx < upstream_kernels_.size(); idx++) { // Only connected dispatchers need a semaphore. TODO: can initialize anyways, but this matches previous // implementation - if (dynamic_cast(this->upstream_kernels[idx])) { - this->config.input_packetize_local_sem[idx] = - tt::tt_metal::CreateSemaphore(*program, this->logical_core, 0, GetCoreType()); + if (dynamic_cast(upstream_kernels_[idx])) { + static_config_.input_packetize_local_sem[idx] = + tt::tt_metal::CreateSemaphore(*program_, logical_core_, 0, GetCoreType()); } } } void MuxKernel::GenerateDependentConfigs() { // Upstream, expect DISPATCH_D or TUNNELER - TT_ASSERT(this->upstream_kernels.size() <= MAX_SWITCH_FAN_IN && this->upstream_kernels.size() > 0); + TT_ASSERT(upstream_kernels_.size() <= MAX_SWITCH_FAN_IN && upstream_kernels_.size() > 0); uint32_t num_upstream_dispatchers = 0; - for (int idx = 0; idx < this->upstream_kernels.size(); idx++) { - FDKernel* k = this->upstream_kernels[idx]; - this->config.remote_rx_x[idx] = k->GetVirtualCore().x; - this->config.remote_rx_y[idx] = k->GetVirtualCore().y; - this->config.input_packetize_log_page_size[idx] = + for (int idx = 0; idx < upstream_kernels_.size(); idx++) { + FDKernel* k = upstream_kernels_[idx]; + dependent_config_.remote_rx_x[idx] = k->GetVirtualCore().x; + dependent_config_.remote_rx_y[idx] = k->GetVirtualCore().y; + dependent_config_.input_packetize_log_page_size[idx] = dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE; // Does this ever change? if (auto dispatch_kernel = dynamic_cast(k)) { - this->config.input_packetize[idx] = 0x1; - this->config.input_packetize_upstream_sem[idx] = dispatch_kernel->GetConfig().my_downstream_cb_sem_id; - this->config.remote_rx_queue_id[idx] = 1; + dependent_config_.input_packetize[idx] = 0x1; + dependent_config_.input_packetize_upstream_sem[idx] = + dispatch_kernel->GetStaticConfig().my_downstream_cb_sem_id; + dependent_config_.remote_rx_queue_id[idx] = 1; num_upstream_dispatchers++; } else if (auto tunneler_kernel = dynamic_cast(k)) { // Don't need to packetize input from tunneler - this->config.input_packetize[idx] = 0x0; - this->config.input_packetize_upstream_sem[idx] = 0; - this->config.remote_rx_queue_id[idx] = tunneler_kernel->GetConfig().vc_count.value() * 2 - 1; + dependent_config_.input_packetize[idx] = 0x0; + dependent_config_.input_packetize_upstream_sem[idx] = 0; + dependent_config_.remote_rx_queue_id[idx] = tunneler_kernel->GetStaticConfig().vc_count.value() * 2 - 1; } else { TT_FATAL(false, "Unexpected kernel type upstream of MUX"); } } - uint32_t src_id = 0xC1 + (FDKernel::GetTunnelStop(device_id) - 1) * num_upstream_dispatchers; - uint32_t dest_id = 0xD1 + (FDKernel::GetTunnelStop(device_id) - 1) * num_upstream_dispatchers; - this->config.input_packetize_src_endpoint = packet_switch_4B_pack(src_id, src_id + 1, src_id + 2, src_id + 3); - this->config.input_packetize_dest_endpoint = packet_switch_4B_pack(dest_id, dest_id + 1, dest_id + 2, dest_id + 3); + uint32_t src_id = 0xC1 + (FDKernel::GetTunnelStop(device_id_) - 1) * num_upstream_dispatchers; + uint32_t dest_id = 0xD1 + (FDKernel::GetTunnelStop(device_id_) - 1) * num_upstream_dispatchers; + static_config_.input_packetize_src_endpoint = packet_switch_4B_pack(src_id, src_id + 1, src_id + 2, src_id + 3); + static_config_.input_packetize_dest_endpoint = + packet_switch_4B_pack(dest_id, dest_id + 1, dest_id + 2, dest_id + 3); // Downstream, expect TUNNELER - TT_ASSERT(this->downstream_kernels.size() == 1); - FDKernel* ds = this->downstream_kernels[0]; + TT_ASSERT(downstream_kernels_.size() == 1); + FDKernel* ds = downstream_kernels_[0]; auto tunneler_kernel = dynamic_cast(ds); TT_ASSERT(ds); - this->config.remote_tx_queue_start_addr_words = - tunneler_kernel->GetConfig().in_queue_start_addr_words.value() + - (tunneler_kernel->GetConfig().vc_count.value() - 1) * tunneler_kernel->GetConfig().in_queue_size_words.value(); - this->config.remote_tx_queue_size_words = tunneler_kernel->GetConfig().in_queue_size_words; - this->config.remote_tx_x = ds->GetVirtualCore().x; - this->config.remote_tx_y = ds->GetVirtualCore().y; - this->config.remote_tx_queue_id = tunneler_kernel->GetConfig().vc_count.value() - 1; + dependent_config_.remote_tx_queue_start_addr_words = + tunneler_kernel->GetStaticConfig().in_queue_start_addr_words.value() + + (tunneler_kernel->GetStaticConfig().vc_count.value() - 1) * + tunneler_kernel->GetStaticConfig().in_queue_size_words.value(); + dependent_config_.remote_tx_queue_size_words = tunneler_kernel->GetStaticConfig().in_queue_size_words; + dependent_config_.remote_tx_x = ds->GetVirtualCore().x; + dependent_config_.remote_tx_y = ds->GetVirtualCore().y; + dependent_config_.remote_tx_queue_id = tunneler_kernel->GetStaticConfig().vc_count.value() - 1; } void MuxKernel::CreateKernel() { std::vector compile_args = { - config.reserved.value(), - config.rx_queue_start_addr_words.value(), - config.rx_queue_size_words.value(), - config.mux_fan_in.value(), + static_config_.reserved.value(), + static_config_.rx_queue_start_addr_words.value(), + static_config_.rx_queue_size_words.value(), + static_config_.mux_fan_in.value(), 0, 0, 0, 0, // Populate remote_rx_config after - config.remote_tx_queue_start_addr_words.value(), - config.remote_tx_queue_size_words.value(), - config.remote_tx_x.value(), - config.remote_tx_y.value(), - config.remote_tx_queue_id.value(), - config.tx_network_type.value(), - config.test_results_buf_addr_arg.value(), - config.test_results_buf_size_bytes.value(), - config.timeout_cycles.value(), - config.output_depacketize.value(), - config.output_depacketize_info.value(), + dependent_config_.remote_tx_queue_start_addr_words.value(), + dependent_config_.remote_tx_queue_size_words.value(), + dependent_config_.remote_tx_x.value(), + dependent_config_.remote_tx_y.value(), + dependent_config_.remote_tx_queue_id.value(), + static_config_.tx_network_type.value(), + static_config_.test_results_buf_addr_arg.value(), + static_config_.test_results_buf_size_bytes.value(), + static_config_.timeout_cycles.value(), + static_config_.output_depacketize.value(), + static_config_.output_depacketize_info.value(), 0, 0, 0, 0, // Populate input_packetize_config after - config.input_packetize_src_endpoint.value(), - config.input_packetize_dest_endpoint.value()}; + static_config_.input_packetize_src_endpoint.value(), + static_config_.input_packetize_dest_endpoint.value()}; for (int idx = 0; idx < MAX_SWITCH_FAN_IN; idx++) { - if (config.remote_rx_x[idx]) { - compile_args[4 + idx] |= (config.remote_rx_x[idx].value() & 0xFF); - compile_args[4 + idx] |= (config.remote_rx_y[idx].value() & 0xFF) << 8; - compile_args[4 + idx] |= (config.remote_rx_queue_id[idx].value() & 0xFF) << 16; - compile_args[4 + idx] |= (config.remote_rx_network_type[idx].value() & 0xFF) << 24; + if (dependent_config_.remote_rx_x[idx]) { + compile_args[4 + idx] |= (dependent_config_.remote_rx_x[idx].value() & 0xFF); + compile_args[4 + idx] |= (dependent_config_.remote_rx_y[idx].value() & 0xFF) << 8; + compile_args[4 + idx] |= (dependent_config_.remote_rx_queue_id[idx].value() & 0xFF) << 16; + compile_args[4 + idx] |= (static_config_.remote_rx_network_type[idx].value() & 0xFF) << 24; } - if (config.input_packetize[idx]) { + if (dependent_config_.input_packetize[idx]) { // Zero out if input packetize not set to match previous implementation. TODO: don't have to do this - if (config.input_packetize[idx].value() != 0) { - compile_args[19 + idx] |= (config.input_packetize[idx].value() & 0xFF); - compile_args[19 + idx] |= (config.input_packetize_log_page_size[idx].value() & 0xFF) << 8; - compile_args[19 + idx] |= (config.input_packetize_upstream_sem[idx].value() & 0xFF) << 16; - compile_args[19 + idx] |= (config.input_packetize_local_sem[idx].value() & 0xFF) << 24; + if (dependent_config_.input_packetize[idx].value() != 0) { + compile_args[19 + idx] |= (dependent_config_.input_packetize[idx].value() & 0xFF); + compile_args[19 + idx] |= (dependent_config_.input_packetize_log_page_size[idx].value() & 0xFF) << 8; + compile_args[19 + idx] |= (dependent_config_.input_packetize_upstream_sem[idx].value() & 0xFF) << 16; + compile_args[19 + idx] |= (static_config_.input_packetize_local_sem[idx].value() & 0xFF) << 24; } } } TT_ASSERT(compile_args.size() == 25); - const auto& grid_size = device->grid_size(); + const auto& grid_size = device_->grid_size(); std::map defines = { // All of these unused, remove later - {"MY_NOC_X", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.non_dispatch_noc, grid_size.x, 0))}, - {"MY_NOC_Y", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.non_dispatch_noc, grid_size.y, 0))}, - {"UPSTREAM_NOC_INDEX", std::to_string(this->noc_selection.upstream_noc)}, + {"MY_NOC_X", std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.non_dispatch_noc, grid_size.x, 0))}, + {"MY_NOC_Y", std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.non_dispatch_noc, grid_size.y, 0))}, + {"UPSTREAM_NOC_INDEX", std::to_string(noc_selection_.upstream_noc)}, {"UPSTREAM_NOC_X", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.upstream_noc, grid_size.x, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.upstream_noc, grid_size.x, 0))}, {"UPSTREAM_NOC_Y", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.upstream_noc, grid_size.y, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.upstream_noc, grid_size.y, 0))}, {"DOWNSTREAM_NOC_X", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.downstream_noc, grid_size.x, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.downstream_noc, grid_size.x, 0))}, {"DOWNSTREAM_NOC_Y", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.downstream_noc, grid_size.y, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.downstream_noc, grid_size.y, 0))}, {"DOWNSTREAM_SLAVE_NOC_X", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.downstream_noc, grid_size.x, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.downstream_noc, grid_size.x, 0))}, {"DOWNSTREAM_SLAVE_NOC_Y", - std::to_string(tt::tt_metal::hal.noc_coordinate(this->noc_selection.downstream_noc, grid_size.y, 0))}, + std::to_string(tt::tt_metal::hal.noc_coordinate(noc_selection_.downstream_noc, grid_size.y, 0))}, {"SKIP_NOC_LOGGING", "1"}}; configure_kernel_variant(dispatch_kernel_file_names[MUX_D], compile_args, defines, false, false, false); } diff --git a/tt_metal/impl/dispatch/kernel_config/mux_kernel.hpp b/tt_metal/impl/dispatch/kernel_config/mux_kernel.hpp index 19138d6007cc..2bb85a1d2983 100644 --- a/tt_metal/impl/dispatch/kernel_config/mux_kernel.hpp +++ b/tt_metal/impl/dispatch/kernel_config/mux_kernel.hpp @@ -4,35 +4,38 @@ #pragma once #include "fd_kernel.hpp" -typedef struct mux_config { +typedef struct mux_static_config { std::optional reserved; std::optional rx_queue_start_addr_words; std::optional rx_queue_size_words; std::optional mux_fan_in; - std::array, MAX_SWITCH_FAN_IN> remote_rx_x; // [4:7], dependent - std::array, MAX_SWITCH_FAN_IN> remote_rx_y; // [4:7], dependent - std::array, MAX_SWITCH_FAN_IN> remote_rx_queue_id; // [4:7], dependent std::array, MAX_SWITCH_FAN_IN> remote_rx_network_type; // [4:7] - std::optional remote_tx_queue_start_addr_words; // Dependent - std::optional remote_tx_queue_size_words; // Dependent - std::optional remote_tx_x; // Dependent - std::optional remote_tx_y; // Dependent - std::optional remote_tx_queue_id; // Dependent std::optional tx_network_type; std::optional test_results_buf_addr_arg; std::optional test_results_buf_size_bytes; std::optional timeout_cycles; std::optional output_depacketize; - std::optional output_depacketize_info; // Packed, pack with above same is input? - std::array, MAX_SWITCH_FAN_IN> input_packetize; // Dependent - std::array, MAX_SWITCH_FAN_IN> input_packetize_log_page_size; // Dependent - std::array, MAX_SWITCH_FAN_IN> input_packetize_upstream_sem; // Dependent + std::optional output_depacketize_info; // Packed, pack with above same is input? std::array, MAX_SWITCH_FAN_IN> input_packetize_local_sem; std::optional input_packetize_src_endpoint; // Packed w/ max 4 assumption std::optional input_packetize_dest_endpoint; // Same as src -} mux_config_t; +} mux_static_config_t; + +typedef struct mux_dependent_config { + std::array, MAX_SWITCH_FAN_IN> remote_rx_x; // [4:7], dependent + std::array, MAX_SWITCH_FAN_IN> remote_rx_y; // [4:7], dependent + std::array, MAX_SWITCH_FAN_IN> remote_rx_queue_id; // [4:7], dependent + std::optional remote_tx_queue_start_addr_words; // Dependent + std::optional remote_tx_queue_size_words; // Dependent + std::optional remote_tx_x; // Dependent + std::optional remote_tx_y; // Dependent + std::optional remote_tx_queue_id; // Dependent + std::array, MAX_SWITCH_FAN_IN> input_packetize; // Dependent + std::array, MAX_SWITCH_FAN_IN> input_packetize_log_page_size; // Dependent + std::array, MAX_SWITCH_FAN_IN> input_packetize_upstream_sem; // Dependent +} mux_dependent_config_t; class MuxKernel : public FDKernel { public: @@ -42,8 +45,9 @@ class MuxKernel : public FDKernel { void CreateKernel() override; void GenerateStaticConfigs() override; void GenerateDependentConfigs() override; - const mux_config_t& GetConfig() { return this->config; } + const mux_static_config_t& GetStaticConfig() { return static_config_; } private: - mux_config_t config; + mux_static_config_t static_config_; + mux_dependent_config_t dependent_config_; }; diff --git a/tt_metal/impl/dispatch/kernel_config/prefetch_kernel.cpp b/tt_metal/impl/dispatch/kernel_config/prefetch_kernel.cpp index 9fc3d29560f1..7370d826695a 100644 --- a/tt_metal/impl/dispatch/kernel_config/prefetch_kernel.cpp +++ b/tt_metal/impl/dispatch/kernel_config/prefetch_kernel.cpp @@ -10,52 +10,52 @@ #include "tt_metal/detail/tt_metal.hpp" void PrefetchKernel::GenerateStaticConfigs() { - uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device->id()); - uint8_t cq_id = this->cq_id; + uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device_->id()); + uint8_t cq_id_ = this->cq_id_; auto& my_dispatch_constants = dispatch_constants::get(GetCoreType()); - if (this->config.is_h_variant.value() && this->config.is_d_variant.value()) { + if (static_config_.is_h_variant.value() && this->static_config_.is_d_variant.value()) { uint32_t cq_start = my_dispatch_constants.get_host_command_queue_addr(CommandQueueHostAddrType::UNRESERVED); - uint32_t cq_size = device->sysmem_manager().get_cq_size(); - uint32_t command_queue_start_addr = get_absolute_cq_offset(channel, cq_id, cq_size); + uint32_t cq_size = device_->sysmem_manager().get_cq_size(); + uint32_t command_queue_start_addr = get_absolute_cq_offset(channel, cq_id_, cq_size); uint32_t issue_queue_start_addr = command_queue_start_addr + cq_start; - uint32_t issue_queue_size = device->sysmem_manager().get_issue_queue_size(cq_id); + uint32_t issue_queue_size = device_->sysmem_manager().get_issue_queue_size(cq_id_); - this->logical_core = dispatch_core_manager::instance().prefetcher_core(device->id(), channel, cq_id); + logical_core_ = dispatch_core_manager::instance().prefetcher_core(device_->id(), channel, cq_id_); - this->config.downstream_cb_base = my_dispatch_constants.dispatch_buffer_base(); - this->config.downstream_cb_log_page_size = dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE; - this->config.downstream_cb_pages = my_dispatch_constants.dispatch_buffer_pages(); - this->config.my_downstream_cb_sem_id = tt::tt_metal::CreateSemaphore( - *program, this->logical_core, my_dispatch_constants.dispatch_buffer_pages(), GetCoreType()); + dependent_config_.downstream_cb_base = my_dispatch_constants.dispatch_buffer_base(); + static_config_.downstream_cb_log_page_size = dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE; + static_config_.downstream_cb_pages = my_dispatch_constants.dispatch_buffer_pages(); + static_config_.my_downstream_cb_sem_id = tt::tt_metal::CreateSemaphore( + *program_, logical_core_, my_dispatch_constants.dispatch_buffer_pages(), GetCoreType()); - this->config.pcie_base = issue_queue_start_addr; - this->config.pcie_size = issue_queue_size; - this->config.prefetch_q_base = + static_config_.pcie_base = issue_queue_start_addr; + static_config_.pcie_size = issue_queue_size; + static_config_.prefetch_q_base = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::UNRESERVED); - this->config.prefetch_q_size = my_dispatch_constants.prefetch_q_size(); - this->config.prefetch_q_rd_ptr_addr = + static_config_.prefetch_q_size = my_dispatch_constants.prefetch_q_size(); + static_config_.prefetch_q_rd_ptr_addr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::PREFETCH_Q_RD); - this->config.prefetch_q_pcie_rd_ptr_addr = + static_config_.prefetch_q_pcie_rd_ptr_addr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::PREFETCH_Q_PCIE_RD); - this->config.cmddat_q_base = my_dispatch_constants.cmddat_q_base(); - this->config.cmddat_q_size = my_dispatch_constants.cmddat_q_size(); + static_config_.cmddat_q_base = my_dispatch_constants.cmddat_q_base(); + static_config_.cmddat_q_size = my_dispatch_constants.cmddat_q_size(); - this->config.scratch_db_base = my_dispatch_constants.scratch_db_base(); - this->config.scratch_db_size = my_dispatch_constants.scratch_db_size(); - this->config.downstream_sync_sem_id = - tt::tt_metal::CreateSemaphore(*program, this->logical_core, 0, GetCoreType()); + static_config_.scratch_db_base = my_dispatch_constants.scratch_db_base(); + static_config_.scratch_db_size = my_dispatch_constants.scratch_db_size(); + static_config_.downstream_sync_sem_id = + tt::tt_metal::CreateSemaphore(*program_, logical_core_, 0, GetCoreType()); // prefetch_d only - this->config.cmddat_q_pages = my_dispatch_constants.prefetch_d_buffer_pages(); - this->config.my_upstream_cb_sem_id = 0; - this->config.upstream_cb_sem_id = 0; - this->config.cmddat_q_log_page_size = dispatch_constants::PREFETCH_D_BUFFER_LOG_PAGE_SIZE; - this->config.cmddat_q_blocks = dispatch_constants::PREFETCH_D_BUFFER_BLOCKS; + static_config_.cmddat_q_pages = my_dispatch_constants.prefetch_d_buffer_pages(); + static_config_.my_upstream_cb_sem_id = 0; + dependent_config_.upstream_cb_sem_id = 0; + static_config_.cmddat_q_log_page_size = dispatch_constants::PREFETCH_D_BUFFER_LOG_PAGE_SIZE; + static_config_.cmddat_q_blocks = dispatch_constants::PREFETCH_D_BUFFER_BLOCKS; uint32_t dispatch_s_buffer_base = 0xff; - if (device->dispatch_s_enabled()) { + if (device_->dispatch_s_enabled()) { uint32_t dispatch_buffer_base = my_dispatch_constants.dispatch_buffer_base(); if (GetCoreType() == CoreType::WORKER) { // dispatch_s is on the same Tensix core as dispatch_d. Shared resources. Offset CB start idx. @@ -67,101 +67,101 @@ void PrefetchKernel::GenerateStaticConfigs() { dispatch_s_buffer_base = dispatch_buffer_base; } } - this->config.dispatch_s_buffer_base = dispatch_s_buffer_base; - this->config.my_dispatch_s_cb_sem_id = tt::tt_metal::CreateSemaphore( - *program, this->logical_core, my_dispatch_constants.dispatch_s_buffer_pages(), GetCoreType()); - this->config.dispatch_s_buffer_size = my_dispatch_constants.dispatch_s_buffer_size(); - this->config.dispatch_s_cb_log_page_size = dispatch_constants::DISPATCH_S_BUFFER_LOG_PAGE_SIZE; - } else if (this->config.is_h_variant.value()) { + static_config_.dispatch_s_buffer_base = dispatch_s_buffer_base; + static_config_.my_dispatch_s_cb_sem_id = tt::tt_metal::CreateSemaphore( + *program_, logical_core_, my_dispatch_constants.dispatch_s_buffer_pages(), GetCoreType()); + static_config_.dispatch_s_buffer_size = my_dispatch_constants.dispatch_s_buffer_size(); + static_config_.dispatch_s_cb_log_page_size = dispatch_constants::DISPATCH_S_BUFFER_LOG_PAGE_SIZE; + } else if (static_config_.is_h_variant.value()) { // PREFETCH_H services a remote chip, and so has a different channel - channel = tt::Cluster::instance().get_assigned_channel_for_device(servicing_device_id); + channel = tt::Cluster::instance().get_assigned_channel_for_device(servicing_device_id_); uint32_t cq_start = my_dispatch_constants.get_host_command_queue_addr(CommandQueueHostAddrType::UNRESERVED); - uint32_t cq_size = device->sysmem_manager().get_cq_size(); - uint32_t command_queue_start_addr = get_absolute_cq_offset(channel, cq_id, cq_size); + uint32_t cq_size = device_->sysmem_manager().get_cq_size(); + uint32_t command_queue_start_addr = get_absolute_cq_offset(channel, cq_id_, cq_size); uint32_t issue_queue_start_addr = command_queue_start_addr + cq_start; - uint32_t issue_queue_size = device->sysmem_manager().get_issue_queue_size(cq_id); + uint32_t issue_queue_size = device_->sysmem_manager().get_issue_queue_size(cq_id_); - this->logical_core = dispatch_core_manager::instance().prefetcher_core(servicing_device_id, channel, cq_id); + logical_core_ = dispatch_core_manager::instance().prefetcher_core(servicing_device_id_, channel, cq_id_); - this->config.downstream_cb_log_page_size = dispatch_constants::PREFETCH_D_BUFFER_LOG_PAGE_SIZE; + static_config_.downstream_cb_log_page_size = dispatch_constants::PREFETCH_D_BUFFER_LOG_PAGE_SIZE; if (tt::Cluster::instance().is_galaxy_cluster()) { // TODO: whys is this hard-coded for galaxy? - this->config.downstream_cb_pages = my_dispatch_constants.mux_buffer_pages(1); + static_config_.downstream_cb_pages = my_dispatch_constants.mux_buffer_pages(1); } else { - this->config.downstream_cb_pages = my_dispatch_constants.mux_buffer_pages(device->num_hw_cqs()); + static_config_.downstream_cb_pages = my_dispatch_constants.mux_buffer_pages(device_->num_hw_cqs()); } - this->config.pcie_base = issue_queue_start_addr; - this->config.pcie_size = issue_queue_size; - this->config.prefetch_q_base = + static_config_.pcie_base = issue_queue_start_addr; + static_config_.pcie_size = issue_queue_size; + static_config_.prefetch_q_base = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::UNRESERVED); - this->config.prefetch_q_size = my_dispatch_constants.prefetch_q_size(); - this->config.prefetch_q_rd_ptr_addr = + static_config_.prefetch_q_size = my_dispatch_constants.prefetch_q_size(); + static_config_.prefetch_q_rd_ptr_addr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::PREFETCH_Q_RD); - this->config.prefetch_q_pcie_rd_ptr_addr = + static_config_.prefetch_q_pcie_rd_ptr_addr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::PREFETCH_Q_PCIE_RD); - this->config.cmddat_q_base = my_dispatch_constants.cmddat_q_base(); - this->config.cmddat_q_size = my_dispatch_constants.cmddat_q_size(); + static_config_.cmddat_q_base = my_dispatch_constants.cmddat_q_base(); + static_config_.cmddat_q_size = my_dispatch_constants.cmddat_q_size(); - this->config.scratch_db_base = my_dispatch_constants.scratch_db_base(); - this->config.scratch_db_size = my_dispatch_constants.scratch_db_size(); - this->config.downstream_sync_sem_id = 0; // Unused for prefetch_h + static_config_.scratch_db_base = my_dispatch_constants.scratch_db_base(); + static_config_.scratch_db_size = my_dispatch_constants.scratch_db_size(); + static_config_.downstream_sync_sem_id = 0; // Unused for prefetch_h - this->config.cmddat_q_pages = my_dispatch_constants.prefetch_d_buffer_pages(); - this->config.my_upstream_cb_sem_id = - tt::tt_metal::CreateSemaphore(*program, this->logical_core, 0, GetCoreType()); - this->config.my_downstream_cb_sem_id = tt::tt_metal::CreateSemaphore( - *program, this->logical_core, this->config.downstream_cb_pages.value(), GetCoreType()); + static_config_.cmddat_q_pages = my_dispatch_constants.prefetch_d_buffer_pages(); + static_config_.my_upstream_cb_sem_id = + tt::tt_metal::CreateSemaphore(*program_, logical_core_, 0, GetCoreType()); + static_config_.my_downstream_cb_sem_id = tt::tt_metal::CreateSemaphore( + *program_, logical_core_, static_config_.downstream_cb_pages.value(), GetCoreType()); tt::tt_metal::CreateSemaphore( - *program, this->logical_core, 0, GetCoreType()); // TODO: what is this third semaphore for? - this->config.cmddat_q_log_page_size = dispatch_constants::PREFETCH_D_BUFFER_LOG_PAGE_SIZE; - this->config.cmddat_q_blocks = dispatch_constants::PREFETCH_D_BUFFER_BLOCKS; + *program_, logical_core_, 0, GetCoreType()); // TODO: what is this third semaphore for? + static_config_.cmddat_q_log_page_size = dispatch_constants::PREFETCH_D_BUFFER_LOG_PAGE_SIZE; + static_config_.cmddat_q_blocks = dispatch_constants::PREFETCH_D_BUFFER_BLOCKS; // PREFETCH_H has no DISPATCH_S - this->config.dispatch_s_buffer_base = 0; - this->config.my_dispatch_s_cb_sem_id = 0; - this->config.dispatch_s_buffer_size = 0; - this->config.dispatch_s_cb_log_page_size = 0; - } else if (this->config.is_d_variant.value()) { - this->logical_core = dispatch_core_manager::instance().prefetcher_d_core(device->id(), channel, cq_id); - - this->config.downstream_cb_base = my_dispatch_constants.dispatch_buffer_base(); - this->config.downstream_cb_log_page_size = dispatch_constants::PREFETCH_D_BUFFER_LOG_PAGE_SIZE; - this->config.downstream_cb_pages = my_dispatch_constants.dispatch_buffer_pages(); - - this->config.pcie_base = 0; - this->config.pcie_size = 0; - this->config.prefetch_q_base = 0; - this->config.prefetch_q_size = my_dispatch_constants.prefetch_q_size(); - this->config.prefetch_q_rd_ptr_addr = + static_config_.dispatch_s_buffer_base = 0; + static_config_.my_dispatch_s_cb_sem_id = 0; + static_config_.dispatch_s_buffer_size = 0; + static_config_.dispatch_s_cb_log_page_size = 0; + } else if (static_config_.is_d_variant.value()) { + logical_core_ = dispatch_core_manager::instance().prefetcher_d_core(device_->id(), channel, cq_id_); + + dependent_config_.downstream_cb_base = my_dispatch_constants.dispatch_buffer_base(); + static_config_.downstream_cb_log_page_size = dispatch_constants::PREFETCH_D_BUFFER_LOG_PAGE_SIZE; + static_config_.downstream_cb_pages = my_dispatch_constants.dispatch_buffer_pages(); + + static_config_.pcie_base = 0; + static_config_.pcie_size = 0; + static_config_.prefetch_q_base = 0; + static_config_.prefetch_q_size = my_dispatch_constants.prefetch_q_size(); + static_config_.prefetch_q_rd_ptr_addr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::PREFETCH_Q_RD); - this->config.prefetch_q_pcie_rd_ptr_addr = + static_config_.prefetch_q_pcie_rd_ptr_addr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::PREFETCH_Q_PCIE_RD); - this->config.cmddat_q_base = my_dispatch_constants.dispatch_buffer_base(); - this->config.cmddat_q_size = my_dispatch_constants.prefetch_d_buffer_size(); + static_config_.cmddat_q_base = my_dispatch_constants.dispatch_buffer_base(); + static_config_.cmddat_q_size = my_dispatch_constants.prefetch_d_buffer_size(); uint32_t pcie_alignment = hal.get_alignment(HalMemType::HOST); - this->config.scratch_db_base = (my_dispatch_constants.dispatch_buffer_base() + - my_dispatch_constants.prefetch_d_buffer_size() + pcie_alignment - 1) & - (~(pcie_alignment - 1)); - this->config.scratch_db_size = my_dispatch_constants.scratch_db_size(); - this->config.downstream_sync_sem_id = - tt::tt_metal::CreateSemaphore(*program, this->logical_core, 0, GetCoreType()); - - this->config.cmddat_q_pages = my_dispatch_constants.prefetch_d_buffer_pages(); - this->config.my_upstream_cb_sem_id = - tt::tt_metal::CreateSemaphore(*program, this->logical_core, 0, GetCoreType()); - this->config.my_downstream_cb_sem_id = tt::tt_metal::CreateSemaphore( - *program, - this->logical_core, + static_config_.scratch_db_base = (my_dispatch_constants.dispatch_buffer_base() + + my_dispatch_constants.prefetch_d_buffer_size() + pcie_alignment - 1) & + (~(pcie_alignment - 1)); + static_config_.scratch_db_size = my_dispatch_constants.scratch_db_size(); + static_config_.downstream_sync_sem_id = + tt::tt_metal::CreateSemaphore(*program_, logical_core_, 0, GetCoreType()); + + static_config_.cmddat_q_pages = my_dispatch_constants.prefetch_d_buffer_pages(); + static_config_.my_upstream_cb_sem_id = + tt::tt_metal::CreateSemaphore(*program_, logical_core_, 0, GetCoreType()); + static_config_.my_downstream_cb_sem_id = tt::tt_metal::CreateSemaphore( + *program_, + logical_core_, my_dispatch_constants.dispatch_buffer_pages(), GetCoreType()); // TODO: this is out of order to match previous implementation - this->config.cmddat_q_log_page_size = dispatch_constants::PREFETCH_D_BUFFER_LOG_PAGE_SIZE; - this->config.cmddat_q_blocks = dispatch_constants::PREFETCH_D_BUFFER_BLOCKS; + static_config_.cmddat_q_log_page_size = dispatch_constants::PREFETCH_D_BUFFER_LOG_PAGE_SIZE; + static_config_.cmddat_q_blocks = dispatch_constants::PREFETCH_D_BUFFER_BLOCKS; uint32_t dispatch_s_buffer_base = 0xff; - if (device->dispatch_s_enabled() || true) { // Just to make it match previous implementation + if (device_->dispatch_s_enabled() || true) { // Just to make it match previous implementation uint32_t dispatch_buffer_base = my_dispatch_constants.dispatch_buffer_base(); if (GetCoreType() == CoreType::WORKER) { // dispatch_s is on the same Tensix core as dispatch_d. Shared resources. Offset CB start idx. @@ -173,13 +173,13 @@ void PrefetchKernel::GenerateStaticConfigs() { dispatch_s_buffer_base = dispatch_buffer_base; } } - this->config.dispatch_s_buffer_base = dispatch_s_buffer_base; - this->config.my_dispatch_s_cb_sem_id = tt::tt_metal::CreateSemaphore( - *program, this->logical_core, my_dispatch_constants.dispatch_s_buffer_pages(), GetCoreType()); - this->config.dispatch_s_buffer_size = my_dispatch_constants.dispatch_s_buffer_size(); - this->config.dispatch_s_cb_log_page_size = device->dispatch_s_enabled() - ? dispatch_constants::DISPATCH_S_BUFFER_LOG_PAGE_SIZE - : dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE; + static_config_.dispatch_s_buffer_base = dispatch_s_buffer_base; + static_config_.my_dispatch_s_cb_sem_id = tt::tt_metal::CreateSemaphore( + *program_, logical_core_, my_dispatch_constants.dispatch_s_buffer_pages(), GetCoreType()); + static_config_.dispatch_s_buffer_size = my_dispatch_constants.dispatch_s_buffer_size(); + static_config_.dispatch_s_cb_log_page_size = device_->dispatch_s_enabled() + ? dispatch_constants::DISPATCH_S_BUFFER_LOG_PAGE_SIZE + : dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE; } else { TT_FATAL(false, "PrefetchKernel must be one of (or both) H and D variants"); } @@ -187,106 +187,110 @@ void PrefetchKernel::GenerateStaticConfigs() { void PrefetchKernel::GenerateDependentConfigs() { auto& my_dispatch_constants = dispatch_constants::get(GetCoreType()); - if (this->config.is_h_variant.value() && this->config.is_d_variant.value()) { + if (static_config_.is_h_variant.value() && this->static_config_.is_d_variant.value()) { // Upstream - TT_ASSERT(this->upstream_kernels.size() == 0); - this->config.upstream_logical_core = UNUSED_LOGICAL_CORE; - this->config.upstream_cb_sem_id = 0; // Used in prefetch_d only + TT_ASSERT(upstream_kernels_.size() == 0); + dependent_config_.upstream_logical_core = UNUSED_LOGICAL_CORE; + dependent_config_.upstream_cb_sem_id = 0; // Used in prefetch_d only // Downstream - if (device->dispatch_s_enabled()) { - TT_ASSERT(this->downstream_kernels.size() == 2); + if (device_->dispatch_s_enabled()) { + TT_ASSERT(downstream_kernels_.size() == 2); } else { - TT_ASSERT(this->downstream_kernels.size() == 1); + TT_ASSERT(downstream_kernels_.size() == 1); } bool found_dispatch = false; bool found_dispatch_s = false; - for (FDKernel* k : this->downstream_kernels) { + for (FDKernel* k : downstream_kernels_) { if (auto dispatch_kernel = dynamic_cast(k)) { TT_ASSERT(!found_dispatch, "PREFETCH kernel has multiple downstream DISPATCH kernels."); found_dispatch = true; - this->config.downstream_logical_core = dispatch_kernel->GetLogicalCore(); - this->config.downstream_cb_sem_id = dispatch_kernel->GetConfig().my_dispatch_cb_sem_id; + dependent_config_.downstream_logical_core = dispatch_kernel->GetLogicalCore(); + dependent_config_.downstream_cb_sem_id = dispatch_kernel->GetStaticConfig().my_dispatch_cb_sem_id; } else if (auto dispatch_s_kernel = dynamic_cast(k)) { TT_ASSERT(!found_dispatch_s, "PREFETCH kernel has multiple downstream DISPATCH kernels."); found_dispatch_s = true; - this->config.downstream_s_logical_core = dispatch_s_kernel->GetLogicalCore(); - this->config.downstream_dispatch_s_cb_sem_id = dispatch_s_kernel->GetConfig().my_dispatch_cb_sem_id; + dependent_config_.downstream_s_logical_core = dispatch_s_kernel->GetLogicalCore(); + dependent_config_.downstream_dispatch_s_cb_sem_id = + dispatch_s_kernel->GetStaticConfig().my_dispatch_cb_sem_id; } else { TT_FATAL(false, "Unrecognized downstream kernel."); } } - if (device->dispatch_s_enabled()) { + if (device_->dispatch_s_enabled()) { // Should have found dispatch_s in the downstream kernels TT_ASSERT(found_dispatch && found_dispatch_s); } else { // No dispatch_s, just write 0s to the configs dependent on it TT_ASSERT(found_dispatch && ~found_dispatch_s); - this->config.downstream_s_logical_core = UNUSED_LOGICAL_CORE; - this->config.downstream_dispatch_s_cb_sem_id = UNUSED_SEM_ID; + dependent_config_.downstream_s_logical_core = UNUSED_LOGICAL_CORE; + dependent_config_.downstream_dispatch_s_cb_sem_id = UNUSED_SEM_ID; } - } else if (this->config.is_h_variant.value()) { + } else if (static_config_.is_h_variant.value()) { // Upstream, just host so no dispatch core - TT_ASSERT(this->upstream_kernels.size() == 0); - this->config.upstream_logical_core = UNUSED_LOGICAL_CORE_ADJUSTED; - this->config.upstream_cb_sem_id = 0; // Used in prefetch_d only + TT_ASSERT(upstream_kernels_.size() == 0); + dependent_config_.upstream_logical_core = UNUSED_LOGICAL_CORE_ADJUSTED; + dependent_config_.upstream_cb_sem_id = 0; // Used in prefetch_d only // Downstream, expect just one ROUTER - TT_ASSERT(this->downstream_kernels.size() == 1); - auto router_kernel = dynamic_cast(this->downstream_kernels[0]); + TT_ASSERT(downstream_kernels_.size() == 1); + auto router_kernel = dynamic_cast(downstream_kernels_[0]); TT_ASSERT(router_kernel); - this->config.downstream_logical_core = router_kernel->GetLogicalCore(); - this->config.downstream_s_logical_core = UNUSED_LOGICAL_CORE_ADJUSTED; + dependent_config_.downstream_logical_core = router_kernel->GetLogicalCore(); + dependent_config_.downstream_s_logical_core = UNUSED_LOGICAL_CORE_ADJUSTED; uint32_t router_idx = router_kernel->GetUpstreamPort(this); // Need the port that this connects to downstream - this->config.downstream_cb_base = (router_kernel->GetConfig().rx_queue_start_addr_words.value() << 4) + - (router_kernel->GetConfig().rx_queue_size_words.value() << 4) * router_idx; - this->config.downstream_cb_sem_id = router_kernel->GetConfig().input_packetize_local_sem[router_idx]; - this->config.downstream_dispatch_s_cb_sem_id = 0; // No downstream DISPATCH_S in this case - } else if (this->config.is_d_variant.value()) { + dependent_config_.downstream_cb_base = + (router_kernel->GetStaticConfig().rx_queue_start_addr_words.value() << 4) + + (router_kernel->GetStaticConfig().rx_queue_size_words.value() << 4) * router_idx; + dependent_config_.downstream_cb_sem_id = router_kernel->GetStaticConfig().input_packetize_local_sem[router_idx]; + dependent_config_.downstream_dispatch_s_cb_sem_id = 0; // No downstream DISPATCH_S in this case + } else if (static_config_.is_d_variant.value()) { // Upstream, expect just one ROUTER - TT_ASSERT(this->upstream_kernels.size() == 1); - auto router_kernel = dynamic_cast(this->upstream_kernels[0]); + TT_ASSERT(upstream_kernels_.size() == 1); + auto router_kernel = dynamic_cast(upstream_kernels_[0]); TT_ASSERT(router_kernel); - this->config.upstream_logical_core = router_kernel->GetLogicalCore(); + dependent_config_.upstream_logical_core = router_kernel->GetLogicalCore(); int router_idx = router_kernel->GetDownstreamPort(this); - this->config.upstream_cb_sem_id = router_kernel->GetConfig().output_depacketize_local_sem[router_idx]; + dependent_config_.upstream_cb_sem_id = + router_kernel->GetStaticConfig().output_depacketize_local_sem[router_idx]; // Downstream, expect a DISPATCH_D and s DISPATCH_S - if (device->dispatch_s_enabled()) { - TT_ASSERT(this->downstream_kernels.size() == 2); + if (device_->dispatch_s_enabled()) { + TT_ASSERT(downstream_kernels_.size() == 2); } else { - TT_ASSERT(this->downstream_kernels.size() == 1); + TT_ASSERT(downstream_kernels_.size() == 1); } bool found_dispatch = false; bool found_dispatch_s = false; - for (FDKernel* k : this->downstream_kernels) { + for (FDKernel* k : downstream_kernels_) { if (auto dispatch_kernel = dynamic_cast(k)) { TT_ASSERT(!found_dispatch, "PREFETCH kernel has multiple downstream DISPATCH kernels."); found_dispatch = true; - this->config.downstream_logical_core = dispatch_kernel->GetLogicalCore(); - this->config.downstream_cb_sem_id = dispatch_kernel->GetConfig().my_dispatch_cb_sem_id; + dependent_config_.downstream_logical_core = dispatch_kernel->GetLogicalCore(); + dependent_config_.downstream_cb_sem_id = dispatch_kernel->GetStaticConfig().my_dispatch_cb_sem_id; } else if (auto dispatch_s_kernel = dynamic_cast(k)) { TT_ASSERT(!found_dispatch_s, "PREFETCH kernel has multiple downstream DISPATCH kernels."); found_dispatch_s = true; - this->config.downstream_s_logical_core = dispatch_s_kernel->GetLogicalCore(); - this->config.downstream_dispatch_s_cb_sem_id = dispatch_s_kernel->GetConfig().my_dispatch_cb_sem_id; + dependent_config_.downstream_s_logical_core = dispatch_s_kernel->GetLogicalCore(); + dependent_config_.downstream_dispatch_s_cb_sem_id = + dispatch_s_kernel->GetStaticConfig().my_dispatch_cb_sem_id; } else { TT_FATAL(false, "Unrecognized downstream kernel."); } } - if (device->dispatch_s_enabled()) { + if (device_->dispatch_s_enabled()) { // Should have found dispatch_s in the downstream kernels TT_ASSERT(found_dispatch && found_dispatch_s); } else { // No dispatch_s, just write 0s to the configs dependent on it TT_ASSERT(found_dispatch && ~found_dispatch_s); - this->config.downstream_s_logical_core = UNUSED_LOGICAL_CORE; - this->config.downstream_dispatch_s_cb_sem_id = - device->dispatch_s_enabled() ? UNUSED_SEM_ID : 1; // Just to make it match previous implementation + dependent_config_.downstream_s_logical_core = UNUSED_LOGICAL_CORE; + dependent_config_.downstream_dispatch_s_cb_sem_id = + device_->dispatch_s_enabled() ? UNUSED_SEM_ID : 1; // Just to make it match previous implementation } } else { TT_FATAL(false, "PrefetchKernel must be one of (or both) H and D variants"); @@ -295,56 +299,56 @@ void PrefetchKernel::GenerateDependentConfigs() { void PrefetchKernel::CreateKernel() { std::vector compile_args = { - config.downstream_cb_base.value(), - config.downstream_cb_log_page_size.value(), - config.downstream_cb_pages.value(), - config.my_downstream_cb_sem_id.value(), - config.downstream_cb_sem_id.value(), - config.pcie_base.value(), - config.pcie_size.value(), - config.prefetch_q_base.value(), - config.prefetch_q_size.value(), - config.prefetch_q_rd_ptr_addr.value(), - config.prefetch_q_pcie_rd_ptr_addr.value(), - config.cmddat_q_base.value(), - config.cmddat_q_size.value(), - config.scratch_db_base.value(), - config.scratch_db_size.value(), - config.downstream_sync_sem_id.value(), - config.cmddat_q_pages.value(), - config.my_upstream_cb_sem_id.value(), - config.upstream_cb_sem_id.value(), - config.cmddat_q_log_page_size.value(), - config.cmddat_q_blocks.value(), - config.dispatch_s_buffer_base.value(), - config.my_dispatch_s_cb_sem_id.value(), - config.downstream_dispatch_s_cb_sem_id.value(), - config.dispatch_s_buffer_size.value(), - config.dispatch_s_cb_log_page_size.value(), - config.is_d_variant.value(), - config.is_h_variant.value(), + dependent_config_.downstream_cb_base.value(), + static_config_.downstream_cb_log_page_size.value(), + static_config_.downstream_cb_pages.value(), + static_config_.my_downstream_cb_sem_id.value(), + dependent_config_.downstream_cb_sem_id.value(), + static_config_.pcie_base.value(), + static_config_.pcie_size.value(), + static_config_.prefetch_q_base.value(), + static_config_.prefetch_q_size.value(), + static_config_.prefetch_q_rd_ptr_addr.value(), + static_config_.prefetch_q_pcie_rd_ptr_addr.value(), + static_config_.cmddat_q_base.value(), + static_config_.cmddat_q_size.value(), + static_config_.scratch_db_base.value(), + static_config_.scratch_db_size.value(), + static_config_.downstream_sync_sem_id.value(), + static_config_.cmddat_q_pages.value(), + static_config_.my_upstream_cb_sem_id.value(), + dependent_config_.upstream_cb_sem_id.value(), + static_config_.cmddat_q_log_page_size.value(), + static_config_.cmddat_q_blocks.value(), + static_config_.dispatch_s_buffer_base.value(), + static_config_.my_dispatch_s_cb_sem_id.value(), + dependent_config_.downstream_dispatch_s_cb_sem_id.value(), + static_config_.dispatch_s_buffer_size.value(), + static_config_.dispatch_s_cb_log_page_size.value(), + static_config_.is_d_variant.value(), + static_config_.is_h_variant.value(), }; TT_ASSERT(compile_args.size() == 28); - auto my_virtual_core = device->virtual_core_from_logical_core(this->logical_core, GetCoreType()); + auto my_virtual_core = device_->virtual_core_from_logical_core(logical_core_, GetCoreType()); auto upstream_virtual_core = - device->virtual_core_from_logical_core(config.upstream_logical_core.value(), GetCoreType()); + device_->virtual_core_from_logical_core(dependent_config_.upstream_logical_core.value(), GetCoreType()); auto downstream_virtual_core = - device->virtual_core_from_logical_core(config.downstream_logical_core.value(), GetCoreType()); + device_->virtual_core_from_logical_core(dependent_config_.downstream_logical_core.value(), GetCoreType()); auto downstream_s_virtual_core = - device->virtual_core_from_logical_core(config.downstream_s_logical_core.value(), GetCoreType()); + device_->virtual_core_from_logical_core(dependent_config_.downstream_s_logical_core.value(), GetCoreType()); - auto my_virtual_noc_coords = device->virtual_noc0_coordinate(noc_selection.non_dispatch_noc, my_virtual_core); + auto my_virtual_noc_coords = device_->virtual_noc0_coordinate(noc_selection_.non_dispatch_noc, my_virtual_core); auto upstream_virtual_noc_coords = - device->virtual_noc0_coordinate(noc_selection.upstream_noc, upstream_virtual_core); + device_->virtual_noc0_coordinate(noc_selection_.upstream_noc, upstream_virtual_core); auto downstream_virtual_noc_coords = - device->virtual_noc0_coordinate(noc_selection.downstream_noc, downstream_virtual_core); + device_->virtual_noc0_coordinate(noc_selection_.downstream_noc, downstream_virtual_core); auto downstream_s_virtual_noc_coords = - device->virtual_noc0_coordinate(noc_selection.downstream_noc, downstream_s_virtual_core); + device_->virtual_noc0_coordinate(noc_selection_.downstream_noc, downstream_s_virtual_core); std::map defines = { {"MY_NOC_X", std::to_string(my_virtual_noc_coords.x)}, {"MY_NOC_Y", std::to_string(my_virtual_noc_coords.y)}, - {"UPSTREAM_NOC_INDEX", std::to_string(this->noc_selection.upstream_noc)}, // Unused, remove later + {"UPSTREAM_NOC_INDEX", std::to_string(noc_selection_.upstream_noc)}, // Unused, remove later {"UPSTREAM_NOC_X", std::to_string(upstream_virtual_noc_coords.x)}, {"UPSTREAM_NOC_Y", std::to_string(upstream_virtual_noc_coords.y)}, {"DOWNSTREAM_NOC_X", std::to_string(downstream_virtual_noc_coords.x)}, @@ -352,7 +356,7 @@ void PrefetchKernel::CreateKernel() { {"DOWNSTREAM_SLAVE_NOC_X", std::to_string(downstream_s_virtual_noc_coords.x)}, {"DOWNSTREAM_SLAVE_NOC_Y", std::to_string(downstream_s_virtual_noc_coords.y)}, }; - this->configure_kernel_variant( + configure_kernel_variant( dispatch_kernel_file_names[PREFETCH], compile_args, defines, @@ -366,13 +370,13 @@ void PrefetchKernel::CreateKernel() { void PrefetchKernel::ConfigureCore() { // Only H-type prefetchers need L1 configuration - if (this->config.is_h_variant.value()) { - tt::log_warning("Configure Prefetch H (device {} core {})", device->id(), logical_core.str()); + if (static_config_.is_h_variant.value()) { + tt::log_warning("Configure Prefetch H (device {} core {})", device_->id(), logical_core_.str()); // Initialize the FetchQ - uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(this->device->id()); + uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device_->id()); auto& my_dispatch_constants = dispatch_constants::get(GetCoreType()); uint32_t cq_start = my_dispatch_constants.get_host_command_queue_addr(CommandQueueHostAddrType::UNRESERVED); - uint32_t cq_size = device->sysmem_manager().get_cq_size(); + uint32_t cq_size = device_->sysmem_manager().get_cq_size(); std::vector prefetch_q(my_dispatch_constants.prefetch_q_entries(), 0); uint32_t prefetch_q_base = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::UNRESERVED); @@ -393,11 +397,10 @@ void PrefetchKernel::ConfigureCore() { uint32_t completion_q1_last_event_ptr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q1_LAST_EVENT); std::vector prefetch_q_pcie_rd_ptr_addr_data = { - get_absolute_cq_offset(channel, cq_id, cq_size) + cq_start}; + get_absolute_cq_offset(channel, cq_id_, cq_size) + cq_start}; + detail::WriteToDeviceL1(device_, logical_core_, prefetch_q_rd_ptr, prefetch_q_rd_ptr_addr_data, GetCoreType()); detail::WriteToDeviceL1( - device, this->logical_core, prefetch_q_rd_ptr, prefetch_q_rd_ptr_addr_data, GetCoreType()); - detail::WriteToDeviceL1( - device, this->logical_core, prefetch_q_pcie_rd_ptr, prefetch_q_pcie_rd_ptr_addr_data, GetCoreType()); - detail::WriteToDeviceL1(device, this->logical_core, prefetch_q_base, prefetch_q, GetCoreType()); + device_, logical_core_, prefetch_q_pcie_rd_ptr, prefetch_q_pcie_rd_ptr_addr_data, GetCoreType()); + detail::WriteToDeviceL1(device_, logical_core_, prefetch_q_base, prefetch_q, GetCoreType()); } } diff --git a/tt_metal/impl/dispatch/kernel_config/prefetch_kernel.hpp b/tt_metal/impl/dispatch/kernel_config/prefetch_kernel.hpp index 4d103998cd8c..2410a1ca2758 100644 --- a/tt_metal/impl/dispatch/kernel_config/prefetch_kernel.hpp +++ b/tt_metal/impl/dispatch/kernel_config/prefetch_kernel.hpp @@ -4,16 +4,10 @@ #pragma once #include "fd_kernel.hpp" -typedef struct prefetch_config { - std::optional upstream_logical_core; // Dependant - std::optional downstream_logical_core; // Dependant - std::optional downstream_s_logical_core; // Dependant - - std::optional downstream_cb_base; // Dependent +typedef struct prefetch_static_config { std::optional downstream_cb_log_page_size; std::optional downstream_cb_pages; std::optional my_downstream_cb_sem_id; - std::optional downstream_cb_sem_id; // Dependant std::optional pcie_base; std::optional pcie_size; @@ -33,20 +27,31 @@ typedef struct prefetch_config { // Used for prefetch_d std::optional cmddat_q_pages; std::optional my_upstream_cb_sem_id; - std::optional upstream_cb_sem_id; // Dependant std::optional cmddat_q_log_page_size; std::optional cmddat_q_blocks; // Used for prefetch_d <--> dispatch_s data path std::optional dispatch_s_buffer_base; std::optional my_dispatch_s_cb_sem_id; - std::optional downstream_dispatch_s_cb_sem_id; // Dependant std::optional dispatch_s_buffer_size; std::optional dispatch_s_cb_log_page_size; std::optional is_d_variant; std::optional is_h_variant; -} prefetch_config_t; +} prefetch_static_config_t; + +typedef struct prefetch_dependent_config { + std::optional upstream_logical_core; // Dependant + std::optional downstream_logical_core; // Dependant + std::optional downstream_s_logical_core; // Dependant + + std::optional downstream_cb_base; // Dependent + std::optional downstream_cb_sem_id; // Dependant + + std::optional upstream_cb_sem_id; // Dependant + + std::optional downstream_dispatch_s_cb_sem_id; // Dependant +} prefetch_dependent_config_t; class PrefetchKernel : public FDKernel { public: @@ -59,15 +64,16 @@ class PrefetchKernel : public FDKernel { bool h_variant, bool d_variant) : FDKernel(node_id, device_id, servicing_device_id, cq_id, noc_selection) { - config.is_h_variant = h_variant; - config.is_d_variant = d_variant; + static_config_.is_h_variant = h_variant; + static_config_.is_d_variant = d_variant; } void CreateKernel() override; void GenerateStaticConfigs() override; void GenerateDependentConfigs() override; void ConfigureCore() override; - const prefetch_config_t& GetConfig() { return this->config; } + const prefetch_static_config_t& GetStaticConfig() { return static_config_; } private: - prefetch_config_t config; + prefetch_static_config_t static_config_; + prefetch_dependent_config_t dependent_config_; }; diff --git a/tt_metal/impl/dispatch/topology.cpp b/tt_metal/impl/dispatch/topology.cpp index fca82c3d56c0..18b0762cd9dd 100644 --- a/tt_metal/impl/dispatch/topology.cpp +++ b/tt_metal/impl/dispatch/topology.cpp @@ -420,9 +420,7 @@ std::vector get_nodes(const std::set& device_ if (num_hw_cqs == 1) { return single_chip_arch_1cq; } else { - // Special case here, single-device can either have dispatch_s or no dispatch_s, depending on the dispatch - // core type. This is only an issue for single-chip, since multi-chip always has ethernet dispatch (and - // therefore no dispatch_s). TODO: determine whether dispatch_s is inserted at this level, instead of inside + // TODO: determine whether dispatch_s is inserted at this level, instead of inside // Device::dispatch_s_enabled(). if (dispatch_core_manager::instance().get_dispatch_core_type(0) == CoreType::WORKER) { return single_chip_arch_2cq_dispatch_s;