From 8c96c4f31b63055b8d69d9b79896ff5e3550f594 Mon Sep 17 00:00:00 2001 From: Artem Yerofieiev <169092593+ayerofieiev-tt@users.noreply.github.com> Date: Tue, 18 Jun 2024 15:15:12 -0400 Subject: [PATCH] #9486: Replace ttl.tensor.all_gather with ttnn.all_gather (#9517) * #9486: change ttl all_gather call with ttnn * #0: correct arguments --- .../demos/t3000/falcon40b/tt/falcon_attention.py | 8 ++++---- .../demos/t3000/falcon40b/tt/falcon_decoder.py | 12 ++++++------ models/demos/t3000/falcon40b/tt/falcon_mlp.py | 8 ++++---- models/demos/t3000/falcon40b/tt/falcon_model.py | 8 ++++---- .../llama2_70b/tt/llama_attention_optimized.py | 8 ++++---- .../llama2_70b/tt/llama_decoder_optimized.py | 16 ++++++++-------- .../t3000/llama2_70b/tt/llama_mlp_optimized.py | 6 +++--- .../t3000/llama2_70b/tt/llama_model_optimized.py | 8 ++++---- .../llama2_70b/tt/llama_decoder_optimized.py | 8 ++++---- .../llama2_70b/tt/llama_model_optimized.py | 4 ++-- .../unit_testing/misc/test_all_gather.py | 7 ++++--- ttnn/cpp/pybind11/operations/ccl.hpp | 1 + 12 files changed, 48 insertions(+), 46 deletions(-) diff --git a/models/demos/t3000/falcon40b/tt/falcon_attention.py b/models/demos/t3000/falcon40b/tt/falcon_attention.py index 81da88afa39..743e96c89b0 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_attention.py +++ b/models/demos/t3000/falcon40b/tt/falcon_attention.py @@ -517,11 +517,11 @@ def fwd_prefill( output_mem_config=self.model_config["CONCAT_HEADS_OUTPUT_MEMCFG"], ) - attn_output = ttnn.experimental.tensor.all_gather( + attn_output = ttnn.all_gather( attn_output, dim=3, num_links=self.model_config["ALL_GATHER_NUM_LINKS"], - output_mem_config=self.model_config["DEFAULT_MEMCFG"], + memory_config=self.model_config["DEFAULT_MEMCFG"], ) for i in range(len(attn_output)): @@ -807,11 +807,11 @@ def fwd_decode( attn_output[i], output_mem_config=self.model_config["DEFAULT_MEMCFG"] ) - attn_output = ttnn.experimental.tensor.all_gather( + attn_output = ttnn.all_gather( attn_output, dim=3, num_links=self.model_config["ALL_GATHER_NUM_LINKS"], - output_mem_config=self.model_config["DEFAULT_MEMCFG"], + memory_config=self.model_config["DEFAULT_MEMCFG"], ) for i in range(len(attn_output)): diff --git a/models/demos/t3000/falcon40b/tt/falcon_decoder.py b/models/demos/t3000/falcon40b/tt/falcon_decoder.py index 9de02f5cd34..61ec15f4afd 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_decoder.py +++ b/models/demos/t3000/falcon40b/tt/falcon_decoder.py @@ -241,11 +241,11 @@ def fwd_prefill( replicated_hidden_states[i], self.model_config["BFP8_DTYPE"] ) - replicated_hidden_states = ttnn.experimental.tensor.all_gather( + replicated_hidden_states = ttnn.all_gather( replicated_hidden_states, - num_links=self.model_config["ALL_GATHER_NUM_LINKS"], dim=3, - output_mem_config=self.model_config["DEFAULT_MEMCFG"], + num_links=self.model_config["ALL_GATHER_NUM_LINKS"], + memory_config=self.model_config["DEFAULT_MEMCFG"], ) if self.model_config["LN_INPUT_DTYPE"] != self.model_config["BFP8_DTYPE"]: @@ -362,11 +362,11 @@ def fwd_decode( hidden_states[i], output_mem_config=self.model_config["DEFAULT_MEMCFG"] ) ) - replicated_hidden_states = ttnn.experimental.tensor.all_gather( + replicated_hidden_states = ttnn.all_gather( replicated_hidden_states, - num_links=self.model_config["ALL_GATHER_NUM_LINKS"], dim=3, - output_mem_config=self.model_config["DEFAULT_MEMCFG"], + num_links=self.model_config["ALL_GATHER_NUM_LINKS"], + memory_config=self.model_config["DEFAULT_MEMCFG"], ) for i in range(len(replicated_hidden_states)): replicated_hidden_states[i] = ttnn.experimental.tensor.interleaved_to_sharded( diff --git a/models/demos/t3000/falcon40b/tt/falcon_mlp.py b/models/demos/t3000/falcon40b/tt/falcon_mlp.py index 674a883ab8b..039641126bb 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_mlp.py +++ b/models/demos/t3000/falcon40b/tt/falcon_mlp.py @@ -125,11 +125,11 @@ def fwd_decode(self, x: List[ttnn.experimental.tensor.Tensor]) -> List[ttnn.expe hidden_states[i] = ttnn.experimental.tensor.sharded_to_interleaved( hidden_states[i], output_mem_config=self.model_config["DEFAULT_MEMCFG"] ) - hidden_states = ttnn.experimental.tensor.all_gather( + hidden_states = ttnn.all_gather( hidden_states, dim=3, num_links=self.model_config["ALL_GATHER_NUM_LINKS"], - output_mem_config=self.model_config["DEFAULT_MEMCFG"], + memory_config=self.model_config["DEFAULT_MEMCFG"], ) for i in range(len(hidden_states)): hidden_states[i] = ttnn.experimental.tensor.interleaved_to_sharded( @@ -169,11 +169,11 @@ def fwd_prefill(self, x: List[ttnn.experimental.tensor.Tensor]) -> List[ttnn.exp if should_deallocate_ln_tensors: x[i].deallocate(True) - hidden_states = ttnn.experimental.tensor.all_gather( + hidden_states = ttnn.all_gather( hidden_states, dim=3, num_links=self.model_config["ALL_GATHER_NUM_LINKS"], - output_mem_config=self.model_config["DEFAULT_MEMCFG"], + memory_config=self.model_config["DEFAULT_MEMCFG"], ) for i in range(len(hidden_states)): diff --git a/models/demos/t3000/falcon40b/tt/falcon_model.py b/models/demos/t3000/falcon40b/tt/falcon_model.py index b918ff989b1..9b40f157ddb 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_model.py +++ b/models/demos/t3000/falcon40b/tt/falcon_model.py @@ -339,11 +339,11 @@ def fwd_prefill( for i in range(len(layer_output)): layer_output[i] = ttnn.experimental.tensor.typecast(layer_output[i], self.model_config["BFP8_DTYPE"]) - layer_output = ttnn.experimental.tensor.all_gather( + layer_output = ttnn.all_gather( layer_output, dim=3, num_links=self.model_config["ALL_GATHER_NUM_LINKS"], - output_mem_config=self.model_config["DEFAULT_MEMCFG"], + memory_config=self.model_config["DEFAULT_MEMCFG"], ) if self.model_config["LN_INPUT_DTYPE"] != self.model_config["BFP8_DTYPE"]: @@ -399,11 +399,11 @@ def fwd_decode( layer_output[i] = ttnn.experimental.tensor.sharded_to_interleaved( layer_output[i], output_mem_config=self.model_config["DEFAULT_MEMCFG"] ) - layer_output = ttnn.experimental.tensor.all_gather( + layer_output = ttnn.all_gather( layer_output, dim=3, num_links=self.model_config["ALL_GATHER_NUM_LINKS"], - output_mem_config=self.model_config["DEFAULT_MEMCFG"], + memory_config=self.model_config["DEFAULT_MEMCFG"], ) for i in range(len(layer_output)): layer_output[i] = ttnn.experimental.tensor.interleaved_to_sharded( diff --git a/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py index 98111d9dbcb..699f7990e21 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py @@ -649,11 +649,11 @@ def attn_selfout( if self.emulated: attn_output = tt_all_gather_torch(attn_output, dim=-1) else: - attn_output = tt_lib.tensor.all_gather( + attn_output = ttnn.all_gather( attn_output, dim=3, num_links=self.model_config["ALL_GATHER_NUM_LINKS"], - output_mem_config=self.model_config["L1_MEMCFG"], + memory_config=self.model_config["L1_MEMCFG"], ) for i in range(len(attn_output)): @@ -911,11 +911,11 @@ def prefill_attn_selfout(self, attn_output: List[tt_lib.tensor.Tensor]) -> List[ if self.emulated: attn_output = tt_all_gather_torch(attn_output, dim=-1) else: - attn_output = tt_lib.tensor.all_gather( + attn_output = ttnn.all_gather( attn_output, dim=3, num_links=self.model_config["ALL_GATHER_NUM_LINKS"], - output_mem_config=self.model_config["DRAM_MEMCFG"], + memory_config=self.model_config["DRAM_MEMCFG"], ) for i in range(len(attn_output)): diff --git a/models/demos/t3000/llama2_70b/tt/llama_decoder_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_decoder_optimized.py index dc61e92664b..c7a58fad1e0 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_decoder_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_decoder_optimized.py @@ -305,11 +305,11 @@ def decode_forward( if self.emulated: xs_replicated = tt_all_gather_torch(xs_replicated, dim=-1) else: - xs_replicated = tt_lib.tensor.all_gather( + xs_replicated = ttnn.all_gather( xs_replicated, dim=3, num_links=self.model_config["ALL_GATHER_NUM_LINKS"], - output_mem_config=self.model_config["L1_MEMCFG"], + memory_config=self.model_config["L1_MEMCFG"], ) for i in range(self.num_devices): @@ -360,11 +360,11 @@ def decode_forward( if self.emulated: attn_resid_replicated = tt_all_gather_torch(attn_resid_replicated, dim=-1) else: - attn_resid_replicated = tt_lib.tensor.all_gather( + attn_resid_replicated = ttnn.all_gather( attn_resid_replicated, dim=3, num_links=self.model_config["ALL_GATHER_NUM_LINKS"], - output_mem_config=self.model_config["L1_MEMCFG"], + memory_config=self.model_config["L1_MEMCFG"], ) for i in range(self.num_devices): @@ -480,11 +480,11 @@ def prefill_forward( if self.emulated: xs_replicated = tt_all_gather_torch(xs_replicated, dim=-1) else: - xs_replicated = tt_lib.tensor.all_gather( + xs_replicated = ttnn.all_gather( xs_replicated, dim=3, num_links=self.model_config["ALL_GATHER_NUM_LINKS"], - output_mem_config=self.model_config["DRAM_MEMCFG"], + memory_config=self.model_config["DRAM_MEMCFG"], ) attn_norm_interleaved = self.sharded_rmsnorm(xs_replicated, self.norm_eps, self.attn_norm_list) @@ -515,11 +515,11 @@ def prefill_forward( if self.emulated: attn_resid_replicated = tt_all_gather_torch(attn_resid_replicated, dim=-1) else: - attn_resid_replicated = tt_lib.tensor.all_gather( + attn_resid_replicated = ttnn.all_gather( attn_resid_replicated, dim=3, num_links=self.model_config["ALL_GATHER_NUM_LINKS"], - output_mem_config=self.model_config["L1_MEMCFG"], + memory_config=self.model_config["L1_MEMCFG"], ) ffn_norm_interleaved = self.sharded_rmsnorm(attn_resid_replicated, self.norm_eps, self.ffn_norm_list) diff --git a/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py index 32e38ccbb50..368369092e6 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py @@ -210,7 +210,7 @@ def prefill_forward(self, x: List[tt_lib.tensor.Tensor]) -> List[tt_lib.tensor.T if self.emulated: hidden_states = tt_all_gather_torch(hidden_states, dim=-1) else: - hidden_states = tt_lib.tensor.all_gather( + hidden_states = ttnn.all_gather( hidden_states, dim=3, num_links=self.model_config["ALL_GATHER_NUM_LINKS"], @@ -270,11 +270,11 @@ def decode_forward(self, x: List[tt_lib.tensor.Tensor]) -> List[tt_lib.tensor.Te if self.emulated: hidden_states = tt_all_gather_torch(hidden_states, dim=-1) else: - hidden_states = tt_lib.tensor.all_gather( + hidden_states = ttnn.all_gather( hidden_states, dim=3, num_links=self.model_config["ALL_GATHER_NUM_LINKS"], - output_mem_config=self.model_config["L1_MEMCFG"], + memory_config=self.model_config["L1_MEMCFG"], ) # Put AllGather results in L1 Sharded diff --git a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py index 74c669447e0..8a770b214d9 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py @@ -368,11 +368,11 @@ def decode_forward( if self.emulated: xs = tt_all_gather_torch(xs, dim=-1) else: - xs = tt_lib.tensor.all_gather( + xs = ttnn.all_gather( xs, dim=3, num_links=self.model_config["ALL_GATHER_NUM_LINKS"], - output_mem_config=self.model_config["L1_MEMCFG"], + memory_config=self.model_config["L1_MEMCFG"], ) ## Duplicate layernorm @@ -492,11 +492,11 @@ def prefill_forward( if self.emulated: xs = tt_all_gather_torch(xs, dim=-1) else: - xs = tt_lib.tensor.all_gather( + xs = ttnn.all_gather( xs, dim=3, num_links=self.model_config["ALL_GATHER_NUM_LINKS"], - output_mem_config=self.model_config["DRAM_MEMCFG"], + memory_config=self.model_config["DRAM_MEMCFG"], ) ## Duplicate layernorm diff --git a/models/experimental/llama2_70b/tt/llama_decoder_optimized.py b/models/experimental/llama2_70b/tt/llama_decoder_optimized.py index 5316a6390a8..6e9fdc7f274 100644 --- a/models/experimental/llama2_70b/tt/llama_decoder_optimized.py +++ b/models/experimental/llama2_70b/tt/llama_decoder_optimized.py @@ -298,7 +298,7 @@ def decode_forward( # if self.emulated: # xs_replicated = tt_all_gather_torch(xs_replicated, dim=-1) # else: - # xs_replicated = tt_lib.tensor.all_gather( + # xs_replicated = ttnn.all_gather( # xs_replicated, # dim=3, # num_links=self.model_config["ALL_GATHER_NUM_LINKS"], @@ -351,7 +351,7 @@ def decode_forward( # if self.emulated: # attn_resid_replicated = tt_all_gather_torch(attn_resid_replicated, dim=-1) # else: - # attn_resid_replicated = tt_lib.tensor.all_gather( + # attn_resid_replicated = ttnn.all_gather( # attn_resid_replicated, # dim=3, # num_links=self.model_config["ALL_GATHER_NUM_LINKS"], @@ -468,7 +468,7 @@ def prefill_forward( # if self.emulated: # xs_replicated = tt_all_gather_torch(xs_replicated, dim=-1) # else: - # xs_replicated = tt_lib.tensor.all_gather( + # xs_replicated = ttnn.all_gather( # xs_replicated, # dim=3, # num_links=self.model_config["ALL_GATHER_NUM_LINKS"], @@ -497,7 +497,7 @@ def prefill_forward( # if self.emulated: # attn_resid_replicated = tt_all_gather_torch(attn_resid_replicated, dim=-1) # else: - # attn_resid_replicated = tt_lib.tensor.all_gather( + # attn_resid_replicated = ttnn.all_gather( # attn_resid_replicated, # dim=3, # num_links=self.model_config["ALL_GATHER_NUM_LINKS"], diff --git a/models/experimental/llama2_70b/tt/llama_model_optimized.py b/models/experimental/llama2_70b/tt/llama_model_optimized.py index bfb08035880..a7af31f8612 100644 --- a/models/experimental/llama2_70b/tt/llama_model_optimized.py +++ b/models/experimental/llama2_70b/tt/llama_model_optimized.py @@ -340,7 +340,7 @@ def decode_forward( # if self.emulated: # xs = tt_all_gather_torch(xs, dim=-1) # else: - # xs = tt_lib.tensor.all_gather( + # xs = ttnn.all_gather( # xs, # dim=3, # num_links=self.model_config["ALL_GATHER_NUM_LINKS"], @@ -456,7 +456,7 @@ def prefill_forward( # if self.emulated: # xs = tt_all_gather_torch(xs, dim=-1) # else: - # xs = tt_lib.tensor.all_gather( + # xs = ttnn.all_gather( # xs, # dim=3, # num_links=self.model_config["ALL_GATHER_NUM_LINKS"], diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py index bfbd33511f6..502e7d03cb4 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py @@ -6,6 +6,7 @@ import pytest from loguru import logger import tt_lib as ttl +import ttnn from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc from models.utility_functions import skip_for_grayskull, get_devices_for_t3000 import itertools @@ -109,7 +110,7 @@ def run_all_gather_on_t3000_impl( tt_input_tensors.append(ttl.tensor.Tensor(t, input_dtype).to(layout).to(devices[i], mem_config)) for i in range(num_iters): - tt_out_tensors = ttl.tensor.all_gather(tt_input_tensors, dim, num_links, output_mem_config=mem_config) + tt_out_tensors = ttnn.all_gather(tt_input_tensors, dim, num_links=num_links, memory_config=mem_config) for d in devices: ttl.device.Synchronize(d) @@ -779,7 +780,7 @@ def test_all_gather_post_commit_sharded( for i, t in enumerate(input_tensors): tt_input_tensors.append(ttl.tensor.Tensor(t, input_dtype).to(tensor_layout).to(devices[i], input_mem_config)) - tt_out_tensors = ttl.tensor.all_gather(tt_input_tensors, dim, num_links, output_mem_config=output_mem_config) + tt_out_tensors = ttnn.all_gather(tt_input_tensors, dim, num_links=num_links, memory_config=output_mem_config) for d in devices: ttl.device.Synchronize(d) torch.set_printoptions(sci_mode=False) @@ -833,7 +834,7 @@ def test_all_gather_fp32( for i, t in enumerate(input_tensors): tt_input_tensors.append(ttl.tensor.Tensor(t, ttl.tensor.DataType.FLOAT32).to(layout).to(devices[i], mem_config)) - tt_out_tensors = ttl.tensor.all_gather(tt_input_tensors, dim, num_links, output_mem_config=mem_config) + tt_out_tensors = ttnn.all_gather(tt_input_tensors, dim, num_links=num_links, memory_config=mem_config) for i, t in enumerate(tt_out_tensors): tt_output_tensor = t.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() diff --git a/ttnn/cpp/pybind11/operations/ccl.hpp b/ttnn/cpp/pybind11/operations/ccl.hpp index fa68102a988..cb4c680d2cc 100644 --- a/ttnn/cpp/pybind11/operations/ccl.hpp +++ b/ttnn/cpp/pybind11/operations/ccl.hpp @@ -35,6 +35,7 @@ void bind_ccl_operation(py::module& module, const ccl_operation_t& operation, co }, py::arg("input_tensor"), py::arg("dim"), + py::kw_only(), py::arg("num_links") = 1, py::arg("memory_config") = std::nullopt}); }