Skip to content

Commit

Permalink
#14479: resolve PCC issue and hang in line reduce scatter case
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanNijjar committed Oct 30, 2024
1 parent d4b7fc2 commit 9289131
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -619,10 +619,8 @@ def test_width_sharded_reduce_scatter_post_commit(
@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.timeout(120)
@pytest.mark.parametrize(
"num_devices, num_links",
[
(4, 2),
],
"num_devices",
[4],
)
@pytest.mark.parametrize("dim", [3])
@pytest.mark.parametrize(
Expand All @@ -641,7 +639,7 @@ def test_width_sharded_reduce_scatter_post_commit(
],
)
@pytest.mark.parametrize(
"per_chip_output_shape,output_shard_shape,shard_grid,in_shard_override,in_shard_grid_override",
"per_chip_output_shape,output_shard_shape,shard_grid,in_shard_override,in_shard_grid_override,num_links",
(
# LLama
(
Expand All @@ -650,17 +648,28 @@ def test_width_sharded_reduce_scatter_post_commit(
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(4, 1))}),
(32, 160),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 4))}),
2,
),
(
(1, 1, 32, 1280),
(32, 128),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(4, 1))}),
None,
None,
2,
),
(
(1, 1, 32, 320),
(32, 32),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(4, 1))}),
None,
None,
1,
),
),
)
@pytest.mark.parametrize("topology", [ttnn.Topology.Ring, ttnn.Topology.Linear])
@pytest.mark.parametrize("topology", [ttnn.Topology.Linear])
# @pytest.mark.parametrize("topology", [ttnn.Topology.Ring, ttnn.Topology.Linear])
@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum])
@pytest.mark.parametrize("enable_async", [True])
def test_width_sharded_reduce_scatter_post_commit_4chip(
Expand Down Expand Up @@ -704,6 +713,7 @@ def test_width_sharded_reduce_scatter_post_commit_4chip(
topology=topology,
enable_async=enable_async,
num_iters=num_iters,
n_worker=2,
)


Expand Down
18 changes: 8 additions & 10 deletions ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,11 +441,10 @@ std::vector<tt_xy_pair> RingReduceScatterBaseTensorSlicer<DERIVED_SLICER_T>::cre
tt::LogOp,
"Reduce Scatter more workers instantiated than is work to be done. Some workers will be idle and do "
"nothing");
num_workers = tensor_slice_shape_in_elems.y;
for (uint32_t w = 0; w < num_workers; ++w) {
for (uint32_t w = 0; w < tensor_slice_shape_in_elems.y; ++w) {
worker_slice_shapes.emplace_back(tensor_slice_shape_in_elems.x, 1);
}
for (uint32_t w = num_workers; w < tensor_slice_shape_in_elems.x; ++w) {
for (uint32_t w = tensor_slice_shape_in_elems.y; w < num_workers; ++w) {
worker_slice_shapes.emplace_back(0, 0);
}
return worker_slice_shapes;
Expand Down Expand Up @@ -713,22 +712,21 @@ std::vector<tt_xy_pair> RingReduceScatterWrappedTensorSlicer::create_worker_slic
std::size_t max_slice_size_in_tiles = max_slice_size_in_pages;

// Assign slices by assuming that the input tensor is flattened into a 1D Shape
std::size_t optim_worker_slice_len_tiles = ceil(total_num_tiles / num_workers); // Ceil so that the remainder worker will have a smaller slice
std::size_t optim_worker_slice_len_tiles = ((total_num_tiles - 1) / num_workers) + 1; // Ceil so that the remainder worker will have a smaller slice

if (max_slice_size_in_tiles < optim_worker_slice_len_tiles) { // Each worker will have a full slice
for (uint32_t w = 0; w < num_workers; ++w) {
worker_slice_shapes.emplace_back(max_slice_size_in_tiles, 1);
}
} else { // Each worker will only have one slice
uint32_t remainder_worker_len_tiles = total_num_tiles % optim_worker_slice_len_tiles;

size_t base_tiles_per_worker = total_num_tiles / num_workers;
size_t total_extra_tiles = total_num_tiles - (base_tiles_per_worker * num_workers);
for (uint32_t w = 0; w < num_workers; ++w) {
bool add_extra_tile = w < total_extra_tiles;
size_t remainder_tiles = add_extra_tile ? 1 : 0;
size_t num_tiles_this_worker = base_tiles_per_worker + remainder_tiles;
worker_slice_shapes.emplace_back(optim_worker_slice_len_tiles, 1);
}
// If there is a remainder worker, we need to adjust the last worker's slice shape to be smaller
if (remainder_worker_len_tiles > 0) {
worker_slice_shapes.back() = tt_xy_pair{remainder_worker_len_tiles, 1};
}
}

return worker_slice_shapes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,12 @@ class ChannelBuffer final {
*(channel_bytes_acked_addresses[i]) = 0;
}

if (TERMINATION_MODE != ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED || total_num_messages_to_move != 0) {
if (total_num_messages_to_move != 0) {
if (is_sender_side) {
// Tell the sender side workers that we're ready to accept data on this channel
increment_worker_semaphores();
}
} else {
ASSERT(TERMINATION_MODE != ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED);
goto_state(STATE::DONE);
}
}
Expand Down
12 changes: 4 additions & 8 deletions ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,8 @@ void kernel_main() {
(const WorkerXY *)workers_xy_list_addr,
false);

if constexpr (terminate_on_worker_signal == EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) {
if (receiver_num_messages_to_send == 0) {
num_receivers_with_no_work++;
}
if (receiver_num_messages_to_send == 0) {
num_receivers_with_no_work++;
}
}

Expand Down Expand Up @@ -218,10 +216,8 @@ void kernel_main() {
(volatile tt_l1_ptr uint32_t *const)sender_semaphores_base_address,
(const WorkerXY *)workers_xy_list_addr,
true);
if constexpr (terminate_on_worker_signal == EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) {
if (sender_num_messages_to_send == 0) {
num_senders_with_no_work++;
}
if (sender_num_messages_to_send == 0) {
num_senders_with_no_work++;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,11 @@ static void add_worker_config_to_edm_builders(
std::size_t expected_message_size_bytes = (num_buffers_per_channel == 1) ? tensor_slicer.get_worker_slice_size_bytes(worker_tensor_slice_index)
: sender_edm_builder.get_eth_buffer_size_bytes();
TT_ASSERT(worker_attrs.send_to_edm_semaphore_id.has_value(), "Internal error");
bool const channel_enabled = tensor_slicer.get_worker_slice_size_bytes(worker_tensor_slice_index) > 0;
ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& sender_channel_buffer_info =
sender_edm_builder.add_sender_channel(
worker_attrs.send_to_edm_semaphore_id.value(),
1,
channel_enabled,
sender_worker_coords,
expected_message_size_bytes);
edm_interface_addresses.worker_sender_edm_semaphore_addresses.insert(
Expand All @@ -139,12 +140,13 @@ static void add_worker_config_to_edm_builders(
std::size_t expected_message_size_bytes = (num_buffers_per_channel == 1) ? tensor_slicer.get_worker_slice_size_bytes(worker_tensor_slice_index)
: receiver_edm_builder.get_eth_buffer_size_bytes();
TT_ASSERT(worker_attrs.receive_from_edm_semaphore_id.has_value());
bool const channel_enabled = tensor_slicer.get_worker_slice_size_bytes(worker_tensor_slice_index) > 0;
ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& receiver_channel_buffer_info =
receiver_edm_builder.add_receiver_channel(
worker_attrs.receive_from_edm_semaphore_id.value(),
// Since we are in worker signal EDM termination mode, we don't need to set the actual number of
// messages the EDM must forward as it will receive its finish signal from the worker instead
1,
channel_enabled,
receiver_worker_coords,
expected_message_size_bytes);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,8 @@ void kernel_main() {
// For RM => shape in elements
std::size_t n_reads = 1;
uint32_t start_ring_index = args.my_ring_idx;
while (args.worker_slice_offset.x < args.tensor_slice_shape.x &&
bool work_to_do = args.tensor_slice_shape.x > 0 && args.tensor_slice_shape.y > 0;
while (work_to_do && args.worker_slice_offset.x < args.tensor_slice_shape.x &&
args.worker_slice_offset.y < args.tensor_slice_shape.y) {
// Need to reset back to the start ring index because the last iteration of the tranfers read chunks
// loop won't increment after the last iteration since the increment is within the loop body
Expand Down Expand Up @@ -556,6 +557,8 @@ void kernel_main() {
push_filler_pages_to_cb(cb_id_in1, 1);
}

reader.close();
if (work_to_do) {
reader.close();
}
WAYPOINT("DONE");
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ void kernel_main() {
writer_send_semaphore_addr_ptr);

uint32_t total_lifetime_cb_pages_popped_from_math = 0;
while (worker_slice_base_offset.x < output_tensor_shape.x && worker_slice_base_offset.y < output_tensor_shape.y) {
bool work_to_do = worker_slice_shape.x > 0 && worker_slice_shape.y > 0;
bool sends_to_edm = num_transfers > 0;
while (work_to_do && worker_slice_base_offset.x < output_tensor_shape.x && worker_slice_base_offset.y < output_tensor_shape.y) {
// First phase - we only forward messages to EDM
// Set the valid_worker_slice_shape
coord_t valid_worker_slice_shape = worker_slice_shape;
Expand Down Expand Up @@ -260,7 +262,8 @@ void kernel_main() {
pop_filler_pages_from_cb(cb_id_in0, 1);
}

if (num_transfers > 0) {
if (sends_to_edm) {
sender.close();
}
WAYPOINT("DONE");
}

0 comments on commit 9289131

Please sign in to comment.