Skip to content

Commit

Permalink
#13136: Fix api after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Oct 1, 2024
1 parent 9309ca1 commit fc1e2be
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 9 deletions.
4 changes: 2 additions & 2 deletions models/demos/wormhole/llama31_8b_N300/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def run_llama_demo_n300(user_input, batch_size, device_mesh, instruct_mode, is_c
# Compile
decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok))
tt_out = tt_model(decode_input, current_pos, current_pos_attn, rot_mat=current_rot_mat)
tt_out_gathered = ttnn.line_all_gather(tt_out, dim=3, num_links=1)
tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear)
ttnn.deallocate(tt_out)
tt_out_rm = ttnn.untilize(tt_out_gathered, use_multicore=True)
ttnn.deallocate(tt_out_gathered)
Expand All @@ -325,7 +325,7 @@ def run_llama_demo_n300(user_input, batch_size, device_mesh, instruct_mode, is_c

decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok))
tt_out = tt_model(decode_input, current_pos, current_pos_attn, rot_mat=current_rot_mat)
tt_out_gathered = ttnn.line_all_gather(tt_out, dim=3, num_links=1)
tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear)
ttnn.deallocate(tt_out)
tt_out_rm = ttnn.untilize(tt_out_gathered, use_multicore=True)
ttnn.deallocate(tt_out_gathered)
Expand Down
4 changes: 2 additions & 2 deletions models/demos/wormhole/llama31_8b_N300/tt/llama_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def forward_decode(
dense_outputs.append(dense_out)

# All reduce
dense_out_gathered = ttnn.line_all_gather(dense_out, dim=1, num_links=1)
dense_out_gathered = ttnn.all_gather(dense_out, dim=1, num_links=1, topology=ttnn.Topology.Linear)
dense_out_reduced = ttnn.experimental.fast_reduce_nc(
dense_out_gathered, dims=[1], output=None, compute_kernel_config=None
)
Expand Down Expand Up @@ -504,7 +504,7 @@ def forward_prefill(self, xs_11SH, rot_mats, transformation_mats, user_id: int =
attn_output_11SH.deallocate(True)

# All reduce
dense_out_gathered = ttnn.line_all_gather(output_11SH, dim=1, num_links=1)
dense_out_gathered = ttnn.all_gather(output_11SH, dim=1, num_links=1, topology=ttnn.Topology.Linear)
dense_out_reduced = ttnn.experimental.fast_reduce_nc(
dense_out_gathered, dims=[1], output=None, compute_kernel_config=None
)
Expand Down
2 changes: 1 addition & 1 deletion models/demos/wormhole/llama31_8b_N300/tt/llama_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor:
w2_out = ttnn.reshape(w2_out, [1, 1, seq_len, -1])

# All reduce
w2_out_gathered = ttnn.line_all_gather(w2_out, dim=1, num_links=1)
w2_out_gathered = ttnn.all_gather(w2_out, dim=1, num_links=1, topology=ttnn.Topology.Linear)
w2_out_reduced = ttnn.experimental.fast_reduce_nc(
w2_out_gathered, dims=[1], output=None, compute_kernel_config=None
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ mesh_tensor = ttnn.from_torch(
)

# Execute Line All-Gather on the tensor
output_tensor = ttnn.line_all_gather(mesh_tensor, dim=3, cluster_axis=0, mesh_device=mesh_device)
output_tensor = ttnn.all_gather(mesh_tensor, dim=3, cluster_axis=0, mesh_device=mesh_device, topology=ttnn.Topology.Linear)
```


Expand Down
4 changes: 2 additions & 2 deletions tests/ttnn/unit_tests/operations/test_all_gather_nightly.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_line_all_gather_on_t3000_nightly(
mem_config,
use_program_cache,
function_level_defaults,
all_gather_operation=ttnn.line_all_gather,
all_gather_topology=ttnn.Topology.Linear,
enable_async=enable_async,
num_iters=num_iters,
)
Expand Down Expand Up @@ -132,7 +132,7 @@ def test_line_all_gather_on_t3000_nightly_two_link(
mem_config,
use_program_cache,
function_level_defaults,
all_gather_operation=ttnn.line_all_gather,
all_gather_topology=ttnn.Topology.Linear,
num_iters=num_iters,
enable_async=enable_async,
)
Expand Down
2 changes: 2 additions & 0 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ void py_bind_all_gather(pybind11::module& module) {
* :attr:`mesh_device` (MeshDevice):
Device mesh to perform the line-all-gather operation on.
Mesh Tensor Programming Guide : https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/Programming%20Mesh%20of%20Devices/Programming%20Mesh%20of%20Devices%20with%20TT-NN.md
Keyword Args:
* :attr:`num_links` (int): Number of links to use for the all-gather operation.
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ Tensor all_gather(
const std::optional<size_t> user_defined_num_buffers_per_channel,
const ttnn::ccl::Topology topology) {

TT_FATAL(topology == ttnn::ccl::Topology::Linear, "This api currently supported only for Linear topology");
TT_FATAL(topology == ttnn::ccl::Topology::Linear, "This all_gather API with cluster_axis is currently supported only for the Linear topology");
const auto mesh_view = mesh_device.get_view();
std::size_t num_devices = (cluster_axis == 0) ? mesh_view->num_rows() : mesh_view->num_cols();

Expand Down

0 comments on commit fc1e2be

Please sign in to comment.