Skip to content

Commit

Permalink
#9486: Replace ttdnn op with ttnn
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Jul 22, 2024
1 parent 635950e commit 21309a6
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions models/demos/t3000/falcon40b/tt/falcon_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ def fwd_decode(self, x: List[ttnn.experimental.tensor.Tensor]) -> List[ttnn.expe
hidden_states
) # Workaround for reduce_scatter only taking a vector of tensors and not device_mesh

hidden_states = ttnn.experimental.tensor.reduce_scatter(
hidden_states = ttnn.reduce_scatter(
hidden_states,
scatter_split_dim=3,
reduce_op=ttnn.experimental.tensor.ReduceOpMath.SUM,
scatter_dim=3,
math_op=ttnn.experimental.tensor.ReduceOpMath.SUM,
num_links=1, # only unidirectional supported for now
output_mem_config=self.model_config["DEFAULT_MEMCFG"],
memory_config=self.model_config["DEFAULT_MEMCFG"],
)

hidden_states = ttnn.aggregate_as_tensor(hidden_states) # Workaround reverse
Expand Down Expand Up @@ -198,12 +198,12 @@ def fwd_prefill(self, x: List[ttnn.experimental.tensor.Tensor]) -> List[ttnn.exp
self.output
) # Workaround for reduce_scatter only taking a vector of tensors and not device_mesh

hidden_states = ttnn.experimental.tensor.reduce_scatter(
hidden_states = ttnn.reduce_scatter(
hidden_states,
scatter_split_dim=3,
reduce_op=ttnn.experimental.tensor.ReduceOpMath.SUM,
scatter_dim=3,
math_op=ttnn.experimental.tensor.ReduceOpMath.SUM,
num_links=1, # only one link supported for now
output_mem_config=self.model_config["DEFAULT_MEMCFG"],
memory_config=self.model_config["DEFAULT_MEMCFG"],
)

hidden_states = ttnn.aggregate_as_tensor(hidden_states) # Workaround reverse
Expand Down

0 comments on commit 21309a6

Please sign in to comment.