diff --git a/models/demos/resnet/tt/metalResnetBlock50.py b/models/demos/resnet/tt/metalResnetBlock50.py index e5c21983e44..fa5a11b8c82 100644 --- a/models/demos/resnet/tt/metalResnetBlock50.py +++ b/models/demos/resnet/tt/metalResnetBlock50.py @@ -1465,7 +1465,7 @@ def __init__( { tt_lib.tensor.CoreRange( tt_lib.tensor.CoreCoord(0, 0), - tt_lib.tensor.CoreCoord(11, 7), + tt_lib.tensor.CoreCoord(10, 7), ), tt_lib.tensor.CoreRange( tt_lib.tensor.CoreCoord(0, 8), @@ -1473,6 +1473,7 @@ def __init__( ), } ) + self.n_fold_cores = 92 self.shard_grid = tt_lib.tensor.CoreRangeSet( { @@ -2060,7 +2061,7 @@ def preprocessing_with_fold(self, x: torch.Tensor) -> tt_lib.tensor: NHW_even = _nearest_y(NHW // stride_h, self.first_conv_num_cores_nhw * 32) shard_spec = tt_lib.tensor.ShardSpec( - self.fold_grid, [NHW // 100, x.shape[3]], tt_lib.tensor.ShardOrientation.ROW_MAJOR, False + self.fold_grid, [NHW // self.n_fold_cores, x.shape[3]], tt_lib.tensor.ShardOrientation.ROW_MAJOR, False ) x = torch2tt_tensor( x, @@ -2076,25 +2077,25 @@ def preprocessing_with_fold(self, x: torch.Tensor) -> tt_lib.tensor: # fold for unity stride on device x = tt_lib.tensor.fold(x, stride_h=stride_h, stride_w=1) - # non-optimal resharding via the interleaved round trip, because - # direct resharding from 100 to 98 cores breaks the reshard op - x = tt_lib.tensor.sharded_to_interleaved( + shard_shape = [ + NHW_even // self.first_conv_num_cores_nhw, + x.get_legacy_shape()[3], + ] + + x = tt_lib.tensor.reshard( x, - output_mem_config=tt_lib.tensor.MemoryConfig( - tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.L1 + tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + tt_lib.tensor.BufferType.L1, + tt_lib.tensor.ShardSpec( + self.shard_grid, + shard_shape, + tt_lib.tensor.ShardOrientation.ROW_MAJOR, + False, + ), ), ) - x = tt_lib.tensor.interleaved_to_sharded( - x, - self.shard_grid, - [ - NHW_even // self.first_conv_num_cores_nhw, - x.get_legacy_shape()[3], - ], - tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, - tt_lib.tensor.ShardOrientation.ROW_MAJOR, - ) return x def forward(self, x: tt_lib.tensor) -> tt_lib.tensor: