Skip to content

Commit

Permalink
#4438: Make Resnet use the new fold op instead of plain Python implem…
Browse files Browse the repository at this point in the history
…entation.
  • Loading branch information
yan-zaretskiy committed Mar 1, 2024
1 parent 0cf156a commit f1953b6
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 170 deletions.
89 changes: 44 additions & 45 deletions models/demos/resnet/tt/metalResnetBlock50.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
import torch
import torch.nn as nn
import math
from loguru import logger
from models.demos.resnet.utils import fold_bn_to_conv_weights_bias
from models.utility_functions import tt2torch_tensor
from models.utility_functions import tt2torch_tensor, torch2tt_tensor
from tt_lib.utils import pad_weight

from models.utility_functions import is_wormhole_b0, is_grayskull
from models.utility_functions import is_grayskull
from tt_lib.fused_ops.average_pool import run_avg_pool_on_device_wrapper as TtAvgPool
from tt_lib.fused_ops.max_pool import run_max_pool_on_device_wrapper as TtMaxPool
from tt_lib.fused_ops.max_pool import compute_max_pool_shape
Expand All @@ -28,12 +27,11 @@
SlidingWindowOpParamsWithParallelConfig,
)
from tt_eager.tt_dnn.op_library.sliding_window_op_infra.tt_py_max_pool import TTPyMaxPool
from tt_eager.tt_dnn.op_library.sliding_window_op_infra.tt_py_untilize_with_halo import TTPyUntilizeWithHalo

from models.utility_functions import (
_nearest_32,
pad_and_fold_conv_activation_for_unity_stride,
pad_and_fold_conv_filters_for_unity_stride,
pad_and_fold_conv_activation_for_unity_stride,
)

hardcoded_matmul_config_linear = {
Expand Down Expand Up @@ -1456,6 +1454,19 @@ def __init__(

self.first_conv_num_cores_nhw = 98
if sharded:
self.shard_grid = tt_lib.tensor.CoreRangeSet(
{
tt_lib.tensor.CoreRange(
tt_lib.tensor.CoreCoord(0, 0),
tt_lib.tensor.CoreCoord(11, 7),
),
tt_lib.tensor.CoreRange(
tt_lib.tensor.CoreCoord(0, 8),
tt_lib.tensor.CoreCoord(1, 8),
),
}
)

self.folded_conv1_params = [self.inplanes, 16, 4, 4, 1, 1, 0, 0, 1, groups]
first_conv_output_padded_nhw_size = _nearest_y(112 * 112 * batch_size, 98 * 32)
first_conv_output_channels = 64
Expand Down Expand Up @@ -1986,56 +1997,44 @@ def _make_layer(

def preprocessing(self, x: torch.Tensor) -> tt_lib.tensor:
if self.sharded:
x = pad_and_fold_conv_activation_for_unity_stride(x, 3, 3, 2, 2)
# NCWH -> NWHC
x = torch.permute(x, (0, 2, 3, 1))
x = x.reshape(
1,
1,
x.shape[0] * x.shape[1] * x.shape[2],
x.shape[3],
)
input_size_to_shard_evenly = _nearest_y(x.shape[2], self.first_conv_num_cores_nhw * 32)
x = torch.nn.functional.pad(x, (0, 0, 0, input_size_to_shard_evenly - x.shape[2], 0, 0))

x = tt_lib.tensor.Tensor(x, tt_lib.tensor.DataType.BFLOAT16)
else:
extra_padding_for_32B_alignment = 25
x = torch.nn.functional.pad(x, (3, 4 + extra_padding_for_32B_alignment, 3, 3, 0, 1))
x = torch.permute(x, (0, 2, 3, 1))
x = tt_lib.tensor.Tensor(x, tt_lib.tensor.DataType.BFLOAT16)
return x
# pad to 230x230x4
C = _nearest_y(x.shape[3], 4)
x = torch.nn.functional.pad(x, (0, C - x.shape[3], 3, 3, 3, 3))

def forward(self, x: tt_lib.tensor) -> tt_lib.tensor:
if self.sharded:
untilize_with_halo_input_shard_height = (int)(x.shape()[2] / self.first_conv_num_cores_nhw)
# fold for unity stride on device
x = torch2tt_tensor(x, self.device, tt_layout=tt_lib.tensor.Layout.ROW_MAJOR)
x = tt_lib.tensor.fold(x, 2, 2)

shard_grid = tt_lib.tensor.CoreRangeSet(
{
tt_lib.tensor.CoreRange(
tt_lib.tensor.CoreCoord(0, 0),
tt_lib.tensor.CoreCoord(11, 7),
),
tt_lib.tensor.CoreRange(
tt_lib.tensor.CoreCoord(0, 8),
tt_lib.tensor.CoreCoord(1, 8),
),
}
)
shard_spec = tt_lib.tensor.ShardSpec(
shard_grid,
# reshape to (1, 1, NHW, C) and pad the NHW dim for sharding
x = x.reshape(1, 1, -1, x.shape()[-1])

_, _, NHW, C = x.shape()
input_size_to_shard_evenly = _nearest_y(NHW, self.first_conv_num_cores_nhw * 32)
padded_shape = [1, 1, input_size_to_shard_evenly, C]
x = tt_lib.tensor.pad(x, padded_shape, [0, 0, 0, 0], 0)

x = tt_lib.tensor.interleaved_to_sharded(
x,
self.shard_grid,
[
untilize_with_halo_input_shard_height,
x.shape()[2] // self.first_conv_num_cores_nhw,
x.shape()[3],
],
tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED,
tt_lib.tensor.ShardOrientation.ROW_MAJOR,
False,
)
mem_config = tt_lib.tensor.MemoryConfig(
tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.L1, shard_spec
)
x = x.to(self.device, mem_config)

else:
extra_padding_for_32B_alignment = 25
x = torch.nn.functional.pad(x, (3, 4 + extra_padding_for_32B_alignment, 3, 3, 0, 1))
x = torch.permute(x, (0, 2, 3, 1))
x = tt_lib.tensor.Tensor(x, tt_lib.tensor.DataType.BFLOAT16)
return x

def forward(self, x: tt_lib.tensor) -> tt_lib.tensor:
if not self.sharded:
original_A_cl_host_shape = x.shape()
x = x.reshape(x.shape()[0], x.shape()[1], 1, x.shape()[2] * x.shape()[3])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import tt_lib as ttl

from models.utility_functions import skip_for_wormhole_b0, torch2tt_tensor
from models.utility_functions import skip_for_wormhole_b0, torch2tt_tensor, tt2torch_tensor


def fold_torch(input_tensor, stride_h, stride_w):
Expand All @@ -30,6 +30,7 @@ def fold_torch(input_tensor, stride_h, stride_w):
((10, 6, 8, 32), 3, 1),
((10, 6, 8, 32), 1, 2),
((10, 6, 8, 32), 1, 1),
((1, 6, 6, 4), 2, 2),
],
)
def test_fold(act_shape, stride_h, stride_w, device):
Expand All @@ -42,11 +43,9 @@ def test_fold(act_shape, stride_h, stride_w, device):
torch_input,
device,
ttl.tensor.Layout.ROW_MAJOR,
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED),
)

