From f07ae468e0ed1a351185cdacf95aeab8c77d06ea Mon Sep 17 00:00:00 2001 From: Aswinmcw Date: Wed, 16 Oct 2024 09:03:20 +0000 Subject: [PATCH] #5560: Combine all_gather with launch_op --- .../ccl/all_reduce/device/all_reduce_op.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp index 5a2ab4fd7e86..a8f8529c3e41 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp @@ -5,6 +5,7 @@ #include "ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.hpp" #include "ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp" #include "ttnn/operations/ccl/all_gather/all_gather.hpp" +#include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" #include "tt_metal/host_api.hpp" #include @@ -115,7 +116,7 @@ Tensor all_reduce( } TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error in all reduce op setup"); - return operation::run( + std::vector reduced_tensors = operation::run( ttnn::AllReduce{ binary_op_type, scatter_dim, @@ -129,12 +130,16 @@ Tensor all_reduce( user_defined_num_workers, user_defined_num_buffers_per_channel}, {input_tensor}); + const auto& reduced_tensor = reduced_tensors.at(0); + + return operation::run( + create_all_gather_struct(reduced_tensor, scatter_dim, num_links, output_mem_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, topology), + {reduced_tensor}); }, {input_tensor}, output_tensors); - // Perform all_gather operation - Tensor gathered_output = ttnn::all_gather(output_tensors.at(0), scatter_dim); - return gathered_output; + + return output_tensors.at(0); } } // namespace ccl