diff --git a/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py b/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py index bd9e81bb9f07..d31a496dc983 100644 --- a/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py +++ b/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py @@ -111,6 +111,8 @@ def run_reduce_scatter_test( if enable_async: logger.info(f"Using Async Mode for Reduce Scatter Op Dispatch") + logger.info(f"Per chip output shape: {per_chip_output_shape}, devices: {num_devices}, scatter_dim: {scatter_dim}") + # Generate input tensors canonical_input_shape = per_chip_output_shape.copy() canonical_input_shape[scatter_dim] *= num_devices @@ -121,7 +123,6 @@ def run_reduce_scatter_test( torch.rand(canonical_input_shape).bfloat16() if not debug else torch.ones(canonical_input_shape).bfloat16() for _ in range(num_devices) ] - if debug: input_tensors[-1] = torch.arange(numel).reshape(canonical_input_shape).bfloat16() for i, canonical_input_tensor in enumerate(input_tensors): @@ -149,6 +150,7 @@ def run_reduce_scatter_test( ttnn.synchronize_device(t3k_mesh_device.get_device(device_id)) logger.info(f"Done iteration {i}") + # ttnn.visualize_mesh_device(t3k_mesh_device, tensor=output_tensor_mesh) # Compute golden # TODO: Make it model how reduce scatter actually works for numerical correctness/ordering golden_canonical_out_tensor = torch.zeros(canonical_input_shape).bfloat16() @@ -167,7 +169,7 @@ def run_reduce_scatter_test( eq, output = comp_pcc(tt_output_tensor, golden_output_tensors[i]) mismatch = mismatch or not eq if not eq: - logger.error(f"output mismatch for tensor {i}") + logger.error(f"output mismatch for tensor {i}. Mesh device ID: {t3k_mesh_device.get_devices()[i].id()}") if debug: for w in range(tt_output_tensor.shape[0]): for z in range(tt_output_tensor.shape[1]): @@ -263,6 +265,10 @@ def test_ring_reduce_scatter_post_commit( "per_chip_output_shape, scatter_dim, layout", [ ([1, 1, 32, 32 * 8], 3, ttnn.TILE_LAYOUT), + ([1, 2, 224, 32 * 8], 3, ttnn.TILE_LAYOUT), + ([1, 8, 1024, 1024], 3, ttnn.TILE_LAYOUT), + ([1, 4, 2048, 1024], 3, ttnn.TILE_LAYOUT), + ([1, 1, 128, 8192], 3, ttnn.TILE_LAYOUT), ], ) @pytest.mark.parametrize( diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp index e966446d6ae9..e31c9885031a 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp @@ -9,7 +9,6 @@ #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp" - using ttnn::ccl::ShardType; using ttnn::ccl::UNINITIALIZED_VALUE_U16; using ttnn::ccl::UNINITIALIZED_VALUE_U32; @@ -431,7 +430,7 @@ FORCE_INLINE void write_chunk_v2( } template -FORCE_INLINE void read_wrapped_chunk_from_output_tensor( +FORCE_INLINE void read_wrapped_chunk_from_output_tensor_to_address( uint32_t& curr_page_idx, uint32_t& offset_into_worker_slice, const ttnn::ccl::coord_t& offset_worker_slice, @@ -440,16 +439,14 @@ FORCE_INLINE void read_wrapped_chunk_from_output_tensor( // In tiles for tile layout const ttnn::ccl::coord_t& tensor_shape, const ttnn::ccl::coord_t& tensor_slice_shape, - const uint32_t cb_id, + const uint32_t local_l1_scratch_buffer_address, const AddrGen& s, const uint32_t num_pages, const uint32_t page_size, bool& last_page_of_worker) { // we expected caller to reset this and the last curr_page_idx when we set it true - ASSERT(last_page_of_worker == false); - cb_reserve_back(cb_id, num_pages); - uint32_t local_l1_read_addr = get_write_ptr(cb_id); + uint32_t local_l1_read_addr = local_l1_scratch_buffer_address; int32_t contig_pages = 1; for (uint32_t i = 0; i < num_pages; i+= contig_pages) { @@ -498,6 +495,40 @@ FORCE_INLINE void read_wrapped_chunk_from_output_tensor( local_l1_read_addr += page_size * contig_pages; } noc_async_read_barrier(); +} + +template +FORCE_INLINE void read_wrapped_chunk_from_output_tensor( + uint32_t& curr_page_idx, + uint32_t& offset_into_worker_slice, + const ttnn::ccl::coord_t& offset_worker_slice, + const ttnn::ccl::coord_t& worker_slice_shape, + + // In tiles for tile layout + const ttnn::ccl::coord_t& tensor_shape, + const ttnn::ccl::coord_t& tensor_slice_shape, + const uint32_t cb_id, + const AddrGen& s, + const uint32_t num_pages, + const uint32_t page_size, + bool& last_page_of_worker) { + + // we expected caller to reset this and the last curr_page_idx when we set it true + ASSERT(last_page_of_worker == false); + cb_reserve_back(cb_id, num_pages); + + read_wrapped_chunk_from_output_tensor_to_address( + curr_page_idx, + offset_into_worker_slice, + offset_worker_slice, + worker_slice_shape, + tensor_shape, + tensor_slice_shape, + get_write_ptr(cb_id), + s, + num_pages, + page_size, + last_page_of_worker); cb_push_back(cb_id, num_pages); } diff --git a/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send.cpp b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send.cpp index 155c054140df..5d636a0f2912 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send.cpp @@ -20,10 +20,42 @@ using ttnn::ccl::coord_t; // For the future using address_t = uint32_t; - - using ttnn::ccl::Shape4D; using tt::tt_metal::TensorMemoryLayout; +using shape_t = Shape4D; + +void dprint(ttnn::ccl::cmd::CclCommandTensor const& command_tensor) { + DPRINT << "\ttensor_slice_shape.w: " << (uint32_t)command_tensor.tensor_slice_shape.w << "\n"; + DPRINT << "\ttensor_slice_shape.z: " << (uint32_t)command_tensor.tensor_slice_shape.z << "\n"; + DPRINT << "\ttensor_slice_shape.y: " << (uint32_t)command_tensor.tensor_slice_shape.y << "\n"; + DPRINT << "\ttensor_slice_shape.x: " << (uint32_t)command_tensor.tensor_slice_shape.x << "\n"; + DPRINT << "\ttensor_slice_offset.w: " << (uint32_t)command_tensor.tensor_slice_offset.w << "\n"; + DPRINT << "\ttensor_slice_offset.z: " << (uint32_t)command_tensor.tensor_slice_offset.z << "\n"; + DPRINT << "\ttensor_slice_offset.y: " << (uint32_t)command_tensor.tensor_slice_offset.y << "\n"; + DPRINT << "\ttensor_slice_offset.x: " << (uint32_t)command_tensor.tensor_slice_offset.x << "\n"; + DPRINT << "\tworker_start_offset_in_slice.w: " << (uint32_t)command_tensor.worker_start_offset_in_slice.w << "\n"; + DPRINT << "\tworker_start_offset_in_slice.z: " << (uint32_t)command_tensor.worker_start_offset_in_slice.z << "\n"; + DPRINT << "\tworker_start_offset_in_slice.y: " << (uint32_t)command_tensor.worker_start_offset_in_slice.y << "\n"; + DPRINT << "\tworker_start_offset_in_slice.x: " << (uint32_t)command_tensor.worker_start_offset_in_slice.x << "\n"; + DPRINT << "\tworker_pages_per_slice: " << (uint32_t)command_tensor.worker_pages_per_slice << "\n"; +} + +void print_tensor_command(uint32_t command_index, ttnn::ccl::cmd::CclCommandTensor const& command_tensor) { +#ifdef DEBUG_PRINT_ENABLED + DPRINT << "cmd[" << (uint32_t)command_index << "]:\n"; + dprint(command_tensor); +#endif +} + +/* + * Convert a flattened worker offset coord value (assumed 0,0,0, worker offset in pages into tensor slice) + * into a 4D coordinate value + */ +inline shape_t worker_wrapped_offset_to_coord(shape_t const& slice_shape, shape_t const& worker_slice_offset) { + static_assert(sizeof(coord_t) == 2 * sizeof(uint32_t), "worker_wrapped_offset_to_coord not updated to work with 4d shape"); + auto const y = worker_slice_offset.x / slice_shape.x; + return shape_t(0, 0, y, worker_slice_offset.x - (y * slice_shape.x)); +} std::size_t get_flat_index_from_shape(const Shape4D &shape, const Shape4D &index) { std::size_t offset = index.x; @@ -153,7 +185,6 @@ auto build_source_address_generator(std::size_t &arg_idx, address_t tensor_addre */ void kernel_main() { std::size_t arg_idx = 0; - using shape_t = Shape4D; /////////////////////////////////////////////////// // ARGS @@ -191,6 +222,10 @@ void kernel_main() { ttnn::ccl::cmd::CclCommandTensor command_tensor; + // Don't use CBs because there appears to be a bug if we have the same producer/consumer core to a given CB + // Instead, open up the CB and use it as a raw scratch space6 + cb_reserve_back(cb_id, packet_size_in_pages); + const uint32_t local_l1_scratch_buffer_address = get_write_ptr(cb_id); for (std::size_t i = 0; i < num_commands; ++i) { // Generalized would be to get the command header info and then dispatch accordingly - if the command type is singular // @@ -199,20 +234,7 @@ void kernel_main() { std::size_t new_arg_idx = arg_idx; { - DPRINT << "cmd[" << (uint32_t)i << "]:\n"; - DPRINT << "\ttensor_slice_shape.w: " << (uint32_t)command_tensor.tensor_slice_shape.w << "\n"; - DPRINT << "\ttensor_slice_shape.z: " << (uint32_t)command_tensor.tensor_slice_shape.z << "\n"; - DPRINT << "\ttensor_slice_shape.y: " << (uint32_t)command_tensor.tensor_slice_shape.y << "\n"; - DPRINT << "\ttensor_slice_shape.x: " << (uint32_t)command_tensor.tensor_slice_shape.x << "\n"; - DPRINT << "\ttensor_slice_offset.w: " << (uint32_t)command_tensor.tensor_slice_offset.w << "\n"; - DPRINT << "\ttensor_slice_offset.z: " << (uint32_t)command_tensor.tensor_slice_offset.z << "\n"; - DPRINT << "\ttensor_slice_offset.y: " << (uint32_t)command_tensor.tensor_slice_offset.y << "\n"; - DPRINT << "\ttensor_slice_offset.x: " << (uint32_t)command_tensor.tensor_slice_offset.x << "\n"; - DPRINT << "\tworker_start_offset_in_slice.w: " << (uint32_t)command_tensor.worker_start_offset_in_slice.w << "\n"; - DPRINT << "\tworker_start_offset_in_slice.z: " << (uint32_t)command_tensor.worker_start_offset_in_slice.z << "\n"; - DPRINT << "\tworker_start_offset_in_slice.y: " << (uint32_t)command_tensor.worker_start_offset_in_slice.y << "\n"; - DPRINT << "\tworker_start_offset_in_slice.x: " << (uint32_t)command_tensor.worker_start_offset_in_slice.x << "\n"; - DPRINT << "\tworker_pages_per_slice: " << (uint32_t)command_tensor.worker_pages_per_slice << "\n"; + print_tensor_command(i, command_tensor); ASSERT(ccl_command.worker_pages_per_slice > 0); // CURRENTLY ONLY SUPPORTS WRAPPED TENSOR ITERATION COMMANDS @@ -221,7 +243,9 @@ void kernel_main() { // const shape_t tensor_slice_start_offset = ttnn::ccl::build_from_args(arg_idx); // Should be RT shape_t valid_worker_slice_shape = build_wrapped_row_tensor_slice(command_tensor.worker_pages_per_slice); // Parametrizable by ct arg - shape_t const& global_offset = command_tensor.tensor_slice_offset + command_tensor.worker_start_offset_in_slice; + shape_t const& worker_start_offset_global = worker_wrapped_offset_to_coord(command_tensor.tensor_slice_shape, command_tensor.worker_start_offset_in_slice); + shape_t const& global_offset = command_tensor.tensor_slice_offset + worker_start_offset_global; + uint32_t curr_tile_id = get_flat_index_from_shape(command_tensor.tensor_shape, global_offset); uint32_t offset_into_worker_slice = 0; @@ -237,7 +261,7 @@ void kernel_main() { ASSERT(ccl_command.tensor_slice_shape.w == 1); ASSERT(ccl_command.tensor_slice_shape.z == 1); - read_wrapped_chunk_from_output_tensor( + read_wrapped_chunk_from_output_tensor_to_address( curr_tile_id, offset_into_worker_slice, ttnn::ccl::coord_t(command_tensor.worker_start_offset_in_slice.x, command_tensor.worker_start_offset_in_slice.y), // Offset into tensor slice @@ -245,7 +269,7 @@ void kernel_main() { // In tiles for tile layout ttnn::ccl::coord_t(command_tensor.tensor_shape.x, command_tensor.tensor_shape.y), ttnn::ccl::coord_t(command_tensor.tensor_slice_shape.x, command_tensor.tensor_slice_shape.y), - cb_id, + local_l1_scratch_buffer_address, tensor_addrgen, n_pages, page_size, @@ -253,11 +277,8 @@ void kernel_main() { // Not optimal (doesn't overlap read/write) - but good for functional // bringup - cb_wait_front(cb_id, n_pages); - uint32_t l1_read_addr = get_read_ptr(cb_id); - sender.wait_for_empty_write_slot(); - sender.send_payload_blocking(cb_id, n_pages, page_size); + sender.send_payload_blocking_from_address(local_l1_scratch_buffer_address, n_pages, page_size); } } } diff --git a/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command_device.hpp b/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command_device.hpp index 80aa51417978..3d7d8c91bd93 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command_device.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command_device.hpp @@ -6,7 +6,9 @@ #include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp" +#ifdef DEBUG_PRINT_ENABLED #include "debug/dprint.h" +#endif #include #include @@ -44,7 +46,9 @@ namespace cmd { void update_command_tensor(std::size_t &arg_idx, CclCommandTensor &cmd_tensor) { auto cmd = CclCommandHeader::from_uint32(get_arg_val(arg_idx++)); + #ifdef DEBUG_PRINT_ENABLED DPRINT << "CMD (code=" << (uint32_t)cmd.code << ", arg_count=" << (uint32_t)cmd.arg_count << ")\n"; + #endif for (std::size_t i = 0; i < cmd.arg_count; i++) { @@ -55,32 +59,45 @@ void update_command_tensor(std::size_t &arg_idx, CclCommandTensor &cmd_tensor) { switch (static_cast(get_arg_val(arg_idx++))) { case CclCommandArgCode::SET_TENSOR_SHAPE_IN_PAGES: CclCommandArg::unpack(reinterpret_cast(get_arg_addr(arg_idx)), cmd_tensor.tensor_shape); + #ifdef DEBUG_PRINT_ENABLED DPRINT << "Updating tensor shape: (w=" << (uint32_t)cmd_tensor.tensor_shape.w << ", z=" << (uint32_t)cmd_tensor.tensor_shape.z << ", y=" << (uint32_t)cmd_tensor.tensor_shape.y << ", x=" << (uint32_t)cmd_tensor.tensor_shape.x << ")\n"; + #endif arg_idx += CclCommandArg::size_in_words(); break; case CclCommandArgCode::SET_TENSOR_SLICE_SHAPE_IN_PAGES: CclCommandArg::unpack(reinterpret_cast(get_arg_addr(arg_idx)), cmd_tensor.tensor_slice_shape); + #ifdef DEBUG_PRINT_ENABLED DPRINT << "Updating tensor slice shape: (w=" << (uint32_t)cmd_tensor.tensor_slice_shape.w << ", z=" << (uint32_t)cmd_tensor.tensor_slice_shape.z << ", y=" << (uint32_t)cmd_tensor.tensor_slice_shape.y << ", x=" << (uint32_t)cmd_tensor.tensor_slice_shape.x << ")\n"; + #endif arg_idx += CclCommandArg::size_in_words(); break; case CclCommandArgCode::SET_TENSOR_SLICE_OFFSET_IN_PAGES: CclCommandArg::unpack( reinterpret_cast(get_arg_addr(arg_idx)), cmd_tensor.tensor_slice_offset); + #ifdef DEBUG_PRINT_ENABLED + DPRINT << "Updating tensor slice offset: (w=" << (uint32_t)cmd_tensor.tensor_slice_offset.w << ", z=" << (uint32_t)cmd_tensor.tensor_slice_offset.z << ", y=" << (uint32_t)cmd_tensor.tensor_slice_offset.y << ", x=" << (uint32_t)cmd_tensor.tensor_slice_offset.x << ")\n"; + #endif arg_idx += CclCommandArg::size_in_words(); break; case CclCommandArgCode::SET_WORKER_START_OFFSET_IN_SLICE_IN_PAGES: CclCommandArg::unpack(reinterpret_cast(get_arg_addr(arg_idx)), cmd_tensor.worker_start_offset_in_slice); + #ifdef DEBUG_PRINT_ENABLED DPRINT << "Updating worker start offset in slice: (w=" << (uint32_t)cmd_tensor.worker_start_offset_in_slice.w << ", z=" << (uint32_t)cmd_tensor.worker_start_offset_in_slice.z << ", y=" << (uint32_t)cmd_tensor.worker_start_offset_in_slice.y << ", x=" << (uint32_t)cmd_tensor.worker_start_offset_in_slice.x << ")\n"; + #endif arg_idx += CclCommandArg::size_in_words(); break; case CclCommandArgCode::SET_WORKER_PAGES_PER_SLICE: CclCommandArg::unpack(reinterpret_cast(get_arg_addr(arg_idx)), cmd_tensor.worker_pages_per_slice); + #ifdef DEBUG_PRINT_ENABLED DPRINT << "Updating worker pages per slice: " << (uint32_t)cmd_tensor.worker_pages_per_slice << "\n"; + #endif arg_idx += CclCommandArg::size_in_words(); break; case CclCommandArgCode::SET_FULL_TENSOR_SLICE_SPEC_IN_PAGES: CclCommandArg::unpack(reinterpret_cast(get_arg_addr(arg_idx)), cmd_tensor); + #ifdef DEBUG_PRINT_ENABLED DPRINT << "Updating full tensor slice spec: (tensor_shape: w=" << (uint32_t)cmd_tensor.tensor_shape.w << ", z=" << (uint32_t)cmd_tensor.tensor_shape.z << ", y=" << (uint32_t)cmd_tensor.tensor_shape.y << ", x=" << (uint32_t)cmd_tensor.tensor_shape.x << ")\n"; + #endif arg_idx += CclCommandArg::size_in_words(); break; default: diff --git a/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp index d1f98013b1b1..b4f1f2b02ede 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp @@ -108,6 +108,21 @@ struct WorkerToEdmSender{ send_payload_impl(cb_id, num_pages, page_size); } + /* + * No CB + */ + FORCE_INLINE void send_payload_blocking_from_address(uint32_t source_address, uint32_t num_pages, uint32_t page_size) { + send_payload_from_address_impl(source_address, num_pages, page_size); + } + + /* + * No CB + */ + // Does not wait for CB. Assumes caller handles CB data availability + FORCE_INLINE void send_payload_non_blocking_from_address(uint32_t source_address, uint32_t num_pages, uint32_t page_size) { + send_payload_from_address_impl(source_address, num_pages, page_size); + } + FORCE_INLINE void close() { if constexpr (termination_mode == ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED) { this->wait_for_empty_write_slot(); @@ -126,6 +141,15 @@ struct WorkerToEdmSender{ std::size_t buffer_index; private: + template + FORCE_INLINE void send_payload_from_address_impl(uint32_t source_address, uint32_t num_pages, uint32_t page_size) { + uint64_t buffer_address = this->edm_buffer_addr + (this->buffer_index * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + ASSERT(num_pages * page_size <= this->buffer_size_bytes); + send_chunk_from_address(source_address, num_pages, page_size, buffer_address); + noc_semaphore_inc(edm_semaphore_addr, 1); + this->buffer_index = (this->buffer_index == this->last_buffer_index) ? 0 : this->buffer_index + 1; + } + template FORCE_INLINE void send_payload_impl(uint32_t cb_id, uint32_t num_pages, uint32_t page_size) { uint64_t buffer_address = this->edm_buffer_addr + (this->buffer_index * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); diff --git a/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp index 09ea561de431..07a1f0e433c1 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp @@ -49,6 +49,15 @@ FORCE_INLINE void fetch_chunk( cb_push_back(cb_id, num_pages); } +template +FORCE_INLINE void send_chunk_from_address( + const uint32_t& local_l1_address, const uint32_t& num_pages, const uint32_t& page_size, uint64_t remote_l1_write_addr) { + noc_async_write(local_l1_address, remote_l1_write_addr, page_size * num_pages); + if constexpr (blocking_mode == ttnn::ccl::EDM_IO_BLOCKING_MODE::BLOCKING) { + noc_async_write_barrier(); + } +} + template FORCE_INLINE void send_chunk( const uint32_t& cb_id, const uint32_t& num_pages, const uint32_t& page_size, uint64_t remote_l1_write_addr) { diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp index a471a12d36bc..fc29a4255107 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp @@ -611,7 +611,8 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( TT_ASSERT(input_tensor_num_units_per_tensor_slice > 0); constexpr bool enable_bidirectional = true; - uint32_t max_num_workers = std::min(user_defined_num_workers.value_or(topology == Topology::Linear ? 2 : 8), input_tensor_num_units_per_tensor_slice); + constexpr std::size_t default_num_workers = 8; + uint32_t max_num_workers = std::min(user_defined_num_workers.value_or(default_num_workers), input_tensor_num_units_per_tensor_slice); if (topology == ttnn::ccl::Topology::Linear) { max_num_workers = std::max(max_num_workers, 2); }