Skip to content

Commit

Permalink
#8131: resnet-50 fix for b20.
Browse files Browse the repository at this point in the history
  • Loading branch information
shwetankTT committed May 9, 2024
1 parent e18e0e6 commit a287711
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions models/experimental/resnet/tt/ttnn_functional_resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,12 +635,8 @@ def __call__(self, input_tensor) -> ttnn.Tensor:
if is_wormhole_b0() and self.batch_size == 20:
# TODO: fix the need to do the reshard here
x = ttnn.to_memory_config(x, ttnn.L1_MEMORY_CONFIG)
x = ttnn.to_layout(
x,
ttnn.ROW_MAJOR_LAYOUT,
memory_config=self.max_pool.max_pool.input_sharded_memory_config,
)
# x = ttnn.to_memory_config(x, self.max_pool.max_pool.input_sharded_memory_config)
x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT)
x = ttnn.to_memory_config(x, self.max_pool.max_pool.input_sharded_memory_config)
x = self.max_pool(x)

x = ttnn.reshape(x, (1, 1, 56 * 56 * self.batch_size, 64))
Expand Down

0 comments on commit a287711

Please sign in to comment.