diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index 7070e631abe5..f9e409046017 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -26,14 +26,27 @@ AllGather create_all_gather_struct( uint32_t num_devices = devices.size(); uint32_t device_index = 0; // Initialize device index - uint32_t receiver_device_id = 0; // Initialize receiver device ID - uint32_t sender_device_id = 0; // Initialize sender device ID + std::optional receiver_device_id = std::nullopt; // Initialize receiver device ID + std::optional sender_device_id = std::nullopt; // Initialize sender device ID for (uint32_t i = 0; i < num_devices; ++i) { if (devices[i] == input_tensor.device()) { device_index = i; - receiver_device_id = devices[(i + 1) % num_devices]->id(); // Next device in the ring - sender_device_id = devices[(i + num_devices - 1) % num_devices]->id(); // Previous device in the ring + if (topology == ttnn::ccl::Topology::Ring) { + // Ring topology + receiver_device_id = devices[(i + 1) % num_devices]->id(); // Next device in the ring + sender_device_id = devices[(i + num_devices - 1) % num_devices]->id(); // Previous device in the ring + } else if (topology == ttnn::ccl::Topology::Linear) { + // Linear topology + bool is_last_chip_in_clockwise_direction = i == (num_devices - 1); + bool is_last_chip_in_counter_clockwise_direction = i == 0; + receiver_device_id = is_last_chip_in_clockwise_direction ? + std::nullopt : + std::optional(devices.at(i+1)->id()); + sender_device_id = is_last_chip_in_counter_clockwise_direction ? + std::nullopt : + std::optional(devices.at(i-1)->id()); + } break; } } @@ -220,7 +233,7 @@ Tensor all_gather( const std::optional user_defined_num_buffers_per_channel, const ttnn::ccl::Topology topology) { - TT_FATAL(topology != ttnn::ccl::Topology::Linear, "This api currently supported only for Linear topology"); + TT_FATAL(topology == ttnn::ccl::Topology::Linear, "This api currently supported only for Linear topology"); const auto mesh_view = mesh_device.get_view(); std::size_t num_devices = (cluster_axis == 0) ? mesh_view->num_rows() : mesh_view->num_cols();