Skip to content

Commit

Permalink
#13368: fix mlp reshard
Browse files Browse the repository at this point in the history
  • Loading branch information
mtairum committed Oct 29, 2024
1 parent 8c28a0a commit 9efcb57
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions models/demos/llama3/tt/llama_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor:
mode == "decode"
): # TODO Add a check for a match between FF1/FF3 and FF2 memory configs. If they match avoid doing the reshard
# Reshard w2_in to a different core_grid configuration. Avoid using ttnn.reshard() due to incompatibility with trace mode
w2_in = ttnn.sharded_to_interleaved(w2_in, ttnn.L1_MEMORY_CONFIG)
w2_in = ttnn.interleaved_to_sharded(w2_in, self.model_config["SHARDED_MLP2_INPUT_MEMCFG"])
w2_in = ttnn.reshard(w2_in, self.model_config["SHARDED_MLP2_INPUT_MEMCFG"])

ttnn.deallocate(w3_out)
ttnn.deallocate(w1_out)
Expand Down

0 comments on commit 9efcb57

Please sign in to comment.