Skip to content

Commit

Permalink
#4957: remove construct_2d_input_tensor, instead use pytorch to pad i…
Browse files Browse the repository at this point in the history
…nput tensor
  • Loading branch information
vtangTT committed Feb 28, 2024
1 parent 9fd6760 commit 4840421
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,14 @@ def test_generate_all_configs_and_references(
input_padded_height = input_h + 2 * pad_h
# Generate following configs by tracing conv -
logger.info("Trace conv and generate following configs - pad_metadata and data_top_left_indices.")
(
pad_metadata,
data_top_left_indices,
input_padded_tensor,
) = trace_conv_to_generate_data_top_left_indices_and_pad_metadata(
conv_params, input_nchw_shape, input_pyt_tensor.reshape(-1).tolist()
pad_metadata, data_top_left_indices = trace_conv_to_generate_data_top_left_indices_and_pad_metadata(
conv_params, input_nchw_shape
)

logger.info("Generate input tensor")
input_padded_pyt_tensor = torch.nn.functional.pad(input_pyt_tensor, (pad_w, pad_w, pad_h, pad_h), value=0)
input_padded_pyt_tensor = input_padded_pyt_tensor.permute(0, 2, 3, 1)
input_padded_tensor = input_padded_pyt_tensor.reshape(-1).tolist()
# run trace conv reference to validate pad_metadata and data_top_left_indices
logger.info("Validate pad_metadata and data_top_left_indices.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,16 +324,16 @@ def test_generate_all_configs_and_references(
input_padded_height = input_h + 2 * pad_h
# Generate following configs by tracing conv -
logger.info("Trace conv and generate follwing configs - pad_metadata and data_top_left_indices.")
(
pad_metadata,
data_top_left_indices,
input_padded_tensor,
) = trace_conv_to_generate_data_top_left_indices_and_pad_metadata(
conv_params, input_nchw_shape, input_pyt_tensor.reshape(-1).tolist()
pad_metadata, data_top_left_indices = trace_conv_to_generate_data_top_left_indices_and_pad_metadata(
conv_params, input_nchw_shape
)
# # print("Data top left indices - ", data_top_left_indices)
# # print("Pad meta data -", pad_metadata)

logger.info("Generate input tensor")
input_padded_pyt_tensor = torch.nn.functional.pad(input_pyt_tensor, (pad_w, pad_w, pad_h, pad_h), value=0)
input_padded_pyt_tensor = input_padded_pyt_tensor.permute(0, 2, 3, 1)
input_padded_tensor = input_padded_pyt_tensor.reshape(-1).tolist()
# run trace conv reference to validate pad_metadata and data_top_left_indices

# Generate more configs -
Expand All @@ -349,7 +349,6 @@ def test_generate_all_configs_and_references(
num_cores_nhw,
filter_h,
filter_w,
# input_c,
)
# print("req_conv_input_shard_start_end-", req_conv_input_shard_start_end)
# print("tensor_metadata-", tensor_metadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,17 +593,10 @@ def set_op_configs(

input_padded_width = input_w + 2 * pad_w

dummy = torch.rand(batch_size * input_h * input_w, dtype=torch.bfloat16)
dummy = torch.reshape(dummy, input_nchw_shape)

# TODO: We should remove C from input_nchw_shape since none of the specs depend on it
# TODO: Pass sliding_window_op_params instead of conv_param?
(
pad_metadata,
data_top_left_indices,
_,
) = trace_conv_to_generate_data_top_left_indices_and_pad_metadata(
conv_params, input_nchw_shape, dummy.reshape(-1).tolist()
pad_metadata, data_top_left_indices = trace_conv_to_generate_data_top_left_indices_and_pad_metadata(
conv_params, input_nchw_shape
)

req_conv_input_shard_start_end, tensor_metadata = decompose_conv_into_shards_and_generate_tensor_metadata(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,17 +227,8 @@ def set_op_configs(self, sliding_window_op_params_hash, reader_patterns_cache):

input_padded_width = input_w + 2 * pad_w

dummy = torch.rand(batch_size * input_h * input_w, dtype=torch.bfloat16)
dummy = torch.reshape(dummy, input_nchw_shape)

(
pad_metadata,
data_top_left_indices,
_,
) = trace_conv_to_generate_data_top_left_indices_and_pad_metadata(
(1, 1, window_h, window_w, stride_h, stride_w, pad_h, pad_w, 1, 1),
input_nchw_shape,
dummy.reshape(-1).tolist(),
pad_metadata, data_top_left_indices = trace_conv_to_generate_data_top_left_indices_and_pad_metadata(
(1, 1, window_h, window_w, stride_h, stride_w, pad_h, pad_w, 1, 1), input_nchw_shape
)

req_conv_input_shard_start_end, tensor_metadata = decompose_conv_into_shards_and_generate_tensor_metadata(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,8 @@ def set_op_configs(
# output_channels, input_channels, filter_h, filter_w, stride_h, stride_w, pad_h, pad_w, dilation, groups
sliding_window_op_all_params = [1, 1, window_h, window_w, stride_h, stride_w, pad_h, pad_w, 1, 1]
input_nchw_shape = [input_n, 1, input_h, input_w]

dummy = torch.rand(input_n * input_h * input_w, dtype=torch.bfloat16)
dummy = torch.reshape(dummy, input_nchw_shape)
(
pad_metadata,
data_top_left_indices,
_,
) = trace_conv_to_generate_data_top_left_indices_and_pad_metadata(
sliding_window_op_all_params, input_nchw_shape, dummy.reshape(-1).tolist()
pad_metadata, data_top_left_indices = trace_conv_to_generate_data_top_left_indices_and_pad_metadata(
sliding_window_op_all_params, input_nchw_shape
)
sliding_window_output_shard_nhw_size = get_sliding_window_op_output_shard_nhw_size(
num_cores_nhw,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@

import torch
import numpy as np
from loguru import logger

from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_allclose_and_pcc


def trace_conv_to_generate_data_top_left_indices_and_pad_metadata(
conv_params, input_nchw_shape, input_tensor, pad_val: torch.int16 = 0x0
conv_params, input_nchw_shape # , input_tensor, pad_val: torch.int16 = 0x0
):
if pad_val == 0xF7FF:
pad_val = -1.03e34 ## TODO: how to do this in python properly???
assert len(conv_params) == 10
output_channels, input_channels, filter_h, filter_w, stride_h, stride_w, pad_h, pad_w, dilation, groups = [
conv_params[i] for i in range(10)
Expand All @@ -22,9 +19,6 @@ def trace_conv_to_generate_data_top_left_indices_and_pad_metadata(
assert len(input_nchw_shape) == 4
input_n, input_c, input_h, input_w = [input_nchw_shape[i] for i in range(4)]

input_tensor_nchw = np.reshape(input_tensor, input_nchw_shape)
input_tensor_nhwc = np.transpose(input_tensor_nchw, (0, 2, 3, 1))
input_tensor_nhwc = np.reshape(input_tensor_nhwc, (np.prod(input_nchw_shape)))
# image 1 data
# 1 2 3 4 5 6 7 8
# 9 10 11 12 13 14 15 16
Expand Down Expand Up @@ -54,41 +48,33 @@ def trace_conv_to_generate_data_top_left_indices_and_pad_metadata(
# We encode above shown padded tensor into pad_metadata (list of boolean - true if padding location)
# pad_meta_data: [true, true, ..., false, ...]
index = 0
input_idx = 0
input_pad_idx = 0
padded_input_h = input_h + (2 * pad_h)
padded_input_w = input_w + (2 * pad_w)
pad_metadata = np.full(input_n * padded_input_h * padded_input_w, False, dtype=bool)
input_padded_tensor = np.full(
input_n * input_c * padded_input_h * padded_input_w, pad_val, dtype=type(input_tensor_nhwc[0])
)
for n in range(input_n):
for h in range(padded_input_h):
for w in range(padded_input_w):
if h < pad_h or h >= (input_h + pad_h) or w < pad_w or w >= (input_w + pad_w):
pad_metadata[index] = True
input_pad_idx += input_c
else:
for c in range(input_c):
input_padded_tensor[input_pad_idx + c] = input_tensor_nhwc[input_idx]
input_idx += 1
input_pad_idx += input_c
index += 1

# TODO: add support for dilation > 1
output_h = ((int)((padded_input_h - filter_h) / stride_h)) + 1
output_w = ((int)((padded_input_w - filter_w) / stride_w)) + 1
# generate a list of input indices corresponding to the top left position of sliding window
# the index refers to the location in the padded tensor
data_top_left_indices = []
# data_top_left_indices = []
index = 0
data_top_left_indices = np.full(input_n * output_h * output_w, 0, dtype=int)
for n in range(input_n):
for oh in range(output_h):
for ow in range(output_w):
ih = oh * stride_h
iw = ow * stride_w
channel_idx = (n * padded_input_h * padded_input_w) + (ih * padded_input_w) + iw
data_top_left_indices.append(channel_idx)
return pad_metadata.tolist(), data_top_left_indices, input_padded_tensor.tolist()
data_top_left_indices[index] = channel_idx
index += 1
return pad_metadata.tolist(), data_top_left_indices.tolist()


def validate_input_padded_tensor_and_data_top_left_indices_and_pad_metadata(
Expand Down

0 comments on commit 4840421

Please sign in to comment.