Skip to content

Commit

Permalink
#0: Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
sankarmanoj-tt committed Dec 20, 2024
1 parent 75f8e4e commit 1007506
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1623,7 +1623,6 @@ def test_conv2d_localrun(device, input_spec):
[1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 127
[1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], # 220
[1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 294
[1, 3024, 1232, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], # 1421
[1, 819, 256, 100, 136, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1443
[1, 819, 256, 50, 68, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1447
[1, 1024, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 1458
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ ParallelConfig determine_parallel_config(
return pconfig;
}

static ParallelConfig determine_output_parallel_config(
ParallelConfig determine_output_parallel_config(
const ParallelConfig& input_parallel_config,
const CoreCoord& compute_grid_size,
uint32_t out_channels,
Expand Down
6 changes: 6 additions & 0 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout(
// Converts convolution weights to grouped layout with padded zeros
Tensor convert_conv_weight_tensor_to_grouped_layout(const Tensor& conv_weight_tensor, uint32_t num_groups, DataType output_dtype);

sliding_window::ParallelConfig determine_output_parallel_config(
const sliding_window::ParallelConfig& input_parallel_config,
const CoreCoord& compute_grid_size,
uint32_t out_channels,
bool is_mm_conv);

std::ostream& operator<<(std::ostream& os, const Conv2dConfig& config);

} // namespace operations::conv
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,9 @@ ttnn::Tensor prepare_conv_weights(
shard_orientation,
!use_non_tile_height);

ParallelConfig output_parallel_config =
determine_output_parallel_config(parallel_config, device->compute_with_storage_grid_size(), out_channels, mm_conv);

bool is_non_tile_mul_width = check_non_tile_mul_width(device, conv_config, in_channels);
std::optional<const ttnn::Tensor> bias_tensor = std::nullopt;
ttnn::Tensor weight_tensor_on_device = weight_tensor;
Expand All @@ -380,7 +383,7 @@ ttnn::Tensor prepare_conv_weights(
opt_conv_op_block_config.act_block_w_ntiles,
opt_conv_op_block_config.out_subblock_w_ntiles,
parallel_config,
parallel_config,
output_parallel_config,
device,
groups,
opt_conv_op_block_config.act_block_h_ntiles,
Expand Down

0 comments on commit 1007506

Please sign in to comment.