Skip to content

Commit

Permalink
Add support for pretty printing Conv2dConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
esmalTT committed Dec 14, 2024
1 parent 2ba5a59 commit c33b67e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 1 deletion.
2 changes: 2 additions & 0 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ void py_bind_conv2d(py::module& module) {
py_conv_config.def_readwrite("enable_split_reader", &Conv2dConfig::enable_split_reader);
py_conv_config.def_readwrite("enable_subblock_padding", &Conv2dConfig::enable_subblock_padding);

py_conv_config.def("__repr__", [](const Conv2dConfig &config) { return fmt::format("{}", config);} );

py::class_<OptimizedConvParallelizationConfig>(module, "OptimizedConvParallelizationConfig")
.def(
py::init<CoreCoord, uint32_t, uint32_t, uint32_t, uint32_t>(),
Expand Down
7 changes: 6 additions & 1 deletion ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool, bool> get_conv_padded_input_sh
}

if (conv_config.override_sharding_config) {
TT_FATAL(conv_config.core_grid.has_value(), "Error");
TT_FATAL(conv_config.core_grid.has_value(), "Core grid must be provided when overriding sharding config");
// override parallel config
auto shard_orientation = shard_layout == TensorMemoryLayout::BLOCK_SHARDED
? block_shard_orientation
Expand Down Expand Up @@ -898,5 +898,10 @@ template std::tuple<ttnn::Tensor, ParallelConfig, ParallelConfig, bool, bool> sh
bool is_mm_conv,
bool is_non_tile_mul_width);

std::ostream& operator<<(std::ostream& os, const Conv2dConfig& config) {
tt::stl::reflection::operator<<(os, config);
return os;
}

} // namespace operations
} // namespace ttnn
2 changes: 2 additions & 0 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,5 +226,7 @@ OptimizedConvBlockConfig get_opt_block_config(
Layout input_tensor_layout,
Conv2dConfig& conv_config);

std::ostream& operator<<(std::ostream& os, const Conv2dConfig& config);

} // namespace operations::conv
} // namespace ttnn

0 comments on commit c33b67e

Please sign in to comment.