Skip to content

Commit

Permalink
#9587: Update CB and worker Go signals to respect max sub cmd limit i…
Browse files Browse the repository at this point in the history
…ntroduced by dispatch packed write local copy change
  • Loading branch information
tt-aho committed Jun 26, 2024
1 parent b99da79 commit a6239ce
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 90 deletions.
170 changes: 91 additions & 79 deletions tt_metal/impl/dispatch/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,41 @@ void EnqueueProgramCommand::assemble_preamble_commands(bool prefetch_stall) {
}
}

template <typename PackedSubCmd>
uint32_t get_max_write_packed_sub_cmds(uint32_t data_size, uint32_t max_prefetch_cmd_size, bool no_stride) {
static_assert(
std::is_same<PackedSubCmd, CQDispatchWritePackedUnicastSubCmd>::value or
std::is_same<PackedSubCmd, CQDispatchWritePackedMulticastSubCmd>::value);
constexpr bool is_unicast = std::is_same<PackedSubCmd, CQDispatchWritePackedUnicastSubCmd>::value;
uint32_t sub_cmd_sizeB =
is_unicast ? sizeof(CQDispatchWritePackedUnicastSubCmd) : sizeof(CQDispatchWritePackedMulticastSubCmd);
// Approximate calculation due to alignment
uint32_t max_prefetch_size = max_prefetch_cmd_size - sizeof(CQPrefetchCmd) - PCIE_ALIGNMENT - sizeof(CQDispatchCmd) - L1_ALIGNMENT;
uint32_t max_prefetch_num_packed_cmds =
no_stride ? (max_prefetch_size - align(data_size * sizeof(uint32_t), L1_ALIGNMENT)) / sub_cmd_sizeB
: max_prefetch_size / (align(data_size * sizeof(uint32_t), L1_ALIGNMENT) + sub_cmd_sizeB);
return min(max_prefetch_num_packed_cmds, is_unicast ? CQ_DISPATCH_CMD_PACKED_WRITE_MAX_UNICAST_SUB_CMDS : CQ_DISPATCH_CMD_PACKED_WRITE_MAX_MULTICAST_SUB_CMDS);
};

template <typename PackedSubCmd>
uint32_t insert_write_packed_payloads(const uint32_t num_sub_cmds, const uint32_t sub_cmd_sizeB, const uint32_t max_prefetch_command_size, std::vector<std::pair<uint32_t, uint32_t>>& packed_cmd_payloads) {
const uint32_t aligned_sub_cmd_sizeB = align(sub_cmd_sizeB, L1_ALIGNMENT);
const uint32_t max_packed_sub_cmds_per_cmd = get_max_write_packed_sub_cmds<PackedSubCmd>(aligned_sub_cmd_sizeB, max_prefetch_command_size, false);
uint32_t rem_num_sub_cmds = num_sub_cmds;
uint32_t cmd_payload_sizeB = 0;
while (rem_num_sub_cmds != 0) {
const uint32_t num_sub_cmds_in_cmd = std::min(max_packed_sub_cmds_per_cmd, rem_num_sub_cmds);
const uint32_t aligned_data_sizeB = aligned_sub_cmd_sizeB * num_sub_cmds_in_cmd;
const uint32_t dispatch_cmd_sizeB = align(
sizeof(CQDispatchCmd) + num_sub_cmds_in_cmd * sizeof(PackedSubCmd),
L1_ALIGNMENT);
packed_cmd_payloads.emplace_back(num_sub_cmds_in_cmd, dispatch_cmd_sizeB + aligned_data_sizeB);
cmd_payload_sizeB += align(sizeof(CQPrefetchCmd) + packed_cmd_payloads.back().second, PCIE_ALIGNMENT);
rem_num_sub_cmds -= num_sub_cmds_in_cmd;
}
return cmd_payload_sizeB;
}

