Skip to content

Commit

Permalink
#0: Split static/dependent configs
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-dma committed Dec 19, 2024
1 parent 6437257 commit c173b81
Show file tree
Hide file tree
Showing 17 changed files with 1,105 additions and 1,065 deletions.
189 changes: 94 additions & 95 deletions tt_metal/impl/dispatch/kernel_config/demux_kernel.cpp

Large diffs are not rendered by default.

39 changes: 22 additions & 17 deletions tt_metal/impl/dispatch/kernel_config/demux_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,38 @@
#pragma once
#include "fd_kernel.hpp"

typedef struct demux_config {
typedef struct demux_static_config {
std::optional<uint32_t> endpoint_id_start_index;
std::optional<uint32_t> rx_queue_start_addr_words;
std::optional<uint32_t> rx_queue_size_words;
std::optional<uint32_t> demux_fan_out; // Dependent
std::optional<uint32_t> demux_fan_out;

std::array<std::optional<uint32_t>, MAX_SWITCH_FAN_OUT> remote_tx_x; // [4:7], dependent
std::array<std::optional<uint32_t>, MAX_SWITCH_FAN_OUT> remote_tx_y; // [4:7], dependent
std::array<std::optional<uint32_t>, MAX_SWITCH_FAN_OUT> remote_tx_queue_id; // [4:7]
std::array<std::optional<uint32_t>, MAX_SWITCH_FAN_OUT> remote_tx_network_type; // [4:7]
std::optional<uint32_t> remote_rx_network_type;

std::optional<uint32_t> test_results_buf_addr_arg;
std::optional<uint32_t> test_results_buf_size_bytes;
std::optional<uint32_t> timeout_cycles;
std::array<std::optional<uint32_t>, MAX_SWITCH_FAN_OUT> output_depacketize_cb_log_page_size; // [26:29]
std::array<std::optional<uint32_t>, MAX_SWITCH_FAN_OUT> output_depacketize_local_sem_id; // [26:29]
std::array<std::optional<uint32_t>, MAX_SWITCH_FAN_OUT> output_depacketize_remove_header; // [26:29]
} demux_static_config_t;

typedef struct demux_dependent_config {
std::array<std::optional<uint32_t>, MAX_SWITCH_FAN_OUT> remote_tx_x; // [4:7], dependent
std::array<std::optional<uint32_t>, MAX_SWITCH_FAN_OUT> remote_tx_y; // [4:7], dependent
std::array<std::optional<uint32_t>, MAX_SWITCH_FAN_OUT> remote_tx_queue_start_addr_words; // [8:2:14], dependent
std::array<std::optional<uint32_t>, MAX_SWITCH_FAN_OUT> remote_tx_queue_size_words; // [9:2:15], dependent
std::optional<uint32_t> remote_rx_x; // Dependent
std::optional<uint32_t> remote_rx_y; // Dependent
std::optional<uint32_t> remote_rx_queue_id; // Dependent
std::optional<uint32_t> remote_rx_network_type;

std::optional<uint32_t> dest_endpoint_output_map_hi; // Dependent
std::optional<uint32_t> dest_endpoint_output_map_lo; // Dependent
std::optional<uint32_t> test_results_buf_addr_arg;
std::optional<uint32_t> test_results_buf_size_bytes;
std::optional<uint32_t> timeout_cycles;
std::optional<uint32_t> output_depacketize; // Dependent
std::array<std::optional<uint32_t>, MAX_SWITCH_FAN_OUT> output_depacketize_cb_log_page_size; // [26:29]
std::optional<uint32_t> output_depacketize; // Dependent
std::array<std::optional<uint32_t>, MAX_SWITCH_FAN_OUT> output_depacketize_downstream_sem_id; // [26:29], dependent
std::array<std::optional<uint32_t>, MAX_SWITCH_FAN_OUT> output_depacketize_local_sem_id; // [26:29]
std::array<std::optional<uint32_t>, MAX_SWITCH_FAN_OUT> output_depacketize_remove_header; // [26:29]
} demux_config_t;
} demux_dependent_config_t;

