diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index a9dc1cc23b9..d2d90183ef4 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -109,29 +109,36 @@ AllGatherConfig::AllGatherConfig(Tensor const& input_tensor, Tensor const& outpu constexpr uint32_t total_l1_buffer_space = eth_l1_mem::address_map::MAX_L1_LOADING_SIZE - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; this->is_sharded = input_tensor.is_sharded(); + const uint32_t page_size = input_tensor.buffer()->page_size(); + const std::size_t channel_sync_bytes_overhead = (enable_merged_payload_and_channel_sync * 16); if (user_defined_num_workers.has_value()) { this->num_eth_buffers = user_defined_num_workers.value() / num_duplicate_directions; } else { this->num_eth_buffers = (this->enable_bidirectional ? 8 /*1*/ : (topology != ttnn::ccl::Topology::Linear ? 8 : 4)); } + int next_eth_num_eth_buffers_to_try = this->num_eth_buffers; - if (bidirectional_mode == AllGatherBidirectionalMode::FULL_TENSOR) { - this->num_eth_buffers = std::min(this->num_eth_buffers, eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS / num_duplicate_directions); - } + while (!user_defined_num_workers.has_value() && this->eth_buffer_size < page_size && this->eth_buffer_size < total_l1_buffer_space && next_eth_num_eth_buffers_to_try > 0) { + this->num_eth_buffers = next_eth_num_eth_buffers_to_try; + + if (bidirectional_mode == AllGatherBidirectionalMode::FULL_TENSOR) { + this->num_eth_buffers = std::min(this->num_eth_buffers, eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS / num_duplicate_directions); + } - this->num_workers_per_link = this->num_eth_buffers; - this->eth_sems_l1_base_byte_address = this->erisc_handshake_address + 16 * 3;//16; - // Really should be called offset_after_semaphore_region - this->semaphore_offset = this->semaphore_size * this->num_eth_buffers * num_duplicate_directions; // TODO: Remove this once dedicated semaphore space for user kernels are added - this->eth_buffers_l1_base_byte_address = this->eth_sems_l1_base_byte_address + this->semaphore_offset; + this->num_workers_per_link = this->num_eth_buffers; + this->eth_sems_l1_base_byte_address = this->erisc_handshake_address + 16 * 3;//16; + // Really should be called offset_after_semaphore_region + this->semaphore_offset = this->semaphore_size * this->num_eth_buffers * num_duplicate_directions; // TODO: Remove this once dedicated semaphore space for user kernels are added + this->eth_buffers_l1_base_byte_address = this->eth_sems_l1_base_byte_address + this->semaphore_offset; - std::size_t channel_sync_bytes_overhead = (enable_merged_payload_and_channel_sync * 16); - uint32_t const page_size = input_tensor.buffer()->page_size(); - std::size_t l1_per_buffer_region = ((total_l1_buffer_space - this->semaphore_offset) / (this->num_eth_buffers * num_duplicate_directions * this->num_edm_buffers_per_channel)) - channel_sync_bytes_overhead; - this->eth_buffer_size = tt::round_down(l1_per_buffer_region, page_size); + std::size_t l1_per_buffer_region = ((total_l1_buffer_space - this->semaphore_offset) / (this->num_eth_buffers * num_duplicate_directions * this->num_edm_buffers_per_channel)) - channel_sync_bytes_overhead; + this->eth_buffer_size = tt::round_down(l1_per_buffer_region, page_size); + + next_eth_num_eth_buffers_to_try--; + } - TT_FATAL((this->eth_buffer_size + channel_sync_bytes_overhead) * (this->num_eth_buffers * num_duplicate_directions * this->num_edm_buffers_per_channel) + this->semaphore_offset <= total_l1_buffer_space, "Error"); - TT_FATAL(eth_buffer_size == 0 or (this->num_eth_buffers * num_duplicate_directions) <= eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS, "Error"); + TT_FATAL(this->eth_buffer_size > 0, "Internal errorin all-gather when computing ethernet buffer size."); + TT_FATAL((this->eth_buffer_size + channel_sync_bytes_overhead) * (this->num_eth_buffers * num_duplicate_directions * this->num_edm_buffers_per_channel) + this->semaphore_offset <= total_l1_buffer_space, "Internal all-gather error when sizing buffers. Sized buffers too large for ethernet L1 capacity."); }