Skip to content

Commit

Permalink
#6206: Enable reshard op in Resnet50
Browse files Browse the repository at this point in the history
  • Loading branch information
yan-zaretskiy committed Mar 14, 2024
1 parent ff18321 commit 926415b
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions models/demos/resnet/tt/metalResnetBlock50.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,14 +1465,15 @@ def __init__(
{
tt_lib.tensor.CoreRange(
tt_lib.tensor.CoreCoord(0, 0),
tt_lib.tensor.CoreCoord(11, 7),
tt_lib.tensor.CoreCoord(10, 7),
),
tt_lib.tensor.CoreRange(
tt_lib.tensor.CoreCoord(0, 8),
tt_lib.tensor.CoreCoord(3, 8),
),
}
)
self.n_fold_cores = 92

self.shard_grid = tt_lib.tensor.CoreRangeSet(
{
Expand Down Expand Up @@ -2060,7 +2061,7 @@ def preprocessing_with_fold(self, x: torch.Tensor) -> tt_lib.tensor:
NHW_even = _nearest_y(NHW // stride_h, self.first_conv_num_cores_nhw * 32)

shard_spec = tt_lib.tensor.ShardSpec(
self.fold_grid, [NHW // 100, x.shape[3]], tt_lib.tensor.ShardOrientation.ROW_MAJOR, False
self.fold_grid, [NHW // self.n_fold_cores, x.shape[3]], tt_lib.tensor.ShardOrientation.ROW_MAJOR, False
)
x = torch2tt_tensor(
x,
Expand All @@ -2076,25 +2077,25 @@ def preprocessing_with_fold(self, x: torch.Tensor) -> tt_lib.tensor:
# fold for unity stride on device
x = tt_lib.tensor.fold(x, stride_h=stride_h, stride_w=1)

# non-optimal resharding via the interleaved round trip, because
# direct resharding from 100 to 98 cores breaks the reshard op
x = tt_lib.tensor.sharded_to_interleaved(
shard_shape = [
NHW_even // self.first_conv_num_cores_nhw,
x.get_legacy_shape()[3],
]

x = tt_lib.tensor.reshard(
x,
output_mem_config=tt_lib.tensor.MemoryConfig(
tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.L1
tt_lib.tensor.MemoryConfig(
tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED,
tt_lib.tensor.BufferType.L1,
tt_lib.tensor.ShardSpec(
self.shard_grid,
shard_shape,
tt_lib.tensor.ShardOrientation.ROW_MAJOR,
False,
),
),
)

x = tt_lib.tensor.interleaved_to_sharded(
x,
self.shard_grid,
[
NHW_even // self.first_conv_num_cores_nhw,
x.get_legacy_shape()[3],
],
tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED,
tt_lib.tensor.ShardOrientation.ROW_MAJOR,
)
return x

def forward(self, x: tt_lib.tensor) -> tt_lib.tensor:
Expand Down

0 comments on commit 926415b

Please sign in to comment.