From 4ca86ace4ae877d4e5a626af65dd4cd4df217a88 Mon Sep 17 00:00:00 2001 From: Aswinmcw Date: Thu, 21 Nov 2024 07:50:11 +0000 Subject: [PATCH] Add negative dim support --- .../operations/ccl/test_all_gather_nightly.py | 8 +++---- .../ccl/test_reduce_scatter_nightly.py | 16 ++++++------- .../operations/ccl/all_gather/all_gather.cpp | 8 +++---- .../operations/ccl/all_gather/all_gather.hpp | 4 ++-- .../ccl/all_gather/all_gather_pybind.cpp | 16 ++++++------- .../ccl/all_gather/device/all_gather_op.cpp | 17 +++++++++++-- .../ccl/all_gather/device/all_gather_op.hpp | 4 ++-- .../device/reduce_scatter_op.cpp | 24 ++++++++++++++----- .../device/reduce_scatter_op.hpp | 4 ++-- .../ccl/reduce_scatter/reduce_scatter.cpp | 4 ++-- .../ccl/reduce_scatter/reduce_scatter.hpp | 4 ++-- .../reduce_scatter/reduce_scatter_pybind.cpp | 12 +++++----- 12 files changed, 73 insertions(+), 48 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_nightly.py index ad1d7a63abee..3313d73880cf 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_nightly.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_nightly.py @@ -22,15 +22,15 @@ [ (4, 1, [4, 1, 33, 256], 0, ttnn.ROW_MAJOR_LAYOUT), (8, 1, [8, 1, 33, 256], 0, ttnn.ROW_MAJOR_LAYOUT), - (8, 1, [8, 1, 256, 32], 0, ttnn.TILE_LAYOUT), + (8, 1, [8, 1, 256, 32], -4, ttnn.TILE_LAYOUT), (8, 1, [8, 8, 256, 384], 1, ttnn.ROW_MAJOR_LAYOUT), # (4, 2, [8, 8, 256, 384], 1, ttnn.TILE_LAYOUT), (8, 1, [8, 8, 256, 384], 1, ttnn.TILE_LAYOUT), - (4, 1, [8, 5, 13, 384], 3, ttnn.ROW_MAJOR_LAYOUT), - (8, 1, [8, 5, 13, 512], 3, ttnn.ROW_MAJOR_LAYOUT), + (4, 1, [8, 5, 13, 384], -1, ttnn.ROW_MAJOR_LAYOUT), + (8, 1, [8, 5, 13, 512], -1, ttnn.ROW_MAJOR_LAYOUT), (4, 1, [8, 5, 32, 384], 3, ttnn.TILE_LAYOUT), (8, 1, [8, 5, 32, 512], 3, ttnn.TILE_LAYOUT), - (4, 1, [1, 1, 32, 16384], 3, ttnn.TILE_LAYOUT), + (4, 1, [1, 1, 32, 16384], -1, ttnn.TILE_LAYOUT), ], ) @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_nightly.py index 17eee1079727..5a00f7883ab2 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_nightly.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_nightly.py @@ -24,18 +24,18 @@ ([1, 8, 1024, 1024], 3, ttnn.TILE_LAYOUT), ([1, 4, 1024, 1024], 3, ttnn.TILE_LAYOUT), ([1, 4, 2048, 1024], 3, ttnn.TILE_LAYOUT), - ([1, 1, 32, 32], 3, ttnn.TILE_LAYOUT), - ([1, 1, 32, 64], 3, ttnn.TILE_LAYOUT), + ([1, 1, 32, 32], -1, ttnn.TILE_LAYOUT), + ([1, 1, 32, 64], -1, ttnn.TILE_LAYOUT), ([1, 1, 64, 64], 3, ttnn.TILE_LAYOUT), - ([1, 1, 32, 128], 3, ttnn.TILE_LAYOUT), + ([1, 1, 32, 128], -1, ttnn.TILE_LAYOUT), ([1, 1, 32, 256], 3, ttnn.TILE_LAYOUT), ([1, 1, 32, 512], 3, ttnn.TILE_LAYOUT), ([1, 1, 32, 1024], 3, ttnn.TILE_LAYOUT), ([1, 1, 32, 2048], 3, ttnn.TILE_LAYOUT), - ([1, 1, 128, 1024], 3, ttnn.TILE_LAYOUT), + ([1, 1, 128, 1024], -1, ttnn.TILE_LAYOUT), ([1, 1, 128, 8192], 3, ttnn.TILE_LAYOUT), ([1, 1, 2048, 1024], 3, ttnn.TILE_LAYOUT), - ([1, 1, 2048, 8192], 3, ttnn.TILE_LAYOUT), + ([1, 1, 2048, 8192], -1, ttnn.TILE_LAYOUT), ], ) @pytest.mark.parametrize( @@ -99,12 +99,12 @@ def test_reduce_scatter_t3k_8chip_nightly( [ ([1, 8, 1024, 1024], 3, ttnn.TILE_LAYOUT), ([1, 4, 1024, 1024], 3, ttnn.TILE_LAYOUT), - ([1, 4, 2048, 1024], 3, ttnn.TILE_LAYOUT), + ([1, 4, 2048, 1024], -1, ttnn.TILE_LAYOUT), ([1, 1, 32, 512], 3, ttnn.TILE_LAYOUT), ([1, 1, 32, 1024], 3, ttnn.TILE_LAYOUT), - ([1, 1, 32, 2048], 3, ttnn.TILE_LAYOUT), + ([1, 1, 32, 2048], -1, ttnn.TILE_LAYOUT), ([1, 1, 128, 1024], 3, ttnn.TILE_LAYOUT), - ([1, 1, 128, 8192], 3, ttnn.TILE_LAYOUT), + ([1, 1, 128, 8192], -1, ttnn.TILE_LAYOUT), ([1, 1, 2048, 1024], 3, ttnn.TILE_LAYOUT), ([1, 1, 2048, 8192], 3, ttnn.TILE_LAYOUT), # These shapes result in some workers with no work, which is currently diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp index 63983bd9f01e..a3e61d497115 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp @@ -9,19 +9,19 @@ namespace ttnn::operations::ccl { ttnn::Tensor ExecuteAllGather::invoke(const ttnn::Tensor& input_tensor, - const uint32_t dim, + const int16_t gather_dim, const uint32_t num_links, const std::optional& memory_config, const std::optional num_workers, const std::optional num_buffers_per_channel, const ttnn::ccl::Topology topology) { return ttnn::operations::ccl::all_gather( - input_tensor, dim, num_links, memory_config, num_workers, num_buffers_per_channel, topology); + input_tensor, gather_dim, num_links, memory_config, num_workers, num_buffers_per_channel, topology); } ttnn::Tensor ExecuteAllGather::invoke( const ttnn::Tensor& input_tensor, - const uint32_t dim, + const int16_t gather_dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, const uint32_t num_links, @@ -30,7 +30,7 @@ ttnn::Tensor ExecuteAllGather::invoke( const std::optional num_buffers_per_channel, const ttnn::ccl::Topology topology) { return ttnn::operations::ccl::all_gather( - input_tensor, dim, cluster_axis, mesh_device, num_links, memory_config, num_workers, num_buffers_per_channel, topology); + input_tensor, gather_dim, cluster_axis, mesh_device, num_links, memory_config, num_workers, num_buffers_per_channel, topology); } } // namespace ttnn::operations::ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.hpp index 1816d4c083da..3dbe3896f94d 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.hpp @@ -14,7 +14,7 @@ namespace ccl { struct ExecuteAllGather { static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, - const uint32_t dim, + const int16_t gather_dim, const uint32_t num_links = 1, const std::optional& memory_config = std::nullopt, const std::optional num_workers = std::nullopt, @@ -23,7 +23,7 @@ struct ExecuteAllGather { static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, - const uint32_t dim, + const int16_t gather_dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, const uint32_t num_links = 1, diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp index 8937ced1230c..b67ebe3a3420 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp @@ -29,16 +29,16 @@ void bind_all_gather(pybind11::module& module, const ccl_operation_t& operation, ttnn::pybind_overload_t{ [](const ccl_operation_t& self, const ttnn::Tensor& input_tensor, - const uint32_t dim, + const int16_t gather_dim, const uint32_t num_links, const std::optional& memory_config, const std::optional num_workers, const std::optional num_buffers_per_channel, const ttnn::ccl::Topology topology) -> ttnn::Tensor { - return self(input_tensor, dim, num_links, memory_config, num_workers, num_buffers_per_channel, topology); + return self(input_tensor, gather_dim, num_links, memory_config, num_workers, num_buffers_per_channel, topology); }, py::arg("input_tensor"), - py::arg("dim"), + py::arg("gather_dim"), py::kw_only(), py::arg("num_links") = 1, py::arg("memory_config") = std::nullopt, @@ -49,7 +49,7 @@ void bind_all_gather(pybind11::module& module, const ccl_operation_t& operation, ttnn::pybind_overload_t{ [](const ccl_operation_t& self, const ttnn::Tensor& input_tensor, - const uint32_t dim, + const int16_t gather_dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, const uint32_t num_links, @@ -57,10 +57,10 @@ void bind_all_gather(pybind11::module& module, const ccl_operation_t& operation, const std::optional num_workers, const std::optional num_buffers_per_channel, const ttnn::ccl::Topology topology) -> ttnn::Tensor { - return self(input_tensor, dim, cluster_axis, mesh_device, num_links, memory_config, num_workers, num_buffers_per_channel, topology); + return self(input_tensor, gather_dim, cluster_axis, mesh_device, num_links, memory_config, num_workers, num_buffers_per_channel, topology); }, py::arg("input_tensor"), - py::arg("dim"), + py::arg("gather_dim"), py::arg("cluster_axis"), py::arg("mesh_device"), py::kw_only(), @@ -84,7 +84,7 @@ void py_bind_all_gather(pybind11::module& module) { Args: input_tensor (ttnn.Tensor): multi-device tensor. - dim (int): Dimension to perform operation. + gather_dim (int): Dimension to perform operation. cluster_axis (int): Provided a MeshTensor, the axis corresponding to MeshDevice to perform the line-all-gather operation on. mesh_device (MeshDevice): Device mesh to perform the line-all-gather operation on. * cluster_axis and mesh_device parameters are applicable only for Linear Topology. @@ -113,7 +113,7 @@ void py_bind_all_gather(pybind11::module& module) { memory_config=mem_config, mesh_mapper=ShardTensor2dMesh(mesh_device, mesh_shape=(1, 8), dims=(-1, -2))) >>> ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device) - >>> output = ttnn.all_gather(ttnn_tensor, dim=0, topology=ttnn.Topology.Ring) + >>> output = ttnn.all_gather(ttnn_tensor, gather_dim=0, topology=ttnn.Topology.Ring) )doc"); } diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index 4957355cf7ee..a556fc1967b9 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -174,7 +174,7 @@ namespace operations { namespace ccl { Tensor all_gather( - const Tensor& input_tensor, const uint32_t dim, const uint32_t num_links, const std::optional& memory_config, const std::optional user_defined_num_workers, const std::optional user_defined_num_buffers_per_channel, const ttnn::ccl::Topology topology) { + const Tensor& input_tensor, const int16_t gather_dim, const uint32_t num_links, const std::optional& memory_config, const std::optional user_defined_num_workers, const std::optional user_defined_num_buffers_per_channel, const ttnn::ccl::Topology topology) { TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "all_gather op is only supported for Fast Dispatch"); auto devices = input_tensor.get_workers(); @@ -185,6 +185,13 @@ Tensor all_gather( if (num_devices == 2){ ccl_topology = ttnn::ccl::Topology::Linear; } + + int16_t rank = input_tensor.get_logical_shape().rank(); + + int16_t dim = (gather_dim < 0) ? rank + gather_dim : gather_dim; + + TT_FATAL(dim >= -rank && dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; operation::launch_op( [dim, num_links, memory_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, ccl_topology]( @@ -205,7 +212,7 @@ Tensor all_gather( Tensor all_gather( const Tensor& input_tensor, - const uint32_t dim, + const int16_t gather_dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, const uint32_t num_links, @@ -218,6 +225,12 @@ Tensor all_gather( const auto mesh_view = mesh_device.get_view(); std::size_t num_devices = (cluster_axis == 0) ? mesh_view->num_rows() : mesh_view->num_cols(); + int16_t rank = input_tensor.get_logical_shape().rank(); + + int16_t dim = (gather_dim < 0) ? rank + gather_dim : gather_dim; + + TT_FATAL(dim >= -rank && dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; operation::launch_op( diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp index b0a162f2a1fa..8776c44e8d7c 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp @@ -200,7 +200,7 @@ namespace ccl { Tensor all_gather( const Tensor& input_tensor, - const uint32_t dim, + const int16_t gather_dim, const uint32_t num_links = 1, const std::optional& memory_config = std::nullopt, const std::optional user_defined_num_workers = std::nullopt, @@ -209,7 +209,7 @@ Tensor all_gather( Tensor all_gather( const Tensor& input_tensor, - const uint32_t dim, + const int16_t gather_dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, const uint32_t num_links = 1, diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp index 0924001d0067..18c33ef6e370 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp @@ -107,7 +107,7 @@ namespace operations{ namespace ccl{ Tensor reduce_scatter( const Tensor& input_tensor, - const uint32_t scatter_dim, + const int16_t scatter_dim, ttnn::operations::reduction::ReduceType math_op, const uint32_t num_links, const MemoryConfig& output_mem_config, @@ -126,9 +126,15 @@ Tensor reduce_scatter( ccl_topology = ttnn::ccl::Topology::Linear; } + int16_t rank = input_tensor.get_logical_shape().rank(); + + int16_t dim = (scatter_dim < 0) ? rank + scatter_dim : scatter_dim; + + TT_FATAL(dim >= -rank && dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; operation::launch_op( - [binary_op_type, scatter_dim, num_links, output_mem_config, ccl_topology, devices, user_defined_num_workers, user_defined_num_buffers_per_channel]( + [binary_op_type, dim, num_links, output_mem_config, ccl_topology, devices, user_defined_num_workers, user_defined_num_buffers_per_channel]( const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { @@ -139,7 +145,7 @@ Tensor reduce_scatter( ttnn::ccl::reduce_scatter_detail::create_reduce_scatter_struct( input_tensor, binary_op_type, - scatter_dim, + dim, num_links, output_mem_config, user_defined_num_workers, @@ -158,7 +164,7 @@ Tensor reduce_scatter( Tensor reduce_scatter( const Tensor &input_tensor, - const uint32_t scatter_dim, + const int16_t scatter_dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, ttnn::operations::reduction::ReduceType reduce_op, @@ -174,10 +180,16 @@ Tensor reduce_scatter( const auto mesh_view = mesh_device.get_view(); std::size_t num_devices = (cluster_axis == 0) ? mesh_view->num_rows() : mesh_view->num_cols(); + int16_t rank = input_tensor.get_logical_shape().rank(); + + int16_t dim = (scatter_dim < 0) ? rank + scatter_dim : scatter_dim; + + TT_FATAL(dim >= -rank && dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; operation::launch_op( - [scatter_dim, binary_op_type, num_links, output_mem_config, mesh_view, cluster_axis, user_defined_num_workers, user_defined_num_buffers_per_channel, num_devices, topology]( + [dim, binary_op_type, num_links, output_mem_config, mesh_view, cluster_axis, user_defined_num_workers, user_defined_num_buffers_per_channel, num_devices, topology]( const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { @@ -206,7 +218,7 @@ Tensor reduce_scatter( return operation::run( ttnn::ReduceScatter{ binary_op_type, - scatter_dim, + dim, num_links, num_devices, device_index, diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp index f26107cda305..c5d1b877ad9f 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp @@ -69,7 +69,7 @@ namespace operations{ namespace ccl{ Tensor reduce_scatter( const Tensor &input_tensor, - const uint32_t scatter_split_dim, + const int16_t scatter_split_dim, ttnn::operations::reduction::ReduceType reduce_op = ttnn::operations::reduction::ReduceType::Sum, const uint32_t num_links = 1, const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, @@ -79,7 +79,7 @@ Tensor reduce_scatter( Tensor reduce_scatter( const ttnn::Tensor &input_tensor, - const uint32_t scatter_dim, + const int16_t scatter_dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, ttnn::operations::reduction::ReduceType reduce_op = ttnn::operations::reduction::ReduceType::Sum, diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter.cpp index ea28f4bd9324..eb8f13c35d90 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter.cpp @@ -10,7 +10,7 @@ namespace ttnn::operations::ccl { ttnn::Tensor ExecuteReduceScatter::invoke( const ttnn::Tensor& input_tensor, - const uint32_t scatter_dim, + const int16_t scatter_dim, ttnn::operations::reduction::ReduceType math_op, const uint32_t num_links, const std::optional& memory_config, @@ -23,7 +23,7 @@ ttnn::Tensor ExecuteReduceScatter::invoke( } ttnn::Tensor ExecuteReduceScatter::invoke( const ttnn::Tensor& input_tensor, - const uint32_t scatter_dim, + const int16_t scatter_dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, ttnn::operations::reduction::ReduceType math_op, diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter.hpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter.hpp index b7acc80e7943..1762b3fa6be4 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter.hpp @@ -17,7 +17,7 @@ namespace ccl { struct ExecuteReduceScatter { static ttnn::Tensor invoke( const Tensor &input_tensor, - const uint32_t scatter_dim, + const int16_t scatter_dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, ttnn::operations::reduction::ReduceType reduce_op = ttnn::operations::reduction::ReduceType::Sum, @@ -29,7 +29,7 @@ struct ExecuteReduceScatter { static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, - const uint32_t scatter_dim, + const int16_t scatter_dim, ttnn::operations::reduction::ReduceType math_op, const uint32_t num_links = 1, const std::optional& memory_config = std::nullopt, diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.cpp index bfac2f9a1d1e..518babc671c1 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.cpp @@ -26,7 +26,7 @@ void bind_reduce_scatter(pybind11::module& module, const ccl_operation_t& operat ttnn::pybind_overload_t{ [](const ccl_operation_t& self, const ttnn::Tensor& input_tensor, - const uint32_t scatter_dim, + const int16_t scatter_dim, ttnn::operations::reduction::ReduceType math_op, const uint32_t num_links, const ttnn::MemoryConfig& memory_config, @@ -48,7 +48,7 @@ void bind_reduce_scatter(pybind11::module& module, const ccl_operation_t& operat ttnn::pybind_overload_t{ [](const ccl_operation_t& self, const ttnn::Tensor& input_tensor, - const uint32_t scatter_dim, + const int16_t scatter_dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, ttnn::operations::reduction::ReduceType math_op, @@ -86,7 +86,7 @@ void py_bind_reduce_scatter(pybind11::module& module) { Args: input_tensor (ttnn.Tensor): multi-device tensor - dim (int): Dimension to perform operation + scatter_dim (int): Dimension to perform operation cluster_axis (int): Provided a MeshTensor, the axis corresponding to MeshDevice to perform the line-all-gather operation on. mesh_device (MeshDevice): Device mesh to perform the line-all-gather operation on. * cluster_axis and mesh_device parameters are applicable only for Linear Topology. @@ -107,8 +107,8 @@ void py_bind_reduce_scatter(pybind11::module& module) { >>> full_tensor = torch.randn([1, 1, 256, 256], dtype=torch.bfloat16) >>> num_devices = 8 - >>> dim = 3 - >>> input_tensors = torch.chunk(full_tensor, num_devices, dim) + >>> scatter_dim = 3 + >>> input_tensors = torch.chunk(full_tensor, num_devices, scatter_dim) >>> physical_device_ids = ttnn.get_t3k_physical_device_ids_ring() >>> mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, 8), physical_device_ids=physical_device_ids[:8]) >>> tt_input_tensors = [] @@ -116,7 +116,7 @@ void py_bind_reduce_scatter(pybind11::module& module) { tt_input_tensors.append(ttnn.Tensor(t, input_dtype).to(layout).to(mesh_device.get_devices()[i], mem_config)) >>> input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors) - >>> output = ttnn.reduce_scatter(input_tensor_mesh, dim=0, topology=ttnn.Topology.Linear) + >>> output = ttnn.reduce_scatter(input_tensor_mesh, scatter_dim=0, topology=ttnn.Topology.Linear) )doc"); }