Skip to content

Commit

Permalink
#10415: add new full-tensor bidirectional mode to all-gather
Browse files Browse the repository at this point in the history
The new bidirectional all-gather mode is being added as a prerequisite
to all-gather + matmul fusion. In addition, this change also leads to
performance improvements, particularly for smaller all-gathers because
fewer end-to-end latencies add up for what tensor to be single packet
per channel/per ring index.

The new mode sends the full input tensor for a given tensor both
directions around the ring, but only halfway around the ring in each
direction. This is in contrast to the prior default mode (SPLIT_TENSOR)
which would send half of the input tensor each direction, but the full
way around the ring.

This new mode is not enabled yet for sharded all-gather.
  • Loading branch information
SeanNijjar committed Jul 21, 2024
1 parent f52ec95 commit 0219a91
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 54 deletions.
1 change: 1 addition & 0 deletions tests/ttnn/unit_tests/operations/test_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,7 @@ def run_all_gather_sharded(
else:
eq, output = comp_pcc(tt_output_tensor, unchunked_input_tensor)
if not eq:
all_eq = False
logger.error(f"output mismatch for tensor {i}")
assert all_eq, f"{i} FAILED: {output}"

Expand Down
37 changes: 34 additions & 3 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ enum AllGatherMode {
SINGLE_TILE_HIGH_WIDTH_SHARDED
};

enum AllGatherBidirectionalMode {
// Splits the tensor into two and sends each half in opposite directions
// the full width around the ring
SPLIT_TENSOR,
// Doesn't split the tensor and sends the full tensor in both directions,
// half-way around the ring
FULL_TENSOR
};

namespace all_gather_op {
using ccl::Topology;
}; // namespace all_gather_op
Expand All @@ -37,6 +46,17 @@ using ccl::EriscDatamoverBuilder;
AllGatherMode choose_all_gather_mode(Tensor const& input_tensor, Tensor const& output_tensor, uint32_t dim);

class AllGatherConfig {
static AllGatherBidirectionalMode choose_bidirectional_mode(Tensor const& input_tensor) {
std::size_t eth_l1_capacity = eth_l1_mem::address_map::MAX_L1_LOADING_SIZE - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE;
std::size_t tensor_size_bytes = input_tensor.shape().volume() * input_tensor.element_size();
// This is currently a guestimate. We need a lot more hard data to identify where this dividing line is.
bool perf_degradation_from_full_tensor_mode = tensor_size_bytes > (2 * eth_l1_capacity);
if (input_tensor.is_sharded() || perf_degradation_from_full_tensor_mode) {
return AllGatherBidirectionalMode::SPLIT_TENSOR;
}
return AllGatherBidirectionalMode::FULL_TENSOR;
}

public:
AllGatherConfig(Tensor const& input_tensor, Tensor const& output_tensor, uint32_t dim, uint32_t ring_size, uint32_t num_links, all_gather_op::Topology topology) :
num_links(num_links),
Expand All @@ -50,7 +70,11 @@ class AllGatherConfig {
input_is_dram(input_tensor.buffer()->buffer_type() == BufferType::DRAM),
output_is_dram(output_tensor.buffer()->buffer_type() == BufferType::DRAM),

mode(choose_all_gather_mode(input_tensor, output_tensor, dim))
mode(choose_all_gather_mode(input_tensor, output_tensor, dim)),

// Sharded currently doesn't support FULL_TENSOR bidirectional due to indexers that require updating in order to support this
// new mode
bidirectional_mode(choose_bidirectional_mode(input_tensor))
{
TT_ASSERT(erisc_handshake_address >= eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE);
TT_ASSERT(erisc_handshake_address < eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE + 16);
Expand All @@ -68,12 +92,17 @@ class AllGatherConfig {

// "duplicate" directions are a short hand to enable linear/mesh all-gather topologies with
// less code-changes. Ideally a new concept is added amongst "num_eth_buffers", "num_workers_per_link", etc.
uint32_t num_duplicate_directions = topology == all_gather_op::Topology::Ring ? 1 : 2;
uint32_t num_duplicate_directions = (topology == all_gather_op::Topology::Ring && bidirectional_mode != AllGatherBidirectionalMode::FULL_TENSOR) ? 1 : 2;

constexpr uint32_t total_l1_buffer_space = eth_l1_mem::address_map::MAX_L1_LOADING_SIZE - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE;

this->is_sharded = input_tensor.is_sharded();
this->num_eth_buffers = (this->enable_bidirectional ? 8 : (this->is_sharded && topology != all_gather_op::Topology::Linear ? 8 : 4));
this->num_eth_buffers = (this->enable_bidirectional ? 8 /*1*/ : (this->is_sharded && topology != all_gather_op::Topology::Linear ? 8 : 4));

if (bidirectional_mode == AllGatherBidirectionalMode::FULL_TENSOR) {
this->num_eth_buffers = std::min(this->num_eth_buffers, eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS / num_duplicate_directions);
}

if (this->is_sharded) {
this->num_eth_buffers = std::min(this->num_eth_buffers, input_tensor.shard_spec()->num_cores());
if ((input_tensor.shard_spec()->num_cores() / this->num_eth_buffers) % (ring_size) != 0 &&
Expand Down Expand Up @@ -135,6 +164,7 @@ class AllGatherConfig {
buffer_index < get_num_edm_channels_in_clockwise_direction() :
true;
}
AllGatherBidirectionalMode get_bidirectional_mode() const { return this->bidirectional_mode; }
uint32_t get_num_edm_channels_in_counter_clockwise_direction() const {
// return all_gather_buffer_params::enable_bidirectional ? all_gather_buffer_params::num_buffers - all_gather_buffer_params::num_buffers / 2 : 0;
// Force all through counter-clockwise direction
Expand Down Expand Up @@ -174,6 +204,7 @@ class AllGatherConfig {
uint32_t eth_sems_l1_base_byte_address;
const all_gather_op::Topology topology;
AllGatherMode mode;
AllGatherBidirectionalMode bidirectional_mode;
bool is_sharded;
bool enable_bidirectional;
const bool input_is_dram;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,4 @@ void kernel_main() {
pop_filler_pages_from_cb(cb_id_in0, half_cb_n_pages - rem_num_pages);
}
}

}
Loading

0 comments on commit 0219a91

Please sign in to comment.