Skip to content

Commit

Permalink
#0: fix all-gather to instantiate larger EDM channels to fit larger p…
Browse files Browse the repository at this point in the history
…age sizes

if page size > default channel size based on default worker/channel count
  • Loading branch information
tt-snijjar committed Oct 3, 2024
1 parent 77c1a07 commit eb113ac
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,29 +109,36 @@ AllGatherConfig::AllGatherConfig(Tensor const& input_tensor, Tensor const& outpu
constexpr uint32_t total_l1_buffer_space = eth_l1_mem::address_map::MAX_L1_LOADING_SIZE - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE;

this->is_sharded = input_tensor.is_sharded();
const uint32_t page_size = input_tensor.buffer()->page_size();
const std::size_t channel_sync_bytes_overhead = (enable_merged_payload_and_channel_sync * 16);
if (user_defined_num_workers.has_value()) {
this->num_eth_buffers = user_defined_num_workers.value() / num_duplicate_directions;
} else {
this->num_eth_buffers = (this->enable_bidirectional ? 8 /*1*/ : (topology != ttnn::ccl::Topology::Linear ? 8 : 4));
}
int next_eth_num_eth_buffers_to_try = this->num_eth_buffers;

if (bidirectional_mode == AllGatherBidirectionalMode::FULL_TENSOR) {
this->num_eth_buffers = std::min(this->num_eth_buffers, eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS / num_duplicate_directions);
}
while (!user_defined_num_workers.has_value() && this->eth_buffer_size < page_size && this->eth_buffer_size < total_l1_buffer_space && next_eth_num_eth_buffers_to_try > 0) {
this->num_eth_buffers = next_eth_num_eth_buffers_to_try;

if (bidirectional_mode == AllGatherBidirectionalMode::FULL_TENSOR) {
this->num_eth_buffers = std::min(this->num_eth_buffers, eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS / num_duplicate_directions);
}

this->num_workers_per_link = this->num_eth_buffers;
this->eth_sems_l1_base_byte_address = this->erisc_handshake_address + 16 * 3;//16;
// Really should be called offset_after_semaphore_region
this->semaphore_offset = this->semaphore_size * this->num_eth_buffers * num_duplicate_directions; // TODO: Remove this once dedicated semaphore space for user kernels are added
this->eth_buffers_l1_base_byte_address = this->eth_sems_l1_base_byte_address + this->semaphore_offset;
this->num_workers_per_link = this->num_eth_buffers;
this->eth_sems_l1_base_byte_address = this->erisc_handshake_address + 16 * 3;//16;
// Really should be called offset_after_semaphore_region
this->semaphore_offset = this->semaphore_size * this->num_eth_buffers * num_duplicate_directions; // TODO: Remove this once dedicated semaphore space for user kernels are added
this->eth_buffers_l1_base_byte_address = this->eth_sems_l1_base_byte_address + this->semaphore_offset;

std::size_t channel_sync_bytes_overhead = (enable_merged_payload_and_channel_sync * 16);
uint32_t const page_size = input_tensor.buffer()->page_size();
std::size_t l1_per_buffer_region = ((total_l1_buffer_space - this->semaphore_offset) / (this->num_eth_buffers * num_duplicate_directions * this->num_edm_buffers_per_channel)) - channel_sync_bytes_overhead;
this->eth_buffer_size = tt::round_down(l1_per_buffer_region, page_size);
std::size_t l1_per_buffer_region = ((total_l1_buffer_space - this->semaphore_offset) / (this->num_eth_buffers * num_duplicate_directions * this->num_edm_buffers_per_channel)) - channel_sync_bytes_overhead;
this->eth_buffer_size = tt::round_down(l1_per_buffer_region, page_size);

next_eth_num_eth_buffers_to_try--;
}

TT_FATAL((this->eth_buffer_size + channel_sync_bytes_overhead) * (this->num_eth_buffers * num_duplicate_directions * this->num_edm_buffers_per_channel) + this->semaphore_offset <= total_l1_buffer_space, "Error");
TT_FATAL(eth_buffer_size == 0 or (this->num_eth_buffers * num_duplicate_directions) <= eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS, "Error");
TT_FATAL(this->eth_buffer_size > 0, "Internal errorin all-gather when computing ethernet buffer size.");
TT_FATAL((this->eth_buffer_size + channel_sync_bytes_overhead) * (this->num_eth_buffers * num_duplicate_directions * this->num_edm_buffers_per_channel) + this->semaphore_offset <= total_l1_buffer_space, "Internal all-gather error when sizing buffers. Sized buffers too large for ethernet L1 capacity.");
}


Expand Down

0 comments on commit eb113ac

Please sign in to comment.