From fc1e2be64baec5ac6488303b12370e78f559c9d5 Mon Sep 17 00:00:00 2001 From: Aswinmcw Date: Tue, 1 Oct 2024 05:16:05 +0000 Subject: [PATCH] #13136: Fix api after rebase --- models/demos/wormhole/llama31_8b_N300/demo/demo.py | 4 ++-- models/demos/wormhole/llama31_8b_N300/tt/llama_attention.py | 4 ++-- models/demos/wormhole/llama31_8b_N300/tt/llama_mlp.py | 2 +- .../Programming Mesh of Devices with TT-NN.md | 2 +- tests/ttnn/unit_tests/operations/test_all_gather_nightly.py | 4 ++-- ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp | 2 ++ .../ttnn/operations/ccl/all_gather/device/all_gather_op.cpp | 2 +- 7 files changed, 11 insertions(+), 9 deletions(-) diff --git a/models/demos/wormhole/llama31_8b_N300/demo/demo.py b/models/demos/wormhole/llama31_8b_N300/demo/demo.py index 76f1a1c9fdd3..c99db8278159 100644 --- a/models/demos/wormhole/llama31_8b_N300/demo/demo.py +++ b/models/demos/wormhole/llama31_8b_N300/demo/demo.py @@ -309,7 +309,7 @@ def run_llama_demo_n300(user_input, batch_size, device_mesh, instruct_mode, is_c # Compile decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) tt_out = tt_model(decode_input, current_pos, current_pos_attn, rot_mat=current_rot_mat) - tt_out_gathered = ttnn.line_all_gather(tt_out, dim=3, num_links=1) + tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear) ttnn.deallocate(tt_out) tt_out_rm = ttnn.untilize(tt_out_gathered, use_multicore=True) ttnn.deallocate(tt_out_gathered) @@ -325,7 +325,7 @@ def run_llama_demo_n300(user_input, batch_size, device_mesh, instruct_mode, is_c decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) tt_out = tt_model(decode_input, current_pos, current_pos_attn, rot_mat=current_rot_mat) - tt_out_gathered = ttnn.line_all_gather(tt_out, dim=3, num_links=1) + tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear) ttnn.deallocate(tt_out) tt_out_rm = ttnn.untilize(tt_out_gathered, use_multicore=True) ttnn.deallocate(tt_out_gathered) diff --git a/models/demos/wormhole/llama31_8b_N300/tt/llama_attention.py b/models/demos/wormhole/llama31_8b_N300/tt/llama_attention.py index 91cc160a4734..14283898a7f7 100644 --- a/models/demos/wormhole/llama31_8b_N300/tt/llama_attention.py +++ b/models/demos/wormhole/llama31_8b_N300/tt/llama_attention.py @@ -358,7 +358,7 @@ def forward_decode( dense_outputs.append(dense_out) # All reduce - dense_out_gathered = ttnn.line_all_gather(dense_out, dim=1, num_links=1) + dense_out_gathered = ttnn.all_gather(dense_out, dim=1, num_links=1, topology=ttnn.Topology.Linear) dense_out_reduced = ttnn.experimental.fast_reduce_nc( dense_out_gathered, dims=[1], output=None, compute_kernel_config=None ) @@ -504,7 +504,7 @@ def forward_prefill(self, xs_11SH, rot_mats, transformation_mats, user_id: int = attn_output_11SH.deallocate(True) # All reduce - dense_out_gathered = ttnn.line_all_gather(output_11SH, dim=1, num_links=1) + dense_out_gathered = ttnn.all_gather(output_11SH, dim=1, num_links=1, topology=ttnn.Topology.Linear) dense_out_reduced = ttnn.experimental.fast_reduce_nc( dense_out_gathered, dims=[1], output=None, compute_kernel_config=None ) diff --git a/models/demos/wormhole/llama31_8b_N300/tt/llama_mlp.py b/models/demos/wormhole/llama31_8b_N300/tt/llama_mlp.py index 5610add88336..f2c33b5214b6 100644 --- a/models/demos/wormhole/llama31_8b_N300/tt/llama_mlp.py +++ b/models/demos/wormhole/llama31_8b_N300/tt/llama_mlp.py @@ -148,7 +148,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: w2_out = ttnn.reshape(w2_out, [1, 1, seq_len, -1]) # All reduce - w2_out_gathered = ttnn.line_all_gather(w2_out, dim=1, num_links=1) + w2_out_gathered = ttnn.all_gather(w2_out, dim=1, num_links=1, topology=ttnn.Topology.Linear) w2_out_reduced = ttnn.experimental.fast_reduce_nc( w2_out_gathered, dims=[1], output=None, compute_kernel_config=None ) diff --git a/tech_reports/Programming Mesh of Devices/Programming Mesh of Devices with TT-NN.md b/tech_reports/Programming Mesh of Devices/Programming Mesh of Devices with TT-NN.md index bb10360833bf..e76a59392529 100644 --- a/tech_reports/Programming Mesh of Devices/Programming Mesh of Devices with TT-NN.md +++ b/tech_reports/Programming Mesh of Devices/Programming Mesh of Devices with TT-NN.md @@ -332,7 +332,7 @@ mesh_tensor = ttnn.from_torch( ) # Execute Line All-Gather on the tensor -output_tensor = ttnn.line_all_gather(mesh_tensor, dim=3, cluster_axis=0, mesh_device=mesh_device) +output_tensor = ttnn.all_gather(mesh_tensor, dim=3, cluster_axis=0, mesh_device=mesh_device, topology=ttnn.Topology.Linear) ``` diff --git a/tests/ttnn/unit_tests/operations/test_all_gather_nightly.py b/tests/ttnn/unit_tests/operations/test_all_gather_nightly.py index 279f629cfb9e..511e13e30ad3 100644 --- a/tests/ttnn/unit_tests/operations/test_all_gather_nightly.py +++ b/tests/ttnn/unit_tests/operations/test_all_gather_nightly.py @@ -73,7 +73,7 @@ def test_line_all_gather_on_t3000_nightly( mem_config, use_program_cache, function_level_defaults, - all_gather_operation=ttnn.line_all_gather, + all_gather_topology=ttnn.Topology.Linear, enable_async=enable_async, num_iters=num_iters, ) @@ -132,7 +132,7 @@ def test_line_all_gather_on_t3000_nightly_two_link( mem_config, use_program_cache, function_level_defaults, - all_gather_operation=ttnn.line_all_gather, + all_gather_topology=ttnn.Topology.Linear, num_iters=num_iters, enable_async=enable_async, ) diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp index 12512cb024dc..8a6dabdcc7da 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp @@ -92,6 +92,8 @@ void py_bind_all_gather(pybind11::module& module) { * :attr:`mesh_device` (MeshDevice): Device mesh to perform the line-all-gather operation on. + Mesh Tensor Programming Guide : https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/Programming%20Mesh%20of%20Devices/Programming%20Mesh%20of%20Devices%20with%20TT-NN.md + Keyword Args: * :attr:`num_links` (int): Number of links to use for the all-gather operation. * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index f9e409046017..a9dc1cc23b91 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -233,7 +233,7 @@ Tensor all_gather( const std::optional user_defined_num_buffers_per_channel, const ttnn::ccl::Topology topology) { - TT_FATAL(topology == ttnn::ccl::Topology::Linear, "This api currently supported only for Linear topology"); + TT_FATAL(topology == ttnn::ccl::Topology::Linear, "This all_gather API with cluster_axis is currently supported only for the Linear topology"); const auto mesh_view = mesh_device.get_view(); std::size_t num_devices = (cluster_axis == 0) ? mesh_view->num_rows() : mesh_view->num_cols();