class DemuxKernel : public FDKernel {
public:
Expand All @@ -41,10 +45,11 @@ class DemuxKernel : public FDKernel {
void CreateKernel() override;
void GenerateStaticConfigs() override;
void GenerateDependentConfigs() override;
const demux_config_t& GetConfig() { return this->config; }
void SetPlacementCQID(int id) { this->placement_cq_id = id; }
const demux_static_config_t& GetStaticConfig() { return static_config_; }
void SetPlacementCQID(int id) { placement_cq_id_ = id; }

private:
demux_config_t config;
int placement_cq_id; // TODO: remove channel hard-coding for dispatch core manager
demux_static_config_t static_config_;
demux_dependent_config_t dependent_config_;
int placement_cq_id_; // TODO: remove channel hard-coding for dispatch core manager
};
439 changes: 221 additions & 218 deletions tt_metal/impl/dispatch/kernel_config/dispatch_kernel.cpp

Large diffs are not rendered by default.

41 changes: 24 additions & 17 deletions tt_metal/impl/dispatch/kernel_config/dispatch_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,21 @@
#pragma once
#include "fd_kernel.hpp"

typedef struct dispatch_config {
std::optional<tt_cxy_pair> upstream_logical_core; // Dependant
std::optional<tt_cxy_pair> downstream_logical_core; // Dependant
std::optional<tt_cxy_pair> downstream_s_logical_core; // Dependant

typedef struct dispatch_static_config {
std::optional<uint32_t> dispatch_cb_base; // 0
std::optional<uint32_t> dispatch_cb_log_page_size;
std::optional<uint32_t> dispatch_cb_pages;
std::optional<uint32_t> my_dispatch_cb_sem_id;
std::optional<uint32_t> upstream_dispatch_cb_sem_id; // Dependant

std::optional<uint32_t> dispatch_cb_blocks; // 5
std::optional<uint32_t> upstream_sync_sem; // Dependant
std::optional<uint32_t> command_queue_base_addr;
std::optional<uint32_t> completion_queue_base_addr;
std::optional<uint32_t> completion_queue_size;

std::optional<uint32_t> downstream_cb_base; // 10, dependent
std::optional<uint32_t> downstream_cb_size; // Dependent
std::optional<uint32_t> my_downstream_cb_sem_id;
std::optional<uint32_t> downstream_cb_sem_id; // Dependant

std::optional<uint32_t> split_dispatch_page_preamble_size; // 14
std::optional<uint32_t> split_prefetch;
std::optional<uint32_t> prefetch_h_noc_xy; // Dependent
std::optional<uint32_t> prefetch_h_local_downstream_sem_addr; // Dependent
std::optional<uint32_t> prefetch_h_max_credits;

std::optional<uint32_t> packed_write_max_unicast_sub_cmds; // 19
Expand All @@ -46,7 +35,24 @@ typedef struct dispatch_config {

std::optional<bool> is_d_variant;
std::optional<bool> is_h_variant;
} dispatch_config_t;
} dispatch_static_config_t;

typedef struct dispatch_dependent_config {
std::optional<tt_cxy_pair> upstream_logical_core; // Dependant
std::optional<tt_cxy_pair> downstream_logical_core; // Dependant
std::optional<tt_cxy_pair> downstream_s_logical_core; // Dependant

std::optional<uint32_t> upstream_dispatch_cb_sem_id; // Dependant

std::optional<uint32_t> upstream_sync_sem; // Dependant

std::optional<uint32_t> downstream_cb_base; // 10, dependent
std::optional<uint32_t> downstream_cb_size; // Dependent
std::optional<uint32_t> downstream_cb_sem_id; // Dependant

std::optional<uint32_t> prefetch_h_noc_xy; // Dependent
std::optional<uint32_t> prefetch_h_local_downstream_sem_addr; // Dependent
} dispatch_dependent_config_t;

class DispatchKernel : public FDKernel {
public:
Expand All @@ -59,15 +65,16 @@ class DispatchKernel : public FDKernel {
bool h_variant,
bool d_variant) :
FDKernel(node_id, device_id, servicing_device_id, cq_id, noc_selection) {
config.is_h_variant = h_variant;
config.is_d_variant = d_variant;
static_config_.is_h_variant = h_variant;
static_config_.is_d_variant = d_variant;
}
void CreateKernel() override;
void GenerateStaticConfigs() override;
void GenerateDependentConfigs() override;
void ConfigureCore() override;
const dispatch_config_t& GetConfig() { return this->config; }
const dispatch_static_config_t& GetStaticConfig() { return static_config_; }

private:
dispatch_config_t config;
dispatch_static_config_t static_config_;
dispatch_dependent_config_t dependent_config_;
};
94 changes: 47 additions & 47 deletions tt_metal/impl/dispatch/kernel_config/dispatch_s_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
#include "tt_metal/detail/tt_metal.hpp"

void DispatchSKernel::GenerateStaticConfigs() {
uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device->id());
uint8_t cq_id = this->cq_id;
uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device_->id());
uint8_t cq_id_ = this->cq_id_;
auto& my_dispatch_constants = dispatch_constants::get(GetCoreType());

uint32_t dispatch_s_buffer_base = 0xff;
if (device->dispatch_s_enabled()) {
if (device_->dispatch_s_enabled()) {
uint32_t dispatch_buffer_base = my_dispatch_constants.dispatch_buffer_base();
if (GetCoreType() == CoreType::WORKER) {
// dispatch_s is on the same Tensix core as dispatch_d. Shared resources. Offset CB start idx.
Expand All @@ -25,78 +25,78 @@ void DispatchSKernel::GenerateStaticConfigs() {
dispatch_s_buffer_base = dispatch_buffer_base;
}
}
this->logical_core = dispatch_core_manager::instance().dispatcher_s_core(device->id(), channel, cq_id);
this->config.cb_base = dispatch_s_buffer_base;
this->config.cb_log_page_size = dispatch_constants::DISPATCH_S_BUFFER_LOG_PAGE_SIZE;
this->config.cb_size = my_dispatch_constants.dispatch_s_buffer_size();
logical_core_ = dispatch_core_manager::instance().dispatcher_s_core(device_->id(), channel, cq_id_);
static_config_.cb_base = dispatch_s_buffer_base;
static_config_.cb_log_page_size = dispatch_constants::DISPATCH_S_BUFFER_LOG_PAGE_SIZE;
static_config_.cb_size = my_dispatch_constants.dispatch_s_buffer_size();
// used by dispatch_s to sync with prefetch
this->config.my_dispatch_cb_sem_id = tt::tt_metal::CreateSemaphore(*program, this->logical_core, 0, GetCoreType());
this->config.dispatch_s_sync_sem_base_addr =
static_config_.my_dispatch_cb_sem_id = tt::tt_metal::CreateSemaphore(*program_, logical_core_, 0, GetCoreType());
static_config_.dispatch_s_sync_sem_base_addr =
my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_S_SYNC_SEM);
// used by dispatch_d to signal that dispatch_s can send go signal

this->config.mcast_go_signal_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::GO_MSG);
this->config.unicast_go_signal_addr =
static_config_.mcast_go_signal_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::GO_MSG);
static_config_.unicast_go_signal_addr =
(hal.get_programmable_core_type_index(HalProgrammableCoreType::ACTIVE_ETH) != -1)
? hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::GO_MSG)
: 0;
this->config.distributed_dispatcher = (GetCoreType() == CoreType::ETH);
this->config.worker_sem_base_addr =
static_config_.distributed_dispatcher = (GetCoreType() == CoreType::ETH);
static_config_.worker_sem_base_addr =
my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE);
this->config.max_num_worker_sems = dispatch_constants::DISPATCH_MESSAGE_ENTRIES;
this->config.max_num_go_signal_noc_data_entries = dispatch_constants::DISPATCH_GO_SIGNAL_NOC_DATA_ENTRIES;
static_config_.max_num_worker_sems = dispatch_constants::DISPATCH_MESSAGE_ENTRIES;
static_config_.max_num_go_signal_noc_data_entries = dispatch_constants::DISPATCH_GO_SIGNAL_NOC_DATA_ENTRIES;
}

