Skip to content

Commit

Permalink
#5560: Combine all_gather with launch_op
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Oct 16, 2024
1 parent 83e452a commit f07ae46
Showing 1 changed file with 9 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.hpp"
#include "ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp"
#include "ttnn/operations/ccl/all_gather/all_gather.hpp"
#include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp"
#include "tt_metal/host_api.hpp"

#include <cstdint>
Expand Down Expand Up @@ -115,7 +116,7 @@ Tensor all_reduce(
}
TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error in all reduce op setup");

return operation::run(
std::vector<Tensor> reduced_tensors = operation::run(
ttnn::AllReduce{
binary_op_type,
scatter_dim,
Expand All @@ -129,12 +130,16 @@ Tensor all_reduce(
user_defined_num_workers,
user_defined_num_buffers_per_channel},
{input_tensor});
const auto& reduced_tensor = reduced_tensors.at(0);

return operation::run(
create_all_gather_struct(reduced_tensor, scatter_dim, num_links, output_mem_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, topology),
{reduced_tensor});
},
{input_tensor},
output_tensors);
// Perform all_gather operation
Tensor gathered_output = ttnn::all_gather(output_tensors.at(0), scatter_dim);
return gathered_output;

return output_tensors.at(0);
}

} // namespace ccl
Expand Down

0 comments on commit f07ae46

Please sign in to comment.