Skip to content

Commit

Permalink
#0: Fix failing test case for width sharded non-32 multiple output width
Browse files Browse the repository at this point in the history
  • Loading branch information
sankarmanoj-tt committed Dec 20, 2024
1 parent 2336024 commit 8db5c45
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 12 deletions.
1 change: 1 addition & 0 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ Result conv2d(
opt_conv_op_block_config.act_block_w_ntiles,
opt_conv_op_block_config.out_subblock_w_ntiles,
parallel_config,
output_parallel_config,
device,
groups,
opt_conv_op_block_config.act_block_h_ntiles,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_width_sh
uint32_t num_groups = num_blocks_act_h * num_blocks_act_w * num_blocks_weight_w;
// writer of conv op partially removes padding on the width
// it removes the padding done for block width but it doesn't remove padding done for tiled width
uint32_t output_channels_padded_to_tile_width = round_up(output_channels, input_num_cores * TILE_WIDTH);
uint32_t output_channels_padded_to_tile_width = round_up(output_channels, output_num_cores * TILE_WIDTH);
TT_FATAL(
output_channels_padded_to_tile_width <= weight_matrix_width,
"output_channels_padded_to_tile_width {} should be less than or equal to weight_matrix_width {}",
Expand Down
26 changes: 16 additions & 10 deletions ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases
DataType weights_bias_dtype,
uint32_t weight_block_h_ntiles,
uint32_t weight_block_w_ntiles,
const ParallelConfig& parallel_config,
const ParallelConfig& input_parallel_config,
const ParallelConfig& output_parallel_config,
T * device,
uint32_t groups,
uint32_t act_block_h_ntiles,
Expand Down Expand Up @@ -231,9 +232,11 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases
uint32_t window_h = weights_shape[2];
uint32_t window_w = weights_shape[3];

uint32_t num_cores_channels = get_num_cores_channels_from_parallel_config(parallel_config);
uint32_t out_channels_padded = tt::round_up(out_channels, num_cores_channels * tt::constants::TILE_WIDTH);
uint32_t in_channels_padded = tt::round_up(in_channels, num_cores_channels * input_channels_alignment);
uint32_t input_num_cores_channels = get_num_cores_channels_from_parallel_config(input_parallel_config);
uint32_t output_num_cores_channels = get_num_cores_channels_from_parallel_config(output_parallel_config);

uint32_t out_channels_padded = tt::round_up(out_channels, output_num_cores_channels * tt::constants::TILE_WIDTH);
uint32_t in_channels_padded = tt::round_up(in_channels, input_num_cores_channels * input_channels_alignment);
uint32_t out_channel_padding = out_channels_padded - out_channels;

tt::tt_metal::LegacyShape weights_channels_padded_shape = tt::tt_metal::LegacyShape(std::array<uint32_t, 4>(
Expand All @@ -258,12 +261,12 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases
weight_tensor_ = ttnn::pad(weight_tensor_, weights_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), 0);

// for conv op, pad the weights to block shape
if (parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) {
if (input_parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) {
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_special_padding_tiled_layout(
weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype);
} else if(parallel_config.shard_scheme == TensorMemoryLayout::BLOCK_SHARDED) {
} else if(input_parallel_config.shard_scheme == TensorMemoryLayout::BLOCK_SHARDED) {
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout_block_sharded(
weight_tensor_, num_cores_channels, weights_bias_dtype);
weight_tensor_, input_num_cores_channels, weights_bias_dtype);
} else {
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout(
weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype);
Expand All @@ -289,7 +292,7 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases
bias_tensor_ = bias_tensor.value();
bool is_bias_tensor_is_on_device = ttnn::is_tensor_on_device_or_multidevice(bias_tensor_);
if(!is_bias_tensor_is_on_device) {
bias_tensor_ = conv_bias_layout_convert(bias_tensor_, weights_bias_dtype, weight_block_h_ntiles, weight_block_w_ntiles, parallel_config, device, out_channels, is_non_tile_mul_width);
bias_tensor_ = conv_bias_layout_convert(bias_tensor_, weights_bias_dtype, weight_block_h_ntiles, weight_block_w_ntiles, input_parallel_config, device, out_channels, is_non_tile_mul_width);
bias_tensor_ = ttnn::operations::core::to_device(bias_tensor_, device, std::nullopt);
}
}
Expand Down Expand Up @@ -377,6 +380,7 @@ ttnn::Tensor prepare_conv_weights(
opt_conv_op_block_config.act_block_w_ntiles,
opt_conv_op_block_config.out_subblock_w_ntiles,
parallel_config,
parallel_config,
device,
groups,
opt_conv_op_block_config.act_block_h_ntiles,
Expand Down Expand Up @@ -550,7 +554,8 @@ template std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weigh
DataType weights_bias_dtype,
uint32_t weight_block_h_ntiles,
uint32_t weight_block_w_ntiles,
const ParallelConfig& parallel_config,
const ParallelConfig& input_parallel_config,
const ParallelConfig& output_parallel_config,
Device* device,
uint32_t groups,
uint32_t act_block_h_ntiles,
Expand All @@ -565,7 +570,8 @@ template std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weigh
DataType weights_bias_dtype,
uint32_t weight_block_h_ntiles,
uint32_t weight_block_w_ntiles,
const ParallelConfig& parallel_config,
const ParallelConfig& input_parallel_config,
const ParallelConfig& output_parallel_config,
MeshDevice* device,
uint32_t groups,
uint32_t act_block_h_ntiles,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases
DataType weights_bias_dtype,
uint32_t weight_block_h_ntiles,
uint32_t weight_block_w_ntiles,
const sliding_window::ParallelConfig& parallel_config,
const sliding_window::ParallelConfig& input_parallel_config,
const sliding_window::ParallelConfig& output_parallel_config,
T * device,
uint32_t groups,
uint32_t act_block_h_ntiles,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ Result conv_transpose2d(
opt_conv_op_block_config.act_block_w_ntiles,
opt_conv_op_block_config.out_subblock_w_ntiles,
parallel_config,
output_parallel_config,
device,
groups,
opt_conv_op_block_config.act_block_h_ntiles,
Expand Down

0 comments on commit 8db5c45

Please sign in to comment.