template <typename PackedSubCmd>
void generate_runtime_args_cmds(
std::vector<HostMemDeviceCommand>& runtime_args_command_sequences,
Expand All @@ -353,17 +388,6 @@ void generate_runtime_args_cmds(
(no_stride ? 1 : num_packed_cmds) * align(runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT);
return dispatch_cmd_sizeB + aligned_runtime_data_sizeB;
};
thread_local static auto get_max_num_packed_cmds =
[](uint32_t runtime_args_len, uint32_t max_size, bool is_unicast, bool no_stride) {
uint32_t sub_cmd_sizeB =
is_unicast ? sizeof(CQDispatchWritePackedUnicastSubCmd) : sizeof(CQDispatchWritePackedMulticastSubCmd);
// Approximate calculation due to alignment
max_size = max_size - sizeof(CQPrefetchCmd) - PCIE_ALIGNMENT - sizeof(CQDispatchCmd) - L1_ALIGNMENT;
uint32_t max_num_packed_cmds =
no_stride ? (max_size - align(runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT)) / sub_cmd_sizeB
: max_size / (align(runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT) + sub_cmd_sizeB);
return max_num_packed_cmds;
};
thread_local static auto get_runtime_args_data_offset =
[](uint32_t num_packed_cmds, uint32_t runtime_args_len, bool is_unicast) {
uint32_t sub_cmd_sizeB =
Expand All @@ -376,7 +400,7 @@ void generate_runtime_args_cmds(

uint32_t num_packed_cmds_in_seq = sub_cmds.size();
uint32_t max_packed_cmds =
get_max_num_packed_cmds(max_runtime_args_len, max_prefetch_command_size, unicast, no_stride);
get_max_write_packed_sub_cmds<PackedSubCmd>(max_runtime_args_len, max_prefetch_command_size, no_stride);
uint32_t offset_idx = 0;
if (no_stride) {
TT_FATAL(max_packed_cmds >= num_packed_cmds_in_seq);
Expand Down Expand Up @@ -568,6 +592,7 @@ void EnqueueProgramCommand::assemble_device_commands() {
// Calculate size of command and fill program indices of data to update
// TODO: Would be nice if we could pull this out of program
uint32_t cmd_sequence_sizeB = 0;
const uint32_t max_prefetch_command_size = dispatch_constants::get(dispatch_core_type).max_prefetch_command_size();

for (const auto& [dst, transfer_info_vec] : program.program_transfer_info.multicast_semaphores) {
uint32_t num_packed_cmds = 0;
Expand Down Expand Up @@ -613,11 +638,7 @@ void EnqueueProgramCommand::assemble_device_commands() {

const auto& circular_buffers_unique_coreranges = program.circular_buffers_unique_coreranges();
const uint16_t num_multicast_cb_sub_cmds = circular_buffers_unique_coreranges.size();
uint32_t cb_configs_payload_start =
(cmd_sequence_sizeB + CQ_PREFETCH_CMD_BARE_MIN_SIZE +
align(num_multicast_cb_sub_cmds * sizeof(CQDispatchWritePackedMulticastSubCmd), L1_ALIGNMENT)) /
sizeof(uint32_t);
uint32_t mcast_cb_payload_sizeB = 0;
std::vector<std::pair<uint32_t, uint32_t>> mcast_cb_payload;
uint16_t cb_config_size_bytes = 0;
uint32_t aligned_cb_config_size_bytes = 0;
std::vector<std::vector<uint32_t>> cb_config_payloads(
Expand Down Expand Up @@ -667,15 +688,9 @@ void EnqueueProgramCommand::assemble_device_commands() {
max_overall_base_index = max(max_overall_base_index, max_base_index);
i++;
}
cb_config_size_bytes =
(max_overall_base_index + UINT32_WORDS_PER_CIRCULAR_BUFFER_CONFIG) * sizeof(uint32_t);
cb_config_size_bytes = (max_overall_base_index + UINT32_WORDS_PER_CIRCULAR_BUFFER_CONFIG) * sizeof(uint32_t);
aligned_cb_config_size_bytes = align(cb_config_size_bytes, L1_ALIGNMENT);
const uint32_t aligned_cb_config_data_sizeB = aligned_cb_config_size_bytes * num_multicast_cb_sub_cmds;
const uint32_t dispatch_cmd_sizeB = align(
sizeof(CQDispatchCmd) + num_multicast_cb_sub_cmds * sizeof(CQDispatchWritePackedMulticastSubCmd),
L1_ALIGNMENT);
mcast_cb_payload_sizeB = dispatch_cmd_sizeB + aligned_cb_config_data_sizeB;
cmd_sequence_sizeB += align(sizeof(CQPrefetchCmd) + mcast_cb_payload_sizeB, PCIE_ALIGNMENT);
cmd_sequence_sizeB += insert_write_packed_payloads<CQDispatchWritePackedMulticastSubCmd>(num_multicast_cb_sub_cmds, cb_config_size_bytes, max_prefetch_command_size, mcast_cb_payload);
}

// Program Binaries and Go Signals
Expand Down Expand Up @@ -804,6 +819,8 @@ void EnqueueProgramCommand::assemble_device_commands() {
std::vector<std::pair<const void*, uint32_t>> unicast_go_signal_data;
std::vector<CQDispatchWritePackedMulticastSubCmd> multicast_go_signal_sub_cmds;
std::vector<CQDispatchWritePackedUnicastSubCmd> unicast_go_signal_sub_cmds;
std::vector<std::pair<uint32_t, uint32_t>> multicast_go_signals_payload;
std::vector<std::pair<uint32_t, uint32_t>> unicast_go_signals_payload;
const uint32_t go_signal_sizeB = sizeof(launch_msg_t);
for (KernelGroup& kernel_group : program.get_kernel_groups(CoreType::WORKER)) {
kernel_group.launch_msg.mode = DISPATCH_MODE_DEV;
Expand All @@ -822,13 +839,7 @@ void EnqueueProgramCommand::assemble_device_commands() {
}
}
if (multicast_go_signal_sub_cmds.size() > 0) {
uint32_t num_multicast_sub_cmds = multicast_go_signal_sub_cmds.size();
uint32_t aligned_go_signal_data_sizeB = align(sizeof(launch_msg_t), L1_ALIGNMENT) * num_multicast_sub_cmds;
uint32_t dispatch_cmd_sizeB = align(
sizeof(CQDispatchCmd) + num_multicast_sub_cmds * sizeof(CQDispatchWritePackedMulticastSubCmd),
L1_ALIGNMENT);
uint32_t mcast_payload_sizeB = dispatch_cmd_sizeB + aligned_go_signal_data_sizeB;
cmd_sequence_sizeB += align(sizeof(CQPrefetchCmd) + mcast_payload_sizeB, PCIE_ALIGNMENT);
cmd_sequence_sizeB += insert_write_packed_payloads<CQDispatchWritePackedMulticastSubCmd>(multicast_go_signal_sub_cmds.size(), sizeof(launch_msg_t), max_prefetch_command_size, multicast_go_signals_payload);
}

for (KernelGroup& kernel_group : program.get_kernel_groups(CoreType::ETH)) {
Expand All @@ -847,13 +858,7 @@ void EnqueueProgramCommand::assemble_device_commands() {
}
}
if (unicast_go_signal_sub_cmds.size() > 0) {
uint32_t num_unicast_sub_cmds = unicast_go_signal_sub_cmds.size();
uint32_t aligned_go_signal_data_sizeB = align(sizeof(launch_msg_t), L1_ALIGNMENT) * num_unicast_sub_cmds;
uint32_t dispatch_cmd_sizeB = align(
sizeof(CQDispatchCmd) + num_unicast_sub_cmds * sizeof(CQDispatchWritePackedUnicastSubCmd),
L1_ALIGNMENT);
uint32_t ucast_payload_sizeB = dispatch_cmd_sizeB + aligned_go_signal_data_sizeB;
cmd_sequence_sizeB += align(sizeof(CQPrefetchCmd) + ucast_payload_sizeB, PCIE_ALIGNMENT);
cmd_sequence_sizeB += insert_write_packed_payloads<CQDispatchWritePackedUnicastSubCmd>(unicast_go_signal_sub_cmds.size(), sizeof(launch_msg_t), max_prefetch_command_size, unicast_go_signals_payload);
}

cached_program_command_sequence.program_command_sequence = HostMemDeviceCommand(cmd_sequence_sizeB);
Expand Down Expand Up @@ -920,16 +925,27 @@ void EnqueueProgramCommand::assemble_device_commands() {
}

// CB Configs commands
cached_program_command_sequence.cb_configs_payload_start = cb_configs_payload_start;
cached_program_command_sequence.aligned_cb_config_size_bytes = aligned_cb_config_size_bytes;
if (num_multicast_cb_sub_cmds > 0) {
program_command_sequence.add_dispatch_write_packed<CQDispatchWritePackedMulticastSubCmd>(
num_multicast_cb_sub_cmds,
CIRCULAR_BUFFER_CONFIG_BASE,
cb_config_size_bytes,
mcast_cb_payload_sizeB,
multicast_cb_config_sub_cmds,
multicast_cb_config_data);
uint32_t curr_sub_cmd_idx = 0;
cached_program_command_sequence.cb_configs_payloads.reserve(num_multicast_cb_sub_cmds);
uint32_t cb_config_size_words = aligned_cb_config_size_bytes / sizeof(uint32_t);
for (const auto& [num_sub_cmds_in_cmd, mcast_cb_payload_sizeB] : mcast_cb_payload) {
uint32_t write_offset_bytes = program_command_sequence.write_offset_bytes();
program_command_sequence.add_dispatch_write_packed<CQDispatchWritePackedMulticastSubCmd>(
num_sub_cmds_in_cmd,
CIRCULAR_BUFFER_CONFIG_BASE,
cb_config_size_bytes,
mcast_cb_payload_sizeB,
multicast_cb_config_sub_cmds,
multicast_cb_config_data,
curr_sub_cmd_idx);
curr_sub_cmd_idx += num_sub_cmds_in_cmd;
uint32_t curr_sub_cmd_data_offset_words = (write_offset_bytes + CQ_PREFETCH_CMD_BARE_MIN_SIZE + align(num_sub_cmds_in_cmd * sizeof(CQDispatchWritePackedMulticastSubCmd), L1_ALIGNMENT)) / sizeof(uint32_t);
for (uint32_t i = 0; i < num_sub_cmds_in_cmd; ++i) {
cached_program_command_sequence.cb_configs_payloads.push_back((uint32_t *)program_command_sequence.data() + curr_sub_cmd_data_offset_words);
curr_sub_cmd_data_offset_words += cb_config_size_words;
}
}
}

// Program Binaries
Expand All @@ -955,42 +971,38 @@ void EnqueueProgramCommand::assemble_device_commands() {

// Go Signals
if (multicast_go_signal_sub_cmds.size() > 0) {
uint32_t num_multicast_sub_cmds = multicast_go_signal_sub_cmds.size();
uint32_t aligned_go_signal_data_sizeB = align(sizeof(launch_msg_t), L1_ALIGNMENT) * num_multicast_sub_cmds;
uint32_t dispatch_cmd_sizeB = align(
sizeof(CQDispatchCmd) + num_multicast_sub_cmds * sizeof(CQDispatchWritePackedMulticastSubCmd),
L1_ALIGNMENT);
uint32_t mcast_payload_sizeB = dispatch_cmd_sizeB + aligned_go_signal_data_sizeB;
program_command_sequence.add_dispatch_write_packed<CQDispatchWritePackedMulticastSubCmd>(
num_multicast_sub_cmds,
GET_MAILBOX_ADDRESS_HOST(launch),
go_signal_sizeB,
mcast_payload_sizeB,
multicast_go_signal_sub_cmds,
multicast_go_signal_data);
uint32_t curr_sub_cmd_idx = 0;
for (const auto& [num_sub_cmds_in_cmd, multicast_go_signal_payload_sizeB] : multicast_go_signals_payload) {
program_command_sequence.add_dispatch_write_packed<CQDispatchWritePackedMulticastSubCmd>(
num_sub_cmds_in_cmd,
GET_MAILBOX_ADDRESS_HOST(launch),
go_signal_sizeB,
multicast_go_signal_payload_sizeB,
multicast_go_signal_sub_cmds,
multicast_go_signal_data,
curr_sub_cmd_idx);
curr_sub_cmd_idx += num_sub_cmds_in_cmd;
}
}

if (unicast_go_signal_sub_cmds.size() > 0) {
uint32_t num_unicast_sub_cmds = unicast_go_signal_sub_cmds.size();
uint32_t aligned_go_signal_data_sizeB = align(sizeof(launch_msg_t), L1_ALIGNMENT) * num_unicast_sub_cmds;
uint32_t dispatch_cmd_sizeB = align(
sizeof(CQDispatchCmd) + num_unicast_sub_cmds * sizeof(CQDispatchWritePackedUnicastSubCmd),
L1_ALIGNMENT);
uint32_t ucast_payload_sizeB = dispatch_cmd_sizeB + aligned_go_signal_data_sizeB;
program_command_sequence.add_dispatch_write_packed<CQDispatchWritePackedUnicastSubCmd>(
num_unicast_sub_cmds,
GET_ETH_MAILBOX_ADDRESS_HOST(launch),
go_signal_sizeB,
ucast_payload_sizeB,
unicast_go_signal_sub_cmds,
unicast_go_signal_data);
uint32_t curr_sub_cmd_idx = 0;
for (const auto& [num_sub_cmds_in_cmd, unicast_go_signal_payload_sizeB] : unicast_go_signals_payload) {
program_command_sequence.add_dispatch_write_packed<CQDispatchWritePackedUnicastSubCmd>(
num_sub_cmds_in_cmd,
GET_ETH_MAILBOX_ADDRESS_HOST(launch),
go_signal_sizeB,
unicast_go_signal_payload_sizeB,
unicast_go_signal_sub_cmds,
unicast_go_signal_data,
curr_sub_cmd_idx);
curr_sub_cmd_idx += num_sub_cmds_in_cmd;
}
}
} else {
auto& program_command_sequence = cached_program_command_sequence.program_command_sequence;
uint32_t* cb_config_payload =
(uint32_t*)program_command_sequence.data() + cached_program_command_sequence.cb_configs_payload_start;
uint32_t aligned_cb_config_size_bytes = cached_program_command_sequence.aligned_cb_config_size_bytes;
uint32_t i = 0;
for (const auto& cbs_on_core_range : cached_program_command_sequence.circular_buffers_on_core_ranges) {
uint32_t* cb_config_payload = cached_program_command_sequence.cb_configs_payloads[i];
for (const shared_ptr<CircularBuffer>& cb : cbs_on_core_range) {
const uint32_t cb_address = cb->address() >> 4;
const uint32_t cb_size = cb->size() >> 4;
Expand All @@ -1005,7 +1017,7 @@ void EnqueueProgramCommand::assemble_device_commands() {
cb_config_payload[base_index + 3] = cb->page_size(buffer_index) >> 4;
}
}
cb_config_payload += aligned_cb_config_size_bytes / sizeof(uint32_t);
i++;
}
}
}
Expand Down
3 changes: 1 addition & 2 deletions tt_metal/impl/dispatch/command_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,7 @@ class EnqueueProgramCommand : public Command {
std::vector<HostMemDeviceCommand> runtime_args_command_sequences;
uint32_t runtime_args_fetch_size_bytes;
HostMemDeviceCommand program_command_sequence;
uint32_t cb_configs_payload_start;
uint32_t aligned_cb_config_size_bytes;
std::vector<uint32_t *> cb_configs_payloads;
std::vector<std::vector<std::shared_ptr<CircularBuffer>>> circular_buffers_on_core_ranges;
};
thread_local static std::unordered_map<uint64_t, CachedProgramCommandSequence> cached_program_command_sequences;
Expand Down
4 changes: 4 additions & 0 deletions tt_metal/impl/dispatch/cq_commands.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ struct CQDispatchWritePagedCmd {
constexpr uint32_t CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NONE = 0x00;
constexpr uint32_t CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_MCAST = 0x01;
constexpr uint32_t CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NO_STRIDE = 0x02;

struct CQDispatchWritePackedCmd {
uint8_t flags; // see above
uint16_t count; // number of sub-cmds (max 1020 unicast, 510 mcast). Max num sub-cmds = (dispatch_constants::TRANSFER_PAGE_SIZE - sizeof(CQDispatchCmd)) / sizeof(CQDispatchWritePacked*castSubCmd)
Expand All @@ -178,6 +179,9 @@ struct CQDispatchWritePackedMulticastSubCmd {
uint32_t num_mcast_dests;
} __attribute__((packed));

constexpr uint32_t CQ_DISPATCH_CMD_PACKED_WRITE_MAX_UNICAST_SUB_CMDS = 108; // GS 120 - 1 row TODO: this should be a compile time arg passed in from host
constexpr uint32_t CQ_DISPATCH_CMD_PACKED_WRITE_MAX_MULTICAST_SUB_CMDS = CQ_DISPATCH_CMD_PACKED_WRITE_MAX_UNICAST_SUB_CMDS * sizeof(CQDispatchWritePackedUnicastSubCmd) / sizeof(CQDispatchWritePackedMulticastSubCmd);

struct CQDispatchWritePackedLargeSubCmd {
uint32_t noc_xy_addr;
uint32_t addr;
Expand Down
7 changes: 4 additions & 3 deletions tt_metal/impl/dispatch/device_command.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class DeviceCommand {

void *data() const { return this->cmd_region; }

uint32_t write_offset_bytes() const { return this->cmd_write_offsetB; }

vector_memcpy_aligned<uint32_t> cmd_vector() const { return this->cmd_region_vector; }

void add_dispatch_wait(
Expand Down Expand Up @@ -407,9 +409,8 @@ class DeviceCommand {
std::is_same<PackedSubCmd, CQDispatchWritePackedMulticastSubCmd>::value);
bool multicast = std::is_same<PackedSubCmd, CQDispatchWritePackedMulticastSubCmd>::value;

static constexpr uint32_t max_num_packed_sub_cmds =
(dispatch_constants::TRANSFER_PAGE_SIZE - sizeof(CQDispatchCmd)) / sizeof(PackedSubCmd);
TT_ASSERT(
constexpr uint32_t max_num_packed_sub_cmds = std::is_same<PackedSubCmd, CQDispatchWritePackedUnicastSubCmd>::value ? CQ_DISPATCH_CMD_PACKED_WRITE_MAX_UNICAST_SUB_CMDS : CQ_DISPATCH_CMD_PACKED_WRITE_MAX_MULTICAST_SUB_CMDS;
TT_FATAL(
num_sub_cmds <= max_num_packed_sub_cmds,
"Max number of packed sub commands are {} but requesting {}",
max_num_packed_sub_cmds,
Expand Down
Loading

0 comments on commit a6239ce

Please sign in to comment.