diff --git a/tests/tt_eager/python_api_testing/unit_testing/test_configs_for_untilize_with_halo_and_conv.py b/tests/tt_eager/python_api_testing/unit_testing/test_configs_for_untilize_with_halo_and_conv.py index 10e47f8b9860..d0cd33fdf6a4 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/test_configs_for_untilize_with_halo_and_conv.py +++ b/tests/tt_eager/python_api_testing/unit_testing/test_configs_for_untilize_with_halo_and_conv.py @@ -119,14 +119,14 @@ def test_generate_all_configs_and_references( input_padded_height = input_h + 2 * pad_h # Generate following configs by tracing conv - logger.info("Trace conv and generate following configs - pad_metadata and data_top_left_indices.") - ( - pad_metadata, - data_top_left_indices, - input_padded_tensor, - ) = trace_conv_to_generate_data_top_left_indices_and_pad_metadata( - conv_params, input_nchw_shape, input_pyt_tensor.reshape(-1).tolist() + pad_metadata, data_top_left_indices = trace_conv_to_generate_data_top_left_indices_and_pad_metadata( + conv_params, input_nchw_shape ) + logger.info("Generate input tensor") + input_padded_pyt_tensor = torch.nn.functional.pad(input_pyt_tensor, (pad_w, pad_w, pad_h, pad_h), value=0) + input_padded_pyt_tensor = input_padded_pyt_tensor.permute(0, 2, 3, 1) + input_padded_tensor = input_padded_pyt_tensor.reshape(-1).tolist() # run trace conv reference to validate pad_metadata and data_top_left_indices logger.info("Validate pad_metadata and data_top_left_indices.") diff --git a/tests/tt_eager/python_api_testing/unit_testing/test_untilize_with_halo_v2.py b/tests/tt_eager/python_api_testing/unit_testing/test_untilize_with_halo_v2.py index 7625b32c63ed..f4bf5b24195e 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/test_untilize_with_halo_v2.py +++ b/tests/tt_eager/python_api_testing/unit_testing/test_untilize_with_halo_v2.py @@ -324,16 +324,16 @@ def test_generate_all_configs_and_references( input_padded_height = input_h + 2 * pad_h # Generate following configs by tracing conv - logger.info("Trace conv and generate follwing configs - pad_metadata and data_top_left_indices.") - ( - pad_metadata, - data_top_left_indices, - input_padded_tensor, - ) = trace_conv_to_generate_data_top_left_indices_and_pad_metadata( - conv_params, input_nchw_shape, input_pyt_tensor.reshape(-1).tolist() + pad_metadata, data_top_left_indices = trace_conv_to_generate_data_top_left_indices_and_pad_metadata( + conv_params, input_nchw_shape ) # # print("Data top left indices - ", data_top_left_indices) # # print("Pad meta data -", pad_metadata) + logger.info("Generate input tensor") + input_padded_pyt_tensor = torch.nn.functional.pad(input_pyt_tensor, (pad_w, pad_w, pad_h, pad_h), value=0) + input_padded_pyt_tensor = input_padded_pyt_tensor.permute(0, 2, 3, 1) + input_padded_tensor = input_padded_pyt_tensor.reshape(-1).tolist() # run trace conv reference to validate pad_metadata and data_top_left_indices # Generate more configs - @@ -349,7 +349,6 @@ def test_generate_all_configs_and_references( num_cores_nhw, filter_h, filter_w, - # input_c, ) # print("req_conv_input_shard_start_end-", req_conv_input_shard_start_end) # print("tensor_metadata-", tensor_metadata) diff --git a/tt_eager/tt_dnn/op_library/sliding_window_op_infra/tt_py_composite_conv.py b/tt_eager/tt_dnn/op_library/sliding_window_op_infra/tt_py_composite_conv.py index 6c729bebc2c9..8223aa40e999 100644 --- a/tt_eager/tt_dnn/op_library/sliding_window_op_infra/tt_py_composite_conv.py +++ b/tt_eager/tt_dnn/op_library/sliding_window_op_infra/tt_py_composite_conv.py @@ -593,17 +593,10 @@ def set_op_configs( input_padded_width = input_w + 2 * pad_w - dummy = torch.rand(batch_size * input_h * input_w, dtype=torch.bfloat16) - dummy = torch.reshape(dummy, input_nchw_shape) - # TODO: We should remove C from input_nchw_shape since none of the specs depend on it # TODO: Pass sliding_window_op_params instead of conv_param? - ( - pad_metadata, - data_top_left_indices, - _, - ) = trace_conv_to_generate_data_top_left_indices_and_pad_metadata( - conv_params, input_nchw_shape, dummy.reshape(-1).tolist() + pad_metadata, data_top_left_indices = trace_conv_to_generate_data_top_left_indices_and_pad_metadata( + conv_params, input_nchw_shape ) req_conv_input_shard_start_end, tensor_metadata = decompose_conv_into_shards_and_generate_tensor_metadata( diff --git a/tt_eager/tt_dnn/op_library/sliding_window_op_infra/tt_py_max_pool.py b/tt_eager/tt_dnn/op_library/sliding_window_op_infra/tt_py_max_pool.py index 7781860ff95e..a96e419ade2c 100644 --- a/tt_eager/tt_dnn/op_library/sliding_window_op_infra/tt_py_max_pool.py +++ b/tt_eager/tt_dnn/op_library/sliding_window_op_infra/tt_py_max_pool.py @@ -227,17 +227,8 @@ def set_op_configs(self, sliding_window_op_params_hash, reader_patterns_cache): input_padded_width = input_w + 2 * pad_w - dummy = torch.rand(batch_size * input_h * input_w, dtype=torch.bfloat16) - dummy = torch.reshape(dummy, input_nchw_shape) - - ( - pad_metadata, - data_top_left_indices, - _, - ) = trace_conv_to_generate_data_top_left_indices_and_pad_metadata( - (1, 1, window_h, window_w, stride_h, stride_w, pad_h, pad_w, 1, 1), - input_nchw_shape, - dummy.reshape(-1).tolist(), + pad_metadata, data_top_left_indices = trace_conv_to_generate_data_top_left_indices_and_pad_metadata( + (1, 1, window_h, window_w, stride_h, stride_w, pad_h, pad_w, 1, 1), input_nchw_shape ) req_conv_input_shard_start_end, tensor_metadata = decompose_conv_into_shards_and_generate_tensor_metadata( diff --git a/tt_eager/tt_dnn/op_library/sliding_window_op_infra/tt_py_untilize_with_halo.py b/tt_eager/tt_dnn/op_library/sliding_window_op_infra/tt_py_untilize_with_halo.py index 6a6babd8b9c0..5389ce3a72a7 100644 --- a/tt_eager/tt_dnn/op_library/sliding_window_op_infra/tt_py_untilize_with_halo.py +++ b/tt_eager/tt_dnn/op_library/sliding_window_op_infra/tt_py_untilize_with_halo.py @@ -110,15 +110,8 @@ def set_op_configs( # output_channels, input_channels, filter_h, filter_w, stride_h, stride_w, pad_h, pad_w, dilation, groups sliding_window_op_all_params = [1, 1, window_h, window_w, stride_h, stride_w, pad_h, pad_w, 1, 1] input_nchw_shape = [input_n, 1, input_h, input_w] - - dummy = torch.rand(input_n * input_h * input_w, dtype=torch.bfloat16) - dummy = torch.reshape(dummy, input_nchw_shape) - ( - pad_metadata, - data_top_left_indices, - _, - ) = trace_conv_to_generate_data_top_left_indices_and_pad_metadata( - sliding_window_op_all_params, input_nchw_shape, dummy.reshape(-1).tolist() + pad_metadata, data_top_left_indices = trace_conv_to_generate_data_top_left_indices_and_pad_metadata( + sliding_window_op_all_params, input_nchw_shape ) sliding_window_output_shard_nhw_size = get_sliding_window_op_output_shard_nhw_size( num_cores_nhw, diff --git a/tt_eager/tt_dnn/op_library/sliding_window_op_infra/untilize_with_halo_config_generation_and_validation.py b/tt_eager/tt_dnn/op_library/sliding_window_op_infra/untilize_with_halo_config_generation_and_validation.py index 99e3baf1aacd..33b1795cc1c3 100644 --- a/tt_eager/tt_dnn/op_library/sliding_window_op_infra/untilize_with_halo_config_generation_and_validation.py +++ b/tt_eager/tt_dnn/op_library/sliding_window_op_infra/untilize_with_halo_config_generation_and_validation.py @@ -4,16 +4,13 @@ import torch import numpy as np -from loguru import logger from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_allclose_and_pcc def trace_conv_to_generate_data_top_left_indices_and_pad_metadata( - conv_params, input_nchw_shape, input_tensor, pad_val: torch.int16 = 0x0 + conv_params, input_nchw_shape # , input_tensor, pad_val: torch.int16 = 0x0 ): - if pad_val == 0xF7FF: - pad_val = -1.03e34 ## TODO: how to do this in python properly??? assert len(conv_params) == 10 output_channels, input_channels, filter_h, filter_w, stride_h, stride_w, pad_h, pad_w, dilation, groups = [ conv_params[i] for i in range(10) @@ -22,9 +19,6 @@ def trace_conv_to_generate_data_top_left_indices_and_pad_metadata( assert len(input_nchw_shape) == 4 input_n, input_c, input_h, input_w = [input_nchw_shape[i] for i in range(4)] - input_tensor_nchw = np.reshape(input_tensor, input_nchw_shape) - input_tensor_nhwc = np.transpose(input_tensor_nchw, (0, 2, 3, 1)) - input_tensor_nhwc = np.reshape(input_tensor_nhwc, (np.prod(input_nchw_shape))) # image 1 data # 1 2 3 4 5 6 7 8 # 9 10 11 12 13 14 15 16 @@ -54,25 +48,14 @@ def trace_conv_to_generate_data_top_left_indices_and_pad_metadata( # We encode above shown padded tensor into pad_metadata (list of boolean - true if padding location) # pad_meta_data: [true, true, ..., false, ...] index = 0 - input_idx = 0 - input_pad_idx = 0 padded_input_h = input_h + (2 * pad_h) padded_input_w = input_w + (2 * pad_w) pad_metadata = np.full(input_n * padded_input_h * padded_input_w, False, dtype=bool) - input_padded_tensor = np.full( - input_n * input_c * padded_input_h * padded_input_w, pad_val, dtype=type(input_tensor_nhwc[0]) - ) for n in range(input_n): for h in range(padded_input_h): for w in range(padded_input_w): if h < pad_h or h >= (input_h + pad_h) or w < pad_w or w >= (input_w + pad_w): pad_metadata[index] = True - input_pad_idx += input_c - else: - for c in range(input_c): - input_padded_tensor[input_pad_idx + c] = input_tensor_nhwc[input_idx] - input_idx += 1 - input_pad_idx += input_c index += 1 # TODO: add support for dilation > 1 @@ -80,15 +63,18 @@ def trace_conv_to_generate_data_top_left_indices_and_pad_metadata( output_w = ((int)((padded_input_w - filter_w) / stride_w)) + 1 # generate a list of input indices corresponding to the top left position of sliding window # the index refers to the location in the padded tensor - data_top_left_indices = [] + # data_top_left_indices = [] + index = 0 + data_top_left_indices = np.full(input_n * output_h * output_w, 0, dtype=int) for n in range(input_n): for oh in range(output_h): for ow in range(output_w): ih = oh * stride_h iw = ow * stride_w channel_idx = (n * padded_input_h * padded_input_w) + (ih * padded_input_w) + iw - data_top_left_indices.append(channel_idx) - return pad_metadata.tolist(), data_top_left_indices, input_padded_tensor.tolist() + data_top_left_indices[index] = channel_idx + index += 1 + return pad_metadata.tolist(), data_top_left_indices.tolist() def validate_input_padded_tensor_and_data_top_left_indices_and_pad_metadata(