Skip to content

Commit

Permalink
#4957: optimizing construct_2d_padded_tensor_list - shaves off ~8 min…
Browse files Browse the repository at this point in the history
…s from post-commit
  • Loading branch information
vtangTT committed Feb 26, 2024
1 parent a9685d4 commit 2c082f7
Showing 1 changed file with 9 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,26 @@ def construct_2d_padded_tensor_list(input_tensor, input_nchw_shape, pad_metadata
if pad_val == 0xF7FF:
pad_val = -1.03e34 ## TODO: how to do this in python properly???
# Construct the padded tensor using pad_metadata
input_padded_tensor = []
input_tensor_idx = 0
assert len(input_nchw_shape) == 4
input_n, input_c, input_h, input_w = [input_nchw_shape[i] for i in range(4)]
# Permute input tensor from nchw shape to nhwc shape
input_tensor_nchw = np.reshape(input_tensor, input_nchw_shape)
input_tensor_nhwc = np.transpose(input_tensor_nchw, (0, 2, 3, 1))
input_tensor_nhwc = np.reshape(input_tensor_nhwc, (np.prod(input_nchw_shape)))

# input_padded_tensor = np.full(len(pad_metadata)*input_c, pad_val, dtype=float)
input_padded_tensor = np.full(len(pad_metadata) * input_c, pad_val, dtype=type(input_tensor_nhwc[0]))
index = 0
for i in range(len(pad_metadata)):
for c in range(input_c):
if pad_metadata[i]:
input_padded_tensor.append(pad_val)
else:
if not pad_metadata[i]:
assert input_tensor_idx < len(input_tensor_nhwc)
input_padded_tensor.append(input_tensor_nhwc[input_tensor_idx])
input_padded_tensor[index] = input_tensor_nhwc[input_tensor_idx]
input_tensor_idx += 1
return input_padded_tensor
index += 1

return input_padded_tensor.tolist()


def trace_conv_to_generate_data_top_left_indices_and_pad_metadata(conv_params, input_nchw_shape):
Expand Down

0 comments on commit 2c082f7

Please sign in to comment.