diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp index b9a3de1393b..bf7a56ee3b5 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp @@ -292,6 +292,7 @@ std::pair> prepare_conv_weights_biases bias_tensor_ = bias_tensor.value(); bool is_bias_tensor_is_on_device = ttnn::is_tensor_on_device_or_multidevice(bias_tensor_); if(!is_bias_tensor_is_on_device) { + TT_FATAL(bias_tensor_.shape()[3]==out_channels, "Bias must have the same length as output channels"); bias_tensor_ = conv_bias_layout_convert(bias_tensor_, weights_bias_dtype, weight_block_h_ntiles, weight_block_w_ntiles, output_parallel_config, device, out_channels_padded, is_non_tile_mul_width); bias_tensor_ = ttnn::operations::core::to_device(bias_tensor_, device, std::nullopt); } @@ -471,6 +472,8 @@ ttnn::Tensor prepare_conv_bias( bool is_non_tile_mul_width = check_non_tile_mul_width(device, conv_config, in_channels); ttnn::Tensor bias_tensor_ = bias_tensor; + TT_FATAL(bias_tensor_.shape()[3]==out_channels, "Bias must have the same length as output channels"); + bias_tensor_ = conv_bias_layout_convert( bias_tensor_, conv_config.weights_dtype,