tt_out = ttl.tensor.fold(tt_input, stride_h, stride_w)
tt_out = tt_out.cpu()
actual = tt_out.to_torch()
actual = tt2torch_tensor(tt_out)

torch.testing.assert_allclose(actual, expected)
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@
#include "tt_dnn/op_library/fold/fold_op.hpp"
#include "tt_dnn/op_library/math.hpp"

namespace {
uint32_t single_row_all_channels_size(const Tensor &x) { return x.shape()[2] * x.shape()[3] * x.element_size(); }
} // namespace

namespace tt::tt_metal {
operation::ProgramWithCallbacks fold_single_core(
const Tensor &input, const Tensor &output, uint8_t stride_h, uint8_t stride_w) {
Expand Down Expand Up @@ -68,9 +64,8 @@ operation::ProgramWithCallbacks fold_single_core(
src_log2_unit_size,
};

uint32_t dst_unit_size_is_power_of_two = is_power_of_two_at_least_32(pixel_size * stride_h * stride_w);
uint32_t dst_log2_unit_size =
src_unit_size_is_power_of_two ? (std::uint32_t)log2(pixel_size * stride_h * stride_w) : 0;
uint32_t dst_unit_size_is_power_of_two = is_power_of_two_at_least_32(aligned_dst_pixel_size);
uint32_t dst_log2_unit_size = dst_unit_size_is_power_of_two ? (std::uint32_t)log2(aligned_dst_pixel_size) : 0;

std::vector<uint32_t> writer_compile_time_args = {
cb_dst0_index,
Expand Down
10 changes: 4 additions & 6 deletions tt_eager/tt_dnn/op_library/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
//
// SPDX-License-Identifier: Apache-2.0

namespace tt {

namespace tt_metal {
#include "tt_metal/common/assert.hpp"
namespace tt::tt_metal {

template <typename T>
bool is_power_of_two(T val) {
return (val & (val-1))==T(0);
return (val & (val - 1)) == T(0);
}

template <typename T>
Expand All @@ -22,6 +22,4 @@ bool is_power_of_two_at_least_32(T val) {
return is_power_of_two_at_least(val, T(32));
}

} // namespace metal

} // namespace tt
} // namespace tt::tt_metal
Loading

0 comments on commit f1953b6

Please sign in to comment.