Skip to content

Commit

Permalink
Enable double/triple buffered EDM channels for reduce-scatter (#11982)
Browse files Browse the repository at this point in the history
* #11853: enabled edm sender/reader objects for reduce scatter

* #11982: enabled double buffering for reduce scatter

* #11982: reverted expected message size and set it as eth buffer size

* #11982: default num buffers to 1 and enabled shrinking for buffer size 1
  • Loading branch information
caixunshiren authored Sep 6, 2024
1 parent 90e4c48 commit 70a1c7b
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,17 @@ struct ReduceScatterWorkerArgBuilder {
WorkerTransferInfo const& worker_transfer_info,
uint32_t cb_num_pages_per_packet,
uint32_t worker_sender_semaphore_id,
uint32_t worker_receiver_semaphore_id) :
uint32_t worker_receiver_semaphore_id,
uint32_t num_buffers_per_channel) :
device(device),
op_config(op_config),
topology_config(topology_config),
worker_input_slice(worker_input_slice),
worker_transfer_info(worker_transfer_info),
cb_num_pages_per_packet(cb_num_pages_per_packet),
worker_sender_semaphore_id(worker_sender_semaphore_id),
worker_receiver_semaphore_id(worker_receiver_semaphore_id) {
worker_receiver_semaphore_id(worker_receiver_semaphore_id),
num_buffers_per_channel(num_buffers_per_channel) {
}

uint32_t get_total_num_math_pages(uint32_t link, uint32_t worker_idx) const {
Expand Down Expand Up @@ -121,13 +123,14 @@ struct ReduceScatterWorkerArgBuilder {
auto const& local_input_tensor = this->op_config.get_input_tensor(0);
auto args = std::vector<uint32_t>{
static_cast<uint32_t>(this->op_config.is_input_sharded() ? 1 : 0),
static_cast<uint32_t>(
this->op_config.get_input_tensor(0).memory_config().buffer_type == BufferType::DRAM ? 1 : 0)};
static_cast<uint32_t>(this->op_config.get_input_tensor(0).memory_config().buffer_type == BufferType::DRAM ? 1 : 0),
static_cast<uint32_t>(this->num_buffers_per_channel)};

std::size_t i = 0;
log_trace(tt::LogOp, "Reduce Scatter Receiver Worker CT Args:");
log_trace(tt::LogOp, "\tis_sharded: {}", args.at(i++));
log_trace(tt::LogOp, "\tsrc_is_dram: {}", args.at(i++));
log_trace(tt::LogOp, "\tnum_buffers_per_channel: {}", args.at(i++));
TT_ASSERT(args.size() == i, "Missed some args");

if (local_input_tensor.is_sharded()) {
Expand Down Expand Up @@ -229,13 +232,14 @@ struct ReduceScatterWorkerArgBuilder {
auto const& local_output_tensor = this->op_config.get_output_tensor(0);
auto args = std::vector<uint32_t>{
static_cast<uint32_t>(this->op_config.is_input_sharded() ? 1 : 0),
static_cast<uint32_t>(
this->op_config.get_output_tensor(0).memory_config().buffer_type == BufferType::DRAM ? 1 : 0)};
static_cast<uint32_t>(this->op_config.get_output_tensor(0).memory_config().buffer_type == BufferType::DRAM ? 1 : 0),
static_cast<uint32_t>(this->num_buffers_per_channel)};

std::size_t i = 0;
log_trace(tt::LogOp, "Reduce Scatter Sender Worker CT Args:");
log_trace(tt::LogOp, "\tis_sharded: {}", args.at(i++));
log_trace(tt::LogOp, "\tdst_is_dram: {}", args.at(i++));
log_trace(tt::LogOp, "\tnum_buffers_per_channel: {}", args.at(i++));
TT_ASSERT(args.size() == i, "Missed some args");

if (local_output_tensor.is_sharded()) {
Expand Down Expand Up @@ -327,6 +331,7 @@ struct ReduceScatterWorkerArgBuilder {
uint32_t cb_num_pages_per_packet;
uint32_t worker_sender_semaphore_id;
uint32_t worker_receiver_semaphore_id;
uint32_t num_buffers_per_channel;

bool src_is_dram;
bool dst_is_dram;
Expand All @@ -349,6 +354,7 @@ static void add_worker_config_to_edm_builders(
ccl::CCLOpConfig const& op_config,
std::vector<CoreCoord> const& worker_cores,
uint32_t num_channels_per_edm,
uint32_t num_buffers_per_channel,

std::vector<ttnn::ccl::EriscDatamoverBuilder>& clockwise_edm_builders,
std::vector<ttnn::ccl::EriscDatamoverBuilder>& counter_clockwise_edm_builders,
Expand Down Expand Up @@ -376,7 +382,8 @@ static void add_worker_config_to_edm_builders(
}

// Get the maximum message size we'd like to use. Not the actual packet size
uint32_t expected_message_size_bytes = tensor_slicer.get_worker_slice_size_bytes(global_worker_idx);
uint32_t expected_message_size_bytes = (num_buffers_per_channel == 1) ? tensor_slicer.get_worker_slice_size_bytes(global_worker_idx)
: clockwise_edm_builders.at(link).get_eth_buffer_size_bytes();

bool sender_enabled = true; // (!is_linear || !is_last_chip_in_chain); // update for linear
if (sender_enabled) {
Expand Down Expand Up @@ -825,6 +832,7 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers(
op_config,
worker_cores,
num_edm_channels,
num_buffers_per_channel,

cw_per_link_edm_builders,
ccw_per_link_edm_builders,
Expand All @@ -848,7 +856,8 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers(
worker_transfer_info,
cb_num_pages_per_packet,
worker_sender_semaphore_id,
worker_receiver_semaphore_id);
worker_receiver_semaphore_id,
num_buffers_per_channel);
auto [worker_receiver_kernel_id, worker_sender_kernel_id, worker_reduce_kernel_id] = build_reduce_scatter_worker_ct(
program,
op_config,
Expand All @@ -875,7 +884,8 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers(
worker_transfer_info,
cb_num_pages_per_packet,
worker_sender_semaphore_id,
worker_receiver_semaphore_id);
worker_receiver_semaphore_id,
num_buffers_per_channel);

log_trace(tt::LogOp, "worker_cores.at(global_worker_index): {}", worker_cores.at(global_worker_index));
set_reduce_scatter_worker_rt(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp"

#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp"
#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp"

using ttnn::ccl::coord_t;
using ttnn::ccl::WorkerXY;
Expand Down Expand Up @@ -97,19 +98,20 @@ constexpr bool is_sharded = get_compile_time_arg_val(0) == 1;

// Currently meaningless when `is_sharded=true`
constexpr bool src_is_dram = get_compile_time_arg_val(1) == 1;
constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(2);
static constexpr tt::tt_metal::TensorMemoryLayout input_tensor_memory_layout =
static_cast<tt::tt_metal::TensorMemoryLayout>(get_compile_time_arg_val(2));
static_cast<tt::tt_metal::TensorMemoryLayout>(get_compile_time_arg_val(3));

// TODO: clean this up
#ifdef SHARDED_MEM_LAYOUT
static constexpr bool is_sharded_mode = true;
static constexpr uint32_t input_tensor_shard_grid_height = get_compile_time_arg_val(3);
static constexpr uint32_t input_tensor_shard_grid_width = get_compile_time_arg_val(4);
static constexpr uint32_t input_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(5);
static constexpr uint32_t input_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(6);
static constexpr uint32_t input_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(7);
static constexpr uint32_t input_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(8);
static constexpr bool input_tensor_shard_grid_transposed = get_compile_time_arg_val(9) != 0;
static constexpr uint32_t input_tensor_shard_grid_height = get_compile_time_arg_val(4);
static constexpr uint32_t input_tensor_shard_grid_width = get_compile_time_arg_val(5);
static constexpr uint32_t input_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(6);
static constexpr uint32_t input_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(7);
static constexpr uint32_t input_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(8);
static constexpr uint32_t input_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(9);
static constexpr bool input_tensor_shard_grid_transposed = get_compile_time_arg_val(10) != 0;
#else
static constexpr bool is_sharded_mode = false;
static constexpr uint32_t input_tensor_shard_grid_height = 0;
Expand Down Expand Up @@ -225,14 +227,19 @@ void kernel_main() {

volatile tt_l1_ptr uint32_t* receiver_read_semaphore_addr_ptr =
reinterpret_cast<volatile tt_l1_ptr uint32_t*>(args.sem_addr);
const uint64_t eth_receiver_l1_base_noc_addr =
get_noc_addr(args.edm_core_noc0_core_x, args.edm_core_noc0_core_y, args.edm_core_buffer_address);
const uint64_t eth_receiver_l1_semaphore_noc_addr =
get_noc_addr(args.edm_core_noc0_core_x, args.edm_core_noc0_core_y, args.edm_core_semaphore_address);

uint32_t total_cb_pages_pushed = 0;
uint32_t total_cb_pages_pushed_to_math = 0;

ccl::edm::WorkerToEdmReader<ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED> reader(
ttnn::ccl::WorkerXY(args.edm_core_noc0_core_x, args.edm_core_noc0_core_y),
args.edm_core_buffer_address,
num_buffers_per_channel,
args.edm_core_semaphore_address,
// (num_full_chunks > 0 ? args.full_chunk_num_pages : rem_num_pages) * args.page_size,
args.full_chunk_num_pages * args.page_size,
receiver_read_semaphore_addr_ptr);

// For the first timestep, there is no other input to reduce with, so we just send it straight to the input CB
// of the output data movement kernel - short-circuiting past the (reducer) math kernel
// For tile => shape in tiles
Expand Down Expand Up @@ -332,21 +339,16 @@ void kernel_main() {
n_pages,
args.page_size,
last_page_of_worker);
uint64_t eth_receiver_l1_curr_noc_addr = eth_receiver_l1_base_noc_addr;

// Fetch from EDM
noc_semaphore_wait(receiver_read_semaphore_addr_ptr, 1);
noc_semaphore_set(receiver_read_semaphore_addr_ptr, 0);
fetch_chunk(cb_id_in0, n_pages, args.page_size, eth_receiver_l1_base_noc_addr);
bool last_worker_message_to_edm = last_transfer && last_slice_of_worker && (p + n_pages >= worker_slice_n_pages);

reader.wait_for_payload_available();
reader.fetch_payload_blocking(cb_id_in0, n_pages, args.page_size, last_worker_message_to_edm);

total_cb_pages_pushed_to_math += n_pages;
total_cb_pages_pushed += n_pages;

bool last_worker_message_to_edm = last_transfer && last_slice_of_worker && (p + n_pages >= worker_slice_n_pages);
if (!last_worker_message_to_edm) {
noc_semaphore_inc(
eth_receiver_l1_semaphore_noc_addr,
ttnn::ccl::EriscDataMoverWorkerSignal::NEXT_MESSAGE_AVAILABLE);
}
if (n_pages < args.half_cb_n_pages) {
uint32_t num_filler_pages = args.half_cb_n_pages - n_pages;
push_filler_pages_to_cb(cb_id_in0, num_filler_pages);
Expand All @@ -372,8 +374,6 @@ void kernel_main() {
push_filler_pages_to_cb(cb_id_in1, 1);
}

noc_semaphore_inc(
eth_receiver_l1_semaphore_noc_addr,
ttnn::ccl::EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY);
reader.close();
DEBUG_STATUS("DONE");
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@ using ttnn::ccl::coord_t;
void kernel_main() {
constexpr bool is_sharded = get_compile_time_arg_val(0) == 1;
constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1;
constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(2);

constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout =
static_cast<tt::tt_metal::TensorMemoryLayout>(get_compile_time_arg_val(2));
static_cast<tt::tt_metal::TensorMemoryLayout>(get_compile_time_arg_val(3));
#ifdef SHARDED_MEM_LAYOUT
constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(3);
constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(4);
constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(5);
constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(6);
constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(7);
constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(8);
constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(9) != 0;
constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(4);
constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(5);
constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(6);
constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(7);
constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(8);
constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(9);
constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(10) != 0;
#endif

uint32_t arg_idx = 0;
Expand Down Expand Up @@ -111,12 +112,15 @@ void kernel_main() {
// Used to wait until eth sender has space available
volatile tt_l1_ptr uint32_t* writer_send_semaphore_addr_ptr =
reinterpret_cast<volatile tt_l1_ptr uint32_t*>(writer_send_sem_addr);
// This is different per writer core
const uint64_t eth_l1_sender_base_noc_addr =
get_noc_addr(eth_sender_noc_x, eth_sender_noc_y, eth_sender_l1_base_addr);
// Used to signal eth sender that data is available. This is different per writer core
const uint64_t eth_l1_sender_semaphore_addr =
get_noc_addr(eth_sender_noc_x, eth_sender_noc_y, eth_sender_l1_sem_addr);

ccl::edm::WorkerToEdmSender<ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED> sender(
ttnn::ccl::WorkerXY(eth_sender_noc_x, eth_sender_noc_y),
eth_sender_l1_base_addr,
num_buffers_per_channel,
eth_sender_l1_sem_addr,
// (num_full_chunks > 0 ? num_pages_per_full_chunk : rem_num_pages) * page_size,
full_chunk_num_pages * page_size,
writer_send_semaphore_addr_ptr);

uint32_t total_lifetime_cb_pages_popped_from_math = 0;
while (worker_slice_base_offset.x < output_tensor_shape.x && worker_slice_base_offset.y < output_tensor_shape.y) {
Expand All @@ -136,12 +140,9 @@ void kernel_main() {
for (uint32_t p = 0; p < num_pages_to_write; p += full_chunk_num_pages) {
uint32_t n_pages = std::min(full_chunk_num_pages, num_pages_to_write - p);
ASSERT(n_pages > 0);
noc_semaphore_wait(writer_send_semaphore_addr_ptr, 1);
noc_semaphore_set(writer_send_semaphore_addr_ptr, 0);
send_chunk(cb_in, n_pages, page_size, eth_l1_sender_base_noc_addr);
noc_semaphore_inc(
eth_l1_sender_semaphore_addr,
ttnn::ccl::EriscDataMoverWorkerSignal::NEXT_MESSAGE_AVAILABLE);
sender.wait_for_empty_write_slot();
sender.send_payload_blocking(cb_in, n_pages, page_size);

if (i != 0) {
total_lifetime_cb_pages_popped_from_math += n_pages;
}
Expand Down Expand Up @@ -205,8 +206,5 @@ void kernel_main() {
pop_filler_pages_from_cb(cb_id_in0, 1);
}

noc_semaphore_wait(writer_send_semaphore_addr_ptr, 1);
noc_semaphore_set(writer_send_semaphore_addr_ptr, 0);
noc_semaphore_inc(
eth_l1_sender_semaphore_addr, ttnn::ccl::EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY);
sender.close();
}

0 comments on commit 70a1c7b

Please sign in to comment.