diff --git a/.clang-format-ignore b/.clang-format-ignore index bcf448b7fc2..2125ca6ce0a 100644 --- a/.clang-format-ignore +++ b/.clang-format-ignore @@ -70,18 +70,6 @@ ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_worker_builder.c ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_worker_builder.hpp ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp -ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp -ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp -ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp -ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp -ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp -ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp -ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp -ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp -ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp -ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp -ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp -ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d_pybind.cpp ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reader_unary_stick_layout_sharded_blocks_interleaved_start_id.cpp ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_op.cpp ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.cpp diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index 690cfcbde9a..7f0e355b594 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -26,8 +26,8 @@ using namespace tt; namespace ttnn { namespace operations::conv { -using sliding_window::SlidingWindowConfig; using sliding_window::ParallelConfig; +using sliding_window::SlidingWindowConfig; namespace conv2d { @@ -55,7 +55,8 @@ Result conv2d( const std::optional& compute_config_, const std::optional& memory_config) { const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups); - const uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; + const uint32_t output_height = + ((input_height - kernel_size[0] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; const uint32_t output_width = ((input_width - kernel_size[1] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1; @@ -105,26 +106,28 @@ Result conv2d( uint32_t round_up_size = !use_non_tile_height ? tt::constants::TILE_HEIGHT : 1; uint32_t nhw_out = batch_size * output_height * output_width; uint32_t out_channels_padded = tt::round_up( - out_channels, - get_num_cores_channels_from_parallel_config(output_parallel_config) * tt::constants::TILE_WIDTH); - if(is_non_tile_mul_width) { + out_channels, get_num_cores_channels_from_parallel_config(output_parallel_config) * tt::constants::TILE_WIDTH); + if (is_non_tile_mul_width) { out_channels_padded = tt::round_up(out_channels, 32); } MemoryConfig conv_out_memory_config = create_sharded_memory_config_from_parallel_config( ttnn::Shape(std::array{1, 1, nhw_out, out_channels_padded}), output_parallel_config, round_up_size); - ParallelConfig largest_parallel_config = output_parallel_config.grid.num_cores() > parallel_config.grid.num_cores() ? output_parallel_config : parallel_config; + ParallelConfig largest_parallel_config = output_parallel_config.grid.num_cores() > parallel_config.grid.num_cores() + ? output_parallel_config + : parallel_config; - OptimizedConvParallelizationConfig opt_conv_op_parallel_config = determine_conv_op_parallel_config_from_conv_output_mem_config( - conv_out_memory_config, - get_num_cores_nhw_from_parallel_config(largest_parallel_config), - get_num_cores_channels_from_parallel_config(largest_parallel_config)); + OptimizedConvParallelizationConfig opt_conv_op_parallel_config = + determine_conv_op_parallel_config_from_conv_output_mem_config( + conv_out_memory_config, + get_num_cores_nhw_from_parallel_config(largest_parallel_config), + get_num_cores_channels_from_parallel_config(largest_parallel_config)); uint32_t in_channels_padded = tt::round_up( in_channels, get_num_cores_channels_from_parallel_config(parallel_config) * conv_config.input_channels_alignment); - if(is_non_tile_mul_width){ + if (is_non_tile_mul_width) { in_channels_padded = tt::round_up(in_channels, conv_config.input_channels_alignment); } @@ -191,8 +194,8 @@ Result conv2d( if (bypass_halo) { if (input_tensor_post_tm.layout() == Layout::TILE) { - input_tensor_post_tm = ttnn::to_layout( - input_tensor_post_tm, Layout::ROW_MAJOR, std::nullopt, std::nullopt, device); + input_tensor_post_tm = + ttnn::to_layout(input_tensor_post_tm, Layout::ROW_MAJOR, std::nullopt, std::nullopt, device); } } else { Tensor halo_output = ttnn::halo( @@ -207,7 +210,7 @@ Result conv2d( !use_non_tile_height); if (conv_config.deallocate_activation) { - input_tensor_post_tm.deallocate(/*force*/true); + input_tensor_post_tm.deallocate(/*force*/ true); } input_tensor_post_tm = std::move(halo_output); @@ -281,7 +284,7 @@ Result Conv2dOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, const ttnn::Tensor& weight_tensor, - Device * device, + Device* device, uint32_t in_channels, uint32_t out_channels, uint32_t batch_size, @@ -295,15 +298,32 @@ Result Conv2dOperation::invoke( std::optional bias_tensor, const std::optional& conv_config_, const std::optional& compute_config_, - const std::optional& memory_config){ - return conv2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, std::move(bias_tensor), std::move(conv_config_), std::move(compute_config_), memory_config); + const std::optional& memory_config) { + return conv2d( + input_tensor, + weight_tensor, + device, + in_channels, + out_channels, + batch_size, + input_height, + input_width, + kernel_size, + stride, + padding, + dilation, + groups, + std::move(bias_tensor), + std::move(conv_config_), + std::move(compute_config_), + memory_config); } Result Conv2dOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, const ttnn::Tensor& weight_tensor, - MeshDevice * device, + MeshDevice* device, uint32_t in_channels, uint32_t out_channels, uint32_t batch_size, @@ -317,11 +337,27 @@ Result Conv2dOperation::invoke( std::optional bias_tensor, const std::optional& conv_config_, const std::optional& compute_config_, - const std::optional& memory_config){ - return conv2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, std::move(bias_tensor), std::move(conv_config_), std::move(compute_config_), memory_config); + const std::optional& memory_config) { + return conv2d( + input_tensor, + weight_tensor, + device, + in_channels, + out_channels, + batch_size, + input_height, + input_width, + kernel_size, + stride, + padding, + dilation, + groups, + std::move(bias_tensor), + std::move(conv_config_), + std::move(compute_config_), + memory_config); } - } // namespace conv2d -} // namespace operations +} // namespace operations::conv } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp index e39a1f2257b..60e2574c831 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp @@ -23,7 +23,7 @@ template Result conv2d( const ttnn::Tensor& input_tensor, const ttnn::Tensor& weight_tensor, - T * device, + T* device, uint32_t in_channels, uint32_t out_channels, uint32_t batch_size, @@ -39,13 +39,12 @@ Result conv2d( const std::optional& compute_config_ = std::nullopt, const std::optional& memory_config = std::nullopt); - -struct Conv2dOperation{ +struct Conv2dOperation { static Result invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, const ttnn::Tensor& weight_tensor, - Device * device, + Device* device, uint32_t in_channels, uint32_t out_channels, uint32_t batch_size, @@ -65,7 +64,7 @@ struct Conv2dOperation{ uint8_t queue_id, const ttnn::Tensor& input_tensor, const ttnn::Tensor& weight_tensor, - MeshDevice * device, + MeshDevice* device, uint32_t in_channels, uint32_t out_channels, uint32_t batch_size, @@ -85,6 +84,6 @@ struct Conv2dOperation{ } // namespace operations::conv } // namespace ttnn -namespace ttnn{ - constexpr auto conv2d = ttnn::register_operation<"ttnn::conv2d", operations::conv::conv2d::Conv2dOperation>(); +namespace ttnn { +constexpr auto conv2d = ttnn::register_operation<"ttnn::conv2d", operations::conv::conv2d::Conv2dOperation>(); } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp index aa855f24200..b4649e34fb0 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp @@ -2,8 +2,6 @@ // // SPDX-License-Identifier: Apache-2.0 - - #include "common/constants.hpp" #include "ttnn/cpp/pybind11/decorators.hpp" @@ -21,7 +19,6 @@ namespace operations::conv { namespace conv2d { void py_bind_conv2d(py::module& module) { - bind_registered_operation( module, ttnn::conv2d, @@ -46,25 +43,44 @@ void py_bind_conv2d(py::module& module) { +-------------------+-------------------------------+---------------+-------------+----------+ )doc", ttnn::pybind_overload_t{ - [](const decltype(ttnn::conv2d)& self, const ttnn::Tensor& input_tensor, - const ttnn::Tensor& weight_tensor, - ttnn::Device* device, - uint32_t in_channels, - uint32_t out_channels, - uint32_t batch_size, - uint32_t input_height, - uint32_t input_width, - std::array kernel_size, - std::array stride, - std::array padding, - std::array dilation, - uint32_t groups, - std::optional bias_tensor, - const std::optional& conv_config, - const std::optional& compute_config, - const std::optional& memory_config, - const uint8_t& queue_id) -> Result { - return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config, compute_config, memory_config); + [](const decltype(ttnn::conv2d)& self, + const ttnn::Tensor& input_tensor, + const ttnn::Tensor& weight_tensor, + ttnn::Device* device, + uint32_t in_channels, + uint32_t out_channels, + uint32_t batch_size, + uint32_t input_height, + uint32_t input_width, + std::array kernel_size, + std::array stride, + std::array padding, + std::array dilation, + uint32_t groups, + std::optional bias_tensor, + const std::optional& conv_config, + const std::optional& compute_config, + const std::optional& memory_config, + const uint8_t& queue_id) -> Result { + return self( + queue_id, + input_tensor, + weight_tensor, + device, + in_channels, + out_channels, + batch_size, + input_height, + input_width, + kernel_size, + stride, + padding, + dilation, + groups, + bias_tensor, + conv_config, + compute_config, + memory_config); }, py::kw_only(), py::arg("input_tensor"), @@ -87,25 +103,44 @@ void py_bind_conv2d(py::module& module) { py::arg("queue_id") = 0}, ttnn::pybind_overload_t{ - [](const decltype(ttnn::conv2d)& self, const ttnn::Tensor& input_tensor, - const ttnn::Tensor& weight_tensor, - ttnn::MeshDevice* device, - uint32_t in_channels, - uint32_t out_channels, - uint32_t batch_size, - uint32_t input_height, - uint32_t input_width, - std::array kernel_size, - std::array stride, - std::array padding, - std::array dilation, - uint32_t groups, - std::optional bias_tensor, - const std::optional& conv_config, - const std::optional& compute_config, - const std::optional& memory_config, - const uint8_t& queue_id) -> Result { - return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config, compute_config, memory_config); + [](const decltype(ttnn::conv2d)& self, + const ttnn::Tensor& input_tensor, + const ttnn::Tensor& weight_tensor, + ttnn::MeshDevice* device, + uint32_t in_channels, + uint32_t out_channels, + uint32_t batch_size, + uint32_t input_height, + uint32_t input_width, + std::array kernel_size, + std::array stride, + std::array padding, + std::array dilation, + uint32_t groups, + std::optional bias_tensor, + const std::optional& conv_config, + const std::optional& compute_config, + const std::optional& memory_config, + const uint8_t& queue_id) -> Result { + return self( + queue_id, + input_tensor, + weight_tensor, + device, + in_channels, + out_channels, + batch_size, + input_height, + input_width, + kernel_size, + stride, + padding, + dilation, + groups, + bias_tensor, + conv_config, + compute_config, + memory_config); }, py::kw_only(), py::arg("input_tensor"), @@ -125,8 +160,7 @@ void py_bind_conv2d(py::module& module) { py::arg("conv_config") = std::nullopt, py::arg("compute_config") = std::nullopt, py::arg("memory_config") = std::nullopt, - py::arg("queue_id") = 0} - ); + py::arg("queue_id") = 0}); module.def( "prepare_conv_weights", @@ -150,7 +184,6 @@ void py_bind_conv2d(py::module& module) { py::arg("conv_config") = std::nullopt, py::arg("compute_config") = std::nullopt); - module.def( "prepare_conv_weights", prepare_conv_weights, @@ -250,8 +283,17 @@ void py_bind_conv2d(py::module& module) { ShardOrientation block_shard_orientation, bool enable_channels_padding, bool is_out_tiled) -> ttnn::operations::sliding_window::ParallelConfig { - return determine_parallel_config( - shard_layout, batch_size, input_channels, output_height, output_width, output_channels, compute_grid_size, block_shard_orientation, enable_channels_padding, is_out_tiled); + return determine_parallel_config( + shard_layout, + batch_size, + input_channels, + output_height, + output_width, + output_channels, + compute_grid_size, + block_shard_orientation, + enable_channels_padding, + is_out_tiled); }, py::arg("shard_layout"), py::arg("batch_size"), @@ -271,50 +313,66 @@ void py_bind_conv2d(py::module& module) { py::arg("parallel_config"), py::arg("tile_size")); - auto py_conv_config = py::class_(module, "Conv2dConfig"); py_conv_config.def( - py::init, std::optional, bool, Layout, bool, bool, bool, bool>(), - py::kw_only(), - py::arg("dtype") = DataType::BFLOAT16, - py::arg("weights_dtype") = DataType::BFLOAT16, - py::arg("activation") = "", - py::arg("input_channels_alignment") = 32, - py::arg("deallocate_activation") = false, - py::arg("reallocate_halo_output") = false, - py::arg("act_block_h_override") = 0, - py::arg("act_block_w_div") = 1, - py::arg("reshard_if_not_optimal") = false, - py::arg("override_sharding_config") = false, - py::arg("shard_layout") = std::nullopt, - py::arg("core_grid") = std::nullopt, - py::arg("transpose_shards") = true, - py::arg("output_layout") = Layout::TILE, - py::arg("enable_act_double_buffer") = false, - py::arg("enable_weights_double_buffer") = false, - py::arg("enable_split_reader") = false, - py::arg("enable_subblock_padding") = false - ); - py_conv_config.def_readwrite("dtype", &Conv2dConfig::dtype); - py_conv_config.def_readwrite("weights_dtype", &Conv2dConfig::weights_dtype); - py_conv_config.def_readwrite("activation", &Conv2dConfig::activation); - py_conv_config.def_readwrite("input_channels_alignment", &Conv2dConfig::input_channels_alignment); - py_conv_config.def_readwrite("deallocate_activation", &Conv2dConfig::deallocate_activation); - py_conv_config.def_readwrite("reallocate_halo_output", &Conv2dConfig::reallocate_halo_output); - py_conv_config.def_readwrite("act_block_h_override", &Conv2dConfig::act_block_h_override); - py_conv_config.def_readwrite("act_block_w_div", &Conv2dConfig::act_block_w_div); - py_conv_config.def_readwrite("reshard_if_not_optimal", &Conv2dConfig::reshard_if_not_optimal); - py_conv_config.def_readwrite("override_sharding_config", &Conv2dConfig::override_sharding_config); - py_conv_config.def_readwrite("shard_layout", &Conv2dConfig::shard_layout); - py_conv_config.def_readwrite("core_grid", &Conv2dConfig::core_grid); - py_conv_config.def_readwrite("transpose_shards", &Conv2dConfig::transpose_shards); - py_conv_config.def_readwrite("output_layout", &Conv2dConfig::output_layout); - py_conv_config.def_readwrite("enable_act_double_buffer", &Conv2dConfig::enable_act_double_buffer); - py_conv_config.def_readwrite("enable_weights_double_buffer", &Conv2dConfig::enable_weights_double_buffer); - py_conv_config.def_readwrite("enable_split_reader", &Conv2dConfig::enable_split_reader); - py_conv_config.def_readwrite("enable_subblock_padding", &Conv2dConfig::enable_subblock_padding); + py::init< + DataType, + DataType, + string, + uint32_t, + bool, + bool, + uint32_t, + uint32_t, + bool, + bool, + std::optional, + std::optional, + bool, + Layout, + bool, + bool, + bool, + bool>(), + py::kw_only(), + py::arg("dtype") = DataType::BFLOAT16, + py::arg("weights_dtype") = DataType::BFLOAT16, + py::arg("activation") = "", + py::arg("input_channels_alignment") = 32, + py::arg("deallocate_activation") = false, + py::arg("reallocate_halo_output") = false, + py::arg("act_block_h_override") = 0, + py::arg("act_block_w_div") = 1, + py::arg("reshard_if_not_optimal") = false, + py::arg("override_sharding_config") = false, + py::arg("shard_layout") = std::nullopt, + py::arg("core_grid") = std::nullopt, + py::arg("transpose_shards") = true, + py::arg("output_layout") = Layout::TILE, + py::arg("enable_act_double_buffer") = false, + py::arg("enable_weights_double_buffer") = false, + py::arg("enable_split_reader") = false, + py::arg("enable_subblock_padding") = false); + py_conv_config.def_readwrite("dtype", &Conv2dConfig::dtype); + py_conv_config.def_readwrite("weights_dtype", &Conv2dConfig::weights_dtype); + py_conv_config.def_readwrite("activation", &Conv2dConfig::activation); + py_conv_config.def_readwrite("input_channels_alignment", &Conv2dConfig::input_channels_alignment); + py_conv_config.def_readwrite("deallocate_activation", &Conv2dConfig::deallocate_activation); + py_conv_config.def_readwrite("reallocate_halo_output", &Conv2dConfig::reallocate_halo_output); + py_conv_config.def_readwrite("act_block_h_override", &Conv2dConfig::act_block_h_override); + py_conv_config.def_readwrite("act_block_w_div", &Conv2dConfig::act_block_w_div); + py_conv_config.def_readwrite("reshard_if_not_optimal", &Conv2dConfig::reshard_if_not_optimal); + py_conv_config.def_readwrite("override_sharding_config", &Conv2dConfig::override_sharding_config); + py_conv_config.def_readwrite("shard_layout", &Conv2dConfig::shard_layout); + py_conv_config.def_readwrite("core_grid", &Conv2dConfig::core_grid); + py_conv_config.def_readwrite("transpose_shards", &Conv2dConfig::transpose_shards); + py_conv_config.def_readwrite("output_layout", &Conv2dConfig::output_layout); + py_conv_config.def_readwrite("enable_act_double_buffer", &Conv2dConfig::enable_act_double_buffer); + py_conv_config.def_readwrite("enable_weights_double_buffer", &Conv2dConfig::enable_weights_double_buffer); + py_conv_config.def_readwrite("enable_split_reader", &Conv2dConfig::enable_split_reader); + py_conv_config.def_readwrite("enable_subblock_padding", &Conv2dConfig::enable_subblock_padding); - py_conv_config.def("__repr__", [](const Conv2dConfig &config) { return fmt::format("{}", config);} ); + py_conv_config.def("__repr__", [](const Conv2dConfig& config) { return fmt::format("{}", config); }); py::class_(module, "OptimizedConvParallelizationConfig") .def( @@ -325,13 +383,13 @@ void py_bind_conv2d(py::module& module) { py::arg("num_cores_c") = 1, py::arg("per_core_out_matrix_height").noconvert(), py::arg("per_core_out_matrix_width").noconvert()) - .def_property_readonly("grid_size", [](OptimizedConvParallelizationConfig const& c) { return c.grid_size; }) + .def_property_readonly("grid_size", [](const OptimizedConvParallelizationConfig& c) { return c.grid_size; }) .def_property_readonly( - "num_cores_nhw", [](OptimizedConvParallelizationConfig const& c) { return c.num_cores_nhw; }) + "num_cores_nhw", [](const OptimizedConvParallelizationConfig& c) { return c.num_cores_nhw; }) .def_property_readonly( "per_core_out_matrix_height", - [](OptimizedConvParallelizationConfig const& c) { return c.per_core_out_matrix_height; }) - .def_property_readonly("per_core_out_matrix_width", [](OptimizedConvParallelizationConfig const& c) { + [](const OptimizedConvParallelizationConfig& c) { return c.per_core_out_matrix_height; }) + .def_property_readonly("per_core_out_matrix_width", [](const OptimizedConvParallelizationConfig& c) { return c.per_core_out_matrix_width; }); @@ -344,16 +402,15 @@ void py_bind_conv2d(py::module& module) { py::arg("out_subblock_h_ntiles").noconvert(), py::arg("out_subblock_w_ntiles").noconvert()) .def_property_readonly( - "act_block_h_ntiles", [](OptimizedConvBlockConfig const& c) { return c.act_block_h_ntiles; }) + "act_block_h_ntiles", [](const OptimizedConvBlockConfig& c) { return c.act_block_h_ntiles; }) .def_property_readonly( - "act_block_w_ntiles", [](OptimizedConvBlockConfig const& c) { return c.act_block_w_ntiles; }) + "act_block_w_ntiles", [](const OptimizedConvBlockConfig& c) { return c.act_block_w_ntiles; }) .def_property_readonly( - "out_subblock_h_ntiles", [](OptimizedConvBlockConfig const& c) { return c.out_subblock_h_ntiles; }) + "out_subblock_h_ntiles", [](const OptimizedConvBlockConfig& c) { return c.out_subblock_h_ntiles; }) .def_property_readonly( - "out_subblock_w_ntiles", [](OptimizedConvBlockConfig const& c) { return c.out_subblock_w_ntiles; }); - + "out_subblock_w_ntiles", [](const OptimizedConvBlockConfig& c) { return c.out_subblock_w_ntiles; }); } } // namespace conv2d -} // namespace operations +} // namespace operations::conv } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp index acd3453ecf5..8b399bc5ca5 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp @@ -23,25 +23,30 @@ using namespace tt; namespace ttnn { namespace operations::conv { -using sliding_window::SlidingWindowConfig; using sliding_window::ParallelConfig; - +using sliding_window::SlidingWindowConfig; uint32_t find_closest_largest_divisor(uint32_t num, uint32_t start_divisor) { uint32_t divisor = start_divisor; - while (num % divisor != 0) divisor = divisor - 1; + while (num % divisor != 0) { + divisor = divisor - 1; + } return divisor; } uint32_t find_closest_largest_divisor(uint32_t num1, uint32_t num2, uint32_t start_divisor) { uint32_t divisor = start_divisor; - while (num1 % divisor != 0 or num2 % divisor != 0) divisor = divisor - 1; + while (num1 % divisor != 0 or num2 % divisor != 0) { + divisor = divisor - 1; + } return divisor; } uint32_t find_closest_common_largest_divisor(uint32_t num1, uint32_t num2, uint32_t start_divisor) { uint32_t divisor = start_divisor; - while (num1 % divisor != 0 or num2 % divisor != 0) divisor = divisor - 1; + while (num1 % divisor != 0 or num2 % divisor != 0) { + divisor = divisor - 1; + } return divisor; } @@ -73,9 +78,10 @@ Tensor convert_conv_weight_tensor_to_tiled_layout( const Tensor& conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, - std::optional output_dtype){ - return tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout(std::move(conv_weight_tensor), in1_block_h, in1_block_w, output_dtype); - } + std::optional output_dtype) { + return tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout( + std::move(conv_weight_tensor), in1_block_h, in1_block_w, output_dtype); +} // Converts convolution weights to tilized 2d matrix layout with special block height padding // Returns a new tensor with layout=Tile @@ -83,13 +89,16 @@ Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( const Tensor& conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, - std::optional output_dtype){ - return tt::tt_metal::convert_conv_weight_tensor_to_special_padding_tiled_layout(std::move(conv_weight_tensor), in1_block_h, in1_block_w, output_dtype); - } + std::optional output_dtype) { + return tt::tt_metal::convert_conv_weight_tensor_to_special_padding_tiled_layout( + std::move(conv_weight_tensor), in1_block_h, in1_block_w, output_dtype); +} // Converts convolution weights to grouped layout with padded zeros -Tensor convert_conv_weight_tensor_to_grouped_layout(const Tensor& conv_weight_tensor, uint32_t num_groups, DataType output_dtype){ - return tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(std::move(conv_weight_tensor), num_groups, output_dtype); +Tensor convert_conv_weight_tensor_to_grouped_layout( + const Tensor& conv_weight_tensor, uint32_t num_groups, DataType output_dtype) { + return tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout( + std::move(conv_weight_tensor), num_groups, output_dtype); } ParallelConfig determine_parallel_config_non_tile_mul_width( @@ -101,7 +110,6 @@ ParallelConfig determine_parallel_config_non_tile_mul_width( uint32_t output_channels, const CoreCoord& compute_grid_size, ShardOrientation block_shard_orientation) { - uint32_t effective_tile_height = 1; uint32_t effective_tile_width = 1; CoreRangeSet grid; @@ -117,14 +125,11 @@ ParallelConfig determine_parallel_config_non_tile_mul_width( uint32_t cores_y = block_shard_orientation == ShardOrientation::COL_MAJOR ? num_cores_c : num_cores_nhw; CoreRange core_range = CoreRange(CoreCoord({0, 0}), CoreCoord({cores_x - 1, cores_y - 1})); grid = CoreRangeSet({core_range}); - auto shard_orientation = shard_layout == TensorMemoryLayout::BLOCK_SHARDED ? block_shard_orientation : ShardOrientation::ROW_MAJOR; - ParallelConfig pconfig = { - .grid = grid, - .shard_scheme = shard_layout, - .shard_orientation = block_shard_orientation}; + auto shard_orientation = + shard_layout == TensorMemoryLayout::BLOCK_SHARDED ? block_shard_orientation : ShardOrientation::ROW_MAJOR; + ParallelConfig pconfig = {.grid = grid, .shard_scheme = shard_layout, .shard_orientation = block_shard_orientation}; return pconfig; - } ParallelConfig determine_parallel_config( @@ -138,10 +143,10 @@ ParallelConfig determine_parallel_config( ShardOrientation block_shard_orientation, bool enable_channels_padding, bool is_out_tiled) { - uint32_t effective_tile_height = is_out_tiled ? tt::constants::TILE_HEIGHT : 1; uint32_t effective_tile_width = is_out_tiled ? tt::constants::TILE_WIDTH : 1; - uint32_t out_nhw_ntiles = tt::round_up(batch_size * output_height * output_width, tt::constants::TILE_HEIGHT) / effective_tile_height; + uint32_t out_nhw_ntiles = + tt::round_up(batch_size * output_height * output_width, tt::constants::TILE_HEIGHT) / effective_tile_height; uint32_t input_channles_ntiles = tt::div_up(input_channels, effective_tile_width); uint32_t out_channels_ntiles = tt::div_up(output_channels, effective_tile_width); @@ -153,7 +158,7 @@ ParallelConfig determine_parallel_config( grid = num_cores_to_corerangeset(num_cores_nhw, compute_grid_size, true); } else if (shard_layout == TensorMemoryLayout::BLOCK_SHARDED) { uint32_t start_divisor = - block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.x : compute_grid_size.y; + block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.x : compute_grid_size.y; uint32_t num_cores_nhw = find_closest_largest_divisor_with_num_padding(out_nhw_ntiles, start_divisor); uint32_t start_divisor_c = block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.y : compute_grid_size.x; @@ -175,11 +180,11 @@ ParallelConfig determine_parallel_config( TT_THROW("Conv2d supports Height, Block or Width Sharded Layouts but got {}", shard_layout); } - auto shard_orientation = shard_layout == TensorMemoryLayout::BLOCK_SHARDED ? block_shard_orientation : ShardOrientation::ROW_MAJOR; // NOTE: taking ROW_MAJOR as default orientation for HEIGHT_SHARDED and WIDTH_SHARDED - ParallelConfig pconfig = { - .grid = grid, - .shard_scheme = shard_layout, - .shard_orientation = shard_orientation }; + auto shard_orientation = shard_layout == TensorMemoryLayout::BLOCK_SHARDED + ? block_shard_orientation + : ShardOrientation::ROW_MAJOR; // NOTE: taking ROW_MAJOR as default orientation for + // HEIGHT_SHARDED and WIDTH_SHARDED + ParallelConfig pconfig = {.grid = grid, .shard_scheme = shard_layout, .shard_orientation = shard_orientation}; return pconfig; } @@ -213,7 +218,7 @@ uint32_t get_num_cores_nhw_from_parallel_config(const ParallelConfig& pconfig) { auto grid_size = pconfig.grid.bounding_box().grid_size(); uint32_t num_cores = pconfig.grid.num_cores(); uint32_t num_cores_nhw = 0; - if(pconfig.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED) { + if (pconfig.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED) { return 1; } @@ -240,7 +245,7 @@ uint32_t get_num_cores_channels_from_parallel_config(const ParallelConfig& pconf uint32_t num_cores_channels = 0; if (pconfig.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) { num_cores_channels = 1; - } else if(pconfig.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED) { + } else if (pconfig.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED) { num_cores_channels = pconfig.grid.num_cores(); } else if (pconfig.shard_orientation == ShardOrientation::COL_MAJOR) { num_cores_channels = grid_size.y; @@ -254,8 +259,12 @@ uint32_t get_num_cores_channels_from_parallel_config(const ParallelConfig& pconf MemoryConfig create_sharded_memory_config_from_parallel_config( const ttnn::Shape& tensor_shape, const ParallelConfig& parallel_config, uint32_t tile_size) { - - log_debug(tt::LogOp, "create_sharded_memory_config_from_parallel_config: tensor_shape: {}, parallel_config: {}, tile_size: {}", tensor_shape, parallel_config, tile_size); + log_debug( + tt::LogOp, + "create_sharded_memory_config_from_parallel_config: tensor_shape: {}, parallel_config: {}, tile_size: {}", + tensor_shape, + parallel_config, + tile_size); // tensor_shape is [N, H, W, C] TT_ASSERT(tensor_shape[0] == 1 && tensor_shape[1] == 1); // todo: add support for generic non-2d shapes // uint32_t channels = tensor_shape[3]; @@ -267,7 +276,7 @@ MemoryConfig create_sharded_memory_config_from_parallel_config( uint32_t nhw_shape = tensor_shape[0] * tensor_shape[1] * tensor_shape[2]; uint32_t nhw_padded = nhw_shape; - if(shard_scheme != TensorMemoryLayout::WIDTH_SHARDED) { + if (shard_scheme != TensorMemoryLayout::WIDTH_SHARDED) { nhw_padded = round_up(nhw_shape, num_cores_nhw * tile_size); } uint32_t nhw_shard = nhw_padded / num_cores_nhw; @@ -278,7 +287,6 @@ MemoryConfig create_sharded_memory_config_from_parallel_config( return MemoryConfig{shard_scheme, BufferType::L1, shard_spec}; } - OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_output_mem_config( const MemoryConfig& conv_output_mem_config, uint32_t num_cores_nhw, uint32_t num_cores_c) { TT_ASSERT(conv_output_mem_config.shard_spec.has_value()); @@ -336,14 +344,14 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( uint32_t window_w, bool fp32_accum, bool split_reader_enabled) { - if (act_block_h_override > 0) { TT_ASSERT( act_block_h_override % 32 == 0, "Config Error: act_block_h_override must be a multiple of 32 (tile height)."); } - uint32_t act_block_h_ntiles = div_up(conv_op_parallel_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT); + uint32_t act_block_h_ntiles = + div_up(conv_op_parallel_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT); if (act_block_h_override > 0) { if (parallel_config.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED) { @@ -368,17 +376,20 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( : parallel_config.shard_orientation == ShardOrientation::COL_MAJOR ? grid_size.y : grid_size.x; TT_ASSERT(padded_in_channels % act_c_num_blocks == 0); - uint32_t act_block_w = parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED - ? round_up(padded_in_channels * window_w, 32) - : round_up((padded_in_channels / act_c_num_blocks) * window_h * window_w, tt::constants::TILE_WIDTH); - if(parallel_config.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED) { + uint32_t act_block_w = + parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED + ? round_up(padded_in_channels * window_w, 32) + : round_up((padded_in_channels / act_c_num_blocks) * window_h * window_w, tt::constants::TILE_WIDTH); + if (parallel_config.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED) { TT_ASSERT(padded_in_channels % (32 * parallel_config.grid.num_cores() * act_block_w_div) == 0); - act_block_w = (padded_in_channels * window_h * window_w)/(parallel_config.grid.num_cores() * act_block_w_div); + act_block_w = (padded_in_channels * window_h * window_w) / (parallel_config.grid.num_cores() * act_block_w_div); } TT_ASSERT(act_block_w % 32 == 0); uint32_t act_block_w_ntiles = act_block_w / 32; - uint32_t out_block_h_ntiles = div_up(conv_op_parallel_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT); - uint32_t weight_block_w_ntiles = div_up(conv_op_parallel_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH); + uint32_t out_block_h_ntiles = + div_up(conv_op_parallel_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT); + uint32_t weight_block_w_ntiles = + div_up(conv_op_parallel_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH); auto [out_subblock_h_ntiles, out_subblock_w_ntiles] = determine_largest_subblock_size(act_block_h_ntiles, weight_block_w_ntiles, fp32_accum, split_reader_enabled); return { @@ -421,7 +432,8 @@ static TensorMemoryLayout select_shard_spec( out_channels, compute_grid_size, shard_orientation, - !is_mm_conv).grid.num_cores(); + !is_mm_conv) + .grid.num_cores(); }; // 1d convs support only height sharding @@ -442,7 +454,8 @@ static TensorMemoryLayout select_shard_spec( // the cores. const uint32_t max_num_cores = compute_grid_size.x * compute_grid_size.y; const uint32_t tree_quarter_cores = static_cast(0.75f * max_num_cores); - if ((cc_height > max_cc && max_cc < tree_quarter_cores) || (cc_height == max_cc && cc_height <= compute_grid_size.x)) { + if ((cc_height > max_cc && max_cc < tree_quarter_cores) || + (cc_height == max_cc && cc_height <= compute_grid_size.x)) { shard_layout = TensorMemoryLayout::HEIGHT_SHARDED; max_cc = cc_height; } @@ -524,16 +537,21 @@ static std::tuple get_conv_padded_i needs_shard_or_reshard = true; } if (conv_config.override_sharding_config) { - TT_FATAL(conv_config.core_grid.has_value(), "If override_sharding_config is set, core_grid must be set as well."); - TT_FATAL(conv_config.shard_layout.has_value(), "If override_sharding_config is set, shard_layout must be set as well."); + TT_FATAL( + conv_config.core_grid.has_value(), + "If override_sharding_config is set, core_grid must be set as well."); + TT_FATAL( + conv_config.shard_layout.has_value(), + "If override_sharding_config is set, shard_layout must be set as well."); if (conv_config.core_grid.value() != input_shard_grid) { needs_shard_or_reshard = true; } - if(shard_layout!=input_shard_scheme) { + if (shard_layout != input_shard_scheme) { needs_shard_or_reshard = true; } bool input_transpose_shards = input_shard_orientation == ShardOrientation::COL_MAJOR; - if (shard_layout == TensorMemoryLayout::BLOCK_SHARDED && conv_config.transpose_shards != input_transpose_shards) { + if (shard_layout == TensorMemoryLayout::BLOCK_SHARDED && + conv_config.transpose_shards != input_transpose_shards) { needs_shard_or_reshard = true; } } @@ -543,8 +561,11 @@ static std::tuple get_conv_padded_i // shallow conv variriant not supported // out_channels <= 256 incorrect output from pack_untilize_dst if output > 256 Tracking --> #14236 // bf8 not supported due to limation of sharding dim multipl of 32 - bool use_non_tile_height = (shard_layout == TensorMemoryLayout::HEIGHT_SHARDED) && out_channels <= 256 && conv_config.act_block_h_override == 0 && - (conv_config.dtype == DataType::BFLOAT16 || conv_config.dtype == DataType::FLOAT32) && conv_config.output_layout == Layout::ROW_MAJOR && conv_config.input_channels_alignment != 16; //shalow conv varient + bool use_non_tile_height = (shard_layout == TensorMemoryLayout::HEIGHT_SHARDED) && out_channels <= 256 && + conv_config.act_block_h_override == 0 && + (conv_config.dtype == DataType::BFLOAT16 || conv_config.dtype == DataType::FLOAT32) && + conv_config.output_layout == Layout::ROW_MAJOR && + conv_config.input_channels_alignment != 16; // shalow conv varient ParallelConfig parallel_config = input_tensor_parallel_config; if (conv_config.reshard_if_not_optimal || needs_shard_or_reshard) { @@ -578,9 +599,8 @@ static std::tuple get_conv_padded_i if (conv_config.override_sharding_config) { TT_FATAL(conv_config.core_grid.has_value(), "Core grid must be provided when overriding sharding config"); // override parallel config - auto shard_orientation = shard_layout == TensorMemoryLayout::BLOCK_SHARDED - ? block_shard_orientation - : ShardOrientation::ROW_MAJOR; + auto shard_orientation = shard_layout == TensorMemoryLayout::BLOCK_SHARDED ? block_shard_orientation + : ShardOrientation::ROW_MAJOR; parallel_config = { .grid = conv_config.core_grid.value(), .shard_scheme = shard_layout, @@ -604,13 +624,13 @@ static std::tuple get_conv_padded_i input_tensor_.layout() == Layout::ROW_MAJOR) { round_up_size = 1; } - uint32_t input_tensor_height_snapped_to_tile = tt::round_up(tensor_height, input_num_cores_nhw * round_up_size); + uint32_t input_tensor_height_snapped_to_tile = tt::round_up(tensor_height, input_num_cores_nhw * round_up_size); TT_ASSERT(input_tensor_height_snapped_to_tile >= tensor_height); uint32_t input_tensor_width_snapped_to_channels_alignment = tt::round_up(input_tensor.get_shape()[3], input_num_cores_c * conv_config.input_channels_alignment); - if(is_non_tile_mul_width) { + if (is_non_tile_mul_width) { input_tensor_width_snapped_to_channels_alignment = - tt::round_up(input_tensor.get_shape()[3], conv_config.input_channels_alignment); + tt::round_up(input_tensor.get_shape()[3], conv_config.input_channels_alignment); } auto input_padded_shape = ttnn::Shape(std::array{ @@ -618,7 +638,8 @@ static std::tuple get_conv_padded_i 1, input_tensor_height_snapped_to_tile, input_tensor_width_snapped_to_channels_alignment}); // TODO: resolve ttnn::types::Shape and - // tt::tt_metal::LegacyShape issue to clean up next line + // tt::tt_metal::LegacyShape issue to clean up next + // line MemoryConfig input_tensor_sharded_memory_config = create_sharded_memory_config_from_parallel_config( ttnn::Shape(std::array{ input_padded_shape[0], input_padded_shape[1], input_padded_shape[2], input_padded_shape[3]}), @@ -663,8 +684,7 @@ std::tuple shard_or_reshard_ ParallelConfig parallel_config = { .grid = input_tensor_sharded_memory_config.shard_spec.value().grid, .shard_scheme = input_tensor_sharded_memory_config.memory_layout, - .shard_orientation = input_tensor_sharded_memory_config.shard_spec.value().orientation - }; + .shard_orientation = input_tensor_sharded_memory_config.shard_spec.value().orientation}; ParallelConfig output_parallel_config = determine_output_parallel_config(parallel_config, compute_grid_size, out_channels, is_mm_conv); @@ -688,10 +708,11 @@ std::tuple shard_or_reshard_ if (input_padded_shape[-2] != tensor_height || input_padded_shape[-1] != tensor_width) { input_tensor = ttnn::pad( input_tensor, - tt::tt_metal::Array4D({input_tensor.get_shape()[0], - input_tensor.get_shape()[1], - input_padded_shape[-2], - input_padded_shape[-1]}), + tt::tt_metal::Array4D( + {input_tensor.get_shape()[0], + input_tensor.get_shape()[1], + input_padded_shape[-2], + input_padded_shape[-1]}), tt::tt_metal::Array4D({0, 0, 0, 0}), 0); } @@ -708,16 +729,17 @@ std::tuple shard_or_reshard_ ttnn::to_layout(input_tensor, Layout::TILE, std::nullopt, std::nullopt, input_tensor.device()); } if (!auto_shard_mm) { - auto resharded_input_tensor = ttnn::to_memory_config( - input_tensor, input_tensor_sharded_memory_config, std::nullopt); + auto resharded_input_tensor = + ttnn::to_memory_config(input_tensor, input_tensor_sharded_memory_config, std::nullopt); if (conv_config.deallocate_activation) { - input_tensor.deallocate(/*force*/true); + input_tensor.deallocate(/*force*/ true); resharded_input_tensor = ttnn::move(resharded_input_tensor); } input_tensor = resharded_input_tensor; } } else { - input_tensor = ttnn::to_device(input_tensor, device, (auto_shard_mm ? ttnn::DRAM_MEMORY_CONFIG : input_tensor_sharded_memory_config)); + input_tensor = ttnn::to_device( + input_tensor, device, (auto_shard_mm ? ttnn::DRAM_MEMORY_CONFIG : input_tensor_sharded_memory_config)); } } return {input_tensor, parallel_config, output_parallel_config, use_non_tile_height}; @@ -793,9 +815,10 @@ void adjust_conv_op_config_for_auto_shard_if_necessary( Conv2dConfig& conv_config, Layout input_tensor_layout, std::optional input_memory_config) { - - // If the input tensor is already sharded, or the conv_config has a specified shard layout, we don't need to do anything. - if ((input_memory_config.has_value() && input_memory_config.value().is_sharded()) || conv_config.shard_layout.has_value()) { + // If the input tensor is already sharded, or the conv_config has a specified shard layout, we don't need to do + // anything. + if ((input_memory_config.has_value() && input_memory_config.value().is_sharded()) || + conv_config.shard_layout.has_value()) { return; } @@ -815,7 +838,8 @@ void adjust_conv_op_config_for_auto_shard_if_necessary( if (conv_config.act_block_h_override == 0 && conv_config.shard_layout != TensorMemoryLayout::WIDTH_SHARDED) { if (in_channels <= constants::TILE_WIDTH / 2 && conv_config.input_channels_alignment == constants::TILE_WIDTH && - !is_mm_conv && conv_config.shard_layout == TensorMemoryLayout::HEIGHT_SHARDED && input_tensor_layout == Layout::ROW_MAJOR) { + !is_mm_conv && conv_config.shard_layout == TensorMemoryLayout::HEIGHT_SHARDED && + input_tensor_layout == Layout::ROW_MAJOR) { log_debug(LogOp, "Auto shard, enable shallow conv"); // height sharded, non matmul conv, with input channels <= 16, and default setting for // input_channels_alignment @@ -843,13 +867,11 @@ void adjust_conv_op_config_for_auto_shard_if_necessary( // Set act_block_w_div to max value to // be conservative with L1 memory usage. // act_block_w_div == 1 is currently the default value. - conv_config.act_block_w_div = - tt::div_up(in_channels, width_sharded_num_cores * constants::TILE_WIDTH); + conv_config.act_block_w_div = tt::div_up(in_channels, width_sharded_num_cores * constants::TILE_WIDTH); } } -template std::tuple -shard_or_reshard_tensor_if_required( +template std::tuple shard_or_reshard_tensor_if_required( Device* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, @@ -862,8 +884,7 @@ shard_or_reshard_tensor_if_required( bool auto_shard, bool is_non_tile_mul_width); -template std::tuple -shard_or_reshard_tensor_if_required( +template std::tuple shard_or_reshard_tensor_if_required( MeshDevice* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, @@ -881,5 +902,5 @@ std::ostream& operator<<(std::ostream& os, const Conv2dConfig& config) { return os; } -} // namespace operations +} // namespace operations::conv } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp index 69ce604a671..59e5e27a0c0 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp @@ -26,17 +26,18 @@ struct Conv2dConfig { uint32_t input_channels_alignment = 32; bool deallocate_activation = false; bool reallocate_halo_output = false; - uint32_t act_block_h_override = 0; // This argument is ignored when shard_layout == WIDTH_SHARDED. - uint32_t act_block_w_div = 1; // Amount by which the maximum possible act_block_width is divided. Max act_block_w = in_channels / (total_num_cores * TILE_WIDTH); - // Ignored when shard_layout == HEIGHT_SHARDED or BLOCK_SHARDED - bool reshard_if_not_optimal = false; // if true, override_sharding_config should not be set to true - bool override_sharding_config = false; // if true, reshard_if_not_optimal should not be set to true + uint32_t act_block_h_override = 0; // This argument is ignored when shard_layout == WIDTH_SHARDED. + uint32_t act_block_w_div = + 1; // Amount by which the maximum possible act_block_width is divided. Max act_block_w = in_channels / + // (total_num_cores * TILE_WIDTH); Ignored when shard_layout == HEIGHT_SHARDED or BLOCK_SHARDED + bool reshard_if_not_optimal = false; // if true, override_sharding_config should not be set to true + bool override_sharding_config = false; // if true, reshard_if_not_optimal should not be set to true std::optional shard_layout = std::nullopt; - std::optional core_grid = std::nullopt; // used only if override_sharding_config is true - bool transpose_shards = true; // used only if override_sharding_config is true and if height sharding is false + std::optional core_grid = std::nullopt; // used only if override_sharding_config is true + bool transpose_shards = true; // used only if override_sharding_config is true and if height sharding is false Layout output_layout = Layout::TILE; bool enable_act_double_buffer = false; - bool enable_weights_double_buffer = false; // Used on for block sharded convolutions + bool enable_weights_double_buffer = false; // Used on for block sharded convolutions bool enable_split_reader = false; bool enable_subblock_padding = false; static constexpr auto attribute_names = std::make_tuple( @@ -106,13 +107,14 @@ sliding_window::ParallelConfig determine_parallel_config( const CoreCoord& compute_grid_size, ShardOrientation block_shard_orientation, bool enable_channels_padding, - bool is_out_tiled=true); + bool is_out_tiled = true); uint32_t get_num_cores_nhw_from_parallel_config(const sliding_window::ParallelConfig& pconfig); uint32_t get_num_cores_channels_from_parallel_config(const sliding_window::ParallelConfig& pconfig); -MemoryConfig create_sharded_memory_config_from_parallel_config(const ttnn::Shape& tensor_shape, const sliding_window::ParallelConfig& parallel_config, uint32_t tile_size); +MemoryConfig create_sharded_memory_config_from_parallel_config( + const ttnn::Shape& tensor_shape, const sliding_window::ParallelConfig& parallel_config, uint32_t tile_size); OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_output_mem_config( const MemoryConfig& conv_output_mem_config, uint32_t num_cores_nhw, uint32_t num_cores_c); @@ -167,7 +169,7 @@ shard_or_reshard_tensor_if_required( uint32_t out_channels, bool is_mm_conv, bool auto_shard, - bool is_non_tile_mul_width=false); + bool is_non_tile_mul_width = false); // Converts convolution weights to tilized 2d matrix layout. // Returns a new tensor with layout=Tile @@ -186,9 +188,10 @@ Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( std::optional output_dtype = std::nullopt); // Converts convolution weights to grouped layout with padded zeros -Tensor convert_conv_weight_tensor_to_grouped_layout(const Tensor& conv_weight_tensor, uint32_t num_groups, DataType output_dtype); +Tensor convert_conv_weight_tensor_to_grouped_layout( + const Tensor& conv_weight_tensor, uint32_t num_groups, DataType output_dtype); std::ostream& operator<<(std::ostream& os, const Conv2dConfig& config); -} // namespace operations::conv -} // namespace ttnn +} // namespace operations::conv +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp index c432add1dc1..1f043e85a37 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp @@ -51,16 +51,20 @@ std::pair, std::vector> compute_opt_conv_activat return {{1, num_rows_padded, num_cols_padded}, {1, num_rows, num_cols}}; } -} // optimized_conv_op_utils +} // namespace optimized_conv_op_utils namespace ttnn::operations::conv { namespace conv2d { -Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optional bias, +Tensor optimized_conv_new( + const Tensor& a, + const Tensor& b, + std::optional bias, const sliding_window::SlidingWindowConfig& sliding_window_config, uint32_t output_channels, uint32_t groups, - bool untilize_out, bool fuse_relu, + bool untilize_out, + bool fuse_relu, const OptimizedConvParallelizationConfig& parallelization_config, const OptimizedConvBlockConfig& block_config, const MemoryConfig& memory_config, @@ -72,39 +76,87 @@ Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optional output_tensors = {Tensor(tt::tt_metal::operation::get_workers_for_op_output({a, b}))}; operation::launch_op( - [sliding_window_config, output_channels, groups, untilize_out, fuse_relu, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, compute_kernel_config, enable_act_double_buffer, enable_weights_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height] - (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { - using ttnn::operations::experimental::auto_format::FormatParams; - auto& a = input_tensors.at(0); - auto& b = input_tensors.at(1); - auto& bias = optional_input_tensors.at(0); - TT_FATAL(b.get_layout() == Layout::TILE, "Weights should be in TILE layout."); // Weights should already be formatted - const auto& ashape = tt::tt_metal::LegacyShape(input_tensor_shape); - auto padded_a_shape = ttnn::Shape(std::array{ashape[0], ashape[1], ashape[2], tt::round_up(ashape[3], 16)}); - FormatParams input_a_format_params = {.pad_shape=padded_a_shape.value, .pad_value=0.0, .target_layout=Layout::ROW_MAJOR}; - FormatParams input_b_format_params = {.pad_shape=b.get_legacy_shape(), .pad_value=0.0, .target_layout=Layout::TILE}; - FormatParams input_bias_format_params = {}; - if (bias.has_value()) { - input_bias_format_params = {.pad_shape=bias.value().get_legacy_shape(), .pad_value=0, .target_layout=Layout::TILE}; - } - auto output_layout = untilize_out ? Layout::ROW_MAJOR : Layout::TILE; - auto arch = is_tensor_on_device_or_multidevice(a) ? a.device()->arch() : ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch(); - bool fp32_accum = a.device()->arch() == tt::ARCH::WORMHOLE_B0; // && compute_kernel_config.has_value()) ? compute_kernel_config.value().fp32_dest_acc_en : false; - return operation::run_without_autoformat( - OptimizedConvNew(sliding_window_config, output_channels, groups, untilize_out, bias.has_value(), fuse_relu, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, compute_kernel_config, enable_act_double_buffer, enable_weights_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height - ), - input_tensors, - optional_input_tensors); - }, {a, b}, output_tensors, {std::move(bias)}); + [sliding_window_config, + output_channels, + groups, + untilize_out, + fuse_relu, + parallelization_config, + block_config, + memory_config, + dtype, + input_tensor_shape, + use_shallow_conv_variant, + compute_kernel_config, + enable_act_double_buffer, + enable_weights_double_buffer, + enable_split_reader, + enable_subblock_padding, + use_non_tile_height]( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector>& optional_output_tensors) mutable -> std::vector { + using ttnn::operations::experimental::auto_format::FormatParams; + auto& a = input_tensors.at(0); + auto& b = input_tensors.at(1); + auto& bias = optional_input_tensors.at(0); + TT_FATAL( + b.get_layout() == Layout::TILE, + "Weights should be in TILE layout."); // Weights should already be formatted + const auto& ashape = tt::tt_metal::LegacyShape(input_tensor_shape); + auto padded_a_shape = + ttnn::Shape(std::array{ashape[0], ashape[1], ashape[2], tt::round_up(ashape[3], 16)}); + FormatParams input_a_format_params = { + .pad_shape = padded_a_shape.value, .pad_value = 0.0, .target_layout = Layout::ROW_MAJOR}; + FormatParams input_b_format_params = { + .pad_shape = b.get_legacy_shape(), .pad_value = 0.0, .target_layout = Layout::TILE}; + FormatParams input_bias_format_params = {}; + if (bias.has_value()) { + input_bias_format_params = { + .pad_shape = bias.value().get_legacy_shape(), .pad_value = 0, .target_layout = Layout::TILE}; + } + auto output_layout = untilize_out ? Layout::ROW_MAJOR : Layout::TILE; + auto arch = is_tensor_on_device_or_multidevice(a) + ? a.device()->arch() + : ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch(); + bool fp32_accum = + a.device()->arch() == tt::ARCH::WORMHOLE_B0; // && compute_kernel_config.has_value()) ? + // compute_kernel_config.value().fp32_dest_acc_en : false; + return operation::run_without_autoformat( + OptimizedConvNew( + sliding_window_config, + output_channels, + groups, + untilize_out, + bias.has_value(), + fuse_relu, + parallelization_config, + block_config, + memory_config, + dtype, + input_tensor_shape, + use_shallow_conv_variant, + compute_kernel_config, + enable_act_double_buffer, + enable_weights_double_buffer, + enable_split_reader, + enable_subblock_padding, + use_non_tile_height), + input_tensors, + optional_input_tensors); + }, + {a, b}, + output_tensors, + {std::move(bias)}); return output_tensors.at(0); - } -void OptimizedConvNew::validate(const std::vector& input_tensors, const std::vector>& optional_input_tensors) const { +void OptimizedConvNew::validate( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors) const { const auto& input_tensor_a = input_tensors.at(0); const auto& input_tensor_b = input_tensors.at(1); TT_FATAL(input_tensor_a.memory_config().is_sharded(), "Activation tensor should be sharded."); @@ -113,8 +165,10 @@ void OptimizedConvNew::validate(const std::vector& input_tensors, const TT_FATAL((this->dtype == DataType::BFLOAT16) || (this->dtype == DataType::FLOAT32), "Error"); } if (this->memory_config.is_sharded()) { - uint32_t out_block_h_ntiles = optimized_conv_op_utils::div_up(parallelization_config.per_core_out_matrix_height, TILE_HEIGHT); - uint32_t per_core_out_matrix_width_ntiles = optimized_conv_op_utils::div_up(parallelization_config.per_core_out_matrix_width, TILE_WIDTH); + uint32_t out_block_h_ntiles = + optimized_conv_op_utils::div_up(parallelization_config.per_core_out_matrix_height, TILE_HEIGHT); + uint32_t per_core_out_matrix_width_ntiles = + optimized_conv_op_utils::div_up(parallelization_config.per_core_out_matrix_width, TILE_WIDTH); auto [act_matrix_shape, act_matrix_shape_unpadded] = optimized_conv_op_utils::compute_opt_conv_activation_as_mm_shape( input_tensor_a.get_legacy_shape(), @@ -122,9 +176,12 @@ void OptimizedConvNew::validate(const std::vector& input_tensors, const parallelization_config.num_cores_nhw, out_block_h_ntiles); uint32_t out_width_ntiles = this->compute_output_specs(input_tensors).at(0).padded_shape()[-1] / TILE_WIDTH; - if(this->memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { + if (this->memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { TT_FATAL(per_core_out_matrix_width_ntiles == out_width_ntiles, "Error"); - TT_FATAL(this->block_config.out_subblock_w_ntiles == out_width_ntiles || this->block_config.out_subblock_h_ntiles == 1, "Error"); + TT_FATAL( + this->block_config.out_subblock_w_ntiles == out_width_ntiles || + this->block_config.out_subblock_h_ntiles == 1, + "Error"); } else if (this->memory_config.memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { // For block sharded, out_width per core is shard width, and this is split along row // TODO: We should clean this up and relax constraints on out_subblock h and w @@ -134,7 +191,10 @@ void OptimizedConvNew::validate(const std::vector& input_tensors, const out_width_ntiles = tt::div_up(out_width_ntiles, this->parallelization_config.grid_size.x); } } - TT_FATAL(this->block_config.out_subblock_w_ntiles == per_core_out_matrix_width_ntiles || this->block_config.out_subblock_h_ntiles == 1, "Error"); + TT_FATAL( + this->block_config.out_subblock_w_ntiles == per_core_out_matrix_width_ntiles || + this->block_config.out_subblock_h_ntiles == 1, + "Error"); } } @@ -149,7 +209,10 @@ std::vector OptimizedConvNew::compute_output_specs(const std::vector // Tiled output shape is padded shape. Padded to tile shape. auto shape_w = batch_size * conv_output_h * conv_output_w; auto shape_c = output_channels; - auto padded_shape_w = this->use_non_tile_height ? parallelization_config.num_cores_nhw * parallelization_config.per_core_out_matrix_height : parallelization_config.num_cores_nhw * tt::round_up(parallelization_config.per_core_out_matrix_height, TILE_HEIGHT); + auto padded_shape_w = this->use_non_tile_height + ? parallelization_config.num_cores_nhw * parallelization_config.per_core_out_matrix_height + : parallelization_config.num_cores_nhw * + tt::round_up(parallelization_config.per_core_out_matrix_height, TILE_HEIGHT); auto padded_shape_c = tt::round_up(this->output_channels, TILE_WIDTH); auto output_padding = Padding( {{0, 0}, {0, 0}, {0, (padded_shape_w - shape_w)}, {0, (padded_shape_c - shape_c)}}, Padding::PadValue::Zero); @@ -161,51 +224,76 @@ std::vector OptimizedConvNew::compute_output_specs(const std::vector uint32_t total_height_tiles = tt::tt_metal::compute_volume(output_shape) / output_shape[-1] / TILE_HEIGHT; uint32_t num_cores; std::array shard_shape; - if(this->use_non_tile_height){ + if (this->use_non_tile_height) { num_cores = this->parallelization_config.num_cores_nhw; uint32_t total_height = tt::tt_metal::compute_volume(output_shape) / output_shape[-1]; shard_shape = {(uint32_t)(total_height / num_cores), output_shape[-1]}; - }else{ - num_cores = total_height_tiles / tt::div_up(this->parallelization_config.per_core_out_matrix_height, TILE_HEIGHT); - CoreRangeSet shard_grid = tt::tt_metal::num_cores_to_corerangeset(num_cores, this->parallelization_config.grid_size, true); + } else { + num_cores = total_height_tiles / + tt::div_up(this->parallelization_config.per_core_out_matrix_height, TILE_HEIGHT); + CoreRangeSet shard_grid = + tt::tt_metal::num_cores_to_corerangeset(num_cores, this->parallelization_config.grid_size, true); - shard_shape = {optimized_conv_op_utils::div_up(this->parallelization_config.per_core_out_matrix_height, TILE_HEIGHT) * TILE_HEIGHT, output_shape[-1]}; + shard_shape = { + optimized_conv_op_utils::div_up( + this->parallelization_config.per_core_out_matrix_height, TILE_HEIGHT) * + TILE_HEIGHT, + output_shape[-1]}; } - CoreRangeSet shard_grid = tt::tt_metal::num_cores_to_corerangeset(num_cores, this->parallelization_config.grid_size, true); + CoreRangeSet shard_grid = + tt::tt_metal::num_cores_to_corerangeset(num_cores, this->parallelization_config.grid_size, true); auto shard_spec = ShardSpec{shard_grid, shard_shape, ShardOrientation::ROW_MAJOR}; auto mem_config = this->memory_config; mem_config.shard_spec = shard_spec; - return {TensorSpec(output_shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(dtype, PageConfig(output_layout), mem_config, ttnn::Shape(output_shape)))}; - } else if(this->memory_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED) { + return {TensorSpec( + output_shape.logical_shape(), + TensorLayout::fromLegacyPaddedShape( + dtype, PageConfig(output_layout), mem_config, ttnn::Shape(output_shape)))}; + } else if (this->memory_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED) { uint32_t total_height_tiles = tt::tt_metal::compute_volume(output_shape) / output_shape[-1] / TILE_HEIGHT; - std::array shard_shape = {tt::div_up(this->parallelization_config.per_core_out_matrix_height, TILE_HEIGHT) * TILE_HEIGHT, tt::div_up(this->parallelization_config.per_core_out_matrix_width, TILE_WIDTH) * TILE_WIDTH}; + std::array shard_shape = { + tt::div_up(this->parallelization_config.per_core_out_matrix_height, TILE_HEIGHT) * TILE_HEIGHT, + tt::div_up(this->parallelization_config.per_core_out_matrix_width, TILE_WIDTH) * TILE_WIDTH}; auto shard_grid = this->memory_config.shard_spec.value().grid; auto shard_spec = ShardSpec{shard_grid, shard_shape, this->memory_config.shard_spec.value().orientation}; auto mem_config = this->memory_config; mem_config.shard_spec = shard_spec; - return {TensorSpec(output_shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(dtype, PageConfig(output_layout), mem_config, ttnn::Shape(output_shape)))}; + return {TensorSpec( + output_shape.logical_shape(), + TensorLayout::fromLegacyPaddedShape( + dtype, PageConfig(output_layout), mem_config, ttnn::Shape(output_shape)))}; } else if (this->memory_config.memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { - return {TensorSpec(output_shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(dtype, PageConfig(output_layout), memory_config, ttnn::Shape(output_shape)))}; + return {TensorSpec( + output_shape.logical_shape(), + TensorLayout::fromLegacyPaddedShape( + dtype, PageConfig(output_layout), memory_config, ttnn::Shape(output_shape)))}; } else { TT_THROW("Unsupported shard scheme"); } } - return {TensorSpec(output_shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(dtype, PageConfig(output_layout), memory_config, ttnn::Shape(output_shape)))}; + return {TensorSpec( + output_shape.logical_shape(), + TensorLayout::fromLegacyPaddedShape( + dtype, PageConfig(output_layout), memory_config, ttnn::Shape(output_shape)))}; } -operation::ProgramWithCallbacks OptimizedConvNew::create_program(const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - std::vector& output_tensors) const { +operation::ProgramWithCallbacks OptimizedConvNew::create_program( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + std::vector& output_tensors) const { const auto& input_tensor_a = input_tensors.at(0); const auto& input_tensor_b = input_tensors.at(1); const auto& input_tensor_bias = optional_input_tensors.at(0); auto& output_tensor = output_tensors.at(0); return multi_core_optimized_conv_sharded_v2_new( - input_tensor_a, input_tensor_b, input_tensor_bias, + input_tensor_a, + input_tensor_b, + input_tensor_bias, sliding_window_config, output_channels, groups, - untilize_out, fuse_relu, + untilize_out, + fuse_relu, parallelization_config, block_config, dtype, @@ -220,13 +308,16 @@ operation::ProgramWithCallbacks OptimizedConvNew::create_program(const std::vect use_non_tile_height); } -operation::OpPerformanceModel OptimizedConvNew::create_op_performance_model(const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector &output_tensors) const { +operation::OpPerformanceModel OptimizedConvNew::create_op_performance_model( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) const { const auto& input_tensor_a_shape = this->input_tensor_shape; uint32_t batch_size = input_tensor_a_shape[0]; uint32_t conv_activation_h = input_tensor_a_shape[1]; uint32_t conv_activation_w = input_tensor_a_shape[2]; uint32_t conv_activation_c = input_tensor_a_shape[3]; - uint32_t filter_h = (uint32_t)sliding_window_config.window_hw.first; // filter_h + uint32_t filter_h = (uint32_t)sliding_window_config.window_hw.first; // filter_h uint32_t filter_w = (uint32_t)sliding_window_config.window_hw.second; // filter_W uint32_t stride_h = (uint32_t)sliding_window_config.stride_hw.first; uint32_t stride_w = (uint32_t)sliding_window_config.stride_hw.second; @@ -234,11 +325,13 @@ operation::OpPerformanceModel OptimizedConvNew::create_op_performance_model(cons uint32_t pad_w = (uint32_t)sliding_window_config.pad_hw.second; const auto& t = output_tensors.at(0); - if(t.storage_type() != StorageType::DEVICE) { + if (t.storage_type() != StorageType::DEVICE) { tt::log_warning(tt::LogOp, "Output tensor not on DEVICE?!"); } - auto arch = t.storage_type() == StorageType::DEVICE ? t.device()->arch() : ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch(); + auto arch = t.storage_type() == StorageType::DEVICE + ? t.device()->arch() + : ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch(); const int num_cores = (arch == tt::ARCH::WORMHOLE_B0) ? 8 * 8 : 9 * 12; const int tensix_mul_adds_per_cycle_lofi = (arch == tt::ARCH::WORMHOLE_B0) ? 4096 : 2048; @@ -248,10 +341,12 @@ operation::OpPerformanceModel OptimizedConvNew::create_op_performance_model(cons // Calculate number of mul/add operations // TODO: add bias modeling - int64_t num_mul_adds_per_elem = conv_activation_c * filter_h * filter_w * 2; // 1 multiply and 1 add per element + int64_t num_mul_adds_per_elem = conv_activation_c * filter_h * filter_w * 2; // 1 multiply and 1 add per element int64_t num_mul_adds = num_mul_adds_per_elem * output_height * output_width * this->output_channels * batch_size; - int ideal_dev_clock_cycles = std::ceil(((float)num_mul_adds / (float)(num_cores * tensix_mul_adds_per_cycle_lofi)) * (float)operation::OpPerformanceModel::fidelity_multiplier(get_math_fidelity(this->compute_kernel_config))); + int ideal_dev_clock_cycles = std::ceil( + ((float)num_mul_adds / (float)(num_cores * tensix_mul_adds_per_cycle_lofi)) * + (float)operation::OpPerformanceModel::fidelity_multiplier(get_math_fidelity(this->compute_kernel_config))); operation::OpPerformanceModel result(input_tensors, output_tensors, ideal_dev_clock_cycles); @@ -268,6 +363,6 @@ operation::OpPerformanceModel OptimizedConvNew::create_op_performance_model(cons return result; } -} // namespace tt_metal +} // namespace conv2d -} // namespace tt +} // namespace ttnn::operations::conv diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp index a31b257fc8d..cd9ab6752a2 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp @@ -15,12 +15,10 @@ namespace operations::conv { namespace conv2d { // TODO: Accept parallelization -enum class OptimizedConvOpParallelizationStrategy { - MULTI_CORE, MULTI_CORE_REUSE, MULTI_CORE_REUSE_MCAST, SINGLE_CORE -}; +enum class OptimizedConvOpParallelizationStrategy { MULTI_CORE, MULTI_CORE_REUSE, MULTI_CORE_REUSE_MCAST, SINGLE_CORE }; struct OptimizedConvParallelizationConfig { - CoreCoord grid_size; // (x,y) + CoreCoord grid_size; // (x,y) uint32_t num_cores_nhw = 1; uint32_t num_cores_c = 1; uint32_t per_core_out_matrix_height = 1; @@ -31,9 +29,7 @@ struct OptimizedConvParallelizationConfig { // std::size_t per_core_M; // std::size_t per_core_N; - CoreCoord get_grid_size() const { - return this->grid_size; - } + CoreCoord get_grid_size() const { return this->grid_size; } }; struct OptimizedConvBlockConfig { @@ -43,11 +39,15 @@ struct OptimizedConvBlockConfig { uint32_t out_subblock_w_ntiles; }; -tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new(const Tensor& a, const Tensor &b, const std::optional& bias, +tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new( + const Tensor& a, + const Tensor& b, + const std::optional& bias, const sliding_window::SlidingWindowConfig& sliding_window_config, uint32_t output_channels, uint32_t groups, - bool untilize_out, bool fuse_relu, + bool untilize_out, + bool fuse_relu, const OptimizedConvParallelizationConfig& parallelization_config, const OptimizedConvBlockConfig& block_config, tt::tt_metal::DataType dtype, @@ -71,7 +71,7 @@ struct OptimizedConvNew { bool untilize_out, has_bias, fuse_relu; tt::tt_metal::MemoryConfig memory_config; const tt::tt_metal::DataType dtype; - std::array input_tensor_shape; // For sharded input, input tensor shape is nonsense + std::array input_tensor_shape; // For sharded input, input tensor shape is nonsense bool use_shallow_conv_variant; const DeviceComputeKernelConfig compute_kernel_config; bool enable_act_double_buffer; @@ -79,39 +79,57 @@ struct OptimizedConvNew { bool enable_split_reader; bool enable_subblock_padding; bool use_non_tile_height; - OptimizedConvNew(const sliding_window::SlidingWindowConfig& sliding_window_config, - uint32_t output_channels, uint32_t groups, + OptimizedConvNew( + const sliding_window::SlidingWindowConfig& sliding_window_config, + uint32_t output_channels, + uint32_t groups, bool untile_out, - bool has_bias, bool fuse_relu, + bool has_bias, + bool fuse_relu, const OptimizedConvParallelizationConfig& p_config, const OptimizedConvBlockConfig& b_config, tt::tt_metal::MemoryConfig memory_config, tt::tt_metal::DataType dtype, - std::array input_tensor_shape, bool use_shallow_conv_variant, - const DeviceComputeKernelConfig compute_kernel_config, bool enable_act_double_buffer, bool enable_weights_double_buffer, bool enable_split_reader, bool enable_subblock_padding, bool use_non_tile_height) : - output_channels(output_channels), - groups(groups), - sliding_window_config(sliding_window_config), - untilize_out(untile_out), - has_bias(has_bias), - fuse_relu(fuse_relu), - parallelization_config(p_config), - block_config(b_config), - memory_config(memory_config), - dtype(dtype), input_tensor_shape(input_tensor_shape), - use_shallow_conv_variant(use_shallow_conv_variant), - compute_kernel_config(compute_kernel_config), - enable_act_double_buffer(enable_act_double_buffer), - enable_weights_double_buffer(enable_weights_double_buffer), - enable_split_reader(enable_split_reader), - enable_subblock_padding(enable_subblock_padding), - use_non_tile_height(use_non_tile_height) {} - - void validate(const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; + std::array input_tensor_shape, + bool use_shallow_conv_variant, + const DeviceComputeKernelConfig compute_kernel_config, + bool enable_act_double_buffer, + bool enable_weights_double_buffer, + bool enable_split_reader, + bool enable_subblock_padding, + bool use_non_tile_height) : + output_channels(output_channels), + groups(groups), + sliding_window_config(sliding_window_config), + untilize_out(untile_out), + has_bias(has_bias), + fuse_relu(fuse_relu), + parallelization_config(p_config), + block_config(b_config), + memory_config(memory_config), + dtype(dtype), + input_tensor_shape(input_tensor_shape), + use_shallow_conv_variant(use_shallow_conv_variant), + compute_kernel_config(compute_kernel_config), + enable_act_double_buffer(enable_act_double_buffer), + enable_weights_double_buffer(enable_weights_double_buffer), + enable_split_reader(enable_split_reader), + enable_subblock_padding(enable_subblock_padding), + use_non_tile_height(use_non_tile_height) {} + + void validate( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors) const; std::vector compute_output_specs(const std::vector& input_tensors) const; - tt::tt_metal::operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, const std::vector>& optional_input_tensors, std::vector &output_tensors) const; + tt::tt_metal::operation::ProgramWithCallbacks create_program( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + std::vector& output_tensors) const; - tt::tt_metal::operation::OpPerformanceModel create_op_performance_model(const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector &output_tensors) const; + tt::tt_metal::operation::OpPerformanceModel create_op_performance_model( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) const; static constexpr auto attribute_names = std::make_tuple( "parallelization_config", @@ -147,11 +165,15 @@ struct OptimizedConvNew { } }; -Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optional bias, +Tensor optimized_conv_new( + const Tensor& a, + const Tensor& b, + std::optional bias, const sliding_window::SlidingWindowConfig& sliding_window_config, uint32_t output_channels, uint32_t groups, - bool untilize_out, bool fuse_relu, + bool untilize_out, + bool fuse_relu, const OptimizedConvParallelizationConfig& parallelization_config, const OptimizedConvBlockConfig& block_config, const tt::tt_metal::MemoryConfig& memory_config, @@ -163,8 +185,7 @@ Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optional, std::vector> compute_opt_conv_activat uint32_t num_cores_nhw, uint32_t act_block_h_ntiles); -} // optimized_conv_op_utils +} // namespace optimized_conv_op_utils diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp index f2ad3573b3f..c707661817c 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp @@ -18,7 +18,6 @@ #include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h" #define DEBUG_PRINT 0 - // #include "debug_macros.h" // SliceRange srt = SliceRange{.h0 = 0, .h1 = 4, .hs = 1, .w0 = 0, .w1 = 8, .ws = 1}; diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp index 0ba0363a9e6..7e68416803a 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp @@ -16,13 +16,14 @@ using namespace tt; namespace ttnn { namespace operations::conv { -using sliding_window::SlidingWindowConfig; using sliding_window::ParallelConfig; +using sliding_window::SlidingWindowConfig; namespace conv2d { void validate_weight_tensor(const ttnn::Tensor& weight_tensor) { - TT_FATAL(!ttnn::has_storage_type_of(weight_tensor, ttnn::DEVICE_STORAGE_TYPE), "conv weight should be placed on host"); + TT_FATAL( + !ttnn::has_storage_type_of(weight_tensor, ttnn::DEVICE_STORAGE_TYPE), "conv weight should be placed on host"); TT_FATAL(weight_tensor.get_layout() == Layout::ROW_MAJOR, "conv weight layout should be in row_major layout"); TT_FATAL(weight_tensor.get_shape().rank() == 4, "conv weight should be 4D tensor"); } @@ -43,12 +44,9 @@ void validate_weights_format(const std::string& weights_format) { } template -bool check_non_tile_mul_width( - T *device, - const Conv2dConfig& conv_config, - const uint32_t in_channels -){ - auto num_cores_c = conv_config.transpose_shards ? device->compute_with_storage_grid_size().y : device->compute_with_storage_grid_size().x; +bool check_non_tile_mul_width(T* device, const Conv2dConfig& conv_config, const uint32_t in_channels) { + auto num_cores_c = conv_config.transpose_shards ? device->compute_with_storage_grid_size().y + : device->compute_with_storage_grid_size().x; auto elem_size = conv_config.weights_dtype == DataType::BFLOAT8_B ? 1 : 2; bool is_non_tile_mul_width = (conv_config.shard_layout == TensorMemoryLayout::BLOCK_SHARDED) && conv_config.act_block_h_override == 0 && @@ -64,26 +62,28 @@ ttnn::Tensor conv_bias_layout_convert( uint32_t weight_block_h_ntiles, uint32_t weight_block_w_ntiles, const ParallelConfig& parallel_config, - T * device, + T* device, uint32_t out_channels, bool is_non_tile_mul_width) { ttnn::Tensor bias_tensor_ = bias_tensor; validate_bias_tensor(bias_tensor_); if (!is_non_tile_mul_width) { auto bias_shape = bias_tensor_.get_shape(); - TT_FATAL(bias_shape[3] == out_channels && bias_shape[0] == 1 && bias_shape[1] == 1 && bias_shape[2] == 1, "bias shape is not correct"); + TT_FATAL( + bias_shape[3] == out_channels && bias_shape[0] == 1 && bias_shape[1] == 1 && bias_shape[2] == 1, + "bias shape is not correct"); tt::tt_metal::LegacyShape bias_channels_padded_shape = tt::tt_metal::LegacyShape( std::array({1, 1, 32, round_up(out_channels, weight_block_w_ntiles * 32)})); - bias_tensor_ = ttnn::pad(bias_tensor_, bias_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D{0, 0, 0, 0}, 0); - bias_tensor_ = ttnn::to_layout( - bias_tensor_, Layout::TILE, std::nullopt, std::nullopt, (T*)nullptr); + bias_tensor_ = + ttnn::pad(bias_tensor_, bias_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D{0, 0, 0, 0}, 0); + bias_tensor_ = ttnn::to_layout(bias_tensor_, Layout::TILE, std::nullopt, std::nullopt, (T*)nullptr); if (bias_tensor_.get_dtype() != bias_dtype) { bias_tensor_ = ttnn::to_dtype(bias_tensor_, bias_dtype); } } else { uint32_t num_cores_channels = get_num_cores_channels_from_parallel_config(parallel_config); - bias_tensor_ = convert_conv_bias_tensor_to_tiled_layout_block_sharded( - bias_tensor_, num_cores_channels, bias_dtype); + bias_tensor_ = + convert_conv_bias_tensor_to_tiled_layout_block_sharded(bias_tensor_, num_cores_channels, bias_dtype); } return bias_tensor_; } @@ -99,7 +99,7 @@ static OptimizedConvBlockConfig get_opt_block_config( uint32_t input_width, std::array kernel_size, std::array stride, - T *device, + T* device, Conv2dConfig& conv_config, Layout input_tensor_layout, const DeviceComputeKernelConfig& compute_config, @@ -123,8 +123,10 @@ static OptimizedConvBlockConfig get_opt_block_config( ShardOrientation shard_orientation = conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; - bool use_non_tile_height = conv_config.shard_layout.value() == TensorMemoryLayout::HEIGHT_SHARDED && out_channels <= 256 && conv_config.act_block_h_override == 0 && - (conv_config.dtype == DataType::BFLOAT16 || conv_config.dtype == DataType::FLOAT32) && conv_config.output_layout == Layout::ROW_MAJOR; + bool use_non_tile_height = conv_config.shard_layout.value() == TensorMemoryLayout::HEIGHT_SHARDED && + out_channels <= 256 && conv_config.act_block_h_override == 0 && + (conv_config.dtype == DataType::BFLOAT16 || conv_config.dtype == DataType::FLOAT32) && + conv_config.output_layout == Layout::ROW_MAJOR; use_non_tile_height = use_non_tile_height && conv_config.input_channels_alignment != 16; ParallelConfig parallel_config = determine_parallel_config( @@ -139,23 +141,30 @@ static OptimizedConvBlockConfig get_opt_block_config( !use_non_tile_height); auto output_parallel_config = parallel_config; - if(conv_config.shard_layout.value() == ttnn::TensorMemoryLayout::WIDTH_SHARDED && !mm_conv) { + if (conv_config.shard_layout.value() == ttnn::TensorMemoryLayout::WIDTH_SHARDED && !mm_conv) { uint32_t max_num_cores = compute_grid_size.x * compute_grid_size.y; output_parallel_config = { - .grid = num_cores_to_corerangeset( find_closest_largest_divisor(tt::div_up(out_channels, tt::constants::TILE_WIDTH),max_num_cores), compute_grid_size, true), + .grid = num_cores_to_corerangeset( + find_closest_largest_divisor(tt::div_up(out_channels, tt::constants::TILE_WIDTH), max_num_cores), + compute_grid_size, + true), .shard_scheme = ttnn::TensorMemoryLayout::WIDTH_SHARDED, - .shard_orientation = parallel_config.shard_orientation - }; - log_debug(tt::LogOp, "Changing width sharded output grid to {}",output_parallel_config.grid); + .shard_orientation = parallel_config.shard_orientation}; + log_debug(tt::LogOp, "Changing width sharded output grid to {}", output_parallel_config.grid); } uint32_t round_up_size = !use_non_tile_height ? tt::constants::TILE_HEIGHT : 1; auto conv_out_memory_config = create_sharded_memory_config_from_parallel_config( - ttnn::Shape(std::array{1, 1, batch_size * output_height * output_width, tt::round_up(out_channels, 32)}), - output_parallel_config, round_up_size); - auto largest_parallel_config = output_parallel_config.grid.num_cores() > parallel_config.grid.num_cores() ? output_parallel_config : parallel_config; + ttnn::Shape( + std::array{1, 1, batch_size * output_height * output_width, tt::round_up(out_channels, 32)}), + output_parallel_config, + round_up_size); + auto largest_parallel_config = output_parallel_config.grid.num_cores() > parallel_config.grid.num_cores() + ? output_parallel_config + : parallel_config; auto opt_conv_op_parallel_config = determine_conv_op_parallel_config_from_conv_output_mem_config( - conv_out_memory_config, get_num_cores_nhw_from_parallel_config(largest_parallel_config), + conv_out_memory_config, + get_num_cores_nhw_from_parallel_config(largest_parallel_config), get_num_cores_channels_from_parallel_config(parallel_config)); uint32_t in_channels_padded = tt::round_up( @@ -163,7 +172,7 @@ static OptimizedConvBlockConfig get_opt_block_config( get_num_cores_channels_from_parallel_config(parallel_config) * conv_config.input_channels_alignment); uint32_t nhw_out_padded_ntile = get_num_cores_nhw_from_parallel_config(output_parallel_config) * - conv_out_memory_config.shard_spec.value().shape[0] / tt::constants::TILE_HEIGHT; + conv_out_memory_config.shard_spec.value().shape[0] / tt::constants::TILE_HEIGHT; return determine_per_core_conv_block_config( parallel_config, @@ -178,8 +187,6 @@ static OptimizedConvBlockConfig get_opt_block_config( conv_config.enable_split_reader); } - - template std::pair> prepare_conv_weights_biases_and_move_to_device( const ttnn::Tensor& weight_tensor, @@ -189,13 +196,12 @@ std::pair> prepare_conv_weights_biases uint32_t weight_block_h_ntiles, uint32_t weight_block_w_ntiles, const ParallelConfig& parallel_config, - T * device, + T* device, uint32_t groups, uint32_t act_block_h_ntiles, uint32_t input_width, const bool parameters_on_device, bool is_non_tile_mul_width) { - validate_weight_tensor(weight_tensor); ttnn::Tensor weight_tensor_; // tensor to return ttnn::Tensor bias_tensor_; @@ -213,15 +219,16 @@ std::pair> prepare_conv_weights_biases // Convert weight tensor to 0 padded shape if groups > 1 if (!is_conv1d and groups > 1) { - weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, weights_bias_dtype); - } - else if (is_conv1d and groups > 1) { + weight_tensor_ = + tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, weights_bias_dtype); + } else if (is_conv1d and groups > 1) { if (is_depthwise_conv) { - weight_tensor_ = convert_conv_weight_tensor_to_depthwise_layout(weight_tensor_, act_block_h_ntiles, weights_bias_dtype); + weight_tensor_ = + convert_conv_weight_tensor_to_depthwise_layout(weight_tensor_, act_block_h_ntiles, weights_bias_dtype); weight_block_h_ntiles = act_block_h_ntiles; - } - else{ - weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, weights_bias_dtype); + } else { + weight_tensor_ = + tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, weights_bias_dtype); } } @@ -236,11 +243,11 @@ std::pair> prepare_conv_weights_biases uint32_t in_channels_padded = tt::round_up(in_channels, num_cores_channels * input_channels_alignment); uint32_t out_channel_padding = out_channels_padded - out_channels; - tt::tt_metal::LegacyShape weights_channels_padded_shape = tt::tt_metal::LegacyShape(std::array( - {out_channels_padded, in_channels_padded, window_h, window_w})); - if(is_non_tile_mul_width) { + tt::tt_metal::LegacyShape weights_channels_padded_shape = tt::tt_metal::LegacyShape( + std::array({out_channels_padded, in_channels_padded, window_h, window_w})); + if (is_non_tile_mul_width) { weights_channels_padded_shape = tt::tt_metal::LegacyShape(std::array( - {round_up(out_channels, 32), round_up(in_channels, input_channels_alignment), window_h, window_w})); + {round_up(out_channels, 32), round_up(in_channels, input_channels_alignment), window_h, window_w})); out_channels_padded = tt::round_up(out_channels, 32); } if (weights_bias_dtype == DataType::BFLOAT8_B) { @@ -255,13 +262,14 @@ std::pair> prepare_conv_weights_biases TT_ASSERT(bias_tensor.value().get_dtype() == weights_bias_dtype); } } - weight_tensor_ = ttnn::pad(weight_tensor_, weights_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), 0); + weight_tensor_ = + ttnn::pad(weight_tensor_, weights_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), 0); // for conv op, pad the weights to block shape if (parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) { weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_special_padding_tiled_layout( weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype); - } else if(parallel_config.shard_scheme == TensorMemoryLayout::BLOCK_SHARDED) { + } else if (parallel_config.shard_scheme == TensorMemoryLayout::BLOCK_SHARDED) { weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout_block_sharded( weight_tensor_, num_cores_channels, weights_bias_dtype); } else { @@ -271,25 +279,34 @@ std::pair> prepare_conv_weights_biases uint32_t weight_matrix_height = in_channels * window_h * window_w; int32_t weight_matrix_height_padding = weight_tensor_.shape()[2] - weight_matrix_height; - TT_FATAL(weight_matrix_height_padding >= 0," Matrix Height Padding can't be negative"); + TT_FATAL(weight_matrix_height_padding >= 0, " Matrix Height Padding can't be negative"); - auto target_shape = ttnn::Shape(std::array{1, 1, weight_matrix_height, out_channels}, + auto target_shape = ttnn::Shape( + std::array{1, 1, weight_matrix_height, out_channels}, std::array, 4>{ std::array{0, 0}, std::array{0, 0}, std::array{0, weight_matrix_height_padding}, - std::array{0, out_channel_padding} - }); + std::array{0, out_channel_padding}}); weight_tensor_ = ttnn::reshape(weight_tensor_, target_shape); - if(parameters_on_device) + if (parameters_on_device) { weight_tensor_ = ttnn::operations::core::to_device(weight_tensor_, device, std::nullopt); + } if (bias_tensor.has_value()) { bias_tensor_ = bias_tensor.value(); bool is_bias_tensor_is_on_device = ttnn::is_tensor_on_device_or_multidevice(bias_tensor_); - if(!is_bias_tensor_is_on_device) { - bias_tensor_ = conv_bias_layout_convert(bias_tensor_, weights_bias_dtype, weight_block_h_ntiles, weight_block_w_ntiles, parallel_config, device, out_channels, is_non_tile_mul_width); + if (!is_bias_tensor_is_on_device) { + bias_tensor_ = conv_bias_layout_convert( + bias_tensor_, + weights_bias_dtype, + weight_block_h_ntiles, + weight_block_w_ntiles, + parallel_config, + device, + out_channels, + is_non_tile_mul_width); bias_tensor_ = ttnn::operations::core::to_device(bias_tensor_, device, std::nullopt); } } @@ -300,7 +317,7 @@ std::pair> prepare_conv_weights_biases template ttnn::Tensor prepare_conv_weights( const ttnn::Tensor& weight_tensor, - const ttnn::MemoryConfig &input_memory_config, + const ttnn::MemoryConfig& input_memory_config, Layout input_tensor_layout, const std::string& weights_format, uint32_t in_channels, @@ -313,21 +330,18 @@ ttnn::Tensor prepare_conv_weights( std::array padding, std::array dilation, uint32_t groups, - T *device, + T* device, const std::optional& conv_config_, const std::optional& compute_config_) { - TT_FATAL(!ttnn::is_tensor_on_device_or_multidevice(weight_tensor), "Error: weight tensor must be on host for preparation."); + TT_FATAL( + !ttnn::is_tensor_on_device_or_multidevice(weight_tensor), + "Error: weight tensor must be on host for preparation."); Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); - DeviceComputeKernelConfig compute_config = compute_config_.value_or(init_device_compute_kernel_config( - device->arch(), - std::nullopt, - MathFidelity::HiFi4, - true, - false, - false - )); + DeviceComputeKernelConfig compute_config = compute_config_.value_or( + init_device_compute_kernel_config(device->arch(), std::nullopt, MathFidelity::HiFi4, true, false, false)); const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups); - const uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; + const uint32_t output_height = + ((input_height - kernel_size[0] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; const uint32_t output_width = ((input_width - kernel_size[1] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1; auto opt_conv_op_block_config = get_opt_block_config( @@ -344,14 +358,15 @@ ttnn::Tensor prepare_conv_weights( conv_config, input_tensor_layout, compute_config, - input_memory_config - ); + input_memory_config); ShardOrientation shard_orientation = conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; - bool use_non_tile_height = conv_config.shard_layout.value() == TensorMemoryLayout::HEIGHT_SHARDED && out_channels <= 256 && conv_config.act_block_h_override == 0 && - (conv_config.dtype == DataType::BFLOAT16 || conv_config.dtype == DataType::FLOAT32) && conv_config.output_layout == Layout::ROW_MAJOR; + bool use_non_tile_height = conv_config.shard_layout.value() == TensorMemoryLayout::HEIGHT_SHARDED && + out_channels <= 256 && conv_config.act_block_h_override == 0 && + (conv_config.dtype == DataType::BFLOAT16 || conv_config.dtype == DataType::FLOAT32) && + conv_config.output_layout == Layout::ROW_MAJOR; use_non_tile_height = use_non_tile_height && conv_config.input_channels_alignment != 16; ParallelConfig parallel_config = determine_parallel_config( @@ -402,26 +417,21 @@ ttnn::Tensor prepare_conv_bias( std::array padding, std::array dilation, uint32_t groups, - T *device, + T* device, const std::optional& conv_config_, const std::optional& compute_config_) { - - TT_FATAL(!ttnn::is_tensor_on_device_or_multidevice(bias_tensor), "Error: bias tensor must be on host for preparation."); + TT_FATAL( + !ttnn::is_tensor_on_device_or_multidevice(bias_tensor), "Error: bias tensor must be on host for preparation."); const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups); - const uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; + const uint32_t output_height = + ((input_height - kernel_size[0] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; const uint32_t output_width = ((input_width - kernel_size[1] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1; Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); - DeviceComputeKernelConfig compute_config = compute_config_.value_or(init_device_compute_kernel_config( - device->arch(), - std::nullopt, - MathFidelity::HiFi4, - true, - false, - false - )); + DeviceComputeKernelConfig compute_config = compute_config_.value_or( + init_device_compute_kernel_config(device->arch(), std::nullopt, MathFidelity::HiFi4, true, false, false)); auto opt_conv_op_block_config = get_opt_block_config( mm_conv, in_channels, @@ -436,15 +446,16 @@ ttnn::Tensor prepare_conv_bias( conv_config, input_tensor_layout, compute_config, - input_memory_config - ); + input_memory_config); uint32_t weight_block_w_ntiles = opt_conv_op_block_config.out_subblock_w_ntiles; ShardOrientation shard_orientation = conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; - bool use_non_tile_height = conv_config.shard_layout.value() == TensorMemoryLayout::HEIGHT_SHARDED && out_channels <= 256 && conv_config.act_block_h_override == 0 && - (conv_config.dtype == DataType::BFLOAT16 || conv_config.dtype == DataType::FLOAT32) && conv_config.output_layout == Layout::ROW_MAJOR; + bool use_non_tile_height = conv_config.shard_layout.value() == TensorMemoryLayout::HEIGHT_SHARDED && + out_channels <= 256 && conv_config.act_block_h_override == 0 && + (conv_config.dtype == DataType::BFLOAT16 || conv_config.dtype == DataType::FLOAT32) && + conv_config.output_layout == Layout::ROW_MAJOR; use_non_tile_height = use_non_tile_height && conv_config.input_channels_alignment != 16; ParallelConfig parallel_config = determine_parallel_config( @@ -468,8 +479,7 @@ ttnn::Tensor prepare_conv_bias( parallel_config, device, out_channels, - is_non_tile_mul_width - ); + is_non_tile_mul_width); return bias_tensor_; } @@ -483,7 +493,7 @@ template OptimizedConvBlockConfig get_opt_block_config( uint32_t input_width, std::array kernel_size, std::array stride, - Device *device, + Device* device, Conv2dConfig& conv_config, Layout input_tensor_layout, const DeviceComputeKernelConfig& compute_config, @@ -499,7 +509,7 @@ template OptimizedConvBlockConfig get_opt_block_config( uint32_t input_width, std::array kernel_size, std::array stride, - MeshDevice *device, + MeshDevice* device, Conv2dConfig& conv_config, Layout input_tensor_layout, const DeviceComputeKernelConfig& compute_config, @@ -520,7 +530,7 @@ template ttnn::Tensor prepare_conv_weights( std::array padding, std::array dilation, uint32_t groups, - Device *device, + Device* device, const std::optional& conv_config_, const std::optional& compute_config_); @@ -539,7 +549,7 @@ template ttnn::Tensor prepare_conv_weights( std::array padding, std::array dilation, uint32_t groups, - MeshDevice *device, + MeshDevice* device, const std::optional& conv_config_, const std::optional& compute_config_); @@ -558,7 +568,8 @@ template std::pair> prepare_conv_weigh const bool parameters_on_device, bool is_non_tile_mul_width); -template std::pair> prepare_conv_weights_biases_and_move_to_device( +template std::pair> +prepare_conv_weights_biases_and_move_to_device( const ttnn::Tensor& weight_tensor, std::optional& bias_tensor, uint32_t input_channels_alignment, @@ -587,7 +598,7 @@ template ttnn::Tensor prepare_conv_bias( std::array padding, std::array dilation, uint32_t groups, - Device *device, + Device* device, const std::optional& conv_config_, const std::optional& compute_config_); @@ -605,7 +616,7 @@ template ttnn::Tensor prepare_conv_bias( std::array padding, std::array dilation, uint32_t groups, - MeshDevice *device, + MeshDevice* device, const std::optional& conv_config_, const std::optional& compute_config_); @@ -615,7 +626,7 @@ template ttnn::Tensor conv_bias_layout_convert( uint32_t weight_block_h_ntiles, uint32_t weight_block_w_ntiles, const sliding_window::ParallelConfig& parallel_config, - Device * device, + Device* device, uint32_t out_channels, bool is_non_tile_mul_width); @@ -630,17 +641,11 @@ template ttnn::Tensor conv_bias_layout_convert( bool is_non_tile_mul_width); template bool check_non_tile_mul_width( - Device *device, - const Conv2dConfig& conv_config, - const uint32_t in_channels -); + Device* device, const Conv2dConfig& conv_config, const uint32_t in_channels); template bool check_non_tile_mul_width( - MeshDevice *device, - const Conv2dConfig& conv_config, - const uint32_t in_channels -); + MeshDevice* device, const Conv2dConfig& conv_config, const uint32_t in_channels); } // namespace conv2d -} // namespace operations +} // namespace operations::conv } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp index 221a9d230f5..d901177fe1d 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp @@ -11,7 +11,7 @@ #include "ttnn/operations/conv/conv2d/conv2d_utils.hpp" namespace ttnn::operations::sliding_window { - struct ParallelConfig; +struct ParallelConfig; } namespace ttnn { @@ -26,7 +26,7 @@ ttnn::Tensor conv_bias_layout_convert( uint32_t weight_block_h_ntiles, uint32_t weight_block_w_ntiles, const sliding_window::ParallelConfig& parallel_config, - T * device, + T* device, uint32_t out_channels, bool is_non_tile_mul_width); @@ -46,7 +46,7 @@ ttnn::Tensor prepare_conv_weights( std::array padding, std::array dilation, uint32_t groups, - T *device, + T* device, const std::optional& conv_config_, const std::optional& compute_config_); @@ -65,7 +65,7 @@ ttnn::Tensor prepare_conv_bias( std::array padding, std::array dilation, uint32_t groups, - T *device, + T* device, const std::optional& conv_config_, const std::optional& compute_config_); @@ -78,19 +78,16 @@ std::pair> prepare_conv_weights_biases uint32_t weight_block_h_ntiles, uint32_t weight_block_w_ntiles, const sliding_window::ParallelConfig& parallel_config, - T * device, + T* device, uint32_t groups, uint32_t act_block_h_ntiles, uint32_t input_width, - const bool parameters_on_device=true, - bool is_non_tile_mul_width=false); + const bool parameters_on_device = true, + bool is_non_tile_mul_width = false); template -bool check_non_tile_mul_width( - T* device, - const Conv2dConfig& conv_config, - const uint32_t in_channels); +bool check_non_tile_mul_width(T* device, const Conv2dConfig& conv_config, const uint32_t in_channels); -} // namespace conv2d -} // namespace operations::conv -} // namespace ttnn +} // namespace conv2d +} // namespace operations::conv +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp index ce8001b4879..a2700b26e55 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp @@ -16,25 +16,23 @@ using namespace tt; namespace ttnn { namespace operations::conv { -using sliding_window::SlidingWindowConfig; using sliding_window::ParallelConfig; +using sliding_window::SlidingWindowConfig; namespace conv_transpose2d { template -Tensor _transform_weights_for_conv_transpose2d( - const Tensor& conv_weight_tensor, - bool mirror_kernel = true) { +Tensor _transform_weights_for_conv_transpose2d(const Tensor& conv_weight_tensor, bool mirror_kernel = true) { auto in_w_shape = conv_weight_tensor.get_legacy_shape(); auto dtype = conv_weight_tensor.dtype(); // in_w_shape = {in_channels, out_channels, kernel_height, kernel_width} // out_w_shape = {out_channels, in_channels, kernel_height, kernel_width} - //Flip kernel_height and kernel_width + // Flip kernel_height and kernel_width auto compute = [&in_w_shape, &dtype, mirror_kernel](const auto& input_buffer) { - auto in_channels = in_w_shape[0]; - auto out_channels = in_w_shape[1]; + auto in_channels = in_w_shape[0]; + auto out_channels = in_w_shape[1]; auto kernel_height = in_w_shape[2]; - auto kernel_width = in_w_shape[3]; + auto kernel_width = in_w_shape[3]; ttnn::SimpleShape output_shape{out_channels, in_channels, kernel_height, kernel_width}; auto output_buffer = owned_buffer::create(output_shape.volume()); @@ -42,18 +40,24 @@ Tensor _transform_weights_for_conv_transpose2d( auto output_weight_out_channel_base_idx = out_channels_index * in_channels * kernel_height * kernel_width; auto input_weight_out_channel_base_idx = out_channels_index * kernel_height * kernel_width; for (auto in_channels_index = 0; in_channels_index < in_channels; in_channels_index++) { - auto output_weight_in_channel_base_idx = in_channels_index * kernel_height * kernel_width; + auto output_weight_in_channel_base_idx = in_channels_index * kernel_height * kernel_width; auto input_weight_in_channel_base_idx = in_channels_index * kernel_height * kernel_width * out_channels; - for (auto in_kernel_height_index = 0; in_kernel_height_index < kernel_height; in_kernel_height_index++) { - auto out_buffer_kh_index = mirror_kernel ? kernel_height - in_kernel_height_index - 1 : in_kernel_height_index; + for (auto in_kernel_height_index = 0; in_kernel_height_index < kernel_height; + in_kernel_height_index++) { + auto out_buffer_kh_index = + mirror_kernel ? kernel_height - in_kernel_height_index - 1 : in_kernel_height_index; auto in_height_offset = in_kernel_height_index * kernel_width; auto out_height_offset = out_buffer_kh_index * kernel_width; - for (auto in_kernel_width_index = 0; in_kernel_width_index < kernel_width; in_kernel_width_index++) { - auto out_buffer_kw_index = mirror_kernel ? kernel_width - in_kernel_width_index - 1 : in_kernel_width_index; + for (auto in_kernel_width_index = 0; in_kernel_width_index < kernel_width; + in_kernel_width_index++) { + auto out_buffer_kw_index = + mirror_kernel ? kernel_width - in_kernel_width_index - 1 : in_kernel_width_index; - auto in_idx = input_weight_out_channel_base_idx + input_weight_in_channel_base_idx + in_height_offset + in_kernel_width_index; - auto out_idx = output_weight_out_channel_base_idx + output_weight_in_channel_base_idx + out_height_offset + out_buffer_kw_index; + auto in_idx = input_weight_out_channel_base_idx + input_weight_in_channel_base_idx + + in_height_offset + in_kernel_width_index; + auto out_idx = output_weight_out_channel_base_idx + output_weight_in_channel_base_idx + + out_height_offset + out_buffer_kw_index; output_buffer[out_idx] = input_buffer[in_idx]; } @@ -76,10 +80,10 @@ Tensor _transform_weights_for_conv_transpose2d( }, conv_weight_tensor.get_storage()); }; - return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor) : convert_tensor(conv_weight_tensor); + return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor) + : convert_tensor(conv_weight_tensor); } - Tensor transform_weights_for_conv_transpose2d(const Tensor& conv_weight_tensor, bool mirror_kernel) { switch (conv_weight_tensor.get_dtype()) { case DataType::BFLOAT16: @@ -88,15 +92,17 @@ Tensor transform_weights_for_conv_transpose2d(const Tensor& conv_weight_tensor, return _transform_weights_for_conv_transpose2d(conv_weight_tensor, mirror_kernel); case DataType::UINT32: return _transform_weights_for_conv_transpose2d(conv_weight_tensor, mirror_kernel); - default: TT_THROW("Unsupported data type for transform_weights_for_conv_transpose2d",conv_weight_tensor.get_dtype()); + default: + TT_THROW( + "Unsupported data type for transform_weights_for_conv_transpose2d", conv_weight_tensor.get_dtype()); } }; -template +template Result conv_transpose2d( const ttnn::Tensor& input_tensor, const ttnn::Tensor& weight_tensor, - T * device, + T* device, uint32_t in_channels, uint32_t out_channels, uint32_t batch_size, @@ -113,243 +119,238 @@ Result conv_transpose2d( const std::optional& compute_config_, const std::optional& memory_config, bool mirror_kernel) { - Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); - DeviceComputeKernelConfig compute_config = compute_config_.value_or(init_device_compute_kernel_config( - device->arch(), - std::nullopt, - MathFidelity::HiFi4, - true, - false, - false - )); - - //Inverse of sliding_window.get_output_shape() - SlidingWindowConfig sliding_window_config = SlidingWindowConfig{ - .batch_size = batch_size, - .input_hw = {input_height, input_width}, - .window_hw = {kernel_size[0], kernel_size[1]}, - .stride_hw = {stride[0], stride[1]}, - .pad_hw = {padding[0], padding[1]}, - .output_pad_hw = {output_padding[0], output_padding[1]}, - .dilation_hw = {dilation[0], dilation[1]}, - .is_transpose = true - }; - - - // ConvTranspose2d is implemented via the Conv2d u_op with flipped weights. - //The input tensor is first passed to the halo op that paddeds the input. - //In the scenario, where stride > 1, the halo op will add interleaved 0s to the input tensor. - //The Conv2d u_op is then called with stride = 1, padding = 0. - //SlidingWindowConfig has a is_transpose flag that is set to true to indicate that the Conv2d u_op & Halo u_op is being called for ConvTranspose2d. - uint32_t output_height = (input_height - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1; - uint32_t output_width = (input_width - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1; - - //Dimensions of Input to Conv u_op - uint32_t full_input_height = output_height + dilation[0] * (kernel_size[0] - 1); - uint32_t full_input_width = output_width + dilation[1] * (kernel_size[1] - 1); - - //Size of input after adding interleaved 0s. - uint32_t strided_input_height = (input_height - 1) * stride[0] + 1; - uint32_t strided_input_width = (input_width - 1) * stride[1] + 1; - - uint32_t input_pad_top = (full_input_height - strided_input_height)/2; - uint32_t input_pad_bottom = full_input_height - strided_input_height - input_pad_top; - - uint32_t input_pad_left = (full_input_width - strided_input_width)/2; - uint32_t input_pad_right = full_input_width - strided_input_width - input_pad_left; - - log_debug(LogOp, "Input : {}x{}", input_height, input_width); - log_debug(LogOp, "Output : {}x{}", output_height, output_width); - - log_debug(LogOp, "Conv Op Input : {}x{}", full_input_height, full_input_width); - log_debug(LogOp, "Strided Input : {}x{}", strided_input_height, strided_input_width); - - log_debug(LogOp, "Padding : ({},{}) ({},{})", input_pad_top, input_pad_bottom, input_pad_left, input_pad_right); - - const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, {1, 1}, {input_pad_top + input_pad_bottom, input_pad_left + input_pad_right}, dilation, groups); - - const auto compute_grid_size = device->compute_with_storage_grid_size(); - - bool auto_shard = false; - if (!input_tensor.is_sharded() && !conv_config.shard_layout.has_value()) { - // In this case we deduce the shard layout. - adjust_conv_op_config_for_auto_shard_if_necessary( - mm_conv, - batch_size, - in_channels, - out_channels, - output_height, - output_width, - weight_tensor.get_shape()[3], - full_input_width, - compute_grid_size, - conv_config, - input_tensor.layout(), - ttnn::is_tensor_on_device_or_multidevice(input_tensor) - ? std::make_optional(input_tensor.memory_config()) - : std::nullopt); - auto_shard = true; - } - - - //Call Halo Transpose - auto [input_tensor_post_tm, parallel_config, output_parallel_config, use_non_tile_height] = - shard_or_reshard_tensor_if_required( - device, - input_tensor, - conv_config, - batch_size, - output_height, - output_width, - in_channels, - out_channels, - mm_conv, - auto_shard - ); - - uint32_t round_up_size = !use_non_tile_height ? tt::constants::TILE_HEIGHT : 1; - - Tensor halo_output; - if (!mm_conv) { - sliding_window_config.num_cores_nhw = get_num_cores_nhw_from_parallel_config(parallel_config); - sliding_window_config.core_range_set = input_tensor_post_tm.memory_config().shard_spec.value().grid; - sliding_window_config.snap_to_tile = !use_non_tile_height; - - halo_output = ttnn::halo( - DefaultQueueId, - input_tensor_post_tm, - sliding_window_config, - 0, - false, - parallel_config.shard_orientation == ShardOrientation::COL_MAJOR, - 0, - input_tensor_post_tm.memory_config()); + Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); + DeviceComputeKernelConfig compute_config = compute_config_.value_or( + init_device_compute_kernel_config(device->arch(), std::nullopt, MathFidelity::HiFi4, true, false, false)); + + // Inverse of sliding_window.get_output_shape() + SlidingWindowConfig sliding_window_config = SlidingWindowConfig{ + .batch_size = batch_size, + .input_hw = {input_height, input_width}, + .window_hw = {kernel_size[0], kernel_size[1]}, + .stride_hw = {stride[0], stride[1]}, + .pad_hw = {padding[0], padding[1]}, + .output_pad_hw = {output_padding[0], output_padding[1]}, + .dilation_hw = {dilation[0], dilation[1]}, + .is_transpose = true}; + + // ConvTranspose2d is implemented via the Conv2d u_op with flipped weights. + // The input tensor is first passed to the halo op that paddeds the input. + // In the scenario, where stride > 1, the halo op will add interleaved 0s to the input tensor. + // The Conv2d u_op is then called with stride = 1, padding = 0. + // SlidingWindowConfig has a is_transpose flag that is set to true to indicate that the Conv2d u_op & Halo u_op is + // being called for ConvTranspose2d. + uint32_t output_height = + (input_height - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1; + uint32_t output_width = + (input_width - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1; + + // Dimensions of Input to Conv u_op + uint32_t full_input_height = output_height + dilation[0] * (kernel_size[0] - 1); + uint32_t full_input_width = output_width + dilation[1] * (kernel_size[1] - 1); + + // Size of input after adding interleaved 0s. + uint32_t strided_input_height = (input_height - 1) * stride[0] + 1; + uint32_t strided_input_width = (input_width - 1) * stride[1] + 1; + + uint32_t input_pad_top = (full_input_height - strided_input_height) / 2; + uint32_t input_pad_bottom = full_input_height - strided_input_height - input_pad_top; + + uint32_t input_pad_left = (full_input_width - strided_input_width) / 2; + uint32_t input_pad_right = full_input_width - strided_input_width - input_pad_left; + + log_debug(LogOp, "Input : {}x{}", input_height, input_width); + log_debug(LogOp, "Output : {}x{}", output_height, output_width); + + log_debug(LogOp, "Conv Op Input : {}x{}", full_input_height, full_input_width); + log_debug(LogOp, "Strided Input : {}x{}", strided_input_height, strided_input_width); + + log_debug(LogOp, "Padding : ({},{}) ({},{})", input_pad_top, input_pad_bottom, input_pad_left, input_pad_right); + + const bool mm_conv = use_matmul_for_1x1_conv( + kernel_size, {1, 1}, {input_pad_top + input_pad_bottom, input_pad_left + input_pad_right}, dilation, groups); + + const auto compute_grid_size = device->compute_with_storage_grid_size(); + + bool auto_shard = false; + if (!input_tensor.is_sharded() && !conv_config.shard_layout.has_value()) { + // In this case we deduce the shard layout. + adjust_conv_op_config_for_auto_shard_if_necessary( + mm_conv, + batch_size, + in_channels, + out_channels, + output_height, + output_width, + weight_tensor.get_shape()[3], + full_input_width, + compute_grid_size, + conv_config, + input_tensor.layout(), + ttnn::is_tensor_on_device_or_multidevice(input_tensor) ? std::make_optional(input_tensor.memory_config()) + : std::nullopt); + auto_shard = true; + } - if(conv_config.deallocate_activation) { - input_tensor_post_tm.deallocate(/*force*/true); - log_debug(tt::LogOp, "Deallocate Input Tensor"); - } + // Call Halo Transpose + auto [input_tensor_post_tm, parallel_config, output_parallel_config, use_non_tile_height] = + shard_or_reshard_tensor_if_required( + device, + input_tensor, + conv_config, + batch_size, + output_height, + output_width, + in_channels, + out_channels, + mm_conv, + auto_shard); - if (conv_config.reallocate_halo_output) { - halo_output = ttnn::move(halo_output); - log_debug(tt::LogOp, "Reallocate Halo Output"); - } - } + uint32_t round_up_size = !use_non_tile_height ? tt::constants::TILE_HEIGHT : 1; - //Call Conv2d u_op with Stride = 1, Padding = 0. - auto conv_out_memory_config = create_sharded_memory_config_from_parallel_config( - ttnn::Shape(std::array{1, 1, batch_size * output_height * output_width, tt::round_up(out_channels, 32)}), - output_parallel_config, - round_up_size); + Tensor halo_output; + if (!mm_conv) { + sliding_window_config.num_cores_nhw = get_num_cores_nhw_from_parallel_config(parallel_config); + sliding_window_config.core_range_set = input_tensor_post_tm.memory_config().shard_spec.value().grid; + sliding_window_config.snap_to_tile = !use_non_tile_height; - auto largest_parallel_config = output_parallel_config.grid.num_cores() > parallel_config.grid.num_cores() ? output_parallel_config : parallel_config; + halo_output = ttnn::halo( + DefaultQueueId, + input_tensor_post_tm, + sliding_window_config, + 0, + false, + parallel_config.shard_orientation == ShardOrientation::COL_MAJOR, + 0, + input_tensor_post_tm.memory_config()); - auto opt_conv_op_parallel_config = determine_conv_op_parallel_config_from_conv_output_mem_config( - conv_out_memory_config, - get_num_cores_nhw_from_parallel_config(largest_parallel_config), - get_num_cores_channels_from_parallel_config(largest_parallel_config) - ); + if (conv_config.deallocate_activation) { + input_tensor_post_tm.deallocate(/*force*/ true); + log_debug(tt::LogOp, "Deallocate Input Tensor"); + } - uint32_t in_channels_padded = tt::round_up( - in_channels, - get_num_cores_channels_from_parallel_config(parallel_config) * conv_config.input_channels_alignment); - uint32_t nhw_out_padded_ntile = get_num_cores_nhw_from_parallel_config(output_parallel_config) * - conv_out_memory_config.shard_spec.value().shape[0] / tt::constants::TILE_HEIGHT; - auto opt_conv_op_block_config = determine_per_core_conv_block_config( - parallel_config, - opt_conv_op_parallel_config, - in_channels_padded, - nhw_out_padded_ntile, - conv_config.act_block_h_override, - conv_config.act_block_w_div, - kernel_size[0], - kernel_size[1], - get_fp32_dest_acc_en(compute_config), - conv_config.enable_split_reader); - - //TODO: Flip the Weights - bool weight_is_on_device = ttnn::is_tensor_on_device_or_multidevice(weight_tensor); - ttnn::Tensor weight_tensor_on_device = weight_tensor; - std::optional bias_tensor_on_device = bias_tensor; - if (!weight_is_on_device) { - // prepare weights in desired layout and move to device - tie(weight_tensor_on_device, bias_tensor_on_device) = prepare_conv_weights_biases_and_move_to_device( - transform_weights_for_conv_transpose2d(weight_tensor, mirror_kernel), - bias_tensor, - conv_config.input_channels_alignment, - conv_config.weights_dtype, - opt_conv_op_block_config.act_block_w_ntiles, - opt_conv_op_block_config.out_subblock_w_ntiles, - parallel_config, - device, - groups, - opt_conv_op_block_config.act_block_h_ntiles, - input_width); + if (conv_config.reallocate_halo_output) { + halo_output = ttnn::move(halo_output); + log_debug(tt::LogOp, "Reallocate Halo Output"); } - if (mm_conv) { - input_tensor_post_tm = ttnn::to_layout( - input_tensor_post_tm, Layout::TILE, conv_config.dtype, input_tensor_post_tm.memory_config(), device); - std::optional program_config = std::nullopt; - std::optional mm_output_memory_config = std::nullopt; - - if (input_tensor_post_tm.is_sharded()) { - uint32_t num_cores_c = get_num_cores_channels_from_parallel_config(parallel_config); - program_config = determine_matmul_op_config_from_conv_op_config( - opt_conv_op_parallel_config, - opt_conv_op_block_config, - parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED, - conv_config.activation, - parallel_config.shard_orientation == ShardOrientation::COL_MAJOR, - num_cores_c); - mm_output_memory_config = conv_out_memory_config; - } - Tensor matmul_output = ttnn::linear( - input_tensor_post_tm, - weight_tensor_on_device, - bias_tensor_on_device, - false, - false, - mm_output_memory_config, - std::nullopt, - program_config); - - if (memory_config.has_value() && memory_config.value() != matmul_output.memory_config()) { - matmul_output = ttnn::to_memory_config(matmul_output, memory_config.value(), std::nullopt); - } + } - return {matmul_output, output_height, output_width, weight_tensor_on_device, bias_tensor_on_device}; + // Call Conv2d u_op with Stride = 1, Padding = 0. + auto conv_out_memory_config = create_sharded_memory_config_from_parallel_config( + ttnn::Shape( + std::array{1, 1, batch_size * output_height * output_width, tt::round_up(out_channels, 32)}), + output_parallel_config, + round_up_size); + + auto largest_parallel_config = output_parallel_config.grid.num_cores() > parallel_config.grid.num_cores() + ? output_parallel_config + : parallel_config; + + auto opt_conv_op_parallel_config = determine_conv_op_parallel_config_from_conv_output_mem_config( + conv_out_memory_config, + get_num_cores_nhw_from_parallel_config(largest_parallel_config), + get_num_cores_channels_from_parallel_config(largest_parallel_config)); + + uint32_t in_channels_padded = tt::round_up( + in_channels, + get_num_cores_channels_from_parallel_config(parallel_config) * conv_config.input_channels_alignment); + uint32_t nhw_out_padded_ntile = get_num_cores_nhw_from_parallel_config(output_parallel_config) * + conv_out_memory_config.shard_spec.value().shape[0] / tt::constants::TILE_HEIGHT; + auto opt_conv_op_block_config = determine_per_core_conv_block_config( + parallel_config, + opt_conv_op_parallel_config, + in_channels_padded, + nhw_out_padded_ntile, + conv_config.act_block_h_override, + conv_config.act_block_w_div, + kernel_size[0], + kernel_size[1], + get_fp32_dest_acc_en(compute_config), + conv_config.enable_split_reader); + + // TODO: Flip the Weights + bool weight_is_on_device = ttnn::is_tensor_on_device_or_multidevice(weight_tensor); + ttnn::Tensor weight_tensor_on_device = weight_tensor; + std::optional bias_tensor_on_device = bias_tensor; + if (!weight_is_on_device) { + // prepare weights in desired layout and move to device + tie(weight_tensor_on_device, bias_tensor_on_device) = prepare_conv_weights_biases_and_move_to_device( + transform_weights_for_conv_transpose2d(weight_tensor, mirror_kernel), + bias_tensor, + conv_config.input_channels_alignment, + conv_config.weights_dtype, + opt_conv_op_block_config.act_block_w_ntiles, + opt_conv_op_block_config.out_subblock_w_ntiles, + parallel_config, + device, + groups, + opt_conv_op_block_config.act_block_h_ntiles, + input_width); + } + if (mm_conv) { + input_tensor_post_tm = ttnn::to_layout( + input_tensor_post_tm, Layout::TILE, conv_config.dtype, input_tensor_post_tm.memory_config(), device); + std::optional program_config = std::nullopt; + std::optional mm_output_memory_config = std::nullopt; + + if (input_tensor_post_tm.is_sharded()) { + uint32_t num_cores_c = get_num_cores_channels_from_parallel_config(parallel_config); + program_config = determine_matmul_op_config_from_conv_op_config( + opt_conv_op_parallel_config, + opt_conv_op_block_config, + parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED, + conv_config.activation, + parallel_config.shard_orientation == ShardOrientation::COL_MAJOR, + num_cores_c); + mm_output_memory_config = conv_out_memory_config; } - // call conv micro op - auto conv_output = optimized_conv_new( - halo_output, + Tensor matmul_output = ttnn::linear( + input_tensor_post_tm, weight_tensor_on_device, bias_tensor_on_device, - sliding_window_config, - out_channels, - groups, - conv_config.output_layout == Layout::ROW_MAJOR, - conv_config.activation == "relu", - opt_conv_op_parallel_config, - opt_conv_op_block_config, - conv_out_memory_config, - conv_config.dtype, - {batch_size, input_height, input_width, in_channels}, - conv_config.input_channels_alignment == 16, - compute_config, - conv_config.enable_act_double_buffer, - conv_config.enable_split_reader, - conv_config.enable_subblock_padding); - if (memory_config.has_value() && memory_config.value() != conv_output.memory_config()) { - conv_output = ttnn::to_memory_config(conv_output, memory_config.value(), std::nullopt); + false, + false, + mm_output_memory_config, + std::nullopt, + program_config); + + if (memory_config.has_value() && memory_config.value() != matmul_output.memory_config()) { + matmul_output = ttnn::to_memory_config(matmul_output, memory_config.value(), std::nullopt); } - return {conv_output, output_height, output_width, weight_tensor_on_device, bias_tensor_on_device}; + + return {matmul_output, output_height, output_width, weight_tensor_on_device, bias_tensor_on_device}; + } + // call conv micro op + auto conv_output = optimized_conv_new( + halo_output, + weight_tensor_on_device, + bias_tensor_on_device, + sliding_window_config, + out_channels, + groups, + conv_config.output_layout == Layout::ROW_MAJOR, + conv_config.activation == "relu", + opt_conv_op_parallel_config, + opt_conv_op_block_config, + conv_out_memory_config, + conv_config.dtype, + {batch_size, input_height, input_width, in_channels}, + conv_config.input_channels_alignment == 16, + compute_config, + conv_config.enable_act_double_buffer, + conv_config.enable_split_reader, + conv_config.enable_subblock_padding); + if (memory_config.has_value() && memory_config.value() != conv_output.memory_config()) { + conv_output = ttnn::to_memory_config(conv_output, memory_config.value(), std::nullopt); } + return {conv_output, output_height, output_width, weight_tensor_on_device, bias_tensor_on_device}; +} Result ConvTranpose2dOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, const ttnn::Tensor& weight_tensor, - Device * device, + Device* device, uint32_t in_channels, uint32_t out_channels, uint32_t batch_size, @@ -366,14 +367,33 @@ Result ConvTranpose2dOperation::invoke( const std::optional& compute_config_, const std::optional& memory_config, bool mirror_kernel) { - return conv_transpose2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, std::move(bias_tensor), std::move(conv_config_), std::move(compute_config_), std::move(memory_config), mirror_kernel); + return conv_transpose2d( + input_tensor, + weight_tensor, + device, + in_channels, + out_channels, + batch_size, + input_height, + input_width, + kernel_size, + stride, + padding, + output_padding, + dilation, + groups, + std::move(bias_tensor), + std::move(conv_config_), + std::move(compute_config_), + std::move(memory_config), + mirror_kernel); } Result ConvTranpose2dOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, const ttnn::Tensor& weight_tensor, - MeshDevice * device, + MeshDevice* device, uint32_t in_channels, uint32_t out_channels, uint32_t batch_size, @@ -389,10 +409,29 @@ Result ConvTranpose2dOperation::invoke( const std::optional& conv_config_, const std::optional& compute_config_, const std::optional& memory_config, - bool mirror_kernel){ - return conv_transpose2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, std::move(bias_tensor), std::move(conv_config_), std::move(compute_config_), std::move(memory_config), mirror_kernel); + bool mirror_kernel) { + return conv_transpose2d( + input_tensor, + weight_tensor, + device, + in_channels, + out_channels, + batch_size, + input_height, + input_width, + kernel_size, + stride, + padding, + output_padding, + dilation, + groups, + std::move(bias_tensor), + std::move(conv_config_), + std::move(compute_config_), + std::move(memory_config), + mirror_kernel); } -} -} -} +} // namespace conv_transpose2d +} // namespace operations::conv +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp index 26b07a668bb..0b595e666a6 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp @@ -16,12 +16,12 @@ using OutputHeight = uint32_t; using OutputWidth = uint32_t; using Result = std::tuple>; -struct ConvTranpose2dOperation{ +struct ConvTranpose2dOperation { static Result invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, const ttnn::Tensor& weight_tensor, - Device * device, + Device* device, uint32_t in_channels, uint32_t out_channels, uint32_t batch_size, @@ -43,7 +43,7 @@ struct ConvTranpose2dOperation{ uint8_t queue_id, const ttnn::Tensor& input_tensor, const ttnn::Tensor& weight_tensor, - MeshDevice * device, + MeshDevice* device, uint32_t in_channels, uint32_t out_channels, uint32_t batch_size, @@ -66,6 +66,7 @@ struct ConvTranpose2dOperation{ } // namespace operations::conv } // namespace ttnn -namespace ttnn{ - constexpr auto conv_transpose2d = ttnn::register_operation<"ttnn::conv_transpose2d", operations::conv::conv_transpose2d::ConvTranpose2dOperation>(); +namespace ttnn { +constexpr auto conv_transpose2d = + ttnn::register_operation<"ttnn::conv_transpose2d", operations::conv::conv_transpose2d::ConvTranpose2dOperation>(); } diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d_pybind.cpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d_pybind.cpp index 53c4bd8d37e..524fe71b581 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d_pybind.cpp @@ -2,8 +2,6 @@ // // SPDX-License-Identifier: Apache-2.0 - - #include "ttnn/cpp/pybind11/decorators.hpp" #include "conv_transpose2d_pybind.hpp" @@ -16,7 +14,6 @@ namespace operations::conv { namespace conv_transpose2d { void py_bind_conv_transpose2d(py::module& module) { - bind_registered_operation( module, ttnn::conv_transpose2d, @@ -92,28 +89,48 @@ void py_bind_conv_transpose2d(py::module& module) { ) )doc", ttnn::pybind_overload_t{ - [](const decltype(ttnn::conv_transpose2d)& self, const ttnn::Tensor& input_tensor, - const ttnn::Tensor& weight_tensor, - ttnn::Device* device, - uint32_t in_channels, - uint32_t out_channels, - uint32_t batch_size, - uint32_t input_height, - uint32_t input_width, - std::array kernel_size, - std::array stride, - std::array padding, - std::array output_padding, - std::array dilation, - uint32_t groups, - std::optional bias_tensor, - const std::optional& conv_config, - const std::optional& compute_config, - const std::optional& memory_config, - bool mirror_kernel, - const uint8_t& queue_id) -> Result { - return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, bias_tensor, conv_config, compute_config, memory_config, mirror_kernel); - + [](const decltype(ttnn::conv_transpose2d)& self, + const ttnn::Tensor& input_tensor, + const ttnn::Tensor& weight_tensor, + ttnn::Device* device, + uint32_t in_channels, + uint32_t out_channels, + uint32_t batch_size, + uint32_t input_height, + uint32_t input_width, + std::array kernel_size, + std::array stride, + std::array padding, + std::array output_padding, + std::array dilation, + uint32_t groups, + std::optional bias_tensor, + const std::optional& conv_config, + const std::optional& compute_config, + const std::optional& memory_config, + bool mirror_kernel, + const uint8_t& queue_id) -> Result { + return self( + queue_id, + input_tensor, + weight_tensor, + device, + in_channels, + out_channels, + batch_size, + input_height, + input_width, + kernel_size, + stride, + padding, + output_padding, + dilation, + groups, + bias_tensor, + conv_config, + compute_config, + memory_config, + mirror_kernel); }, py::kw_only(), py::arg("input_tensor"), @@ -138,28 +155,48 @@ void py_bind_conv_transpose2d(py::module& module) { py::arg("queue_id") = 0}, ttnn::pybind_overload_t{ - [](const decltype(ttnn::conv_transpose2d)& self, const ttnn::Tensor& input_tensor, - const ttnn::Tensor& weight_tensor, - ttnn::MeshDevice* device, - uint32_t in_channels, - uint32_t out_channels, - uint32_t batch_size, - uint32_t input_height, - uint32_t input_width, - std::array kernel_size, - std::array stride, - std::array padding, - std::array output_padding, - std::array dilation, - uint32_t groups, - std::optional bias_tensor, - const std::optional& conv_config, - const std::optional& compute_config, - const std::optional& memory_config, - bool mirror_kernel, - const uint8_t& queue_id) -> Result { - return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, bias_tensor, conv_config, compute_config, memory_config, mirror_kernel); - + [](const decltype(ttnn::conv_transpose2d)& self, + const ttnn::Tensor& input_tensor, + const ttnn::Tensor& weight_tensor, + ttnn::MeshDevice* device, + uint32_t in_channels, + uint32_t out_channels, + uint32_t batch_size, + uint32_t input_height, + uint32_t input_width, + std::array kernel_size, + std::array stride, + std::array padding, + std::array output_padding, + std::array dilation, + uint32_t groups, + std::optional bias_tensor, + const std::optional& conv_config, + const std::optional& compute_config, + const std::optional& memory_config, + bool mirror_kernel, + const uint8_t& queue_id) -> Result { + return self( + queue_id, + input_tensor, + weight_tensor, + device, + in_channels, + out_channels, + batch_size, + input_height, + input_width, + kernel_size, + stride, + padding, + output_padding, + dilation, + groups, + bias_tensor, + conv_config, + compute_config, + memory_config, + mirror_kernel); }, py::kw_only(), py::arg("input_tensor"), @@ -181,10 +218,9 @@ void py_bind_conv_transpose2d(py::module& module) { py::arg("compute_config") = std::nullopt, py::arg("memory_config") = std::nullopt, py::arg("mirror_kernel") = true, - py::arg("queue_id") = 0} - ); + py::arg("queue_id") = 0}); } -} // namespace conv2d -} // namespace operations +} // namespace conv_transpose2d +} // namespace operations::conv } // namespace ttnn