From 28c64c0f89c5241ec8f3a81c96b265f51163d7ea Mon Sep 17 00:00:00 2001 From: Nilaykumar Patel Date: Sun, 8 Dec 2024 07:46:13 +0000 Subject: [PATCH] Refactor tensor_utils.cpp. Signed-off-by: Nilaykumar Patel --- ttnn/cpp/ttnn/tensor/tensor_utils.cpp | 408 ++++++++------------------ ttnn/cpp/ttnn/tensor/tensor_utils.hpp | 2 +- 2 files changed, 130 insertions(+), 280 deletions(-) diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp index ef1e685998d..670be6d27c0 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp @@ -6,7 +6,7 @@ #include "ttnn/distributed/api.hpp" #include "ttnn/tensor/host_buffer/functions.hpp" -#include "ttnn/tensor/host_buffer/types.hpp" +#include "ttnn/tensor/types.hpp" namespace tt { @@ -32,6 +32,53 @@ Tensor convert_tensor(const Tensor& input_tensor, compute_& compute) { return ttnn::distributed::is_multi_device_tensor(input_tensor) ? transform(input_tensor, convert_tensor) : convert_tensor(input_tensor); } +template +Tensor convert_tensor_to_tiled_layout_common( + const Tensor& input_tensor, + std::optional output_dtype, + const std::unordered_map& function_map, + Args&&... args) { + TT_ASSERT( + input_tensor.get_layout() == Layout::ROW_MAJOR && + "Tensor(weight/bias) should be in row major layout for conversion to tilized layout."); + + if (output_dtype.has_value()) { + if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { + TT_ASSERT(input_tensor.get_dtype() == DataType::FLOAT32); + } else { + TT_ASSERT(input_tensor.get_dtype() == input_tensor.get_dtype()); + } + } + auto entry = function_map.find(input_tensor.get_dtype()); + if (entry == function_map.end()) { + TT_THROW("Unsupported data type"); + } + return entry->second(input_tensor, std::forward(args)..., output_dtype.value_or(input_tensor.get_dtype())); +} + +template +Tensor create_tensor_from_owned_buffer( + owned_buffer::Buffer& buf, DataType& output_dtype, ttnn::SimpleShape& output_shape) { + if constexpr (std::is_same::value) { + if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { + auto tensor = + Tensor(std::move(OwnedStorage{std::move(buf)}), output_shape, DataType::FLOAT32, Layout::ROW_MAJOR) + .to(Layout::TILE); + auto output_float_data = owned_buffer::get_as(tensor).get(); + auto output_packed_data = + output_dtype == DataType::BFLOAT8_B + ? pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false) + : pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); + return Tensor( + std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_shape, output_dtype, Layout::TILE); + } + } else { + TT_ASSERT((output_dtype != DataType::BFLOAT8_B) || (output_dtype != DataType::BFLOAT4_B)); + } + auto rm_tensor = Tensor(std::move(OwnedStorage{std::move(buf)}), output_shape, output_dtype, Layout::ROW_MAJOR); + return rm_tensor.to(Layout::TILE); +} template Tensor to_weight_special_padding_tile_layout( @@ -65,29 +112,7 @@ Tensor to_weight_special_padding_tile_layout( } } } - if constexpr (std::is_same::value) { - if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { - auto tensor = Tensor( - std::move(OwnedStorage{std::move(output_buffer)}), - output_shape, - DataType::FLOAT32, - Layout::ROW_MAJOR) - .to(Layout::TILE); - auto output_float_data = owned_buffer::get_as(tensor).get(); - auto output_packed_data = - output_dtype == DataType::BFLOAT8_B - ? pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false) - : pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_shape, output_dtype, Layout::TILE); - } - } else { - TT_ASSERT((output_dtype != DataType::BFLOAT8_B) || (output_dtype != DataType::BFLOAT4_B)); - } - auto rm_tensor = - Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_shape, output_dtype, Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); + return create_tensor_from_owned_buffer(output_buffer, output_dtype, output_shape); }; return convert_tensor(conv_weight_tensor, compute); } @@ -126,29 +151,7 @@ Tensor to_weight_tile_layout( } } } - if constexpr (std::is_same::value) { - if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { - auto tensor = Tensor( - std::move(OwnedStorage{std::move(output_buffer)}), - output_shape, - DataType::FLOAT32, - Layout::ROW_MAJOR) - .to(Layout::TILE); - auto output_float_data = owned_buffer::get_as(tensor).get(); - auto output_packed_data = - output_dtype == DataType::BFLOAT8_B - ? pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false) - : pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_shape, output_dtype, Layout::TILE); - } - } else { - TT_ASSERT((output_dtype != DataType::BFLOAT8_B) || (output_dtype != DataType::BFLOAT4_B)); - } - auto rm_tensor = - Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_shape, output_dtype, Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); + return create_tensor_from_owned_buffer(output_buffer, output_dtype, output_shape); }; return convert_tensor(conv_weight_tensor, compute); @@ -161,30 +164,14 @@ Tensor convert_conv_weight_tensor_to_tiled_layout( uint32_t in1_block_h, uint32_t in1_block_w, std::optional output_dtype) { - TT_ASSERT( - conv_weight_tensor.get_layout() == Layout::ROW_MAJOR && - "Convolution weights should be in row major layout for conversion to tilized layout."); - - if (output_dtype.has_value()) { - if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { - TT_ASSERT(conv_weight_tensor.get_dtype() == DataType::FLOAT32); - } else { - TT_ASSERT(conv_weight_tensor.get_dtype() == conv_weight_tensor.get_dtype()); - } - } - - switch (conv_weight_tensor.get_dtype()) { - case DataType::BFLOAT16: - return to_weight_tile_layout( - conv_weight_tensor, in1_block_h, in1_block_w, output_dtype.value_or(conv_weight_tensor.get_dtype())); - case DataType::FLOAT32: - return to_weight_tile_layout( - conv_weight_tensor, in1_block_h, in1_block_w, output_dtype.value_or(conv_weight_tensor.get_dtype())); - case DataType::UINT32: - return to_weight_tile_layout( - conv_weight_tensor, in1_block_h, in1_block_w, output_dtype.value_or(conv_weight_tensor.get_dtype())); - default: TT_THROW("Unsupported data type"); - } + const static std::unordered_map> + to_w_tile_layout_map = { + {DataType::BFLOAT16, &to_weight_tile_layout}, + {DataType::FLOAT32, &to_weight_tile_layout}, + {DataType::UINT32, &to_weight_tile_layout}}; + + return convert_tensor_to_tiled_layout_common( + conv_weight_tensor, output_dtype, to_w_tile_layout_map, in1_block_h, in1_block_w); } template @@ -236,41 +223,7 @@ Tensor to_weight_tile_layout_block_sharded( } } } - if constexpr (std::is_same::value) { - if (output_dtype == DataType::BFLOAT8_B) { - auto tensor = Tensor( - std::move(OwnedStorage{std::move(output_buffer)}), - output_shape, - DataType::FLOAT32, - Layout::ROW_MAJOR) - .to(Layout::TILE); - auto output_float_data = owned_buffer::get_as(tensor).get(); - auto output_packed_data = - pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_shape, output_dtype, Layout::TILE); - } - if (output_dtype == DataType::BFLOAT4_B) { - auto tensor = Tensor( - std::move(OwnedStorage{std::move(output_buffer)}), - output_shape, - DataType::FLOAT32, - Layout::ROW_MAJOR) - .to(Layout::TILE); - auto output_float_data = owned_buffer::get_as(tensor).get(); - auto output_packed_data = - pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_shape, output_dtype, Layout::TILE); - } - } else { - TT_ASSERT((output_dtype != DataType::BFLOAT8_B) || (output_dtype != DataType::BFLOAT4_B)); - } - auto rm_tensor = - Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_shape, output_dtype, Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); + return create_tensor_from_owned_buffer(output_buffer, output_dtype, output_shape); }; return convert_tensor(conv_weight_tensor, compute); } @@ -279,25 +232,14 @@ Tensor to_weight_tile_layout_block_sharded( // Returns a new tensor with layout=Tile Tensor convert_conv_weight_tensor_to_tiled_layout_block_sharded( const Tensor& conv_weight_tensor, uint32_t num_channel_shards, std::optional output_dtype) { - TT_ASSERT( - conv_weight_tensor.get_layout() == Layout::ROW_MAJOR && - "Convolution weights should be in row major layout for conversion to tilized layout."); - const static std:: - map> - to_w_tile_layout_map = { - {DataType::BFLOAT16, &to_weight_tile_layout_block_sharded}, - {DataType::FLOAT32, &to_weight_tile_layout_block_sharded}, - {DataType::UINT32, &to_weight_tile_layout_block_sharded}, - }; - if (output_dtype.has_value()) { - if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { - TT_ASSERT(conv_weight_tensor.get_dtype() == DataType::FLOAT32); - } else { - TT_ASSERT(conv_weight_tensor.get_dtype() == conv_weight_tensor.get_dtype()); - } - } - return to_w_tile_layout_map.at(conv_weight_tensor.get_dtype())( - conv_weight_tensor, num_channel_shards, output_dtype.value_or(conv_weight_tensor.get_dtype())); + const static std::unordered_map> + to_w_tile_layout_map = { + {DataType::BFLOAT16, &to_weight_tile_layout_block_sharded}, + {DataType::FLOAT32, &to_weight_tile_layout_block_sharded}, + {DataType::UINT32, &to_weight_tile_layout_block_sharded}}; + + return convert_tensor_to_tiled_layout_common( + conv_weight_tensor, output_dtype, to_w_tile_layout_map, num_channel_shards); } template @@ -327,41 +269,7 @@ Tensor to_bias_tile_layout_block_sharded( output_buffer[matrix_idx] = input_buffer[idx]; } } - if constexpr (std::is_same::value) { - if (output_dtype == DataType::BFLOAT8_B) { - auto tensor = Tensor( - std::move(OwnedStorage{std::move(output_buffer)}), - output_shape, - DataType::FLOAT32, - Layout::ROW_MAJOR) - .to(Layout::TILE); - auto output_float_data = owned_buffer::get_as(tensor).get(); - auto output_packed_data = - pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_shape, output_dtype, Layout::TILE); - } - if (output_dtype == DataType::BFLOAT4_B) { - auto tensor = Tensor( - std::move(OwnedStorage{std::move(output_buffer)}), - output_shape, - DataType::FLOAT32, - Layout::ROW_MAJOR) - .to(Layout::TILE); - auto output_float_data = owned_buffer::get_as(tensor).get(); - auto output_packed_data = - pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_shape, output_dtype, Layout::TILE); - } - } else { - TT_ASSERT((output_dtype != DataType::BFLOAT8_B) || (output_dtype != DataType::BFLOAT4_B)); - } - auto rm_tensor = - Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_shape, output_dtype, Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); + return create_tensor_from_owned_buffer(output_buffer, output_dtype, output_shape); }; return convert_tensor(conv_bias_tensor, compute); @@ -371,25 +279,16 @@ Tensor to_bias_tile_layout_block_sharded( // Returns a new tensor with layout=Tile Tensor convert_conv_bias_tensor_to_tiled_layout_block_sharded( const Tensor& conv_bias_tensor, uint32_t num_channel_shards, std::optional output_dtype) { - TT_ASSERT( - conv_bias_tensor.get_layout() == Layout::ROW_MAJOR && - "Convolution weights should be in row major layout for conversion to tilized layout."); - const static std:: - map> - to_b_tile_layout_map = { - {DataType::BFLOAT16, &to_bias_tile_layout_block_sharded}, - {DataType::FLOAT32, &to_bias_tile_layout_block_sharded}, - {DataType::UINT32, &to_bias_tile_layout_block_sharded}, - }; - if (output_dtype.has_value()) { - if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { - TT_ASSERT(conv_bias_tensor.get_dtype() == DataType::FLOAT32); - } else { - TT_ASSERT(conv_bias_tensor.get_dtype() == conv_bias_tensor.get_dtype()); - } - } - return to_b_tile_layout_map.at(conv_bias_tensor.get_dtype())( - conv_bias_tensor, num_channel_shards, output_dtype.value_or(conv_bias_tensor.get_dtype())); + const static std::unordered_map< + DataType, + std::function> + to_b_tile_layout_map = { + {DataType::BFLOAT16, &to_bias_tile_layout_block_sharded}, + {DataType::FLOAT32, &to_bias_tile_layout_block_sharded}, + {DataType::UINT32, &to_bias_tile_layout_block_sharded}, + }; + return convert_tensor_to_tiled_layout_common( + conv_bias_tensor, output_dtype, to_b_tile_layout_map, num_channel_shards); } // Converts convolution weights to tilized 2d matrix layout. @@ -399,30 +298,14 @@ Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( uint32_t in1_block_h, uint32_t in1_block_w, std::optional output_dtype) { - TT_ASSERT( - conv_weight_tensor.get_layout() == Layout::ROW_MAJOR && - "Convolution weights should be in row major layout for conversion to tilized layout."); - - if (output_dtype.has_value()) { - if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { - TT_ASSERT(conv_weight_tensor.get_dtype() == DataType::FLOAT32); - } else { - TT_ASSERT(conv_weight_tensor.get_dtype() == conv_weight_tensor.get_dtype()); - } - } - - switch (conv_weight_tensor.get_dtype()) { - case DataType::BFLOAT16: - return to_weight_special_padding_tile_layout( - conv_weight_tensor, in1_block_h, in1_block_w, output_dtype.value_or(conv_weight_tensor.get_dtype())); - case DataType::FLOAT32: - return to_weight_special_padding_tile_layout( - conv_weight_tensor, in1_block_h, in1_block_w, output_dtype.value_or(conv_weight_tensor.get_dtype())); - case DataType::UINT32: - return to_weight_special_padding_tile_layout( - conv_weight_tensor, in1_block_h, in1_block_w, output_dtype.value_or(conv_weight_tensor.get_dtype())); - default: TT_THROW("Unsupported data type"); - } + const static std::unordered_map> + to_w_tile_layout_map = { + {DataType::BFLOAT16, &to_weight_special_padding_tile_layout}, + {DataType::FLOAT32, &to_weight_special_padding_tile_layout}, + {DataType::UINT32, &to_weight_special_padding_tile_layout}}; + + return convert_tensor_to_tiled_layout_common( + conv_weight_tensor, output_dtype, to_w_tile_layout_map, in1_block_h, in1_block_w); } /* @@ -478,7 +361,7 @@ Helper function to aid in converting depthwise weight tensor to broadcasted weig */ template static Tensor conv_depthwise_weight_bcast_helper( - Tensor& conv_weight_tensor, + const Tensor& conv_weight_tensor, const ttnn::SimpleShape& original_weight_shape, const ttnn::SimpleShape& output_weight_shape, DataType output_dtype) { @@ -514,10 +397,6 @@ divided into num_groups for each groupped filter */ Tensor convert_conv_weight_tensor_to_grouped_layout( const Tensor& conv_weight_tensor, uint32_t num_groups, DataType output_dtype) { - TT_ASSERT( - conv_weight_tensor.get_layout() == Layout::ROW_MAJOR && - "Convolution weights should be in row major layout for adding the required padding"); - // Define output tensor shape. This is going to be channel dimension of weight tensor * num_groups - this value // should match number of input channels being convolved with the weight tensor auto original_conv_weight_tensor_shape_test = conv_weight_tensor.get_shape(); @@ -532,52 +411,27 @@ Tensor convert_conv_weight_tensor_to_grouped_layout( original_conv_weight_tensor_shape[2], original_conv_weight_tensor_shape[3]}; - // Create newly allocated buffer all initialized to 0 depending on the datatype of the weight tensor - if (output_dtype == DataType::INT32) { - return conv_group_weight_zero_pad_helper( - conv_weight_tensor, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape, - num_groups, - output_dtype); - } else if (output_dtype == DataType::FLOAT32) { - return conv_group_weight_zero_pad_helper( - conv_weight_tensor, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape, - num_groups, - output_dtype); - } else if (output_dtype == DataType::BFLOAT16) { - return conv_group_weight_zero_pad_helper( - conv_weight_tensor, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape, - num_groups, - output_dtype); - } else if (output_dtype == DataType::UINT16) { - return conv_group_weight_zero_pad_helper( - conv_weight_tensor, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape, - num_groups, - output_dtype); - } else if (output_dtype == DataType::BFLOAT8_B) { - return conv_group_weight_zero_pad_helper( - conv_weight_tensor, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape, - num_groups, - DataType::FLOAT32); - } else { - return conv_group_weight_zero_pad_helper( - conv_weight_tensor, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape, - num_groups, - output_dtype); - } - - TT_THROW("Unsupported weight data type given when trying to add zero padding to weight tensor"); + const static std::unordered_map< + DataType, + std::function> + to_w_tile_layout_map = { + {DataType::INT32, &conv_group_weight_zero_pad_helper}, + {DataType::FLOAT32, &conv_group_weight_zero_pad_helper}, + {DataType::BFLOAT16, &conv_group_weight_zero_pad_helper}, + {DataType::UINT16, &conv_group_weight_zero_pad_helper}, + {DataType::BFLOAT8_B, &conv_group_weight_zero_pad_helper}, + {DataType::UINT32, &conv_group_weight_zero_pad_helper}, + {DataType::BFLOAT4_B, &conv_group_weight_zero_pad_helper}, + }; + output_dtype = output_dtype == DataType::BFLOAT8_B ? DataType::FLOAT32 : output_dtype; + + return convert_tensor_to_tiled_layout_common( + conv_weight_tensor, + output_dtype, + to_w_tile_layout_map, + original_conv_weight_tensor_shape, + output_conv_weight_tensor_shape, + num_groups); } /* @@ -587,10 +441,7 @@ allocated output tensor with shape [out_channels, act_block_h, H, W] The extra c from the original weight tensor - it would be convolving act_block in conv_matrix in one go */ Tensor convert_conv_weight_tensor_to_depthwise_layout( - Tensor conv_weight_tensor, uint32_t act_block_h_ntiles, DataType output_dtype) { - TT_ASSERT( - conv_weight_tensor.get_layout() == Layout::ROW_MAJOR && - "Convolution weights should be in row major layout for repeating the required dimensions"); + const Tensor& conv_weight_tensor, uint32_t act_block_h_ntiles, DataType output_dtype) { auto original_conv_weight_tensor_shape_test = conv_weight_tensor.get_shape(); uint32_t num_input_channels_to_repeat = act_block_h_ntiles * constants::TILE_HEIGHT; ttnn::SimpleShape original_conv_weight_tensor_shape{ @@ -605,27 +456,26 @@ Tensor convert_conv_weight_tensor_to_depthwise_layout( original_conv_weight_tensor_shape[3]}; // Create newly allocated buffer all initialized to 0 depending on the datatype of the weight tensor - if (output_dtype == DataType::INT32) { - return conv_depthwise_weight_bcast_helper( - conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype); - } else if (output_dtype == DataType::FLOAT32) { - return conv_depthwise_weight_bcast_helper( - conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype); - } else if (output_dtype == DataType::BFLOAT16) { - return conv_depthwise_weight_bcast_helper( - conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype); - } else if (output_dtype == DataType::UINT16) { - return conv_depthwise_weight_bcast_helper( - conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype); - } else if (output_dtype == DataType::BFLOAT8_B) { - return conv_depthwise_weight_bcast_helper( - conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, DataType::FLOAT32); - } else { - return conv_depthwise_weight_bcast_helper( - conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, DataType::FLOAT32); - } - - TT_THROW("Unsupported weight data type given when trying to add zero padding to weight tensor"); + const static std:: + unordered_map> + to_w_tile_layout_map = { + {DataType::INT32, &conv_depthwise_weight_bcast_helper}, + {DataType::FLOAT32, &conv_depthwise_weight_bcast_helper}, + {DataType::BFLOAT16, &conv_depthwise_weight_bcast_helper}, + {DataType::UINT16, &conv_depthwise_weight_bcast_helper}, + {DataType::BFLOAT8_B, &conv_depthwise_weight_bcast_helper}, + {DataType::UINT32, &conv_depthwise_weight_bcast_helper}, + {DataType::BFLOAT4_B, &conv_depthwise_weight_bcast_helper}, + }; + output_dtype = ((output_dtype == DataType::BFLOAT8_B) || (output_dtype == DataType::BFLOAT4_B)) ? DataType::FLOAT32 + : output_dtype; + + return convert_tensor_to_tiled_layout_common( + conv_weight_tensor, + output_dtype, + to_w_tile_layout_map, + original_conv_weight_tensor_shape, + output_conv_weight_tensor_shape); } const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, tt::stl::Span shape) { diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.hpp b/ttnn/cpp/ttnn/tensor/tensor_utils.hpp index 3c2565299b9..f4c9b1ae537 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.hpp @@ -44,7 +44,7 @@ Tensor convert_conv_weight_tensor_to_grouped_layout( // Converts convolution weights to depthwise layout with broadcasted weights Tensor convert_conv_weight_tensor_to_depthwise_layout( - Tensor conv_weight_tensor, uint32_t act_block_h_ntiles, DataType output_dtype); + const Tensor& conv_weight_tensor, uint32_t act_block_h_ntiles, DataType output_dtype); const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, tt::stl::Span shape);