From 385dce279d8ed0bda2b6ca7f09c426609f10dede Mon Sep 17 00:00:00 2001
From: Aswinmcw <azayasankaran@tenstorrent.com>
Date: Tue, 22 Oct 2024 07:43:45 +0000
Subject: [PATCH] #5560: Initial commit to get reduce_scatter as common

---
 .../device/reduce_scatter_op.cpp              | 95 ++++++++++++-------
 1 file changed, 61 insertions(+), 34 deletions(-)

diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp
index 2c87dd4dd000..2f8c4bf522a2 100644
--- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp
+++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp
@@ -9,6 +9,55 @@
 
 namespace ttnn {
 
+ReduceScatter create_reduce_scatter_struct (
+    const Tensor& input_tensor,
+    const ttnn::operations::binary::BinaryOpType binary_op_type,
+    const uint32_t scatter_dim,
+    const uint32_t num_links,
+    const MemoryConfig output_mem_config,
+    const std::optional<size_t> user_defined_num_workers,
+    const std::optional<size_t> user_defined_num_buffers_per_channel,
+    const std::vector<Device*>& devices,
+    const ttnn::ccl::Topology topology
+){
+    uint32_t num_devices = devices.size();
+
+    bool is_linear = topology == ttnn::ccl::Topology::Linear;
+
+    uint32_t device_index = 0; // Initialize device index
+    std::optional<chip_id_t> receiver_device_id = std::nullopt; // Initialize receiver device ID
+    std::optional<chip_id_t> sender_device_id = std::nullopt; // Initialize sender device ID
+    for (uint32_t i = 0; i < num_devices; ++i) {
+        if (devices.at(i) == input_tensor.device()) {
+
+            bool is_last_chip_in_clockwise_direction = is_linear && i == (num_devices - 1);
+            bool is_last_chip_in_counter_clockwise_direction = is_linear && i == 0;
+            device_index = i;
+            receiver_device_id = is_last_chip_in_clockwise_direction ?
+                std::nullopt :
+                std::optional<chip_id_t>(devices.at((i + 1) % num_devices)->id());
+            sender_device_id = is_last_chip_in_counter_clockwise_direction ?
+                std::nullopt :
+                std::optional<chip_id_t>(devices.at((i + num_devices - 1) % num_devices)->id());
+            break;
+        }
+    }
+    TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error, Reduce-scatter was unable to identify either a sender or receiver device ID and atleast one must be identified for a valid Reduce-scatter configuration. The input mesh tensor or Reduce-scatter arguments may be incorrect");
+
+    return ttnn::ReduceScatter{
+                    binary_op_type,
+                    scatter_dim,
+                    num_links,
+                    num_devices,
+                    device_index,
+                    receiver_device_id,
+                    sender_device_id,
+                    output_mem_config,
+                    topology,
+                    user_defined_num_workers,
+                    user_defined_num_buffers_per_channel};
+}
+
 void ReduceScatter::validate(const std::vector<Tensor>& input_tensors) const {
     for (auto const& t : input_tensors) {
         TT_FATAL(
@@ -77,54 +126,32 @@ Tensor reduce_scatter(
     ttnn::operations::binary::BinaryOpType binary_op_type = convert_reduce_type_to_eltwise_type(math_op);
     TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "reduce_scatter op is only supported for Fast Dispatch");
 
+    ttnn::ccl::Topology ccl_topology = topology;
     auto devices = input_tensor.get_workers();
+    uint32_t num_devices = devices.size();
+    if (num_devices == 2){
+        ccl_topology = ttnn::ccl::Topology::Linear;
+    }
+
     std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))};
     operation::launch_op(
-        [binary_op_type, scatter_dim, num_links, output_mem_config, topology, devices, user_defined_num_workers, user_defined_num_buffers_per_channel](
+        [binary_op_type, scatter_dim, num_links, output_mem_config, ccl_topology, devices, user_defined_num_workers, user_defined_num_buffers_per_channel](
             const std::vector<Tensor>& input_tensors,
             const std::vector<std::optional<const Tensor>>& optional_input_tensors,
             const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
 
-            uint32_t num_devices = devices.size();
-            if (num_devices == 2){
-                topology = ttnn::ccl::Topology::Linear;
-            }
-            bool is_linear = topology == ttnn::ccl::Topology::Linear;
-
             const auto& input_tensor = input_tensors.at(0);
-            uint32_t device_index = 0; // Initialize device index
-            std::optional<chip_id_t> receiver_device_id = std::nullopt; // Initialize receiver device ID
-            std::optional<chip_id_t> sender_device_id = std::nullopt; // Initialize sender device ID
-            for (uint32_t i = 0; i < num_devices; ++i) {
-                if (devices.at(i) == input_tensor.device()) {
-
-                    bool is_last_chip_in_clockwise_direction = is_linear && i == (num_devices - 1);
-                    bool is_last_chip_in_counter_clockwise_direction = is_linear && i == 0;
-                    device_index = i;
-                    receiver_device_id = is_last_chip_in_clockwise_direction ?
-                        std::nullopt :
-                        std::optional<chip_id_t>(devices.at((i + 1) % num_devices)->id());
-                    sender_device_id = is_last_chip_in_counter_clockwise_direction ?
-                        std::nullopt :
-                        std::optional<chip_id_t>(devices.at((i + num_devices - 1) % num_devices)->id());
-                    break;
-                }
-            }
-            TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error, Reduce-scatter was unable to identify either a sender or receiver device ID and atleast one must be identified for a valid Reduce-scatter configuration. The input mesh tensor or Reduce-scatter arguments may be incorrect");
-
             return operation::run(
-                ttnn::ReduceScatter{
+                create_reduce_scatter_struct(
+                    input_tensor,
                     binary_op_type,
                     scatter_dim,
                     num_links,
-                    num_devices,
-                    device_index,
-                    receiver_device_id,
-                    sender_device_id,
                     output_mem_config,
-                    topology,
                     user_defined_num_workers,
-                    user_defined_num_buffers_per_channel},
+                    user_defined_num_buffers_per_channel,
+                    devices,
+                    ccl_topology),
                 {input_tensor});
         },
      {input_tensor},