void DispatchSKernel::GenerateDependentConfigs() {
// Upstream
TT_ASSERT(this->upstream_kernels.size() == 1);
auto prefetch_kernel = dynamic_cast<PrefetchKernel*>(this->upstream_kernels[0]);
TT_ASSERT(upstream_kernels_.size() == 1);
auto prefetch_kernel = dynamic_cast<PrefetchKernel*>(upstream_kernels_[0]);
TT_ASSERT(prefetch_kernel);
this->config.upstream_logical_core = prefetch_kernel->GetLogicalCore();
this->config.upstream_dispatch_cb_sem_id = prefetch_kernel->GetConfig().my_dispatch_s_cb_sem_id;
dependent_config_.upstream_logical_core = prefetch_kernel->GetLogicalCore();
dependent_config_.upstream_dispatch_cb_sem_id = prefetch_kernel->GetStaticConfig().my_dispatch_s_cb_sem_id;

// Downstream
TT_ASSERT(this->downstream_kernels.size() == 1);
auto dispatch_kernel = dynamic_cast<DispatchKernel*>(this->downstream_kernels[0]);
TT_ASSERT(downstream_kernels_.size() == 1);
auto dispatch_kernel = dynamic_cast<DispatchKernel*>(downstream_kernels_[0]);
TT_ASSERT(dispatch_kernel);
this->config.downstream_logical_core = dispatch_kernel->GetLogicalCore();
dependent_config_.downstream_logical_core = dispatch_kernel->GetLogicalCore();
}

