diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index e7a0fd0f9bd7..07cdc9d8bf28 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -24,39 +24,8 @@ AllGather create_all_gather_struct( const ttnn::ccl::Topology topology ) { uint32_t num_devices = devices.size(); - - uint32_t device_index = 0; // Initialize device index - std::optional receiver_device_id = std::nullopt; // Initialize receiver device ID - std::optional sender_device_id = std::nullopt; // Initialize sender device ID - - for (uint32_t i = 0; i < num_devices; ++i) { - if (devices[i] == input_tensor.device()) { - device_index = i; - switch(topology){ - case ttnn::ccl::Topology::Ring:{ - // Ring topology - receiver_device_id = devices[(i + 1) % num_devices]->id(); // Next device in the ring - sender_device_id = devices[(i + num_devices - 1) % num_devices]->id(); // Previous device in the ring - break; - } - case ttnn::ccl::Topology::Linear:{ - // Linear topology - bool is_last_chip_in_clockwise_direction = i == (num_devices - 1); - bool is_last_chip_in_counter_clockwise_direction = i == 0; - receiver_device_id = is_last_chip_in_clockwise_direction ? - std::nullopt : - std::optional(devices.at(i+1)->id()); - sender_device_id = is_last_chip_in_counter_clockwise_direction ? - std::nullopt : - std::optional(devices.at(i-1)->id()); - break; - } - default: - TT_FATAL(false, "Invalid Topology {}, Accepted topologies are Ring and Linear currently", topology); - } - break; - } - } + auto [device_index, sender_device_id, receiver_device_id] = + getDeviceIndexAndSenderReceiverIDs(input_tensor, devices, topology); return ttnn::AllGather{ dim, num_links, num_devices, device_index, user_defined_num_workers, user_defined_num_buffers_per_channel, receiver_device_id, sender_device_id, memory_config.value_or(input_tensor.memory_config()), topology}; diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp index 410f8aaf85ca..07438040e30d 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp @@ -12,6 +12,35 @@ namespace ttnn { namespace ccl { +std::tuple, std::optional> getDeviceIndexAndSenderReceiverIDs( + const Tensor& input_tensor, + const std::vector& devices, + const ttnn::ccl::Topology& topology) { + + uint32_t num_devices = devices.size(); + bool is_linear = topology == ttnn::ccl::Topology::Linear; + uint32_t device_index = 0; // Initialize device index + for (uint32_t i = 0; i < num_devices; ++i) { + if (devices.at(i) == input_tensor.device()) { + device_index = i; + bool is_last_chip_in_clockwise_direction = is_linear && i == (num_devices - 1); + bool is_last_chip_in_counter_clockwise_direction = is_linear && i == 0; + + std::optional receiver_device_id = is_last_chip_in_clockwise_direction ? + std::nullopt : + std::optional(devices.at((i + 1) % num_devices)->id()); + + std::optional sender_device_id = is_last_chip_in_counter_clockwise_direction ? + std::nullopt : + std::optional(devices.at((i + num_devices - 1) % num_devices)->id()); + + return {device_index, sender_device_id, receiver_device_id}; + } + } + + return {device_index, std::nullopt, std::nullopt}; // Return null if the device is not found +} + RingTopology::RingTopology( Device const* device, Topology topology, diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp index 231e086be3b3..e41d0f51e9ad 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp @@ -18,6 +18,11 @@ namespace ttnn { namespace ccl { +std::tuple, std::optional> getDeviceIndexAndSenderReceiverIDs( + const Tensor& input_tensor, + const std::vector& devices, + const ttnn::ccl::Topology& topology); + // Eventual home: ccl_topology_descriptors struct RingTopology { RingTopology( diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp index 2c87dd4dd000..6965fd34c0b6 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp @@ -89,27 +89,11 @@ Tensor reduce_scatter( if (num_devices == 2){ topology = ttnn::ccl::Topology::Linear; } - bool is_linear = topology == ttnn::ccl::Topology::Linear; const auto& input_tensor = input_tensors.at(0); - uint32_t device_index = 0; // Initialize device index - std::optional receiver_device_id = std::nullopt; // Initialize receiver device ID - std::optional sender_device_id = std::nullopt; // Initialize sender device ID - for (uint32_t i = 0; i < num_devices; ++i) { - if (devices.at(i) == input_tensor.device()) { - - bool is_last_chip_in_clockwise_direction = is_linear && i == (num_devices - 1); - bool is_last_chip_in_counter_clockwise_direction = is_linear && i == 0; - device_index = i; - receiver_device_id = is_last_chip_in_clockwise_direction ? - std::nullopt : - std::optional(devices.at((i + 1) % num_devices)->id()); - sender_device_id = is_last_chip_in_counter_clockwise_direction ? - std::nullopt : - std::optional(devices.at((i + num_devices - 1) % num_devices)->id()); - break; - } - } + auto [device_index, sender_device_id, receiver_device_id] = + getDeviceIndexAndSenderReceiverIDs(input_tensor, devices, topology); + TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error, Reduce-scatter was unable to identify either a sender or receiver device ID and atleast one must be identified for a valid Reduce-scatter configuration. The input mesh tensor or Reduce-scatter arguments may be incorrect"); return operation::run(