diff --git a/tt_eager/tt_dnn/op_library/all_gather/all_gather_op.hpp b/tt_eager/tt_dnn/op_library/all_gather/all_gather_op.hpp index 1c0961c00986..4bb034513540 100644 --- a/tt_eager/tt_dnn/op_library/all_gather/all_gather_op.hpp +++ b/tt_eager/tt_dnn/op_library/all_gather/all_gather_op.hpp @@ -30,6 +30,15 @@ enum AllGatherMode { SINGLE_TILE_HIGH_WIDTH_SHARDED }; +enum AllGatherBidirectionalMode { + // Splits the tensor into two and sends each half in opposite directions + // the full width around the ring + SPLIT_TENSOR, + // Doesn't split the tensor and sends the full tensor in both directions, + // half-way around the ring + FULL_TENSOR +}; + namespace all_gather_op { using ccl::Topology; }; // namespace all_gather_op @@ -39,6 +48,17 @@ using ccl::EriscDatamoverBuilder; AllGatherMode choose_all_gather_mode(Tensor const& input_tensor, Tensor const& output_tensor, uint32_t dim); class AllGatherConfig { + static AllGatherBidirectionalMode choose_bidirectional_mode(Tensor const& input_tensor) { + std::size_t eth_l1_capacity = eth_l1_mem::address_map::MAX_L1_LOADING_SIZE - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; + std::size_t tensor_size_bytes = input_tensor.shape().volume() * input_tensor.element_size(); + // This is currently a guestimate. We need a lot more hard data to identify where this dividing line is. + bool perf_degradation_from_full_tensor_mode = tensor_size_bytes > (2 * eth_l1_capacity); + if (input_tensor.is_sharded() || perf_degradation_from_full_tensor_mode) { + return AllGatherBidirectionalMode::SPLIT_TENSOR; + } + return AllGatherBidirectionalMode::FULL_TENSOR; + } + public: AllGatherConfig(Tensor const& input_tensor, Tensor const& output_tensor, uint32_t dim, uint32_t ring_size, uint32_t num_links, all_gather_op::Topology topology) : num_links(num_links), @@ -52,7 +72,11 @@ class AllGatherConfig { input_is_dram(input_tensor.buffer()->buffer_type() == BufferType::DRAM), output_is_dram(output_tensor.buffer()->buffer_type() == BufferType::DRAM), - mode(choose_all_gather_mode(input_tensor, output_tensor, dim)) + mode(choose_all_gather_mode(input_tensor, output_tensor, dim)), + + // Sharded currently doesn't support FULL_TENSOR bidirectional due to indexers that require updating in order to support this + // new mode + bidirectional_mode(choose_bidirectional_mode(input_tensor)) { TT_ASSERT(erisc_handshake_address >= eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE); TT_ASSERT(erisc_handshake_address < eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE + 16); @@ -70,12 +94,17 @@ class AllGatherConfig { // "duplicate" directions are a short hand to enable linear/mesh all-gather topologies with // less code-changes. Ideally a new concept is added amongst "num_eth_buffers", "num_workers_per_link", etc. - uint32_t num_duplicate_directions = topology == all_gather_op::Topology::Ring ? 1 : 2; + uint32_t num_duplicate_directions = (topology == all_gather_op::Topology::Ring && bidirectional_mode != AllGatherBidirectionalMode::FULL_TENSOR) ? 1 : 2; constexpr uint32_t total_l1_buffer_space = eth_l1_mem::address_map::MAX_L1_LOADING_SIZE - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; this->is_sharded = input_tensor.is_sharded(); - this->num_eth_buffers = (this->enable_bidirectional ? 8 : (this->is_sharded && topology != all_gather_op::Topology::Linear ? 8 : 4)); + this->num_eth_buffers = (this->enable_bidirectional ? 8 /*1*/ : (this->is_sharded && topology != all_gather_op::Topology::Linear ? 8 : 4)); + + if (bidirectional_mode == AllGatherBidirectionalMode::FULL_TENSOR) { + this->num_eth_buffers = std::min(this->num_eth_buffers, eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS / num_duplicate_directions); + } + if (this->is_sharded) { this->num_eth_buffers = std::min(this->num_eth_buffers, input_tensor.shard_spec()->num_cores()); if ((input_tensor.shard_spec()->num_cores() / this->num_eth_buffers) % (ring_size) != 0 && @@ -137,6 +166,7 @@ class AllGatherConfig { buffer_index < get_num_edm_channels_in_clockwise_direction() : true; } + AllGatherBidirectionalMode get_bidirectional_mode() const { return this->bidirectional_mode; } uint32_t get_num_edm_channels_in_counter_clockwise_direction() const { // return all_gather_buffer_params::enable_bidirectional ? all_gather_buffer_params::num_buffers - all_gather_buffer_params::num_buffers / 2 : 0; // Force all through counter-clockwise direction @@ -176,6 +206,7 @@ class AllGatherConfig { uint32_t eth_sems_l1_base_byte_address; const all_gather_op::Topology topology; AllGatherMode mode; + AllGatherBidirectionalMode bidirectional_mode; bool is_sharded; bool enable_bidirectional; const bool input_is_dram; diff --git a/tt_eager/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp b/tt_eager/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp index 685979cd5a9b..4d27e8527e63 100644 --- a/tt_eager/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp +++ b/tt_eager/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp @@ -100,5 +100,4 @@ void kernel_main() { pop_filler_pages_from_cb(cb_id_in0, half_cb_n_pages - rem_num_pages); } } - } diff --git a/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp index 0371d0f316e4..3c129bc2fa60 100644 --- a/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp @@ -57,6 +57,83 @@ static std::tuple select_worker_cores(AllGatherConfig } +std::vector> compute_worker_sender_num_transfers( + AllGatherConfig const& all_gather_config, uint32_t num_links, uint32_t ring_size, uint32_t ring_index, all_gather_op::Topology topology, uint32_t direction +) { + std::vector> worker_sender_num_transfers; + worker_sender_num_transfers.reserve(num_links); + for (uint32_t l = 0; l < num_links; ++l) { + worker_sender_num_transfers.emplace_back(all_gather_config.get_num_eth_buffers_per_edm()); + for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { + uint32_t &worker_num_transfers = worker_sender_num_transfers.at(l).at(b); + switch (topology) { + case all_gather_op::Topology::Linear: + worker_num_transfers = direction == 0 ? ring_index + 1 : ring_size - ring_index; + break; + + case all_gather_op::Topology::Ring: + switch (all_gather_config.get_bidirectional_mode()) { + case tt::tt_metal::AllGatherBidirectionalMode::SPLIT_TENSOR: + worker_num_transfers = ring_size - 1; + break; + + case tt::tt_metal::AllGatherBidirectionalMode::FULL_TENSOR: + worker_num_transfers = direction == 0 /*all_gather_config.is_buffer_in_clockwise_ring(b)*/ ? + ((((ring_size - 1) - 1) / 2) + 1): + (ring_size - 1) / 2; + break; + + default: + TT_FATAL("Unsupported bidirectional mode"); + }; + break; + + default: + TT_FATAL("Unsupported topology"); + }; + } + } + + return worker_sender_num_transfers; +} +std::vector> compute_worker_receiver_num_transfers( + AllGatherConfig const& all_gather_config, uint32_t num_links, uint32_t ring_size, uint32_t ring_index, all_gather_op::Topology topology, uint32_t direction) { + std::vector> worker_sender_num_transfers; + worker_sender_num_transfers.reserve(num_links); + for (uint32_t l = 0; l < num_links; ++l) { + worker_sender_num_transfers.emplace_back(all_gather_config.get_num_eth_buffers_per_edm()); + for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { + uint32_t &worker_num_transfers = worker_sender_num_transfers.at(l).at(b); + switch (topology) { + case all_gather_op::Topology::Linear: + worker_num_transfers = (direction == 0 ? ring_index + 1: ring_size - ring_index) - 1; + break; + + case all_gather_op::Topology::Ring: + switch (all_gather_config.get_bidirectional_mode()) { + case tt::tt_metal::AllGatherBidirectionalMode::SPLIT_TENSOR: + worker_num_transfers = ring_size - 1; + break; + + case tt::tt_metal::AllGatherBidirectionalMode::FULL_TENSOR: + worker_num_transfers = direction == 0 /*all_gather_config.is_buffer_in_clockwise_ring(b)*/ ? + ((((ring_size - 1) - 1) / 2) + 1): + (ring_size - 1) / 2; + break; + + default: + TT_FATAL("Unsupported bidirectional mode"); + }; + break; + + default: + TT_FATAL("Unsupported topology"); + }; + } + } + + return worker_sender_num_transfers; +} // For ring all-gather, we can send sub-sections of input tensor in opposite directions @@ -136,13 +213,14 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& std::vector all_worker_receiver_cores; all_worker_receiver_cores.reserve(total_worker_core_pairs_used); - uint32_t num_input_pages = input_tensor.buffer()->size() / input_page_size; uint32_t min_pages_per_link = num_input_pages / num_links; - - - const uint32_t num_full_send_directions = topology == all_gather_op::Topology::Linear ? 2 : 1; + bool full_send_both_directions = + (topology == all_gather_op::Topology::Linear || + (topology == all_gather_op::Topology::Ring && + all_gather_config.get_bidirectional_mode() == tt::tt_metal::AllGatherBidirectionalMode::FULL_TENSOR)); + const uint32_t num_full_send_directions = full_send_both_directions ? 2 : 1; constexpr uint32_t max_num_full_send_directions = 2; std::vector clockwise_edm_builders; @@ -153,8 +231,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& auto edm_sem_addrs_per_link = std::vector>(num_links); auto edm_buffer_addrs_per_link = std::vector>(num_links); for (uint32_t link = 0; link < num_links; link++) { - edm_sem_addrs_per_link.at(link).reserve(all_gather_config.get_num_eth_buffers_per_edm()); - edm_buffer_addrs_per_link.at(link).reserve(all_gather_config.get_num_eth_buffers_per_edm()); + edm_sem_addrs_per_link.at(link).reserve(all_gather_config.get_num_eth_buffers_per_edm() * num_full_send_directions); + edm_buffer_addrs_per_link.at(link).reserve(all_gather_config.get_num_eth_buffers_per_edm() * num_full_send_directions); uint32_t edm_sem_addr = all_gather_config.get_eth_sems_l1_base_byte_address(); uint32_t edm_buffer_addr = all_gather_config.get_eth_buffers_l1_base_byte_address(); @@ -180,22 +258,21 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& // but if we are implementing a line topology, the number of transfers will depend on whether we // are setting up the forward/clockwise direction or the backward/counter-clockwise direction and also // how far we are from the first/last chip, depending on whether we are in forward or direction - const uint32_t sender_num_transfers = topology == all_gather_op::Topology::Linear ? - (direction == 0 ? ring_index + 1: ring_size - ring_index): - ring_size - 1; - const uint32_t receiver_num_transfers = topology == all_gather_op::Topology::Linear ? - sender_num_transfers - 1: - ring_size - 1; + + auto const& sender_worker_num_transfers = compute_worker_sender_num_transfers( + all_gather_config, num_links, ring_size, ring_index, topology, direction); + auto const& receiver_worker_num_transfers = compute_worker_receiver_num_transfers( + all_gather_config, num_links, ring_size, ring_index, topology, direction); std::vector eth_sender_cores; eth_sender_cores.reserve(num_links); std::vector eth_receiver_cores; eth_receiver_cores.reserve(num_links); // If linear topology, the first chip in the chain will not have a "receiver" eth core (or more correctly, // it doesn't have an input clockwise our output counter-clockwise connection) - bool is_first_chip_in_chain = direction == 0 ? ring_index == 0 : ring_index == ring_size - 1; + bool is_first_chip_in_chain = is_linear && (direction == 0 ? ring_index == 0 : ring_index == ring_size - 1); // If linear topology, the last chip in the chain will not have a "sender" eth core (or more correctly, // it doesn't have an output clockwise our input counter-clockwise connection) - bool is_last_chip_in_chain = direction == 0 ? ring_index == ring_size - 1 : ring_index == 0; + bool is_last_chip_in_chain = is_linear && (direction == 0 ? ring_index == ring_size - 1 : ring_index == 0); uint32_t sender_socket_idx = 0; uint32_t receiver_socket_idx = 0; @@ -222,10 +299,14 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& } } - auto is_buffer_in_clockwise_direction = [&all_gather_config,&direction](uint32_t b) { + auto is_buffer_in_clockwise_direction = [&all_gather_config,&direction,&topology_config](uint32_t b) { TT_ASSERT(direction < max_num_full_send_directions); - bool in_clockwise_direction = all_gather_config.is_buffer_in_clockwise_ring(b); - return (direction == 0) ? in_clockwise_direction : !in_clockwise_direction; + if (!topology_config.is_linear && all_gather_config.get_bidirectional_mode() == tt::tt_metal::AllGatherBidirectionalMode::FULL_TENSOR) { + return direction == 0; + } else { + bool in_clockwise_direction = all_gather_config.is_buffer_in_clockwise_ring(b); + return (direction == 0) ? in_clockwise_direction : !in_clockwise_direction; + } }; std::vector pages_per_link(num_links, min_pages_per_link); @@ -317,11 +398,24 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& TT_ASSERT(rem_pages < pages_per_chunk || num_full_chunks == 0); TT_ASSERT(rem_pages <= max_pages_per_chunk); - std::vector num_full_chunks_per_worker(all_gather_config.get_num_eth_buffers_per_edm(), num_full_chunks / all_gather_config.get_num_eth_buffers_per_edm()); + std::vector num_full_chunks_per_worker(all_gather_config.get_num_eth_buffers_per_edm(),0); + std::vector rem_pages_per_worker(all_gather_config.get_num_eth_buffers_per_edm(), 0); std::vector is_channel_shrinkable(all_gather_config.get_num_eth_buffers_per_edm(), false); std::vector largest_packets_per_channel(all_gather_config.get_num_eth_buffers_per_edm(), 0); - std::vector rem_pages_per_worker(all_gather_config.get_num_eth_buffers_per_edm(), 0); + + std::vector clockwise_link_buffer_num_messages_to_send; + std::vector counter_clockwise_link_buffer_num_messages_to_send; + std::vector edm_semaphores_base_address; + std::vector link_buffer_sender_addresses; + clockwise_link_buffer_num_messages_to_send.reserve(all_gather_config.get_num_eth_buffers_per_edm()); + counter_clockwise_link_buffer_num_messages_to_send.reserve(all_gather_config.get_num_eth_buffers_per_edm()); + edm_semaphores_base_address.reserve(all_gather_config.get_num_eth_buffers_per_edm()); + link_buffer_sender_addresses.reserve(all_gather_config.get_num_eth_buffers_per_edm()); + { + for (std::size_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); b++) { + num_full_chunks_per_worker.at(b) = num_full_chunks / all_gather_config.get_num_eth_buffers_per_edm(); + } uint32_t worker_idx = 0; for (worker_idx = 0; worker_idx < num_full_chunks % all_gather_config.get_num_eth_buffers_per_edm(); ++worker_idx) { num_full_chunks_per_worker.at(worker_idx)++; @@ -330,16 +424,21 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& rem_pages_per_worker.at(worker_idx % all_gather_config.get_num_eth_buffers_per_edm()) = rem_pages; TT_ASSERT(rem_pages_per_worker.at(worker_idx % all_gather_config.get_num_eth_buffers_per_edm()) * 2 <= cb_num_pages); } + { // Logging + log_trace(tt::LogOp, "num_full_chunks, remaining pages per worker (clockwise):"); + for (std::size_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); b++) { + if (is_buffer_in_clockwise_direction(b)) { + log_trace(tt::LogOp, "\tworker {}: {}, {}", b, num_full_chunks_per_worker.at(b), rem_pages_per_worker.at(b)); + } + } + log_trace(tt::LogOp, "num_full_chunks, remaining pages per worker (counter-clockwise):"); + for (std::size_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); b++) { + if (!is_buffer_in_clockwise_direction(b)) { + log_trace(tt::LogOp, "\tworker {}: {}, {}", b, num_full_chunks_per_worker.at(b), rem_pages_per_worker.at(b)); + } + } + } } - - std::vector clockwise_link_buffer_num_messages_to_send; - std::vector counter_clockwise_link_buffer_num_messages_to_send; - std::vector edm_semaphores_base_address; - std::vector link_buffer_sender_addresses; - clockwise_link_buffer_num_messages_to_send.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - counter_clockwise_link_buffer_num_messages_to_send.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - edm_semaphores_base_address.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - link_buffer_sender_addresses.reserve(all_gather_config.get_num_eth_buffers_per_edm()); if (is_sharded) { for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { auto input_tensor_shard_arg_generator = InputTensorShardAddrGenArgGenerator( @@ -378,10 +477,10 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& // link num messages clockwise_link_buffer_num_messages_to_send.push_back( (num_full_chunks_per_worker.at(b) + (rem_pages_per_worker.at(b) > 0 ? 1 : 0)) * - sender_num_transfers); + sender_worker_num_transfers.at(i).at(b)); counter_clockwise_link_buffer_num_messages_to_send.push_back( (num_full_chunks_per_worker.at(b) + (rem_pages_per_worker.at(b) > 0 ? 1 : 0)) * - receiver_num_transfers); + receiver_worker_num_transfers.at(i).at(b)); } for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { log_trace(tt::LogOp, "rem_pages_per_worker[{}]: {}", b, rem_pages_per_worker.at(b)); @@ -473,18 +572,18 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& // 2) num_transfers std::vector worker_reader_sender_ct_args = { static_cast(sharding_info.get_shard_type()), - static_cast(sender_num_transfers) + static_cast(sender_worker_num_transfers.at(i).at(b)) }; log_trace(tt::LogOp, "----worker_reader_sender_ct_args size={}", worker_reader_sender_ct_args.size()); log_trace(tt::LogOp, "\tsharding_info.get_shard_type(): {}", sharding_info.get_shard_type()); - log_trace(tt::LogOp, "\tnum_transfers: {}", sender_num_transfers); + log_trace(tt::LogOp, "\tnum_transfers: {}", sender_worker_num_transfers.at(i).at(b)); return worker_reader_sender_ct_args; } else { std::vector worker_reader_sender_ct_args = { static_cast(all_gather_config.is_input_dram()), static_cast(all_gather_config.is_output_dram()), - static_cast(sender_num_transfers), + static_cast(sender_worker_num_transfers.at(i).at(b)), static_cast(num_full_chunks_per_worker.at(b)), static_cast(input_page_size), static_cast(output_page_size), @@ -513,7 +612,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& log_trace(tt::LogOp, "Worker {} SR args", b); log_trace(tt::LogOp, "\tall_gather_config.is_input_dram(): {}", all_gather_config.is_input_dram()); log_trace(tt::LogOp, "\tall_gather_config.is_output_dram(): {}", all_gather_config.is_output_dram()); - log_trace(tt::LogOp, "\tsender_num_transfers: {}", sender_num_transfers); + log_trace(tt::LogOp, "\tsender_num_transfers: {}", sender_worker_num_transfers.at(i).at(b)); log_trace(tt::LogOp, "\tnum_full_chunks_per_worker.at(b): {}", num_full_chunks_per_worker.at(b)); log_trace(tt::LogOp, "\tinput_page_size: {}", input_page_size); log_trace(tt::LogOp, "\toutput_page_size: {}", output_page_size); @@ -650,7 +749,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& CoreCoord const& worker_eth_sender_core = is_clockwise_direction ? eth_sender_cores.at(i) : eth_receiver_cores.at(i); std::vector worker_writer_sender_ct_args = { static_cast(all_gather_config.is_output_dram()), - static_cast(sender_num_transfers), + static_cast(sender_worker_num_transfers.at(i).at(b)), static_cast(num_full_chunks_per_worker.at(b)), static_cast(input_page_size), static_cast(output_page_size), @@ -675,7 +774,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& }; log_trace(tt::LogOp, "Worker {} SW CT args", b); log_trace(tt::LogOp, "\tall_gather_config.is_output_dram(): {}", all_gather_config.is_output_dram()); - log_trace(tt::LogOp, "\tsender_num_transfers: {}", sender_num_transfers); + log_trace(tt::LogOp, "\tsender_num_transfers: {}", sender_worker_num_transfers.at(i).at(b)); log_trace(tt::LogOp, "\tnum_full_chunks_per_worker: {}", num_full_chunks_per_worker.at(b)); log_trace(tt::LogOp, "\tinput_page_size: {}", input_page_size); log_trace(tt::LogOp, "\toutput_page_size: {}", output_page_size); @@ -758,7 +857,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& static_cast(device->ethernet_core_from_logical_core(worker_eth_sender_core).y), // eth_sender_noc_y static_cast(pages_per_eth_l1_buffer.at(b)), //output_tensor_shard_arg_generator.args_struct.num_dest_cores),//pages_per_eth_l1_buffer.at(b)), static_cast(sender_worker_writer_semaphore_addr), // writer_send_sem_addr - static_cast(sender_num_transfers), + static_cast(sender_worker_num_transfers.at(i).at(b)), static_cast(input_tensor_shard_arg_generator.args_struct.num_dest_cores), static_cast(cb_num_pages / 2), }; @@ -774,7 +873,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& log_trace(tt::LogOp, "\teth_sender_noc_y: {}", device->ethernet_core_from_logical_core(worker_eth_sender_core).y); log_trace(tt::LogOp, "\tpages_per_eth_l1_buffer: {}", pages_per_eth_l1_buffer.at(b)); log_trace(tt::LogOp, "\twriter_send_sem_addr: {}", sender_worker_writer_semaphore_addr); - log_trace(tt::LogOp, "\tnum_transfers: {}", sender_num_transfers); + log_trace(tt::LogOp, "\tnum_transfers: {}", sender_worker_num_transfers.at(i).at(b)); output_tensor_shard_arg_generator.dump_to_log(); return worker_writer_sender_rt_args; @@ -820,7 +919,6 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& TT_ASSERT(!is_linear || ((is_clockwise_direction && (ring_index != 0)) || (!is_clockwise_direction && ring_index != ring_size - 1)) ); - // TODO(snijjar): squash before merge uint32_t receiver_ring_index = is_linear? (is_clockwise_direction ? ring_index - 1 : ring_index + 1): (is_clockwise_direction ? @@ -872,7 +970,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& } else { CoreCoord const& worker_eth_receiver_core = is_clockwise_direction ? eth_receiver_cores.at(i) : eth_sender_cores.at(i); std::vector worker_receiver_reader_ct_args = { - static_cast(receiver_num_transfers), + static_cast(receiver_worker_num_transfers.at(i).at(b)), static_cast(num_full_chunks_per_worker.at(b)), static_cast(input_page_size), static_cast(pages_per_chunk), @@ -885,7 +983,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& }; log_trace(tt::LogOp, "Worker {} RR ct args", b); - log_trace(tt::LogOp, "\treceiver_num_transfers: {}", receiver_num_transfers); + log_trace(tt::LogOp, "\treceiver_num_transfers: {}", receiver_worker_num_transfers.at(i).at(b)); log_trace(tt::LogOp, "\tnum_full_chunks_per_worker: {}", num_full_chunks_per_worker.at(b)); log_trace(tt::LogOp, "\tinput_page_size: {}", input_page_size); log_trace(tt::LogOp, "\tpages_per_chunk: {}", pages_per_chunk); @@ -943,7 +1041,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& worker_reader_receiver_rt_args.push_back(static_cast(receiver_eth_sem_addrs.at(b))); // eth_receiver_l1_semaphore_addr worker_reader_receiver_rt_args.push_back(receiver_worker_semaphore_addr); // local_receiver_read_sem_addr worker_reader_receiver_rt_args.push_back(pages_per_eth_l1_buffer.at(b)), //output_tensor_shard_arg_generator.args_struct.num_dest_cores), //pages_per_eth_l1_buffer.at(b)); // num_shards_per_eth_buf - worker_reader_receiver_rt_args.push_back(receiver_num_transfers); // local_receiver_read_sem_addr + worker_reader_receiver_rt_args.push_back(receiver_worker_num_transfers.at(i).at(b)); // local_receiver_read_sem_addr worker_reader_receiver_rt_args.push_back(static_cast(cb_num_pages / 2)); // local_receiver_read_sem_addr std::copy(output_tensor_shard_addr_gen_args.begin(), output_tensor_shard_addr_gen_args.end(), std::back_inserter(worker_reader_receiver_rt_args)); @@ -1002,7 +1100,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& } else { std::vector worker_writer_receiver_ct_args = { static_cast(all_gather_config.is_output_dram()), - static_cast(receiver_num_transfers), + static_cast(receiver_worker_num_transfers.at(i).at(b)), static_cast(num_full_chunks_per_worker.at(b)), static_cast(input_page_size), static_cast(output_page_size), @@ -1029,7 +1127,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& log_trace(tt::LogOp, "Worker {} RW ct args", b); log_trace(tt::LogOp, "\tall_gather_config.is_output_dram(): {}", all_gather_config.is_output_dram()); - log_trace(tt::LogOp, "\treceiver_num_transfers: {}", receiver_num_transfers); + log_trace(tt::LogOp, "\treceiver_num_transfers: {}", receiver_worker_num_transfers.at(i).at(b)); log_trace(tt::LogOp, "\tnum_full_chunks_per_worker.at(b): {}", num_full_chunks_per_worker.at(b)); log_trace(tt::LogOp, "\tinput_page_size: {}", input_page_size); log_trace(tt::LogOp, "\toutput_page_size: {}", output_page_size); @@ -1094,7 +1192,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& worker_receive_writer_rt_args.push_back(sender_worker_reader_semaphore_addr); worker_receive_writer_rt_args.push_back(output_tensor_shard_arg_generator.args_struct.num_dest_cores), //pages_per_eth_l1_buffer.at(b)); - worker_receive_writer_rt_args.push_back(receiver_num_transfers); + worker_receive_writer_rt_args.push_back(receiver_worker_num_transfers.at(i).at(b)); worker_receive_writer_rt_args.push_back(pages_per_buffer.at(b)); worker_receive_writer_rt_args.push_back(static_cast(cb_num_pages / 2)); diff --git a/tt_eager/tt_dnn/op_library/ccl/ccl_common.cpp b/tt_eager/tt_dnn/op_library/ccl/ccl_common.cpp index 1fc8a9c9469f..c0938b7f7d8d 100644 --- a/tt_eager/tt_dnn/op_library/ccl/ccl_common.cpp +++ b/tt_eager/tt_dnn/op_library/ccl/ccl_common.cpp @@ -79,12 +79,12 @@ KernelHandle generate_edm_kernel( ccl::EriscDatamoverBuilder const& edm_builder, CoreCoord const& eth_core, NOC noc_id) { - log_trace(tt::LogOp, "EDM CLOCKWISE KERNEL RT ARGS: "); edm_builder.dump_to_log(); std::vector const& edm_clockwise_kernel_rt_args = edm_builder.emit_runtime_args(); // Ethernet Kernels std::vector eth_sender_ct_args = edm_builder.emit_compile_time_args(); + log_trace(tt::LogOp, "EDM core (x={},y={}):", eth_core.x, eth_core.y); log_trace(tt::LogOp, "CT ARGS:"); for (auto const& s : eth_sender_ct_args) { log_trace(tt::LogOp, "\t{}", s); diff --git a/tt_eager/tt_dnn/op_library/ccl/edm/erisc_datamover.cpp b/tt_eager/tt_dnn/op_library/ccl/edm/erisc_datamover.cpp index 0b40a9df7d60..f1c420cf4302 100644 --- a/tt_eager/tt_dnn/op_library/ccl/edm/erisc_datamover.cpp +++ b/tt_eager/tt_dnn/op_library/ccl/edm/erisc_datamover.cpp @@ -106,26 +106,27 @@ struct sender_receiver_index_t { } }; -void kernel_main() { - // COMPILE TIME ARGS - // If true, will enable this erisc's sender functionality - constexpr bool enable_sender_side = get_compile_time_arg_val(0) != 0; +// COMPILE TIME ARGS +// If true, will enable this erisc's sender functionality +static constexpr bool enable_sender_side = get_compile_time_arg_val(0) != 0; + +// If true, will enable this erisc's receiver functionality +static constexpr bool enable_receiver_side = get_compile_time_arg_val(1) != 0; - // If true, will enable this erisc's receiver functionality - constexpr bool enable_receiver_side = get_compile_time_arg_val(1) != 0; +static constexpr uint32_t num_senders = get_compile_time_arg_val(2); +static constexpr uint32_t num_receivers = get_compile_time_arg_val(3); - constexpr uint32_t num_senders = get_compile_time_arg_val(2); - constexpr uint32_t num_receivers = get_compile_time_arg_val(3); +static constexpr tt::tt_metal::ccl::EriscDataMoverBufferSharingMode edm_buffer_sharing_mode = + static_cast(get_compile_time_arg_val(4)); - constexpr tt::tt_metal::ccl::EriscDataMoverBufferSharingMode edm_buffer_sharing_mode = - static_cast(get_compile_time_arg_val(4)); +static constexpr tt::tt_metal::ccl::EriscDataMoverTerminationMode terminate_on_worker_signal = + static_cast(get_compile_time_arg_val(5)); - constexpr tt::tt_metal::ccl::EriscDataMoverTerminationMode terminate_on_worker_signal = - static_cast(get_compile_time_arg_val(5)); +using EDM_CONFIG_T = erisc::datamover::EriscDatamoverConfig; +using ChannelBufferT = erisc::datamover::ChannelBuffer; + +void kernel_main() { - constexpr auto EDM_CONFIG = erisc::datamover::EriscDatamoverConfig(); - using EDM_CONFIG_T = decltype(EDM_CONFIG); - using ChannelBufferT = erisc::datamover::ChannelBuffer; std::array buffer_channels; //