Skip to content

Commit

Permalink
#9486: Replace ttl.tensor.all_gather with ttnn.all_gather (#9517)
Browse files Browse the repository at this point in the history
* #9486: change ttl all_gather call with ttnn

* #0: correct arguments
  • Loading branch information
ayerofieiev-tt authored Jun 18, 2024
1 parent ca0904f commit 8c96c4f
Show file tree
Hide file tree
Showing 12 changed files with 48 additions and 46 deletions.
8 changes: 4 additions & 4 deletions models/demos/t3000/falcon40b/tt/falcon_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,11 +517,11 @@ def fwd_prefill(
output_mem_config=self.model_config["CONCAT_HEADS_OUTPUT_MEMCFG"],
)

attn_output = ttnn.experimental.tensor.all_gather(
attn_output = ttnn.all_gather(
attn_output,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
output_mem_config=self.model_config["DEFAULT_MEMCFG"],
memory_config=self.model_config["DEFAULT_MEMCFG"],
)

for i in range(len(attn_output)):
Expand Down Expand Up @@ -807,11 +807,11 @@ def fwd_decode(
attn_output[i], output_mem_config=self.model_config["DEFAULT_MEMCFG"]
)

attn_output = ttnn.experimental.tensor.all_gather(
attn_output = ttnn.all_gather(
attn_output,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
output_mem_config=self.model_config["DEFAULT_MEMCFG"],
memory_config=self.model_config["DEFAULT_MEMCFG"],
)

for i in range(len(attn_output)):
Expand Down
12 changes: 6 additions & 6 deletions models/demos/t3000/falcon40b/tt/falcon_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,11 @@ def fwd_prefill(
replicated_hidden_states[i], self.model_config["BFP8_DTYPE"]
)

replicated_hidden_states = ttnn.experimental.tensor.all_gather(
replicated_hidden_states = ttnn.all_gather(
replicated_hidden_states,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
dim=3,
output_mem_config=self.model_config["DEFAULT_MEMCFG"],
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
memory_config=self.model_config["DEFAULT_MEMCFG"],
)

if self.model_config["LN_INPUT_DTYPE"] != self.model_config["BFP8_DTYPE"]:
Expand Down Expand Up @@ -362,11 +362,11 @@ def fwd_decode(
hidden_states[i], output_mem_config=self.model_config["DEFAULT_MEMCFG"]
)
)
replicated_hidden_states = ttnn.experimental.tensor.all_gather(
replicated_hidden_states = ttnn.all_gather(
replicated_hidden_states,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
dim=3,
output_mem_config=self.model_config["DEFAULT_MEMCFG"],
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
memory_config=self.model_config["DEFAULT_MEMCFG"],
)
for i in range(len(replicated_hidden_states)):
replicated_hidden_states[i] = ttnn.experimental.tensor.interleaved_to_sharded(
Expand Down
8 changes: 4 additions & 4 deletions models/demos/t3000/falcon40b/tt/falcon_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ def fwd_decode(self, x: List[ttnn.experimental.tensor.Tensor]) -> List[ttnn.expe
hidden_states[i] = ttnn.experimental.tensor.sharded_to_interleaved(
hidden_states[i], output_mem_config=self.model_config["DEFAULT_MEMCFG"]
)
hidden_states = ttnn.experimental.tensor.all_gather(
hidden_states = ttnn.all_gather(
hidden_states,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
output_mem_config=self.model_config["DEFAULT_MEMCFG"],
memory_config=self.model_config["DEFAULT_MEMCFG"],
)
for i in range(len(hidden_states)):
hidden_states[i] = ttnn.experimental.tensor.interleaved_to_sharded(
Expand Down Expand Up @@ -169,11 +169,11 @@ def fwd_prefill(self, x: List[ttnn.experimental.tensor.Tensor]) -> List[ttnn.exp
if should_deallocate_ln_tensors:
x[i].deallocate(True)

hidden_states = ttnn.experimental.tensor.all_gather(
hidden_states = ttnn.all_gather(
hidden_states,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
output_mem_config=self.model_config["DEFAULT_MEMCFG"],
memory_config=self.model_config["DEFAULT_MEMCFG"],
)

for i in range(len(hidden_states)):
Expand Down
8 changes: 4 additions & 4 deletions models/demos/t3000/falcon40b/tt/falcon_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,11 +339,11 @@ def fwd_prefill(
for i in range(len(layer_output)):
layer_output[i] = ttnn.experimental.tensor.typecast(layer_output[i], self.model_config["BFP8_DTYPE"])

layer_output = ttnn.experimental.tensor.all_gather(
layer_output = ttnn.all_gather(
layer_output,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
output_mem_config=self.model_config["DEFAULT_MEMCFG"],
memory_config=self.model_config["DEFAULT_MEMCFG"],
)

if self.model_config["LN_INPUT_DTYPE"] != self.model_config["BFP8_DTYPE"]:
Expand Down Expand Up @@ -399,11 +399,11 @@ def fwd_decode(
layer_output[i] = ttnn.experimental.tensor.sharded_to_interleaved(
layer_output[i], output_mem_config=self.model_config["DEFAULT_MEMCFG"]
)
layer_output = ttnn.experimental.tensor.all_gather(
layer_output = ttnn.all_gather(
layer_output,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
output_mem_config=self.model_config["DEFAULT_MEMCFG"],
memory_config=self.model_config["DEFAULT_MEMCFG"],
)
for i in range(len(layer_output)):
layer_output[i] = ttnn.experimental.tensor.interleaved_to_sharded(
Expand Down
8 changes: 4 additions & 4 deletions models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,11 +649,11 @@ def attn_selfout(
if self.emulated:
attn_output = tt_all_gather_torch(attn_output, dim=-1)
else:
attn_output = tt_lib.tensor.all_gather(
attn_output = ttnn.all_gather(
attn_output,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
output_mem_config=self.model_config["L1_MEMCFG"],
memory_config=self.model_config["L1_MEMCFG"],
)

for i in range(len(attn_output)):
Expand Down Expand Up @@ -911,11 +911,11 @@ def prefill_attn_selfout(self, attn_output: List[tt_lib.tensor.Tensor]) -> List[
if self.emulated:
attn_output = tt_all_gather_torch(attn_output, dim=-1)
else:
attn_output = tt_lib.tensor.all_gather(
attn_output = ttnn.all_gather(
attn_output,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
output_mem_config=self.model_config["DRAM_MEMCFG"],
memory_config=self.model_config["DRAM_MEMCFG"],
)

for i in range(len(attn_output)):
Expand Down
16 changes: 8 additions & 8 deletions models/demos/t3000/llama2_70b/tt/llama_decoder_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,11 @@ def decode_forward(
if self.emulated:
xs_replicated = tt_all_gather_torch(xs_replicated, dim=-1)
else:
xs_replicated = tt_lib.tensor.all_gather(
xs_replicated = ttnn.all_gather(
xs_replicated,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
output_mem_config=self.model_config["L1_MEMCFG"],
memory_config=self.model_config["L1_MEMCFG"],
)

for i in range(self.num_devices):
Expand Down Expand Up @@ -360,11 +360,11 @@ def decode_forward(
if self.emulated:
attn_resid_replicated = tt_all_gather_torch(attn_resid_replicated, dim=-1)
else:
attn_resid_replicated = tt_lib.tensor.all_gather(
attn_resid_replicated = ttnn.all_gather(
attn_resid_replicated,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
output_mem_config=self.model_config["L1_MEMCFG"],
memory_config=self.model_config["L1_MEMCFG"],
)

for i in range(self.num_devices):
Expand Down Expand Up @@ -480,11 +480,11 @@ def prefill_forward(
if self.emulated:
xs_replicated = tt_all_gather_torch(xs_replicated, dim=-1)
else:
xs_replicated = tt_lib.tensor.all_gather(
xs_replicated = ttnn.all_gather(
xs_replicated,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
output_mem_config=self.model_config["DRAM_MEMCFG"],
memory_config=self.model_config["DRAM_MEMCFG"],
)

attn_norm_interleaved = self.sharded_rmsnorm(xs_replicated, self.norm_eps, self.attn_norm_list)
Expand Down Expand Up @@ -515,11 +515,11 @@ def prefill_forward(
if self.emulated:
attn_resid_replicated = tt_all_gather_torch(attn_resid_replicated, dim=-1)
else:
attn_resid_replicated = tt_lib.tensor.all_gather(
attn_resid_replicated = ttnn.all_gather(
attn_resid_replicated,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
output_mem_config=self.model_config["L1_MEMCFG"],
memory_config=self.model_config["L1_MEMCFG"],
)

ffn_norm_interleaved = self.sharded_rmsnorm(attn_resid_replicated, self.norm_eps, self.ffn_norm_list)
Expand Down
6 changes: 3 additions & 3 deletions models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def prefill_forward(self, x: List[tt_lib.tensor.Tensor]) -> List[tt_lib.tensor.T
if self.emulated:
hidden_states = tt_all_gather_torch(hidden_states, dim=-1)
else:
hidden_states = tt_lib.tensor.all_gather(
hidden_states = ttnn.all_gather(
hidden_states,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
Expand Down Expand Up @@ -270,11 +270,11 @@ def decode_forward(self, x: List[tt_lib.tensor.Tensor]) -> List[tt_lib.tensor.Te
if self.emulated:
hidden_states = tt_all_gather_torch(hidden_states, dim=-1)
else:
hidden_states = tt_lib.tensor.all_gather(
hidden_states = ttnn.all_gather(
hidden_states,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
output_mem_config=self.model_config["L1_MEMCFG"],
memory_config=self.model_config["L1_MEMCFG"],
)

# Put AllGather results in L1 Sharded
Expand Down
8 changes: 4 additions & 4 deletions models/demos/t3000/llama2_70b/tt/llama_model_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,11 +368,11 @@ def decode_forward(
if self.emulated:
xs = tt_all_gather_torch(xs, dim=-1)
else:
xs = tt_lib.tensor.all_gather(
xs = ttnn.all_gather(
xs,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
output_mem_config=self.model_config["L1_MEMCFG"],
memory_config=self.model_config["L1_MEMCFG"],
)

## Duplicate layernorm
Expand Down Expand Up @@ -492,11 +492,11 @@ def prefill_forward(
if self.emulated:
xs = tt_all_gather_torch(xs, dim=-1)
else:
xs = tt_lib.tensor.all_gather(
xs = ttnn.all_gather(
xs,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
output_mem_config=self.model_config["DRAM_MEMCFG"],
memory_config=self.model_config["DRAM_MEMCFG"],
)

## Duplicate layernorm
Expand Down
8 changes: 4 additions & 4 deletions models/experimental/llama2_70b/tt/llama_decoder_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def decode_forward(
# if self.emulated:
# xs_replicated = tt_all_gather_torch(xs_replicated, dim=-1)
# else:
# xs_replicated = tt_lib.tensor.all_gather(
# xs_replicated = ttnn.all_gather(
# xs_replicated,
# dim=3,
# num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
Expand Down Expand Up @@ -351,7 +351,7 @@ def decode_forward(
# if self.emulated:
# attn_resid_replicated = tt_all_gather_torch(attn_resid_replicated, dim=-1)
# else:
# attn_resid_replicated = tt_lib.tensor.all_gather(
# attn_resid_replicated = ttnn.all_gather(
# attn_resid_replicated,
# dim=3,
# num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
Expand Down Expand Up @@ -468,7 +468,7 @@ def prefill_forward(
# if self.emulated:
# xs_replicated = tt_all_gather_torch(xs_replicated, dim=-1)
# else:
# xs_replicated = tt_lib.tensor.all_gather(
# xs_replicated = ttnn.all_gather(
# xs_replicated,
# dim=3,
# num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
Expand Down Expand Up @@ -497,7 +497,7 @@ def prefill_forward(
# if self.emulated:
# attn_resid_replicated = tt_all_gather_torch(attn_resid_replicated, dim=-1)
# else:
# attn_resid_replicated = tt_lib.tensor.all_gather(
# attn_resid_replicated = ttnn.all_gather(
# attn_resid_replicated,
# dim=3,
# num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
Expand Down
4 changes: 2 additions & 2 deletions models/experimental/llama2_70b/tt/llama_model_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def decode_forward(
# if self.emulated:
# xs = tt_all_gather_torch(xs, dim=-1)
# else:
# xs = tt_lib.tensor.all_gather(
# xs = ttnn.all_gather(
# xs,
# dim=3,
# num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
Expand Down Expand Up @@ -456,7 +456,7 @@ def prefill_forward(
# if self.emulated:
# xs = tt_all_gather_torch(xs, dim=-1)
# else:
# xs = tt_lib.tensor.all_gather(
# xs = ttnn.all_gather(
# xs,
# dim=3,
# num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from loguru import logger
import tt_lib as ttl
import ttnn
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc
from models.utility_functions import skip_for_grayskull, get_devices_for_t3000
import itertools
Expand Down Expand Up @@ -109,7 +110,7 @@ def run_all_gather_on_t3000_impl(
tt_input_tensors.append(ttl.tensor.Tensor(t, input_dtype).to(layout).to(devices[i], mem_config))

for i in range(num_iters):
tt_out_tensors = ttl.tensor.all_gather(tt_input_tensors, dim, num_links, output_mem_config=mem_config)
tt_out_tensors = ttnn.all_gather(tt_input_tensors, dim, num_links=num_links, memory_config=mem_config)

for d in devices:
ttl.device.Synchronize(d)
Expand Down Expand Up @@ -779,7 +780,7 @@ def test_all_gather_post_commit_sharded(
for i, t in enumerate(input_tensors):
tt_input_tensors.append(ttl.tensor.Tensor(t, input_dtype).to(tensor_layout).to(devices[i], input_mem_config))

tt_out_tensors = ttl.tensor.all_gather(tt_input_tensors, dim, num_links, output_mem_config=output_mem_config)
tt_out_tensors = ttnn.all_gather(tt_input_tensors, dim, num_links=num_links, memory_config=output_mem_config)
for d in devices:
ttl.device.Synchronize(d)
torch.set_printoptions(sci_mode=False)
Expand Down Expand Up @@ -833,7 +834,7 @@ def test_all_gather_fp32(
for i, t in enumerate(input_tensors):
tt_input_tensors.append(ttl.tensor.Tensor(t, ttl.tensor.DataType.FLOAT32).to(layout).to(devices[i], mem_config))

tt_out_tensors = ttl.tensor.all_gather(tt_input_tensors, dim, num_links, output_mem_config=mem_config)
tt_out_tensors = ttnn.all_gather(tt_input_tensors, dim, num_links=num_links, memory_config=mem_config)

for i, t in enumerate(tt_out_tensors):
tt_output_tensor = t.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()
Expand Down
1 change: 1 addition & 0 deletions ttnn/cpp/pybind11/operations/ccl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ void bind_ccl_operation(py::module& module, const ccl_operation_t& operation, co
},
py::arg("input_tensor"),
py::arg("dim"),
py::kw_only(),
py::arg("num_links") = 1,
py::arg("memory_config") = std::nullopt});
}
Expand Down

0 comments on commit 8c96c4f

Please sign in to comment.