From c40e47d6005f5e991c36042e7ef2d0b372e5fcba Mon Sep 17 00:00:00 2001 From: Aswinmcw Date: Tue, 2 Jul 2024 12:33:31 +0000 Subject: [PATCH 01/10] #9486: Merge line_all_gather to TTNN --- .../ccl/all_gather/all_gather_op.hpp | 29 +++++++++++++++++++ .../ccl/all_gather/all_gather_pybind.hpp | 22 ++++++++++++++ .../ccl/all_gather/device/all_gather_op.hpp | 6 ++++ 3 files changed, 57 insertions(+) diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp index 94c84f41a38..520fc802bda 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp @@ -22,9 +22,38 @@ struct ExecuteAllGather { } }; +struct ExecuteLineAllGather { + static inline const std::array input_tensor_schemas() { + return {ttnn::TensorSchema{ + 2, + 4, + {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, + {ttnn::ROW_MAJOR_LAYOUT, ttnn::TILE_LAYOUT}, + true, + false, + false, + false}}; + } + + template + static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) { + return std::forward_as_tuple(input_tensor); + } + + static ttnn::Tensor execute_on_main_thread( + const ttnn::Tensor& input_tensor, + const uint32_t dim, + const uint32_t num_links = 1, + const std::optional& memory_config = std::nullopt) { + return tt::operations::ccl::line_all_gather(input_tensor, dim, num_links, memory_config); + } +}; + } // namespace ccl } // namespace operations constexpr auto all_gather = ttnn::register_operation("ttnn::all_gather"); +constexpr auto line_all_gather = ttnn::register_operation("ttnn::line_all_gather"); + } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp index c3aa4015c49..c8ecafae369 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp @@ -65,6 +65,28 @@ void py_bind_all_gather(py::module& module) { >>> output = ttnn.all_gather(tensor, dim=0) )doc"); + + detail::bind_ccl_operation( + module, + ttnn::line_all_gather, + R"doc(line_all_gather(input_tensor: ttnn.Tensor, dim: int, *, num_links: int = 1, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor + + Performs an all-gather operation on multi-device :attr:`input_tensor` across all devices. + + Args: + * :attr:`input_tensor` (ttnn.Tensor): multi-device tensor + * :attr:`dim` (int) + + Keyword Args: + * :attr:`num_links` (int): Number of links to use for the all-gather operation. + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. + + Example: + + >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) + >>> output = ttnn.line_all_gather(tensor, dim=0) + + )doc"); } } // namespace ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp index c3469068f73..cb82a498bf4 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp @@ -743,6 +743,12 @@ Tensor all_gather( const uint32_t num_links = 1, const std::optional& memory_config = std::nullopt); +Tensor line_all_gather( + const Tensor& input_tensor, + const uint32_t dim, + const uint32_t num_links = 1, + const std::optional& memory_config = std::nullopt); + } // namespace ccl } // namespace operations From 20a9371994c2674bcdc6978dd02380cc716469c3 Mon Sep 17 00:00:00 2001 From: Aswinmcw Date: Thu, 4 Jul 2024 11:51:57 +0000 Subject: [PATCH 02/10] #9486: Move CCL kernel files to TTNN --- .../ops/ccl/test_all_gather_utils.cpp | 3 - ttnn/cpp/pybind11/operations/__init__.hpp | 6 + .../ccl/all_gather/all_gather_op.hpp | 11 +- .../ccl/all_gather/all_gather_pybind.hpp | 18 ++- .../ccl/all_gather/ccl_all_gather_pybind.hpp | 72 ++++++++++++ .../ccl/all_gather/device/all_gather_op.hpp | 70 ++++++------ .../all_gather/device/ccl_all_gather_op.hpp | 46 ++++++++ .../multi_core/all_gather_op_multi_core.cpp | 4 +- .../ccl_line_all_gather_pybind.hpp | 106 ++++++++++++++++++ .../device/ccl_line_all_gather_op.hpp | 64 +++++++++++ .../device/line_all_gather_op.cpp | 56 +++++++++ .../device/line_all_gather_op.hpp | 52 +++++++++ 12 files changed, 459 insertions(+), 49 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/ccl/all_gather/ccl_all_gather_pybind.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/all_gather/device/ccl_all_gather_op.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp diff --git a/tests/tt_eager/ops/ccl/test_all_gather_utils.cpp b/tests/tt_eager/ops/ccl/test_all_gather_utils.cpp index 35335223344..2a49e90728b 100644 --- a/tests/tt_eager/ops/ccl/test_all_gather_utils.cpp +++ b/tests/tt_eager/ops/ccl/test_all_gather_utils.cpp @@ -227,9 +227,6 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_GetIntraCoreStrideInSh uint32_t ring_size = 8; auto stride = ttnn::OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size,num_workers,ring_size); ASSERT_EQ(stride, 3); - } - { - uint32_t input_shard_grid_size = 16; uint32_t num_workers = 4; uint32_t ring_size = 8; auto stride = ttnn::OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size,num_workers,ring_size); diff --git a/ttnn/cpp/pybind11/operations/__init__.hpp b/ttnn/cpp/pybind11/operations/__init__.hpp index dc40707d640..99402370b64 100644 --- a/ttnn/cpp/pybind11/operations/__init__.hpp +++ b/ttnn/cpp/pybind11/operations/__init__.hpp @@ -62,6 +62,12 @@ void py_module(py::module& module) { auto m_unary_backward = module.def_submodule("unary_backward", "unary_backward operations"); unary_backward::py_module(m_unary_backward); + + auto m_ccl_all_gather = module.def_submodule("ccl_all_gather", "collective communication operations"); + ccl_all_gather::py_module(m_ccl_all_gather); + + auto m_ccl_line_all_gather = module.def_submodule("ccl_line_all_gather", "collective communication operations "); + ccl_line_all_gather::py_module(m_ccl_line_all_gather); auto m_ccl = module.def_submodule("ccl", "collective communication operations"); ccl::py_bind_all_gather(m_ccl); diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp index 520fc802bda..2dd167b3db2 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp @@ -4,13 +4,18 @@ #pragma once +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp #include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" +======= +#include "ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.hpp" +>>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp #include "ttnn/cpp/ttnn/multi_device.hpp" namespace ttnn { namespace operations { namespace ccl { +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp struct ExecuteAllGather { static ttnn::Tensor execute_on_main_thread( @@ -22,6 +27,8 @@ struct ExecuteAllGather { } }; +======= +>>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp struct ExecuteLineAllGather { static inline const std::array input_tensor_schemas() { return {ttnn::TensorSchema{ @@ -45,15 +52,13 @@ struct ExecuteLineAllGather { const uint32_t dim, const uint32_t num_links = 1, const std::optional& memory_config = std::nullopt) { - return tt::operations::ccl::line_all_gather(input_tensor, dim, num_links, memory_config); + return ttnn::operations::ccl::line_all_gather(input_tensor, dim, num_links, memory_config); } }; } // namespace ccl } // namespace operations -constexpr auto all_gather = ttnn::register_operation("ttnn::all_gather"); - constexpr auto line_all_gather = ttnn::register_operation("ttnn::line_all_gather"); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp index c8ecafae369..eeb5d5a46e9 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp @@ -8,19 +8,27 @@ #include #include "ttnn/cpp/pybind11/decorators.hpp" +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp #include "ttnn/operations/ccl/all_gather/all_gather_op.hpp" +======= +#include "ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp" +>>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp #include "ttnn/types.hpp" namespace py = pybind11; namespace ttnn { namespace operations { -namespace ccl { +namespace ccl_line_all_gather { namespace detail { template +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp void bind_all_gather(py::module& module, const ccl_operation_t& operation, const char* doc) { +======= +void bind_ccl_line_all_gather(py::module& module, const ccl_operation_t& operation, const char* doc) { +>>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp bind_registered_operation( module, operation, @@ -43,6 +51,7 @@ void bind_all_gather(py::module& module, const ccl_operation_t& operation, const } // namespace detail +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp void py_bind_all_gather(py::module& module) { detail::bind_all_gather( module, @@ -65,8 +74,11 @@ void py_bind_all_gather(py::module& module) { >>> output = ttnn.all_gather(tensor, dim=0) )doc"); +======= +void py_module(py::module& module) { +>>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp - detail::bind_ccl_operation( + detail::bind_ccl_line_all_gather( module, ttnn::line_all_gather, R"doc(line_all_gather(input_tensor: ttnn.Tensor, dim: int, *, num_links: int = 1, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor @@ -89,6 +101,6 @@ void py_bind_all_gather(py::module& module) { )doc"); } -} // namespace ccl +} // namespace ccl_line_all_gather } // namespace operations } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/ccl_all_gather_pybind.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/ccl_all_gather_pybind.hpp new file mode 100644 index 00000000000..97b1fc3b055 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/ccl_all_gather_pybind.hpp @@ -0,0 +1,72 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "ttnn/cpp/pybind11/decorators.hpp" +#include "ttnn/operations/ccl/all_gather/device/ccl_all_gather_op.hpp" +#include "ttnn/types.hpp" + +namespace py = pybind11; + +namespace ttnn { +namespace operations { +namespace ccl_all_gather { + +namespace detail { + +template +void bind_ccl_all_gather(py::module& module, const ccl_operation_t& operation, const char* doc) { + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const ccl_operation_t& self, + const ttnn::Tensor& input_tensor, + const uint32_t dim, + const uint32_t num_links, + const std::optional& memory_config) -> ttnn::Tensor { + return self(input_tensor, dim, num_links, memory_config); + }, + py::arg("input_tensor"), + py::arg("dim"), + py::kw_only(), + py::arg("num_links") = 1, + py::arg("memory_config") = std::nullopt}); +} + +} // namespace detail + + +void py_module(py::module& module) { + detail::bind_ccl_all_gather( + module, + ttnn::all_gather, + R"doc(all_gather(input_tensor: ttnn.Tensor, dim: int, *, num_links: int = 1, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor + + Performs an all-gather operation on multi-device :attr:`input_tensor` across all devices. + + Args: + * :attr:`input_tensor` (ttnn.Tensor): multi-device tensor + * :attr:`dim` (int) + + Keyword Args: + * :attr:`num_links` (int): Number of links to use for the all-gather operation. + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. + + Example: + + >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) + >>> output = ttnn.all_gather(tensor, dim=0) + + )doc"); +} + +} // namespace ccl_all_gather +} // namespace operations +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp index cb82a498bf4..2f84fceefe2 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp @@ -29,10 +29,10 @@ enum AllGatherMode { }; namespace all_gather_op { -using ccl::Topology; +using tt::tt_metal::ccl::Topology; }; // namespace all_gather_op -using ccl::EriscDatamoverBuilder; +using tt::tt_metal::ccl::EriscDatamoverBuilder; AllGatherMode choose_all_gather_mode(Tensor const& input_tensor, Tensor const& output_tensor, uint32_t dim); @@ -255,13 +255,13 @@ struct ShardedAllGatherConfig { switch(input_tensor.memory_config().memory_layout) { case TensorMemoryLayout::WIDTH_SHARDED: - this->shard_type = ccl::ShardType::Width; + this->shard_type = tt::tt_metal::ccl::ShardType::Width; break; case TensorMemoryLayout::BLOCK_SHARDED: - this->shard_type = ccl::ShardType::Block; + this->shard_type = tt::tt_metal::ccl::ShardType::Block; break; case TensorMemoryLayout::HEIGHT_SHARDED: - this->shard_type = ccl::ShardType::Height; + this->shard_type = tt::tt_metal::ccl::ShardType::Height; break; case TensorMemoryLayout::INTERLEAVED: case TensorMemoryLayout::SINGLE_BANK: @@ -285,7 +285,7 @@ struct ShardedAllGatherConfig { return single_tile_shard_on_dim; } - ccl::ShardType get_shard_type() const { + tt::tt_metal::ccl::ShardType get_shard_type() const { TT_ASSERT(is_sharding_enabled, "Tried getting sharding config for non-sharded tensor"); return shard_type; } @@ -293,7 +293,7 @@ struct ShardedAllGatherConfig { private: bool requires_post_all_gather_reshard; bool single_tile_shard_on_dim; - ccl::ShardType shard_type; + tt::tt_metal::ccl::ShardType shard_type; bool is_sharding_enabled; }; @@ -302,7 +302,7 @@ struct ShardedAllGatherConfig { struct ShardAddrGenArgGenerator { using shard_cores_t = CoreRangeSet; - ShardAddrGenArgGenerator(ccl::ShardAddrGenArgs const& args_struct) : + ShardAddrGenArgGenerator(tt::tt_metal::ccl::ShardAddrGenArgs const& args_struct) : args_struct(args_struct), initialized(true) {} ShardAddrGenArgGenerator() : initialized(false) {} @@ -312,14 +312,14 @@ struct ShardAddrGenArgGenerator { std::vector args; args.reserve(7 * this->args_struct.num_dest_cores * 2); - TT_ASSERT(this->args_struct.shard_size_in_bytes != ccl::UNINITIALIZED_VALUE_U32); - TT_ASSERT(this->args_struct.total_chunks_per_core != ccl::UNINITIALIZED_VALUE_U16); - TT_ASSERT(this->args_struct.shards_start_address != ccl::UNINITIALIZED_VALUE_U32); - TT_ASSERT(this->args_struct.starting_core_index != ccl::UNINITIALIZED_VALUE_U16); - TT_ASSERT(this->args_struct.starting_chunk_into_shard != ccl::UNINITIALIZED_VALUE_U16); - TT_ASSERT(this->args_struct.intra_core_stride_in_shards != ccl::UNINITIALIZED_VALUE_U16); - TT_ASSERT(this->args_struct.contiguous_chunks_before_stride != ccl::UNINITIALIZED_VALUE_U16); - TT_ASSERT(this->args_struct.num_dest_cores != ccl::UNINITIALIZED_VALUE_U16); + TT_ASSERT(this->args_struct.shard_size_in_bytes != tt::tt_metal::ccl::UNINITIALIZED_VALUE_U32); + TT_ASSERT(this->args_struct.total_chunks_per_core != tt::tt_metal::ccl::UNINITIALIZED_VALUE_U16); + TT_ASSERT(this->args_struct.shards_start_address != tt::tt_metal::ccl::UNINITIALIZED_VALUE_U32); + TT_ASSERT(this->args_struct.starting_core_index != tt::tt_metal::ccl::UNINITIALIZED_VALUE_U16); + TT_ASSERT(this->args_struct.starting_chunk_into_shard != tt::tt_metal::ccl::UNINITIALIZED_VALUE_U16); + TT_ASSERT(this->args_struct.intra_core_stride_in_shards != tt::tt_metal::ccl::UNINITIALIZED_VALUE_U16); + TT_ASSERT(this->args_struct.contiguous_chunks_before_stride != tt::tt_metal::ccl::UNINITIALIZED_VALUE_U16); + TT_ASSERT(this->args_struct.num_dest_cores != tt::tt_metal::ccl::UNINITIALIZED_VALUE_U16); TT_ASSERT(this->args_struct.dest_cores.size() != 0); args.push_back(this->args_struct.is_clockwise); @@ -331,7 +331,7 @@ struct ShardAddrGenArgGenerator { args.push_back(this->args_struct.intra_core_stride_in_shards); args.push_back(this->args_struct.contiguous_chunks_before_stride); args.push_back(this->args_struct.num_dest_cores); - for (ccl::WorkerXY const& core : this->args_struct.dest_cores) { + for (tt::tt_metal::ccl::WorkerXY const& core : this->args_struct.dest_cores) { args.push_back(core.to_uint32()); } @@ -360,7 +360,7 @@ struct ShardAddrGenArgGenerator { TT_ASSERT(this->args_struct.starting_core_index < this->args_struct.dest_cores.size()); } - ccl::ShardAddrGenArgs args_struct; + tt::tt_metal::ccl::ShardAddrGenArgs args_struct; bool initialized; }; @@ -392,7 +392,7 @@ struct InputTensorShardAddrGenArgGenerator final : public ShardAddrGenArgGenerat } InputTensorShardAddrGenArgGenerator( Device const* device, - ccl::CclOpShardedTensorConfig *input_tensor_config, + tt::tt_metal::ccl::CclOpShardedTensorConfig *input_tensor_config, uint32_t ring_index, uint32_t ring_size, uint32_t num_workers, @@ -425,7 +425,7 @@ struct InputTensorShardAddrGenArgGenerator final : public ShardAddrGenArgGenerat this->args_struct.dest_cores.reserve(dest_core_coords.size()); std::transform(dest_core_coords.begin(), dest_core_coords.end(), std::back_inserter(this->args_struct.dest_cores), [&device](CoreCoord const& core) { - return ccl::WorkerXY( + return tt::tt_metal::ccl::WorkerXY( static_cast(device->worker_core_from_logical_core(core).x), static_cast(device->worker_core_from_logical_core(core).y) ); @@ -444,7 +444,7 @@ struct InputTensorShardAddrGenArgGenerator final : public ShardAddrGenArgGenerat struct OutputTensorShardAddrGenArgGenerator final : ShardAddrGenArgGenerator { static std::vector compute_worker_coord_worker_dest_cores ( - ccl::ShardType shard_type, + tt::tt_metal::ccl::ShardType shard_type, std::vector const& global_shard_dest_cores, uint32_t input_num_shards, uint32_t output_num_shards, @@ -494,8 +494,8 @@ struct OutputTensorShardAddrGenArgGenerator final : ShardAddrGenArgGenerator { } - static std::vector compute_worker_dest_cores ( - ccl::ShardType shard_type, + static std::vector compute_worker_dest_cores ( + tt::tt_metal::ccl::ShardType shard_type, Device const& device, CoreRangeSet const& shard_core_range, uint32_t input_num_shards, @@ -512,11 +512,11 @@ struct OutputTensorShardAddrGenArgGenerator final : ShardAddrGenArgGenerator { worker_index, is_shard_orientation_row_major); - std::vector dest_cores_of_worker; + std::vector dest_cores_of_worker; dest_cores_of_worker.reserve(worker_coord_worker_dest_cores.size()); std::transform(worker_coord_worker_dest_cores.begin(), worker_coord_worker_dest_cores.end(), std::back_inserter(dest_cores_of_worker), [&device](CoreCoord const& core) { - return ccl::WorkerXY( + return tt::tt_metal::ccl::WorkerXY( static_cast(device.worker_core_from_logical_core(core).x), static_cast(device.worker_core_from_logical_core(core).y) ); @@ -588,8 +588,8 @@ struct OutputTensorShardAddrGenArgGenerator final : ShardAddrGenArgGenerator { OutputTensorShardAddrGenArgGenerator( AllGatherConfig const& all_gather_config, Device const* device, - ccl::CclOpShardedTensorConfig *input_tensor_config, - ccl::CclOpShardedTensorConfig *output_tensor_config, + tt::tt_metal::ccl::CclOpShardedTensorConfig *input_tensor_config, + tt::tt_metal::ccl::CclOpShardedTensorConfig *output_tensor_config, uint32_t ring_index, uint32_t ring_size, uint32_t num_workers, @@ -617,7 +617,7 @@ struct OutputTensorShardAddrGenArgGenerator final : ShardAddrGenArgGenerator { uint32_t input_num_shards = sharded_tensor_num_cores; uint32_t output_num_shards = input_num_shards * ring_size; this->args_struct.dest_cores = OutputTensorShardAddrGenArgGenerator::compute_worker_dest_cores ( - ccl::ShardType::Width, + tt::tt_metal::ccl::ShardType::Width, *device, tensor_shard_grid, input_num_shards, @@ -630,7 +630,7 @@ struct OutputTensorShardAddrGenArgGenerator final : ShardAddrGenArgGenerator { TT_ASSERT(this->args_struct.dest_cores.size() > 0); std::vector const& global_shard_dest_cores = corerange_to_cores(tensor_shard_grid, std::nullopt, is_shard_orientation_row_major); CoreCoord const& dest_core_coord = global_shard_dest_cores.at(global_starting_dest_worker_index); - ccl::WorkerXY noc0_starting_dest_core_xy( + tt::tt_metal::ccl::WorkerXY noc0_starting_dest_core_xy( static_cast(device->worker_core_from_logical_core(dest_core_coord).x), static_cast(device->worker_core_from_logical_core(dest_core_coord).y) ); @@ -679,7 +679,7 @@ struct FullWorkerGridShardAddrGenArgGenerator { args.push_back(args_struct.is_clockwise); args.push_back(args_struct.curr_core_index); args.push_back(args_struct.total_num_cores); - for (ccl::WorkerXY const& core : args_struct.dest_cores) { + for (tt::tt_metal::ccl::WorkerXY const& core : args_struct.dest_cores) { args.push_back(core.to_uint32()); } @@ -717,7 +717,7 @@ struct FullWorkerGridShardAddrGenArgGenerator { auto const& tensor_shard_grid = input_tensor.buffer()->shard_spec().grid(); this->args_struct.dest_cores = OutputTensorShardAddrGenArgGenerator::compute_worker_dest_cores ( - ccl::ShardType::Width, + tt::tt_metal::ccl::ShardType::Width, *device, tensor_shard_grid, tensor_shard_grid.num_cores(), @@ -730,7 +730,7 @@ struct FullWorkerGridShardAddrGenArgGenerator { this->initialized = true; } - ccl::FullWorkerGridShardAddrGenArgs args_struct; + tt::tt_metal::ccl::FullWorkerGridShardAddrGenArgs args_struct; bool initialized; }; @@ -743,12 +743,6 @@ Tensor all_gather( const uint32_t num_links = 1, const std::optional& memory_config = std::nullopt); -Tensor line_all_gather( - const Tensor& input_tensor, - const uint32_t dim, - const uint32_t num_links = 1, - const std::optional& memory_config = std::nullopt); - } // namespace ccl } // namespace operations diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/ccl_all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/ccl_all_gather_op.hpp new file mode 100644 index 00000000000..d1f975692be --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/ccl_all_gather_op.hpp @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" +#include "ttnn/cpp/ttnn/multi_device.hpp" + +namespace ttnn { +namespace operations { +namespace ccl { + +struct ExecuteAllGather { + static inline const std::array input_tensor_schemas() { + return {ttnn::TensorSchema{ + 2, + 4, + {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, + {ttnn::ROW_MAJOR_LAYOUT, ttnn::TILE_LAYOUT}, + true, + false, + false, + false}}; + } + + template + static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) { + return std::forward_as_tuple(input_tensor); + } + + static ttnn::Tensor execute_on_main_thread( + const ttnn::Tensor& input_tensor, + const uint32_t dim, + const uint32_t num_links = 1, + const std::optional& memory_config = std::nullopt) { + return ttnn::operations::ccl::all_gather(input_tensor, dim, num_links, memory_config); + } +}; + +} // namespace ccl +} // namespace operations + +constexpr auto all_gather = ttnn::register_operation("ttnn::all_gather"); + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp index 2280d7b352e..c9d81acd9ad 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp @@ -403,8 +403,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& for (uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { uint32_t num_workers_per_eth_buffer = std::min(workers_per_link, all_gather_config.get_num_eth_buffers_per_edm() - worker_index); - std::vector sender_worker_coords; - std::vector receiver_worker_coords; + std::vector sender_worker_coords; + std::vector receiver_worker_coords; for (uint32_t w = b * num_workers_per_eth_buffer; w < (b + 1) * num_workers_per_eth_buffer; ++w) { sender_worker_coords.push_back( ttnn::ccl::WorkerXY( diff --git a/ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp new file mode 100644 index 00000000000..eeb5d5a46e9 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp @@ -0,0 +1,106 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "ttnn/cpp/pybind11/decorators.hpp" +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp +#include "ttnn/operations/ccl/all_gather/all_gather_op.hpp" +======= +#include "ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp" +>>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp +#include "ttnn/types.hpp" + +namespace py = pybind11; + +namespace ttnn { +namespace operations { +namespace ccl_line_all_gather { + +namespace detail { + +template +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp +void bind_all_gather(py::module& module, const ccl_operation_t& operation, const char* doc) { +======= +void bind_ccl_line_all_gather(py::module& module, const ccl_operation_t& operation, const char* doc) { +>>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const ccl_operation_t& self, + const ttnn::Tensor& input_tensor, + const uint32_t dim, + const uint32_t num_links, + const std::optional& memory_config) -> ttnn::Tensor { + return self(input_tensor, dim, num_links, memory_config); + }, + py::arg("input_tensor"), + py::arg("dim"), + py::kw_only(), + py::arg("num_links") = 1, + py::arg("memory_config") = std::nullopt}); +} + +} // namespace detail + + +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp +void py_bind_all_gather(py::module& module) { + detail::bind_all_gather( + module, + ttnn::all_gather, + R"doc(all_gather(input_tensor: ttnn.Tensor, dim: int, *, num_links: int = 1, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor + + Performs an all-gather operation on multi-device :attr:`input_tensor` across all devices. + + Args: + * :attr:`input_tensor` (ttnn.Tensor): multi-device tensor + * :attr:`dim` (int) + + Keyword Args: + * :attr:`num_links` (int): Number of links to use for the all-gather operation. + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. + + Example: + + >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) + >>> output = ttnn.all_gather(tensor, dim=0) + + )doc"); +======= +void py_module(py::module& module) { +>>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp + + detail::bind_ccl_line_all_gather( + module, + ttnn::line_all_gather, + R"doc(line_all_gather(input_tensor: ttnn.Tensor, dim: int, *, num_links: int = 1, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor + + Performs an all-gather operation on multi-device :attr:`input_tensor` across all devices. + + Args: + * :attr:`input_tensor` (ttnn.Tensor): multi-device tensor + * :attr:`dim` (int) + + Keyword Args: + * :attr:`num_links` (int): Number of links to use for the all-gather operation. + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. + + Example: + + >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) + >>> output = ttnn.line_all_gather(tensor, dim=0) + + )doc"); +} + +} // namespace ccl_line_all_gather +} // namespace operations +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp new file mode 100644 index 00000000000..2dd167b3db2 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp @@ -0,0 +1,64 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp +#include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" +======= +#include "ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.hpp" +>>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp +#include "ttnn/cpp/ttnn/multi_device.hpp" + +namespace ttnn { +namespace operations { +namespace ccl { + +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp +struct ExecuteAllGather { + + static ttnn::Tensor execute_on_main_thread( + const ttnn::Tensor& input_tensor, + const uint32_t dim, + const uint32_t num_links = 1, + const std::optional& memory_config = std::nullopt) { + return ttnn::operations::ccl::all_gather(input_tensor, dim, num_links, memory_config); + } +}; + +======= +>>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp +struct ExecuteLineAllGather { + static inline const std::array input_tensor_schemas() { + return {ttnn::TensorSchema{ + 2, + 4, + {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, + {ttnn::ROW_MAJOR_LAYOUT, ttnn::TILE_LAYOUT}, + true, + false, + false, + false}}; + } + + template + static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) { + return std::forward_as_tuple(input_tensor); + } + + static ttnn::Tensor execute_on_main_thread( + const ttnn::Tensor& input_tensor, + const uint32_t dim, + const uint32_t num_links = 1, + const std::optional& memory_config = std::nullopt) { + return ttnn::operations::ccl::line_all_gather(input_tensor, dim, num_links, memory_config); + } +}; + +} // namespace ccl +} // namespace operations + +constexpr auto line_all_gather = ttnn::register_operation("ttnn::line_all_gather"); + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.cpp index 4c32817bdfb..2a900f153ef 100644 --- a/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.cpp @@ -14,6 +14,12 @@ namespace ttnn { +<<<<<<< HEAD +======= +namespace utils { + + +>>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN void LineAllGather::validate(const std::vector &input_tensors) const { TT_FATAL(input_tensors.size() == 1); const auto& input_tensor = input_tensors[0]; @@ -84,6 +90,51 @@ operation::ProgramWithCallbacks LineAllGather::create_program(const std::vector< }; } +<<<<<<< HEAD +======= + + +std::vector line_all_gather_impl(const std::vector& input_tensors, const uint32_t dim, const uint32_t num_links, const MemoryConfig& output_mem_config, const all_gather_op::Topology topology) { + + TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "This op is only supported for Fast Dispatch"); + + std::vector output_tensors = std::vector(input_tensors.size()); + + bool is_ring = topology == all_gather_op::Topology::Ring; + uint32_t num_inputs = static_cast(input_tensors.size()); + for (uint32_t i = 0; i < input_tensors.size(); ++i) { + output_tensors[i] = Tensor(operation::get_workers_for_op_output({input_tensors[i]})); + // Extract these tensors in the main thread, since they're used to get the sender and receiver device ids + // Dont get the device in the main thread, since it can cause stalls in async mode. + const Tensor& tensor_on_receiver = input_tensors[(i + 1) % num_inputs]; + const Tensor& tensor_on_sender = input_tensors[i == 0 ? num_inputs - 1 : i - 1]; + // Package output in vector, to populate it with launch_op + std::vector output_for_curr_device = {output_tensors[i]}; + operation::launch_op( + [is_ring, dim, num_links, i, num_inputs, output_mem_config, topology] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + bool is_last_chip_in_clockwise_direction = is_ring ? false : i == (num_inputs - 1); + bool is_last_chip_in_counter_clockwise_direction = is_ring ? false : i == 0; + + std::optional receiver_device_id = is_last_chip_in_clockwise_direction ? + std::nullopt : + std::optional(input_tensors.at(1).device()->id()); + std::optional sender_device_id = is_last_chip_in_counter_clockwise_direction ? + std::nullopt : + std::optional(input_tensors.at(2).device()->id()); + return operation::run(LineAllGather{dim, num_links, num_inputs, i, receiver_device_id, sender_device_id, output_mem_config,topology}, {input_tensors.at(0)}); + }, + {input_tensors[i], tensor_on_receiver, tensor_on_sender}, output_for_curr_device); + } + return output_tensors; +} + +std::vector line_all_gather(const std::vector& input_tensors, const uint32_t dim, const uint32_t num_links, const MemoryConfig& output_mem_config) { + return line_all_gather_impl(input_tensors, dim, num_links, output_mem_config, all_gather_op::Topology::Linear); +} + +} // namespace utils + +>>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN namespace operations { namespace ccl { @@ -117,8 +168,13 @@ Tensor line_all_gather( } return operation::run( +<<<<<<< HEAD ttnn::LineAllGather{ dim, num_links, num_devices, device_index, receiver_device_id, sender_device_id, memory_config.value_or(input_tensor.memory_config()), ttnn::all_gather_op::Topology::Linear}, +======= + ttnn::utils::LineAllGather{ + dim, num_links, num_devices, device_index, receiver_device_id, sender_device_id, memory_config.value_or(input_tensor.memory_config()), ttnn::utils::all_gather_op::Topology::Linear}, +>>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN {input_tensor}); }, {input_tensor}, diff --git a/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.hpp index c6171bc57f0..df3a85a94f7 100644 --- a/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.hpp @@ -8,11 +8,19 @@ #include "common/core_coord.h" #include "impl/buffers/buffer.hpp" #include "tensor/tensor.hpp" +<<<<<<< HEAD #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/host_api.hpp" #include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" #include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" +======= +#include "tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_dnn/op_library/ccl/ccl_host_datastructures.hpp" +#include "tt_dnn/op_library/ccl/ccl_common.hpp" +>>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN #include "tt_dnn/op_library/run_operation.hpp" @@ -22,11 +30,21 @@ namespace ttnn { +<<<<<<< HEAD namespace all_gather_op { using ccl::Topology; }; // namespace all_gather_op using ccl::EriscDatamoverBuilder; +======= +namespace utils { + +namespace all_gather_op { +using tt::tt_metal::ccl::Topology; +}; // namespace all_gather_op + +using tt::tt_metal::ccl::EriscDatamoverBuilder; +>>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN struct LineAllGather { @@ -43,8 +61,42 @@ struct LineAllGather { std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; +<<<<<<< HEAD +}; + +======= + + static constexpr auto attribute_names = std::forward_as_tuple( + "dim", + "num_links", + "ring_size", + "ring_index", + "receiver_device_id", + "sender_device_id", + "output_mem_config", + "topology"); + + const auto attribute_values() const { + return std::forward_as_tuple( + dim, num_links, ring_size, ring_index, receiver_device_id, sender_device_id, output_mem_config, topology); + } }; +// All Gather Variants +std::vector line_all_gather_impl( + const std::vector& input_tensors, + const uint32_t dim, + const uint32_t num_links, + const MemoryConfig& output_mem_config, + const all_gather_op::Topology topology); +std::vector line_all_gather( + const std::vector &input_tensors, + const uint32_t dim, + const uint32_t num_links = 1, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +} // namespace utils +>>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN namespace operations { namespace ccl { From 52ace56430ca565c480f22340005529249d823a2 Mon Sep 17 00:00:00 2001 From: Aswinmcw Date: Thu, 11 Jul 2024 09:18:17 +0000 Subject: [PATCH 03/10] #9486: Move CCL common to TTNN --- ...st_all_gather_sharded_indexing_helpers.cpp | 2 +- .../ops/ccl/test_all_gather_utils.cpp | 5 +- .../ccl/all_gather/device/all_gather_op.hpp | 86 +- .../multi_core/all_gather_op_multi_core.cpp | 4 +- ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp | 22 +- ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp | 8 +- .../ccl/ccl_host_datastructures.hpp | 10 +- ttnn/cpp/ttnn/operations/ccl/edm/README.md | 0 .../ccl/edm/erisc_async_datamover.hpp | 748 ++++++++++++++++++ .../operations/ccl/edm/erisc_datamover.cpp | 343 ++++++++ .../ccl/kernels/edm/erisc_async_datamover.hpp | 18 + .../ccl/kernels/edm/erisc_datamover.cpp | 17 + .../device/line_all_gather_op.hpp | 53 -- .../hetergeneous_data_structs.hpp | 2 +- 14 files changed, 1197 insertions(+), 121 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/ccl/edm/README.md create mode 100644 ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp diff --git a/tests/tt_eager/ops/ccl/test_all_gather_sharded_indexing_helpers.cpp b/tests/tt_eager/ops/ccl/test_all_gather_sharded_indexing_helpers.cpp index 3fa4d26db90..fe2c6e7b4bc 100644 --- a/tests/tt_eager/ops/ccl/test_all_gather_sharded_indexing_helpers.cpp +++ b/tests/tt_eager/ops/ccl/test_all_gather_sharded_indexing_helpers.cpp @@ -289,4 +289,4 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceSingleT } } } -} +} \ No newline at end of file diff --git a/tests/tt_eager/ops/ccl/test_all_gather_utils.cpp b/tests/tt_eager/ops/ccl/test_all_gather_utils.cpp index 2a49e90728b..ad1a0555f4e 100644 --- a/tests/tt_eager/ops/ccl/test_all_gather_utils.cpp +++ b/tests/tt_eager/ops/ccl/test_all_gather_utils.cpp @@ -227,6 +227,9 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_GetIntraCoreStrideInSh uint32_t ring_size = 8; auto stride = ttnn::OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size,num_workers,ring_size); ASSERT_EQ(stride, 3); + } + { + uint32_t input_shard_grid_size = 16; uint32_t num_workers = 4; uint32_t ring_size = 8; auto stride = ttnn::OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size,num_workers,ring_size); @@ -1361,4 +1364,4 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex0_NumWor ASSERT_EQ(curr_core_chunk_index, 4); ASSERT_EQ(curr_worker_index, 6); ASSERT_EQ(contiguous_chunk_count, 1); -} +} \ No newline at end of file diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp index 2f84fceefe2..536e18451e2 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp @@ -29,10 +29,10 @@ enum AllGatherMode { }; namespace all_gather_op { -using tt::tt_metal::ccl::Topology; +using ttnn::utils::ccl::Topology; }; // namespace all_gather_op -using tt::tt_metal::ccl::EriscDatamoverBuilder; +using ttnn::utils::ccl::EriscDatamoverBuilder; AllGatherMode choose_all_gather_mode(Tensor const& input_tensor, Tensor const& output_tensor, uint32_t dim); @@ -255,13 +255,13 @@ struct ShardedAllGatherConfig { switch(input_tensor.memory_config().memory_layout) { case TensorMemoryLayout::WIDTH_SHARDED: - this->shard_type = tt::tt_metal::ccl::ShardType::Width; + this->shard_type = ttnn::utils::ccl::ShardType::Width; break; case TensorMemoryLayout::BLOCK_SHARDED: - this->shard_type = tt::tt_metal::ccl::ShardType::Block; + this->shard_type = ttnn::utils::ccl::ShardType::Block; break; case TensorMemoryLayout::HEIGHT_SHARDED: - this->shard_type = tt::tt_metal::ccl::ShardType::Height; + this->shard_type = ttnn::utils::ccl::ShardType::Height; break; case TensorMemoryLayout::INTERLEAVED: case TensorMemoryLayout::SINGLE_BANK: @@ -285,7 +285,7 @@ struct ShardedAllGatherConfig { return single_tile_shard_on_dim; } - tt::tt_metal::ccl::ShardType get_shard_type() const { + ttnn::utils::ccl::ShardType get_shard_type() const { TT_ASSERT(is_sharding_enabled, "Tried getting sharding config for non-sharded tensor"); return shard_type; } @@ -293,7 +293,7 @@ struct ShardedAllGatherConfig { private: bool requires_post_all_gather_reshard; bool single_tile_shard_on_dim; - tt::tt_metal::ccl::ShardType shard_type; + ttnn::utils::ccl::ShardType shard_type; bool is_sharding_enabled; }; @@ -302,7 +302,7 @@ struct ShardedAllGatherConfig { struct ShardAddrGenArgGenerator { using shard_cores_t = CoreRangeSet; - ShardAddrGenArgGenerator(tt::tt_metal::ccl::ShardAddrGenArgs const& args_struct) : + ShardAddrGenArgGenerator(ccl::ShardAddrGenArgs const& args_struct) : args_struct(args_struct), initialized(true) {} ShardAddrGenArgGenerator() : initialized(false) {} @@ -312,14 +312,14 @@ struct ShardAddrGenArgGenerator { std::vector args; args.reserve(7 * this->args_struct.num_dest_cores * 2); - TT_ASSERT(this->args_struct.shard_size_in_bytes != tt::tt_metal::ccl::UNINITIALIZED_VALUE_U32); - TT_ASSERT(this->args_struct.total_chunks_per_core != tt::tt_metal::ccl::UNINITIALIZED_VALUE_U16); - TT_ASSERT(this->args_struct.shards_start_address != tt::tt_metal::ccl::UNINITIALIZED_VALUE_U32); - TT_ASSERT(this->args_struct.starting_core_index != tt::tt_metal::ccl::UNINITIALIZED_VALUE_U16); - TT_ASSERT(this->args_struct.starting_chunk_into_shard != tt::tt_metal::ccl::UNINITIALIZED_VALUE_U16); - TT_ASSERT(this->args_struct.intra_core_stride_in_shards != tt::tt_metal::ccl::UNINITIALIZED_VALUE_U16); - TT_ASSERT(this->args_struct.contiguous_chunks_before_stride != tt::tt_metal::ccl::UNINITIALIZED_VALUE_U16); - TT_ASSERT(this->args_struct.num_dest_cores != tt::tt_metal::ccl::UNINITIALIZED_VALUE_U16); + TT_ASSERT(this->args_struct.shard_size_in_bytes != ttnn::utils::ccl::UNINITIALIZED_VALUE_U32); + TT_ASSERT(this->args_struct.total_chunks_per_core != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16); + TT_ASSERT(this->args_struct.shards_start_address != ttnn::utils::ccl::UNINITIALIZED_VALUE_U32); + TT_ASSERT(this->args_struct.starting_core_index != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16); + TT_ASSERT(this->args_struct.starting_chunk_into_shard != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16); + TT_ASSERT(this->args_struct.intra_core_stride_in_shards != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16); + TT_ASSERT(this->args_struct.contiguous_chunks_before_stride != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16); + TT_ASSERT(this->args_struct.num_dest_cores != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16); TT_ASSERT(this->args_struct.dest_cores.size() != 0); args.push_back(this->args_struct.is_clockwise); @@ -331,7 +331,7 @@ struct ShardAddrGenArgGenerator { args.push_back(this->args_struct.intra_core_stride_in_shards); args.push_back(this->args_struct.contiguous_chunks_before_stride); args.push_back(this->args_struct.num_dest_cores); - for (tt::tt_metal::ccl::WorkerXY const& core : this->args_struct.dest_cores) { + for (ccl::WorkerXY const& core : this->args_struct.dest_cores) { args.push_back(core.to_uint32()); } @@ -360,7 +360,7 @@ struct ShardAddrGenArgGenerator { TT_ASSERT(this->args_struct.starting_core_index < this->args_struct.dest_cores.size()); } - tt::tt_metal::ccl::ShardAddrGenArgs args_struct; + ttnn::utils::ccl::ShardAddrGenArgs args_struct; bool initialized; }; @@ -392,7 +392,7 @@ struct InputTensorShardAddrGenArgGenerator final : public ShardAddrGenArgGenerat } InputTensorShardAddrGenArgGenerator( Device const* device, - tt::tt_metal::ccl::CclOpShardedTensorConfig *input_tensor_config, + ttnn::utils::ccl::CclOpShardedTensorConfig *input_tensor_config, uint32_t ring_index, uint32_t ring_size, uint32_t num_workers, @@ -425,7 +425,7 @@ struct InputTensorShardAddrGenArgGenerator final : public ShardAddrGenArgGenerat this->args_struct.dest_cores.reserve(dest_core_coords.size()); std::transform(dest_core_coords.begin(), dest_core_coords.end(), std::back_inserter(this->args_struct.dest_cores), [&device](CoreCoord const& core) { - return tt::tt_metal::ccl::WorkerXY( + return ttnn::utils::ccl::WorkerXY( static_cast(device->worker_core_from_logical_core(core).x), static_cast(device->worker_core_from_logical_core(core).y) ); @@ -444,7 +444,7 @@ struct InputTensorShardAddrGenArgGenerator final : public ShardAddrGenArgGenerat struct OutputTensorShardAddrGenArgGenerator final : ShardAddrGenArgGenerator { static std::vector compute_worker_coord_worker_dest_cores ( - tt::tt_metal::ccl::ShardType shard_type, + ttnn::utils::ccl::ShardType shard_type, std::vector const& global_shard_dest_cores, uint32_t input_num_shards, uint32_t output_num_shards, @@ -494,8 +494,8 @@ struct OutputTensorShardAddrGenArgGenerator final : ShardAddrGenArgGenerator { } - static std::vector compute_worker_dest_cores ( - tt::tt_metal::ccl::ShardType shard_type, + static std::vector compute_worker_dest_cores ( + ttnn::utils::ccl::ShardType shard_type, Device const& device, CoreRangeSet const& shard_core_range, uint32_t input_num_shards, @@ -512,11 +512,11 @@ struct OutputTensorShardAddrGenArgGenerator final : ShardAddrGenArgGenerator { worker_index, is_shard_orientation_row_major); - std::vector dest_cores_of_worker; + std::vector dest_cores_of_worker; dest_cores_of_worker.reserve(worker_coord_worker_dest_cores.size()); std::transform(worker_coord_worker_dest_cores.begin(), worker_coord_worker_dest_cores.end(), std::back_inserter(dest_cores_of_worker), [&device](CoreCoord const& core) { - return tt::tt_metal::ccl::WorkerXY( + return ttnn::utils::ccl::WorkerXY( static_cast(device.worker_core_from_logical_core(core).x), static_cast(device.worker_core_from_logical_core(core).y) ); @@ -588,8 +588,8 @@ struct OutputTensorShardAddrGenArgGenerator final : ShardAddrGenArgGenerator { OutputTensorShardAddrGenArgGenerator( AllGatherConfig const& all_gather_config, Device const* device, - tt::tt_metal::ccl::CclOpShardedTensorConfig *input_tensor_config, - tt::tt_metal::ccl::CclOpShardedTensorConfig *output_tensor_config, + ttnn::utils::ccl::CclOpShardedTensorConfig *input_tensor_config, + ttnn::utils::ccl::CclOpShardedTensorConfig *output_tensor_config, uint32_t ring_index, uint32_t ring_size, uint32_t num_workers, @@ -617,7 +617,7 @@ struct OutputTensorShardAddrGenArgGenerator final : ShardAddrGenArgGenerator { uint32_t input_num_shards = sharded_tensor_num_cores; uint32_t output_num_shards = input_num_shards * ring_size; this->args_struct.dest_cores = OutputTensorShardAddrGenArgGenerator::compute_worker_dest_cores ( - tt::tt_metal::ccl::ShardType::Width, + ttnn::utils::ccl::ShardType::Width, *device, tensor_shard_grid, input_num_shards, @@ -630,7 +630,7 @@ struct OutputTensorShardAddrGenArgGenerator final : ShardAddrGenArgGenerator { TT_ASSERT(this->args_struct.dest_cores.size() > 0); std::vector const& global_shard_dest_cores = corerange_to_cores(tensor_shard_grid, std::nullopt, is_shard_orientation_row_major); CoreCoord const& dest_core_coord = global_shard_dest_cores.at(global_starting_dest_worker_index); - tt::tt_metal::ccl::WorkerXY noc0_starting_dest_core_xy( + ttnn::utils::ccl::WorkerXY noc0_starting_dest_core_xy( static_cast(device->worker_core_from_logical_core(dest_core_coord).x), static_cast(device->worker_core_from_logical_core(dest_core_coord).y) ); @@ -655,17 +655,17 @@ struct FullWorkerGridShardAddrGenArgGenerator { args.reserve(12 + args_struct.total_num_cores); TT_ASSERT(args_struct.dest_cores.size() > 0, "dest_cores was uninitialized"); - TT_ASSERT(args_struct.tile_size_in_bytes != ccl::UNINITIALIZED_VALUE_U32, "tile_size_in_bytes was uninitialized"); - TT_ASSERT(args_struct.shards_start_address != ccl::UNINITIALIZED_VALUE_U32, "shards_start_address was uninitialized"); - TT_ASSERT(args_struct.curr_core_index != ccl::UNINITIALIZED_VALUE_U16, "curr_core_index was uninitialized"); - TT_ASSERT(args_struct.total_num_cores != ccl::UNINITIALIZED_VALUE_U16, "total_num_cores was uninitialized"); - TT_ASSERT(args_struct.curr_shard_tile_x != ccl::UNINITIALIZED_VALUE_U16, "curr_shard_tile_x was uninitialized"); - TT_ASSERT(args_struct.curr_shard_tile_y != ccl::UNINITIALIZED_VALUE_U16, "curr_shard_tile_y was uninitialized"); - TT_ASSERT(args_struct.curr_tile_index != ccl::UNINITIALIZED_VALUE_U16, "curr_tile_index was uninitialized"); - TT_ASSERT(args_struct.curr_shard != ccl::UNINITIALIZED_VALUE_U16, "curr_shard was uninitialized"); - TT_ASSERT(args_struct.input_shard_num_tiles_x != ccl::UNINITIALIZED_VALUE_U16, "input_shard_num_tiles_x was uninitialized"); - TT_ASSERT(args_struct.input_shard_num_tiles_y != ccl::UNINITIALIZED_VALUE_U16, "input_shard_num_tiles_y was uninitialized"); - TT_ASSERT(args_struct.total_shards_x != ccl::UNINITIALIZED_VALUE_U16, "total_shards_x was uninitialized"); + TT_ASSERT(args_struct.tile_size_in_bytes != ttnn::utils::ccl::UNINITIALIZED_VALUE_U32, "tile_size_in_bytes was uninitialized"); + TT_ASSERT(args_struct.shards_start_address != ttnn::utils::ccl::UNINITIALIZED_VALUE_U32, "shards_start_address was uninitialized"); + TT_ASSERT(args_struct.curr_core_index != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16, "curr_core_index was uninitialized"); + TT_ASSERT(args_struct.total_num_cores != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16, "total_num_cores was uninitialized"); + TT_ASSERT(args_struct.curr_shard_tile_x != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16, "curr_shard_tile_x was uninitialized"); + TT_ASSERT(args_struct.curr_shard_tile_y != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16, "curr_shard_tile_y was uninitialized"); + TT_ASSERT(args_struct.curr_tile_index != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16, "curr_tile_index was uninitialized"); + TT_ASSERT(args_struct.curr_shard != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16, "curr_shard was uninitialized"); + TT_ASSERT(args_struct.input_shard_num_tiles_x != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16, "input_shard_num_tiles_x was uninitialized"); + TT_ASSERT(args_struct.input_shard_num_tiles_y != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16, "input_shard_num_tiles_y was uninitialized"); + TT_ASSERT(args_struct.total_shards_x != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16, "total_shards_x was uninitialized"); args.push_back(args_struct.tile_size_in_bytes); args.push_back(args_struct.shards_start_address); @@ -679,7 +679,7 @@ struct FullWorkerGridShardAddrGenArgGenerator { args.push_back(args_struct.is_clockwise); args.push_back(args_struct.curr_core_index); args.push_back(args_struct.total_num_cores); - for (tt::tt_metal::ccl::WorkerXY const& core : args_struct.dest_cores) { + for (ccl::WorkerXY const& core : args_struct.dest_cores) { args.push_back(core.to_uint32()); } @@ -717,7 +717,7 @@ struct FullWorkerGridShardAddrGenArgGenerator { auto const& tensor_shard_grid = input_tensor.buffer()->shard_spec().grid(); this->args_struct.dest_cores = OutputTensorShardAddrGenArgGenerator::compute_worker_dest_cores ( - tt::tt_metal::ccl::ShardType::Width, + ttnn::utils::ccl::ShardType::Width, *device, tensor_shard_grid, tensor_shard_grid.num_cores(), @@ -730,7 +730,7 @@ struct FullWorkerGridShardAddrGenArgGenerator { this->initialized = true; } - tt::tt_metal::ccl::FullWorkerGridShardAddrGenArgs args_struct; + ttnn::utils::ccl::FullWorkerGridShardAddrGenArgs args_struct; bool initialized; }; diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp index c9d81acd9ad..2280d7b352e 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp @@ -403,8 +403,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& for (uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { uint32_t num_workers_per_eth_buffer = std::min(workers_per_link, all_gather_config.get_num_eth_buffers_per_edm() - worker_index); - std::vector sender_worker_coords; - std::vector receiver_worker_coords; + std::vector sender_worker_coords; + std::vector receiver_worker_coords; for (uint32_t w = b * num_workers_per_eth_buffer; w < (b + 1) * num_workers_per_eth_buffer; ++w) { sender_worker_coords.push_back( ttnn::ccl::WorkerXY( diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp index 16ea367a731..2c9dac84c8b 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp @@ -45,7 +45,7 @@ void generate_edm_kernels_for_ring_or_linear_topology( auto eth_sender_core = topology_config.eth_sender_cores.at(i); log_trace(tt::LogOp, "EDM CLOCKWISE KERNEL RT ARGS: "); auto eth_sender_kernel = - ccl::generate_edm_kernel(program, device, clockwise_edm_builders.at(i), eth_sender_core, sender_noc); + ttnn::utils::ccl::generate_edm_kernel(program, device, clockwise_edm_builders.at(i), eth_sender_core, sender_noc); log_trace( tt::LogOp, "RingIndex: {}. Link {}. Clockwise EDM Core (x={},y={})", @@ -59,7 +59,7 @@ void generate_edm_kernels_for_ring_or_linear_topology( if (is_counter_clockwise_direction_edm_enabled) { log_trace(tt::LogOp, "EDM COUNTER CLOCKWISE KERNEL RT ARGS: "); auto eth_receiver_core = topology_config.eth_receiver_cores.at(i); - auto eth_receiver_kernel = ccl::generate_edm_kernel( + auto eth_receiver_kernel = ttnn::utils::ccl::generate_edm_kernel( program, device, counter_clockwise_edm_builders.at(i), eth_receiver_core, receiver_noc); log_trace( tt::LogOp, @@ -75,7 +75,7 @@ void generate_edm_kernels_for_ring_or_linear_topology( KernelHandle generate_edm_kernel( tt::tt_metal::Program& program, Device const* device, - ccl::EriscDatamoverBuilder const& edm_builder, + ttnn::utils::ccl::EriscDatamoverBuilder const& edm_builder, CoreCoord const& eth_core, NOC noc_id) { log_trace(tt::LogOp, "EDM CLOCKWISE KERNEL RT ARGS: "); @@ -110,29 +110,29 @@ KernelHandle generate_edm_kernel( ccl::EriscDatamoverBuilder create_erisc_datamover_builder( std::size_t num_channels, uint32_t page_size, - ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, - ccl::EriscDataMoverTerminationMode termination_mode) { + ttnn::utils::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, + ttnn::utils::ccl::EriscDataMoverTerminationMode termination_mode) { TT_ASSERT(num_channels > 0); std::vector edm_sem_addresses(num_channels, 0); std::vector edm_buffer_addresses(num_channels, 0); - uint32_t edm_sem_addr = ccl::EriscDatamoverConfig::get_semaphores_base_address(num_channels); - uint32_t edm_buffer_addr = ccl::EriscDatamoverConfig::get_buffers_base_address(num_channels); + uint32_t edm_sem_addr = ttnn::utils::ccl::EriscDatamoverConfig::get_semaphores_base_address(num_channels); + uint32_t edm_buffer_addr = ttnn::utils::ccl::EriscDatamoverConfig::get_buffers_base_address(num_channels); TT_ASSERT(edm_sem_addr > 0); TT_ASSERT(edm_buffer_addr > 0); - const uint32_t buffer_size = ccl::EriscDatamoverConfig::compute_buffer_size(num_channels, page_size); + const uint32_t buffer_size = ttnn::utils::ccl::EriscDatamoverConfig::compute_buffer_size(num_channels, page_size); for (std::size_t c = 0; c < num_channels; ++c) { edm_sem_addresses.at(c) = edm_sem_addr; - edm_sem_addr += ccl::EriscDatamoverConfig::semaphore_size; + edm_sem_addr += ttnn::utils::ccl::EriscDatamoverConfig::semaphore_size; edm_buffer_addresses.at(c) = edm_buffer_addr; edm_buffer_addr += buffer_size; TT_ASSERT((c == 0) || (edm_buffer_addresses.back() != edm_buffer_addresses.front())); TT_ASSERT((c == 0) || (edm_sem_addresses.back() != edm_sem_addresses.front())); } - return ccl::EriscDatamoverBuilder( + return ttnn::utils::ccl::EriscDatamoverBuilder( buffer_size, - ccl::EriscDatamoverConfig::get_edm_handshake_address(), + ttnn::utils::ccl::EriscDatamoverConfig::get_edm_handshake_address(), edm_sem_addresses, edm_buffer_addresses, buffer_sharing_mode, diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp index 9a71b4b3034..6a4bb07953c 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp @@ -318,8 +318,8 @@ class RingReduceScatterTensorSlicer : public LegacyCclTensorSlicer { uint32_t max_slice_size_in_bytes, uint32_t half_cb_n_pages); - ccl::InterleavedTensorWorkerSlice get_worker_slice(std::size_t global_worker_index) { - return ccl::InterleavedTensorWorkerSlice( + ttnn::utils::ccl::InterleavedTensorWorkerSlice get_worker_slice(std::size_t global_worker_index) { + return ttnn::utils::ccl::InterleavedTensorWorkerSlice( this->flattened_tensor_shape, this->tensor_slice_shape, this->worker_slice_shapes.at(global_worker_index), @@ -452,7 +452,7 @@ class InterleavedRingAllGatherTensorSlicer : public LegacyCclTensorSlicer { KernelHandle generate_edm_kernel( tt::tt_metal::Program& program, Device const* device, - ccl::EriscDatamoverBuilder const& edm_builder, + ttnn::utils::ccl::EriscDatamoverBuilder const& edm_builder, CoreCoord const& eth_core, NOC noc_id); @@ -468,7 +468,7 @@ void generate_edm_kernels_for_ring_or_linear_topology( ccl::EriscDatamoverBuilder create_erisc_datamover_builder( std::size_t num_channels, uint32_t page_size, - ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, + ttnn::utils::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, EriscDataMoverTerminationMode termination_mode); } // namespace ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp b/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp index 55066f63eea..39bca5454ef 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp @@ -167,8 +167,8 @@ class EriscDatamoverBuilder { uint32_t eth_buffer_size_bytes; uint32_t handshake_addr; uint32_t const num_channel_buffers; - ccl::EriscDataMoverBufferSharingMode const buffer_sharing_mode; - ccl::EriscDataMoverTerminationMode const termination_mode; + ttnn::utils::ccl::EriscDataMoverBufferSharingMode const buffer_sharing_mode; + ttnn::utils::ccl::EriscDataMoverTerminationMode const termination_mode; uint32_t num_senders; uint32_t num_receivers; @@ -187,9 +187,9 @@ class EriscDatamoverBuilder { uint32_t handshake_addr, std::vector const& local_semaphore_addresses, std::vector const& local_buffer_addresses, - ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, - ccl::EriscDataMoverTerminationMode termination_mode = - ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) : + ttnn::utils::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, + ttnn::utils::ccl::EriscDataMoverTerminationMode termination_mode = + ttnn::utils::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) : local_semaphore_addresses(local_semaphore_addresses), local_buffer_addresses(local_buffer_addresses), eth_buffer_size_bytes(eth_buffer_size), diff --git a/ttnn/cpp/ttnn/operations/ccl/edm/README.md b/ttnn/cpp/ttnn/operations/ccl/edm/README.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp b/ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp new file mode 100644 index 00000000000..1fea3f29555 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp @@ -0,0 +1,748 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include + +#include "dataflow_api.h" +#include "debug/assert.h" +#include "eth_l1_address_map.h" +#include "ethernet/dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "tt_metal/hw/inc/wormhole/noc/noc.h" + +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp +using ttnn::ccl::EriscDataMoverBufferSharingMode; +using ttnn::ccl::EriscDataMoverTerminationMode; +using ttnn::ccl::EriscDataMoverWorkerSignal; +======= +using ttnn::utils::ccl::EriscDataMoverBufferSharingMode; +using ttnn::utils::ccl::EriscDataMoverTerminationMode; +using ttnn::utils::ccl::EriscDataMoverWorkerSignal; +>>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp + +namespace erisc { +namespace datamover { + +template +struct EriscDatamoverConfig { + static constexpr EriscDataMoverBufferSharingMode BUFFER_SHARING_MODE = buffer_sharing_mode; + static constexpr EriscDataMoverTerminationMode TERMINATION_MODE = termination_mode; +}; + +template +struct edm_worker_index {}; + +template <> +struct edm_worker_index { + uint16_t worker_index = 0; +}; + +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp +using ttnn::ccl::WorkerXY; +======= +using ttnn::utils::ccl::WorkerXY; +>>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp + +/* + * The `ChannelBuffer` is a building block of the Erisc Data Mover (EDM). For every concurrent transaction + * channel managed by the EDM, there is a `ChannelBuffer` associated with the. The `ChannelBuffer` manages + * state for the transaction channel, holds information such as buffer and semaphore addresses, and has helper + * functions to more easily check semaphore and ack statuses and to send/receive data and/or semaphore updates. + */ +// template +template +class ChannelBuffer final { + static constexpr EriscDataMoverBufferSharingMode BUFFER_SHARING_MODE = EDM_CONFIG::BUFFER_SHARING_MODE; + static constexpr EriscDataMoverTerminationMode TERMINATION_MODE = EDM_CONFIG::TERMINATION_MODE; + static_assert( + BUFFER_SHARING_MODE == EriscDataMoverBufferSharingMode::NOT_SHARED || + BUFFER_SHARING_MODE == EriscDataMoverBufferSharingMode::ROUND_ROBIN, + "The only BufferSharding modes supported are NOT_SHARED and ROUND_ROBIN"); + + public: + enum STATE : uint8_t { + DONE = 0, + + // For sender: means we are ready to tell the worker(s) that the buffer is available for writing into + // + SIGNALING_WORKER, + + // For sender: we are waiting for the payload to arrive in L1; we are checking local semaphore for worker + // completion For receiver: we are waiting for worker to complete pull of payload from L1; we are checking local + // semaphore for worker completion + WAITING_FOR_WORKER, + + // For sender: means workers have signalled (via semaphores) that the buffer payload is + // ready in L1 + // For receiver: + READY_FOR_ETH_TRANSFER, + + // For sender: means we are waiting for ack from receiver that payload was received + // For receiver: means we are waitinf for a payload from sender + WAITING_FOR_ETH, + }; + + // for default initialization in arrays + ChannelBuffer() : + local_semaphore_address(0), + worker_coords(0), + address(0), + size_in_bytes(0), + worker_semaphore_l1_address(0), + num_workers(0), + num_messages_moved(0), + channel_bytes_sent_address(0), + channel_bytes_acked_address(0), + total_num_messages_to_move(0), + state(STATE::DONE) {} + + ChannelBuffer( + uint32_t eth_transaction_channel, + size_t address, + size_t size_in_bytes, + uint32_t worker_semaphore_l1_address, + uint32_t num_workers, + uint32_t total_num_messages_to_move, + volatile tt_l1_ptr uint32_t *const local_semaphore_address, + tt_l1_ptr const WorkerXY *worker_coords, + bool is_sender_side) : + eth_transaction_channel(eth_transaction_channel), + local_semaphore_address(local_semaphore_address), + worker_coords(worker_coords), + address(address), + size_in_bytes(size_in_bytes), + worker_semaphore_l1_address(worker_semaphore_l1_address), + num_workers(num_workers), + num_messages_moved(0), + channel_bytes_sent_address(&erisc_info->channels[eth_transaction_channel].bytes_sent), + channel_bytes_acked_address(&erisc_info->channels[eth_transaction_channel].receiver_ack), + total_num_messages_to_move(total_num_messages_to_move), + state(is_sender_side ? STATE::WAITING_FOR_WORKER : STATE::WAITING_FOR_ETH), + is_sender_completion_pending(false), + is_sender_side(is_sender_side) { + clear_local_semaphore(); + +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp + if (TERMINATION_MODE != ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED || total_num_messages_to_move != 0) { +======= + if (TERMINATION_MODE != ttnn::utils::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED || total_num_messages_to_move != 0) { +>>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp + if (is_sender_side) { + // Tell the sender side workers that we're ready to accept data on this channel + increment_worker_semaphores(); + } + } else { +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp + ASSERT(TERMINATION_MODE != ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED); +======= + ASSERT(TERMINATION_MODE != ttnn::utils::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED); +>>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp + goto_state(STATE::DONE); + } + }; + + // Resets the semaphore in local L1, which workers write to remotely. + FORCE_INLINE void clear_local_semaphore() { noc_semaphore_set(local_semaphore_address, 0); } + + // Increment the semaphore in the remote L1s of every worker associated with this ChannelBuffer + FORCE_INLINE void increment_worker_semaphores() { + if constexpr (BUFFER_SHARING_MODE == EriscDataMoverBufferSharingMode::NOT_SHARED) { + // We have to be careful that the worker x/y matches for the `noc_index` + // active on the erisc + for (std::size_t i = 0; i < this->num_workers; i++) { + WorkerXY worker_xy = this->worker_coords[i]; + uint64_t worker_semaphore_address = + get_noc_addr((uint32_t)worker_xy.x, (uint32_t)worker_xy.y, this->worker_semaphore_l1_address); + + noc_semaphore_inc(worker_semaphore_address, 1); + } + } else if (BUFFER_SHARING_MODE == EriscDataMoverBufferSharingMode::ROUND_ROBIN) { + WorkerXY worker_xy = this->worker_coords[this->worker_index.worker_index]; + uint64_t worker_semaphore_address = + get_noc_addr((uint32_t)worker_xy.x, (uint32_t)worker_xy.y, this->worker_semaphore_l1_address); + + noc_semaphore_inc(worker_semaphore_address, 1); + this->worker_index.worker_index++; + if (this->worker_index.worker_index >= this->num_workers) { + this->worker_index.worker_index = 0; + } + } else { + ASSERT(false); // Not implemented + } + } + + [[nodiscard]] FORCE_INLINE bool is_local_semaphore_full() const { + if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) { + ASSERT(*(this->local_semaphore_address) <= this->num_workers); + } + return *(this->local_semaphore_address) == this->num_workers; + } + + [[nodiscard]] FORCE_INLINE bool is_active() const { + return this->num_messages_moved < this->total_num_messages_to_move; + } + + [[nodiscard]] STATE get_state() const { return this->state; } + + FORCE_INLINE void goto_state(STATE s) { this->state = s; } + + [[nodiscard]] FORCE_INLINE bool is_waiting_for_workers_core() const { + return this->state == STATE::WAITING_FOR_WORKER; + } + [[nodiscard]] FORCE_INLINE bool is_ready_to_signal_workers() const { + return this->state == STATE::SIGNALING_WORKER; + } + [[nodiscard]] FORCE_INLINE bool is_waiting_for_remote_eth_core() const { + return this->state == STATE::WAITING_FOR_ETH; + } + [[nodiscard]] FORCE_INLINE bool is_ready_for_eth_transfer() const { + return this->state == STATE::READY_FOR_ETH_TRANSFER; + } + [[nodiscard]] FORCE_INLINE bool is_done() const { return this->state == STATE::DONE; } + + [[nodiscard]] FORCE_INLINE uint32_t get_eth_transaction_channel() const { + ASSERT(this->eth_transaction_channel < eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS); + return this->eth_transaction_channel; + } + [[nodiscard]] FORCE_INLINE std::size_t get_remote_eth_buffer_address() const { return this->address; } + [[nodiscard]] FORCE_INLINE std::size_t get_size_in_bytes() const { return this->size_in_bytes; } + [[nodiscard]] FORCE_INLINE std::size_t get_current_payload_size() const { return this->get_size_in_bytes(); } + + [[nodiscard]] FORCE_INLINE std::size_t get_buffer_address() const { return this->address; } + + FORCE_INLINE uint32_t get_messages_moved() { return this->num_messages_moved; } + FORCE_INLINE void increment_messages_moved() { this->num_messages_moved++; } + + [[nodiscard]] FORCE_INLINE bool all_messages_moved() { + return this->num_messages_moved == this->total_num_messages_to_move; + } + + FORCE_INLINE void set_send_completion_pending(bool value) { this->is_sender_completion_pending = value; } + [[nodiscard]] FORCE_INLINE bool is_send_completion_pending() const { return this->is_sender_completion_pending; } + + FORCE_INLINE bool eth_is_receiver_channel_send_done() const { return *this->channel_bytes_sent_address == 0; } + FORCE_INLINE bool eth_bytes_are_available_on_channel() const { return *this->channel_bytes_sent_address != 0; } + FORCE_INLINE bool eth_is_receiver_channel_send_acked() const { return *this->channel_bytes_acked_address != 0; } + volatile tt_l1_ptr uint32_t *const get_channel_bytes_sent_address() { return this->channel_bytes_sent_address; } + volatile tt_l1_ptr uint32_t *const get_channel_bytes_acked_address() { return this->channel_bytes_acked_address; } + + public: + uint32_t eth_transaction_channel; // + volatile tt_l1_ptr uint32_t *const local_semaphore_address; + WorkerXY const *const worker_coords; + std::size_t const address; + std::size_t const size_in_bytes; + // Even for multiple workers, this address will be the same + std::size_t const worker_semaphore_l1_address; + uint32_t const num_workers; + uint32_t num_messages_moved; + volatile tt_l1_ptr uint32_t *const channel_bytes_sent_address; + volatile tt_l1_ptr uint32_t *const channel_bytes_acked_address; + const uint32_t total_num_messages_to_move; + STATE state; + edm_worker_index worker_index; + bool is_sender_completion_pending; + bool is_sender_side; +}; + +template +class QueueIndexPointer { + public: + QueueIndexPointer(uint8_t queue_size) : ptr(0), size(queue_size), wrap_around(queue_size * 2) { + // FWASSERT(queue_size < 128); + } + + [[nodiscard("index was called without consuming the result. Did you mean to call it?")]] T index() const { + return this->ptr >= this->size ? this->ptr - this->size : this->ptr; + } + [[nodiscard("raw_index was called without consuming the result. Did you mean to call it?")]] inline T raw_index() + const { + return this->ptr; + } + [[nodiscard("distance was called without consuming the result. Did you mean to call it?")]] inline static T + distance(QueueIndexPointer ptr, QueueIndexPointer ackptr) { + // FWASSERT(ptr.size == ackptr.size); + return ackptr.ptr > ptr.ptr ? (ptr.wrap_around - ackptr.ptr) + ptr.ptr : ptr.ptr - ackptr.ptr; + } + [[nodiscard("full was called without consuming the result. Did you mean to call it?")]] inline static T full( + QueueIndexPointer ptr, QueueIndexPointer ackptr) { + // FWASSERT(ptr.size == ackptr.size); + return distance(ptr.ptr, ackptr.ptr) >= ptr.size; + } + [[nodiscard("empty was called without consuming the result. Did you mean to call it?")]] inline static T empty( + QueueIndexPointer ptr, QueueIndexPointer ackptr) { + // FWASSERT(ptr.size == ackptr.size); + return ptr.ptr == ackptr.ptr; + } + inline void increment() { this->ptr = this->next_pointer(); } + [[nodiscard( + "next_index was called without consuming the result. Did you mean to call it?")]] inline QueueIndexPointer + next_index() const { + return QueueIndexPointer(this->next_pointer(), this->size); + } + // Compares indices since the raw index is not visible to the user + inline bool operator==(const QueueIndexPointer &other) const { return this->ptr == other.ptr; } + inline bool operator!=(const QueueIndexPointer &other) const { return this->ptr != other.ptr; } + + private: + inline T next_pointer() { + T next_ptr = (this->ptr + 1); + next_ptr = next_ptr == wrap_around ? 0 : next_ptr; + return next_ptr; + } + QueueIndexPointer(T ptr, uint8_t queue_size) : ptr(ptr), size(queue_size), wrap_around(queue_size * 2) {} + T ptr; + uint8_t size; + uint8_t wrap_around; +}; + +FORCE_INLINE void eth_setup_handshake(std::uint32_t handshake_register_address, bool is_sender) { + reinterpret_cast(handshake_register_address)[4] = 1; + reinterpret_cast(handshake_register_address)[5] = 1; + reinterpret_cast(handshake_register_address)[6] = 0x1c0ffee1; + reinterpret_cast(handshake_register_address)[7] = 0x1c0ffee2; + + erisc_info->channels[0].receiver_ack = 0; + for (uint32_t i = 1; i < eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS; i++) { + erisc_info->channels[i].bytes_sent = 0; + erisc_info->channels[i].receiver_ack = 0; + } + *(volatile tt_l1_ptr uint32_t *)handshake_register_address = 0; + if (is_sender) { + eth_wait_receiver_done(); + eth_send_bytes(handshake_register_address, handshake_register_address, 16); + eth_wait_for_receiver_done(); + } else { + eth_wait_for_bytes(16); + eth_receiver_channel_done(0); + } +} + +template +FORCE_INLINE void initialize_transaction_buffer_addresses( + uint32_t max_concurrent_transactions, + uint32_t first_buffer_base_address, + uint32_t num_bytes_per_send, + std::array &transaction_channel_buffer_addresses) { + uint32_t buffer_address = first_buffer_base_address; + for (uint32_t i = 0; i < max_concurrent_transactions; i++) { + transaction_channel_buffer_addresses[i] = buffer_address; + buffer_address += num_bytes_per_send; + } +} + +///////////////////////////////////////////// +// SENDER SIDE HELPERS +///////////////////////////////////////////// + +template +FORCE_INLINE bool sender_eth_send_data_sequence(ChannelBuffer &sender_buffer_channel) { + bool did_something = false; + if (sender_buffer_channel.eth_is_receiver_channel_send_done()) { + bool need_to_send_completion = sender_buffer_channel.is_send_completion_pending(); + if (!sender_buffer_channel.is_send_completion_pending() && !eth_txq_is_busy()) { + static constexpr std::size_t ETH_BYTES_TO_WORDS_SHIFT = 4; + eth_send_bytes_over_channel_payload_only( + sender_buffer_channel.get_buffer_address(), + sender_buffer_channel.get_remote_eth_buffer_address(), + sender_buffer_channel.get_current_payload_size(), + sender_buffer_channel.get_eth_transaction_channel(), + sender_buffer_channel.get_current_payload_size(), + sender_buffer_channel.get_current_payload_size() >> ETH_BYTES_TO_WORDS_SHIFT); + + sender_buffer_channel.set_send_completion_pending(true); + need_to_send_completion = true; + did_something = true; + } + + if (need_to_send_completion && !eth_txq_is_busy()) { + eth_send_payload_complete_signal_over_channel( + sender_buffer_channel.get_eth_transaction_channel(), sender_buffer_channel.get_current_payload_size()); + sender_buffer_channel.set_send_completion_pending(false); + sender_buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_ETH); + did_something = true; + } + } + + return did_something; +} + +template +FORCE_INLINE bool sender_notify_workers_if_buffer_available_sequence( + ChannelBuffer &sender_buffer_channel, uint32_t &num_senders_complete) { + bool channel_done = false; + if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) { + channel_done = sender_buffer_channel.all_messages_moved(); + } else if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::WORKER_INITIATED) { + // Nothing to do here because in this termination mode, we must check the signal in a different state + } else { + ASSERT(false); + } + + sender_buffer_channel.clear_local_semaphore(); + sender_buffer_channel.increment_worker_semaphores(); + + if (!channel_done) { + sender_buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_WORKER); + } else { + sender_buffer_channel.goto_state(ChannelBuffer::DONE); + num_senders_complete++; + } + + return true; +} + +template +FORCE_INLINE bool sender_eth_check_receiver_ack_sequence( + ChannelBuffer &sender_buffer_channel, uint32_t &num_senders_complete) { + bool did_something = false; + + bool transimission_acked_by_receiver = sender_buffer_channel.eth_is_receiver_channel_send_acked() || + sender_buffer_channel.eth_is_receiver_channel_send_done(); + if (transimission_acked_by_receiver) { + eth_clear_sender_channel_ack(sender_buffer_channel.get_eth_transaction_channel()); + sender_buffer_channel.increment_messages_moved(); + sender_buffer_channel.goto_state(ChannelBuffer::SIGNALING_WORKER); + sender_notify_workers_if_buffer_available_sequence(sender_buffer_channel, num_senders_complete); + did_something = true; + } + + return did_something; +} + +/* + * + */ +template +FORCE_INLINE bool sender_noc_receive_payload_ack_check_sequence( + ChannelBuffer &sender_channel_buffer, uint32_t &num_senders_complete) { + bool did_something = false; + + if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::WORKER_INITIATED) { + if (*sender_channel_buffer.local_semaphore_address == EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY) { + sender_channel_buffer.clear_local_semaphore(); + sender_channel_buffer.goto_state(ChannelBuffer::DONE); + num_senders_complete++; + return true; + } + } + + bool read_finished = sender_channel_buffer.is_local_semaphore_full(); + if (read_finished) { + // We can clear the semaphore, and wait for space on receiver + // sender_channel_buffer.clear_local_semaphore(); + sender_channel_buffer.goto_state(ChannelBuffer::READY_FOR_ETH_TRANSFER); + did_something = true; + + erisc::datamover::sender_eth_send_data_sequence(sender_channel_buffer); + } + + return did_something; +} + +///////////////////////////////////////////// +// RECEIVER SIDE HELPERS +///////////////////////////////////////////// + +/* + * + */ +template +FORCE_INLINE bool receiver_eth_notify_workers_payload_available_sequence(ChannelBuffer &buffer_channel) { + buffer_channel.clear_local_semaphore(); + buffer_channel.increment_worker_semaphores(); + buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_WORKER); + + return true; +} + +/* + * If payload received, notify (send ack to) sender so sender knows it can free up its local buffer + * + */ +template +FORCE_INLINE bool receiver_eth_accept_payload_sequence( + ChannelBuffer &buffer_channel, + uint32_t &num_receivers_complete, + uint32_t eth_transaction_ack_word_addr) { + bool did_something = false; + + if (buffer_channel.eth_bytes_are_available_on_channel()) { + if (!eth_txq_is_busy()) { + eth_receiver_channel_ack(buffer_channel.get_eth_transaction_channel(), eth_transaction_ack_word_addr); + buffer_channel.goto_state(ChannelBuffer::SIGNALING_WORKER); + did_something = true; + + // FIXME: Decouple these so we can still signal workers even if eth command queue is busy + // Prefer sending eth ack first, but notify workers even if we have to come back to + // send the eth ack later + receiver_eth_notify_workers_payload_available_sequence(buffer_channel); + } + } + + return did_something; +} + +/* + * Does something if we are waiting for workers to complete their read and the read is complete: + * - increment messages moved (that transfer is done) + * - notifies sender it is safe to send next payload + * - clear local semaphore + */ +template +FORCE_INLINE bool receiver_noc_read_worker_completion_check_sequence( + ChannelBuffer &buffer_channel, + uint32_t &num_receivers_complete, + uint32_t eth_transaction_complete_addr) { + bool did_something = false; + + bool workers_are_finished_reading = buffer_channel.is_local_semaphore_full(); + + if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::WORKER_INITIATED) { + // May have already gotten final termination signal by this point so check for that too + workers_are_finished_reading = + workers_are_finished_reading || + (*buffer_channel.local_semaphore_address == EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY); + } + + bool can_notify_sender_of_buffer_available = workers_are_finished_reading; + if (can_notify_sender_of_buffer_available) { + if (!eth_txq_is_busy()) { + eth_receiver_channel_done(buffer_channel.get_eth_transaction_channel()); + buffer_channel.increment_messages_moved(); + + bool channel_done = false; + if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) { + channel_done = buffer_channel.all_messages_moved(); + } else if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::WORKER_INITIATED) { + channel_done = (*buffer_channel.local_semaphore_address == EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY); + } else { + ASSERT(false); + } + + if (!channel_done) { + buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_ETH); + } else { + buffer_channel.goto_state(ChannelBuffer::DONE); + num_receivers_complete++; + } + + did_something = true; + } + } + + return did_something; +} + +//////////////////////////// +// DEPRECATED +//////////////////////////// +namespace deprecated { +// This namespace exists to support non-decoupled mode microbenchmarks until those are available +// in decoupled mode + +FORCE_INLINE bool sender_buffer_pool_full( + const QueueIndexPointer noc_reader_buffer_wrptr, + const QueueIndexPointer noc_reader_buffer_ackptr, + const QueueIndexPointer eth_sender_rdptr, + const QueueIndexPointer eth_sender_ackptr) { + return QueueIndexPointer::full(noc_reader_buffer_wrptr, eth_sender_ackptr); +} + +FORCE_INLINE bool sender_buffer_pool_empty( + const QueueIndexPointer noc_reader_buffer_wrptr, + const QueueIndexPointer noc_reader_buffer_ackptr, + const QueueIndexPointer eth_sender_rdptr, + const QueueIndexPointer eth_sender_ackptr) { + return QueueIndexPointer::empty(eth_sender_rdptr, noc_reader_buffer_wrptr); +} + +FORCE_INLINE bool sender_buffer_available_for_eth_send( + const QueueIndexPointer noc_reader_buffer_wrptr, + const QueueIndexPointer noc_reader_buffer_ackptr, + const QueueIndexPointer eth_sender_rdptr, + const QueueIndexPointer eth_sender_ackptr) { + return eth_sender_rdptr != noc_reader_buffer_ackptr; +} + +template +FORCE_INLINE bool sender_eth_send_data_sequence( + std::array &transaction_channel_sender_buffer_addresses, + std::array &transaction_channel_receiver_buffer_addresses, + uint32_t local_eth_l1_src_addr, + uint32_t remote_eth_l1_dst_addr, + uint32_t num_bytes, + uint32_t num_bytes_per_send, + uint32_t num_bytes_per_send_word_size, + QueueIndexPointer noc_reader_buffer_wrptr, + QueueIndexPointer noc_reader_buffer_ackptr, + QueueIndexPointer ð_sender_rdptr, + QueueIndexPointer ð_sender_ackptr) { + bool did_something = false; + bool data_ready_for_send = sender_buffer_available_for_eth_send( + noc_reader_buffer_wrptr, noc_reader_buffer_ackptr, eth_sender_rdptr, eth_sender_ackptr); + if (data_ready_for_send) { + bool consumer_ready_to_accept = eth_is_receiver_channel_send_done(eth_sender_rdptr.index()); + if (consumer_ready_to_accept) { + // Queue up another send + uint32_t sender_buffer_address = transaction_channel_sender_buffer_addresses[eth_sender_rdptr.index()]; + uint32_t receiver_buffer_address = transaction_channel_receiver_buffer_addresses[eth_sender_rdptr.index()]; + + eth_send_bytes_over_channel( + sender_buffer_address, + receiver_buffer_address, + num_bytes, + eth_sender_rdptr.index(), + num_bytes_per_send, + num_bytes_per_send_word_size); + eth_sender_rdptr.increment(); + did_something = true; + } + } + + return did_something; +} + +FORCE_INLINE bool sender_eth_check_receiver_ack_sequence( + const QueueIndexPointer noc_reader_buffer_wrptr, + const QueueIndexPointer noc_reader_buffer_ackptr, + QueueIndexPointer ð_sender_rdptr, + QueueIndexPointer ð_sender_ackptr, + uint32_t &num_eth_sends_acked) { + bool did_something = false; + bool eth_sends_unacknowledged = QueueIndexPointer::distance(eth_sender_rdptr, eth_sender_ackptr) > 0; + if (eth_sends_unacknowledged) { + bool transimission_acked_by_receiver = eth_is_receiver_channel_send_acked(eth_sender_ackptr.index()) || + eth_is_receiver_channel_send_done(eth_sender_ackptr.index()); + if (transimission_acked_by_receiver) { + num_eth_sends_acked++; + eth_sender_ackptr.increment(); + + did_something = true; + } + } + + return did_something; +} + +FORCE_INLINE bool sender_is_noc_read_in_progress( + const QueueIndexPointer noc_reader_buffer_wrptr, + const QueueIndexPointer noc_reader_buffer_ackptr) { + return noc_reader_buffer_wrptr != noc_reader_buffer_ackptr; +} + +FORCE_INLINE bool sender_noc_receive_payload_ack_check_sequence( + QueueIndexPointer &noc_reader_buffer_wrptr, + QueueIndexPointer &noc_reader_buffer_ackptr, + const uint8_t noc_index) { + bool did_something = false; + + bool noc_read_is_in_progress = sender_is_noc_read_in_progress(noc_reader_buffer_wrptr, noc_reader_buffer_ackptr); + if (noc_read_is_in_progress) { +#if EMULATE_DRAM_READ_CYCLES == 1 + bool read_finished = emulated_dram_read_cycles_finished(); +#else + bool read_finished = ncrisc_noc_reads_flushed(noc_index); +#endif + if (read_finished) { + noc_reader_buffer_ackptr.increment(); + did_something = true; + } + } + + return did_something; +} + +///////////////////////////////////////////// +// RECEIVER SIDE HELPERS +///////////////////////////////////////////// + +FORCE_INLINE bool receiver_is_noc_write_in_progress( + const QueueIndexPointer noc_writer_buffer_wrptr, + const QueueIndexPointer noc_writer_buffer_ackptr) { + return noc_writer_buffer_wrptr != noc_writer_buffer_ackptr; +} + +bool receiver_eth_accept_payload_sequence( + QueueIndexPointer noc_writer_buffer_wrptr, + QueueIndexPointer noc_writer_buffer_ackptr, + QueueIndexPointer ð_receiver_ptr, + QueueIndexPointer ð_receiver_ackptr, + uint32_t eth_channel_sync_ack_addr) { + bool did_something = false; + bool receive_pointers_full = QueueIndexPointer::full(eth_receiver_ptr, eth_receiver_ackptr); + + if (!receive_pointers_full) { + if (eth_bytes_are_available_on_channel(eth_receiver_ptr.index())) { + // DPRINT << "rx: accepting payload, sending receive ack on channel " << (uint32_t)eth_receiver_ptr << "\n"; + eth_receiver_channel_ack(eth_receiver_ptr.index(), eth_channel_sync_ack_addr); + eth_receiver_ptr.increment(); + did_something = true; + } + } + + return did_something; +} + +FORCE_INLINE bool receiver_noc_read_worker_completion_check_sequence( + QueueIndexPointer &noc_writer_buffer_wrptr, + QueueIndexPointer &noc_writer_buffer_ackptr, + const uint8_t noc_index) { + bool did_something = false; + + bool noc_write_is_in_progress = + receiver_is_noc_write_in_progress(noc_writer_buffer_wrptr, noc_writer_buffer_ackptr); + if (noc_write_is_in_progress) { +#if EMULATE_DRAM_READ_CYCLES == 1 + bool write_finished = emulated_dram_write_cycles_finished(); +#else + bool writes_finished = ncrisc_noc_nonposted_writes_sent(noc_index); +#endif + if (writes_finished) { + // DPRINT << "rx: accepting payload, sending receive ack on channel " << (uint32_t)noc_writer_buffer_ackptr + // << "\n"; + noc_writer_buffer_ackptr.increment(); + + did_something = true; + } + } + + return did_something; +} + +FORCE_INLINE bool receiver_eth_send_ack_to_sender_sequence( + const QueueIndexPointer noc_writer_buffer_wrptr, + const QueueIndexPointer noc_writer_buffer_ackptr, + QueueIndexPointer ð_receiver_rdptr, + QueueIndexPointer ð_receiver_ackptr, + uint32_t &num_eth_sends_acked) { + bool did_something = false; + bool eth_sends_unacknowledged = eth_receiver_rdptr != eth_receiver_ackptr; + if (eth_sends_unacknowledged) { + // If data is done being sent out of this local l1 buffer and to the destination(s), + // then we can safely send the ack and increment the ackptr + bool buffer_writes_flushed = ncrisc_noc_nonposted_writes_sent(noc_index); + // bool buffer_writes_flushed = ncrisc_noc_nonposted_writes_flushed(noc_index); + if (buffer_writes_flushed) { + // DPRINT << "rx: accepting payload, sending receive ack on channel " << (uint32_t)noc_writer_buffer_wrptr + // << "\n"; + eth_receiver_channel_done(eth_receiver_ackptr.index()); + num_eth_sends_acked++; + eth_receiver_ackptr.increment(); + // DPRINT << "rx: Sending eth ack. ackptr incrementing to " << (uint32_t)eth_receiver_ackptr.index() << + // "\n"; + + did_something = true; + } + } + + return did_something; +} + +}; // namespace deprecated + +}; // namespace datamover +}; // namespace erisc diff --git a/ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp b/ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp new file mode 100644 index 00000000000..a137ec25813 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp @@ -0,0 +1,343 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "dataflow_api.h" +#include "debug/dprint.h" +#include "eth_l1_address_map.h" +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp + +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp" +======= +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp" +>>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp + +// Args Schema: +// 1) handshake addr +// 2) sender channels offset (indicates for the erisc channels, where the senders start +// so sender and receivers don't clash when paired with sender/receiver on the other +// end of the link.) +// 3) sender num channels (How many erisc channels to use. ALso how many buffers to instantiate) +// Informs how many times to iterate through the next group of args +// 4) sender_buffer_address +// 5) sender_num_messages_to_send +// 6) sender_channel_size +// 7) sender_semaphores_base_address +// 8) worker_semaphore_address +// 9) sender_num_workers +// Informs how many worker X/Y coords to accept in the next loop. Each X/Y pair is 2 uint16s +// 10) worker_coord(s) +// ... +// Repeat from step 2 for receiver side + +// Intended only for (performance) test use cases +FORCE_INLINE void eth_setup_handshake2(std::uint32_t handshake_register_address, bool is_sender) { + if (is_sender) { + DPRINT << "eth_send_bytes\n"; + eth_send_bytes(handshake_register_address, handshake_register_address, 16); + DPRINT << "eth_wait_for_receiver_done\n"; + eth_wait_for_receiver_done(); + } else { + DPRINT << "eth_wait_for_bytes\n"; + eth_wait_for_bytes(16); + DPRINT << "wait eth_receiver_done\n"; + eth_receiver_channel_done(0); + } +} + +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp +using ttnn::ccl::WorkerXY; +======= +using ttnn::utils::ccl::WorkerXY; +>>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp + +template +struct sender_receiver_index_t { + static constexpr bool ZERO_SENDERS = num_senders == 0; + static constexpr bool ZERO_RECEIVERS = num_receivers == 0; + static constexpr bool NUM_SENDERS_IS_POW_2 = !ZERO_SENDERS && (((num_senders - 1) & num_senders) == 0); + static constexpr bool NUM_RECEIVERS_IS_POW_2 = !ZERO_RECEIVERS && (((num_receivers - 1) & num_receivers) == 0); + static constexpr uint16_t SENDER_INCR_MASK = !ZERO_SENDERS ? num_senders - 1 : 0; + static constexpr uint16_t RECEIVER_INCR_MASK = !ZERO_RECEIVERS ? num_receivers - 1 : 0; + static constexpr uint16_t COMBINED_INCR_MASK = SENDER_INCR_MASK << 8 | RECEIVER_INCR_MASK; + static constexpr uint16_t COMBINED_INCR = (1 << 8) | 1; + union { + struct { + uint8_t sender; + uint8_t receiver; + }; + uint16_t combined; + } index; + union { + struct { + uint8_t sender; + uint8_t receiver; + }; + uint16_t combined; + } real_index; + union { + struct { + uint8_t sender; + uint8_t receiver; + }; + uint16_t combined; + } start; + + sender_receiver_index_t(uint8_t send_start, uint8_t receive_start, uint8_t num_send, uint8_t num_receive) { + start.sender = send_start; + start.receiver = receive_start; + index.sender = 0; + index.receiver = 0; + real_index.sender = send_start; + real_index.receiver = receive_start; + } + + FORCE_INLINE void increment() { + if constexpr (NUM_SENDERS_IS_POW_2 and NUM_RECEIVERS_IS_POW_2) { + index.combined = (index.combined + COMBINED_INCR) & COMBINED_INCR_MASK; + real_index.combined = start.combined + index.combined; + } else if constexpr (ZERO_RECEIVERS and NUM_SENDERS_IS_POW_2) { + index.sender = (index.sender + 1) & SENDER_INCR_MASK; + real_index.sender = start.sender + index.sender; + } else if constexpr (ZERO_SENDERS and NUM_RECEIVERS_IS_POW_2) { + index.receiver = (index.receiver + 1) & RECEIVER_INCR_MASK; + real_index.receiver = start.receiver + index.receiver; + } else { + index.combined += COMBINED_INCR; + index.sender = index.sender >= num_senders ? 0 : index.sender; + index.receiver = index.receiver >= num_receivers ? 0 : index.receiver; + real_index.combined = start.combined + index.combined; + } + } +}; + +void kernel_main() { + // COMPILE TIME ARGS + // If true, will enable this erisc's sender functionality + constexpr bool enable_sender_side = get_compile_time_arg_val(0) != 0; + + // If true, will enable this erisc's receiver functionality + constexpr bool enable_receiver_side = get_compile_time_arg_val(1) != 0; + + constexpr uint32_t num_senders = get_compile_time_arg_val(2); + constexpr uint32_t num_receivers = get_compile_time_arg_val(3); + +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp + constexpr ttnn::ccl::EriscDataMoverBufferSharingMode edm_buffer_sharing_mode = + static_cast(get_compile_time_arg_val(4)); + + constexpr ttnn::ccl::EriscDataMoverTerminationMode terminate_on_worker_signal = + static_cast(get_compile_time_arg_val(5)); +======= + constexpr ttnn::utils::ccl::EriscDataMoverBufferSharingMode edm_buffer_sharing_mode = + static_cast(get_compile_time_arg_val(4)); + + constexpr ttnn::utils::ccl::EriscDataMoverTerminationMode terminate_on_worker_signal = + static_cast(get_compile_time_arg_val(5)); +>>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp + + constexpr auto EDM_CONFIG = erisc::datamover::EriscDatamoverConfig(); + using EDM_CONFIG_T = decltype(EDM_CONFIG); + using ChannelBufferT = erisc::datamover::ChannelBuffer; + std::array buffer_channels; + + // + std::array printed_receiver_done; + + // SENDER ARGS + uint32_t args_offset = 0; + uint32_t handshake_addr = get_arg_val(args_offset++); + + uint8_t const sender_channels_start = get_arg_val(args_offset++); + uint32_t const sender_num_channels = num_senders;//get_arg_val(args_offset++); + uint8_t num_senders_with_no_work = 0; + for (uint32_t channel = 0; channel < sender_num_channels; channel++) { + uint32_t const sender_buffer_address = get_arg_val(args_offset++); + uint32_t const sender_num_messages_to_send = get_arg_val(args_offset++); + // Each channel buffer is at buffer_base + (channel_id * sender_channel_size) + // Each channel currently constrained to the same buffer size + uint32_t const sender_channel_size = get_arg_val(args_offset++); + // The erisc's local l1 copy of the semaphore workers remotely increment + uint32_t const sender_semaphores_base_address = get_arg_val(args_offset++); + // worker's semaphore L1 address + const uint32_t worker_semaphore_address = get_arg_val(args_offset++); + const uint32_t sender_num_workers = get_arg_val(args_offset++); + const uint32_t workers_xy_list_addr = get_arg_addr(args_offset); + args_offset += sender_num_workers; + new (&buffer_channels[sender_channels_start + channel]) ChannelBufferT( + sender_channels_start + channel, + sender_buffer_address, + sender_channel_size, + worker_semaphore_address, + sender_num_workers, + sender_num_messages_to_send, + (volatile tt_l1_ptr uint32_t *const)sender_semaphores_base_address, + (const WorkerXY *)workers_xy_list_addr, + true); + if constexpr (terminate_on_worker_signal == EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) { + if (sender_num_messages_to_send == 0) { + num_senders_with_no_work++; + } + } + } + + // Receiver args + uint8_t const receiver_channels_start = get_arg_val(args_offset++); + uint32_t const receiver_num_channels = num_receivers;//get_arg_val(args_offset++); + uint8_t num_receivers_with_no_work = 0; + for (uint32_t channel = 0; channel < receiver_num_channels; channel++) { + uint32_t const receiver_buffers_base_address = get_arg_val(args_offset++); + uint32_t const receiver_num_messages_to_send = get_arg_val(args_offset++); + // Each channel buffer is at buffer_base + (channel_id * sender_channel_size) + // Each channel currently constrained to the same buffer size + uint32_t const receiver_channel_size = get_arg_val(args_offset++); + uint32_t const receiver_semaphores_base_address = get_arg_val(args_offset++); + uint32_t const worker_semaphore_address = get_arg_val(args_offset++); + uint32_t const receiver_num_workers = get_arg_val(args_offset++); + const uint32_t workers_xy_list_addr = get_arg_addr(args_offset); + args_offset += receiver_num_workers; + new (&buffer_channels[receiver_channels_start + channel]) ChannelBufferT( + receiver_channels_start + channel, + receiver_buffers_base_address, + receiver_channel_size, + worker_semaphore_address, + receiver_num_workers, + receiver_num_messages_to_send, + (volatile tt_l1_ptr uint32_t *const)receiver_semaphores_base_address, + (const WorkerXY *)workers_xy_list_addr, + false); + + if constexpr (terminate_on_worker_signal == EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) { + if (receiver_num_messages_to_send == 0) { + num_receivers_with_no_work++; + } + } + } + + // Handshake with other erisc to make sure it's safe to start sending/receiving + // Chose an arbitrary ordering mechanism to guarantee one of the erisc's will always be "sender" and the other + // will always be "receiver" (only for handshake purposes) + bool act_as_sender_in_handshake = + (sender_channels_start < receiver_channels_start || receiver_num_channels == 0) && sender_num_channels > 0; + erisc::datamover::eth_setup_handshake(handshake_addr, act_as_sender_in_handshake); + uint32_t eth_transaction_ack_word_addr = handshake_addr + 16; + uint32_t eth_transaction_complete_addr = handshake_addr + 32; + + constexpr uint32_t SWITCH_INTERVAL = 4000000; + uint32_t did_nothing_count = 0; + + uint32_t num_senders_complete = !enable_sender_side ? sender_num_channels : num_senders_with_no_work; + uint32_t num_receivers_complete = !enable_receiver_side ? receiver_num_channels : num_receivers_with_no_work; + bool senders_in_progress = num_senders_complete != sender_num_channels; + bool receivers_in_progress = num_receivers_complete != receiver_num_channels; + + auto send_recv_index = sender_receiver_index_t(sender_channels_start, receiver_channels_start, sender_num_channels, receiver_num_channels); + + while (senders_in_progress || receivers_in_progress) { + bool did_something_sender = false; + bool did_something_receiver = false; + + uint32_t num_receivers_complete_old = num_receivers_complete; + uint32_t num_senders_complete_old = num_senders_complete; + ////////////////////////////////////// + // SENDER + if constexpr (enable_sender_side) { + ChannelBufferT ¤t_sender = buffer_channels[send_recv_index.real_index.sender]; + switch (current_sender.get_state()) { + case ChannelBufferT::STATE::WAITING_FOR_WORKER: + did_something_sender = + erisc::datamover::sender_noc_receive_payload_ack_check_sequence(current_sender, num_senders_complete); + senders_in_progress = senders_in_progress && num_senders_complete != sender_num_channels; + break; + + case ChannelBufferT::STATE::READY_FOR_ETH_TRANSFER: + did_something_sender = erisc::datamover::sender_eth_send_data_sequence(current_sender); + break; + + case ChannelBufferT::STATE::SIGNALING_WORKER: + did_something_sender = erisc::datamover::sender_notify_workers_if_buffer_available_sequence( + current_sender, num_senders_complete); + senders_in_progress = senders_in_progress && num_senders_complete != sender_num_channels; + break; + + case ChannelBufferT::STATE::WAITING_FOR_ETH: + did_something_sender = + erisc::datamover::sender_eth_check_receiver_ack_sequence(current_sender, num_senders_complete); + senders_in_progress = senders_in_progress && num_senders_complete != sender_num_channels; + break; + + default: + break; + }; + } + + ////////////////////////////////////// + // RECEIVER + if constexpr (enable_receiver_side) { + ChannelBufferT ¤t_receiver = buffer_channels[send_recv_index.real_index.receiver]; + + switch (current_receiver.get_state()) { + case ChannelBufferT::STATE::WAITING_FOR_ETH: + did_something_receiver = erisc::datamover::receiver_eth_accept_payload_sequence(current_receiver, num_receivers_complete, eth_transaction_ack_word_addr); + receivers_in_progress = receivers_in_progress && num_receivers_complete != receiver_num_channels; + break; + + case ChannelBufferT::STATE::SIGNALING_WORKER: + did_something_receiver = + erisc::datamover::receiver_eth_notify_workers_payload_available_sequence(current_receiver); + break; + + case ChannelBufferT::STATE::WAITING_FOR_WORKER: + did_something_receiver = erisc::datamover::receiver_noc_read_worker_completion_check_sequence( + current_receiver, num_receivers_complete, eth_transaction_complete_addr); + receivers_in_progress = receivers_in_progress && num_receivers_complete != receiver_num_channels; + break; + + default: + break; + }; + } + send_recv_index.increment(); + ////////////////////////////////////// + + // Enabling this block as is (with all the "did_something"s, seems to cause a loss of about + // 0.5 GBps in throughput) + if (did_something_sender || did_something_receiver) { + did_nothing_count = 0; + } else { + if (did_nothing_count++ > SWITCH_INTERVAL) { + did_nothing_count = 0; + run_routing(); + } + } + } + + for (uint32_t s = 0; s < num_senders + num_receivers; s++ ) { + auto const& channel = buffer_channels[s]; + // We need to explicitly check for channel send done because we may + // advance sender channel state as soon as we receive an ack. Since we + // may be the last active channel, and advance to done state just from ack + // from the receiver ("I got a payload"), then we need to wait for done + // at the very end here. Otherise if we invoke another erisc op back-to-back, + // we may mess up transaction state because it's possible for receiver of this + // op to send the completion done after that one has already started. + uint32_t wait_count = 0; + uint32_t wait_max = 50000; + while(!channel.eth_is_receiver_channel_send_done()) { + wait_count++; + if (wait_count > wait_max) { + + DEBUG_STATUS("STK"); + run_routing(); + wait_count = 0; + } + } + } + + DEBUG_STATUS("DONE"); +} diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp index 8248934c780..1fea3f29555 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp @@ -13,9 +13,15 @@ #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include "tt_metal/hw/inc/wormhole/noc/noc.h" +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp using ttnn::ccl::EriscDataMoverBufferSharingMode; using ttnn::ccl::EriscDataMoverTerminationMode; using ttnn::ccl::EriscDataMoverWorkerSignal; +======= +using ttnn::utils::ccl::EriscDataMoverBufferSharingMode; +using ttnn::utils::ccl::EriscDataMoverTerminationMode; +using ttnn::utils::ccl::EriscDataMoverWorkerSignal; +>>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp namespace erisc { namespace datamover { @@ -34,7 +40,11 @@ struct edm_worker_index { uint16_t worker_index = 0; }; +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp using ttnn::ccl::WorkerXY; +======= +using ttnn::utils::ccl::WorkerXY; +>>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp /* * The `ChannelBuffer` is a building block of the Erisc Data Mover (EDM). For every concurrent transaction @@ -115,13 +125,21 @@ class ChannelBuffer final { is_sender_side(is_sender_side) { clear_local_semaphore(); +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp if (TERMINATION_MODE != ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED || total_num_messages_to_move != 0) { +======= + if (TERMINATION_MODE != ttnn::utils::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED || total_num_messages_to_move != 0) { +>>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp if (is_sender_side) { // Tell the sender side workers that we're ready to accept data on this channel increment_worker_semaphores(); } } else { +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp ASSERT(TERMINATION_MODE != ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED); +======= + ASSERT(TERMINATION_MODE != ttnn::utils::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED); +>>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp goto_state(STATE::DONE); } }; diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp index 23d8c41e252..a137ec25813 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp @@ -8,9 +8,14 @@ #include "dataflow_api.h" #include "debug/dprint.h" #include "eth_l1_address_map.h" +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include "ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp" +======= +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp" +>>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp // Args Schema: // 1) handshake addr @@ -45,7 +50,11 @@ FORCE_INLINE void eth_setup_handshake2(std::uint32_t handshake_register_address, } } +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp using ttnn::ccl::WorkerXY; +======= +using ttnn::utils::ccl::WorkerXY; +>>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp template struct sender_receiver_index_t { @@ -118,11 +127,19 @@ void kernel_main() { constexpr uint32_t num_senders = get_compile_time_arg_val(2); constexpr uint32_t num_receivers = get_compile_time_arg_val(3); +<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp constexpr ttnn::ccl::EriscDataMoverBufferSharingMode edm_buffer_sharing_mode = static_cast(get_compile_time_arg_val(4)); constexpr ttnn::ccl::EriscDataMoverTerminationMode terminate_on_worker_signal = static_cast(get_compile_time_arg_val(5)); +======= + constexpr ttnn::utils::ccl::EriscDataMoverBufferSharingMode edm_buffer_sharing_mode = + static_cast(get_compile_time_arg_val(4)); + + constexpr ttnn::utils::ccl::EriscDataMoverTerminationMode terminate_on_worker_signal = + static_cast(get_compile_time_arg_val(5)); +>>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp constexpr auto EDM_CONFIG = erisc::datamover::EriscDatamoverConfig(); using EDM_CONFIG_T = decltype(EDM_CONFIG); diff --git a/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.hpp index df3a85a94f7..c6217818ae2 100644 --- a/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.hpp @@ -8,19 +8,11 @@ #include "common/core_coord.h" #include "impl/buffers/buffer.hpp" #include "tensor/tensor.hpp" -<<<<<<< HEAD #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/host_api.hpp" #include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" #include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" -======= -#include "tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/host_api.hpp" -#include "tt_dnn/op_library/ccl/ccl_host_datastructures.hpp" -#include "tt_dnn/op_library/ccl/ccl_common.hpp" ->>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN #include "tt_dnn/op_library/run_operation.hpp" @@ -30,21 +22,11 @@ namespace ttnn { -<<<<<<< HEAD namespace all_gather_op { using ccl::Topology; }; // namespace all_gather_op using ccl::EriscDatamoverBuilder; -======= -namespace utils { - -namespace all_gather_op { -using tt::tt_metal::ccl::Topology; -}; // namespace all_gather_op - -using tt::tt_metal::ccl::EriscDatamoverBuilder; ->>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN struct LineAllGather { @@ -61,43 +43,8 @@ struct LineAllGather { std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; -<<<<<<< HEAD }; -======= - - static constexpr auto attribute_names = std::forward_as_tuple( - "dim", - "num_links", - "ring_size", - "ring_index", - "receiver_device_id", - "sender_device_id", - "output_mem_config", - "topology"); - - const auto attribute_values() const { - return std::forward_as_tuple( - dim, num_links, ring_size, ring_index, receiver_device_id, sender_device_id, output_mem_config, topology); - } -}; - -// All Gather Variants -std::vector line_all_gather_impl( - const std::vector& input_tensors, - const uint32_t dim, - const uint32_t num_links, - const MemoryConfig& output_mem_config, - const all_gather_op::Topology topology); -std::vector line_all_gather( - const std::vector &input_tensors, - const uint32_t dim, - const uint32_t num_links = 1, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -} // namespace utils ->>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN - namespace operations { namespace ccl { diff --git a/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp index c6a37929fac..665706be9da 100644 --- a/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp @@ -98,7 +98,7 @@ struct ArchDependentTypes { template <> struct ArchDependentTypes { - using workers_list_t = ccl::WorkerXY *; + using workers_list_t = ttnn::utils::ccl::WorkerXY *; static const workers_list_t WORKERS_LIST_UNINITIALIZED_VALUE; }; From 4d9777de918bc1ea02fee2f515b1780c101b7e03 Mon Sep 17 00:00:00 2001 From: Aswinmcw Date: Thu, 11 Jul 2024 11:29:59 +0000 Subject: [PATCH 04/10] #9486: re-organize namespace --- ttnn/cpp/pybind11/operations/__init__.hpp | 8 +- .../ccl/all_gather/ccl_all_gather_pybind.hpp | 10 +- .../ccl/all_gather/device/all_gather_op.hpp | 76 ++++++------- ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp | 22 ++-- ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp | 8 +- .../ccl/ccl_host_datastructures.hpp | 10 +- .../ccl_line_all_gather_pybind.hpp | 106 ------------------ .../hetergeneous_data_structs.hpp | 2 +- 8 files changed, 67 insertions(+), 175 deletions(-) delete mode 100644 ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp diff --git a/ttnn/cpp/pybind11/operations/__init__.hpp b/ttnn/cpp/pybind11/operations/__init__.hpp index 99402370b64..15e3d418563 100644 --- a/ttnn/cpp/pybind11/operations/__init__.hpp +++ b/ttnn/cpp/pybind11/operations/__init__.hpp @@ -62,12 +62,10 @@ void py_module(py::module& module) { auto m_unary_backward = module.def_submodule("unary_backward", "unary_backward operations"); unary_backward::py_module(m_unary_backward); - - auto m_ccl_all_gather = module.def_submodule("ccl_all_gather", "collective communication operations"); - ccl_all_gather::py_module(m_ccl_all_gather); - auto m_ccl_line_all_gather = module.def_submodule("ccl_line_all_gather", "collective communication operations "); - ccl_line_all_gather::py_module(m_ccl_line_all_gather); + auto m_ccl = module.def_submodule("ccl", "collective communication operations"); + ccl::py_module_all_gather(m_ccl); + ccl::py_module_line_all_gather(m_ccl); auto m_ccl = module.def_submodule("ccl", "collective communication operations"); ccl::py_bind_all_gather(m_ccl); diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/ccl_all_gather_pybind.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/ccl_all_gather_pybind.hpp index 97b1fc3b055..7ed2421a86f 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/ccl_all_gather_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/ccl_all_gather_pybind.hpp @@ -15,12 +15,12 @@ namespace py = pybind11; namespace ttnn { namespace operations { -namespace ccl_all_gather { +namespace ccl { namespace detail { template -void bind_ccl_all_gather(py::module& module, const ccl_operation_t& operation, const char* doc) { +void bind_all_gather(py::module& module, const ccl_operation_t& operation, const char* doc) { bind_registered_operation( module, operation, @@ -43,8 +43,8 @@ void bind_ccl_all_gather(py::module& module, const ccl_operation_t& operation, c } // namespace detail -void py_module(py::module& module) { - detail::bind_ccl_all_gather( +void py_module_all_gather(py::module& module) { + detail::bind_all_gather( module, ttnn::all_gather, R"doc(all_gather(input_tensor: ttnn.Tensor, dim: int, *, num_links: int = 1, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor @@ -67,6 +67,6 @@ void py_module(py::module& module) { )doc"); } -} // namespace ccl_all_gather +} // namespace ccl } // namespace operations } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp index 536e18451e2..c3469068f73 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp @@ -29,10 +29,10 @@ enum AllGatherMode { }; namespace all_gather_op { -using ttnn::utils::ccl::Topology; +using ccl::Topology; }; // namespace all_gather_op -using ttnn::utils::ccl::EriscDatamoverBuilder; +using ccl::EriscDatamoverBuilder; AllGatherMode choose_all_gather_mode(Tensor const& input_tensor, Tensor const& output_tensor, uint32_t dim); @@ -255,13 +255,13 @@ struct ShardedAllGatherConfig { switch(input_tensor.memory_config().memory_layout) { case TensorMemoryLayout::WIDTH_SHARDED: - this->shard_type = ttnn::utils::ccl::ShardType::Width; + this->shard_type = ccl::ShardType::Width; break; case TensorMemoryLayout::BLOCK_SHARDED: - this->shard_type = ttnn::utils::ccl::ShardType::Block; + this->shard_type = ccl::ShardType::Block; break; case TensorMemoryLayout::HEIGHT_SHARDED: - this->shard_type = ttnn::utils::ccl::ShardType::Height; + this->shard_type = ccl::ShardType::Height; break; case TensorMemoryLayout::INTERLEAVED: case TensorMemoryLayout::SINGLE_BANK: @@ -285,7 +285,7 @@ struct ShardedAllGatherConfig { return single_tile_shard_on_dim; } - ttnn::utils::ccl::ShardType get_shard_type() const { + ccl::ShardType get_shard_type() const { TT_ASSERT(is_sharding_enabled, "Tried getting sharding config for non-sharded tensor"); return shard_type; } @@ -293,7 +293,7 @@ struct ShardedAllGatherConfig { private: bool requires_post_all_gather_reshard; bool single_tile_shard_on_dim; - ttnn::utils::ccl::ShardType shard_type; + ccl::ShardType shard_type; bool is_sharding_enabled; }; @@ -312,14 +312,14 @@ struct ShardAddrGenArgGenerator { std::vector args; args.reserve(7 * this->args_struct.num_dest_cores * 2); - TT_ASSERT(this->args_struct.shard_size_in_bytes != ttnn::utils::ccl::UNINITIALIZED_VALUE_U32); - TT_ASSERT(this->args_struct.total_chunks_per_core != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16); - TT_ASSERT(this->args_struct.shards_start_address != ttnn::utils::ccl::UNINITIALIZED_VALUE_U32); - TT_ASSERT(this->args_struct.starting_core_index != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16); - TT_ASSERT(this->args_struct.starting_chunk_into_shard != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16); - TT_ASSERT(this->args_struct.intra_core_stride_in_shards != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16); - TT_ASSERT(this->args_struct.contiguous_chunks_before_stride != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16); - TT_ASSERT(this->args_struct.num_dest_cores != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16); + TT_ASSERT(this->args_struct.shard_size_in_bytes != ccl::UNINITIALIZED_VALUE_U32); + TT_ASSERT(this->args_struct.total_chunks_per_core != ccl::UNINITIALIZED_VALUE_U16); + TT_ASSERT(this->args_struct.shards_start_address != ccl::UNINITIALIZED_VALUE_U32); + TT_ASSERT(this->args_struct.starting_core_index != ccl::UNINITIALIZED_VALUE_U16); + TT_ASSERT(this->args_struct.starting_chunk_into_shard != ccl::UNINITIALIZED_VALUE_U16); + TT_ASSERT(this->args_struct.intra_core_stride_in_shards != ccl::UNINITIALIZED_VALUE_U16); + TT_ASSERT(this->args_struct.contiguous_chunks_before_stride != ccl::UNINITIALIZED_VALUE_U16); + TT_ASSERT(this->args_struct.num_dest_cores != ccl::UNINITIALIZED_VALUE_U16); TT_ASSERT(this->args_struct.dest_cores.size() != 0); args.push_back(this->args_struct.is_clockwise); @@ -360,7 +360,7 @@ struct ShardAddrGenArgGenerator { TT_ASSERT(this->args_struct.starting_core_index < this->args_struct.dest_cores.size()); } - ttnn::utils::ccl::ShardAddrGenArgs args_struct; + ccl::ShardAddrGenArgs args_struct; bool initialized; }; @@ -392,7 +392,7 @@ struct InputTensorShardAddrGenArgGenerator final : public ShardAddrGenArgGenerat } InputTensorShardAddrGenArgGenerator( Device const* device, - ttnn::utils::ccl::CclOpShardedTensorConfig *input_tensor_config, + ccl::CclOpShardedTensorConfig *input_tensor_config, uint32_t ring_index, uint32_t ring_size, uint32_t num_workers, @@ -425,7 +425,7 @@ struct InputTensorShardAddrGenArgGenerator final : public ShardAddrGenArgGenerat this->args_struct.dest_cores.reserve(dest_core_coords.size()); std::transform(dest_core_coords.begin(), dest_core_coords.end(), std::back_inserter(this->args_struct.dest_cores), [&device](CoreCoord const& core) { - return ttnn::utils::ccl::WorkerXY( + return ccl::WorkerXY( static_cast(device->worker_core_from_logical_core(core).x), static_cast(device->worker_core_from_logical_core(core).y) ); @@ -444,7 +444,7 @@ struct InputTensorShardAddrGenArgGenerator final : public ShardAddrGenArgGenerat struct OutputTensorShardAddrGenArgGenerator final : ShardAddrGenArgGenerator { static std::vector compute_worker_coord_worker_dest_cores ( - ttnn::utils::ccl::ShardType shard_type, + ccl::ShardType shard_type, std::vector const& global_shard_dest_cores, uint32_t input_num_shards, uint32_t output_num_shards, @@ -495,7 +495,7 @@ struct OutputTensorShardAddrGenArgGenerator final : ShardAddrGenArgGenerator { static std::vector compute_worker_dest_cores ( - ttnn::utils::ccl::ShardType shard_type, + ccl::ShardType shard_type, Device const& device, CoreRangeSet const& shard_core_range, uint32_t input_num_shards, @@ -516,7 +516,7 @@ struct OutputTensorShardAddrGenArgGenerator final : ShardAddrGenArgGenerator { dest_cores_of_worker.reserve(worker_coord_worker_dest_cores.size()); std::transform(worker_coord_worker_dest_cores.begin(), worker_coord_worker_dest_cores.end(), std::back_inserter(dest_cores_of_worker), [&device](CoreCoord const& core) { - return ttnn::utils::ccl::WorkerXY( + return ccl::WorkerXY( static_cast(device.worker_core_from_logical_core(core).x), static_cast(device.worker_core_from_logical_core(core).y) ); @@ -588,8 +588,8 @@ struct OutputTensorShardAddrGenArgGenerator final : ShardAddrGenArgGenerator { OutputTensorShardAddrGenArgGenerator( AllGatherConfig const& all_gather_config, Device const* device, - ttnn::utils::ccl::CclOpShardedTensorConfig *input_tensor_config, - ttnn::utils::ccl::CclOpShardedTensorConfig *output_tensor_config, + ccl::CclOpShardedTensorConfig *input_tensor_config, + ccl::CclOpShardedTensorConfig *output_tensor_config, uint32_t ring_index, uint32_t ring_size, uint32_t num_workers, @@ -617,7 +617,7 @@ struct OutputTensorShardAddrGenArgGenerator final : ShardAddrGenArgGenerator { uint32_t input_num_shards = sharded_tensor_num_cores; uint32_t output_num_shards = input_num_shards * ring_size; this->args_struct.dest_cores = OutputTensorShardAddrGenArgGenerator::compute_worker_dest_cores ( - ttnn::utils::ccl::ShardType::Width, + ccl::ShardType::Width, *device, tensor_shard_grid, input_num_shards, @@ -630,7 +630,7 @@ struct OutputTensorShardAddrGenArgGenerator final : ShardAddrGenArgGenerator { TT_ASSERT(this->args_struct.dest_cores.size() > 0); std::vector const& global_shard_dest_cores = corerange_to_cores(tensor_shard_grid, std::nullopt, is_shard_orientation_row_major); CoreCoord const& dest_core_coord = global_shard_dest_cores.at(global_starting_dest_worker_index); - ttnn::utils::ccl::WorkerXY noc0_starting_dest_core_xy( + ccl::WorkerXY noc0_starting_dest_core_xy( static_cast(device->worker_core_from_logical_core(dest_core_coord).x), static_cast(device->worker_core_from_logical_core(dest_core_coord).y) ); @@ -655,17 +655,17 @@ struct FullWorkerGridShardAddrGenArgGenerator { args.reserve(12 + args_struct.total_num_cores); TT_ASSERT(args_struct.dest_cores.size() > 0, "dest_cores was uninitialized"); - TT_ASSERT(args_struct.tile_size_in_bytes != ttnn::utils::ccl::UNINITIALIZED_VALUE_U32, "tile_size_in_bytes was uninitialized"); - TT_ASSERT(args_struct.shards_start_address != ttnn::utils::ccl::UNINITIALIZED_VALUE_U32, "shards_start_address was uninitialized"); - TT_ASSERT(args_struct.curr_core_index != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16, "curr_core_index was uninitialized"); - TT_ASSERT(args_struct.total_num_cores != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16, "total_num_cores was uninitialized"); - TT_ASSERT(args_struct.curr_shard_tile_x != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16, "curr_shard_tile_x was uninitialized"); - TT_ASSERT(args_struct.curr_shard_tile_y != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16, "curr_shard_tile_y was uninitialized"); - TT_ASSERT(args_struct.curr_tile_index != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16, "curr_tile_index was uninitialized"); - TT_ASSERT(args_struct.curr_shard != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16, "curr_shard was uninitialized"); - TT_ASSERT(args_struct.input_shard_num_tiles_x != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16, "input_shard_num_tiles_x was uninitialized"); - TT_ASSERT(args_struct.input_shard_num_tiles_y != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16, "input_shard_num_tiles_y was uninitialized"); - TT_ASSERT(args_struct.total_shards_x != ttnn::utils::ccl::UNINITIALIZED_VALUE_U16, "total_shards_x was uninitialized"); + TT_ASSERT(args_struct.tile_size_in_bytes != ccl::UNINITIALIZED_VALUE_U32, "tile_size_in_bytes was uninitialized"); + TT_ASSERT(args_struct.shards_start_address != ccl::UNINITIALIZED_VALUE_U32, "shards_start_address was uninitialized"); + TT_ASSERT(args_struct.curr_core_index != ccl::UNINITIALIZED_VALUE_U16, "curr_core_index was uninitialized"); + TT_ASSERT(args_struct.total_num_cores != ccl::UNINITIALIZED_VALUE_U16, "total_num_cores was uninitialized"); + TT_ASSERT(args_struct.curr_shard_tile_x != ccl::UNINITIALIZED_VALUE_U16, "curr_shard_tile_x was uninitialized"); + TT_ASSERT(args_struct.curr_shard_tile_y != ccl::UNINITIALIZED_VALUE_U16, "curr_shard_tile_y was uninitialized"); + TT_ASSERT(args_struct.curr_tile_index != ccl::UNINITIALIZED_VALUE_U16, "curr_tile_index was uninitialized"); + TT_ASSERT(args_struct.curr_shard != ccl::UNINITIALIZED_VALUE_U16, "curr_shard was uninitialized"); + TT_ASSERT(args_struct.input_shard_num_tiles_x != ccl::UNINITIALIZED_VALUE_U16, "input_shard_num_tiles_x was uninitialized"); + TT_ASSERT(args_struct.input_shard_num_tiles_y != ccl::UNINITIALIZED_VALUE_U16, "input_shard_num_tiles_y was uninitialized"); + TT_ASSERT(args_struct.total_shards_x != ccl::UNINITIALIZED_VALUE_U16, "total_shards_x was uninitialized"); args.push_back(args_struct.tile_size_in_bytes); args.push_back(args_struct.shards_start_address); @@ -717,7 +717,7 @@ struct FullWorkerGridShardAddrGenArgGenerator { auto const& tensor_shard_grid = input_tensor.buffer()->shard_spec().grid(); this->args_struct.dest_cores = OutputTensorShardAddrGenArgGenerator::compute_worker_dest_cores ( - ttnn::utils::ccl::ShardType::Width, + ccl::ShardType::Width, *device, tensor_shard_grid, tensor_shard_grid.num_cores(), @@ -730,7 +730,7 @@ struct FullWorkerGridShardAddrGenArgGenerator { this->initialized = true; } - ttnn::utils::ccl::FullWorkerGridShardAddrGenArgs args_struct; + ccl::FullWorkerGridShardAddrGenArgs args_struct; bool initialized; }; diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp index 2c9dac84c8b..16ea367a731 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp @@ -45,7 +45,7 @@ void generate_edm_kernels_for_ring_or_linear_topology( auto eth_sender_core = topology_config.eth_sender_cores.at(i); log_trace(tt::LogOp, "EDM CLOCKWISE KERNEL RT ARGS: "); auto eth_sender_kernel = - ttnn::utils::ccl::generate_edm_kernel(program, device, clockwise_edm_builders.at(i), eth_sender_core, sender_noc); + ccl::generate_edm_kernel(program, device, clockwise_edm_builders.at(i), eth_sender_core, sender_noc); log_trace( tt::LogOp, "RingIndex: {}. Link {}. Clockwise EDM Core (x={},y={})", @@ -59,7 +59,7 @@ void generate_edm_kernels_for_ring_or_linear_topology( if (is_counter_clockwise_direction_edm_enabled) { log_trace(tt::LogOp, "EDM COUNTER CLOCKWISE KERNEL RT ARGS: "); auto eth_receiver_core = topology_config.eth_receiver_cores.at(i); - auto eth_receiver_kernel = ttnn::utils::ccl::generate_edm_kernel( + auto eth_receiver_kernel = ccl::generate_edm_kernel( program, device, counter_clockwise_edm_builders.at(i), eth_receiver_core, receiver_noc); log_trace( tt::LogOp, @@ -75,7 +75,7 @@ void generate_edm_kernels_for_ring_or_linear_topology( KernelHandle generate_edm_kernel( tt::tt_metal::Program& program, Device const* device, - ttnn::utils::ccl::EriscDatamoverBuilder const& edm_builder, + ccl::EriscDatamoverBuilder const& edm_builder, CoreCoord const& eth_core, NOC noc_id) { log_trace(tt::LogOp, "EDM CLOCKWISE KERNEL RT ARGS: "); @@ -110,29 +110,29 @@ KernelHandle generate_edm_kernel( ccl::EriscDatamoverBuilder create_erisc_datamover_builder( std::size_t num_channels, uint32_t page_size, - ttnn::utils::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, - ttnn::utils::ccl::EriscDataMoverTerminationMode termination_mode) { + ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, + ccl::EriscDataMoverTerminationMode termination_mode) { TT_ASSERT(num_channels > 0); std::vector edm_sem_addresses(num_channels, 0); std::vector edm_buffer_addresses(num_channels, 0); - uint32_t edm_sem_addr = ttnn::utils::ccl::EriscDatamoverConfig::get_semaphores_base_address(num_channels); - uint32_t edm_buffer_addr = ttnn::utils::ccl::EriscDatamoverConfig::get_buffers_base_address(num_channels); + uint32_t edm_sem_addr = ccl::EriscDatamoverConfig::get_semaphores_base_address(num_channels); + uint32_t edm_buffer_addr = ccl::EriscDatamoverConfig::get_buffers_base_address(num_channels); TT_ASSERT(edm_sem_addr > 0); TT_ASSERT(edm_buffer_addr > 0); - const uint32_t buffer_size = ttnn::utils::ccl::EriscDatamoverConfig::compute_buffer_size(num_channels, page_size); + const uint32_t buffer_size = ccl::EriscDatamoverConfig::compute_buffer_size(num_channels, page_size); for (std::size_t c = 0; c < num_channels; ++c) { edm_sem_addresses.at(c) = edm_sem_addr; - edm_sem_addr += ttnn::utils::ccl::EriscDatamoverConfig::semaphore_size; + edm_sem_addr += ccl::EriscDatamoverConfig::semaphore_size; edm_buffer_addresses.at(c) = edm_buffer_addr; edm_buffer_addr += buffer_size; TT_ASSERT((c == 0) || (edm_buffer_addresses.back() != edm_buffer_addresses.front())); TT_ASSERT((c == 0) || (edm_sem_addresses.back() != edm_sem_addresses.front())); } - return ttnn::utils::ccl::EriscDatamoverBuilder( + return ccl::EriscDatamoverBuilder( buffer_size, - ttnn::utils::ccl::EriscDatamoverConfig::get_edm_handshake_address(), + ccl::EriscDatamoverConfig::get_edm_handshake_address(), edm_sem_addresses, edm_buffer_addresses, buffer_sharing_mode, diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp index 6a4bb07953c..9a71b4b3034 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp @@ -318,8 +318,8 @@ class RingReduceScatterTensorSlicer : public LegacyCclTensorSlicer { uint32_t max_slice_size_in_bytes, uint32_t half_cb_n_pages); - ttnn::utils::ccl::InterleavedTensorWorkerSlice get_worker_slice(std::size_t global_worker_index) { - return ttnn::utils::ccl::InterleavedTensorWorkerSlice( + ccl::InterleavedTensorWorkerSlice get_worker_slice(std::size_t global_worker_index) { + return ccl::InterleavedTensorWorkerSlice( this->flattened_tensor_shape, this->tensor_slice_shape, this->worker_slice_shapes.at(global_worker_index), @@ -452,7 +452,7 @@ class InterleavedRingAllGatherTensorSlicer : public LegacyCclTensorSlicer { KernelHandle generate_edm_kernel( tt::tt_metal::Program& program, Device const* device, - ttnn::utils::ccl::EriscDatamoverBuilder const& edm_builder, + ccl::EriscDatamoverBuilder const& edm_builder, CoreCoord const& eth_core, NOC noc_id); @@ -468,7 +468,7 @@ void generate_edm_kernels_for_ring_or_linear_topology( ccl::EriscDatamoverBuilder create_erisc_datamover_builder( std::size_t num_channels, uint32_t page_size, - ttnn::utils::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, + ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, EriscDataMoverTerminationMode termination_mode); } // namespace ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp b/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp index 39bca5454ef..55066f63eea 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp @@ -167,8 +167,8 @@ class EriscDatamoverBuilder { uint32_t eth_buffer_size_bytes; uint32_t handshake_addr; uint32_t const num_channel_buffers; - ttnn::utils::ccl::EriscDataMoverBufferSharingMode const buffer_sharing_mode; - ttnn::utils::ccl::EriscDataMoverTerminationMode const termination_mode; + ccl::EriscDataMoverBufferSharingMode const buffer_sharing_mode; + ccl::EriscDataMoverTerminationMode const termination_mode; uint32_t num_senders; uint32_t num_receivers; @@ -187,9 +187,9 @@ class EriscDatamoverBuilder { uint32_t handshake_addr, std::vector const& local_semaphore_addresses, std::vector const& local_buffer_addresses, - ttnn::utils::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, - ttnn::utils::ccl::EriscDataMoverTerminationMode termination_mode = - ttnn::utils::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) : + ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, + ccl::EriscDataMoverTerminationMode termination_mode = + ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) : local_semaphore_addresses(local_semaphore_addresses), local_buffer_addresses(local_buffer_addresses), eth_buffer_size_bytes(eth_buffer_size), diff --git a/ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp deleted file mode 100644 index eeb5d5a46e9..00000000000 --- a/ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp +++ /dev/null @@ -1,106 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include -#include - -#include "ttnn/cpp/pybind11/decorators.hpp" -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp -#include "ttnn/operations/ccl/all_gather/all_gather_op.hpp" -======= -#include "ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp" ->>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp -#include "ttnn/types.hpp" - -namespace py = pybind11; - -namespace ttnn { -namespace operations { -namespace ccl_line_all_gather { - -namespace detail { - -template -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp -void bind_all_gather(py::module& module, const ccl_operation_t& operation, const char* doc) { -======= -void bind_ccl_line_all_gather(py::module& module, const ccl_operation_t& operation, const char* doc) { ->>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp - bind_registered_operation( - module, - operation, - doc, - ttnn::pybind_overload_t{ - [](const ccl_operation_t& self, - const ttnn::Tensor& input_tensor, - const uint32_t dim, - const uint32_t num_links, - const std::optional& memory_config) -> ttnn::Tensor { - return self(input_tensor, dim, num_links, memory_config); - }, - py::arg("input_tensor"), - py::arg("dim"), - py::kw_only(), - py::arg("num_links") = 1, - py::arg("memory_config") = std::nullopt}); -} - -} // namespace detail - - -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp -void py_bind_all_gather(py::module& module) { - detail::bind_all_gather( - module, - ttnn::all_gather, - R"doc(all_gather(input_tensor: ttnn.Tensor, dim: int, *, num_links: int = 1, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor - - Performs an all-gather operation on multi-device :attr:`input_tensor` across all devices. - - Args: - * :attr:`input_tensor` (ttnn.Tensor): multi-device tensor - * :attr:`dim` (int) - - Keyword Args: - * :attr:`num_links` (int): Number of links to use for the all-gather operation. - * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. - - Example: - - >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) - >>> output = ttnn.all_gather(tensor, dim=0) - - )doc"); -======= -void py_module(py::module& module) { ->>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp - - detail::bind_ccl_line_all_gather( - module, - ttnn::line_all_gather, - R"doc(line_all_gather(input_tensor: ttnn.Tensor, dim: int, *, num_links: int = 1, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor - - Performs an all-gather operation on multi-device :attr:`input_tensor` across all devices. - - Args: - * :attr:`input_tensor` (ttnn.Tensor): multi-device tensor - * :attr:`dim` (int) - - Keyword Args: - * :attr:`num_links` (int): Number of links to use for the all-gather operation. - * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. - - Example: - - >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) - >>> output = ttnn.line_all_gather(tensor, dim=0) - - )doc"); -} - -} // namespace ccl_line_all_gather -} // namespace operations -} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp index 665706be9da..c6a37929fac 100644 --- a/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp @@ -98,7 +98,7 @@ struct ArchDependentTypes { template <> struct ArchDependentTypes { - using workers_list_t = ttnn::utils::ccl::WorkerXY *; + using workers_list_t = ccl::WorkerXY *; static const workers_list_t WORKERS_LIST_UNINITIALIZED_VALUE; }; From 75121b86d664a0ebebd2e6ff4f8d7fdb262bb860 Mon Sep 17 00:00:00 2001 From: Aswinmcw Date: Fri, 12 Jul 2024 04:52:12 +0000 Subject: [PATCH 05/10] #9486: Move kernel files into kernels directory --- .../ccl/all_gather/all_gather_op.hpp | 16 + .../ccl/all_gather/all_gather_pybind.hpp | 22 + .../ccl/all_gather/ccl_all_gather_pybind.hpp | 72 -- .../all_gather/device/ccl_all_gather_op.hpp | 46 -- ttnn/cpp/ttnn/operations/ccl/edm/README.md | 0 .../ccl/edm/erisc_async_datamover.hpp | 748 ------------------ .../operations/ccl/edm/erisc_datamover.cpp | 343 -------- .../ccl/kernels/edm/erisc_datamover.cpp | 9 + .../device/ccl_line_all_gather_op.hpp | 64 -- 9 files changed, 47 insertions(+), 1273 deletions(-) delete mode 100644 ttnn/cpp/ttnn/operations/ccl/all_gather/ccl_all_gather_pybind.hpp delete mode 100644 ttnn/cpp/ttnn/operations/ccl/all_gather/device/ccl_all_gather_op.hpp delete mode 100644 ttnn/cpp/ttnn/operations/ccl/edm/README.md delete mode 100644 ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp delete mode 100644 ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp delete mode 100644 ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp index 2dd167b3db2..b5c3817ee82 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp @@ -4,17 +4,22 @@ #pragma once +<<<<<<< HEAD <<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp #include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" ======= #include "ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.hpp" >>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp +======= +#include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" +>>>>>>> af98ddace6... #9486: Move kernel files into kernels directory #include "ttnn/cpp/ttnn/multi_device.hpp" namespace ttnn { namespace operations { namespace ccl { +<<<<<<< HEAD <<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp struct ExecuteAllGather { @@ -30,6 +35,9 @@ struct ExecuteAllGather { ======= >>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp struct ExecuteLineAllGather { +======= +struct ExecuteAllGather { +>>>>>>> af98ddace6... #9486: Move kernel files into kernels directory static inline const std::array input_tensor_schemas() { return {ttnn::TensorSchema{ 2, @@ -52,13 +60,21 @@ struct ExecuteLineAllGather { const uint32_t dim, const uint32_t num_links = 1, const std::optional& memory_config = std::nullopt) { +<<<<<<< HEAD return ttnn::operations::ccl::line_all_gather(input_tensor, dim, num_links, memory_config); +======= + return ttnn::operations::ccl::all_gather(input_tensor, dim, num_links, memory_config); +>>>>>>> af98ddace6... #9486: Move kernel files into kernels directory } }; } // namespace ccl } // namespace operations +<<<<<<< HEAD constexpr auto line_all_gather = ttnn::register_operation("ttnn::line_all_gather"); +======= +constexpr auto all_gather = ttnn::register_operation("ttnn::all_gather"); +>>>>>>> af98ddace6... #9486: Move kernel files into kernels directory } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp index eeb5d5a46e9..aafbb457f0d 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp @@ -8,27 +8,39 @@ #include #include "ttnn/cpp/pybind11/decorators.hpp" +<<<<<<< HEAD <<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp #include "ttnn/operations/ccl/all_gather/all_gather_op.hpp" ======= #include "ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp" >>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp +======= +#include "ttnn/operations/ccl/all_gather/all_gather_op.hpp" +>>>>>>> af98ddace6... #9486: Move kernel files into kernels directory #include "ttnn/types.hpp" namespace py = pybind11; namespace ttnn { namespace operations { +<<<<<<< HEAD namespace ccl_line_all_gather { +======= +namespace ccl { +>>>>>>> af98ddace6... #9486: Move kernel files into kernels directory namespace detail { template +<<<<<<< HEAD <<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp void bind_all_gather(py::module& module, const ccl_operation_t& operation, const char* doc) { ======= void bind_ccl_line_all_gather(py::module& module, const ccl_operation_t& operation, const char* doc) { >>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp +======= +void bind_all_gather(py::module& module, const ccl_operation_t& operation, const char* doc) { +>>>>>>> af98ddace6... #9486: Move kernel files into kernels directory bind_registered_operation( module, operation, @@ -51,8 +63,12 @@ void bind_ccl_line_all_gather(py::module& module, const ccl_operation_t& operati } // namespace detail +<<<<<<< HEAD <<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp void py_bind_all_gather(py::module& module) { +======= +void py_module_all_gather(py::module& module) { +>>>>>>> af98ddace6... #9486: Move kernel files into kernels directory detail::bind_all_gather( module, ttnn::all_gather, @@ -74,6 +90,7 @@ void py_bind_all_gather(py::module& module) { >>> output = ttnn.all_gather(tensor, dim=0) )doc"); +<<<<<<< HEAD ======= void py_module(py::module& module) { >>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp @@ -102,5 +119,10 @@ void py_module(py::module& module) { } } // namespace ccl_line_all_gather +======= +} + +} // namespace ccl +>>>>>>> af98ddace6... #9486: Move kernel files into kernels directory } // namespace operations } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/ccl_all_gather_pybind.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/ccl_all_gather_pybind.hpp deleted file mode 100644 index 7ed2421a86f..00000000000 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/ccl_all_gather_pybind.hpp +++ /dev/null @@ -1,72 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include -#include - -#include "ttnn/cpp/pybind11/decorators.hpp" -#include "ttnn/operations/ccl/all_gather/device/ccl_all_gather_op.hpp" -#include "ttnn/types.hpp" - -namespace py = pybind11; - -namespace ttnn { -namespace operations { -namespace ccl { - -namespace detail { - -template -void bind_all_gather(py::module& module, const ccl_operation_t& operation, const char* doc) { - bind_registered_operation( - module, - operation, - doc, - ttnn::pybind_overload_t{ - [](const ccl_operation_t& self, - const ttnn::Tensor& input_tensor, - const uint32_t dim, - const uint32_t num_links, - const std::optional& memory_config) -> ttnn::Tensor { - return self(input_tensor, dim, num_links, memory_config); - }, - py::arg("input_tensor"), - py::arg("dim"), - py::kw_only(), - py::arg("num_links") = 1, - py::arg("memory_config") = std::nullopt}); -} - -} // namespace detail - - -void py_module_all_gather(py::module& module) { - detail::bind_all_gather( - module, - ttnn::all_gather, - R"doc(all_gather(input_tensor: ttnn.Tensor, dim: int, *, num_links: int = 1, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor - - Performs an all-gather operation on multi-device :attr:`input_tensor` across all devices. - - Args: - * :attr:`input_tensor` (ttnn.Tensor): multi-device tensor - * :attr:`dim` (int) - - Keyword Args: - * :attr:`num_links` (int): Number of links to use for the all-gather operation. - * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. - - Example: - - >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) - >>> output = ttnn.all_gather(tensor, dim=0) - - )doc"); -} - -} // namespace ccl -} // namespace operations -} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/ccl_all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/ccl_all_gather_op.hpp deleted file mode 100644 index d1f975692be..00000000000 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/ccl_all_gather_op.hpp +++ /dev/null @@ -1,46 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" -#include "ttnn/cpp/ttnn/multi_device.hpp" - -namespace ttnn { -namespace operations { -namespace ccl { - -struct ExecuteAllGather { - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::ROW_MAJOR_LAYOUT, ttnn::TILE_LAYOUT}, - true, - false, - false, - false}}; - } - - template - static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); - } - - static ttnn::Tensor execute_on_main_thread( - const ttnn::Tensor& input_tensor, - const uint32_t dim, - const uint32_t num_links = 1, - const std::optional& memory_config = std::nullopt) { - return ttnn::operations::ccl::all_gather(input_tensor, dim, num_links, memory_config); - } -}; - -} // namespace ccl -} // namespace operations - -constexpr auto all_gather = ttnn::register_operation("ttnn::all_gather"); - -} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/edm/README.md b/ttnn/cpp/ttnn/operations/ccl/edm/README.md deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp b/ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp deleted file mode 100644 index 1fea3f29555..00000000000 --- a/ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp +++ /dev/null @@ -1,748 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include -#include - -#include "dataflow_api.h" -#include "debug/assert.h" -#include "eth_l1_address_map.h" -#include "ethernet/dataflow_api.h" -#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" -#include "tt_metal/hw/inc/wormhole/noc/noc.h" - -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp -using ttnn::ccl::EriscDataMoverBufferSharingMode; -using ttnn::ccl::EriscDataMoverTerminationMode; -using ttnn::ccl::EriscDataMoverWorkerSignal; -======= -using ttnn::utils::ccl::EriscDataMoverBufferSharingMode; -using ttnn::utils::ccl::EriscDataMoverTerminationMode; -using ttnn::utils::ccl::EriscDataMoverWorkerSignal; ->>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp - -namespace erisc { -namespace datamover { - -template -struct EriscDatamoverConfig { - static constexpr EriscDataMoverBufferSharingMode BUFFER_SHARING_MODE = buffer_sharing_mode; - static constexpr EriscDataMoverTerminationMode TERMINATION_MODE = termination_mode; -}; - -template -struct edm_worker_index {}; - -template <> -struct edm_worker_index { - uint16_t worker_index = 0; -}; - -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp -using ttnn::ccl::WorkerXY; -======= -using ttnn::utils::ccl::WorkerXY; ->>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp - -/* - * The `ChannelBuffer` is a building block of the Erisc Data Mover (EDM). For every concurrent transaction - * channel managed by the EDM, there is a `ChannelBuffer` associated with the. The `ChannelBuffer` manages - * state for the transaction channel, holds information such as buffer and semaphore addresses, and has helper - * functions to more easily check semaphore and ack statuses and to send/receive data and/or semaphore updates. - */ -// template -template -class ChannelBuffer final { - static constexpr EriscDataMoverBufferSharingMode BUFFER_SHARING_MODE = EDM_CONFIG::BUFFER_SHARING_MODE; - static constexpr EriscDataMoverTerminationMode TERMINATION_MODE = EDM_CONFIG::TERMINATION_MODE; - static_assert( - BUFFER_SHARING_MODE == EriscDataMoverBufferSharingMode::NOT_SHARED || - BUFFER_SHARING_MODE == EriscDataMoverBufferSharingMode::ROUND_ROBIN, - "The only BufferSharding modes supported are NOT_SHARED and ROUND_ROBIN"); - - public: - enum STATE : uint8_t { - DONE = 0, - - // For sender: means we are ready to tell the worker(s) that the buffer is available for writing into - // - SIGNALING_WORKER, - - // For sender: we are waiting for the payload to arrive in L1; we are checking local semaphore for worker - // completion For receiver: we are waiting for worker to complete pull of payload from L1; we are checking local - // semaphore for worker completion - WAITING_FOR_WORKER, - - // For sender: means workers have signalled (via semaphores) that the buffer payload is - // ready in L1 - // For receiver: - READY_FOR_ETH_TRANSFER, - - // For sender: means we are waiting for ack from receiver that payload was received - // For receiver: means we are waitinf for a payload from sender - WAITING_FOR_ETH, - }; - - // for default initialization in arrays - ChannelBuffer() : - local_semaphore_address(0), - worker_coords(0), - address(0), - size_in_bytes(0), - worker_semaphore_l1_address(0), - num_workers(0), - num_messages_moved(0), - channel_bytes_sent_address(0), - channel_bytes_acked_address(0), - total_num_messages_to_move(0), - state(STATE::DONE) {} - - ChannelBuffer( - uint32_t eth_transaction_channel, - size_t address, - size_t size_in_bytes, - uint32_t worker_semaphore_l1_address, - uint32_t num_workers, - uint32_t total_num_messages_to_move, - volatile tt_l1_ptr uint32_t *const local_semaphore_address, - tt_l1_ptr const WorkerXY *worker_coords, - bool is_sender_side) : - eth_transaction_channel(eth_transaction_channel), - local_semaphore_address(local_semaphore_address), - worker_coords(worker_coords), - address(address), - size_in_bytes(size_in_bytes), - worker_semaphore_l1_address(worker_semaphore_l1_address), - num_workers(num_workers), - num_messages_moved(0), - channel_bytes_sent_address(&erisc_info->channels[eth_transaction_channel].bytes_sent), - channel_bytes_acked_address(&erisc_info->channels[eth_transaction_channel].receiver_ack), - total_num_messages_to_move(total_num_messages_to_move), - state(is_sender_side ? STATE::WAITING_FOR_WORKER : STATE::WAITING_FOR_ETH), - is_sender_completion_pending(false), - is_sender_side(is_sender_side) { - clear_local_semaphore(); - -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp - if (TERMINATION_MODE != ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED || total_num_messages_to_move != 0) { -======= - if (TERMINATION_MODE != ttnn::utils::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED || total_num_messages_to_move != 0) { ->>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp - if (is_sender_side) { - // Tell the sender side workers that we're ready to accept data on this channel - increment_worker_semaphores(); - } - } else { -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp - ASSERT(TERMINATION_MODE != ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED); -======= - ASSERT(TERMINATION_MODE != ttnn::utils::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED); ->>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp - goto_state(STATE::DONE); - } - }; - - // Resets the semaphore in local L1, which workers write to remotely. - FORCE_INLINE void clear_local_semaphore() { noc_semaphore_set(local_semaphore_address, 0); } - - // Increment the semaphore in the remote L1s of every worker associated with this ChannelBuffer - FORCE_INLINE void increment_worker_semaphores() { - if constexpr (BUFFER_SHARING_MODE == EriscDataMoverBufferSharingMode::NOT_SHARED) { - // We have to be careful that the worker x/y matches for the `noc_index` - // active on the erisc - for (std::size_t i = 0; i < this->num_workers; i++) { - WorkerXY worker_xy = this->worker_coords[i]; - uint64_t worker_semaphore_address = - get_noc_addr((uint32_t)worker_xy.x, (uint32_t)worker_xy.y, this->worker_semaphore_l1_address); - - noc_semaphore_inc(worker_semaphore_address, 1); - } - } else if (BUFFER_SHARING_MODE == EriscDataMoverBufferSharingMode::ROUND_ROBIN) { - WorkerXY worker_xy = this->worker_coords[this->worker_index.worker_index]; - uint64_t worker_semaphore_address = - get_noc_addr((uint32_t)worker_xy.x, (uint32_t)worker_xy.y, this->worker_semaphore_l1_address); - - noc_semaphore_inc(worker_semaphore_address, 1); - this->worker_index.worker_index++; - if (this->worker_index.worker_index >= this->num_workers) { - this->worker_index.worker_index = 0; - } - } else { - ASSERT(false); // Not implemented - } - } - - [[nodiscard]] FORCE_INLINE bool is_local_semaphore_full() const { - if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) { - ASSERT(*(this->local_semaphore_address) <= this->num_workers); - } - return *(this->local_semaphore_address) == this->num_workers; - } - - [[nodiscard]] FORCE_INLINE bool is_active() const { - return this->num_messages_moved < this->total_num_messages_to_move; - } - - [[nodiscard]] STATE get_state() const { return this->state; } - - FORCE_INLINE void goto_state(STATE s) { this->state = s; } - - [[nodiscard]] FORCE_INLINE bool is_waiting_for_workers_core() const { - return this->state == STATE::WAITING_FOR_WORKER; - } - [[nodiscard]] FORCE_INLINE bool is_ready_to_signal_workers() const { - return this->state == STATE::SIGNALING_WORKER; - } - [[nodiscard]] FORCE_INLINE bool is_waiting_for_remote_eth_core() const { - return this->state == STATE::WAITING_FOR_ETH; - } - [[nodiscard]] FORCE_INLINE bool is_ready_for_eth_transfer() const { - return this->state == STATE::READY_FOR_ETH_TRANSFER; - } - [[nodiscard]] FORCE_INLINE bool is_done() const { return this->state == STATE::DONE; } - - [[nodiscard]] FORCE_INLINE uint32_t get_eth_transaction_channel() const { - ASSERT(this->eth_transaction_channel < eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS); - return this->eth_transaction_channel; - } - [[nodiscard]] FORCE_INLINE std::size_t get_remote_eth_buffer_address() const { return this->address; } - [[nodiscard]] FORCE_INLINE std::size_t get_size_in_bytes() const { return this->size_in_bytes; } - [[nodiscard]] FORCE_INLINE std::size_t get_current_payload_size() const { return this->get_size_in_bytes(); } - - [[nodiscard]] FORCE_INLINE std::size_t get_buffer_address() const { return this->address; } - - FORCE_INLINE uint32_t get_messages_moved() { return this->num_messages_moved; } - FORCE_INLINE void increment_messages_moved() { this->num_messages_moved++; } - - [[nodiscard]] FORCE_INLINE bool all_messages_moved() { - return this->num_messages_moved == this->total_num_messages_to_move; - } - - FORCE_INLINE void set_send_completion_pending(bool value) { this->is_sender_completion_pending = value; } - [[nodiscard]] FORCE_INLINE bool is_send_completion_pending() const { return this->is_sender_completion_pending; } - - FORCE_INLINE bool eth_is_receiver_channel_send_done() const { return *this->channel_bytes_sent_address == 0; } - FORCE_INLINE bool eth_bytes_are_available_on_channel() const { return *this->channel_bytes_sent_address != 0; } - FORCE_INLINE bool eth_is_receiver_channel_send_acked() const { return *this->channel_bytes_acked_address != 0; } - volatile tt_l1_ptr uint32_t *const get_channel_bytes_sent_address() { return this->channel_bytes_sent_address; } - volatile tt_l1_ptr uint32_t *const get_channel_bytes_acked_address() { return this->channel_bytes_acked_address; } - - public: - uint32_t eth_transaction_channel; // - volatile tt_l1_ptr uint32_t *const local_semaphore_address; - WorkerXY const *const worker_coords; - std::size_t const address; - std::size_t const size_in_bytes; - // Even for multiple workers, this address will be the same - std::size_t const worker_semaphore_l1_address; - uint32_t const num_workers; - uint32_t num_messages_moved; - volatile tt_l1_ptr uint32_t *const channel_bytes_sent_address; - volatile tt_l1_ptr uint32_t *const channel_bytes_acked_address; - const uint32_t total_num_messages_to_move; - STATE state; - edm_worker_index worker_index; - bool is_sender_completion_pending; - bool is_sender_side; -}; - -template -class QueueIndexPointer { - public: - QueueIndexPointer(uint8_t queue_size) : ptr(0), size(queue_size), wrap_around(queue_size * 2) { - // FWASSERT(queue_size < 128); - } - - [[nodiscard("index was called without consuming the result. Did you mean to call it?")]] T index() const { - return this->ptr >= this->size ? this->ptr - this->size : this->ptr; - } - [[nodiscard("raw_index was called without consuming the result. Did you mean to call it?")]] inline T raw_index() - const { - return this->ptr; - } - [[nodiscard("distance was called without consuming the result. Did you mean to call it?")]] inline static T - distance(QueueIndexPointer ptr, QueueIndexPointer ackptr) { - // FWASSERT(ptr.size == ackptr.size); - return ackptr.ptr > ptr.ptr ? (ptr.wrap_around - ackptr.ptr) + ptr.ptr : ptr.ptr - ackptr.ptr; - } - [[nodiscard("full was called without consuming the result. Did you mean to call it?")]] inline static T full( - QueueIndexPointer ptr, QueueIndexPointer ackptr) { - // FWASSERT(ptr.size == ackptr.size); - return distance(ptr.ptr, ackptr.ptr) >= ptr.size; - } - [[nodiscard("empty was called without consuming the result. Did you mean to call it?")]] inline static T empty( - QueueIndexPointer ptr, QueueIndexPointer ackptr) { - // FWASSERT(ptr.size == ackptr.size); - return ptr.ptr == ackptr.ptr; - } - inline void increment() { this->ptr = this->next_pointer(); } - [[nodiscard( - "next_index was called without consuming the result. Did you mean to call it?")]] inline QueueIndexPointer - next_index() const { - return QueueIndexPointer(this->next_pointer(), this->size); - } - // Compares indices since the raw index is not visible to the user - inline bool operator==(const QueueIndexPointer &other) const { return this->ptr == other.ptr; } - inline bool operator!=(const QueueIndexPointer &other) const { return this->ptr != other.ptr; } - - private: - inline T next_pointer() { - T next_ptr = (this->ptr + 1); - next_ptr = next_ptr == wrap_around ? 0 : next_ptr; - return next_ptr; - } - QueueIndexPointer(T ptr, uint8_t queue_size) : ptr(ptr), size(queue_size), wrap_around(queue_size * 2) {} - T ptr; - uint8_t size; - uint8_t wrap_around; -}; - -FORCE_INLINE void eth_setup_handshake(std::uint32_t handshake_register_address, bool is_sender) { - reinterpret_cast(handshake_register_address)[4] = 1; - reinterpret_cast(handshake_register_address)[5] = 1; - reinterpret_cast(handshake_register_address)[6] = 0x1c0ffee1; - reinterpret_cast(handshake_register_address)[7] = 0x1c0ffee2; - - erisc_info->channels[0].receiver_ack = 0; - for (uint32_t i = 1; i < eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS; i++) { - erisc_info->channels[i].bytes_sent = 0; - erisc_info->channels[i].receiver_ack = 0; - } - *(volatile tt_l1_ptr uint32_t *)handshake_register_address = 0; - if (is_sender) { - eth_wait_receiver_done(); - eth_send_bytes(handshake_register_address, handshake_register_address, 16); - eth_wait_for_receiver_done(); - } else { - eth_wait_for_bytes(16); - eth_receiver_channel_done(0); - } -} - -template -FORCE_INLINE void initialize_transaction_buffer_addresses( - uint32_t max_concurrent_transactions, - uint32_t first_buffer_base_address, - uint32_t num_bytes_per_send, - std::array &transaction_channel_buffer_addresses) { - uint32_t buffer_address = first_buffer_base_address; - for (uint32_t i = 0; i < max_concurrent_transactions; i++) { - transaction_channel_buffer_addresses[i] = buffer_address; - buffer_address += num_bytes_per_send; - } -} - -///////////////////////////////////////////// -// SENDER SIDE HELPERS -///////////////////////////////////////////// - -template -FORCE_INLINE bool sender_eth_send_data_sequence(ChannelBuffer &sender_buffer_channel) { - bool did_something = false; - if (sender_buffer_channel.eth_is_receiver_channel_send_done()) { - bool need_to_send_completion = sender_buffer_channel.is_send_completion_pending(); - if (!sender_buffer_channel.is_send_completion_pending() && !eth_txq_is_busy()) { - static constexpr std::size_t ETH_BYTES_TO_WORDS_SHIFT = 4; - eth_send_bytes_over_channel_payload_only( - sender_buffer_channel.get_buffer_address(), - sender_buffer_channel.get_remote_eth_buffer_address(), - sender_buffer_channel.get_current_payload_size(), - sender_buffer_channel.get_eth_transaction_channel(), - sender_buffer_channel.get_current_payload_size(), - sender_buffer_channel.get_current_payload_size() >> ETH_BYTES_TO_WORDS_SHIFT); - - sender_buffer_channel.set_send_completion_pending(true); - need_to_send_completion = true; - did_something = true; - } - - if (need_to_send_completion && !eth_txq_is_busy()) { - eth_send_payload_complete_signal_over_channel( - sender_buffer_channel.get_eth_transaction_channel(), sender_buffer_channel.get_current_payload_size()); - sender_buffer_channel.set_send_completion_pending(false); - sender_buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_ETH); - did_something = true; - } - } - - return did_something; -} - -template -FORCE_INLINE bool sender_notify_workers_if_buffer_available_sequence( - ChannelBuffer &sender_buffer_channel, uint32_t &num_senders_complete) { - bool channel_done = false; - if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) { - channel_done = sender_buffer_channel.all_messages_moved(); - } else if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::WORKER_INITIATED) { - // Nothing to do here because in this termination mode, we must check the signal in a different state - } else { - ASSERT(false); - } - - sender_buffer_channel.clear_local_semaphore(); - sender_buffer_channel.increment_worker_semaphores(); - - if (!channel_done) { - sender_buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_WORKER); - } else { - sender_buffer_channel.goto_state(ChannelBuffer::DONE); - num_senders_complete++; - } - - return true; -} - -template -FORCE_INLINE bool sender_eth_check_receiver_ack_sequence( - ChannelBuffer &sender_buffer_channel, uint32_t &num_senders_complete) { - bool did_something = false; - - bool transimission_acked_by_receiver = sender_buffer_channel.eth_is_receiver_channel_send_acked() || - sender_buffer_channel.eth_is_receiver_channel_send_done(); - if (transimission_acked_by_receiver) { - eth_clear_sender_channel_ack(sender_buffer_channel.get_eth_transaction_channel()); - sender_buffer_channel.increment_messages_moved(); - sender_buffer_channel.goto_state(ChannelBuffer::SIGNALING_WORKER); - sender_notify_workers_if_buffer_available_sequence(sender_buffer_channel, num_senders_complete); - did_something = true; - } - - return did_something; -} - -/* - * - */ -template -FORCE_INLINE bool sender_noc_receive_payload_ack_check_sequence( - ChannelBuffer &sender_channel_buffer, uint32_t &num_senders_complete) { - bool did_something = false; - - if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::WORKER_INITIATED) { - if (*sender_channel_buffer.local_semaphore_address == EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY) { - sender_channel_buffer.clear_local_semaphore(); - sender_channel_buffer.goto_state(ChannelBuffer::DONE); - num_senders_complete++; - return true; - } - } - - bool read_finished = sender_channel_buffer.is_local_semaphore_full(); - if (read_finished) { - // We can clear the semaphore, and wait for space on receiver - // sender_channel_buffer.clear_local_semaphore(); - sender_channel_buffer.goto_state(ChannelBuffer::READY_FOR_ETH_TRANSFER); - did_something = true; - - erisc::datamover::sender_eth_send_data_sequence(sender_channel_buffer); - } - - return did_something; -} - -///////////////////////////////////////////// -// RECEIVER SIDE HELPERS -///////////////////////////////////////////// - -/* - * - */ -template -FORCE_INLINE bool receiver_eth_notify_workers_payload_available_sequence(ChannelBuffer &buffer_channel) { - buffer_channel.clear_local_semaphore(); - buffer_channel.increment_worker_semaphores(); - buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_WORKER); - - return true; -} - -/* - * If payload received, notify (send ack to) sender so sender knows it can free up its local buffer - * - */ -template -FORCE_INLINE bool receiver_eth_accept_payload_sequence( - ChannelBuffer &buffer_channel, - uint32_t &num_receivers_complete, - uint32_t eth_transaction_ack_word_addr) { - bool did_something = false; - - if (buffer_channel.eth_bytes_are_available_on_channel()) { - if (!eth_txq_is_busy()) { - eth_receiver_channel_ack(buffer_channel.get_eth_transaction_channel(), eth_transaction_ack_word_addr); - buffer_channel.goto_state(ChannelBuffer::SIGNALING_WORKER); - did_something = true; - - // FIXME: Decouple these so we can still signal workers even if eth command queue is busy - // Prefer sending eth ack first, but notify workers even if we have to come back to - // send the eth ack later - receiver_eth_notify_workers_payload_available_sequence(buffer_channel); - } - } - - return did_something; -} - -/* - * Does something if we are waiting for workers to complete their read and the read is complete: - * - increment messages moved (that transfer is done) - * - notifies sender it is safe to send next payload - * - clear local semaphore - */ -template -FORCE_INLINE bool receiver_noc_read_worker_completion_check_sequence( - ChannelBuffer &buffer_channel, - uint32_t &num_receivers_complete, - uint32_t eth_transaction_complete_addr) { - bool did_something = false; - - bool workers_are_finished_reading = buffer_channel.is_local_semaphore_full(); - - if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::WORKER_INITIATED) { - // May have already gotten final termination signal by this point so check for that too - workers_are_finished_reading = - workers_are_finished_reading || - (*buffer_channel.local_semaphore_address == EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY); - } - - bool can_notify_sender_of_buffer_available = workers_are_finished_reading; - if (can_notify_sender_of_buffer_available) { - if (!eth_txq_is_busy()) { - eth_receiver_channel_done(buffer_channel.get_eth_transaction_channel()); - buffer_channel.increment_messages_moved(); - - bool channel_done = false; - if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) { - channel_done = buffer_channel.all_messages_moved(); - } else if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::WORKER_INITIATED) { - channel_done = (*buffer_channel.local_semaphore_address == EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY); - } else { - ASSERT(false); - } - - if (!channel_done) { - buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_ETH); - } else { - buffer_channel.goto_state(ChannelBuffer::DONE); - num_receivers_complete++; - } - - did_something = true; - } - } - - return did_something; -} - -//////////////////////////// -// DEPRECATED -//////////////////////////// -namespace deprecated { -// This namespace exists to support non-decoupled mode microbenchmarks until those are available -// in decoupled mode - -FORCE_INLINE bool sender_buffer_pool_full( - const QueueIndexPointer noc_reader_buffer_wrptr, - const QueueIndexPointer noc_reader_buffer_ackptr, - const QueueIndexPointer eth_sender_rdptr, - const QueueIndexPointer eth_sender_ackptr) { - return QueueIndexPointer::full(noc_reader_buffer_wrptr, eth_sender_ackptr); -} - -FORCE_INLINE bool sender_buffer_pool_empty( - const QueueIndexPointer noc_reader_buffer_wrptr, - const QueueIndexPointer noc_reader_buffer_ackptr, - const QueueIndexPointer eth_sender_rdptr, - const QueueIndexPointer eth_sender_ackptr) { - return QueueIndexPointer::empty(eth_sender_rdptr, noc_reader_buffer_wrptr); -} - -FORCE_INLINE bool sender_buffer_available_for_eth_send( - const QueueIndexPointer noc_reader_buffer_wrptr, - const QueueIndexPointer noc_reader_buffer_ackptr, - const QueueIndexPointer eth_sender_rdptr, - const QueueIndexPointer eth_sender_ackptr) { - return eth_sender_rdptr != noc_reader_buffer_ackptr; -} - -template -FORCE_INLINE bool sender_eth_send_data_sequence( - std::array &transaction_channel_sender_buffer_addresses, - std::array &transaction_channel_receiver_buffer_addresses, - uint32_t local_eth_l1_src_addr, - uint32_t remote_eth_l1_dst_addr, - uint32_t num_bytes, - uint32_t num_bytes_per_send, - uint32_t num_bytes_per_send_word_size, - QueueIndexPointer noc_reader_buffer_wrptr, - QueueIndexPointer noc_reader_buffer_ackptr, - QueueIndexPointer ð_sender_rdptr, - QueueIndexPointer ð_sender_ackptr) { - bool did_something = false; - bool data_ready_for_send = sender_buffer_available_for_eth_send( - noc_reader_buffer_wrptr, noc_reader_buffer_ackptr, eth_sender_rdptr, eth_sender_ackptr); - if (data_ready_for_send) { - bool consumer_ready_to_accept = eth_is_receiver_channel_send_done(eth_sender_rdptr.index()); - if (consumer_ready_to_accept) { - // Queue up another send - uint32_t sender_buffer_address = transaction_channel_sender_buffer_addresses[eth_sender_rdptr.index()]; - uint32_t receiver_buffer_address = transaction_channel_receiver_buffer_addresses[eth_sender_rdptr.index()]; - - eth_send_bytes_over_channel( - sender_buffer_address, - receiver_buffer_address, - num_bytes, - eth_sender_rdptr.index(), - num_bytes_per_send, - num_bytes_per_send_word_size); - eth_sender_rdptr.increment(); - did_something = true; - } - } - - return did_something; -} - -FORCE_INLINE bool sender_eth_check_receiver_ack_sequence( - const QueueIndexPointer noc_reader_buffer_wrptr, - const QueueIndexPointer noc_reader_buffer_ackptr, - QueueIndexPointer ð_sender_rdptr, - QueueIndexPointer ð_sender_ackptr, - uint32_t &num_eth_sends_acked) { - bool did_something = false; - bool eth_sends_unacknowledged = QueueIndexPointer::distance(eth_sender_rdptr, eth_sender_ackptr) > 0; - if (eth_sends_unacknowledged) { - bool transimission_acked_by_receiver = eth_is_receiver_channel_send_acked(eth_sender_ackptr.index()) || - eth_is_receiver_channel_send_done(eth_sender_ackptr.index()); - if (transimission_acked_by_receiver) { - num_eth_sends_acked++; - eth_sender_ackptr.increment(); - - did_something = true; - } - } - - return did_something; -} - -FORCE_INLINE bool sender_is_noc_read_in_progress( - const QueueIndexPointer noc_reader_buffer_wrptr, - const QueueIndexPointer noc_reader_buffer_ackptr) { - return noc_reader_buffer_wrptr != noc_reader_buffer_ackptr; -} - -FORCE_INLINE bool sender_noc_receive_payload_ack_check_sequence( - QueueIndexPointer &noc_reader_buffer_wrptr, - QueueIndexPointer &noc_reader_buffer_ackptr, - const uint8_t noc_index) { - bool did_something = false; - - bool noc_read_is_in_progress = sender_is_noc_read_in_progress(noc_reader_buffer_wrptr, noc_reader_buffer_ackptr); - if (noc_read_is_in_progress) { -#if EMULATE_DRAM_READ_CYCLES == 1 - bool read_finished = emulated_dram_read_cycles_finished(); -#else - bool read_finished = ncrisc_noc_reads_flushed(noc_index); -#endif - if (read_finished) { - noc_reader_buffer_ackptr.increment(); - did_something = true; - } - } - - return did_something; -} - -///////////////////////////////////////////// -// RECEIVER SIDE HELPERS -///////////////////////////////////////////// - -FORCE_INLINE bool receiver_is_noc_write_in_progress( - const QueueIndexPointer noc_writer_buffer_wrptr, - const QueueIndexPointer noc_writer_buffer_ackptr) { - return noc_writer_buffer_wrptr != noc_writer_buffer_ackptr; -} - -bool receiver_eth_accept_payload_sequence( - QueueIndexPointer noc_writer_buffer_wrptr, - QueueIndexPointer noc_writer_buffer_ackptr, - QueueIndexPointer ð_receiver_ptr, - QueueIndexPointer ð_receiver_ackptr, - uint32_t eth_channel_sync_ack_addr) { - bool did_something = false; - bool receive_pointers_full = QueueIndexPointer::full(eth_receiver_ptr, eth_receiver_ackptr); - - if (!receive_pointers_full) { - if (eth_bytes_are_available_on_channel(eth_receiver_ptr.index())) { - // DPRINT << "rx: accepting payload, sending receive ack on channel " << (uint32_t)eth_receiver_ptr << "\n"; - eth_receiver_channel_ack(eth_receiver_ptr.index(), eth_channel_sync_ack_addr); - eth_receiver_ptr.increment(); - did_something = true; - } - } - - return did_something; -} - -FORCE_INLINE bool receiver_noc_read_worker_completion_check_sequence( - QueueIndexPointer &noc_writer_buffer_wrptr, - QueueIndexPointer &noc_writer_buffer_ackptr, - const uint8_t noc_index) { - bool did_something = false; - - bool noc_write_is_in_progress = - receiver_is_noc_write_in_progress(noc_writer_buffer_wrptr, noc_writer_buffer_ackptr); - if (noc_write_is_in_progress) { -#if EMULATE_DRAM_READ_CYCLES == 1 - bool write_finished = emulated_dram_write_cycles_finished(); -#else - bool writes_finished = ncrisc_noc_nonposted_writes_sent(noc_index); -#endif - if (writes_finished) { - // DPRINT << "rx: accepting payload, sending receive ack on channel " << (uint32_t)noc_writer_buffer_ackptr - // << "\n"; - noc_writer_buffer_ackptr.increment(); - - did_something = true; - } - } - - return did_something; -} - -FORCE_INLINE bool receiver_eth_send_ack_to_sender_sequence( - const QueueIndexPointer noc_writer_buffer_wrptr, - const QueueIndexPointer noc_writer_buffer_ackptr, - QueueIndexPointer ð_receiver_rdptr, - QueueIndexPointer ð_receiver_ackptr, - uint32_t &num_eth_sends_acked) { - bool did_something = false; - bool eth_sends_unacknowledged = eth_receiver_rdptr != eth_receiver_ackptr; - if (eth_sends_unacknowledged) { - // If data is done being sent out of this local l1 buffer and to the destination(s), - // then we can safely send the ack and increment the ackptr - bool buffer_writes_flushed = ncrisc_noc_nonposted_writes_sent(noc_index); - // bool buffer_writes_flushed = ncrisc_noc_nonposted_writes_flushed(noc_index); - if (buffer_writes_flushed) { - // DPRINT << "rx: accepting payload, sending receive ack on channel " << (uint32_t)noc_writer_buffer_wrptr - // << "\n"; - eth_receiver_channel_done(eth_receiver_ackptr.index()); - num_eth_sends_acked++; - eth_receiver_ackptr.increment(); - // DPRINT << "rx: Sending eth ack. ackptr incrementing to " << (uint32_t)eth_receiver_ackptr.index() << - // "\n"; - - did_something = true; - } - } - - return did_something; -} - -}; // namespace deprecated - -}; // namespace datamover -}; // namespace erisc diff --git a/ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp b/ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp deleted file mode 100644 index a137ec25813..00000000000 --- a/ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp +++ /dev/null @@ -1,343 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include - -#include "dataflow_api.h" -#include "debug/dprint.h" -#include "eth_l1_address_map.h" -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp - -#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" -#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp" -======= -#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" -#include "ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp" ->>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp - -// Args Schema: -// 1) handshake addr -// 2) sender channels offset (indicates for the erisc channels, where the senders start -// so sender and receivers don't clash when paired with sender/receiver on the other -// end of the link.) -// 3) sender num channels (How many erisc channels to use. ALso how many buffers to instantiate) -// Informs how many times to iterate through the next group of args -// 4) sender_buffer_address -// 5) sender_num_messages_to_send -// 6) sender_channel_size -// 7) sender_semaphores_base_address -// 8) worker_semaphore_address -// 9) sender_num_workers -// Informs how many worker X/Y coords to accept in the next loop. Each X/Y pair is 2 uint16s -// 10) worker_coord(s) -// ... -// Repeat from step 2 for receiver side - -// Intended only for (performance) test use cases -FORCE_INLINE void eth_setup_handshake2(std::uint32_t handshake_register_address, bool is_sender) { - if (is_sender) { - DPRINT << "eth_send_bytes\n"; - eth_send_bytes(handshake_register_address, handshake_register_address, 16); - DPRINT << "eth_wait_for_receiver_done\n"; - eth_wait_for_receiver_done(); - } else { - DPRINT << "eth_wait_for_bytes\n"; - eth_wait_for_bytes(16); - DPRINT << "wait eth_receiver_done\n"; - eth_receiver_channel_done(0); - } -} - -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp -using ttnn::ccl::WorkerXY; -======= -using ttnn::utils::ccl::WorkerXY; ->>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp - -template -struct sender_receiver_index_t { - static constexpr bool ZERO_SENDERS = num_senders == 0; - static constexpr bool ZERO_RECEIVERS = num_receivers == 0; - static constexpr bool NUM_SENDERS_IS_POW_2 = !ZERO_SENDERS && (((num_senders - 1) & num_senders) == 0); - static constexpr bool NUM_RECEIVERS_IS_POW_2 = !ZERO_RECEIVERS && (((num_receivers - 1) & num_receivers) == 0); - static constexpr uint16_t SENDER_INCR_MASK = !ZERO_SENDERS ? num_senders - 1 : 0; - static constexpr uint16_t RECEIVER_INCR_MASK = !ZERO_RECEIVERS ? num_receivers - 1 : 0; - static constexpr uint16_t COMBINED_INCR_MASK = SENDER_INCR_MASK << 8 | RECEIVER_INCR_MASK; - static constexpr uint16_t COMBINED_INCR = (1 << 8) | 1; - union { - struct { - uint8_t sender; - uint8_t receiver; - }; - uint16_t combined; - } index; - union { - struct { - uint8_t sender; - uint8_t receiver; - }; - uint16_t combined; - } real_index; - union { - struct { - uint8_t sender; - uint8_t receiver; - }; - uint16_t combined; - } start; - - sender_receiver_index_t(uint8_t send_start, uint8_t receive_start, uint8_t num_send, uint8_t num_receive) { - start.sender = send_start; - start.receiver = receive_start; - index.sender = 0; - index.receiver = 0; - real_index.sender = send_start; - real_index.receiver = receive_start; - } - - FORCE_INLINE void increment() { - if constexpr (NUM_SENDERS_IS_POW_2 and NUM_RECEIVERS_IS_POW_2) { - index.combined = (index.combined + COMBINED_INCR) & COMBINED_INCR_MASK; - real_index.combined = start.combined + index.combined; - } else if constexpr (ZERO_RECEIVERS and NUM_SENDERS_IS_POW_2) { - index.sender = (index.sender + 1) & SENDER_INCR_MASK; - real_index.sender = start.sender + index.sender; - } else if constexpr (ZERO_SENDERS and NUM_RECEIVERS_IS_POW_2) { - index.receiver = (index.receiver + 1) & RECEIVER_INCR_MASK; - real_index.receiver = start.receiver + index.receiver; - } else { - index.combined += COMBINED_INCR; - index.sender = index.sender >= num_senders ? 0 : index.sender; - index.receiver = index.receiver >= num_receivers ? 0 : index.receiver; - real_index.combined = start.combined + index.combined; - } - } -}; - -void kernel_main() { - // COMPILE TIME ARGS - // If true, will enable this erisc's sender functionality - constexpr bool enable_sender_side = get_compile_time_arg_val(0) != 0; - - // If true, will enable this erisc's receiver functionality - constexpr bool enable_receiver_side = get_compile_time_arg_val(1) != 0; - - constexpr uint32_t num_senders = get_compile_time_arg_val(2); - constexpr uint32_t num_receivers = get_compile_time_arg_val(3); - -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp - constexpr ttnn::ccl::EriscDataMoverBufferSharingMode edm_buffer_sharing_mode = - static_cast(get_compile_time_arg_val(4)); - - constexpr ttnn::ccl::EriscDataMoverTerminationMode terminate_on_worker_signal = - static_cast(get_compile_time_arg_val(5)); -======= - constexpr ttnn::utils::ccl::EriscDataMoverBufferSharingMode edm_buffer_sharing_mode = - static_cast(get_compile_time_arg_val(4)); - - constexpr ttnn::utils::ccl::EriscDataMoverTerminationMode terminate_on_worker_signal = - static_cast(get_compile_time_arg_val(5)); ->>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp - - constexpr auto EDM_CONFIG = erisc::datamover::EriscDatamoverConfig(); - using EDM_CONFIG_T = decltype(EDM_CONFIG); - using ChannelBufferT = erisc::datamover::ChannelBuffer; - std::array buffer_channels; - - // - std::array printed_receiver_done; - - // SENDER ARGS - uint32_t args_offset = 0; - uint32_t handshake_addr = get_arg_val(args_offset++); - - uint8_t const sender_channels_start = get_arg_val(args_offset++); - uint32_t const sender_num_channels = num_senders;//get_arg_val(args_offset++); - uint8_t num_senders_with_no_work = 0; - for (uint32_t channel = 0; channel < sender_num_channels; channel++) { - uint32_t const sender_buffer_address = get_arg_val(args_offset++); - uint32_t const sender_num_messages_to_send = get_arg_val(args_offset++); - // Each channel buffer is at buffer_base + (channel_id * sender_channel_size) - // Each channel currently constrained to the same buffer size - uint32_t const sender_channel_size = get_arg_val(args_offset++); - // The erisc's local l1 copy of the semaphore workers remotely increment - uint32_t const sender_semaphores_base_address = get_arg_val(args_offset++); - // worker's semaphore L1 address - const uint32_t worker_semaphore_address = get_arg_val(args_offset++); - const uint32_t sender_num_workers = get_arg_val(args_offset++); - const uint32_t workers_xy_list_addr = get_arg_addr(args_offset); - args_offset += sender_num_workers; - new (&buffer_channels[sender_channels_start + channel]) ChannelBufferT( - sender_channels_start + channel, - sender_buffer_address, - sender_channel_size, - worker_semaphore_address, - sender_num_workers, - sender_num_messages_to_send, - (volatile tt_l1_ptr uint32_t *const)sender_semaphores_base_address, - (const WorkerXY *)workers_xy_list_addr, - true); - if constexpr (terminate_on_worker_signal == EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) { - if (sender_num_messages_to_send == 0) { - num_senders_with_no_work++; - } - } - } - - // Receiver args - uint8_t const receiver_channels_start = get_arg_val(args_offset++); - uint32_t const receiver_num_channels = num_receivers;//get_arg_val(args_offset++); - uint8_t num_receivers_with_no_work = 0; - for (uint32_t channel = 0; channel < receiver_num_channels; channel++) { - uint32_t const receiver_buffers_base_address = get_arg_val(args_offset++); - uint32_t const receiver_num_messages_to_send = get_arg_val(args_offset++); - // Each channel buffer is at buffer_base + (channel_id * sender_channel_size) - // Each channel currently constrained to the same buffer size - uint32_t const receiver_channel_size = get_arg_val(args_offset++); - uint32_t const receiver_semaphores_base_address = get_arg_val(args_offset++); - uint32_t const worker_semaphore_address = get_arg_val(args_offset++); - uint32_t const receiver_num_workers = get_arg_val(args_offset++); - const uint32_t workers_xy_list_addr = get_arg_addr(args_offset); - args_offset += receiver_num_workers; - new (&buffer_channels[receiver_channels_start + channel]) ChannelBufferT( - receiver_channels_start + channel, - receiver_buffers_base_address, - receiver_channel_size, - worker_semaphore_address, - receiver_num_workers, - receiver_num_messages_to_send, - (volatile tt_l1_ptr uint32_t *const)receiver_semaphores_base_address, - (const WorkerXY *)workers_xy_list_addr, - false); - - if constexpr (terminate_on_worker_signal == EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) { - if (receiver_num_messages_to_send == 0) { - num_receivers_with_no_work++; - } - } - } - - // Handshake with other erisc to make sure it's safe to start sending/receiving - // Chose an arbitrary ordering mechanism to guarantee one of the erisc's will always be "sender" and the other - // will always be "receiver" (only for handshake purposes) - bool act_as_sender_in_handshake = - (sender_channels_start < receiver_channels_start || receiver_num_channels == 0) && sender_num_channels > 0; - erisc::datamover::eth_setup_handshake(handshake_addr, act_as_sender_in_handshake); - uint32_t eth_transaction_ack_word_addr = handshake_addr + 16; - uint32_t eth_transaction_complete_addr = handshake_addr + 32; - - constexpr uint32_t SWITCH_INTERVAL = 4000000; - uint32_t did_nothing_count = 0; - - uint32_t num_senders_complete = !enable_sender_side ? sender_num_channels : num_senders_with_no_work; - uint32_t num_receivers_complete = !enable_receiver_side ? receiver_num_channels : num_receivers_with_no_work; - bool senders_in_progress = num_senders_complete != sender_num_channels; - bool receivers_in_progress = num_receivers_complete != receiver_num_channels; - - auto send_recv_index = sender_receiver_index_t(sender_channels_start, receiver_channels_start, sender_num_channels, receiver_num_channels); - - while (senders_in_progress || receivers_in_progress) { - bool did_something_sender = false; - bool did_something_receiver = false; - - uint32_t num_receivers_complete_old = num_receivers_complete; - uint32_t num_senders_complete_old = num_senders_complete; - ////////////////////////////////////// - // SENDER - if constexpr (enable_sender_side) { - ChannelBufferT ¤t_sender = buffer_channels[send_recv_index.real_index.sender]; - switch (current_sender.get_state()) { - case ChannelBufferT::STATE::WAITING_FOR_WORKER: - did_something_sender = - erisc::datamover::sender_noc_receive_payload_ack_check_sequence(current_sender, num_senders_complete); - senders_in_progress = senders_in_progress && num_senders_complete != sender_num_channels; - break; - - case ChannelBufferT::STATE::READY_FOR_ETH_TRANSFER: - did_something_sender = erisc::datamover::sender_eth_send_data_sequence(current_sender); - break; - - case ChannelBufferT::STATE::SIGNALING_WORKER: - did_something_sender = erisc::datamover::sender_notify_workers_if_buffer_available_sequence( - current_sender, num_senders_complete); - senders_in_progress = senders_in_progress && num_senders_complete != sender_num_channels; - break; - - case ChannelBufferT::STATE::WAITING_FOR_ETH: - did_something_sender = - erisc::datamover::sender_eth_check_receiver_ack_sequence(current_sender, num_senders_complete); - senders_in_progress = senders_in_progress && num_senders_complete != sender_num_channels; - break; - - default: - break; - }; - } - - ////////////////////////////////////// - // RECEIVER - if constexpr (enable_receiver_side) { - ChannelBufferT ¤t_receiver = buffer_channels[send_recv_index.real_index.receiver]; - - switch (current_receiver.get_state()) { - case ChannelBufferT::STATE::WAITING_FOR_ETH: - did_something_receiver = erisc::datamover::receiver_eth_accept_payload_sequence(current_receiver, num_receivers_complete, eth_transaction_ack_word_addr); - receivers_in_progress = receivers_in_progress && num_receivers_complete != receiver_num_channels; - break; - - case ChannelBufferT::STATE::SIGNALING_WORKER: - did_something_receiver = - erisc::datamover::receiver_eth_notify_workers_payload_available_sequence(current_receiver); - break; - - case ChannelBufferT::STATE::WAITING_FOR_WORKER: - did_something_receiver = erisc::datamover::receiver_noc_read_worker_completion_check_sequence( - current_receiver, num_receivers_complete, eth_transaction_complete_addr); - receivers_in_progress = receivers_in_progress && num_receivers_complete != receiver_num_channels; - break; - - default: - break; - }; - } - send_recv_index.increment(); - ////////////////////////////////////// - - // Enabling this block as is (with all the "did_something"s, seems to cause a loss of about - // 0.5 GBps in throughput) - if (did_something_sender || did_something_receiver) { - did_nothing_count = 0; - } else { - if (did_nothing_count++ > SWITCH_INTERVAL) { - did_nothing_count = 0; - run_routing(); - } - } - } - - for (uint32_t s = 0; s < num_senders + num_receivers; s++ ) { - auto const& channel = buffer_channels[s]; - // We need to explicitly check for channel send done because we may - // advance sender channel state as soon as we receive an ack. Since we - // may be the last active channel, and advance to done state just from ack - // from the receiver ("I got a payload"), then we need to wait for done - // at the very end here. Otherise if we invoke another erisc op back-to-back, - // we may mess up transaction state because it's possible for receiver of this - // op to send the completion done after that one has already started. - uint32_t wait_count = 0; - uint32_t wait_max = 50000; - while(!channel.eth_is_receiver_channel_send_done()) { - wait_count++; - if (wait_count > wait_max) { - - DEBUG_STATUS("STK"); - run_routing(); - wait_count = 0; - } - } - } - - DEBUG_STATUS("DONE"); -} diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp index a137ec25813..6de2e2c5016 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp @@ -14,8 +14,17 @@ #include "ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp" ======= #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +<<<<<<< HEAD #include "ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp" >>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp +======= +<<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp" +======== +#include "ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp" +>>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp +>>>>>>>> af98ddace6... #9486: Move kernel files into kernels directory:ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp +>>>>>>> af98ddace6... #9486: Move kernel files into kernels directory // Args Schema: // 1) handshake addr diff --git a/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp deleted file mode 100644 index 2dd167b3db2..00000000000 --- a/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp +++ /dev/null @@ -1,64 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp -#include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" -======= -#include "ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.hpp" ->>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp -#include "ttnn/cpp/ttnn/multi_device.hpp" - -namespace ttnn { -namespace operations { -namespace ccl { - -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp -struct ExecuteAllGather { - - static ttnn::Tensor execute_on_main_thread( - const ttnn::Tensor& input_tensor, - const uint32_t dim, - const uint32_t num_links = 1, - const std::optional& memory_config = std::nullopt) { - return ttnn::operations::ccl::all_gather(input_tensor, dim, num_links, memory_config); - } -}; - -======= ->>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp -struct ExecuteLineAllGather { - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::ROW_MAJOR_LAYOUT, ttnn::TILE_LAYOUT}, - true, - false, - false, - false}}; - } - - template - static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); - } - - static ttnn::Tensor execute_on_main_thread( - const ttnn::Tensor& input_tensor, - const uint32_t dim, - const uint32_t num_links = 1, - const std::optional& memory_config = std::nullopt) { - return ttnn::operations::ccl::line_all_gather(input_tensor, dim, num_links, memory_config); - } -}; - -} // namespace ccl -} // namespace operations - -constexpr auto line_all_gather = ttnn::register_operation("ttnn::line_all_gather"); - -} // namespace ttnn From 6c3e9c2441acda7ae55cb7a58008ce0c6cd9de05 Mon Sep 17 00:00:00 2001 From: Aswinmcw Date: Fri, 5 Jul 2024 07:51:18 +0000 Subject: [PATCH 06/10] #9486: Merge CCL reduce_scatter to TTNN --- .../test_reduce_scatter_nightly.py | 9 +- .../test_reduce_scatter_post_commit.py | 9 +- ttnn/CMakeLists.txt | 2 + .../host/reduce_scatter_full_worker_grid.cpp | 111 ++- .../ccl/reduce_scatter/reduce_scatter_op.cpp | 21 +- .../ccl/reduce_scatter/reduce_scatter_op.hpp | 18 +- .../ccl_reduce_scatter_pybind.hpp | 75 ++ .../device/ccl_reduce_scatter_op.hpp | 48 + .../host/reduce_scatter_full_worker_grid.cpp | 928 ++++++++++++++++++ ...interleaved_ring_reduce_scatter_reader.cpp | 356 +++++++ ...interleaved_ring_reduce_scatter_sender.cpp | 148 +++ .../device/reduce_scatter_op.cpp | 132 +++ .../device/reduce_scatter_op.hpp | 67 ++ 13 files changed, 1867 insertions(+), 57 deletions(-) rename tests/{tt_eager/python_api_testing/unit_testing/misc => ttnn/unit_tests/operations}/test_reduce_scatter_nightly.py (97%) rename tests/{tt_eager/python_api_testing/unit_testing/misc => ttnn/unit_tests/operations}/test_reduce_scatter_post_commit.py (97%) create mode 100644 ttnn/cpp/ttnn/operations/ccl/reduce_scatter/ccl_reduce_scatter_pybind.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/ccl_reduce_scatter_op.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_reduce_scatter_nightly.py b/tests/ttnn/unit_tests/operations/test_reduce_scatter_nightly.py similarity index 97% rename from tests/tt_eager/python_api_testing/unit_testing/misc/test_reduce_scatter_nightly.py rename to tests/ttnn/unit_tests/operations/test_reduce_scatter_nightly.py index 41faeee80bc..2d48781a582 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_reduce_scatter_nightly.py +++ b/tests/ttnn/unit_tests/operations/test_reduce_scatter_nightly.py @@ -6,6 +6,7 @@ import pytest from loguru import logger import tt_lib as ttl +import ttnn from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc from models.utility_functions import skip_for_grayskull, get_devices_for_t3000 import itertools @@ -75,12 +76,12 @@ def run_reduce_scatter_test( # Run the op # for i in range(num_iters): - tt_out_tensors = ttl.tensor.reduce_scatter( + tt_out_tensors = ttnn.reduce_scatter( tt_input_tensors, - scatter_split_dim=scatter_dim, - reduce_op=math_op, + scatter_dim=scatter_dim, + math_op=math_op, num_links=num_links, - output_mem_config=mem_config, + memory_config=mem_config, ) for d in devices: diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_reduce_scatter_post_commit.py b/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py similarity index 97% rename from tests/tt_eager/python_api_testing/unit_testing/misc/test_reduce_scatter_post_commit.py rename to tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py index 81d87933e59..dcc477aa5a6 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_reduce_scatter_post_commit.py +++ b/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py @@ -6,6 +6,7 @@ import pytest from loguru import logger import tt_lib as ttl +import ttnn from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc from models.utility_functions import skip_for_grayskull, get_devices_for_t3000 import itertools @@ -75,12 +76,12 @@ def run_reduce_scatter_test( # Run the op # for i in range(num_iters): - tt_out_tensors = ttl.tensor.reduce_scatter( + tt_out_tensors = ttnn.reduce_scatter( tt_input_tensors, - scatter_split_dim=scatter_dim, - reduce_op=math_op, + scatter_dim=scatter_dim, + math_op=math_op, num_links=num_links, - output_mem_config=mem_config, + memory_config=mem_config, ) for d in devices: diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 79dded7e42a..3f25f86d2cb 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -40,6 +40,8 @@ set(TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/ccl_common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/examples/example/device/example_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/examples/example/device/single_core_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/examples/example/device/multi_core_program_factory.cpp diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp index 2088728d2f9..eab8cdfd40e 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp +++ b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp @@ -34,9 +34,9 @@ using namespace tt::constants; // with that received chunk. It will forward the partially reduced chunk. // Reduces along rank -namespace tt { +namespace ttnn { -namespace tt_metal { +namespace utils { namespace ccl { namespace reduce_scatter_detail { @@ -379,9 +379,6 @@ static void add_worker_config_to_edm_builders( log_trace(tt::LogOp, "Adding receiver EDM channel"); ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& receiver_channel_buffer_info = receiver_edm_builder.add_receiver_channel( - worker_receiver_semaphore_address, - // Since we are in worker signal EDM termination mode, we don't need to set the actual number of - // messages the EDM must forward as it will receive its finish signal from the worker instead 1, receiver_worker_coords, expected_message_size_bytes); @@ -394,14 +391,21 @@ static void add_worker_config_to_edm_builders( } static std::tuple build_reduce_scatter_worker( - tt_metal::Program& program, + tt::tt_metal::Program& program, Device const* device, +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp ttnn::ccl::RingTopology const& topology_config, ttnn::ccl::CCLOpConfig const& op_config, ReduceScatterWorkerArgBuilder const& worker_arg_builder, std::vector& cw_edm_builders, std::vector& ccw_edm_builders, EdmInterfaceAddresses const& edm_interface_addresses, +======= + ttnn::utils::ccl::CCLOpConfig const& op_config, + ReduceScatterWorkerArgBuilder const& worker_arg_builder, + std::vector& cw_edm_builders, + std::vector& ccw_edm_builders, +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp CoreCoord const& worker_core, uint32_t num_edm_channels, uint32_t link, @@ -415,9 +419,15 @@ static std::tuple build_reduce_scatter_worker( log_trace(tt::LogOp, "Worker Define: {} = {}", key, value); } static std::string const& receiver_kernel_path = +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp"; static std::string const& sender_kernel_path = "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp"; +======= + "ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp"; + static std::string const& sender_kernel_path = + "ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp"; +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp // This will be configurable by sharded/non-sharded but present the same arg builder KernelHandle worker_receiver_kernel_id, worker_sender_kernel_id; @@ -439,13 +449,13 @@ static std::tuple build_reduce_scatter_worker( ? edm_interface_addresses.worker_receiver_edm_buffer_addresses.at(global_worker_index) : edm_interface_addresses.worker_sender_edm_buffer_addresses.at(global_worker_index); - worker_receiver_kernel_id = tt_metal::CreateKernel( + worker_receiver_kernel_id = tt::tt_metal::CreateKernel( program, receiver_kernel_path, worker_core, - tt_metal::ReaderDataMovementConfig(worker_arg_builder.generate_receiver_kernel_ct_args(), worker_defines)); + tt::tt_metal::ReaderDataMovementConfig(worker_arg_builder.generate_receiver_kernel_ct_args(), worker_defines)); - tt_metal::SetRuntimeArgs( + tt::tt_metal::SetRuntimeArgs( program, worker_receiver_kernel_id, worker_core, @@ -463,18 +473,18 @@ static std::tuple build_reduce_scatter_worker( constexpr bool fp32_dest_acc_en = false; constexpr bool math_approx_mode = false; std::map eltwise_defines = ttnn::operations::binary::utils::get_defines(binary_math_op); - KernelHandle worker_reduce_kernel_id = tt_metal::CreateKernel( + KernelHandle worker_reduce_kernel_id = tt::tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/compute/eltwise_binary_kernel.cpp", worker_core, - tt_metal::ComputeConfig{ + tt::tt_metal::ComputeConfig{ .math_fidelity = MathFidelity::HiFi4, .fp32_dest_acc_en = fp32_dest_acc_en, .math_approx_mode = math_approx_mode, .compile_args = compute_kernel_args, .defines = eltwise_defines}); - tt_metal::SetRuntimeArgs( + tt::tt_metal::SetRuntimeArgs( program, worker_reduce_kernel_id, worker_core, @@ -496,13 +506,13 @@ static std::tuple build_reduce_scatter_worker( is_in_clockwise_direction ? edm_interface_addresses.worker_sender_edm_buffer_addresses.at(global_worker_index) : edm_interface_addresses.worker_receiver_edm_buffer_addresses.at(global_worker_index); - worker_sender_kernel_id = tt_metal::CreateKernel( + worker_sender_kernel_id = tt::tt_metal::CreateKernel( program, sender_kernel_path, worker_core, - tt_metal::WriterDataMovementConfig(worker_arg_builder.generate_sender_kernel_ct_args(), worker_defines)); + tt::tt_metal::WriterDataMovementConfig(worker_arg_builder.generate_sender_kernel_ct_args(), worker_defines)); - tt_metal::SetRuntimeArgs( + tt::tt_metal::SetRuntimeArgs( program, worker_sender_kernel_id, worker_core, @@ -601,14 +611,18 @@ static uint32_t compute_maximum_worker_slice_in_bytes( } static bool is_cb_buffering_sufficient_to_avoid_deadlock( +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp ttnn::ccl::InterleavedTensorWorkerSlice const& worker_slice, +======= + ttnn::utils::ccl::InterleavedTensorWorkerSlice const& worker_slice, +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp uint32_t cb_src0_size_pages, uint32_t cb_dst0_size_pages, uint32_t cb_short_circuit_size_pages, std::size_t edm_channel_buffer_size, uint32_t page_size) { uint32_t worker_size_pages_rounded_up = - round_up(worker_slice.worker_slice_shape.x * worker_slice.worker_slice_shape.y, cb_src0_size_pages / 2); + tt::round_up(worker_slice.worker_slice_shape.x * worker_slice.worker_slice_shape.y, cb_src0_size_pages / 2); uint32_t worker_slice_size_bytes = worker_size_pages_rounded_up * page_size; uint32_t available_buffering_capacity = compute_maximum_worker_slice_in_bytes( cb_src0_size_pages, cb_dst0_size_pages, cb_short_circuit_size_pages, edm_channel_buffer_size, page_size); @@ -627,40 +641,44 @@ static bool is_cb_buffering_sufficient_to_avoid_deadlock( static std::tuple create_worker_circular_buffers( Tensor const& input_tensor, +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp ttnn::ccl::CCLOpConfig const& op_config, +======= + ttnn::utils::ccl::CCLOpConfig const& op_config, +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp CoreRangeSet const& worker_core_range, uint32_t worker_pages_per_transfer, - tt_metal::Program& program) { - tt::DataFormat df = tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); + tt::tt_metal::Program& program) { + tt::DataFormat df = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); uint32_t page_size_bytes = op_config.get_page_size(); // Input 0 CB - uint32_t src0_cb_index = CB::c_in0; - tt_metal::CircularBufferConfig cb_src0_config = - tt_metal::CircularBufferConfig(worker_pages_per_transfer * page_size_bytes, {{src0_cb_index, df}}) + uint32_t src0_cb_index = tt::CB::c_in0; + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig(worker_pages_per_transfer * page_size_bytes, {{src0_cb_index, df}}) .set_page_size(src0_cb_index, page_size_bytes); CBHandle cb_src0_workers = CreateCircularBuffer(program, worker_core_range, cb_src0_config); // Input 1 CB - uint32_t src1_cb_index = CB::c_in1; - tt_metal::CircularBufferConfig cb_src1_config = - tt_metal::CircularBufferConfig(worker_pages_per_transfer * page_size_bytes, {{src1_cb_index, df}}) + uint32_t src1_cb_index = tt::CB::c_in1; + tt::tt_metal::CircularBufferConfig cb_src1_config = + tt::tt_metal::CircularBufferConfig(worker_pages_per_transfer * page_size_bytes, {{src1_cb_index, df}}) .set_page_size(src1_cb_index, page_size_bytes); CBHandle cb_src1_workers = CreateCircularBuffer(program, worker_core_range, cb_src1_config); // Dataflow Writer Kernel input CB - uint32_t cb_dst0_index = CB::c_out0; - tt_metal::CircularBufferConfig cb_dst0_config = - tt_metal::CircularBufferConfig(worker_pages_per_transfer * page_size_bytes, {{cb_dst0_index, df}}) + uint32_t cb_dst0_index = tt::CB::c_out0; + tt::tt_metal::CircularBufferConfig cb_dst0_config = + tt::tt_metal::CircularBufferConfig(worker_pages_per_transfer * page_size_bytes, {{cb_dst0_index, df}}) .set_page_size(cb_dst0_index, page_size_bytes); CBHandle cb_dst0_sender_workers = CreateCircularBuffer(program, worker_core_range, cb_dst0_config); // From reader -> writer kernel (I think I need this because sharing the cb_dst0_sender_workers as output // of reader kernel (first output) and math kernel (all subsequent outputs) doesn't seem to work because // it seems like the math kernels hold some of the CB state in local variables) - uint32_t cb_short_circuit_index = CB::c_out1; - tt_metal::CircularBufferConfig cb_short_circuit_config = - tt_metal::CircularBufferConfig( + uint32_t cb_short_circuit_index = tt::CB::c_out1; + tt::tt_metal::CircularBufferConfig cb_short_circuit_config = + tt::tt_metal::CircularBufferConfig( (worker_pages_per_transfer * page_size_bytes) * 2, {{cb_short_circuit_index, df}}) .set_page_size(cb_short_circuit_index, page_size_bytes); CBHandle cb_short_circuit_sender_workers = @@ -679,7 +697,11 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( const uint32_t ring_index, const std::optional receiver_device_id, const std::optional sender_device_id, +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp ttnn::ccl::Topology topology) { +======= + ttnn::utils::ccl::Topology topology) { +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp log_trace(tt::LogOp, "reduce_scatter_with_workers entry"); TT_ASSERT( input_tensors.at(0).get_legacy_shape()[scatter_split_dim] == @@ -691,15 +713,24 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( /////////////// Constants/Configuration /// Constants/Configuration +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp ttnn::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode =ttnn::ccl::EriscDataMoverBufferSharingMode::ROUND_ROBIN; auto const& op_config =ttnn::ccl::CCLOpConfig(input_tensors, output_tensors, topology); std::unique_ptr input_tensor_config = ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(input_tensors.at(0)); std::unique_ptr output_tensor_config = ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(output_tensors.at(0)); +======= + ttnn::utils::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode =ttnn::utils::ccl::EriscDataMoverBufferSharingMode::ROUND_ROBIN; + auto const& op_config =ttnn::utils::ccl::CCLOpConfig(input_tensors, output_tensors, topology); + std::unique_ptr input_tensor_config = + ttnn::utils::ccl::CclOpTensorConfig::build_all_gather_tensor_config(input_tensors.at(0)); + std::unique_ptr output_tensor_config = + ttnn::utils::ccl::CclOpTensorConfig::build_all_gather_tensor_config(output_tensors.at(0)); +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp uint32_t per_step_dim_size = input_tensors.at(0).get_legacy_shape()[scatter_split_dim] / ring_size; uint32_t input_tensor_num_units_per_scatter_dim = - per_step_dim_size / constants::TILE_WIDTH; // TODO: find the divisibility based on layout + per_step_dim_size / tt::constants::TILE_WIDTH; // TODO: find the divisibility based on layout TT_ASSERT(input_tensor_num_units_per_scatter_dim > 0); uint32_t max_num_workers = std::min(8, input_tensor_num_units_per_scatter_dim); bool enable_bidirectional = false; @@ -727,21 +758,21 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( } ////////////////// - tt_metal::Program program{}; + tt::tt_metal::Program program{}; const auto& device = local_chip_tensor.device(); auto const& topology_config = ttnn::ccl::RingTopology(device, topology, sender_device_id, receiver_device_id, num_links, ring_size, ring_index); - auto dim_slice_factors = Shape(std::vector(local_chip_tensor.get_legacy_shape().rank(), 1)); + auto dim_slice_factors = tt::tt_metal::Shape(std::vector(local_chip_tensor.get_legacy_shape().rank(), 1)); dim_slice_factors[-1] = ring_size; CoreRangeSet const& worker_core_range = select_worker_cores(op_config, num_links, num_edm_channels); auto const& worker_cores = corerange_to_cores(worker_core_range, std::nullopt, true); // Semaphores && CBs - auto worker_receiver_semaphore_address = tt_metal::CreateSemaphore(program, worker_core_range, 0); - auto worker_sender_semaphore_address = tt_metal::CreateSemaphore(program, worker_core_range, 0); + auto worker_receiver_semaphore_address = tt::tt_metal::CreateSemaphore(program, worker_core_range, 0); + auto worker_sender_semaphore_address = tt::tt_metal::CreateSemaphore(program, worker_core_range, 0); uint32_t cb_num_pages = (cw_per_link_edm_builders.at(0).get_eth_buffer_size_bytes() / op_config.get_page_size()) * 2; @@ -801,7 +832,7 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( } // build the worker kernels - tt_metal::ComputeConfig compute_config; + tt::tt_metal::ComputeConfig compute_config; for (std::size_t link = 0; link < num_links; link++) { uint32_t global_worker_index = link * num_edm_channels; log_trace(tt::LogOp, "=============================================="); @@ -853,7 +884,11 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( } // Generate the EDM kernels +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp ttnn::ccl::generate_edm_kernels_for_ring_or_linear_topology( +======= + ttnn::utils::ccl::generate_edm_kernels_for_ring_or_linear_topology( +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp program, device, topology_config, @@ -889,5 +924,5 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( } // namespace reduce_scatter_detail } // namespace ccl -} // namespace tt_metal -} // namespace tt +} // namespace utils +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp index 8af69597b2b..ecacfca26a6 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp +++ b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp @@ -2,7 +2,11 @@ // // SPDX-License-Identifier: Apache-2.0 +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp #include "ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp" +======= +#include "ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp" +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp #include "ttnn/experimental/tt_dnn/op_library/reduce/reduce_op.hpp" #include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" @@ -11,8 +15,8 @@ #include "ttnn/operations/eltwise/binary/binary.hpp" -namespace tt { -namespace tt_metal { +namespace ttnn { +namespace utils { void ReduceScatter::validate(const std::vector& input_tensors) const { for (auto const& t : input_tensors) { @@ -25,13 +29,13 @@ void ReduceScatter::validate(const std::vector& input_tensors) const { } } -std::vector ReduceScatter::compute_output_shapes(const std::vector& input_tensors) const { +std::vector ReduceScatter::compute_output_shapes(const std::vector& input_tensors) const { auto shape = input_tensors[0].get_legacy_shape(); TT_ASSERT( shape[this->scatter_dim] % this->ring_size == 0, "The size of the scatter dimension must be a multiple of the ring size"); shape[this->scatter_dim] /= this->ring_size; - return std::vector(input_tensors.size(), shape); + return std::vector(input_tensors.size(), shape); } std::vector ReduceScatter::create_output_tensors(const std::vector& input_tensors) const { @@ -72,11 +76,16 @@ std::vector reduce_scatter_impl( output_tensors.reserve(input_tensors.size()); std::vector ops; ops.reserve(input_tensors.size()); +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp bool is_ring = topology ==ttnn::ccl::Topology::Ring; for (uint32_t i = 0; i < input_tensors.size(); ++i) { bool is_last_chip_in_clockwise_direction = is_ring ? false : i == (input_tensors.size() - 1); bool is_last_chip_in_counter_clockwise_direction = is_ring ? false : i == 0; +======= + bool is_ring = topology ==ttnn::utils::ccl::Topology::Ring; +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp + for (uint32_t i = 0; i < input_tensors.size(); ++i) { std::optional receiver_device_id = is_last_chip_in_clockwise_direction ? std::nullopt @@ -119,5 +128,5 @@ std::vector reduce_scatter( input_tensors, binary_op_type, scatter_dim, num_links, output_mem_config,ttnn::ccl::Topology::Ring); } -}; // namespace tt_metal -}; // namespace tt +}; // namespace utils +}; // namespace ttnn diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp index 357902e319b..2440d5babb9 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp +++ b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp @@ -11,8 +11,8 @@ #include "ttnn/operations/eltwise/binary/binary.hpp" -namespace tt { -namespace tt_metal { +namespace ttnn { +namespace utils { struct ReduceScatter { const ttnn::operations::binary::BinaryOpType binary_op_type; @@ -26,7 +26,7 @@ struct ReduceScatter { const ttnn::ccl::Topology topology; void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, std::vector &output_tensors) const; @@ -51,9 +51,17 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( const uint32_t ring_index, const std::optional receiver_device_id, const std::optional sender_device_id, +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp ttnn::ccl::Topology topology); +======= +<<<<<<< HEAD:tt_eager/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp + ttnn::utils::ccl::Topology topology); +======= + tt::tt_metal::ccl::Topology topology); +>>>>>>> bdb9766ed5... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp } }; // namespace ccl -}; // namespace tt_metal -}; // namespace tt +}; // namespace utils +}; // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/ccl_reduce_scatter_pybind.hpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/ccl_reduce_scatter_pybind.hpp new file mode 100644 index 00000000000..40bf407725d --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/ccl_reduce_scatter_pybind.hpp @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "ttnn/cpp/pybind11/decorators.hpp" +#include "ttnn/operations/ccl/reduce_scatter/device/ccl_reduce_scatter_op.hpp" +#include "ttnn/types.hpp" + +namespace py = pybind11; + +namespace ttnn { +namespace operations { +namespace ccl_reduce_scatter { + +namespace detail { + +template +void bind_ccl_reduce_scatter(py::module& module, const ccl_operation_t& operation, const char* doc) { + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const ccl_operation_t& self, + const std::vector& input_tensors, + const uint32_t scatter_dim, + ReduceOpMath math_op, + const uint32_t num_links, + const ttnn::MemoryConfig& memory_config) -> std::vector { + return self(input_tensors, scatter_dim, math_op, num_links, memory_config); + }, + py::arg("input_tensors"), + py::arg("scatter_dim"), + py::arg("math_op"), + py::kw_only(), + py::arg("num_links") = 1, + py::arg("memory_config") = std::nullopt}); +} + +} // namespace detail + + +void py_module(py::module& module) { + + detail::bind_ccl_reduce_scatter( + module, + ttnn::reduce_scatter, + R"doc(reduce_scatter(input_tensor: std::vector, scatter_dim: int, math_op: ReduceOpMath, *, num_links: int = 1, memory_config: Optional[ttnn.MemoryConfig] = None) -> std::vector + + Performs an reduce_scatter operation on multi-device :attr:`input_tensor` across all devices. + + Args: + * :attr:`input_tensor` (ttnn.Tensor): multi-device tensor + * :attr:`dim` (int) + + Keyword Args: + * :attr:`num_links` (int): Number of links to use for the all-gather operation. + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. + + Example: + + >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) + >>> output = ttnn.reduce_scatter(tensor, dim=0) + + )doc"); +} + +} // namespace ccl_reduce_scatter +} // namespace operations +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/ccl_reduce_scatter_op.hpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/ccl_reduce_scatter_op.hpp new file mode 100644 index 00000000000..bee76abb0b1 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/ccl_reduce_scatter_op.hpp @@ -0,0 +1,48 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp" +#include "ttnn/cpp/ttnn/multi_device.hpp" + +namespace ttnn { +namespace operations { +namespace ccl { + +struct ExecuteReduceScatter { + static inline const std::array input_tensor_schemas() { + return {ttnn::TensorSchema{ + 2, + 4, + {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, + {ttnn::ROW_MAJOR_LAYOUT, ttnn::TILE_LAYOUT}, + true, + false, + false, + false}}; + } + + template + static auto input_tensors_to_validate(const std::vector& input_tensors, Args&&... args) { + return std::forward_as_tuple(input_tensors.at(0)); + } + + static std::vector execute_on_main_thread( + const std::vector& input_tensors, + const uint32_t scatter_dim, + ReduceOpMath math_op, + const uint32_t num_links = 1, + const std::optional& memory_config = std::nullopt) { + MemoryConfig out_memory_config = memory_config.value_or(input_tensors.at(0).memory_config()); + return utils::reduce_scatter(input_tensors, scatter_dim, math_op, num_links, out_memory_config); + } +}; + +} // namespace ccl +} // namespace operations + +constexpr auto reduce_scatter = ttnn::register_operation("ttnn::reduce_scatter"); + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp new file mode 100644 index 00000000000..eab8cdfd40e --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp @@ -0,0 +1,928 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 +/// + +#include "common/core_coord.h" +#include "eth_l1_address_map.h" +#include "impl/buffers/buffer.hpp" +#include "impl/kernels/data_types.hpp" +#include "tensor/tensor_impl.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_metal/impl/buffers/circular_buffer_types.hpp" + +#include "ttnn/operations/eltwise/binary/binary.hpp" + +// Includes that need to be moved to CCL datastructures header +#include + +using namespace tt::constants; + +// Notes on abbreviations: +// cw = clockwise +// ccw = counter-clockwise +// edm = erisc data mover + +// How this reduce_scatter op works: +// For each chip, we have a element range of the input tensor shape that will eventually scatter +// out to it. For all other chunks outside that range, the chip will forward the chunk to the next chip. +// While forwarding the data, the chip will also reduce it with the local input tensor chunk corresponding +// with that received chunk. It will forward the partially reduced chunk. +// Reduces along rank + +namespace ttnn { + +namespace utils { + +namespace ccl { +namespace reduce_scatter_detail { +struct WorkerTransferInfo { + WorkerTransferInfo( + std::vector pages_per_full_chunk_per_worker, uint32_t num_links, uint32_t num_workers) : + pages_per_full_chunk_per_worker(pages_per_full_chunk_per_worker), + num_links(num_links), + num_workers(num_workers) {} + + uint32_t get_num_pages_per_full_chunk(uint32_t link, uint32_t worker_idx) const { + return pages_per_full_chunk_per_worker.at(link * num_workers + worker_idx); + } + + std::vector pages_per_full_chunk_per_worker; + uint32_t num_links; + uint32_t num_workers; +}; + +static std::size_t decide_number_of_edm_channels( + ttnn::ccl::CCLOpConfig const& ccl_op_config, std::size_t max_num_workers, bool enable_bidirectional) { + return ccl_op_config.is_input_sharded() ? std::min( + ccl_op_config.get_shard_grid_size(), + std::min(max_num_workers, enable_bidirectional ? 8 : 4)) + : std::min(max_num_workers, enable_bidirectional ? 8 : 4); +} + +struct ReduceScatterWorkerArgBuilder { + ReduceScatterWorkerArgBuilder( + ttnn::ccl::CCLOpConfig const& op_config, + ttnn::ccl::RingTopology const& topology_config, + ttnn::ccl::InterleavedTensorWorkerSlice const& worker_input_slice, + WorkerTransferInfo const& worker_transfer_info, + uint32_t worker_idx, + uint32_t link, + uint32_t cb_num_pages_per_packet, + uint32_t worker_sender_semaphore_address, + uint32_t worker_receiver_semaphore_address) : + op_config(op_config), + topology_config(topology_config), + worker_input_slice(worker_input_slice), + worker_transfer_info(worker_transfer_info), + cb_num_pages_per_packet(cb_num_pages_per_packet), + worker_sender_semaphore_address(worker_sender_semaphore_address), + worker_receiver_semaphore_address(worker_receiver_semaphore_address) { +#ifndef SEND_MATH_TERMINATE_SIGNAL + // This algorithm assumes that the worker slices are sized such that they start at the same x offsets for each + // new row they slice into (as they stride through the tensor) + std::size_t num_slice_iterations = + worker_input_slice.compute_num_worker_slice_iterations(worker_transfer_info.num_workers); + std::size_t worker_slice_num_pages = + worker_input_slice.worker_slice_shape.x * worker_input_slice.worker_slice_shape.y; + std::size_t pages_per_full_chunk = worker_transfer_info.get_num_pages_per_full_chunk(link, worker_idx); + std::size_t num_filler_pages_per_slice = pages_per_full_chunk - (worker_slice_num_pages % pages_per_full_chunk); + this->total_num_math_pages = (worker_input_slice.get_worker_slice_num_pages() + num_filler_pages_per_slice) * + num_slice_iterations * (topology_config.ring_size - 1); + + log_trace(tt::LogOp, "ReduceScatterWorkerArgBuilder: total_num_math_pages: {}", this->total_num_math_pages); +#endif + } + + std::vector generate_reduce_op_kernel_ct_args() const { + log_trace(tt::LogOp, "Reduce Scatter Worker CT Args: None"); + return {}; + } + + std::vector generate_reduce_op_kernel_rt_args( + uint32_t link, uint32_t worker_index, uint32_t ring_size) const { + log_trace(tt::LogOp, "generate_reduce_op_kernel_rt_args"); + + auto const& args = std::vector{total_num_math_pages, 1, 0}; + + std::size_t i = 0; + log_trace(tt::LogOp, "Reduce Scatter Worker RT Args:"); + log_trace(tt::LogOp, "\tblock_size: {}", args.at(i++)); + log_trace(tt::LogOp, "\ttotal_num_math_pages: {}", args.at(i++)); + log_trace(tt::LogOp, "\tacc_to_dst: {}", args.at(i++)); + + return args; + } + + std::vector generate_receiver_kernel_ct_args() const { + auto const& args = std::vector{ + static_cast(this->op_config.is_input_sharded() ? 1 : 0), + static_cast( + this->op_config.get_input_tensor(0).memory_config().buffer_type == BufferType::DRAM ? 1 : 0)}; + + std::size_t i = 0; + log_trace(tt::LogOp, "Reduce Scatter Receiver Worker CT Args:"); + log_trace(tt::LogOp, "\tis_sharded: {}", args.at(i++)); + log_trace(tt::LogOp, "\tsrc_is_dram: {}", args.at(i++)); + TT_ASSERT(args.size() == i, "Missed some args"); + + return args; + } + + std::vector generate_receiver_kernel_rt_args( + ttnn::ccl::WorkerXY edm_core, + uint32_t edm_core_semaphore_address, + uint32_t edm_core_buffer_address, + uint32_t link, + uint32_t worker_index, + bool is_in_clockwise_direction) const { + TT_ASSERT(edm_core_semaphore_address > 0); + TT_ASSERT(edm_core_buffer_address > 0); + auto const& local_input_tensor = this->op_config.get_input_tensor(0); + uint32_t starting_ring_index = + is_in_clockwise_direction ? (this->topology_config.ring_index == 0 ? this->topology_config.ring_size - 1 + : this->topology_config.ring_index - 1) + : (this->topology_config.ring_index == this->topology_config.ring_size - 1 + ? 0 + : this->topology_config.ring_index + 1); + auto args = std::vector{ + static_cast(local_input_tensor.buffer()->address()), + static_cast(this->topology_config.ring_size), // num_transfers + static_cast(this->worker_transfer_info.get_num_pages_per_full_chunk(link, worker_index)), + static_cast(this->op_config.get_page_size()), + static_cast(starting_ring_index), + static_cast(this->topology_config.ring_size), + static_cast(this->worker_receiver_semaphore_address), + static_cast(is_in_clockwise_direction ? 1 : 0), + static_cast(this->cb_num_pages_per_packet), + static_cast(edm_core.x), + static_cast(edm_core.y), + static_cast(edm_core_semaphore_address), + static_cast(edm_core_buffer_address), + + static_cast(worker_transfer_info.num_workers), + + static_cast(this->worker_input_slice.tensor_shape.x), + static_cast(this->worker_input_slice.tensor_shape.y), + + static_cast(this->worker_input_slice.tensor_slice_shape.x), + static_cast(this->worker_input_slice.tensor_slice_shape.y), + + static_cast(this->worker_input_slice.worker_slice_shape.x), + static_cast(this->worker_input_slice.worker_slice_shape.y), + + static_cast(this->worker_input_slice.worker_slice_offset.x), + static_cast(this->worker_input_slice.worker_slice_offset.y), + + this->total_num_math_pages}; + + std::size_t i = 0; + log_trace(tt::LogOp, "Reduce Scatter Receiver Worker RT Args:"); + log_trace(tt::LogOp, "\tsrc_addr: {}", args.at(i++)); + log_trace(tt::LogOp, "\tnum_transfers: {}", args.at(i++)); + log_trace(tt::LogOp, "\tfull_chunk_num_pages: {}", args.at(i++)); + log_trace(tt::LogOp, "\tpage_size: {}", args.at(i++)); + log_trace(tt::LogOp, "\tmy_ring_idx: {}", args.at(i++)); + log_trace(tt::LogOp, "\tring_size: {}", args.at(i++)); + log_trace(tt::LogOp, "\tsem_addr: {}", args.at(i++)); + log_trace(tt::LogOp, "\tis_clockwise_direction: {}", args.at(i++)); + log_trace(tt::LogOp, "\thalf_cb_n_pages: {}", args.at(i++)); + + log_trace(tt::LogOp, "\tedm_core_noc0_core_x: {}", args.at(i++)); + log_trace(tt::LogOp, "\tedm_core_noc0_core_y: {}", args.at(i++)); + log_trace(tt::LogOp, "\tedm_core_semaphore_address: {}", args.at(i++)); + log_trace(tt::LogOp, "\tedm_core_buffer_address: {}", args.at(i++)); + log_trace(tt::LogOp, "\tnum_concurrent_workers: {}", args.at(i++)); + + log_trace(tt::LogOp, "\tinput_tensor_shape.x={}", args.at(i++)); + log_trace(tt::LogOp, "\tinput_tensor_shape.y={}", args.at(i++)); + log_trace(tt::LogOp, "\ttensor_slice_shape.x={}", args.at(i++)); + log_trace(tt::LogOp, "\ttensor_slice_shape.y={}", args.at(i++)); + log_trace(tt::LogOp, "\tworker_slice_shape.x={}", args.at(i++)); + log_trace(tt::LogOp, "\tworker_slice_shape.y={}", args.at(i++)); + log_trace(tt::LogOp, "\tworker_slice_offset.x={}", args.at(i++)); + log_trace(tt::LogOp, "\tworker_slice_offset.y={}", args.at(i++)); + log_trace(tt::LogOp, "\ttotal_num_math_pages={}", args.at(i++)); + + TT_ASSERT(args.size() == i, "Missed some args"); + + return args; + } + + std::vector generate_sender_kernel_ct_args() const { + auto const& args = std::vector{ + static_cast(this->op_config.is_input_sharded() ? 1 : 0), + static_cast( + this->op_config.get_output_tensor(0).memory_config().buffer_type == BufferType::DRAM ? 1 : 0)}; + + std::size_t i = 0; + log_trace(tt::LogOp, "Reduce Scatter Sender Worker CT Args:"); + log_trace(tt::LogOp, "\tis_sharded: {}", args.at(i++)); + log_trace(tt::LogOp, "\tdst_is_dram: {}", args.at(i++)); + TT_ASSERT(args.size() == i, "Missed some args"); + + return args; + } + + std::vector generate_sender_kernel_rt_args( + ttnn::ccl::WorkerXY edm_core, + uint32_t edm_core_semaphore_address, + uint32_t edm_core_buffer_address, + uint32_t link, + uint32_t worker_index, + bool is_clockwise) const { + TT_ASSERT(edm_core_semaphore_address > 0); + TT_ASSERT(edm_core_buffer_address > 0); + auto const& local_output_tensor = this->op_config.get_output_tensor(0); + auto const& args = std::vector{ + static_cast(local_output_tensor.buffer()->address()), + static_cast(edm_core_buffer_address), + static_cast(edm_core_semaphore_address), + static_cast(edm_core.x), + static_cast(edm_core.y), + static_cast(this->topology_config.ring_size - 1), // num_transfers), + + static_cast(this->op_config.get_page_size()), + static_cast(this->worker_transfer_info.get_num_pages_per_full_chunk(link, worker_index)), + + static_cast(this->worker_sender_semaphore_address), + static_cast(this->cb_num_pages_per_packet), + + static_cast(worker_transfer_info.num_workers), + + // For sender side, all worker slice info is the same except for the tensor shape + // and for sender side specifically, there is only one tensor_slice_shape for the output + // tensor (as opposed to `ring_size` tensor_slice_shapes for the input tensor), so we can + // directly use it as the output tensor shape + static_cast(this->worker_input_slice.tensor_slice_shape.x), + static_cast(this->worker_input_slice.tensor_slice_shape.y), + static_cast(this->worker_input_slice.worker_slice_shape.x), + static_cast(this->worker_input_slice.worker_slice_shape.y), + static_cast(this->worker_input_slice.worker_slice_offset.x), + static_cast(this->worker_input_slice.worker_slice_offset.y), + + total_num_math_pages}; + + std::size_t i = 0; + log_trace(tt::LogOp, "Reduce Scatter Sender Worker RT Args:"); + log_trace(tt::LogOp, "\tdst_addr: {}", args.at(i++)); + log_trace(tt::LogOp, "\teth_sender_l1_base_addr: {}", args.at(i++)); + log_trace(tt::LogOp, "\teth_sender_l1_sem_addr: {}", args.at(i++)); + log_trace(tt::LogOp, "\teth_sender_noc_x: {}", args.at(i++)); + log_trace(tt::LogOp, "\teth_sender_noc_y: {}", args.at(i++)); + log_trace(tt::LogOp, "\tnum_transfers: {}", args.at(i++)); + log_trace(tt::LogOp, "\tpage_size: {}", args.at(i++)); + log_trace(tt::LogOp, "\tfull_chunk_num_pages: {}", args.at(i++)); + log_trace(tt::LogOp, "\twriter_send_sem_addr: {}", args.at(i++)); + log_trace(tt::LogOp, "\thalf_cb_n_pages: {}", args.at(i++)); + log_trace(tt::LogOp, "\tnum_concurrent_workers: {}", args.at(i++)); + + log_trace(tt::LogOp, "\toutput_tensor_shape.x: {}", args.at(i++)); + log_trace(tt::LogOp, "\toutput_tensor_shape.y: {}", args.at(i++)); + log_trace(tt::LogOp, "\tworker_slice_shape.x: {}", args.at(i++)); + log_trace(tt::LogOp, "\tworker_slice_shape.y: {}", args.at(i++)); + log_trace(tt::LogOp, "\tworker_slice_offset.x: {}", args.at(i++)); + log_trace(tt::LogOp, "\tworker_slice_offset.y: {}", args.at(i++)); + + log_trace(tt::LogOp, "\ttotal_num_math_pages={}", args.at(i++)); + + TT_ASSERT(args.size() == i, "Missed some args"); + + return args; + } + + ttnn::ccl::RingTopology const topology_config; + ttnn::ccl::CCLOpConfig const op_config; + ttnn::ccl::InterleavedTensorWorkerSlice const worker_input_slice; + WorkerTransferInfo const worker_transfer_info; + uint32_t cb_num_pages_per_packet; + uint32_t worker_sender_semaphore_address; + uint32_t worker_receiver_semaphore_address; + + uint32_t total_num_math_pages; + bool src_is_dram; + bool dst_is_dram; +}; + +struct EdmInterfaceAddresses { + std::unordered_map worker_sender_edm_semaphore_addresses; + std::unordered_map worker_sender_edm_buffer_addresses; + std::unordered_map worker_receiver_edm_semaphore_addresses; + std::unordered_map worker_receiver_edm_buffer_addresses; +}; + +// Future work: split this up further: +// 1) assign workers to EDM channel (with buffer sharing mode specified too) +// 2) Compute the semaphore and buffer addresses (for each EDM channel and worker) +// For now - the mapping between workers and EDM channels is 1:1 +static void add_worker_config_to_edm_builders( + Device* device, + ttnn::ccl::RingReduceScatterTensorSlicer& tensor_slicer, // TODO: Update to Generic ReduceScatterSlicer when it is implemented + ttnn::ccl::CCLOpConfig const& op_config, + std::vector const& worker_cores, + uint32_t num_channels_per_edm, + + std::vector& clockwise_edm_builders, + std::vector& counter_clockwise_edm_builders, + + uint32_t worker_sender_semaphore_address, + uint32_t worker_receiver_semaphore_address, + uint32_t link, + uint32_t ring_size, + std::function is_buffer_in_clockwise_direction_fn, + + EdmInterfaceAddresses& edm_interface_addresses) { + for (uint32_t c = 0; c < num_channels_per_edm; ++c) { + uint32_t global_worker_idx = c + num_channels_per_edm * link; + uint32_t num_workers_per_eth_buffer = 1; + + std::vector sender_worker_coords; + std::vector receiver_worker_coords; + for (uint32_t w = c * num_workers_per_eth_buffer; w < (c + 1) * num_workers_per_eth_buffer; ++w) { + sender_worker_coords.push_back(ttnn::ccl::WorkerXY( + device->worker_core_from_logical_core(worker_cores.at(w)).x, + device->worker_core_from_logical_core(worker_cores.at(w)).y)); + receiver_worker_coords.push_back(ttnn::ccl::WorkerXY( + device->worker_core_from_logical_core(worker_cores.at(w)).x, + device->worker_core_from_logical_core(worker_cores.at(w)).y)); + } + + // Get the expected message size in bytes for this worker + uint32_t expected_message_size_bytes = tensor_slicer.get_worker_slice_size_bytes(global_worker_idx); + + bool sender_enabled = true; // (!is_linear || !is_last_chip_in_chain); // update for linear + if (sender_enabled) { + auto& sender_edm_builder = is_buffer_in_clockwise_direction_fn(c) ? clockwise_edm_builders.at(link) + : counter_clockwise_edm_builders.at(link); + log_trace(tt::LogOp, "Adding sender EDM channel"); + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& sender_channel_buffer_info = + sender_edm_builder.add_sender_channel( + worker_sender_semaphore_address, + 1, // cw_edm_channel_num_messages_to_send_per_transfer.at(c) * (ring_size - 1), + sender_worker_coords, + expected_message_size_bytes); + edm_interface_addresses.worker_sender_edm_semaphore_addresses.insert( + {global_worker_idx, sender_channel_buffer_info.eth_semaphore_l1_address}); + edm_interface_addresses.worker_sender_edm_buffer_addresses.insert( + {global_worker_idx, sender_channel_buffer_info.eth_buffer_l1_address}); + } + + bool receiver_enabled = true; //(!is_linear || !is_first_chip_in_chain); + if (receiver_enabled) { + auto& receiver_edm_builder = is_buffer_in_clockwise_direction_fn(c) + ? counter_clockwise_edm_builders.at(link) + : clockwise_edm_builders.at(link); + log_trace(tt::LogOp, "Adding receiver EDM channel"); + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& receiver_channel_buffer_info = + receiver_edm_builder.add_receiver_channel( + 1, + receiver_worker_coords, + expected_message_size_bytes); + edm_interface_addresses.worker_receiver_edm_semaphore_addresses.insert( + {global_worker_idx, receiver_channel_buffer_info.eth_semaphore_l1_address}); + edm_interface_addresses.worker_receiver_edm_buffer_addresses.insert( + {global_worker_idx, receiver_channel_buffer_info.eth_buffer_l1_address}); + } + } +} + +static std::tuple build_reduce_scatter_worker( + tt::tt_metal::Program& program, + Device const* device, +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp + ttnn::ccl::RingTopology const& topology_config, + ttnn::ccl::CCLOpConfig const& op_config, + ReduceScatterWorkerArgBuilder const& worker_arg_builder, + std::vector& cw_edm_builders, + std::vector& ccw_edm_builders, + EdmInterfaceAddresses const& edm_interface_addresses, +======= + ttnn::utils::ccl::CCLOpConfig const& op_config, + ReduceScatterWorkerArgBuilder const& worker_arg_builder, + std::vector& cw_edm_builders, + std::vector& ccw_edm_builders, +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp + CoreCoord const& worker_core, + uint32_t num_edm_channels, + uint32_t link, + uint32_t ring_size, + uint32_t worker_index, + std::map const& worker_defines, + ttnn::operations::binary::BinaryOpType binary_math_op) { + + TT_ASSERT(worker_defines.size() > 0); + for (auto const& [key, value] : worker_defines) { + log_trace(tt::LogOp, "Worker Define: {} = {}", key, value); + } + static std::string const& receiver_kernel_path = +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp + "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp"; + static std::string const& sender_kernel_path = + "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp"; +======= + "ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp"; + static std::string const& sender_kernel_path = + "ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp"; +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp + + // This will be configurable by sharded/non-sharded but present the same arg builder + KernelHandle worker_receiver_kernel_id, worker_sender_kernel_id; + + bool is_in_clockwise_direction = true; // TODO: bidirectional + uint32_t global_worker_index = link * num_edm_channels + worker_index; + { + CoreCoord const& receiver_edm = is_in_clockwise_direction ? topology_config.eth_receiver_cores.at(link) + : topology_config.eth_sender_cores.at(link); + ttnn::ccl::WorkerXY receiver_edm_noc_coord =ttnn::ccl::WorkerXY( + device->ethernet_core_from_logical_core(receiver_edm).x, + device->ethernet_core_from_logical_core(receiver_edm).y); + const uint32_t edm_core_semaphore_address = + is_in_clockwise_direction + ? edm_interface_addresses.worker_receiver_edm_semaphore_addresses.at(global_worker_index) + : edm_interface_addresses.worker_sender_edm_semaphore_addresses.at(global_worker_index); + const uint32_t edm_core_buffer_address = + is_in_clockwise_direction + ? edm_interface_addresses.worker_receiver_edm_buffer_addresses.at(global_worker_index) + : edm_interface_addresses.worker_sender_edm_buffer_addresses.at(global_worker_index); + + worker_receiver_kernel_id = tt::tt_metal::CreateKernel( + program, + receiver_kernel_path, + worker_core, + tt::tt_metal::ReaderDataMovementConfig(worker_arg_builder.generate_receiver_kernel_ct_args(), worker_defines)); + + tt::tt_metal::SetRuntimeArgs( + program, + worker_receiver_kernel_id, + worker_core, + worker_arg_builder.generate_receiver_kernel_rt_args( + receiver_edm_noc_coord, + edm_core_semaphore_address, + edm_core_buffer_address, + link, + worker_index, + is_in_clockwise_direction)); + } + + { + vector compute_kernel_args = {}; + constexpr bool fp32_dest_acc_en = false; + constexpr bool math_approx_mode = false; + std::map eltwise_defines = ttnn::operations::binary::utils::get_defines(binary_math_op); + KernelHandle worker_reduce_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/compute/eltwise_binary_kernel.cpp", + worker_core, + tt::tt_metal::ComputeConfig{ + .math_fidelity = MathFidelity::HiFi4, + .fp32_dest_acc_en = fp32_dest_acc_en, + .math_approx_mode = math_approx_mode, + .compile_args = compute_kernel_args, + .defines = eltwise_defines}); + + tt::tt_metal::SetRuntimeArgs( + program, + worker_reduce_kernel_id, + worker_core, + worker_arg_builder.generate_reduce_op_kernel_rt_args(link, worker_index, ring_size)); + } + + { + CoreCoord sender_edm = is_in_clockwise_direction ? topology_config.eth_sender_cores.at(link) + : topology_config.eth_receiver_cores.at(link); + ttnn::ccl::WorkerXY const sender_edm_noc_coord =ttnn::ccl::WorkerXY( + device->ethernet_core_from_logical_core(sender_edm).x, + device->ethernet_core_from_logical_core(sender_edm).y); + TT_ASSERT(sender_edm_noc_coord.y == 0 || sender_edm_noc_coord.y == 6); + const uint32_t edm_core_semaphore_address = + is_in_clockwise_direction + ? edm_interface_addresses.worker_sender_edm_semaphore_addresses.at(global_worker_index) + : edm_interface_addresses.worker_receiver_edm_semaphore_addresses.at(global_worker_index); + const uint32_t edm_core_buffer_address = + is_in_clockwise_direction + ? edm_interface_addresses.worker_sender_edm_buffer_addresses.at(global_worker_index) + : edm_interface_addresses.worker_receiver_edm_buffer_addresses.at(global_worker_index); + worker_sender_kernel_id = tt::tt_metal::CreateKernel( + program, + sender_kernel_path, + worker_core, + tt::tt_metal::WriterDataMovementConfig(worker_arg_builder.generate_sender_kernel_ct_args(), worker_defines)); + + tt::tt_metal::SetRuntimeArgs( + program, + worker_sender_kernel_id, + worker_core, + worker_arg_builder.generate_sender_kernel_rt_args( + sender_edm_noc_coord, + edm_core_semaphore_address, + edm_core_buffer_address, + link, + worker_index, + is_in_clockwise_direction)); + } + + return {worker_receiver_kernel_id, worker_sender_kernel_id}; +} + +static CoreRangeSet select_worker_cores( + ttnn::ccl::CCLOpConfig const& op_config, std::size_t num_links, std::size_t num_edm_channels) { + switch (op_config.get_topology()) { + case ttnn::ccl::Topology::Linear: + return CoreRangeSet({CoreRange(CoreCoord(0, 0), CoreCoord(num_edm_channels - 1, num_links - 1))}); + case ttnn::ccl::Topology::Ring: + return CoreRangeSet({CoreRange(CoreCoord(0, 0), CoreCoord(num_edm_channels - 1, num_links - 1))}); + default: TT_ASSERT(false, "Unsupported topology"); return CoreRangeSet({}); + }; +} + +static WorkerTransferInfo compute_num_edm_messages_per_channel( + ttnn::ccl::CCLOpConfig const& op_config, + ttnn::ccl::RingReduceScatterTensorSlicer& tensor_slicer, // TODO: Update to Generic ReduceScatterSlicer when it is implemented + std::vector const& cw_per_link_edm_builders, + std::vector const& ccw_per_link_edm_builders, + std::size_t const num_edm_channels, + std::size_t const num_links, + std::size_t const ring_size) { + uint32_t const page_size_in_bytes = op_config.get_page_size(); + TT_ASSERT(num_edm_channels > 0); + TT_ASSERT(num_links > 0); + TT_ASSERT(page_size_in_bytes > 0); + log_trace(tt::LogOp, "WorkerTransferInfo"); + std::size_t total_num_workers = num_edm_channels * num_links; + + auto get_iter_begin = [num_edm_channels](auto& vec, std::size_t link) -> auto { + return vec.begin() + (link * num_edm_channels); + }; + + auto get_iter_end = [num_edm_channels, num_links](auto& vec, std::size_t link) -> auto { + bool last_link = link == num_links - 1; + TT_ASSERT( + (!last_link && ((link + 1) * num_edm_channels < vec.size())) || + (last_link && ((link + 1) * num_edm_channels == vec.size()))); + return last_link ? vec.end() : vec.begin() + ((link + 1) * num_edm_channels); + }; + + // Pages per EDM channel + std::size_t total_num_edm_channels = num_links * num_edm_channels; + log_trace(tt::LogOp, "total_num_edm_channels: {}", total_num_edm_channels); + + std::vector num_pages_per_full_chunk(total_num_edm_channels * num_links, 0); + + for (std::size_t link = 0; link < num_links; link++) { + std::size_t edm_channel_size_in_bytes = cw_per_link_edm_builders.at(link).get_eth_buffer_size_bytes(); + std::size_t num_pages_per_edm_buffer = edm_channel_size_in_bytes / page_size_in_bytes; + log_trace( + tt::LogOp, + "link {}, edm_channel_size_in_bytes: {}, page_size_in_bytes: {}, num_pages_per_edm_buffer: {}", + link, + edm_channel_size_in_bytes, + page_size_in_bytes, + num_pages_per_edm_buffer); + + std::fill( + get_iter_begin(num_pages_per_full_chunk, link), + get_iter_end(num_pages_per_full_chunk, link), + num_pages_per_edm_buffer); + } + + log_trace(tt::LogOp, "-- num_pages_per_full_chunk:"); + for (std::size_t l = 0; l < num_links; l++) { + for (std::size_t w = 0; w < num_edm_channels; w++) { + log_trace( + tt::LogOp, "\t\t(link={},worker={}): {}", l, w, num_pages_per_full_chunk.at(l * num_edm_channels + w)); + } + } + + return WorkerTransferInfo(num_pages_per_full_chunk, num_links, num_edm_channels); +} + +static uint32_t compute_maximum_worker_slice_in_bytes( + uint32_t cb_src0_size_pages, + uint32_t cb_dst0_size_pages, + uint32_t cb_short_circuit_size_pages, + std::size_t edm_channel_buffer_size, + uint32_t page_size) { + return std::min(cb_short_circuit_size_pages, cb_src0_size_pages + cb_dst0_size_pages) * page_size + + edm_channel_buffer_size; +} + +static bool is_cb_buffering_sufficient_to_avoid_deadlock( +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp + ttnn::ccl::InterleavedTensorWorkerSlice const& worker_slice, +======= + ttnn::utils::ccl::InterleavedTensorWorkerSlice const& worker_slice, +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp + uint32_t cb_src0_size_pages, + uint32_t cb_dst0_size_pages, + uint32_t cb_short_circuit_size_pages, + std::size_t edm_channel_buffer_size, + uint32_t page_size) { + uint32_t worker_size_pages_rounded_up = + tt::round_up(worker_slice.worker_slice_shape.x * worker_slice.worker_slice_shape.y, cb_src0_size_pages / 2); + uint32_t worker_slice_size_bytes = worker_size_pages_rounded_up * page_size; + uint32_t available_buffering_capacity = compute_maximum_worker_slice_in_bytes( + cb_src0_size_pages, cb_dst0_size_pages, cb_short_circuit_size_pages, edm_channel_buffer_size, page_size); + log_trace(tt::LogOp, "worker_slice.worker_slice_shape.x: {}", worker_slice.worker_slice_shape.x); + log_trace(tt::LogOp, "worker_slice.worker_slice_shape.y: {}", worker_slice.worker_slice_shape.y); + log_trace(tt::LogOp, "worker_slice_size_bytes: {}", worker_slice_size_bytes); + log_trace(tt::LogOp, "worker_size_pages_rounded_up: {}", worker_size_pages_rounded_up); + log_trace(tt::LogOp, "cb_src0_size_pages: {}", cb_src0_size_pages); + log_trace(tt::LogOp, "cb_dst0_size_pages: {}", cb_dst0_size_pages); + log_trace(tt::LogOp, "page_size: {}", page_size); + log_trace(tt::LogOp, "edm_channel_buffer_size: {}", edm_channel_buffer_size); + log_trace(tt::LogOp, "available_buffering_capacity: {}", available_buffering_capacity); + + return available_buffering_capacity >= worker_slice_size_bytes; +} + +static std::tuple create_worker_circular_buffers( + Tensor const& input_tensor, +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp + ttnn::ccl::CCLOpConfig const& op_config, +======= + ttnn::utils::ccl::CCLOpConfig const& op_config, +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp + CoreRangeSet const& worker_core_range, + uint32_t worker_pages_per_transfer, + tt::tt_metal::Program& program) { + tt::DataFormat df = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); + uint32_t page_size_bytes = op_config.get_page_size(); + + // Input 0 CB + uint32_t src0_cb_index = tt::CB::c_in0; + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig(worker_pages_per_transfer * page_size_bytes, {{src0_cb_index, df}}) + .set_page_size(src0_cb_index, page_size_bytes); + CBHandle cb_src0_workers = CreateCircularBuffer(program, worker_core_range, cb_src0_config); + + // Input 1 CB + uint32_t src1_cb_index = tt::CB::c_in1; + tt::tt_metal::CircularBufferConfig cb_src1_config = + tt::tt_metal::CircularBufferConfig(worker_pages_per_transfer * page_size_bytes, {{src1_cb_index, df}}) + .set_page_size(src1_cb_index, page_size_bytes); + CBHandle cb_src1_workers = CreateCircularBuffer(program, worker_core_range, cb_src1_config); + + // Dataflow Writer Kernel input CB + uint32_t cb_dst0_index = tt::CB::c_out0; + tt::tt_metal::CircularBufferConfig cb_dst0_config = + tt::tt_metal::CircularBufferConfig(worker_pages_per_transfer * page_size_bytes, {{cb_dst0_index, df}}) + .set_page_size(cb_dst0_index, page_size_bytes); + CBHandle cb_dst0_sender_workers = CreateCircularBuffer(program, worker_core_range, cb_dst0_config); + + // From reader -> writer kernel (I think I need this because sharing the cb_dst0_sender_workers as output + // of reader kernel (first output) and math kernel (all subsequent outputs) doesn't seem to work because + // it seems like the math kernels hold some of the CB state in local variables) + uint32_t cb_short_circuit_index = tt::CB::c_out1; + tt::tt_metal::CircularBufferConfig cb_short_circuit_config = + tt::tt_metal::CircularBufferConfig( + (worker_pages_per_transfer * page_size_bytes) * 2, {{cb_short_circuit_index, df}}) + .set_page_size(cb_short_circuit_index, page_size_bytes); + CBHandle cb_short_circuit_sender_workers = + CreateCircularBuffer(program, worker_core_range, cb_short_circuit_config); + + return {cb_src0_workers, cb_src1_workers, cb_dst0_sender_workers, cb_short_circuit_sender_workers}; +} + +operation::ProgramWithCallbacks reduce_scatter_with_workers( + const std::vector& input_tensors, + const std::vector& output_tensors, + ttnn::operations::binary::BinaryOpType reduce_op, + const uint32_t scatter_split_dim, + const uint32_t num_links, + const uint32_t ring_size, + const uint32_t ring_index, + const std::optional receiver_device_id, + const std::optional sender_device_id, +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp + ttnn::ccl::Topology topology) { +======= + ttnn::utils::ccl::Topology topology) { +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp + log_trace(tt::LogOp, "reduce_scatter_with_workers entry"); + TT_ASSERT( + input_tensors.at(0).get_legacy_shape()[scatter_split_dim] == + output_tensors.at(0).get_legacy_shape()[scatter_split_dim] * ring_size, + "Input and output tensor shapes must match"); + TT_ASSERT( + input_tensors.at(0).buffer()->num_pages() % ring_size == 0, + "Reduce scatter current only supports even divisibility of input tensor(s) across ranks"); + + /////////////// Constants/Configuration + /// Constants/Configuration +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp + ttnn::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode =ttnn::ccl::EriscDataMoverBufferSharingMode::ROUND_ROBIN; + auto const& op_config =ttnn::ccl::CCLOpConfig(input_tensors, output_tensors, topology); + std::unique_ptr input_tensor_config = + ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(input_tensors.at(0)); + std::unique_ptr output_tensor_config = + ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(output_tensors.at(0)); +======= + ttnn::utils::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode =ttnn::utils::ccl::EriscDataMoverBufferSharingMode::ROUND_ROBIN; + auto const& op_config =ttnn::utils::ccl::CCLOpConfig(input_tensors, output_tensors, topology); + std::unique_ptr input_tensor_config = + ttnn::utils::ccl::CclOpTensorConfig::build_all_gather_tensor_config(input_tensors.at(0)); + std::unique_ptr output_tensor_config = + ttnn::utils::ccl::CclOpTensorConfig::build_all_gather_tensor_config(output_tensors.at(0)); +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp + uint32_t per_step_dim_size = input_tensors.at(0).get_legacy_shape()[scatter_split_dim] / ring_size; + uint32_t input_tensor_num_units_per_scatter_dim = + per_step_dim_size / tt::constants::TILE_WIDTH; // TODO: find the divisibility based on layout + TT_ASSERT(input_tensor_num_units_per_scatter_dim > 0); + uint32_t max_num_workers = std::min(8, input_tensor_num_units_per_scatter_dim); + bool enable_bidirectional = false; + auto num_edm_channels = decide_number_of_edm_channels(op_config, max_num_workers, enable_bidirectional); + log_trace(tt::LogOp, "num_edm_channels: {}", num_edm_channels); + auto edm_termination_mode =ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; + auto const& edm_builder = create_erisc_datamover_builder( + num_edm_channels, op_config.get_page_size(), buffer_sharing_mode, edm_termination_mode); + TT_ASSERT(num_edm_channels > 0); + + Tensor const& local_chip_tensor = input_tensors.at(0); + Tensor const& local_chip_output_tensor = output_tensors.at(0); + + std::map worker_defines; + std::vector worker_receiver_kernels; + std::vector worker_sender_kernels; + std::vector cw_per_link_edm_builders(num_links, edm_builder); + std::vector ccw_per_link_edm_builders(num_links, edm_builder); + + bool rm = local_chip_tensor.get_layout() == Layout::ROW_MAJOR; + if (rm) { + worker_defines["RM_INTERLEAVED"] = "1"; + } else { + worker_defines["TILE_INTERLEAVED"] = "1"; + } + + ////////////////// + tt::tt_metal::Program program{}; + const auto& device = local_chip_tensor.device(); + + auto const& topology_config = + ttnn::ccl::RingTopology(device, topology, sender_device_id, receiver_device_id, num_links, ring_size, ring_index); + + auto dim_slice_factors = tt::tt_metal::Shape(std::vector(local_chip_tensor.get_legacy_shape().rank(), 1)); + dim_slice_factors[-1] = ring_size; + + CoreRangeSet const& worker_core_range = select_worker_cores(op_config, num_links, num_edm_channels); + auto const& worker_cores = corerange_to_cores(worker_core_range, std::nullopt, true); + + // Semaphores && CBs + auto worker_receiver_semaphore_address = tt::tt_metal::CreateSemaphore(program, worker_core_range, 0); + auto worker_sender_semaphore_address = tt::tt_metal::CreateSemaphore(program, worker_core_range, 0); + + uint32_t cb_num_pages = + (cw_per_link_edm_builders.at(0).get_eth_buffer_size_bytes() / op_config.get_page_size()) * 2; + uint32_t cb_num_pages_per_packet = cb_num_pages / 2; + log_trace(tt::LogOp, "cb_num_pages: {}", cb_num_pages); + auto const& [cb_src0_workers, cb_src1_workers, cb_dst0_sender_workers, cb_short_circuit_sender_workers] = + create_worker_circular_buffers(local_chip_tensor, op_config, worker_core_range, cb_num_pages, program); + + uint32_t max_worker_slice_in_bytes = compute_maximum_worker_slice_in_bytes( + cb_num_pages, + cb_num_pages, + cb_num_pages, + cw_per_link_edm_builders.at(0).get_eth_buffer_size_bytes(), + op_config.get_page_size()); + auto tensor_slicer =ttnn::ccl::RingReduceScatterTensorSlicer( + local_chip_tensor, + local_chip_output_tensor, + scatter_split_dim, + ring_index, + ring_size, + num_edm_channels * num_links, + max_worker_slice_in_bytes, + cb_num_pages / 2); + + // Not per buffer because the buffer sharing mode may cause some buffers to share EDM transfers + WorkerTransferInfo const& worker_transfer_info = compute_num_edm_messages_per_channel( + op_config, + tensor_slicer, + cw_per_link_edm_builders, + ccw_per_link_edm_builders, + num_edm_channels, + num_links, + ring_size); + + // Configure the EDM builders + EdmInterfaceAddresses edm_interface_addresses; + for (std::size_t link = 0; link < num_links; link++) { + add_worker_config_to_edm_builders( + device, + tensor_slicer, + op_config, + worker_cores, + num_edm_channels, + + cw_per_link_edm_builders, + ccw_per_link_edm_builders, + + worker_sender_semaphore_address, + worker_receiver_semaphore_address, + link, + ring_size, + [enable_bidirectional, num_edm_channels](uint32_t x) { + return enable_bidirectional ? (x % num_edm_channels == 0) : true; + }, + + edm_interface_addresses); + } + + // build the worker kernels + tt::tt_metal::ComputeConfig compute_config; + for (std::size_t link = 0; link < num_links; link++) { + uint32_t global_worker_index = link * num_edm_channels; + log_trace(tt::LogOp, "=============================================="); + log_trace(tt::LogOp, "------------------ Link: {} ------------------", link); + for (std::size_t worker = 0; worker < num_edm_channels; worker++) { + std::size_t global_worker_index = worker + link * num_edm_channels; + log_trace(tt::LogOp, "------ Worker: {} (global ID={})", worker, global_worker_index); + // This will be configurable by sharded/non-sharded but present the same arg builder + auto const& worker_slice = tensor_slicer.get_worker_slice(global_worker_index); + auto worker_arg_builder = ReduceScatterWorkerArgBuilder( + op_config, + topology_config, + worker_slice, + worker_transfer_info, + worker, + link, + cb_num_pages_per_packet, + worker_sender_semaphore_address, + worker_receiver_semaphore_address); + + log_trace(tt::LogOp, "worker_cores.at(global_worker_index): {}", worker_cores.at(global_worker_index)); + auto [receiver_kernel_id, sender_kernel_id] = build_reduce_scatter_worker( + program, + device, + topology_config, + op_config, + worker_arg_builder, + cw_per_link_edm_builders, + ccw_per_link_edm_builders, + edm_interface_addresses, + worker_cores.at(global_worker_index), + num_edm_channels, + link, + ring_size, + worker, + worker_defines, + reduce_op); + worker_receiver_kernels.push_back(receiver_kernel_id); + worker_sender_kernels.push_back(sender_kernel_id); + + TT_ASSERT(is_cb_buffering_sufficient_to_avoid_deadlock( + worker_slice, + cb_num_pages, + cb_num_pages, + cb_num_pages, + cw_per_link_edm_builders.at(0).get_eth_buffer_size_bytes(), + op_config.get_page_size())); + } + } + + // Generate the EDM kernels +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp + ttnn::ccl::generate_edm_kernels_for_ring_or_linear_topology( +======= + ttnn::utils::ccl::generate_edm_kernels_for_ring_or_linear_topology( +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp + program, + device, + topology_config, + cw_per_link_edm_builders, + ccw_per_link_edm_builders, + receiver_device_id, + sender_device_id); + + uint32_t total_num_workers = worker_cores.size(); + auto override_runtime_arguments_callback = + [topology_config, worker_receiver_kernels, worker_sender_kernels, worker_cores, total_num_workers, ring_index]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) { + const auto& input = input_tensors.at(0); + const auto& output = output_tensors.at(0); + TT_ASSERT(worker_sender_kernels.size() == worker_receiver_kernels.size()); + for (uint32_t i = 0; i < worker_sender_kernels.size(); ++i) { + auto& worker_receiver_runtime_args = + GetRuntimeArgs(program, worker_receiver_kernels.at(i), worker_cores.at(i)); + worker_receiver_runtime_args.at(0) = input.buffer()->address(); + + auto& worker_sender_runtime_args = + GetRuntimeArgs(program, worker_sender_kernels.at(i), worker_cores.at(i)); + worker_sender_runtime_args.at(0) = output.buffer()->address(); + } + }; + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} + +} // namespace reduce_scatter_detail +} // namespace ccl +} // namespace utils +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp new file mode 100644 index 00000000000..850bbfadca4 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp @@ -0,0 +1,356 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "dataflow_api.h" +#include "debug/assert.h" +#include "tensix_types.h" +#include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" + +using ttnn::ccl::coord_t; +using ttnn::ccl::WorkerXY; + +struct reduce_scatter_reader_common_args_t { + reduce_scatter_reader_common_args_t(uint32_t& arg_idx) : + src_addr(get_arg_val(arg_idx++)), + num_transfers(get_arg_val(arg_idx++)), + full_chunk_num_pages(get_arg_val(arg_idx++)), + page_size(get_arg_val(arg_idx++)), + + my_ring_idx(get_arg_val(arg_idx++)), + ring_size(get_arg_val(arg_idx++)), + sem_addr(get_arg_val(arg_idx++)), + + is_clockwise_direction(get_arg_val(arg_idx++) == 1), + half_cb_n_pages(get_arg_val(arg_idx++)), + edm_core_noc0_core_x(get_arg_val(arg_idx++)), + edm_core_noc0_core_y(get_arg_val(arg_idx++)), + edm_core_semaphore_address(get_arg_val(arg_idx++)), + edm_core_buffer_address(get_arg_val(arg_idx++)), + num_concurrent_workers(get_arg_val(arg_idx++)), + + input_tensor_shape(ttnn::ccl::coord_from_args(arg_idx)), + tensor_slice_shape(ttnn::ccl::coord_from_args(arg_idx)), + worker_slice_shape(ttnn::ccl::coord_from_args(arg_idx)), + worker_slice_offset(ttnn::ccl::coord_from_args(arg_idx)), + total_eltwise_kernel_num_pages(get_arg_val(arg_idx++)) + { + ASSERT(full_chunk_num_pages > 0); + ASSERT(page_size > 0); + ASSERT(ring_size > 0); + ASSERT(half_cb_n_pages > 0); + } + + const uint32_t src_addr; + const uint32_t num_transfers; + const uint32_t full_chunk_num_pages; + const uint32_t page_size; + uint32_t my_ring_idx; + const uint32_t ring_size; + const uint32_t sem_addr; + + const bool is_clockwise_direction; + + const uint32_t half_cb_n_pages; + const uint32_t edm_core_noc0_core_x; + const uint32_t edm_core_noc0_core_y; + const uint32_t edm_core_semaphore_address; + const uint32_t edm_core_buffer_address; + const uint32_t num_concurrent_workers; + + coord_t input_tensor_shape; + coord_t tensor_slice_shape; + coord_t worker_slice_shape; + coord_t worker_slice_offset; + uint32_t total_eltwise_kernel_num_pages; +}; +#ifdef RM_INTERLEAVED +constexpr bool rm_interleaved_addr_gen_mode = true; +#else +constexpr bool rm_interleaved_addr_gen_mode = false; +#endif + +template +struct interleaved_addr_gen_t { + using type = InterleavedAddrGen; +}; +template <> +struct interleaved_addr_gen_t { + using type = InterleavedAddrGen; +}; +template <> +struct interleaved_addr_gen_t { + using type = InterleavedAddrGen; +}; +template <> +struct interleaved_addr_gen_t { + using type = InterleavedAddrGenFast; +}; +template <> +struct interleaved_addr_gen_t { + using type = InterleavedAddrGenFast; +}; + +template +struct reduce_scatter_reader_unique_args_t : public reduce_scatter_reader_common_args_t { + using src_addr_gen_t = typename interleaved_addr_gen_t::type; + + reduce_scatter_reader_unique_args_t(uint32_t& arg_idx, const DataFormat in0_df) : + reduce_scatter_reader_common_args_t(arg_idx) { + this->s = { + .bank_base_address = this->src_addr, + .page_size = page_size +#if defined TILE_INTERLEAVED + , + .data_format = in0_df +#endif + }; + } + + src_addr_gen_t s; + + void dprint() const { + DPRINT << "RSR args:" + << "\n\tsrc_addr=" << src_addr << "\n\tnum_transfers=" << num_transfers << "\n\tpage_size=" << page_size + << "\n\tfull_chunk_num_pages=" << full_chunk_num_pages << "\n\tmy_ring_idx=" << my_ring_idx + << "\n\tsem_addr=" << sem_addr << "\n\tis_clockwise_direction=" << (uint32_t)is_clockwise_direction + << "\n\thalf_cb_n_pages=" << half_cb_n_pages << "\n\tring_size=" << ring_size + << "\n\tedm_core_noc0_core_x=" << edm_core_noc0_core_x + << "\n\tedm_core_noc0_core_y=" << edm_core_noc0_core_y + << "\n\tedm_core_semaphore_address=" << edm_core_semaphore_address + << "\n\tedm_core_buffer_address=" << edm_core_buffer_address << "\n"; + } +}; + +template +struct reduce_scatter_reader_unique_args_t : public reduce_scatter_reader_common_args_t { + reduce_scatter_reader_unique_args_t(uint32_t& arg_idx, const DataFormat in0_df) : + reduce_scatter_reader_common_args_t(arg_idx), + shard_num_pages(get_arg_val(arg_idx++)), + num_l1_cores(get_arg_val(arg_idx++)), + l1_cores_ptr(reinterpret_cast(get_arg_addr(arg_idx))) { + arg_idx += this->num_l1_cores; + } + + const uint32_t shard_num_pages; + const uint32_t num_l1_cores; + const WorkerXY* const l1_cores_ptr; + + void dprint() const {} +}; + +using advance_to_next_transfer_slice_result_t = std::tuple< + uint32_t, // ring_index + uint32_t // slice_base_page_offset + >; +template +advance_to_next_transfer_slice_result_t advance_to_next_transfer_slice( + uint32_t const ring_size, + uint32_t const curr_ring_idx, + uint32_t const slice_base_page_offset, + coord_t const& input_tensor_shape, + coord_t const& tensor_slice_shape, + bool const is_clockwise_direction) { + bool const sliced_only_on_width = tensor_slice_shape.x < input_tensor_shape.x && tensor_slice_shape.y == input_tensor_shape.y; + uint32_t single_ring_idx_stride = + sliced_only_on_width ? tensor_slice_shape.x : tensor_slice_shape.y * input_tensor_shape.x; + uint32_t n_minus_one_ring_indices_stride = sliced_only_on_width + ? tensor_slice_shape.x * (ring_size - 1) + : tensor_slice_shape.y * input_tensor_shape.x * (ring_size - 1); + + if constexpr (!is_sharded) { + if (is_clockwise_direction) { + if (curr_ring_idx == 0) { + return advance_to_next_transfer_slice_result_t{ + ring_size - 1, + slice_base_page_offset + n_minus_one_ring_indices_stride, + }; + } else { + return advance_to_next_transfer_slice_result_t{ + curr_ring_idx - 1, + slice_base_page_offset - single_ring_idx_stride, + }; + } + } else { + if (curr_ring_idx == ring_size - 1) { + return advance_to_next_transfer_slice_result_t{ + 0, + slice_base_page_offset - n_minus_one_ring_indices_stride, + }; + } else { + return advance_to_next_transfer_slice_result_t{ + curr_ring_idx + 1, + slice_base_page_offset + single_ring_idx_stride, + }; + } + } + } +} + +void kernel_main() { + constexpr bool is_sharded = get_compile_time_arg_val(0) == 1; + + // Currently meaningless when `is_sharded=true` + constexpr bool src_is_dram = get_compile_time_arg_val(1) == 1; + + uint32_t arg_idx = 0; + + constexpr uint32_t to_dm_sender_short_circuit_cb = tt::CB::c_out1; + constexpr uint32_t cb_id_in0 = tt::CB::c_in0; + constexpr uint32_t cb_id_in1 = tt::CB::c_in1; + const DataFormat in0_df = get_dataformat(cb_id_in0); + auto args = reduce_scatter_reader_unique_args_t(arg_idx, in0_df); + + ASSERT(args.half_cb_n_pages >= args.full_chunk_num_pages); + + bool width_sliced = args.tensor_slice_shape.x <= args.input_tensor_shape.x; + + volatile tt_l1_ptr uint32_t* receiver_read_semaphore_addr_ptr = + reinterpret_cast(args.sem_addr); + const uint64_t eth_receiver_l1_base_noc_addr = + get_noc_addr(args.edm_core_noc0_core_x, args.edm_core_noc0_core_y, args.edm_core_buffer_address); + const uint64_t eth_receiver_l1_semaphore_noc_addr = + get_noc_addr(args.edm_core_noc0_core_x, args.edm_core_noc0_core_y, args.edm_core_semaphore_address); + + uint32_t total_cb_pages_pushed = 0; + uint32_t total_cb_pages_pushed_to_math = 0; + + // For the first timestep, there is no other input to reduce with, so we just send it straight to the input CB + // of the output data movement kernel - short-circuiting past the (reducer) math kernel + // For tile => shape in tiles + // For RM => shape in elements + uint32_t start_ring_index = args.my_ring_idx; + while (args.worker_slice_offset.x < args.tensor_slice_shape.x && + args.worker_slice_offset.y < args.tensor_slice_shape.y) { + // Need to reset back to the start ring index because the last iteration of the tranfers read chunks + // loop won't increment after the last iteration since the increment is within the loop body + args.my_ring_idx = start_ring_index; + uint32_t curr_ring_slice_start_page_offset = + width_sliced ? args.tensor_slice_shape.x * start_ring_index + : args.tensor_slice_shape.y * start_ring_index * args.input_tensor_shape.x; + + auto const& next_slice_offset = advance_slice_row_major( + args.worker_slice_offset, args.worker_slice_shape, args.tensor_slice_shape, args.num_concurrent_workers); + bool last_slice_of_worker = next_slice_offset.x >= args.tensor_slice_shape.x || + next_slice_offset.y >= args.tensor_slice_shape.y; + + const uint32_t worker_relative_start_offset_into_slice = + args.worker_slice_offset.x + (args.worker_slice_offset.y * args.input_tensor_shape.x); + const uint32_t starting_tile_id = curr_ring_slice_start_page_offset + worker_relative_start_offset_into_slice; + uint32_t curr_tile_id = starting_tile_id; + + coord_t valid_worker_slice_shape = coord_t( + std::min(args.worker_slice_shape.x, args.tensor_slice_shape.x - args.worker_slice_offset.x), + std::min(args.worker_slice_shape.y, args.tensor_slice_shape.y - args.worker_slice_offset.y)); + + bool last_page_of_worker = false; + uint32_t const worker_slice_n_pages = valid_worker_slice_shape.x * valid_worker_slice_shape.y; + ASSERT( + (args.num_transfers - 1) * worker_slice_n_pages + total_cb_pages_pushed_to_math <= + args.total_eltwise_kernel_num_pages); + { + coord_t offset_into_worker_slice = {0, 0}; + for (uint32_t p = 0; p < worker_slice_n_pages; p += args.full_chunk_num_pages) { + uint32_t n_pages = std::min(args.full_chunk_num_pages, worker_slice_n_pages - p); + ASSERT(!last_page_of_worker); + read_chunk_from_output_tensor_v2( + curr_tile_id, + offset_into_worker_slice, + valid_worker_slice_shape, + // In tiles for tile layout + args.input_tensor_shape, + to_dm_sender_short_circuit_cb, + args.s, + n_pages, + args.page_size, + last_page_of_worker); + total_cb_pages_pushed += n_pages; + if (n_pages < args.half_cb_n_pages) { + uint32_t num_filler_pages = args.half_cb_n_pages - n_pages; + push_filler_pages_to_cb(to_dm_sender_short_circuit_cb, num_filler_pages); + ASSERT(args.half_cb_n_pages > n_pages); + ASSERT(p + n_pages == worker_slice_n_pages); + total_cb_pages_pushed += num_filler_pages; + } + } + } + + for (uint32_t i = 1; i < args.num_transfers; ++i) { + bool last_transfer = i == args.num_transfers - 1; + coord_t offset_into_worker_slice = {0, 0}; + std::tie(args.my_ring_idx, curr_ring_slice_start_page_offset) = advance_to_next_transfer_slice( + args.ring_size, + args.my_ring_idx, + curr_ring_slice_start_page_offset, + args.input_tensor_shape, + args.tensor_slice_shape, + args.is_clockwise_direction); + ASSERT(last_page_of_worker); + last_page_of_worker = false; + curr_tile_id = curr_ring_slice_start_page_offset + worker_relative_start_offset_into_slice; + + for (uint32_t p = 0; p < worker_slice_n_pages; p += args.full_chunk_num_pages) { + uint32_t n_pages = std::min(args.full_chunk_num_pages, worker_slice_n_pages - p); + ASSERT(n_pages > 0); + // Fetch from input tensor + read_chunk_from_output_tensor_v2( + curr_tile_id, + offset_into_worker_slice, + valid_worker_slice_shape, + // In tiles for tile layout + args.input_tensor_shape, + cb_id_in1, + args.s, + n_pages, + args.page_size, + last_page_of_worker); + uint64_t eth_receiver_l1_curr_noc_addr = eth_receiver_l1_base_noc_addr; + + // Fetch from EDM + noc_semaphore_wait(receiver_read_semaphore_addr_ptr, 1); + noc_semaphore_set(receiver_read_semaphore_addr_ptr, 0); + fetch_chunk(cb_id_in0, n_pages, args.page_size, eth_receiver_l1_base_noc_addr); + total_cb_pages_pushed_to_math += n_pages; + total_cb_pages_pushed += n_pages; + + bool last_worker_message_to_edm = last_transfer && last_slice_of_worker && (p + n_pages >= worker_slice_n_pages); + if (!last_worker_message_to_edm) { + noc_semaphore_inc( + eth_receiver_l1_semaphore_noc_addr, + ttnn::ccl::EriscDataMoverWorkerSignal::NEXT_MESSAGE_AVAILABLE); + } + if (n_pages < args.half_cb_n_pages) { + uint32_t num_filler_pages = args.half_cb_n_pages - n_pages; + push_filler_pages_to_cb(cb_id_in0, num_filler_pages); + push_filler_pages_to_cb(cb_id_in1, num_filler_pages); + total_cb_pages_pushed_to_math += num_filler_pages; + total_cb_pages_pushed += num_filler_pages; + } + } + ASSERT(last_page_of_worker); + } + + args.worker_slice_offset = next_slice_offset; + } + + ASSERT(args.total_eltwise_kernel_num_pages >= total_cb_pages_pushed_to_math); + DEBUG_STATUS("DRN1"); + // The host code currently doesn't know how to accuractly count the exact number of pages pushed through the + // math reduce op so it instead provides a known safe lower bound which may be more than actually required by the + // op. It passes this number to sender and receiver, who will push/pop junk pages to/from the math op to ensure + // it will complete + for (; total_cb_pages_pushed_to_math < args.total_eltwise_kernel_num_pages; total_cb_pages_pushed_to_math++) { + push_filler_pages_to_cb(cb_id_in0, 1); + push_filler_pages_to_cb(cb_id_in1, 1); + } + + noc_semaphore_inc( + eth_receiver_l1_semaphore_noc_addr, + ttnn::ccl::EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY); + DEBUG_STATUS("DONE"); +} diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp new file mode 100644 index 00000000000..ac8647cb584 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp @@ -0,0 +1,148 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" + +using ttnn::ccl::coord_t; + +void kernel_main() { + constexpr bool is_sharded = get_compile_time_arg_val(0) == 1; + constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1; + + uint32_t arg_idx = 0; + uint32_t const dst_addr = get_arg_val(arg_idx++); + uint32_t const eth_sender_l1_base_addr = get_arg_val(arg_idx++); + uint32_t const eth_sender_l1_sem_addr = get_arg_val(arg_idx++); + uint32_t const eth_sender_noc_x = get_arg_val(arg_idx++); + uint32_t const eth_sender_noc_y = get_arg_val(arg_idx++); + uint32_t const num_transfers = get_arg_val(arg_idx++); + uint32_t const page_size = get_arg_val(arg_idx++); + uint32_t const full_chunk_num_pages = get_arg_val(arg_idx++); + uint32_t const writer_send_sem_addr = get_arg_val(arg_idx++); + uint32_t const half_cb_n_pages = get_arg_val(arg_idx++); + uint32_t const num_concurrent_workers = get_arg_val(arg_idx++); + + coord_t const& output_tensor_shape = ttnn::ccl::coord_from_args(arg_idx); + coord_t const& worker_slice_shape = ttnn::ccl::coord_from_args(arg_idx); + coord_t worker_slice_base_offset = ttnn::ccl::coord_from_args(arg_idx); + + uint32_t total_eltwise_kernel_num_pages = get_arg_val(arg_idx++); + + // Argument validation + ASSERT(half_cb_n_pages >= full_chunk_num_pages); + ASSERT(full_chunk_num_pages > 0); + ASSERT(page_size > 0); + ASSERT(half_cb_n_pages > 0); + + constexpr uint32_t cb_id_in0 = tt::CB::c_out0; + constexpr uint32_t cb_id_in_short_circuit = tt::CB::c_out1; + const DataFormat in0_df = get_dataformat(cb_id_in0); +#ifdef RM_INTERLEAVED + InterleavedAddrGen d = { + .bank_base_address = dst_addr + output_start_addr_offset, .page_size = page_size}; +#elif defined TILE_INTERLEAVED + + InterleavedAddrGenFast d = { + .bank_base_address = dst_addr, .page_size = page_size, .data_format = in0_df}; +#endif + + // Used to wait until eth sender has space available + volatile tt_l1_ptr uint32_t* writer_send_semaphore_addr_ptr = + reinterpret_cast(writer_send_sem_addr); + // This is different per writer core + const uint64_t eth_l1_sender_base_noc_addr = + get_noc_addr(eth_sender_noc_x, eth_sender_noc_y, eth_sender_l1_base_addr); + // Used to signal eth sender that data is available. This is different per writer core + const uint64_t eth_l1_sender_semaphore_addr = + get_noc_addr(eth_sender_noc_x, eth_sender_noc_y, eth_sender_l1_sem_addr); + + uint32_t total_lifetime_cb_pages_popped_from_math = 0; + while (worker_slice_base_offset.x < output_tensor_shape.x && worker_slice_base_offset.y < output_tensor_shape.y) { + // First phase - we only forward messages to EDM + coord_t valid_worker_slice_shape = coord_t( + std::min(worker_slice_shape.x, output_tensor_shape.x - worker_slice_base_offset.x), + std::min(worker_slice_shape.y, output_tensor_shape.y - worker_slice_base_offset.y)); + uint32_t const num_pages_to_write = valid_worker_slice_shape.x * valid_worker_slice_shape.y; + + ASSERT(total_lifetime_cb_pages_popped_from_math + num_pages_to_write <= total_eltwise_kernel_num_pages); + for (uint32_t i = 0; i < num_transfers; ++i) { + const uint32_t cb_in = i == 0 ? cb_id_in_short_circuit : cb_id_in0; + for (uint32_t p = 0; p < num_pages_to_write; p += full_chunk_num_pages) { + uint32_t n_pages = std::min(full_chunk_num_pages, num_pages_to_write - p); + ASSERT(n_pages > 0); + noc_semaphore_wait(writer_send_semaphore_addr_ptr, 1); + noc_semaphore_set(writer_send_semaphore_addr_ptr, 0); + send_chunk(cb_in, n_pages, page_size, eth_l1_sender_base_noc_addr); + noc_semaphore_inc( + eth_l1_sender_semaphore_addr, + ttnn::ccl::EriscDataMoverWorkerSignal::NEXT_MESSAGE_AVAILABLE); + if (i != 0) { + total_lifetime_cb_pages_popped_from_math += n_pages; + } + if (n_pages < half_cb_n_pages) { + uint32_t num_filler_pages = half_cb_n_pages - n_pages; + + ASSERT(p + n_pages == num_pages_to_write); + pop_filler_pages_from_cb(cb_in, num_filler_pages); + if (i != 0) { + total_lifetime_cb_pages_popped_from_math += num_filler_pages; + } + } + } + } + + // write the final reduced chunk for this chip out to the output tensor + // Second phase - Dump the local output to the output tensor + uint32_t curr_ring_slice_start_page_offset = 0; + const uint32_t worker_relative_start_offset_into_slice = + worker_slice_base_offset.x + (worker_slice_base_offset.y * output_tensor_shape.x); + auto current_worker_slice_offset = worker_slice_base_offset; + const uint32_t starting_tile_id = curr_ring_slice_start_page_offset + worker_relative_start_offset_into_slice; + uint32_t curr_tile_id = starting_tile_id; + + bool last_page_of_worker = false; + for (uint32_t p = 0; p < num_pages_to_write; p += full_chunk_num_pages) { + ASSERT(curr_tile_id < output_tensor_shape.x * output_tensor_shape.y); + ASSERT(!last_page_of_worker); + uint32_t n_pages = std::min(full_chunk_num_pages, num_pages_to_write - p); + ASSERT(n_pages <= half_cb_n_pages); + ASSERT(full_chunk_num_pages <= half_cb_n_pages); + write_chunk_v2( + curr_tile_id, + current_worker_slice_offset, + valid_worker_slice_shape, + output_tensor_shape, // In tiles for tile layout + cb_id_in0, + d, + n_pages, + page_size, + last_page_of_worker); + total_lifetime_cb_pages_popped_from_math += n_pages; + if (n_pages < half_cb_n_pages) { + uint32_t num_filler_pages = half_cb_n_pages - n_pages; + ASSERT(p + n_pages == num_pages_to_write); + pop_filler_pages_from_cb(cb_id_in0, num_filler_pages); + total_lifetime_cb_pages_popped_from_math += num_filler_pages; + } + } + + worker_slice_base_offset = advance_slice_row_major( + worker_slice_base_offset, worker_slice_shape, output_tensor_shape, num_concurrent_workers); + } + + ASSERT(total_lifetime_cb_pages_popped_from_math <= total_eltwise_kernel_num_pages); + for (; total_lifetime_cb_pages_popped_from_math < total_eltwise_kernel_num_pages; + total_lifetime_cb_pages_popped_from_math++) { + pop_filler_pages_from_cb(cb_id_in0, 1); + } + + noc_semaphore_wait(writer_send_semaphore_addr_ptr, 1); + noc_semaphore_set(writer_send_semaphore_addr_ptr, 0); + noc_semaphore_inc( + eth_l1_sender_semaphore_addr, ttnn::ccl::EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY); +} diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp new file mode 100644 index 00000000000..ecacfca26a6 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp @@ -0,0 +1,132 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp +#include "ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp" +======= +#include "ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp" +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp + +#include "ttnn/experimental/tt_dnn/op_library/reduce/reduce_op.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "tt_metal/host_api.hpp" + +#include "ttnn/operations/eltwise/binary/binary.hpp" + + +namespace ttnn { +namespace utils { + +void ReduceScatter::validate(const std::vector& input_tensors) const { + for (auto const& t : input_tensors) { + TT_FATAL( + t.get_legacy_shape()[this->scatter_dim] / this->ring_size > 0, + "Reduce scatter input tensor shape on dim {} must be divisible by ring size"); + TT_FATAL( + t.get_legacy_shape()[this->scatter_dim] % this->ring_size == 0, + "Reduce scatter input tensor shape on dim {} must be divisible by ring size"); + } +} + +std::vector ReduceScatter::compute_output_shapes(const std::vector& input_tensors) const { + auto shape = input_tensors[0].get_legacy_shape(); + TT_ASSERT( + shape[this->scatter_dim] % this->ring_size == 0, + "The size of the scatter dimension must be a multiple of the ring size"); + shape[this->scatter_dim] /= this->ring_size; + return std::vector(input_tensors.size(), shape); +} + +std::vector ReduceScatter::create_output_tensors(const std::vector& input_tensors) const { + const auto& input_tensor = input_tensors.at(0); + if (this->output_mem_config.is_sharded()) { + TT_FATAL(false, "Sharded output is not supported for ReduceScatter"); + } else { + return operation::generic_create_output_tensors( + *this, input_tensors, input_tensor.get_dtype(), input_tensor.get_layout(), this->output_mem_config); + } +} + +operation::ProgramWithCallbacks ReduceScatter::create_program( + const std::vector& input_tensors, std::vector& output_tensors) const { + return ccl::reduce_scatter_detail::reduce_scatter_with_workers( + input_tensors, + output_tensors, + this->binary_op_type, + this->scatter_dim, + this->num_links, + this->ring_size, + this->ring_index, + this->receiver_device_id, + this->sender_device_id, + this->topology); +} + +std::vector reduce_scatter_impl( + const std::vector& input_tensors, + const ttnn::operations::binary::BinaryOpType binary_op_type, + const uint32_t scatter_dim, + const uint32_t num_links, + const MemoryConfig& output_mem_config, + const ttnn::ccl::Topology topology) { + TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "This op is only supported for Fast Dispatch"); + + std::vector output_tensors; + output_tensors.reserve(input_tensors.size()); + std::vector ops; + ops.reserve(input_tensors.size()); +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp + bool is_ring = topology ==ttnn::ccl::Topology::Ring; + for (uint32_t i = 0; i < input_tensors.size(); ++i) { + bool is_last_chip_in_clockwise_direction = is_ring ? false : i == (input_tensors.size() - 1); + bool is_last_chip_in_counter_clockwise_direction = is_ring ? false : i == 0; +======= + bool is_ring = topology ==ttnn::utils::ccl::Topology::Ring; +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp + + for (uint32_t i = 0; i < input_tensors.size(); ++i) { + std::optional receiver_device_id = + is_last_chip_in_clockwise_direction + ? std::nullopt + : std::optional(input_tensors[(i + 1) % input_tensors.size()].device()->id()); + std::optional sender_device_id = + is_last_chip_in_counter_clockwise_direction + ? std::nullopt + : std::optional(input_tensors[i == 0 ? input_tensors.size() - 1 : i - 1].device()->id()); + ops.emplace_back(ReduceScatter{ + binary_op_type, + scatter_dim, + num_links, + static_cast(input_tensors.size()), + i, + receiver_device_id, + sender_device_id, + output_mem_config, + topology}); + output_tensors.push_back(operation::run(ops[i], {input_tensors.at(i)}).at(0)); + } + return output_tensors; +} + +static ttnn::operations::binary::BinaryOpType convert_reduce_type_to_eltwise_type(ReduceOpMath reduce_op) { + switch (reduce_op) { + case ReduceOpMath::SUM: return ttnn::operations::binary::BinaryOpType::ADD; + + default: TT_FATAL("Reduce scatter only support reduce_op_type SUM"); return ttnn::operations::binary::BinaryOpType::ADD; + } +} + +std::vector reduce_scatter( + const std::vector& input_tensors, + const uint32_t scatter_dim, + ReduceOpMath math_op, + const uint32_t num_links, + const MemoryConfig& output_mem_config) { + ttnn::operations::binary::BinaryOpType binary_op_type = convert_reduce_type_to_eltwise_type(math_op); + return reduce_scatter_impl( + input_tensors, binary_op_type, scatter_dim, num_links, output_mem_config,ttnn::ccl::Topology::Ring); +} + +}; // namespace utils +}; // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp new file mode 100644 index 00000000000..2440d5babb9 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp @@ -0,0 +1,67 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/experimental/tt_dnn/op_library/run_operation.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/experimental/tt_dnn/op_library/reduce/reduce_op.hpp" + +#include "ttnn/operations/eltwise/binary/binary.hpp" + +namespace ttnn { +namespace utils { + +struct ReduceScatter { + const ttnn::operations::binary::BinaryOpType binary_op_type; + const uint32_t scatter_dim; + const uint32_t num_links; + const uint32_t ring_size; + const uint32_t ring_index; + const std::optional receiver_device_id; + const std::optional sender_device_id; + const MemoryConfig output_mem_config; + const ttnn::ccl::Topology topology; + + void validate(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector create_output_tensors(const std::vector &input_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector &input_tensors, std::vector &output_tensors) const; +}; + +std::vector reduce_scatter( + const std::vector &input_tensors, + const uint32_t scatter_split_dim, + ReduceOpMath reduce_op = ReduceOpMath::SUM, + const uint32_t num_links = 1, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +namespace ccl { +namespace reduce_scatter_detail { +operation::ProgramWithCallbacks reduce_scatter_with_workers( + const std::vector& input_tensors, + const std::vector& output_tensors, + ttnn::operations::binary::BinaryOpType reduce_op, + const uint32_t scatter_split_dim, + const uint32_t num_links, + const uint32_t ring_size, + const uint32_t ring_index, + const std::optional receiver_device_id, + const std::optional sender_device_id, +<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp + ttnn::ccl::Topology topology); +======= +<<<<<<< HEAD:tt_eager/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp + ttnn::utils::ccl::Topology topology); +======= + tt::tt_metal::ccl::Topology topology); +>>>>>>> bdb9766ed5... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp +>>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp +} +}; // namespace ccl + +}; // namespace utils +}; // namespace ttnn From b9eb5e05d5a39369968b2fb2fe6dec4b3f11cd73 Mon Sep 17 00:00:00 2001 From: Aswinmcw Date: Fri, 5 Jul 2024 08:22:33 +0000 Subject: [PATCH 07/10] #9486: Move pytests to TTNN --- tests/scripts/t3000/run_t3000_frequent_tests.sh | 2 +- .../device/ccl_reduce_scatter_op.hpp | 16 ---------------- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/tests/scripts/t3000/run_t3000_frequent_tests.sh b/tests/scripts/t3000/run_t3000_frequent_tests.sh index 1c6b59afcdf..c2139f10158 100755 --- a/tests/scripts/t3000/run_t3000_frequent_tests.sh +++ b/tests/scripts/t3000/run_t3000_frequent_tests.sh @@ -71,7 +71,7 @@ run_t3000_tteager_tests() { echo "LOG_METAL: Running run_t3000_tteager_tests" pytest -n auto tests/ttnn/unit_tests/operations/test_all_gather.py -k post_commit ; fail+=$? - pytest -n auto tests/tt_eager/python_api_testing/unit_testing/misc/test_reduce_scatter_post_commit.py ; fail+=$? + pytest -n auto tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py ; fail+=$? # distributed layernorm WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest tests/ttnn/unit_tests/operations/test_distributed_layernorm.py ; fail+=$? diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/ccl_reduce_scatter_op.hpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/ccl_reduce_scatter_op.hpp index bee76abb0b1..4e863ae816b 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/ccl_reduce_scatter_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/ccl_reduce_scatter_op.hpp @@ -12,22 +12,6 @@ namespace operations { namespace ccl { struct ExecuteReduceScatter { - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::ROW_MAJOR_LAYOUT, ttnn::TILE_LAYOUT}, - true, - false, - false, - false}}; - } - - template - static auto input_tensors_to_validate(const std::vector& input_tensors, Args&&... args) { - return std::forward_as_tuple(input_tensors.at(0)); - } static std::vector execute_on_main_thread( const std::vector& input_tensors, From 97a9c96058fc928d07bb6b393efc8f17572f4867 Mon Sep 17 00:00:00 2001 From: Aswinmcw Date: Sat, 13 Jul 2024 08:25:41 +0000 Subject: [PATCH 08/10] #0: Fix issues --- ttnn/cpp/pybind11/operations/__init__.hpp | 2 + .../tt_dnn/op_library/CMakeLists.txt | 2 - .../host/reduce_scatter_full_worker_grid.cpp | 41 ++----------------- .../device/reduce_scatter_op.cpp | 11 +++-- .../device/reduce_scatter_op.hpp | 8 ---- ...e_scatter_op.hpp => reduce_scatter_op.hpp} | 0 ...r_pybind.hpp => reduce_scatter_pybind.hpp} | 12 +++--- 7 files changed, 18 insertions(+), 58 deletions(-) rename ttnn/cpp/ttnn/operations/ccl/reduce_scatter/{device/ccl_reduce_scatter_op.hpp => reduce_scatter_op.hpp} (100%) rename ttnn/cpp/ttnn/operations/ccl/reduce_scatter/{ccl_reduce_scatter_pybind.hpp => reduce_scatter_pybind.hpp} (86%) diff --git a/ttnn/cpp/pybind11/operations/__init__.hpp b/ttnn/cpp/pybind11/operations/__init__.hpp index 15e3d418563..5b74c2637f2 100644 --- a/ttnn/cpp/pybind11/operations/__init__.hpp +++ b/ttnn/cpp/pybind11/operations/__init__.hpp @@ -9,6 +9,7 @@ #include "ttnn/operations/ccl/all_gather/all_gather_pybind.hpp" #include "ttnn/operations/ccl/line_all_gather/line_all_gather_pybind.hpp" +#include "ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.hpp" #include "pybind11/operations/copy.hpp" #include "pybind11/operations/core.hpp" #include "pybind11/operations/creation.hpp" @@ -66,6 +67,7 @@ void py_module(py::module& module) { auto m_ccl = module.def_submodule("ccl", "collective communication operations"); ccl::py_module_all_gather(m_ccl); ccl::py_module_line_all_gather(m_ccl); + ccl::py_module_reduce_scatter(m_ccl); auto m_ccl = module.def_submodule("ccl", "collective communication operations"); ccl::py_bind_all_gather(m_ccl); diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/CMakeLists.txt b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/CMakeLists.txt index 84cbae70773..e1a03233f64 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/CMakeLists.txt +++ b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/CMakeLists.txt @@ -4,8 +4,6 @@ set(TT_DNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/auto_format.cpp ${CMAKE_CURRENT_SOURCE_DIR}/data_transfer/data_transfer_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layout_conversion/layout_conversion_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/ccl/reduce_scatter/reduce_scatter_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sharded/sharded_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sharded/multi_core/sharded_op_multi_core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sharded_partial/sharded_op_partial.cpp diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp index eab8cdfd40e..b39054ba7c8 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp @@ -379,6 +379,9 @@ static void add_worker_config_to_edm_builders( log_trace(tt::LogOp, "Adding receiver EDM channel"); ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& receiver_channel_buffer_info = receiver_edm_builder.add_receiver_channel( + worker_receiver_semaphore_address, + // Since we are in worker signal EDM termination mode, we don't need to set the actual number of + // messages the EDM must forward as it will receive its finish signal from the worker instead 1, receiver_worker_coords, expected_message_size_bytes); @@ -393,19 +396,12 @@ static void add_worker_config_to_edm_builders( static std::tuple build_reduce_scatter_worker( tt::tt_metal::Program& program, Device const* device, -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp ttnn::ccl::RingTopology const& topology_config, ttnn::ccl::CCLOpConfig const& op_config, ReduceScatterWorkerArgBuilder const& worker_arg_builder, std::vector& cw_edm_builders, std::vector& ccw_edm_builders, EdmInterfaceAddresses const& edm_interface_addresses, -======= - ttnn::utils::ccl::CCLOpConfig const& op_config, - ReduceScatterWorkerArgBuilder const& worker_arg_builder, - std::vector& cw_edm_builders, - std::vector& ccw_edm_builders, ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp CoreCoord const& worker_core, uint32_t num_edm_channels, uint32_t link, @@ -419,15 +415,9 @@ static std::tuple build_reduce_scatter_worker( log_trace(tt::LogOp, "Worker Define: {} = {}", key, value); } static std::string const& receiver_kernel_path = -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp - "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp"; - static std::string const& sender_kernel_path = - "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp"; -======= "ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp"; static std::string const& sender_kernel_path = "ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp"; ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp // This will be configurable by sharded/non-sharded but present the same arg builder KernelHandle worker_receiver_kernel_id, worker_sender_kernel_id; @@ -611,11 +601,7 @@ static uint32_t compute_maximum_worker_slice_in_bytes( } static bool is_cb_buffering_sufficient_to_avoid_deadlock( -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp ttnn::ccl::InterleavedTensorWorkerSlice const& worker_slice, -======= - ttnn::utils::ccl::InterleavedTensorWorkerSlice const& worker_slice, ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp uint32_t cb_src0_size_pages, uint32_t cb_dst0_size_pages, uint32_t cb_short_circuit_size_pages, @@ -641,11 +627,7 @@ static bool is_cb_buffering_sufficient_to_avoid_deadlock( static std::tuple create_worker_circular_buffers( Tensor const& input_tensor, -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp ttnn::ccl::CCLOpConfig const& op_config, -======= - ttnn::utils::ccl::CCLOpConfig const& op_config, ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp CoreRangeSet const& worker_core_range, uint32_t worker_pages_per_transfer, tt::tt_metal::Program& program) { @@ -697,11 +679,7 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( const uint32_t ring_index, const std::optional receiver_device_id, const std::optional sender_device_id, -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp ttnn::ccl::Topology topology) { -======= - ttnn::utils::ccl::Topology topology) { ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp log_trace(tt::LogOp, "reduce_scatter_with_workers entry"); TT_ASSERT( input_tensors.at(0).get_legacy_shape()[scatter_split_dim] == @@ -713,21 +691,12 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( /////////////// Constants/Configuration /// Constants/Configuration -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp ttnn::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode =ttnn::ccl::EriscDataMoverBufferSharingMode::ROUND_ROBIN; auto const& op_config =ttnn::ccl::CCLOpConfig(input_tensors, output_tensors, topology); std::unique_ptr input_tensor_config = ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(input_tensors.at(0)); std::unique_ptr output_tensor_config = ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(output_tensors.at(0)); -======= - ttnn::utils::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode =ttnn::utils::ccl::EriscDataMoverBufferSharingMode::ROUND_ROBIN; - auto const& op_config =ttnn::utils::ccl::CCLOpConfig(input_tensors, output_tensors, topology); - std::unique_ptr input_tensor_config = - ttnn::utils::ccl::CclOpTensorConfig::build_all_gather_tensor_config(input_tensors.at(0)); - std::unique_ptr output_tensor_config = - ttnn::utils::ccl::CclOpTensorConfig::build_all_gather_tensor_config(output_tensors.at(0)); ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp uint32_t per_step_dim_size = input_tensors.at(0).get_legacy_shape()[scatter_split_dim] / ring_size; uint32_t input_tensor_num_units_per_scatter_dim = per_step_dim_size / tt::constants::TILE_WIDTH; // TODO: find the divisibility based on layout @@ -884,11 +853,7 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( } // Generate the EDM kernels -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp ttnn::ccl::generate_edm_kernels_for_ring_or_linear_topology( -======= - ttnn::utils::ccl::generate_edm_kernels_for_ring_or_linear_topology( ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp program, device, topology_config, diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp index ecacfca26a6..927d2c529ec 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp @@ -2,11 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp -#include "ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp" -======= #include "ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp" ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp #include "ttnn/experimental/tt_dnn/op_library/reduce/reduce_op.hpp" #include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" @@ -76,6 +72,7 @@ std::vector reduce_scatter_impl( output_tensors.reserve(input_tensors.size()); std::vector ops; ops.reserve(input_tensors.size()); +<<<<<<< HEAD <<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp bool is_ring = topology ==ttnn::ccl::Topology::Ring; for (uint32_t i = 0; i < input_tensors.size(); ++i) { @@ -85,7 +82,13 @@ std::vector reduce_scatter_impl( bool is_ring = topology ==ttnn::utils::ccl::Topology::Ring; >>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp +======= + bool is_ring = topology == ccl::Topology::Ring; +>>>>>>> a98abddcea... #0: Fix issues for (uint32_t i = 0; i < input_tensors.size(); ++i) { + bool is_last_chip_in_clockwise_direction = is_ring ? false : i == (input_tensors.size() - 1); + bool is_last_chip_in_counter_clockwise_direction = is_ring ? false : i == 0; + std::optional receiver_device_id = is_last_chip_in_clockwise_direction ? std::nullopt diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp index 2440d5babb9..79d4c86e199 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp @@ -51,15 +51,7 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( const uint32_t ring_index, const std::optional receiver_device_id, const std::optional sender_device_id, -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp ttnn::ccl::Topology topology); -======= -<<<<<<< HEAD:tt_eager/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp - ttnn::utils::ccl::Topology topology); -======= - tt::tt_metal::ccl::Topology topology); ->>>>>>> bdb9766ed5... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp } }; // namespace ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/ccl_reduce_scatter_op.hpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_op.hpp similarity index 100% rename from ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/ccl_reduce_scatter_op.hpp rename to ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_op.hpp diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/ccl_reduce_scatter_pybind.hpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.hpp similarity index 86% rename from ttnn/cpp/ttnn/operations/ccl/reduce_scatter/ccl_reduce_scatter_pybind.hpp rename to ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.hpp index 40bf407725d..7ad64dd7c67 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/ccl_reduce_scatter_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.hpp @@ -8,19 +8,19 @@ #include #include "ttnn/cpp/pybind11/decorators.hpp" -#include "ttnn/operations/ccl/reduce_scatter/device/ccl_reduce_scatter_op.hpp" +#include "ttnn/operations/ccl/reduce_scatter/reduce_scatter_op.hpp" #include "ttnn/types.hpp" namespace py = pybind11; namespace ttnn { namespace operations { -namespace ccl_reduce_scatter { +namespace ccl { namespace detail { template -void bind_ccl_reduce_scatter(py::module& module, const ccl_operation_t& operation, const char* doc) { +void bind_reduce_scatter(py::module& module, const ccl_operation_t& operation, const char* doc) { bind_registered_operation( module, operation, @@ -45,9 +45,9 @@ void bind_ccl_reduce_scatter(py::module& module, const ccl_operation_t& operatio } // namespace detail -void py_module(py::module& module) { +void py_module_reduce_scatter(py::module& module) { - detail::bind_ccl_reduce_scatter( + detail::bind_reduce_scatter( module, ttnn::reduce_scatter, R"doc(reduce_scatter(input_tensor: std::vector, scatter_dim: int, math_op: ReduceOpMath, *, num_links: int = 1, memory_config: Optional[ttnn.MemoryConfig] = None) -> std::vector @@ -70,6 +70,6 @@ void py_module(py::module& module) { )doc"); } -} // namespace ccl_reduce_scatter +} // namespace ccl } // namespace operations } // namespace ttnn From 0fcc51d22b1d5d5436c2d3b5fbfd458078174b7c Mon Sep 17 00:00:00 2001 From: Aswinmcw Date: Sat, 20 Jul 2024 05:01:26 +0000 Subject: [PATCH 09/10] #0: Rebased #0: Rebased --- ...st_all_gather_sharded_indexing_helpers.cpp | 2 +- .../ops/ccl/test_all_gather_utils.cpp | 2 +- ttnn/cpp/pybind11/operations/__init__.hpp | 6 +- .../host/reduce_scatter_full_worker_grid.cpp | 928 ------------------ ...interleaved_ring_reduce_scatter_reader.cpp | 356 ------- ...interleaved_ring_reduce_scatter_sender.cpp | 148 --- .../ccl/reduce_scatter/reduce_scatter_op.cpp | 132 --- .../ccl/reduce_scatter/reduce_scatter_op.hpp | 67 -- .../csrc/tt_lib_bindings_tensor_dm_ops.cpp | 23 - .../ccl/all_gather/all_gather_op.hpp | 50 - .../ccl/all_gather/all_gather_pybind.hpp | 56 -- .../ccl/kernels/edm/erisc_async_datamover.hpp | 21 +- .../ccl/kernels/edm/erisc_datamover.cpp | 26 - .../device/line_all_gather_op.cpp | 56 -- .../host/reduce_scatter_full_worker_grid.cpp | 3 - .../device/reduce_scatter_op.cpp | 36 +- .../device/reduce_scatter_op.hpp | 21 +- .../ccl/reduce_scatter/reduce_scatter_op.hpp | 2 +- .../reduce_scatter/reduce_scatter_pybind.hpp | 2 +- 19 files changed, 31 insertions(+), 1906 deletions(-) delete mode 100644 ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp delete mode 100644 ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp delete mode 100644 ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp delete mode 100644 ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp delete mode 100644 ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp diff --git a/tests/tt_eager/ops/ccl/test_all_gather_sharded_indexing_helpers.cpp b/tests/tt_eager/ops/ccl/test_all_gather_sharded_indexing_helpers.cpp index fe2c6e7b4bc..3fa4d26db90 100644 --- a/tests/tt_eager/ops/ccl/test_all_gather_sharded_indexing_helpers.cpp +++ b/tests/tt_eager/ops/ccl/test_all_gather_sharded_indexing_helpers.cpp @@ -289,4 +289,4 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceSingleT } } } -} \ No newline at end of file +} diff --git a/tests/tt_eager/ops/ccl/test_all_gather_utils.cpp b/tests/tt_eager/ops/ccl/test_all_gather_utils.cpp index ad1a0555f4e..35335223344 100644 --- a/tests/tt_eager/ops/ccl/test_all_gather_utils.cpp +++ b/tests/tt_eager/ops/ccl/test_all_gather_utils.cpp @@ -1364,4 +1364,4 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex0_NumWor ASSERT_EQ(curr_core_chunk_index, 4); ASSERT_EQ(curr_worker_index, 6); ASSERT_EQ(contiguous_chunk_count, 1); -} \ No newline at end of file +} diff --git a/ttnn/cpp/pybind11/operations/__init__.hpp b/ttnn/cpp/pybind11/operations/__init__.hpp index 5b74c2637f2..5fa97d07a74 100644 --- a/ttnn/cpp/pybind11/operations/__init__.hpp +++ b/ttnn/cpp/pybind11/operations/__init__.hpp @@ -64,14 +64,10 @@ void py_module(py::module& module) { auto m_unary_backward = module.def_submodule("unary_backward", "unary_backward operations"); unary_backward::py_module(m_unary_backward); - auto m_ccl = module.def_submodule("ccl", "collective communication operations"); - ccl::py_module_all_gather(m_ccl); - ccl::py_module_line_all_gather(m_ccl); - ccl::py_module_reduce_scatter(m_ccl); - auto m_ccl = module.def_submodule("ccl", "collective communication operations"); ccl::py_bind_all_gather(m_ccl); ccl::py_bind_line_all_gather(m_ccl); + ccl::py_bind_reduce_scatter(m_ccl); auto m_complex_unary = module.def_submodule("complex_unary", "complex_unary operations"); complex_unary::py_module(m_complex_unary); diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp deleted file mode 100644 index eab8cdfd40e..00000000000 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp +++ /dev/null @@ -1,928 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 -/// - -#include "common/core_coord.h" -#include "eth_l1_address_map.h" -#include "impl/buffers/buffer.hpp" -#include "impl/kernels/data_types.hpp" -#include "tensor/tensor_impl.hpp" -#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" -#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" -#include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/host_api.hpp" -#include "tt_metal/impl/buffers/circular_buffer_types.hpp" - -#include "ttnn/operations/eltwise/binary/binary.hpp" - -// Includes that need to be moved to CCL datastructures header -#include - -using namespace tt::constants; - -// Notes on abbreviations: -// cw = clockwise -// ccw = counter-clockwise -// edm = erisc data mover - -// How this reduce_scatter op works: -// For each chip, we have a element range of the input tensor shape that will eventually scatter -// out to it. For all other chunks outside that range, the chip will forward the chunk to the next chip. -// While forwarding the data, the chip will also reduce it with the local input tensor chunk corresponding -// with that received chunk. It will forward the partially reduced chunk. -// Reduces along rank - -namespace ttnn { - -namespace utils { - -namespace ccl { -namespace reduce_scatter_detail { -struct WorkerTransferInfo { - WorkerTransferInfo( - std::vector pages_per_full_chunk_per_worker, uint32_t num_links, uint32_t num_workers) : - pages_per_full_chunk_per_worker(pages_per_full_chunk_per_worker), - num_links(num_links), - num_workers(num_workers) {} - - uint32_t get_num_pages_per_full_chunk(uint32_t link, uint32_t worker_idx) const { - return pages_per_full_chunk_per_worker.at(link * num_workers + worker_idx); - } - - std::vector pages_per_full_chunk_per_worker; - uint32_t num_links; - uint32_t num_workers; -}; - -static std::size_t decide_number_of_edm_channels( - ttnn::ccl::CCLOpConfig const& ccl_op_config, std::size_t max_num_workers, bool enable_bidirectional) { - return ccl_op_config.is_input_sharded() ? std::min( - ccl_op_config.get_shard_grid_size(), - std::min(max_num_workers, enable_bidirectional ? 8 : 4)) - : std::min(max_num_workers, enable_bidirectional ? 8 : 4); -} - -struct ReduceScatterWorkerArgBuilder { - ReduceScatterWorkerArgBuilder( - ttnn::ccl::CCLOpConfig const& op_config, - ttnn::ccl::RingTopology const& topology_config, - ttnn::ccl::InterleavedTensorWorkerSlice const& worker_input_slice, - WorkerTransferInfo const& worker_transfer_info, - uint32_t worker_idx, - uint32_t link, - uint32_t cb_num_pages_per_packet, - uint32_t worker_sender_semaphore_address, - uint32_t worker_receiver_semaphore_address) : - op_config(op_config), - topology_config(topology_config), - worker_input_slice(worker_input_slice), - worker_transfer_info(worker_transfer_info), - cb_num_pages_per_packet(cb_num_pages_per_packet), - worker_sender_semaphore_address(worker_sender_semaphore_address), - worker_receiver_semaphore_address(worker_receiver_semaphore_address) { -#ifndef SEND_MATH_TERMINATE_SIGNAL - // This algorithm assumes that the worker slices are sized such that they start at the same x offsets for each - // new row they slice into (as they stride through the tensor) - std::size_t num_slice_iterations = - worker_input_slice.compute_num_worker_slice_iterations(worker_transfer_info.num_workers); - std::size_t worker_slice_num_pages = - worker_input_slice.worker_slice_shape.x * worker_input_slice.worker_slice_shape.y; - std::size_t pages_per_full_chunk = worker_transfer_info.get_num_pages_per_full_chunk(link, worker_idx); - std::size_t num_filler_pages_per_slice = pages_per_full_chunk - (worker_slice_num_pages % pages_per_full_chunk); - this->total_num_math_pages = (worker_input_slice.get_worker_slice_num_pages() + num_filler_pages_per_slice) * - num_slice_iterations * (topology_config.ring_size - 1); - - log_trace(tt::LogOp, "ReduceScatterWorkerArgBuilder: total_num_math_pages: {}", this->total_num_math_pages); -#endif - } - - std::vector generate_reduce_op_kernel_ct_args() const { - log_trace(tt::LogOp, "Reduce Scatter Worker CT Args: None"); - return {}; - } - - std::vector generate_reduce_op_kernel_rt_args( - uint32_t link, uint32_t worker_index, uint32_t ring_size) const { - log_trace(tt::LogOp, "generate_reduce_op_kernel_rt_args"); - - auto const& args = std::vector{total_num_math_pages, 1, 0}; - - std::size_t i = 0; - log_trace(tt::LogOp, "Reduce Scatter Worker RT Args:"); - log_trace(tt::LogOp, "\tblock_size: {}", args.at(i++)); - log_trace(tt::LogOp, "\ttotal_num_math_pages: {}", args.at(i++)); - log_trace(tt::LogOp, "\tacc_to_dst: {}", args.at(i++)); - - return args; - } - - std::vector generate_receiver_kernel_ct_args() const { - auto const& args = std::vector{ - static_cast(this->op_config.is_input_sharded() ? 1 : 0), - static_cast( - this->op_config.get_input_tensor(0).memory_config().buffer_type == BufferType::DRAM ? 1 : 0)}; - - std::size_t i = 0; - log_trace(tt::LogOp, "Reduce Scatter Receiver Worker CT Args:"); - log_trace(tt::LogOp, "\tis_sharded: {}", args.at(i++)); - log_trace(tt::LogOp, "\tsrc_is_dram: {}", args.at(i++)); - TT_ASSERT(args.size() == i, "Missed some args"); - - return args; - } - - std::vector generate_receiver_kernel_rt_args( - ttnn::ccl::WorkerXY edm_core, - uint32_t edm_core_semaphore_address, - uint32_t edm_core_buffer_address, - uint32_t link, - uint32_t worker_index, - bool is_in_clockwise_direction) const { - TT_ASSERT(edm_core_semaphore_address > 0); - TT_ASSERT(edm_core_buffer_address > 0); - auto const& local_input_tensor = this->op_config.get_input_tensor(0); - uint32_t starting_ring_index = - is_in_clockwise_direction ? (this->topology_config.ring_index == 0 ? this->topology_config.ring_size - 1 - : this->topology_config.ring_index - 1) - : (this->topology_config.ring_index == this->topology_config.ring_size - 1 - ? 0 - : this->topology_config.ring_index + 1); - auto args = std::vector{ - static_cast(local_input_tensor.buffer()->address()), - static_cast(this->topology_config.ring_size), // num_transfers - static_cast(this->worker_transfer_info.get_num_pages_per_full_chunk(link, worker_index)), - static_cast(this->op_config.get_page_size()), - static_cast(starting_ring_index), - static_cast(this->topology_config.ring_size), - static_cast(this->worker_receiver_semaphore_address), - static_cast(is_in_clockwise_direction ? 1 : 0), - static_cast(this->cb_num_pages_per_packet), - static_cast(edm_core.x), - static_cast(edm_core.y), - static_cast(edm_core_semaphore_address), - static_cast(edm_core_buffer_address), - - static_cast(worker_transfer_info.num_workers), - - static_cast(this->worker_input_slice.tensor_shape.x), - static_cast(this->worker_input_slice.tensor_shape.y), - - static_cast(this->worker_input_slice.tensor_slice_shape.x), - static_cast(this->worker_input_slice.tensor_slice_shape.y), - - static_cast(this->worker_input_slice.worker_slice_shape.x), - static_cast(this->worker_input_slice.worker_slice_shape.y), - - static_cast(this->worker_input_slice.worker_slice_offset.x), - static_cast(this->worker_input_slice.worker_slice_offset.y), - - this->total_num_math_pages}; - - std::size_t i = 0; - log_trace(tt::LogOp, "Reduce Scatter Receiver Worker RT Args:"); - log_trace(tt::LogOp, "\tsrc_addr: {}", args.at(i++)); - log_trace(tt::LogOp, "\tnum_transfers: {}", args.at(i++)); - log_trace(tt::LogOp, "\tfull_chunk_num_pages: {}", args.at(i++)); - log_trace(tt::LogOp, "\tpage_size: {}", args.at(i++)); - log_trace(tt::LogOp, "\tmy_ring_idx: {}", args.at(i++)); - log_trace(tt::LogOp, "\tring_size: {}", args.at(i++)); - log_trace(tt::LogOp, "\tsem_addr: {}", args.at(i++)); - log_trace(tt::LogOp, "\tis_clockwise_direction: {}", args.at(i++)); - log_trace(tt::LogOp, "\thalf_cb_n_pages: {}", args.at(i++)); - - log_trace(tt::LogOp, "\tedm_core_noc0_core_x: {}", args.at(i++)); - log_trace(tt::LogOp, "\tedm_core_noc0_core_y: {}", args.at(i++)); - log_trace(tt::LogOp, "\tedm_core_semaphore_address: {}", args.at(i++)); - log_trace(tt::LogOp, "\tedm_core_buffer_address: {}", args.at(i++)); - log_trace(tt::LogOp, "\tnum_concurrent_workers: {}", args.at(i++)); - - log_trace(tt::LogOp, "\tinput_tensor_shape.x={}", args.at(i++)); - log_trace(tt::LogOp, "\tinput_tensor_shape.y={}", args.at(i++)); - log_trace(tt::LogOp, "\ttensor_slice_shape.x={}", args.at(i++)); - log_trace(tt::LogOp, "\ttensor_slice_shape.y={}", args.at(i++)); - log_trace(tt::LogOp, "\tworker_slice_shape.x={}", args.at(i++)); - log_trace(tt::LogOp, "\tworker_slice_shape.y={}", args.at(i++)); - log_trace(tt::LogOp, "\tworker_slice_offset.x={}", args.at(i++)); - log_trace(tt::LogOp, "\tworker_slice_offset.y={}", args.at(i++)); - log_trace(tt::LogOp, "\ttotal_num_math_pages={}", args.at(i++)); - - TT_ASSERT(args.size() == i, "Missed some args"); - - return args; - } - - std::vector generate_sender_kernel_ct_args() const { - auto const& args = std::vector{ - static_cast(this->op_config.is_input_sharded() ? 1 : 0), - static_cast( - this->op_config.get_output_tensor(0).memory_config().buffer_type == BufferType::DRAM ? 1 : 0)}; - - std::size_t i = 0; - log_trace(tt::LogOp, "Reduce Scatter Sender Worker CT Args:"); - log_trace(tt::LogOp, "\tis_sharded: {}", args.at(i++)); - log_trace(tt::LogOp, "\tdst_is_dram: {}", args.at(i++)); - TT_ASSERT(args.size() == i, "Missed some args"); - - return args; - } - - std::vector generate_sender_kernel_rt_args( - ttnn::ccl::WorkerXY edm_core, - uint32_t edm_core_semaphore_address, - uint32_t edm_core_buffer_address, - uint32_t link, - uint32_t worker_index, - bool is_clockwise) const { - TT_ASSERT(edm_core_semaphore_address > 0); - TT_ASSERT(edm_core_buffer_address > 0); - auto const& local_output_tensor = this->op_config.get_output_tensor(0); - auto const& args = std::vector{ - static_cast(local_output_tensor.buffer()->address()), - static_cast(edm_core_buffer_address), - static_cast(edm_core_semaphore_address), - static_cast(edm_core.x), - static_cast(edm_core.y), - static_cast(this->topology_config.ring_size - 1), // num_transfers), - - static_cast(this->op_config.get_page_size()), - static_cast(this->worker_transfer_info.get_num_pages_per_full_chunk(link, worker_index)), - - static_cast(this->worker_sender_semaphore_address), - static_cast(this->cb_num_pages_per_packet), - - static_cast(worker_transfer_info.num_workers), - - // For sender side, all worker slice info is the same except for the tensor shape - // and for sender side specifically, there is only one tensor_slice_shape for the output - // tensor (as opposed to `ring_size` tensor_slice_shapes for the input tensor), so we can - // directly use it as the output tensor shape - static_cast(this->worker_input_slice.tensor_slice_shape.x), - static_cast(this->worker_input_slice.tensor_slice_shape.y), - static_cast(this->worker_input_slice.worker_slice_shape.x), - static_cast(this->worker_input_slice.worker_slice_shape.y), - static_cast(this->worker_input_slice.worker_slice_offset.x), - static_cast(this->worker_input_slice.worker_slice_offset.y), - - total_num_math_pages}; - - std::size_t i = 0; - log_trace(tt::LogOp, "Reduce Scatter Sender Worker RT Args:"); - log_trace(tt::LogOp, "\tdst_addr: {}", args.at(i++)); - log_trace(tt::LogOp, "\teth_sender_l1_base_addr: {}", args.at(i++)); - log_trace(tt::LogOp, "\teth_sender_l1_sem_addr: {}", args.at(i++)); - log_trace(tt::LogOp, "\teth_sender_noc_x: {}", args.at(i++)); - log_trace(tt::LogOp, "\teth_sender_noc_y: {}", args.at(i++)); - log_trace(tt::LogOp, "\tnum_transfers: {}", args.at(i++)); - log_trace(tt::LogOp, "\tpage_size: {}", args.at(i++)); - log_trace(tt::LogOp, "\tfull_chunk_num_pages: {}", args.at(i++)); - log_trace(tt::LogOp, "\twriter_send_sem_addr: {}", args.at(i++)); - log_trace(tt::LogOp, "\thalf_cb_n_pages: {}", args.at(i++)); - log_trace(tt::LogOp, "\tnum_concurrent_workers: {}", args.at(i++)); - - log_trace(tt::LogOp, "\toutput_tensor_shape.x: {}", args.at(i++)); - log_trace(tt::LogOp, "\toutput_tensor_shape.y: {}", args.at(i++)); - log_trace(tt::LogOp, "\tworker_slice_shape.x: {}", args.at(i++)); - log_trace(tt::LogOp, "\tworker_slice_shape.y: {}", args.at(i++)); - log_trace(tt::LogOp, "\tworker_slice_offset.x: {}", args.at(i++)); - log_trace(tt::LogOp, "\tworker_slice_offset.y: {}", args.at(i++)); - - log_trace(tt::LogOp, "\ttotal_num_math_pages={}", args.at(i++)); - - TT_ASSERT(args.size() == i, "Missed some args"); - - return args; - } - - ttnn::ccl::RingTopology const topology_config; - ttnn::ccl::CCLOpConfig const op_config; - ttnn::ccl::InterleavedTensorWorkerSlice const worker_input_slice; - WorkerTransferInfo const worker_transfer_info; - uint32_t cb_num_pages_per_packet; - uint32_t worker_sender_semaphore_address; - uint32_t worker_receiver_semaphore_address; - - uint32_t total_num_math_pages; - bool src_is_dram; - bool dst_is_dram; -}; - -struct EdmInterfaceAddresses { - std::unordered_map worker_sender_edm_semaphore_addresses; - std::unordered_map worker_sender_edm_buffer_addresses; - std::unordered_map worker_receiver_edm_semaphore_addresses; - std::unordered_map worker_receiver_edm_buffer_addresses; -}; - -// Future work: split this up further: -// 1) assign workers to EDM channel (with buffer sharing mode specified too) -// 2) Compute the semaphore and buffer addresses (for each EDM channel and worker) -// For now - the mapping between workers and EDM channels is 1:1 -static void add_worker_config_to_edm_builders( - Device* device, - ttnn::ccl::RingReduceScatterTensorSlicer& tensor_slicer, // TODO: Update to Generic ReduceScatterSlicer when it is implemented - ttnn::ccl::CCLOpConfig const& op_config, - std::vector const& worker_cores, - uint32_t num_channels_per_edm, - - std::vector& clockwise_edm_builders, - std::vector& counter_clockwise_edm_builders, - - uint32_t worker_sender_semaphore_address, - uint32_t worker_receiver_semaphore_address, - uint32_t link, - uint32_t ring_size, - std::function is_buffer_in_clockwise_direction_fn, - - EdmInterfaceAddresses& edm_interface_addresses) { - for (uint32_t c = 0; c < num_channels_per_edm; ++c) { - uint32_t global_worker_idx = c + num_channels_per_edm * link; - uint32_t num_workers_per_eth_buffer = 1; - - std::vector sender_worker_coords; - std::vector receiver_worker_coords; - for (uint32_t w = c * num_workers_per_eth_buffer; w < (c + 1) * num_workers_per_eth_buffer; ++w) { - sender_worker_coords.push_back(ttnn::ccl::WorkerXY( - device->worker_core_from_logical_core(worker_cores.at(w)).x, - device->worker_core_from_logical_core(worker_cores.at(w)).y)); - receiver_worker_coords.push_back(ttnn::ccl::WorkerXY( - device->worker_core_from_logical_core(worker_cores.at(w)).x, - device->worker_core_from_logical_core(worker_cores.at(w)).y)); - } - - // Get the expected message size in bytes for this worker - uint32_t expected_message_size_bytes = tensor_slicer.get_worker_slice_size_bytes(global_worker_idx); - - bool sender_enabled = true; // (!is_linear || !is_last_chip_in_chain); // update for linear - if (sender_enabled) { - auto& sender_edm_builder = is_buffer_in_clockwise_direction_fn(c) ? clockwise_edm_builders.at(link) - : counter_clockwise_edm_builders.at(link); - log_trace(tt::LogOp, "Adding sender EDM channel"); - ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& sender_channel_buffer_info = - sender_edm_builder.add_sender_channel( - worker_sender_semaphore_address, - 1, // cw_edm_channel_num_messages_to_send_per_transfer.at(c) * (ring_size - 1), - sender_worker_coords, - expected_message_size_bytes); - edm_interface_addresses.worker_sender_edm_semaphore_addresses.insert( - {global_worker_idx, sender_channel_buffer_info.eth_semaphore_l1_address}); - edm_interface_addresses.worker_sender_edm_buffer_addresses.insert( - {global_worker_idx, sender_channel_buffer_info.eth_buffer_l1_address}); - } - - bool receiver_enabled = true; //(!is_linear || !is_first_chip_in_chain); - if (receiver_enabled) { - auto& receiver_edm_builder = is_buffer_in_clockwise_direction_fn(c) - ? counter_clockwise_edm_builders.at(link) - : clockwise_edm_builders.at(link); - log_trace(tt::LogOp, "Adding receiver EDM channel"); - ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& receiver_channel_buffer_info = - receiver_edm_builder.add_receiver_channel( - 1, - receiver_worker_coords, - expected_message_size_bytes); - edm_interface_addresses.worker_receiver_edm_semaphore_addresses.insert( - {global_worker_idx, receiver_channel_buffer_info.eth_semaphore_l1_address}); - edm_interface_addresses.worker_receiver_edm_buffer_addresses.insert( - {global_worker_idx, receiver_channel_buffer_info.eth_buffer_l1_address}); - } - } -} - -static std::tuple build_reduce_scatter_worker( - tt::tt_metal::Program& program, - Device const* device, -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp - ttnn::ccl::RingTopology const& topology_config, - ttnn::ccl::CCLOpConfig const& op_config, - ReduceScatterWorkerArgBuilder const& worker_arg_builder, - std::vector& cw_edm_builders, - std::vector& ccw_edm_builders, - EdmInterfaceAddresses const& edm_interface_addresses, -======= - ttnn::utils::ccl::CCLOpConfig const& op_config, - ReduceScatterWorkerArgBuilder const& worker_arg_builder, - std::vector& cw_edm_builders, - std::vector& ccw_edm_builders, ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp - CoreCoord const& worker_core, - uint32_t num_edm_channels, - uint32_t link, - uint32_t ring_size, - uint32_t worker_index, - std::map const& worker_defines, - ttnn::operations::binary::BinaryOpType binary_math_op) { - - TT_ASSERT(worker_defines.size() > 0); - for (auto const& [key, value] : worker_defines) { - log_trace(tt::LogOp, "Worker Define: {} = {}", key, value); - } - static std::string const& receiver_kernel_path = -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp - "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp"; - static std::string const& sender_kernel_path = - "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp"; -======= - "ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp"; - static std::string const& sender_kernel_path = - "ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp"; ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp - - // This will be configurable by sharded/non-sharded but present the same arg builder - KernelHandle worker_receiver_kernel_id, worker_sender_kernel_id; - - bool is_in_clockwise_direction = true; // TODO: bidirectional - uint32_t global_worker_index = link * num_edm_channels + worker_index; - { - CoreCoord const& receiver_edm = is_in_clockwise_direction ? topology_config.eth_receiver_cores.at(link) - : topology_config.eth_sender_cores.at(link); - ttnn::ccl::WorkerXY receiver_edm_noc_coord =ttnn::ccl::WorkerXY( - device->ethernet_core_from_logical_core(receiver_edm).x, - device->ethernet_core_from_logical_core(receiver_edm).y); - const uint32_t edm_core_semaphore_address = - is_in_clockwise_direction - ? edm_interface_addresses.worker_receiver_edm_semaphore_addresses.at(global_worker_index) - : edm_interface_addresses.worker_sender_edm_semaphore_addresses.at(global_worker_index); - const uint32_t edm_core_buffer_address = - is_in_clockwise_direction - ? edm_interface_addresses.worker_receiver_edm_buffer_addresses.at(global_worker_index) - : edm_interface_addresses.worker_sender_edm_buffer_addresses.at(global_worker_index); - - worker_receiver_kernel_id = tt::tt_metal::CreateKernel( - program, - receiver_kernel_path, - worker_core, - tt::tt_metal::ReaderDataMovementConfig(worker_arg_builder.generate_receiver_kernel_ct_args(), worker_defines)); - - tt::tt_metal::SetRuntimeArgs( - program, - worker_receiver_kernel_id, - worker_core, - worker_arg_builder.generate_receiver_kernel_rt_args( - receiver_edm_noc_coord, - edm_core_semaphore_address, - edm_core_buffer_address, - link, - worker_index, - is_in_clockwise_direction)); - } - - { - vector compute_kernel_args = {}; - constexpr bool fp32_dest_acc_en = false; - constexpr bool math_approx_mode = false; - std::map eltwise_defines = ttnn::operations::binary::utils::get_defines(binary_math_op); - KernelHandle worker_reduce_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/compute/eltwise_binary_kernel.cpp", - worker_core, - tt::tt_metal::ComputeConfig{ - .math_fidelity = MathFidelity::HiFi4, - .fp32_dest_acc_en = fp32_dest_acc_en, - .math_approx_mode = math_approx_mode, - .compile_args = compute_kernel_args, - .defines = eltwise_defines}); - - tt::tt_metal::SetRuntimeArgs( - program, - worker_reduce_kernel_id, - worker_core, - worker_arg_builder.generate_reduce_op_kernel_rt_args(link, worker_index, ring_size)); - } - - { - CoreCoord sender_edm = is_in_clockwise_direction ? topology_config.eth_sender_cores.at(link) - : topology_config.eth_receiver_cores.at(link); - ttnn::ccl::WorkerXY const sender_edm_noc_coord =ttnn::ccl::WorkerXY( - device->ethernet_core_from_logical_core(sender_edm).x, - device->ethernet_core_from_logical_core(sender_edm).y); - TT_ASSERT(sender_edm_noc_coord.y == 0 || sender_edm_noc_coord.y == 6); - const uint32_t edm_core_semaphore_address = - is_in_clockwise_direction - ? edm_interface_addresses.worker_sender_edm_semaphore_addresses.at(global_worker_index) - : edm_interface_addresses.worker_receiver_edm_semaphore_addresses.at(global_worker_index); - const uint32_t edm_core_buffer_address = - is_in_clockwise_direction - ? edm_interface_addresses.worker_sender_edm_buffer_addresses.at(global_worker_index) - : edm_interface_addresses.worker_receiver_edm_buffer_addresses.at(global_worker_index); - worker_sender_kernel_id = tt::tt_metal::CreateKernel( - program, - sender_kernel_path, - worker_core, - tt::tt_metal::WriterDataMovementConfig(worker_arg_builder.generate_sender_kernel_ct_args(), worker_defines)); - - tt::tt_metal::SetRuntimeArgs( - program, - worker_sender_kernel_id, - worker_core, - worker_arg_builder.generate_sender_kernel_rt_args( - sender_edm_noc_coord, - edm_core_semaphore_address, - edm_core_buffer_address, - link, - worker_index, - is_in_clockwise_direction)); - } - - return {worker_receiver_kernel_id, worker_sender_kernel_id}; -} - -static CoreRangeSet select_worker_cores( - ttnn::ccl::CCLOpConfig const& op_config, std::size_t num_links, std::size_t num_edm_channels) { - switch (op_config.get_topology()) { - case ttnn::ccl::Topology::Linear: - return CoreRangeSet({CoreRange(CoreCoord(0, 0), CoreCoord(num_edm_channels - 1, num_links - 1))}); - case ttnn::ccl::Topology::Ring: - return CoreRangeSet({CoreRange(CoreCoord(0, 0), CoreCoord(num_edm_channels - 1, num_links - 1))}); - default: TT_ASSERT(false, "Unsupported topology"); return CoreRangeSet({}); - }; -} - -static WorkerTransferInfo compute_num_edm_messages_per_channel( - ttnn::ccl::CCLOpConfig const& op_config, - ttnn::ccl::RingReduceScatterTensorSlicer& tensor_slicer, // TODO: Update to Generic ReduceScatterSlicer when it is implemented - std::vector const& cw_per_link_edm_builders, - std::vector const& ccw_per_link_edm_builders, - std::size_t const num_edm_channels, - std::size_t const num_links, - std::size_t const ring_size) { - uint32_t const page_size_in_bytes = op_config.get_page_size(); - TT_ASSERT(num_edm_channels > 0); - TT_ASSERT(num_links > 0); - TT_ASSERT(page_size_in_bytes > 0); - log_trace(tt::LogOp, "WorkerTransferInfo"); - std::size_t total_num_workers = num_edm_channels * num_links; - - auto get_iter_begin = [num_edm_channels](auto& vec, std::size_t link) -> auto { - return vec.begin() + (link * num_edm_channels); - }; - - auto get_iter_end = [num_edm_channels, num_links](auto& vec, std::size_t link) -> auto { - bool last_link = link == num_links - 1; - TT_ASSERT( - (!last_link && ((link + 1) * num_edm_channels < vec.size())) || - (last_link && ((link + 1) * num_edm_channels == vec.size()))); - return last_link ? vec.end() : vec.begin() + ((link + 1) * num_edm_channels); - }; - - // Pages per EDM channel - std::size_t total_num_edm_channels = num_links * num_edm_channels; - log_trace(tt::LogOp, "total_num_edm_channels: {}", total_num_edm_channels); - - std::vector num_pages_per_full_chunk(total_num_edm_channels * num_links, 0); - - for (std::size_t link = 0; link < num_links; link++) { - std::size_t edm_channel_size_in_bytes = cw_per_link_edm_builders.at(link).get_eth_buffer_size_bytes(); - std::size_t num_pages_per_edm_buffer = edm_channel_size_in_bytes / page_size_in_bytes; - log_trace( - tt::LogOp, - "link {}, edm_channel_size_in_bytes: {}, page_size_in_bytes: {}, num_pages_per_edm_buffer: {}", - link, - edm_channel_size_in_bytes, - page_size_in_bytes, - num_pages_per_edm_buffer); - - std::fill( - get_iter_begin(num_pages_per_full_chunk, link), - get_iter_end(num_pages_per_full_chunk, link), - num_pages_per_edm_buffer); - } - - log_trace(tt::LogOp, "-- num_pages_per_full_chunk:"); - for (std::size_t l = 0; l < num_links; l++) { - for (std::size_t w = 0; w < num_edm_channels; w++) { - log_trace( - tt::LogOp, "\t\t(link={},worker={}): {}", l, w, num_pages_per_full_chunk.at(l * num_edm_channels + w)); - } - } - - return WorkerTransferInfo(num_pages_per_full_chunk, num_links, num_edm_channels); -} - -static uint32_t compute_maximum_worker_slice_in_bytes( - uint32_t cb_src0_size_pages, - uint32_t cb_dst0_size_pages, - uint32_t cb_short_circuit_size_pages, - std::size_t edm_channel_buffer_size, - uint32_t page_size) { - return std::min(cb_short_circuit_size_pages, cb_src0_size_pages + cb_dst0_size_pages) * page_size + - edm_channel_buffer_size; -} - -static bool is_cb_buffering_sufficient_to_avoid_deadlock( -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp - ttnn::ccl::InterleavedTensorWorkerSlice const& worker_slice, -======= - ttnn::utils::ccl::InterleavedTensorWorkerSlice const& worker_slice, ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp - uint32_t cb_src0_size_pages, - uint32_t cb_dst0_size_pages, - uint32_t cb_short_circuit_size_pages, - std::size_t edm_channel_buffer_size, - uint32_t page_size) { - uint32_t worker_size_pages_rounded_up = - tt::round_up(worker_slice.worker_slice_shape.x * worker_slice.worker_slice_shape.y, cb_src0_size_pages / 2); - uint32_t worker_slice_size_bytes = worker_size_pages_rounded_up * page_size; - uint32_t available_buffering_capacity = compute_maximum_worker_slice_in_bytes( - cb_src0_size_pages, cb_dst0_size_pages, cb_short_circuit_size_pages, edm_channel_buffer_size, page_size); - log_trace(tt::LogOp, "worker_slice.worker_slice_shape.x: {}", worker_slice.worker_slice_shape.x); - log_trace(tt::LogOp, "worker_slice.worker_slice_shape.y: {}", worker_slice.worker_slice_shape.y); - log_trace(tt::LogOp, "worker_slice_size_bytes: {}", worker_slice_size_bytes); - log_trace(tt::LogOp, "worker_size_pages_rounded_up: {}", worker_size_pages_rounded_up); - log_trace(tt::LogOp, "cb_src0_size_pages: {}", cb_src0_size_pages); - log_trace(tt::LogOp, "cb_dst0_size_pages: {}", cb_dst0_size_pages); - log_trace(tt::LogOp, "page_size: {}", page_size); - log_trace(tt::LogOp, "edm_channel_buffer_size: {}", edm_channel_buffer_size); - log_trace(tt::LogOp, "available_buffering_capacity: {}", available_buffering_capacity); - - return available_buffering_capacity >= worker_slice_size_bytes; -} - -static std::tuple create_worker_circular_buffers( - Tensor const& input_tensor, -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp - ttnn::ccl::CCLOpConfig const& op_config, -======= - ttnn::utils::ccl::CCLOpConfig const& op_config, ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp - CoreRangeSet const& worker_core_range, - uint32_t worker_pages_per_transfer, - tt::tt_metal::Program& program) { - tt::DataFormat df = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); - uint32_t page_size_bytes = op_config.get_page_size(); - - // Input 0 CB - uint32_t src0_cb_index = tt::CB::c_in0; - tt::tt_metal::CircularBufferConfig cb_src0_config = - tt::tt_metal::CircularBufferConfig(worker_pages_per_transfer * page_size_bytes, {{src0_cb_index, df}}) - .set_page_size(src0_cb_index, page_size_bytes); - CBHandle cb_src0_workers = CreateCircularBuffer(program, worker_core_range, cb_src0_config); - - // Input 1 CB - uint32_t src1_cb_index = tt::CB::c_in1; - tt::tt_metal::CircularBufferConfig cb_src1_config = - tt::tt_metal::CircularBufferConfig(worker_pages_per_transfer * page_size_bytes, {{src1_cb_index, df}}) - .set_page_size(src1_cb_index, page_size_bytes); - CBHandle cb_src1_workers = CreateCircularBuffer(program, worker_core_range, cb_src1_config); - - // Dataflow Writer Kernel input CB - uint32_t cb_dst0_index = tt::CB::c_out0; - tt::tt_metal::CircularBufferConfig cb_dst0_config = - tt::tt_metal::CircularBufferConfig(worker_pages_per_transfer * page_size_bytes, {{cb_dst0_index, df}}) - .set_page_size(cb_dst0_index, page_size_bytes); - CBHandle cb_dst0_sender_workers = CreateCircularBuffer(program, worker_core_range, cb_dst0_config); - - // From reader -> writer kernel (I think I need this because sharing the cb_dst0_sender_workers as output - // of reader kernel (first output) and math kernel (all subsequent outputs) doesn't seem to work because - // it seems like the math kernels hold some of the CB state in local variables) - uint32_t cb_short_circuit_index = tt::CB::c_out1; - tt::tt_metal::CircularBufferConfig cb_short_circuit_config = - tt::tt_metal::CircularBufferConfig( - (worker_pages_per_transfer * page_size_bytes) * 2, {{cb_short_circuit_index, df}}) - .set_page_size(cb_short_circuit_index, page_size_bytes); - CBHandle cb_short_circuit_sender_workers = - CreateCircularBuffer(program, worker_core_range, cb_short_circuit_config); - - return {cb_src0_workers, cb_src1_workers, cb_dst0_sender_workers, cb_short_circuit_sender_workers}; -} - -operation::ProgramWithCallbacks reduce_scatter_with_workers( - const std::vector& input_tensors, - const std::vector& output_tensors, - ttnn::operations::binary::BinaryOpType reduce_op, - const uint32_t scatter_split_dim, - const uint32_t num_links, - const uint32_t ring_size, - const uint32_t ring_index, - const std::optional receiver_device_id, - const std::optional sender_device_id, -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp - ttnn::ccl::Topology topology) { -======= - ttnn::utils::ccl::Topology topology) { ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp - log_trace(tt::LogOp, "reduce_scatter_with_workers entry"); - TT_ASSERT( - input_tensors.at(0).get_legacy_shape()[scatter_split_dim] == - output_tensors.at(0).get_legacy_shape()[scatter_split_dim] * ring_size, - "Input and output tensor shapes must match"); - TT_ASSERT( - input_tensors.at(0).buffer()->num_pages() % ring_size == 0, - "Reduce scatter current only supports even divisibility of input tensor(s) across ranks"); - - /////////////// Constants/Configuration - /// Constants/Configuration -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp - ttnn::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode =ttnn::ccl::EriscDataMoverBufferSharingMode::ROUND_ROBIN; - auto const& op_config =ttnn::ccl::CCLOpConfig(input_tensors, output_tensors, topology); - std::unique_ptr input_tensor_config = - ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(input_tensors.at(0)); - std::unique_ptr output_tensor_config = - ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(output_tensors.at(0)); -======= - ttnn::utils::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode =ttnn::utils::ccl::EriscDataMoverBufferSharingMode::ROUND_ROBIN; - auto const& op_config =ttnn::utils::ccl::CCLOpConfig(input_tensors, output_tensors, topology); - std::unique_ptr input_tensor_config = - ttnn::utils::ccl::CclOpTensorConfig::build_all_gather_tensor_config(input_tensors.at(0)); - std::unique_ptr output_tensor_config = - ttnn::utils::ccl::CclOpTensorConfig::build_all_gather_tensor_config(output_tensors.at(0)); ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp - uint32_t per_step_dim_size = input_tensors.at(0).get_legacy_shape()[scatter_split_dim] / ring_size; - uint32_t input_tensor_num_units_per_scatter_dim = - per_step_dim_size / tt::constants::TILE_WIDTH; // TODO: find the divisibility based on layout - TT_ASSERT(input_tensor_num_units_per_scatter_dim > 0); - uint32_t max_num_workers = std::min(8, input_tensor_num_units_per_scatter_dim); - bool enable_bidirectional = false; - auto num_edm_channels = decide_number_of_edm_channels(op_config, max_num_workers, enable_bidirectional); - log_trace(tt::LogOp, "num_edm_channels: {}", num_edm_channels); - auto edm_termination_mode =ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; - auto const& edm_builder = create_erisc_datamover_builder( - num_edm_channels, op_config.get_page_size(), buffer_sharing_mode, edm_termination_mode); - TT_ASSERT(num_edm_channels > 0); - - Tensor const& local_chip_tensor = input_tensors.at(0); - Tensor const& local_chip_output_tensor = output_tensors.at(0); - - std::map worker_defines; - std::vector worker_receiver_kernels; - std::vector worker_sender_kernels; - std::vector cw_per_link_edm_builders(num_links, edm_builder); - std::vector ccw_per_link_edm_builders(num_links, edm_builder); - - bool rm = local_chip_tensor.get_layout() == Layout::ROW_MAJOR; - if (rm) { - worker_defines["RM_INTERLEAVED"] = "1"; - } else { - worker_defines["TILE_INTERLEAVED"] = "1"; - } - - ////////////////// - tt::tt_metal::Program program{}; - const auto& device = local_chip_tensor.device(); - - auto const& topology_config = - ttnn::ccl::RingTopology(device, topology, sender_device_id, receiver_device_id, num_links, ring_size, ring_index); - - auto dim_slice_factors = tt::tt_metal::Shape(std::vector(local_chip_tensor.get_legacy_shape().rank(), 1)); - dim_slice_factors[-1] = ring_size; - - CoreRangeSet const& worker_core_range = select_worker_cores(op_config, num_links, num_edm_channels); - auto const& worker_cores = corerange_to_cores(worker_core_range, std::nullopt, true); - - // Semaphores && CBs - auto worker_receiver_semaphore_address = tt::tt_metal::CreateSemaphore(program, worker_core_range, 0); - auto worker_sender_semaphore_address = tt::tt_metal::CreateSemaphore(program, worker_core_range, 0); - - uint32_t cb_num_pages = - (cw_per_link_edm_builders.at(0).get_eth_buffer_size_bytes() / op_config.get_page_size()) * 2; - uint32_t cb_num_pages_per_packet = cb_num_pages / 2; - log_trace(tt::LogOp, "cb_num_pages: {}", cb_num_pages); - auto const& [cb_src0_workers, cb_src1_workers, cb_dst0_sender_workers, cb_short_circuit_sender_workers] = - create_worker_circular_buffers(local_chip_tensor, op_config, worker_core_range, cb_num_pages, program); - - uint32_t max_worker_slice_in_bytes = compute_maximum_worker_slice_in_bytes( - cb_num_pages, - cb_num_pages, - cb_num_pages, - cw_per_link_edm_builders.at(0).get_eth_buffer_size_bytes(), - op_config.get_page_size()); - auto tensor_slicer =ttnn::ccl::RingReduceScatterTensorSlicer( - local_chip_tensor, - local_chip_output_tensor, - scatter_split_dim, - ring_index, - ring_size, - num_edm_channels * num_links, - max_worker_slice_in_bytes, - cb_num_pages / 2); - - // Not per buffer because the buffer sharing mode may cause some buffers to share EDM transfers - WorkerTransferInfo const& worker_transfer_info = compute_num_edm_messages_per_channel( - op_config, - tensor_slicer, - cw_per_link_edm_builders, - ccw_per_link_edm_builders, - num_edm_channels, - num_links, - ring_size); - - // Configure the EDM builders - EdmInterfaceAddresses edm_interface_addresses; - for (std::size_t link = 0; link < num_links; link++) { - add_worker_config_to_edm_builders( - device, - tensor_slicer, - op_config, - worker_cores, - num_edm_channels, - - cw_per_link_edm_builders, - ccw_per_link_edm_builders, - - worker_sender_semaphore_address, - worker_receiver_semaphore_address, - link, - ring_size, - [enable_bidirectional, num_edm_channels](uint32_t x) { - return enable_bidirectional ? (x % num_edm_channels == 0) : true; - }, - - edm_interface_addresses); - } - - // build the worker kernels - tt::tt_metal::ComputeConfig compute_config; - for (std::size_t link = 0; link < num_links; link++) { - uint32_t global_worker_index = link * num_edm_channels; - log_trace(tt::LogOp, "=============================================="); - log_trace(tt::LogOp, "------------------ Link: {} ------------------", link); - for (std::size_t worker = 0; worker < num_edm_channels; worker++) { - std::size_t global_worker_index = worker + link * num_edm_channels; - log_trace(tt::LogOp, "------ Worker: {} (global ID={})", worker, global_worker_index); - // This will be configurable by sharded/non-sharded but present the same arg builder - auto const& worker_slice = tensor_slicer.get_worker_slice(global_worker_index); - auto worker_arg_builder = ReduceScatterWorkerArgBuilder( - op_config, - topology_config, - worker_slice, - worker_transfer_info, - worker, - link, - cb_num_pages_per_packet, - worker_sender_semaphore_address, - worker_receiver_semaphore_address); - - log_trace(tt::LogOp, "worker_cores.at(global_worker_index): {}", worker_cores.at(global_worker_index)); - auto [receiver_kernel_id, sender_kernel_id] = build_reduce_scatter_worker( - program, - device, - topology_config, - op_config, - worker_arg_builder, - cw_per_link_edm_builders, - ccw_per_link_edm_builders, - edm_interface_addresses, - worker_cores.at(global_worker_index), - num_edm_channels, - link, - ring_size, - worker, - worker_defines, - reduce_op); - worker_receiver_kernels.push_back(receiver_kernel_id); - worker_sender_kernels.push_back(sender_kernel_id); - - TT_ASSERT(is_cb_buffering_sufficient_to_avoid_deadlock( - worker_slice, - cb_num_pages, - cb_num_pages, - cb_num_pages, - cw_per_link_edm_builders.at(0).get_eth_buffer_size_bytes(), - op_config.get_page_size())); - } - } - - // Generate the EDM kernels -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp - ttnn::ccl::generate_edm_kernels_for_ring_or_linear_topology( -======= - ttnn::utils::ccl::generate_edm_kernels_for_ring_or_linear_topology( ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp - program, - device, - topology_config, - cw_per_link_edm_builders, - ccw_per_link_edm_builders, - receiver_device_id, - sender_device_id); - - uint32_t total_num_workers = worker_cores.size(); - auto override_runtime_arguments_callback = - [topology_config, worker_receiver_kernels, worker_sender_kernels, worker_cores, total_num_workers, ring_index]( - const void* operation, - Program& program, - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector& output_tensors) { - const auto& input = input_tensors.at(0); - const auto& output = output_tensors.at(0); - TT_ASSERT(worker_sender_kernels.size() == worker_receiver_kernels.size()); - for (uint32_t i = 0; i < worker_sender_kernels.size(); ++i) { - auto& worker_receiver_runtime_args = - GetRuntimeArgs(program, worker_receiver_kernels.at(i), worker_cores.at(i)); - worker_receiver_runtime_args.at(0) = input.buffer()->address(); - - auto& worker_sender_runtime_args = - GetRuntimeArgs(program, worker_sender_kernels.at(i), worker_cores.at(i)); - worker_sender_runtime_args.at(0) = output.buffer()->address(); - } - }; - - return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; -} - -} // namespace reduce_scatter_detail -} // namespace ccl -} // namespace utils -} // namespace ttnn diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp deleted file mode 100644 index 850bbfadca4..00000000000 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp +++ /dev/null @@ -1,356 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include - -#include "dataflow_api.h" -#include "debug/assert.h" -#include "tensix_types.h" -#include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" -#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" -#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" - -using ttnn::ccl::coord_t; -using ttnn::ccl::WorkerXY; - -struct reduce_scatter_reader_common_args_t { - reduce_scatter_reader_common_args_t(uint32_t& arg_idx) : - src_addr(get_arg_val(arg_idx++)), - num_transfers(get_arg_val(arg_idx++)), - full_chunk_num_pages(get_arg_val(arg_idx++)), - page_size(get_arg_val(arg_idx++)), - - my_ring_idx(get_arg_val(arg_idx++)), - ring_size(get_arg_val(arg_idx++)), - sem_addr(get_arg_val(arg_idx++)), - - is_clockwise_direction(get_arg_val(arg_idx++) == 1), - half_cb_n_pages(get_arg_val(arg_idx++)), - edm_core_noc0_core_x(get_arg_val(arg_idx++)), - edm_core_noc0_core_y(get_arg_val(arg_idx++)), - edm_core_semaphore_address(get_arg_val(arg_idx++)), - edm_core_buffer_address(get_arg_val(arg_idx++)), - num_concurrent_workers(get_arg_val(arg_idx++)), - - input_tensor_shape(ttnn::ccl::coord_from_args(arg_idx)), - tensor_slice_shape(ttnn::ccl::coord_from_args(arg_idx)), - worker_slice_shape(ttnn::ccl::coord_from_args(arg_idx)), - worker_slice_offset(ttnn::ccl::coord_from_args(arg_idx)), - total_eltwise_kernel_num_pages(get_arg_val(arg_idx++)) - { - ASSERT(full_chunk_num_pages > 0); - ASSERT(page_size > 0); - ASSERT(ring_size > 0); - ASSERT(half_cb_n_pages > 0); - } - - const uint32_t src_addr; - const uint32_t num_transfers; - const uint32_t full_chunk_num_pages; - const uint32_t page_size; - uint32_t my_ring_idx; - const uint32_t ring_size; - const uint32_t sem_addr; - - const bool is_clockwise_direction; - - const uint32_t half_cb_n_pages; - const uint32_t edm_core_noc0_core_x; - const uint32_t edm_core_noc0_core_y; - const uint32_t edm_core_semaphore_address; - const uint32_t edm_core_buffer_address; - const uint32_t num_concurrent_workers; - - coord_t input_tensor_shape; - coord_t tensor_slice_shape; - coord_t worker_slice_shape; - coord_t worker_slice_offset; - uint32_t total_eltwise_kernel_num_pages; -}; -#ifdef RM_INTERLEAVED -constexpr bool rm_interleaved_addr_gen_mode = true; -#else -constexpr bool rm_interleaved_addr_gen_mode = false; -#endif - -template -struct interleaved_addr_gen_t { - using type = InterleavedAddrGen; -}; -template <> -struct interleaved_addr_gen_t { - using type = InterleavedAddrGen; -}; -template <> -struct interleaved_addr_gen_t { - using type = InterleavedAddrGen; -}; -template <> -struct interleaved_addr_gen_t { - using type = InterleavedAddrGenFast; -}; -template <> -struct interleaved_addr_gen_t { - using type = InterleavedAddrGenFast; -}; - -template -struct reduce_scatter_reader_unique_args_t : public reduce_scatter_reader_common_args_t { - using src_addr_gen_t = typename interleaved_addr_gen_t::type; - - reduce_scatter_reader_unique_args_t(uint32_t& arg_idx, const DataFormat in0_df) : - reduce_scatter_reader_common_args_t(arg_idx) { - this->s = { - .bank_base_address = this->src_addr, - .page_size = page_size -#if defined TILE_INTERLEAVED - , - .data_format = in0_df -#endif - }; - } - - src_addr_gen_t s; - - void dprint() const { - DPRINT << "RSR args:" - << "\n\tsrc_addr=" << src_addr << "\n\tnum_transfers=" << num_transfers << "\n\tpage_size=" << page_size - << "\n\tfull_chunk_num_pages=" << full_chunk_num_pages << "\n\tmy_ring_idx=" << my_ring_idx - << "\n\tsem_addr=" << sem_addr << "\n\tis_clockwise_direction=" << (uint32_t)is_clockwise_direction - << "\n\thalf_cb_n_pages=" << half_cb_n_pages << "\n\tring_size=" << ring_size - << "\n\tedm_core_noc0_core_x=" << edm_core_noc0_core_x - << "\n\tedm_core_noc0_core_y=" << edm_core_noc0_core_y - << "\n\tedm_core_semaphore_address=" << edm_core_semaphore_address - << "\n\tedm_core_buffer_address=" << edm_core_buffer_address << "\n"; - } -}; - -template -struct reduce_scatter_reader_unique_args_t : public reduce_scatter_reader_common_args_t { - reduce_scatter_reader_unique_args_t(uint32_t& arg_idx, const DataFormat in0_df) : - reduce_scatter_reader_common_args_t(arg_idx), - shard_num_pages(get_arg_val(arg_idx++)), - num_l1_cores(get_arg_val(arg_idx++)), - l1_cores_ptr(reinterpret_cast(get_arg_addr(arg_idx))) { - arg_idx += this->num_l1_cores; - } - - const uint32_t shard_num_pages; - const uint32_t num_l1_cores; - const WorkerXY* const l1_cores_ptr; - - void dprint() const {} -}; - -using advance_to_next_transfer_slice_result_t = std::tuple< - uint32_t, // ring_index - uint32_t // slice_base_page_offset - >; -template -advance_to_next_transfer_slice_result_t advance_to_next_transfer_slice( - uint32_t const ring_size, - uint32_t const curr_ring_idx, - uint32_t const slice_base_page_offset, - coord_t const& input_tensor_shape, - coord_t const& tensor_slice_shape, - bool const is_clockwise_direction) { - bool const sliced_only_on_width = tensor_slice_shape.x < input_tensor_shape.x && tensor_slice_shape.y == input_tensor_shape.y; - uint32_t single_ring_idx_stride = - sliced_only_on_width ? tensor_slice_shape.x : tensor_slice_shape.y * input_tensor_shape.x; - uint32_t n_minus_one_ring_indices_stride = sliced_only_on_width - ? tensor_slice_shape.x * (ring_size - 1) - : tensor_slice_shape.y * input_tensor_shape.x * (ring_size - 1); - - if constexpr (!is_sharded) { - if (is_clockwise_direction) { - if (curr_ring_idx == 0) { - return advance_to_next_transfer_slice_result_t{ - ring_size - 1, - slice_base_page_offset + n_minus_one_ring_indices_stride, - }; - } else { - return advance_to_next_transfer_slice_result_t{ - curr_ring_idx - 1, - slice_base_page_offset - single_ring_idx_stride, - }; - } - } else { - if (curr_ring_idx == ring_size - 1) { - return advance_to_next_transfer_slice_result_t{ - 0, - slice_base_page_offset - n_minus_one_ring_indices_stride, - }; - } else { - return advance_to_next_transfer_slice_result_t{ - curr_ring_idx + 1, - slice_base_page_offset + single_ring_idx_stride, - }; - } - } - } -} - -void kernel_main() { - constexpr bool is_sharded = get_compile_time_arg_val(0) == 1; - - // Currently meaningless when `is_sharded=true` - constexpr bool src_is_dram = get_compile_time_arg_val(1) == 1; - - uint32_t arg_idx = 0; - - constexpr uint32_t to_dm_sender_short_circuit_cb = tt::CB::c_out1; - constexpr uint32_t cb_id_in0 = tt::CB::c_in0; - constexpr uint32_t cb_id_in1 = tt::CB::c_in1; - const DataFormat in0_df = get_dataformat(cb_id_in0); - auto args = reduce_scatter_reader_unique_args_t(arg_idx, in0_df); - - ASSERT(args.half_cb_n_pages >= args.full_chunk_num_pages); - - bool width_sliced = args.tensor_slice_shape.x <= args.input_tensor_shape.x; - - volatile tt_l1_ptr uint32_t* receiver_read_semaphore_addr_ptr = - reinterpret_cast(args.sem_addr); - const uint64_t eth_receiver_l1_base_noc_addr = - get_noc_addr(args.edm_core_noc0_core_x, args.edm_core_noc0_core_y, args.edm_core_buffer_address); - const uint64_t eth_receiver_l1_semaphore_noc_addr = - get_noc_addr(args.edm_core_noc0_core_x, args.edm_core_noc0_core_y, args.edm_core_semaphore_address); - - uint32_t total_cb_pages_pushed = 0; - uint32_t total_cb_pages_pushed_to_math = 0; - - // For the first timestep, there is no other input to reduce with, so we just send it straight to the input CB - // of the output data movement kernel - short-circuiting past the (reducer) math kernel - // For tile => shape in tiles - // For RM => shape in elements - uint32_t start_ring_index = args.my_ring_idx; - while (args.worker_slice_offset.x < args.tensor_slice_shape.x && - args.worker_slice_offset.y < args.tensor_slice_shape.y) { - // Need to reset back to the start ring index because the last iteration of the tranfers read chunks - // loop won't increment after the last iteration since the increment is within the loop body - args.my_ring_idx = start_ring_index; - uint32_t curr_ring_slice_start_page_offset = - width_sliced ? args.tensor_slice_shape.x * start_ring_index - : args.tensor_slice_shape.y * start_ring_index * args.input_tensor_shape.x; - - auto const& next_slice_offset = advance_slice_row_major( - args.worker_slice_offset, args.worker_slice_shape, args.tensor_slice_shape, args.num_concurrent_workers); - bool last_slice_of_worker = next_slice_offset.x >= args.tensor_slice_shape.x || - next_slice_offset.y >= args.tensor_slice_shape.y; - - const uint32_t worker_relative_start_offset_into_slice = - args.worker_slice_offset.x + (args.worker_slice_offset.y * args.input_tensor_shape.x); - const uint32_t starting_tile_id = curr_ring_slice_start_page_offset + worker_relative_start_offset_into_slice; - uint32_t curr_tile_id = starting_tile_id; - - coord_t valid_worker_slice_shape = coord_t( - std::min(args.worker_slice_shape.x, args.tensor_slice_shape.x - args.worker_slice_offset.x), - std::min(args.worker_slice_shape.y, args.tensor_slice_shape.y - args.worker_slice_offset.y)); - - bool last_page_of_worker = false; - uint32_t const worker_slice_n_pages = valid_worker_slice_shape.x * valid_worker_slice_shape.y; - ASSERT( - (args.num_transfers - 1) * worker_slice_n_pages + total_cb_pages_pushed_to_math <= - args.total_eltwise_kernel_num_pages); - { - coord_t offset_into_worker_slice = {0, 0}; - for (uint32_t p = 0; p < worker_slice_n_pages; p += args.full_chunk_num_pages) { - uint32_t n_pages = std::min(args.full_chunk_num_pages, worker_slice_n_pages - p); - ASSERT(!last_page_of_worker); - read_chunk_from_output_tensor_v2( - curr_tile_id, - offset_into_worker_slice, - valid_worker_slice_shape, - // In tiles for tile layout - args.input_tensor_shape, - to_dm_sender_short_circuit_cb, - args.s, - n_pages, - args.page_size, - last_page_of_worker); - total_cb_pages_pushed += n_pages; - if (n_pages < args.half_cb_n_pages) { - uint32_t num_filler_pages = args.half_cb_n_pages - n_pages; - push_filler_pages_to_cb(to_dm_sender_short_circuit_cb, num_filler_pages); - ASSERT(args.half_cb_n_pages > n_pages); - ASSERT(p + n_pages == worker_slice_n_pages); - total_cb_pages_pushed += num_filler_pages; - } - } - } - - for (uint32_t i = 1; i < args.num_transfers; ++i) { - bool last_transfer = i == args.num_transfers - 1; - coord_t offset_into_worker_slice = {0, 0}; - std::tie(args.my_ring_idx, curr_ring_slice_start_page_offset) = advance_to_next_transfer_slice( - args.ring_size, - args.my_ring_idx, - curr_ring_slice_start_page_offset, - args.input_tensor_shape, - args.tensor_slice_shape, - args.is_clockwise_direction); - ASSERT(last_page_of_worker); - last_page_of_worker = false; - curr_tile_id = curr_ring_slice_start_page_offset + worker_relative_start_offset_into_slice; - - for (uint32_t p = 0; p < worker_slice_n_pages; p += args.full_chunk_num_pages) { - uint32_t n_pages = std::min(args.full_chunk_num_pages, worker_slice_n_pages - p); - ASSERT(n_pages > 0); - // Fetch from input tensor - read_chunk_from_output_tensor_v2( - curr_tile_id, - offset_into_worker_slice, - valid_worker_slice_shape, - // In tiles for tile layout - args.input_tensor_shape, - cb_id_in1, - args.s, - n_pages, - args.page_size, - last_page_of_worker); - uint64_t eth_receiver_l1_curr_noc_addr = eth_receiver_l1_base_noc_addr; - - // Fetch from EDM - noc_semaphore_wait(receiver_read_semaphore_addr_ptr, 1); - noc_semaphore_set(receiver_read_semaphore_addr_ptr, 0); - fetch_chunk(cb_id_in0, n_pages, args.page_size, eth_receiver_l1_base_noc_addr); - total_cb_pages_pushed_to_math += n_pages; - total_cb_pages_pushed += n_pages; - - bool last_worker_message_to_edm = last_transfer && last_slice_of_worker && (p + n_pages >= worker_slice_n_pages); - if (!last_worker_message_to_edm) { - noc_semaphore_inc( - eth_receiver_l1_semaphore_noc_addr, - ttnn::ccl::EriscDataMoverWorkerSignal::NEXT_MESSAGE_AVAILABLE); - } - if (n_pages < args.half_cb_n_pages) { - uint32_t num_filler_pages = args.half_cb_n_pages - n_pages; - push_filler_pages_to_cb(cb_id_in0, num_filler_pages); - push_filler_pages_to_cb(cb_id_in1, num_filler_pages); - total_cb_pages_pushed_to_math += num_filler_pages; - total_cb_pages_pushed += num_filler_pages; - } - } - ASSERT(last_page_of_worker); - } - - args.worker_slice_offset = next_slice_offset; - } - - ASSERT(args.total_eltwise_kernel_num_pages >= total_cb_pages_pushed_to_math); - DEBUG_STATUS("DRN1"); - // The host code currently doesn't know how to accuractly count the exact number of pages pushed through the - // math reduce op so it instead provides a known safe lower bound which may be more than actually required by the - // op. It passes this number to sender and receiver, who will push/pop junk pages to/from the math op to ensure - // it will complete - for (; total_cb_pages_pushed_to_math < args.total_eltwise_kernel_num_pages; total_cb_pages_pushed_to_math++) { - push_filler_pages_to_cb(cb_id_in0, 1); - push_filler_pages_to_cb(cb_id_in1, 1); - } - - noc_semaphore_inc( - eth_receiver_l1_semaphore_noc_addr, - ttnn::ccl::EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY); - DEBUG_STATUS("DONE"); -} diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp deleted file mode 100644 index ac8647cb584..00000000000 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp +++ /dev/null @@ -1,148 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include - -#include "dataflow_api.h" -#include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" -#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" - -using ttnn::ccl::coord_t; - -void kernel_main() { - constexpr bool is_sharded = get_compile_time_arg_val(0) == 1; - constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1; - - uint32_t arg_idx = 0; - uint32_t const dst_addr = get_arg_val(arg_idx++); - uint32_t const eth_sender_l1_base_addr = get_arg_val(arg_idx++); - uint32_t const eth_sender_l1_sem_addr = get_arg_val(arg_idx++); - uint32_t const eth_sender_noc_x = get_arg_val(arg_idx++); - uint32_t const eth_sender_noc_y = get_arg_val(arg_idx++); - uint32_t const num_transfers = get_arg_val(arg_idx++); - uint32_t const page_size = get_arg_val(arg_idx++); - uint32_t const full_chunk_num_pages = get_arg_val(arg_idx++); - uint32_t const writer_send_sem_addr = get_arg_val(arg_idx++); - uint32_t const half_cb_n_pages = get_arg_val(arg_idx++); - uint32_t const num_concurrent_workers = get_arg_val(arg_idx++); - - coord_t const& output_tensor_shape = ttnn::ccl::coord_from_args(arg_idx); - coord_t const& worker_slice_shape = ttnn::ccl::coord_from_args(arg_idx); - coord_t worker_slice_base_offset = ttnn::ccl::coord_from_args(arg_idx); - - uint32_t total_eltwise_kernel_num_pages = get_arg_val(arg_idx++); - - // Argument validation - ASSERT(half_cb_n_pages >= full_chunk_num_pages); - ASSERT(full_chunk_num_pages > 0); - ASSERT(page_size > 0); - ASSERT(half_cb_n_pages > 0); - - constexpr uint32_t cb_id_in0 = tt::CB::c_out0; - constexpr uint32_t cb_id_in_short_circuit = tt::CB::c_out1; - const DataFormat in0_df = get_dataformat(cb_id_in0); -#ifdef RM_INTERLEAVED - InterleavedAddrGen d = { - .bank_base_address = dst_addr + output_start_addr_offset, .page_size = page_size}; -#elif defined TILE_INTERLEAVED - - InterleavedAddrGenFast d = { - .bank_base_address = dst_addr, .page_size = page_size, .data_format = in0_df}; -#endif - - // Used to wait until eth sender has space available - volatile tt_l1_ptr uint32_t* writer_send_semaphore_addr_ptr = - reinterpret_cast(writer_send_sem_addr); - // This is different per writer core - const uint64_t eth_l1_sender_base_noc_addr = - get_noc_addr(eth_sender_noc_x, eth_sender_noc_y, eth_sender_l1_base_addr); - // Used to signal eth sender that data is available. This is different per writer core - const uint64_t eth_l1_sender_semaphore_addr = - get_noc_addr(eth_sender_noc_x, eth_sender_noc_y, eth_sender_l1_sem_addr); - - uint32_t total_lifetime_cb_pages_popped_from_math = 0; - while (worker_slice_base_offset.x < output_tensor_shape.x && worker_slice_base_offset.y < output_tensor_shape.y) { - // First phase - we only forward messages to EDM - coord_t valid_worker_slice_shape = coord_t( - std::min(worker_slice_shape.x, output_tensor_shape.x - worker_slice_base_offset.x), - std::min(worker_slice_shape.y, output_tensor_shape.y - worker_slice_base_offset.y)); - uint32_t const num_pages_to_write = valid_worker_slice_shape.x * valid_worker_slice_shape.y; - - ASSERT(total_lifetime_cb_pages_popped_from_math + num_pages_to_write <= total_eltwise_kernel_num_pages); - for (uint32_t i = 0; i < num_transfers; ++i) { - const uint32_t cb_in = i == 0 ? cb_id_in_short_circuit : cb_id_in0; - for (uint32_t p = 0; p < num_pages_to_write; p += full_chunk_num_pages) { - uint32_t n_pages = std::min(full_chunk_num_pages, num_pages_to_write - p); - ASSERT(n_pages > 0); - noc_semaphore_wait(writer_send_semaphore_addr_ptr, 1); - noc_semaphore_set(writer_send_semaphore_addr_ptr, 0); - send_chunk(cb_in, n_pages, page_size, eth_l1_sender_base_noc_addr); - noc_semaphore_inc( - eth_l1_sender_semaphore_addr, - ttnn::ccl::EriscDataMoverWorkerSignal::NEXT_MESSAGE_AVAILABLE); - if (i != 0) { - total_lifetime_cb_pages_popped_from_math += n_pages; - } - if (n_pages < half_cb_n_pages) { - uint32_t num_filler_pages = half_cb_n_pages - n_pages; - - ASSERT(p + n_pages == num_pages_to_write); - pop_filler_pages_from_cb(cb_in, num_filler_pages); - if (i != 0) { - total_lifetime_cb_pages_popped_from_math += num_filler_pages; - } - } - } - } - - // write the final reduced chunk for this chip out to the output tensor - // Second phase - Dump the local output to the output tensor - uint32_t curr_ring_slice_start_page_offset = 0; - const uint32_t worker_relative_start_offset_into_slice = - worker_slice_base_offset.x + (worker_slice_base_offset.y * output_tensor_shape.x); - auto current_worker_slice_offset = worker_slice_base_offset; - const uint32_t starting_tile_id = curr_ring_slice_start_page_offset + worker_relative_start_offset_into_slice; - uint32_t curr_tile_id = starting_tile_id; - - bool last_page_of_worker = false; - for (uint32_t p = 0; p < num_pages_to_write; p += full_chunk_num_pages) { - ASSERT(curr_tile_id < output_tensor_shape.x * output_tensor_shape.y); - ASSERT(!last_page_of_worker); - uint32_t n_pages = std::min(full_chunk_num_pages, num_pages_to_write - p); - ASSERT(n_pages <= half_cb_n_pages); - ASSERT(full_chunk_num_pages <= half_cb_n_pages); - write_chunk_v2( - curr_tile_id, - current_worker_slice_offset, - valid_worker_slice_shape, - output_tensor_shape, // In tiles for tile layout - cb_id_in0, - d, - n_pages, - page_size, - last_page_of_worker); - total_lifetime_cb_pages_popped_from_math += n_pages; - if (n_pages < half_cb_n_pages) { - uint32_t num_filler_pages = half_cb_n_pages - n_pages; - ASSERT(p + n_pages == num_pages_to_write); - pop_filler_pages_from_cb(cb_id_in0, num_filler_pages); - total_lifetime_cb_pages_popped_from_math += num_filler_pages; - } - } - - worker_slice_base_offset = advance_slice_row_major( - worker_slice_base_offset, worker_slice_shape, output_tensor_shape, num_concurrent_workers); - } - - ASSERT(total_lifetime_cb_pages_popped_from_math <= total_eltwise_kernel_num_pages); - for (; total_lifetime_cb_pages_popped_from_math < total_eltwise_kernel_num_pages; - total_lifetime_cb_pages_popped_from_math++) { - pop_filler_pages_from_cb(cb_id_in0, 1); - } - - noc_semaphore_wait(writer_send_semaphore_addr_ptr, 1); - noc_semaphore_set(writer_send_semaphore_addr_ptr, 0); - noc_semaphore_inc( - eth_l1_sender_semaphore_addr, ttnn::ccl::EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY); -} diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp deleted file mode 100644 index ecacfca26a6..00000000000 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp +++ /dev/null @@ -1,132 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp -#include "ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp" -======= -#include "ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp" ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp - -#include "ttnn/experimental/tt_dnn/op_library/reduce/reduce_op.hpp" -#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" -#include "tt_metal/host_api.hpp" - -#include "ttnn/operations/eltwise/binary/binary.hpp" - - -namespace ttnn { -namespace utils { - -void ReduceScatter::validate(const std::vector& input_tensors) const { - for (auto const& t : input_tensors) { - TT_FATAL( - t.get_legacy_shape()[this->scatter_dim] / this->ring_size > 0, - "Reduce scatter input tensor shape on dim {} must be divisible by ring size"); - TT_FATAL( - t.get_legacy_shape()[this->scatter_dim] % this->ring_size == 0, - "Reduce scatter input tensor shape on dim {} must be divisible by ring size"); - } -} - -std::vector ReduceScatter::compute_output_shapes(const std::vector& input_tensors) const { - auto shape = input_tensors[0].get_legacy_shape(); - TT_ASSERT( - shape[this->scatter_dim] % this->ring_size == 0, - "The size of the scatter dimension must be a multiple of the ring size"); - shape[this->scatter_dim] /= this->ring_size; - return std::vector(input_tensors.size(), shape); -} - -std::vector ReduceScatter::create_output_tensors(const std::vector& input_tensors) const { - const auto& input_tensor = input_tensors.at(0); - if (this->output_mem_config.is_sharded()) { - TT_FATAL(false, "Sharded output is not supported for ReduceScatter"); - } else { - return operation::generic_create_output_tensors( - *this, input_tensors, input_tensor.get_dtype(), input_tensor.get_layout(), this->output_mem_config); - } -} - -operation::ProgramWithCallbacks ReduceScatter::create_program( - const std::vector& input_tensors, std::vector& output_tensors) const { - return ccl::reduce_scatter_detail::reduce_scatter_with_workers( - input_tensors, - output_tensors, - this->binary_op_type, - this->scatter_dim, - this->num_links, - this->ring_size, - this->ring_index, - this->receiver_device_id, - this->sender_device_id, - this->topology); -} - -std::vector reduce_scatter_impl( - const std::vector& input_tensors, - const ttnn::operations::binary::BinaryOpType binary_op_type, - const uint32_t scatter_dim, - const uint32_t num_links, - const MemoryConfig& output_mem_config, - const ttnn::ccl::Topology topology) { - TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "This op is only supported for Fast Dispatch"); - - std::vector output_tensors; - output_tensors.reserve(input_tensors.size()); - std::vector ops; - ops.reserve(input_tensors.size()); -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp - bool is_ring = topology ==ttnn::ccl::Topology::Ring; - for (uint32_t i = 0; i < input_tensors.size(); ++i) { - bool is_last_chip_in_clockwise_direction = is_ring ? false : i == (input_tensors.size() - 1); - bool is_last_chip_in_counter_clockwise_direction = is_ring ? false : i == 0; -======= - bool is_ring = topology ==ttnn::utils::ccl::Topology::Ring; ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp - - for (uint32_t i = 0; i < input_tensors.size(); ++i) { - std::optional receiver_device_id = - is_last_chip_in_clockwise_direction - ? std::nullopt - : std::optional(input_tensors[(i + 1) % input_tensors.size()].device()->id()); - std::optional sender_device_id = - is_last_chip_in_counter_clockwise_direction - ? std::nullopt - : std::optional(input_tensors[i == 0 ? input_tensors.size() - 1 : i - 1].device()->id()); - ops.emplace_back(ReduceScatter{ - binary_op_type, - scatter_dim, - num_links, - static_cast(input_tensors.size()), - i, - receiver_device_id, - sender_device_id, - output_mem_config, - topology}); - output_tensors.push_back(operation::run(ops[i], {input_tensors.at(i)}).at(0)); - } - return output_tensors; -} - -static ttnn::operations::binary::BinaryOpType convert_reduce_type_to_eltwise_type(ReduceOpMath reduce_op) { - switch (reduce_op) { - case ReduceOpMath::SUM: return ttnn::operations::binary::BinaryOpType::ADD; - - default: TT_FATAL("Reduce scatter only support reduce_op_type SUM"); return ttnn::operations::binary::BinaryOpType::ADD; - } -} - -std::vector reduce_scatter( - const std::vector& input_tensors, - const uint32_t scatter_dim, - ReduceOpMath math_op, - const uint32_t num_links, - const MemoryConfig& output_mem_config) { - ttnn::operations::binary::BinaryOpType binary_op_type = convert_reduce_type_to_eltwise_type(math_op); - return reduce_scatter_impl( - input_tensors, binary_op_type, scatter_dim, num_links, output_mem_config,ttnn::ccl::Topology::Ring); -} - -}; // namespace utils -}; // namespace ttnn diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp deleted file mode 100644 index 2440d5babb9..00000000000 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp +++ /dev/null @@ -1,67 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include "ttnn/experimental/tt_dnn/op_library/run_operation.hpp" -#include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" -#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" -#include "ttnn/experimental/tt_dnn/op_library/reduce/reduce_op.hpp" - -#include "ttnn/operations/eltwise/binary/binary.hpp" - -namespace ttnn { -namespace utils { - -struct ReduceScatter { - const ttnn::operations::binary::BinaryOpType binary_op_type; - const uint32_t scatter_dim; - const uint32_t num_links; - const uint32_t ring_size; - const uint32_t ring_index; - const std::optional receiver_device_id; - const std::optional sender_device_id; - const MemoryConfig output_mem_config; - const ttnn::ccl::Topology topology; - - void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; - std::vector create_output_tensors(const std::vector &input_tensors) const; - operation::ProgramWithCallbacks create_program( - const std::vector &input_tensors, std::vector &output_tensors) const; -}; - -std::vector reduce_scatter( - const std::vector &input_tensors, - const uint32_t scatter_split_dim, - ReduceOpMath reduce_op = ReduceOpMath::SUM, - const uint32_t num_links = 1, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -namespace ccl { -namespace reduce_scatter_detail { -operation::ProgramWithCallbacks reduce_scatter_with_workers( - const std::vector& input_tensors, - const std::vector& output_tensors, - ttnn::operations::binary::BinaryOpType reduce_op, - const uint32_t scatter_split_dim, - const uint32_t num_links, - const uint32_t ring_size, - const uint32_t ring_index, - const std::optional receiver_device_id, - const std::optional sender_device_id, -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp - ttnn::ccl::Topology topology); -======= -<<<<<<< HEAD:tt_eager/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp - ttnn::utils::ccl::Topology topology); -======= - tt::tt_metal::ccl::Topology topology); ->>>>>>> bdb9766ed5... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp -} -}; // namespace ccl - -}; // namespace utils -}; // namespace ttnn diff --git a/ttnn/cpp/ttnn/experimental/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp b/ttnn/cpp/ttnn/experimental/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp index 1199deabf86..28a24e7dd1a 100644 --- a/ttnn/cpp/ttnn/experimental/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp +++ b/ttnn/cpp/ttnn/experimental/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp @@ -19,7 +19,6 @@ #include "ttnn/experimental/tt_dnn/op_library/non_zero_indices/non_zero_indices_op.hpp" #include "ttnn/experimental/tt_dnn/op_library/sharded/sharded_op.hpp" #include "ttnn/experimental/tt_dnn/op_library/sharded_partial/sharded_op_partial.hpp" -#include "ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp" namespace tt::tt_metal::detail{ @@ -447,28 +446,6 @@ namespace tt::tt_metal::detail{ R"doc(Converts a partial tensor from sharded_to_interleaved memory layout)doc" ); - // ---------- Multi-Device ops ---------- - - // Reduce Scatter - m_tensor.def("reduce_scatter", &reduce_scatter, - py::arg("input_tensors"), py::arg("scatter_split_dim"), py::arg("reduce_op"), py::arg("num_links") = 1, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - R"doc( - Performs reduce scatter across chips, where the input tensors are sliced along the scatter dim, and pairwise reduced as they propagate and reduce through the cluster. - - For example, a reduce scatter on a ring of rank 8 and input tensor shapes (per rank) of [1,1,1024,8096] and scatter_dim=3, will split each input tensor - on width into 8 parts of size [1,1,1024,1024]. Each of those parts will reduce with the corresponding chunk from the other ranks. All chips will collectively - reduce the first incoming [1,1,1024,1024] chunk with their local first [1,1,1024,1024] chunk and be forwarded. The second incoming [1,1,1024,1024] chunk will - be reduced with the second local [1,1,1024,1024] chunk and be forwarded and so on. Each rank in the ring will start on a different offset into the chunk such - that by the end, they will finish with a different reduced chunk offset from the original tensor shape. - - .. csv-table:: - :header: "Argument", "Description", "Data type", "Valid range", "Required" - - "scatter_split_dim", "Dimension to evenly slice input tensor along for each rank", "int", "0..3", "Yes" - "reduce_op", "reduction math operation", " ReduceOpMath", "SUM", "No" - "num_links", "Number of ethernet links to allow the op to use to send data chip to chip for the operation. Default=1", "int", "1..max_num_links", "No" - "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" - )doc"); } } diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp index b5c3817ee82..94c84f41a38 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp @@ -4,23 +4,13 @@ #pragma once -<<<<<<< HEAD -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp #include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" -======= -#include "ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.hpp" ->>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp -======= -#include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" ->>>>>>> af98ddace6... #9486: Move kernel files into kernels directory #include "ttnn/cpp/ttnn/multi_device.hpp" namespace ttnn { namespace operations { namespace ccl { -<<<<<<< HEAD -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp struct ExecuteAllGather { static ttnn::Tensor execute_on_main_thread( @@ -32,49 +22,9 @@ struct ExecuteAllGather { } }; -======= ->>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp -struct ExecuteLineAllGather { -======= -struct ExecuteAllGather { ->>>>>>> af98ddace6... #9486: Move kernel files into kernels directory - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::ROW_MAJOR_LAYOUT, ttnn::TILE_LAYOUT}, - true, - false, - false, - false}}; - } - - template - static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); - } - - static ttnn::Tensor execute_on_main_thread( - const ttnn::Tensor& input_tensor, - const uint32_t dim, - const uint32_t num_links = 1, - const std::optional& memory_config = std::nullopt) { -<<<<<<< HEAD - return ttnn::operations::ccl::line_all_gather(input_tensor, dim, num_links, memory_config); -======= - return ttnn::operations::ccl::all_gather(input_tensor, dim, num_links, memory_config); ->>>>>>> af98ddace6... #9486: Move kernel files into kernels directory - } -}; - } // namespace ccl } // namespace operations -<<<<<<< HEAD -constexpr auto line_all_gather = ttnn::register_operation("ttnn::line_all_gather"); -======= constexpr auto all_gather = ttnn::register_operation("ttnn::all_gather"); ->>>>>>> af98ddace6... #9486: Move kernel files into kernels directory } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp index aafbb457f0d..c3aa4015c49 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp @@ -8,39 +8,19 @@ #include #include "ttnn/cpp/pybind11/decorators.hpp" -<<<<<<< HEAD -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp #include "ttnn/operations/ccl/all_gather/all_gather_op.hpp" -======= -#include "ttnn/operations/ccl/line_all_gather/device/ccl_line_all_gather_op.hpp" ->>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp -======= -#include "ttnn/operations/ccl/all_gather/all_gather_op.hpp" ->>>>>>> af98ddace6... #9486: Move kernel files into kernels directory #include "ttnn/types.hpp" namespace py = pybind11; namespace ttnn { namespace operations { -<<<<<<< HEAD -namespace ccl_line_all_gather { -======= namespace ccl { ->>>>>>> af98ddace6... #9486: Move kernel files into kernels directory namespace detail { template -<<<<<<< HEAD -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp -void bind_all_gather(py::module& module, const ccl_operation_t& operation, const char* doc) { -======= -void bind_ccl_line_all_gather(py::module& module, const ccl_operation_t& operation, const char* doc) { ->>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp -======= void bind_all_gather(py::module& module, const ccl_operation_t& operation, const char* doc) { ->>>>>>> af98ddace6... #9486: Move kernel files into kernels directory bind_registered_operation( module, operation, @@ -63,12 +43,7 @@ void bind_all_gather(py::module& module, const ccl_operation_t& operation, const } // namespace detail -<<<<<<< HEAD -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp void py_bind_all_gather(py::module& module) { -======= -void py_module_all_gather(py::module& module) { ->>>>>>> af98ddace6... #9486: Move kernel files into kernels directory detail::bind_all_gather( module, ttnn::all_gather, @@ -90,39 +65,8 @@ void py_module_all_gather(py::module& module) { >>> output = ttnn.all_gather(tensor, dim=0) )doc"); -<<<<<<< HEAD -======= -void py_module(py::module& module) { ->>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN:ttnn/cpp/ttnn/operations/ccl/line_all_gather/ccl_line_all_gather_pybind.hpp - - detail::bind_ccl_line_all_gather( - module, - ttnn::line_all_gather, - R"doc(line_all_gather(input_tensor: ttnn.Tensor, dim: int, *, num_links: int = 1, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor - - Performs an all-gather operation on multi-device :attr:`input_tensor` across all devices. - - Args: - * :attr:`input_tensor` (ttnn.Tensor): multi-device tensor - * :attr:`dim` (int) - - Keyword Args: - * :attr:`num_links` (int): Number of links to use for the all-gather operation. - * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. - - Example: - - >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) - >>> output = ttnn.line_all_gather(tensor, dim=0) - - )doc"); -} - -} // namespace ccl_line_all_gather -======= } } // namespace ccl ->>>>>>> af98ddace6... #9486: Move kernel files into kernels directory } // namespace operations } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp index 1fea3f29555..6c45db409de 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp @@ -13,15 +13,9 @@ #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include "tt_metal/hw/inc/wormhole/noc/noc.h" -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp using ttnn::ccl::EriscDataMoverBufferSharingMode; using ttnn::ccl::EriscDataMoverTerminationMode; using ttnn::ccl::EriscDataMoverWorkerSignal; -======= -using ttnn::utils::ccl::EriscDataMoverBufferSharingMode; -using ttnn::utils::ccl::EriscDataMoverTerminationMode; -using ttnn::utils::ccl::EriscDataMoverWorkerSignal; ->>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp namespace erisc { namespace datamover { @@ -40,11 +34,7 @@ struct edm_worker_index { uint16_t worker_index = 0; }; -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp using ttnn::ccl::WorkerXY; -======= -using ttnn::utils::ccl::WorkerXY; ->>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp /* * The `ChannelBuffer` is a building block of the Erisc Data Mover (EDM). For every concurrent transaction @@ -125,25 +115,16 @@ class ChannelBuffer final { is_sender_side(is_sender_side) { clear_local_semaphore(); -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp if (TERMINATION_MODE != ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED || total_num_messages_to_move != 0) { -======= - if (TERMINATION_MODE != ttnn::utils::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED || total_num_messages_to_move != 0) { ->>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp if (is_sender_side) { // Tell the sender side workers that we're ready to accept data on this channel increment_worker_semaphores(); } } else { -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp ASSERT(TERMINATION_MODE != ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED); -======= - ASSERT(TERMINATION_MODE != ttnn::utils::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED); ->>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp goto_state(STATE::DONE); } - }; - + } // Resets the semaphore in local L1, which workers write to remotely. FORCE_INLINE void clear_local_semaphore() { noc_semaphore_set(local_semaphore_address, 0); } diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp index 6de2e2c5016..23d8c41e252 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp @@ -8,23 +8,9 @@ #include "dataflow_api.h" #include "debug/dprint.h" #include "eth_l1_address_map.h" -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include "ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp" -======= -#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" -<<<<<<< HEAD -#include "ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp" ->>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp -======= -<<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp -#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp" -======== -#include "ttnn/cpp/ttnn/operations/ccl/edm/erisc_async_datamover.hpp" ->>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp ->>>>>>>> af98ddace6... #9486: Move kernel files into kernels directory:ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp ->>>>>>> af98ddace6... #9486: Move kernel files into kernels directory // Args Schema: // 1) handshake addr @@ -59,11 +45,7 @@ FORCE_INLINE void eth_setup_handshake2(std::uint32_t handshake_register_address, } } -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp using ttnn::ccl::WorkerXY; -======= -using ttnn::utils::ccl::WorkerXY; ->>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp template struct sender_receiver_index_t { @@ -136,19 +118,11 @@ void kernel_main() { constexpr uint32_t num_senders = get_compile_time_arg_val(2); constexpr uint32_t num_receivers = get_compile_time_arg_val(3); -<<<<<<< HEAD:ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp constexpr ttnn::ccl::EriscDataMoverBufferSharingMode edm_buffer_sharing_mode = static_cast(get_compile_time_arg_val(4)); constexpr ttnn::ccl::EriscDataMoverTerminationMode terminate_on_worker_signal = static_cast(get_compile_time_arg_val(5)); -======= - constexpr ttnn::utils::ccl::EriscDataMoverBufferSharingMode edm_buffer_sharing_mode = - static_cast(get_compile_time_arg_val(4)); - - constexpr ttnn::utils::ccl::EriscDataMoverTerminationMode terminate_on_worker_signal = - static_cast(get_compile_time_arg_val(5)); ->>>>>>> f290d934d9... #9486: Move CCL common to TTNN:ttnn/cpp/ttnn/operations/ccl/edm/erisc_datamover.cpp constexpr auto EDM_CONFIG = erisc::datamover::EriscDatamoverConfig(); using EDM_CONFIG_T = decltype(EDM_CONFIG); diff --git a/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.cpp index 2a900f153ef..4c32817bdfb 100644 --- a/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.cpp @@ -14,12 +14,6 @@ namespace ttnn { -<<<<<<< HEAD -======= -namespace utils { - - ->>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN void LineAllGather::validate(const std::vector &input_tensors) const { TT_FATAL(input_tensors.size() == 1); const auto& input_tensor = input_tensors[0]; @@ -90,51 +84,6 @@ operation::ProgramWithCallbacks LineAllGather::create_program(const std::vector< }; } -<<<<<<< HEAD -======= - - -std::vector line_all_gather_impl(const std::vector& input_tensors, const uint32_t dim, const uint32_t num_links, const MemoryConfig& output_mem_config, const all_gather_op::Topology topology) { - - TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "This op is only supported for Fast Dispatch"); - - std::vector output_tensors = std::vector(input_tensors.size()); - - bool is_ring = topology == all_gather_op::Topology::Ring; - uint32_t num_inputs = static_cast(input_tensors.size()); - for (uint32_t i = 0; i < input_tensors.size(); ++i) { - output_tensors[i] = Tensor(operation::get_workers_for_op_output({input_tensors[i]})); - // Extract these tensors in the main thread, since they're used to get the sender and receiver device ids - // Dont get the device in the main thread, since it can cause stalls in async mode. - const Tensor& tensor_on_receiver = input_tensors[(i + 1) % num_inputs]; - const Tensor& tensor_on_sender = input_tensors[i == 0 ? num_inputs - 1 : i - 1]; - // Package output in vector, to populate it with launch_op - std::vector output_for_curr_device = {output_tensors[i]}; - operation::launch_op( - [is_ring, dim, num_links, i, num_inputs, output_mem_config, topology] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { - bool is_last_chip_in_clockwise_direction = is_ring ? false : i == (num_inputs - 1); - bool is_last_chip_in_counter_clockwise_direction = is_ring ? false : i == 0; - - std::optional receiver_device_id = is_last_chip_in_clockwise_direction ? - std::nullopt : - std::optional(input_tensors.at(1).device()->id()); - std::optional sender_device_id = is_last_chip_in_counter_clockwise_direction ? - std::nullopt : - std::optional(input_tensors.at(2).device()->id()); - return operation::run(LineAllGather{dim, num_links, num_inputs, i, receiver_device_id, sender_device_id, output_mem_config,topology}, {input_tensors.at(0)}); - }, - {input_tensors[i], tensor_on_receiver, tensor_on_sender}, output_for_curr_device); - } - return output_tensors; -} - -std::vector line_all_gather(const std::vector& input_tensors, const uint32_t dim, const uint32_t num_links, const MemoryConfig& output_mem_config) { - return line_all_gather_impl(input_tensors, dim, num_links, output_mem_config, all_gather_op::Topology::Linear); -} - -} // namespace utils - ->>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN namespace operations { namespace ccl { @@ -168,13 +117,8 @@ Tensor line_all_gather( } return operation::run( -<<<<<<< HEAD ttnn::LineAllGather{ dim, num_links, num_devices, device_index, receiver_device_id, sender_device_id, memory_config.value_or(input_tensor.memory_config()), ttnn::all_gather_op::Topology::Linear}, -======= - ttnn::utils::LineAllGather{ - dim, num_links, num_devices, device_index, receiver_device_id, sender_device_id, memory_config.value_or(input_tensor.memory_config()), ttnn::utils::all_gather_op::Topology::Linear}, ->>>>>>> 60a6703d2e... #9486: Move CCL kernel files to TTNN {input_tensor}); }, {input_tensor}, diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp index b39054ba7c8..909566ece2a 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp @@ -36,8 +36,6 @@ using namespace tt::constants; namespace ttnn { -namespace utils { - namespace ccl { namespace reduce_scatter_detail { struct WorkerTransferInfo { @@ -889,5 +887,4 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( } // namespace reduce_scatter_detail } // namespace ccl -} // namespace utils } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp index 927d2c529ec..11271380809 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp @@ -12,7 +12,6 @@ namespace ttnn { -namespace utils { void ReduceScatter::validate(const std::vector& input_tensors) const { for (auto const& t : input_tensors) { @@ -59,6 +58,16 @@ operation::ProgramWithCallbacks ReduceScatter::create_program( this->topology); } +static ttnn::operations::binary::BinaryOpType convert_reduce_type_to_eltwise_type(ReduceOpMath reduce_op) { + switch (reduce_op) { + case ReduceOpMath::SUM: return ttnn::operations::binary::BinaryOpType::ADD; + + default: TT_FATAL("Reduce scatter only support reduce_op_type SUM"); return ttnn::operations::binary::BinaryOpType::ADD; + } +} + +namespace operations{ +namespace ccl{ std::vector reduce_scatter_impl( const std::vector& input_tensors, const ttnn::operations::binary::BinaryOpType binary_op_type, @@ -72,19 +81,7 @@ std::vector reduce_scatter_impl( output_tensors.reserve(input_tensors.size()); std::vector ops; ops.reserve(input_tensors.size()); -<<<<<<< HEAD -<<<<<<< HEAD:ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp bool is_ring = topology ==ttnn::ccl::Topology::Ring; - for (uint32_t i = 0; i < input_tensors.size(); ++i) { - bool is_last_chip_in_clockwise_direction = is_ring ? false : i == (input_tensors.size() - 1); - bool is_last_chip_in_counter_clockwise_direction = is_ring ? false : i == 0; -======= - bool is_ring = topology ==ttnn::utils::ccl::Topology::Ring; ->>>>>>> 8170cf2cca... #9486: Merge CCL reduce_scatter to TTNN:ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp - -======= - bool is_ring = topology == ccl::Topology::Ring; ->>>>>>> a98abddcea... #0: Fix issues for (uint32_t i = 0; i < input_tensors.size(); ++i) { bool is_last_chip_in_clockwise_direction = is_ring ? false : i == (input_tensors.size() - 1); bool is_last_chip_in_counter_clockwise_direction = is_ring ? false : i == 0; @@ -97,7 +94,7 @@ std::vector reduce_scatter_impl( is_last_chip_in_counter_clockwise_direction ? std::nullopt : std::optional(input_tensors[i == 0 ? input_tensors.size() - 1 : i - 1].device()->id()); - ops.emplace_back(ReduceScatter{ + ops.emplace_back(ttnn::ReduceScatter{ binary_op_type, scatter_dim, num_links, @@ -112,14 +109,6 @@ std::vector reduce_scatter_impl( return output_tensors; } -static ttnn::operations::binary::BinaryOpType convert_reduce_type_to_eltwise_type(ReduceOpMath reduce_op) { - switch (reduce_op) { - case ReduceOpMath::SUM: return ttnn::operations::binary::BinaryOpType::ADD; - - default: TT_FATAL("Reduce scatter only support reduce_op_type SUM"); return ttnn::operations::binary::BinaryOpType::ADD; - } -} - std::vector reduce_scatter( const std::vector& input_tensors, const uint32_t scatter_dim, @@ -130,6 +119,7 @@ std::vector reduce_scatter( return reduce_scatter_impl( input_tensors, binary_op_type, scatter_dim, num_links, output_mem_config,ttnn::ccl::Topology::Ring); } +} // namespace ccl +} // namespace operations -}; // namespace utils }; // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp index 79d4c86e199..b5052e28da1 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp @@ -12,7 +12,6 @@ #include "ttnn/operations/eltwise/binary/binary.hpp" namespace ttnn { -namespace utils { struct ReduceScatter { const ttnn::operations::binary::BinaryOpType binary_op_type; @@ -32,13 +31,6 @@ struct ReduceScatter { const std::vector &input_tensors, std::vector &output_tensors) const; }; -std::vector reduce_scatter( - const std::vector &input_tensors, - const uint32_t scatter_split_dim, - ReduceOpMath reduce_op = ReduceOpMath::SUM, - const uint32_t num_links = 1, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - namespace ccl { namespace reduce_scatter_detail { operation::ProgramWithCallbacks reduce_scatter_with_workers( @@ -55,5 +47,16 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( } }; // namespace ccl -}; // namespace utils +namespace operations{ +namespace ccl{ + std::vector reduce_scatter( + const std::vector &input_tensors, + const uint32_t scatter_split_dim, + ReduceOpMath reduce_op = ReduceOpMath::SUM, + const uint32_t num_links = 1, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); +} // namespace ccl +} // namespace operations + + }; // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_op.hpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_op.hpp index 4e863ae816b..e6f3f9dfb83 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_op.hpp @@ -20,7 +20,7 @@ struct ExecuteReduceScatter { const uint32_t num_links = 1, const std::optional& memory_config = std::nullopt) { MemoryConfig out_memory_config = memory_config.value_or(input_tensors.at(0).memory_config()); - return utils::reduce_scatter(input_tensors, scatter_dim, math_op, num_links, out_memory_config); + return ttnn::operations::ccl::reduce_scatter(input_tensors, scatter_dim, math_op, num_links, out_memory_config); } }; diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.hpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.hpp index 7ad64dd7c67..f4c5482635a 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.hpp @@ -45,7 +45,7 @@ void bind_reduce_scatter(py::module& module, const ccl_operation_t& operation, c } // namespace detail -void py_module_reduce_scatter(py::module& module) { +void py_bind_reduce_scatter(py::module& module) { detail::bind_reduce_scatter( module, From de1d373fdd73ea3a831b2a366960e221a849f7d8 Mon Sep 17 00:00:00 2001 From: Aswinmcw Date: Sat, 20 Jul 2024 10:38:24 +0000 Subject: [PATCH 10/10] #9486: Replace ttdnn op with ttnn --- models/demos/t3000/falcon40b/tt/falcon_mlp.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/models/demos/t3000/falcon40b/tt/falcon_mlp.py b/models/demos/t3000/falcon40b/tt/falcon_mlp.py index 8259bef8700..b4e9614e6f0 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_mlp.py +++ b/models/demos/t3000/falcon40b/tt/falcon_mlp.py @@ -124,12 +124,12 @@ def fwd_decode(self, x: List[ttnn.experimental.tensor.Tensor]) -> List[ttnn.expe hidden_states ) # Workaround for reduce_scatter only taking a vector of tensors and not device_mesh - hidden_states = ttnn.experimental.tensor.reduce_scatter( + hidden_states = ttnn.reduce_scatter( hidden_states, - scatter_split_dim=3, - reduce_op=ttnn.experimental.tensor.ReduceOpMath.SUM, + scatter_dim=3, + math_op=ttnn.experimental.tensor.ReduceOpMath.SUM, num_links=1, # only unidirectional supported for now - output_mem_config=self.model_config["DEFAULT_MEMCFG"], + memory_config=self.model_config["DEFAULT_MEMCFG"], ) hidden_states = ttnn.aggregate_as_tensor(hidden_states) # Workaround reverse @@ -198,12 +198,12 @@ def fwd_prefill(self, x: List[ttnn.experimental.tensor.Tensor]) -> List[ttnn.exp self.output ) # Workaround for reduce_scatter only taking a vector of tensors and not device_mesh - hidden_states = ttnn.experimental.tensor.reduce_scatter( + hidden_states = ttnn.reduce_scatter( hidden_states, - scatter_split_dim=3, - reduce_op=ttnn.experimental.tensor.ReduceOpMath.SUM, + scatter_dim=3, + math_op=ttnn.experimental.tensor.ReduceOpMath.SUM, num_links=1, # only one link supported for now - output_mem_config=self.model_config["DEFAULT_MEMCFG"], + memory_config=self.model_config["DEFAULT_MEMCFG"], ) hidden_states = ttnn.aggregate_as_tensor(hidden_states) # Workaround reverse