void DispatchSKernel::CreateKernel() {
std::vector<uint32_t> compile_args = {
config.cb_base.value(),
config.cb_log_page_size.value(),
config.cb_size.value(),
config.my_dispatch_cb_sem_id.value(),
config.upstream_dispatch_cb_sem_id.value(),
config.dispatch_s_sync_sem_base_addr.value(),
config.mcast_go_signal_addr.value(),
config.unicast_go_signal_addr.value(),
config.distributed_dispatcher.value(),
config.worker_sem_base_addr.value(),
config.max_num_worker_sems.value(),
config.max_num_go_signal_noc_data_entries.value(),
static_config_.cb_base.value(),
static_config_.cb_log_page_size.value(),
static_config_.cb_size.value(),
static_config_.my_dispatch_cb_sem_id.value(),
dependent_config_.upstream_dispatch_cb_sem_id.value(),
static_config_.dispatch_s_sync_sem_base_addr.value(),
static_config_.mcast_go_signal_addr.value(),
static_config_.unicast_go_signal_addr.value(),
static_config_.distributed_dispatcher.value(),
static_config_.worker_sem_base_addr.value(),
static_config_.max_num_worker_sems.value(),
static_config_.max_num_go_signal_noc_data_entries.value(),
};
TT_ASSERT(compile_args.size() == 12);
auto my_virtual_core = device->virtual_core_from_logical_core(this->logical_core, GetCoreType());
auto my_virtual_core = device_->virtual_core_from_logical_core(logical_core_, GetCoreType());
auto upstream_virtual_core =
device->virtual_core_from_logical_core(config.upstream_logical_core.value(), GetCoreType());
device_->virtual_core_from_logical_core(dependent_config_.upstream_logical_core.value(), GetCoreType());
auto downstream_virtual_core =
device->virtual_core_from_logical_core(config.downstream_logical_core.value(), GetCoreType());
auto downstream_s_virtual_core = device->virtual_core_from_logical_core(UNUSED_LOGICAL_CORE, GetCoreType());
device_->virtual_core_from_logical_core(dependent_config_.downstream_logical_core.value(), GetCoreType());
auto downstream_s_virtual_core = device_->virtual_core_from_logical_core(UNUSED_LOGICAL_CORE, GetCoreType());

auto my_virtual_noc_coords = device->virtual_noc0_coordinate(noc_selection.non_dispatch_noc, my_virtual_core);
auto my_virtual_noc_coords = device_->virtual_noc0_coordinate(noc_selection_.non_dispatch_noc, my_virtual_core);
auto upstream_virtual_noc_coords =
device->virtual_noc0_coordinate(noc_selection.upstream_noc, upstream_virtual_core);
device_->virtual_noc0_coordinate(noc_selection_.upstream_noc, upstream_virtual_core);
auto downstream_virtual_noc_coords =
device->virtual_noc0_coordinate(noc_selection.downstream_noc, downstream_virtual_core);
device_->virtual_noc0_coordinate(noc_selection_.downstream_noc, downstream_virtual_core);
auto downstream_s_virtual_noc_coords =
device->virtual_noc0_coordinate(noc_selection.downstream_noc, downstream_s_virtual_core);
device_->virtual_noc0_coordinate(noc_selection_.downstream_noc, downstream_s_virtual_core);

std::map<string, string> defines = {
{"MY_NOC_X", std::to_string(my_virtual_noc_coords.x)},
{"MY_NOC_Y", std::to_string(my_virtual_noc_coords.y)},
{"UPSTREAM_NOC_INDEX", std::to_string(this->noc_selection.upstream_noc)}, // Unused, remove later
{"UPSTREAM_NOC_INDEX", std::to_string(noc_selection_.upstream_noc)}, // Unused, remove later
{"UPSTREAM_NOC_X", std::to_string(upstream_virtual_noc_coords.x)},
{"UPSTREAM_NOC_Y", std::to_string(upstream_virtual_noc_coords.y)},
{"DOWNSTREAM_NOC_X", std::to_string(downstream_virtual_noc_coords.x)},
Expand All @@ -108,11 +108,11 @@ void DispatchSKernel::CreateKernel() {
}

void DispatchSKernel::ConfigureCore() {
if (!this->device->distributed_dispatcher()) {
if (!device_->distributed_dispatcher()) {
return;
}
// Just need to clear the dispatch message
tt::log_warning("Configure Dispatch S (device {} core {})", device->id(), logical_core.str());
tt::log_warning("Configure Dispatch S (device {} core {})", device_->id(), logical_core_.str());
std::vector<uint32_t> zero = {0x0};
auto& my_dispatch_constants = dispatch_constants::get(GetCoreType());
uint32_t dispatch_s_sync_sem_base_addr =
Expand All @@ -124,7 +124,7 @@ void DispatchSKernel::ConfigureCore() {
dispatch_s_sync_sem_base_addr + my_dispatch_constants.get_dispatch_message_offset(i);
uint32_t dispatch_message_addr =
dispatch_message_base_addr + my_dispatch_constants.get_dispatch_message_offset(i);
detail::WriteToDeviceL1(device, this->logical_core, dispatch_s_sync_sem_addr, zero, GetCoreType());
detail::WriteToDeviceL1(device, this->logical_core, dispatch_message_addr, zero, GetCoreType());
detail::WriteToDeviceL1(device_, logical_core_, dispatch_s_sync_sem_addr, zero, GetCoreType());
detail::WriteToDeviceL1(device_, logical_core_, dispatch_message_addr, zero, GetCoreType());
}
}
19 changes: 11 additions & 8 deletions tt_metal/impl/dispatch/kernel_config/dispatch_s_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,11 @@
#pragma once
#include "fd_kernel.hpp"

typedef struct dispatch_s_config {
std::optional<tt_cxy_pair> upstream_logical_core; // Dependant
std::optional<tt_cxy_pair> downstream_logical_core; // Dependant

typedef struct dispatch_s_static_config {
std::optional<uint32_t> cb_base;
std::optional<uint32_t> cb_log_page_size;
std::optional<uint32_t> cb_size;
std::optional<uint32_t> my_dispatch_cb_sem_id;
std::optional<uint32_t> upstream_dispatch_cb_sem_id; // Dependent
std::optional<uint32_t> dispatch_s_sync_sem_base_addr;

std::optional<uint32_t> mcast_go_signal_addr;
Expand All @@ -21,7 +17,13 @@ typedef struct dispatch_s_config {
std::optional<uint32_t> worker_sem_base_addr;
std::optional<uint32_t> max_num_worker_sems;
std::optional<uint32_t> max_num_go_signal_noc_data_entries;
} dispatch_s_config_t;
} dispatch_s_static_config_t;

typedef struct dispatch_s_dependent_config {
std::optional<tt_cxy_pair> upstream_logical_core; // Dependant
std::optional<tt_cxy_pair> downstream_logical_core; // Dependant
std::optional<uint32_t> upstream_dispatch_cb_sem_id; // Dependent
} dispatch_s_dependent_config_t;

class DispatchSKernel : public FDKernel {
public:
Expand All @@ -32,8 +34,9 @@ class DispatchSKernel : public FDKernel {
void GenerateStaticConfigs() override;
void GenerateDependentConfigs() override;
void ConfigureCore() override;
const dispatch_s_config_t& GetConfig() { return this->config; }
const dispatch_s_static_config_t& GetStaticConfig() { return static_config_; }

private:
dispatch_s_config_t config;
dispatch_s_static_config_t static_config_;
dispatch_s_dependent_config_t dependent_config_;
};
Loading

0 comments on commit c173b81

Please sign in to comment.