Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
yan-zaretskiy committed Feb 26, 2024
1 parent e98d1b1 commit ff0a18f
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions models/demos/resnet/tt/metalResnetBlock50.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
import torch
import torch.nn as nn
import math
from loguru import logger
from models.demos.resnet.utils import fold_bn_to_conv_weights_bias
from models.utility_functions import tt2torch_tensor
from tt_lib.utils import pad_weight

from models.utility_functions import is_wormhole_b0, is_grayskull
from models.utility_functions import is_grayskull
from tt_lib.fused_ops.average_pool import run_avg_pool_on_device_wrapper as TtAvgPool
from tt_lib.fused_ops.max_pool import run_max_pool_on_device_wrapper as TtMaxPool
from tt_lib.fused_ops.max_pool import compute_max_pool_shape
Expand All @@ -28,11 +27,9 @@
SlidingWindowOpParamsWithParallelConfig,
)
from tt_eager.tt_dnn.op_library.sliding_window_op_infra.tt_py_max_pool import TTPyMaxPool
from tt_eager.tt_dnn.op_library.sliding_window_op_infra.tt_py_untilize_with_halo import TTPyUntilizeWithHalo

from models.utility_functions import (
_nearest_32,
pad_and_fold_conv_activation_for_unity_stride,
pad_and_fold_conv_filters_for_unity_stride,
)

Expand Down Expand Up @@ -1985,18 +1982,27 @@ def _make_layer(

def preprocessing(self, x: torch.Tensor) -> tt_lib.tensor:
if self.sharded:
x = pad_and_fold_conv_activation_for_unity_stride(x, 3, 3, 2, 2)
# NCWH -> NWHC
x = torch.permute(x, (0, 2, 3, 1))
x = x.reshape(
1,
1,
x.shape[0] * x.shape[1] * x.shape[2],
x.shape[3],
)
input_size_to_shard_evenly = _nearest_y(x.shape[2], self.first_conv_num_cores_nhw * 32)
x = torch.nn.functional.pad(x, (0, 0, 0, input_size_to_shard_evenly - x.shape[2], 0, 0))

# pad to 230x230x4
C = _nearest_y(x.shape[3], 4)
x = torch.nn.functional.pad(x, (0, C - x.shape[3], 3, 3, 3, 3))

x = tt_lib.tensor.Tensor(x, tt_lib.tensor.DataType.BFLOAT16)
x = x.to(tt_lib.tensor.Layout.ROW_MAJOR)
x = x.to(self.device, tt_lib.tensor.MemoryConfig(tt_lib.tensor.TensorMemoryLayout.INTERLEAVED))

# fold for unity stride on device
x = tt_lib.tensor.fold(x, 2, 2)

N, H, W, C = x.shape()

x = x.reshape(1, 1, N * H * W, C)

input_size_to_shard_evenly = _nearest_y(N * H * W, self.first_conv_num_cores_nhw * 32)
padded_shape = [1, 1, input_size_to_shard_evenly, C]
x = tt_lib.tensor.pad(x, padded_shape, [0, 0, 0, 0], 0, use_multicore=True)
else:
extra_padding_for_32B_alignment = 25
x = torch.nn.functional.pad(x, (3, 4 + extra_padding_for_32B_alignment, 3, 3, 0, 1))
Expand Down

0 comments on commit ff0a18f

Please sign in to comment.