From 385dce279d8ed0bda2b6ca7f09c426609f10dede Mon Sep 17 00:00:00 2001 From: Aswinmcw <azayasankaran@tenstorrent.com> Date: Tue, 22 Oct 2024 07:43:45 +0000 Subject: [PATCH] #5560: Initial commit to get reduce_scatter as common --- .../device/reduce_scatter_op.cpp | 95 ++++++++++++------- 1 file changed, 61 insertions(+), 34 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp index 2c87dd4dd000..2f8c4bf522a2 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp @@ -9,6 +9,55 @@ namespace ttnn { +ReduceScatter create_reduce_scatter_struct ( + const Tensor& input_tensor, + const ttnn::operations::binary::BinaryOpType binary_op_type, + const uint32_t scatter_dim, + const uint32_t num_links, + const MemoryConfig output_mem_config, + const std::optional<size_t> user_defined_num_workers, + const std::optional<size_t> user_defined_num_buffers_per_channel, + const std::vector<Device*>& devices, + const ttnn::ccl::Topology topology +){ + uint32_t num_devices = devices.size(); + + bool is_linear = topology == ttnn::ccl::Topology::Linear; + + uint32_t device_index = 0; // Initialize device index + std::optional<chip_id_t> receiver_device_id = std::nullopt; // Initialize receiver device ID + std::optional<chip_id_t> sender_device_id = std::nullopt; // Initialize sender device ID + for (uint32_t i = 0; i < num_devices; ++i) { + if (devices.at(i) == input_tensor.device()) { + + bool is_last_chip_in_clockwise_direction = is_linear && i == (num_devices - 1); + bool is_last_chip_in_counter_clockwise_direction = is_linear && i == 0; + device_index = i; + receiver_device_id = is_last_chip_in_clockwise_direction ? + std::nullopt : + std::optional<chip_id_t>(devices.at((i + 1) % num_devices)->id()); + sender_device_id = is_last_chip_in_counter_clockwise_direction ? + std::nullopt : + std::optional<chip_id_t>(devices.at((i + num_devices - 1) % num_devices)->id()); + break; + } + } + TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error, Reduce-scatter was unable to identify either a sender or receiver device ID and atleast one must be identified for a valid Reduce-scatter configuration. The input mesh tensor or Reduce-scatter arguments may be incorrect"); + + return ttnn::ReduceScatter{ + binary_op_type, + scatter_dim, + num_links, + num_devices, + device_index, + receiver_device_id, + sender_device_id, + output_mem_config, + topology, + user_defined_num_workers, + user_defined_num_buffers_per_channel}; +} + void ReduceScatter::validate(const std::vector<Tensor>& input_tensors) const { for (auto const& t : input_tensors) { TT_FATAL( @@ -77,54 +126,32 @@ Tensor reduce_scatter( ttnn::operations::binary::BinaryOpType binary_op_type = convert_reduce_type_to_eltwise_type(math_op); TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "reduce_scatter op is only supported for Fast Dispatch"); + ttnn::ccl::Topology ccl_topology = topology; auto devices = input_tensor.get_workers(); + uint32_t num_devices = devices.size(); + if (num_devices == 2){ + ccl_topology = ttnn::ccl::Topology::Linear; + } + std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; operation::launch_op( - [binary_op_type, scatter_dim, num_links, output_mem_config, topology, devices, user_defined_num_workers, user_defined_num_buffers_per_channel]( + [binary_op_type, scatter_dim, num_links, output_mem_config, ccl_topology, devices, user_defined_num_workers, user_defined_num_buffers_per_channel]( const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors, const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> { - uint32_t num_devices = devices.size(); - if (num_devices == 2){ - topology = ttnn::ccl::Topology::Linear; - } - bool is_linear = topology == ttnn::ccl::Topology::Linear; - const auto& input_tensor = input_tensors.at(0); - uint32_t device_index = 0; // Initialize device index - std::optional<chip_id_t> receiver_device_id = std::nullopt; // Initialize receiver device ID - std::optional<chip_id_t> sender_device_id = std::nullopt; // Initialize sender device ID - for (uint32_t i = 0; i < num_devices; ++i) { - if (devices.at(i) == input_tensor.device()) { - - bool is_last_chip_in_clockwise_direction = is_linear && i == (num_devices - 1); - bool is_last_chip_in_counter_clockwise_direction = is_linear && i == 0; - device_index = i; - receiver_device_id = is_last_chip_in_clockwise_direction ? - std::nullopt : - std::optional<chip_id_t>(devices.at((i + 1) % num_devices)->id()); - sender_device_id = is_last_chip_in_counter_clockwise_direction ? - std::nullopt : - std::optional<chip_id_t>(devices.at((i + num_devices - 1) % num_devices)->id()); - break; - } - } - TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error, Reduce-scatter was unable to identify either a sender or receiver device ID and atleast one must be identified for a valid Reduce-scatter configuration. The input mesh tensor or Reduce-scatter arguments may be incorrect"); - return operation::run( - ttnn::ReduceScatter{ + create_reduce_scatter_struct( + input_tensor, binary_op_type, scatter_dim, num_links, - num_devices, - device_index, - receiver_device_id, - sender_device_id, output_mem_config, - topology, user_defined_num_workers, - user_defined_num_buffers_per_channel}, + user_defined_num_buffers_per_channel, + devices, + ccl_topology), {input_tensor}); }, {input_tensor},