diff --git a/models/experimental/resnet/tt/ttnn_functional_resnet50.py b/models/experimental/resnet/tt/ttnn_functional_resnet50.py index 3234e8f7f7fa..ade604947723 100644 --- a/models/experimental/resnet/tt/ttnn_functional_resnet50.py +++ b/models/experimental/resnet/tt/ttnn_functional_resnet50.py @@ -635,12 +635,8 @@ def __call__(self, input_tensor) -> ttnn.Tensor: if is_wormhole_b0() and self.batch_size == 20: # TODO: fix the need to do the reshard here x = ttnn.to_memory_config(x, ttnn.L1_MEMORY_CONFIG) - x = ttnn.to_layout( - x, - ttnn.ROW_MAJOR_LAYOUT, - memory_config=self.max_pool.max_pool.input_sharded_memory_config, - ) - # x = ttnn.to_memory_config(x, self.max_pool.max_pool.input_sharded_memory_config) + x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) + x = ttnn.to_memory_config(x, self.max_pool.max_pool.input_sharded_memory_config) x = self.max_pool(x) x = ttnn.reshape(x, (1, 1, 56 * 56 * self.batch_size, 64))