diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py index 05abc879bed3..ff94196d37e0 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py @@ -626,7 +626,7 @@ def __init__( conv_dummy_tensor = torch.rand((self.fold_output_shape), dtype=torch.bfloat16) conv_dummy_tensor = ttnn.from_torch(conv_dummy_tensor, layout=ttnn.ROW_MAJOR_LAYOUT) - _, self.override_fold_mem_config, _ = ttnn.get_conv_padded_input_shape_and_mem_config( + _, self.override_fold_mem_config, _, _ = ttnn.get_conv_padded_input_shape_and_mem_config( device=device, input_tensor=conv_dummy_tensor, conv_config=self.conv1_config, diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index a2a514ed9750..e761e5a75a6b 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -135,6 +135,7 @@ def run_conv( enable_act_double_buffer=False, enable_split_reader=False, enable_subblock_padding=False, + output_layout=output_layout, ) if config_override and "act_block_h" in config_override: conv_config.act_block_h_override = config_override["act_block_h"] @@ -2203,3 +2204,116 @@ def test_conv_for_vanilla_unet( output_layout=output_layout, has_bias=False, ) + + +@skip_for_blackhole() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize( + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override", + ( + # unique convs in rn50 (complete list) + # first conv post folding and input_channels padding to tile width + (16, 64, 64, 14, 14, 3, 3, 1, 1, 1, 1, True, None), + # rn50 layer1 + (8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), + (16, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), + (20, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), + # rn50 layer2 + (8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), + (16, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), + (20, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), + (8, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), + (16, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), + (20, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), + (1, 32, 32, 240, 320, 3, 3, 1, 1, 1, 1, True, None), + (1, 64, 32, 240, 320, 3, 3, 1, 1, 1, 1, True, None), + ), +) +@pytest.mark.parametrize( + "weights_dtype", + [ttnn.bfloat8_b, ttnn.bfloat16], +) +@pytest.mark.parametrize( + "activations_dtype", + [ttnn.bfloat16], +) +@pytest.mark.parametrize("fp32_accum", [False, True], ids=["no_fp32_accum", "fp32_accum"]) +@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) +@pytest.mark.parametrize("packer_l1_acc", [True, False], ids=["pack_l1", "no_pack_l1"]) +@pytest.mark.parametrize("has_bias", [True, False], ids=["with_bias", "no_bias"]) +def test_non_tile_multiple_height_conv_wh( + device, + use_program_cache, + math_fidelity, + activations_dtype, + weights_dtype, + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + stride_h, + stride_w, + pad_h, + pad_w, + use_1d_systolic_array, + config_override, + fp32_accum, + packer_l1_acc, + has_bias, +): + if device.core_grid.y == 7: + pytest.skip("Issue #6992: Statically allocated circular buffers in program clash with L1 buffers on core range") + + if ( + is_grayskull() + and activations_dtype == ttnn.bfloat16 + and batch_size == 20 + and ( + output_channels == 64 + or ( + stride_h == 2 + and (output_channels == 256 or (output_channels == 128 and weights_dtype == ttnn.bfloat16)) + ) + ) + ): + pytest.skip("Skipping test because it won't fit in L1!") + + if ( + (weights_dtype == ttnn.bfloat16 and batch_size == 20 and output_channels == 128 and input_height == 56) + or (weights_dtype == ttnn.bfloat16 and batch_size == 20 and output_channels == 64) + or (weights_dtype == ttnn.bfloat8_b and batch_size == 20 and output_channels == 128 and input_height == 56) + ): + pytest.skip("Skipping test because it won't fit in L1!") + + if has_bias and packer_l1_acc and fp32_accum: + pytest.skip("bug!") + + use_shallow_conv_variant = (input_channels == 16) and device.arch() != ttnn.device.Arch.WORMHOLE_B0 + run_conv( + device, + math_fidelity, + activations_dtype, + weights_dtype, + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + stride_h, + stride_w, + pad_h, + pad_w, + use_1d_systolic_array, + config_override=config_override, + use_shallow_conv_variant=use_shallow_conv_variant, + transpose_mcast=use_1d_systolic_array, ## use RM (transpose_mcast=False) with 2D on WH + packer_l1_acc=packer_l1_acc, + fp32_accum=fp32_accum, + has_bias=has_bias, + output_layout=ttnn.ROW_MAJOR_LAYOUT, + ) diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index b0e1c839ca92..dbdd7268c968 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -188,9 +188,7 @@ MemoryConfig create_sharded_memory_config_from_parallel_config( uint32_t nhw_shape = tensor_shape[0] * tensor_shape[1] * tensor_shape[2]; uint32_t nhw_padded = nhw_shape; - if(shard_scheme != TensorMemoryLayout::WIDTH_SHARDED) { - nhw_padded = round_up(nhw_shape, num_cores_nhw * tile_size); - } + nhw_padded = round_up(nhw_shape, num_cores_nhw * tile_size); uint32_t nhw_shard = nhw_padded / num_cores_nhw; TT_ASSERT(channels % num_cores_channels == 0, "Channels: {}, num core channels: {}", channels, num_cores_channels); uint32_t channel_shard = channels / num_cores_channels; @@ -204,14 +202,16 @@ OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_o TT_ASSERT(conv_output_mem_config.shard_spec.has_value()); const auto& shard_spec = conv_output_mem_config.shard_spec.value(); const auto& shard_shape = shard_spec.shape; - TT_ASSERT(conv_output_mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED || shard_shape[0] % 32 == 0); TT_ASSERT(shard_shape[1] % 32 == 0); + uint32_t per_core_out_matrix_height_ntiles = div_up(shard_shape[0], 32); return { .grid_size = shard_spec.grid.bounding_box().grid_size(), .num_cores_nhw = num_cores_nhw, .num_cores_c = num_cores_c, - .per_core_out_matrix_height_ntiles = tt::round_up(shard_shape[0], 32) / 32, + .per_core_out_matrix_height_ntiles = per_core_out_matrix_height_ntiles, .per_core_out_matrix_width_ntiles = shard_shape[1] / 32, + .per_core_out_matrix_height = shard_shape[0], + .per_core_out_matrix_width = shard_shape[1], }; } @@ -382,7 +382,7 @@ static TensorMemoryLayout select_shard_spec( } template -std::tuple get_conv_padded_input_shape_and_mem_config( +std::tuple get_conv_padded_input_shape_and_mem_config( T* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, @@ -429,6 +429,10 @@ std::tuple get_conv_padded_input_shape_an dilation, device); } + bool use_non_tile_height = shard_layout == TensorMemoryLayout::HEIGHT_SHARDED && out_channels <= 256 && conv_config.act_block_h_override == 0 && + conv_config.dtype == DataType::BFLOAT16 && conv_config.output_layout == Layout::ROW_MAJOR; + use_non_tile_height = use_non_tile_height && conv_config.input_channels_alignment != 16; //shalow conv varient + ParallelConfig input_tensor_parallel_config; if (!input_tensor_on_device) { needs_shard_or_reshard = true; @@ -474,8 +478,16 @@ std::tuple get_conv_padded_input_shape_an if (conv_config.reshard_if_not_optimal || needs_shard_or_reshard) { auto block_shard_orientation = conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; - ParallelConfig optimal_parallel_config = determine_parallel_config( - shard_layout, batch_size, in_channels, height, width, out_channels, device, block_shard_orientation); + const ParallelConfig& optimal_parallel_config = determine_parallel_config( + shard_layout, + batch_size, + in_channels, + height, + width, + out_channels, + device, + block_shard_orientation, + !use_non_tile_height); if (conv_config.override_sharding_config) { TT_FATAL(conv_config.core_grid.has_value(), "Error"); @@ -498,7 +510,8 @@ std::tuple get_conv_padded_input_shape_an uint32_t input_num_cores_nhw = get_num_cores_nhw_from_parallel_config(parallel_config); // TT_ASSERT(input_tensor.get_legacy_shape() == input_tensor.get_shape()); uint32_t tensor_height = input_tensor.get_shape()[0] * input_tensor.get_shape()[1] * input_tensor.get_shape()[2]; - uint32_t input_tensor_height_snapped_to_tile = (shard_layout == TensorMemoryLayout::WIDTH_SHARDED)? tensor_height : tt::round_up(tensor_height, input_num_cores_nhw * 32); + uint32_t round_up_size = (use_non_tile_height || conv_config.shard_layout == TensorMemoryLayout::WIDTH_SHARDED) ? 1 : tt::constants::TILE_HEIGHT; + uint32_t input_tensor_height_snapped_to_tile = tt::round_up(tensor_height, input_num_cores_nhw * round_up_size); TT_ASSERT(input_tensor_height_snapped_to_tile >= tensor_height); uint32_t tensor_width = input_tensor.get_shape()[3]; uint32_t input_tensor_width_snapped_to_channels_alignment = @@ -510,20 +523,18 @@ std::tuple get_conv_padded_input_shape_an 1, input_tensor_height_snapped_to_tile, input_tensor_width_snapped_to_channels_alignment}); // TODO: resolve ttnn::types::Shape and - // tt::tt_metal::LegacyShape issue to clean up next line auto input_tensor_sharded_memory_config = create_sharded_memory_config_from_parallel_config( ttnn::Shape(std::array{ - input_padded_shape[0], input_padded_shape[1], input_padded_shape[2], input_padded_shape[3]}), - parallel_config, - 32); - return {input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard}; + input_padded_shape[0], input_padded_shape[1], input_padded_shape[2], input_padded_shape[3]}), + parallel_config, round_up_size); + return {input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard, use_non_tile_height}; } else { - return {input_tensor.shape(), input_tensor.memory_config(), needs_shard_or_reshard}; + return {input_tensor.shape(), input_tensor.memory_config(), needs_shard_or_reshard, use_non_tile_height}; } } template -std::tuple shard_or_reshard_tensor_if_required( +std::tuple shard_or_reshard_tensor_if_required( T* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, @@ -542,7 +553,7 @@ std::tuple shard_or_reshard_tensor_if_requir ttnn::Tensor input_tensor = input_tensor_; // tensor to return bool input_tensor_on_device = ttnn::is_tensor_on_device_or_multidevice(input_tensor_); - auto [input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard] = + auto [input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard, use_non_tile_height] = get_conv_padded_input_shape_and_mem_config( device, input_tensor_, @@ -620,7 +631,7 @@ std::tuple shard_or_reshard_tensor_if_requir } } } - return {input_tensor, parallel_config, needs_shard_or_reshard}; + return {input_tensor, parallel_config, needs_shard_or_reshard, use_non_tile_height}; } void validate_weight_and_bias_tensors( @@ -817,7 +828,7 @@ std::tuple{1, 1, batch_size * output_height * output_width, tt::round_up(out_channels, 32)}), - parallel_config, - 32); + parallel_config, round_up_size); auto opt_conv_op_parallel_config = determine_conv_op_parallel_config_from_conv_output_mem_config( conv_out_memory_config, get_num_cores_nhw_from_parallel_config(parallel_config), get_num_cores_channels_from_parallel_config(parallel_config)); + TT_ASSERT(use_non_tile_height || conv_out_memory_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED || opt_conv_op_parallel_config.per_core_out_matrix_height % 32 == 0); auto opt_conv_op_block_config = determine_per_core_conv_block_config( parallel_config, opt_conv_op_parallel_config, @@ -910,7 +922,7 @@ std::tuple( ShardOrientation block_shard_orientation, bool is_out_tiled); -template std::tuple get_conv_padded_input_shape_and_mem_config( +template std::tuple get_conv_padded_input_shape_and_mem_config( Device* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, @@ -1070,8 +1084,8 @@ template std::tuple get_conv_padded_input uint32_t input_width, uint32_t groups); -template std::tuple get_conv_padded_input_shape_and_mem_config( - MeshDevice* device, +template std::tuple get_conv_padded_input_shape_and_mem_config( + MeshDevice * device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, uint32_t batch_size, @@ -1087,7 +1101,7 @@ template std::tuple get_conv_padded_input uint32_t input_width, uint32_t groups); -template std::tuple shard_or_reshard_tensor_if_required( +template std::tuple shard_or_reshard_tensor_if_required( Device* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, @@ -1104,8 +1118,8 @@ template std::tuple shard_or_reshard_tensor_ uint32_t input_width, uint32_t groups); -template std::tuple shard_or_reshard_tensor_if_required( - MeshDevice* device, +template std::tuple shard_or_reshard_tensor_if_required( + MeshDevice * device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, uint32_t batch_size, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp index 0e86e223eafe..c1946630657a 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp @@ -126,7 +126,7 @@ std::pair determine_largest_subblock_size(uint32_t block_hei OptimizedConvBlockConfig determine_per_core_conv_block_config(const sliding_window::ParallelConfig& parallel_config, const OptimizedConvParallelizationConfig& conv_op_parallel_config, uint32_t padded_in_channels, uint32_t act_block_h_override, uint32_t window_w, bool fp32_accum, bool use_shallow_conv_variant); template -std::tuple get_conv_padded_input_shape_and_mem_config( +std::tuple get_conv_padded_input_shape_and_mem_config( T * device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, @@ -144,7 +144,7 @@ std::tuple get_conv_padded_input_shape_an uint32_t groups); template -std::tuple shard_or_reshard_tensor_if_required( +std::tuple shard_or_reshard_tensor_if_required( T * device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp index f3a7b6e920ff..02f3f1ee7aec 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp @@ -136,7 +136,7 @@ void py_bind_conv2d(py::module& module) { std::array dilation, uint32_t weights_width, uint32_t input_width, - uint32_t groups) -> std::tuple { + uint32_t groups) -> std::tuple { return ttnn::operations::conv::conv2d::get_conv_padded_input_shape_and_mem_config( device, input_tensor, @@ -187,7 +187,7 @@ void py_bind_conv2d(py::module& module) { std::array dilation, uint32_t weights_width, uint32_t input_width, - uint32_t groups) -> std::tuple { + uint32_t groups) -> std::tuple { return ttnn::operations::conv::conv2d::get_conv_padded_input_shape_and_mem_config( device, input_tensor, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp index 692e015c546a..fc2f98557eaf 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp @@ -61,11 +61,12 @@ Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optional compute_kernel_config, bool enable_act_double_buffer, bool enable_split_reader, - bool enable_subblock_padding + bool enable_subblock_padding, + bool use_non_tile_height ) { std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({a, b}))}; operation::launch_op( - [sliding_window_config, output_channels, groups, untilize_out, fuse_relu, math_fidelity, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, compute_kernel_config, enable_act_double_buffer, enable_split_reader, enable_subblock_padding] + [sliding_window_config, output_channels, groups, untilize_out, fuse_relu, math_fidelity, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, compute_kernel_config, enable_act_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { using ttnn::operations::experimental::auto_format::FormatParams; auto& a = input_tensors.at(0); @@ -86,7 +87,7 @@ Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optionalarch() == tt::ARCH::WORMHOLE_B0; // && compute_kernel_config.has_value()) ? compute_kernel_config.value().fp32_dest_acc_en : false; auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::LoFi, true, fp32_accum, false); return operation::run_without_autoformat( - OptimizedConvNew(sliding_window_config, output_channels, groups, untilize_out, bias.has_value(), fuse_relu, math_fidelity, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, kernel_config_val, enable_act_double_buffer, enable_split_reader, enable_subblock_padding + OptimizedConvNew(sliding_window_config, output_channels, groups, untilize_out, bias.has_value(), fuse_relu, math_fidelity, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, kernel_config_val, enable_act_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height ), input_tensors, optional_input_tensors); @@ -100,6 +101,9 @@ void OptimizedConvNew::validate(const std::vector& input_tensors, const const auto& input_tensor_b = input_tensors.at(1); // TODO: ... TT_FATAL(!input_tensor_b.memory_config().is_sharded(), "Error"); + if(this->use_non_tile_height){ + TT_FATAL(this->output_channels <= 256, "use_non_tile_height uses row major order. having more than 8 tiles in output_channel in dst registers is not supported"); + } if (this->untilize_out) { TT_FATAL((this->dtype == DataType::BFLOAT16) || (this->dtype == DataType::FLOAT32), "Error"); } @@ -143,8 +147,7 @@ std::vector OptimizedConvNew::compute_output_shapes(c // Tiled output shape is padded shape. Padded to tile shape. auto shape_w = batch_size * conv_output_h * conv_output_w; auto shape_c = output_channels; - auto padded_shape_w = - parallelization_config.num_cores_nhw * parallelization_config.per_core_out_matrix_height_ntiles * TILE_HEIGHT; + auto padded_shape_w = this->use_non_tile_height ? parallelization_config.num_cores_nhw * parallelization_config.per_core_out_matrix_height : parallelization_config.num_cores_nhw * parallelization_config.per_core_out_matrix_height_ntiles * TILE_HEIGHT; auto padded_shape_c = tt::round_up(this->output_channels, TILE_WIDTH); auto output_padding = Padding( {{0, 0}, {0, 0}, {0, (padded_shape_w - shape_w)}, {0, (padded_shape_c - shape_c)}}, Padding::PadValue::Zero); @@ -160,10 +163,17 @@ std::vector OptimizedConvNew::create_output_tensors(const std::vectorcompute_output_shapes(input_tensors).at(0); if (this->memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { uint32_t total_height_tiles = tt::tt_metal::compute_volume(output_shape) / output_shape[-1] / TILE_HEIGHT; - uint32_t num_cores = total_height_tiles / this->parallelization_config.per_core_out_matrix_height_ntiles; + uint32_t num_cores; + std::array shard_shape; + if(this->use_non_tile_height){ + num_cores = this->parallelization_config.num_cores_nhw; + uint32_t total_height = tt::tt_metal::compute_volume(output_shape) / output_shape[-1]; + shard_shape = {(uint32_t)(total_height / num_cores), output_shape[-1]}; + }else{ + num_cores = total_height_tiles / this->parallelization_config.per_core_out_matrix_height_ntiles; + shard_shape = {this->parallelization_config.per_core_out_matrix_height_ntiles * TILE_HEIGHT, output_shape[-1]}; + } CoreRangeSet shard_grid = tt::tt_metal::num_cores_to_corerange_set(num_cores, this->parallelization_config.grid_size, true); - - std::array shard_shape = {this->parallelization_config.per_core_out_matrix_height_ntiles * TILE_HEIGHT, output_shape[-1]}; auto shard_spec = ShardSpec{shard_grid, shard_shape, ShardOrientation::ROW_MAJOR}; auto mem_config = this->memory_config; mem_config.shard_spec = shard_spec; @@ -228,7 +238,8 @@ operation::ProgramWithCallbacks OptimizedConvNew::create_program(const std::vect output_tensor, enable_act_double_buffer, enable_split_reader, - enable_subblock_padding); + enable_subblock_padding, + use_non_tile_height); } operation::OpPerformanceModel OptimizedConvNew::create_op_performance_model(const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector &output_tensors) const { diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp index cf938e1da131..4f5eb3e394f4 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp @@ -25,6 +25,8 @@ struct OptimizedConvParallelizationConfig { uint32_t num_cores_c = 1; uint32_t per_core_out_matrix_height_ntiles = 1; uint32_t per_core_out_matrix_width_ntiles = 1; + uint32_t per_core_out_matrix_height = 1; + uint32_t per_core_out_matrix_width = 1; // std::size_t in0_block_w; // std::size_t out_subblock_h; // std::size_t out_subblock_w; @@ -57,7 +59,8 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new(const T Tensor& output, bool enable_act_double_buffer, bool enable_split_reader, - bool enable_subblock_padding); + bool enable_subblock_padding, + bool use_non_tile_height); // new micro op struct OptimizedConvNew { @@ -76,6 +79,7 @@ struct OptimizedConvNew { bool enable_act_double_buffer; bool enable_split_reader; bool enable_subblock_padding; + bool use_non_tile_height; OptimizedConvNew(const sliding_window::SlidingWindowConfig& sliding_window_config, uint32_t output_channels, uint32_t groups, bool untile_out, @@ -85,7 +89,7 @@ struct OptimizedConvNew { MemoryConfig out_mem_config, DataType dtype, std::array input_tensor_shape, bool use_shallow_conv_variant, - const DeviceComputeKernelConfig compute_kernel_config, bool enable_act_double_buffer, bool enable_split_reader, bool enable_subblock_padding) : + const DeviceComputeKernelConfig compute_kernel_config, bool enable_act_double_buffer, bool enable_split_reader, bool enable_subblock_padding, bool use_non_tile_height) : output_channels(output_channels), groups(groups), sliding_window_config(sliding_window_config), @@ -101,7 +105,8 @@ struct OptimizedConvNew { compute_kernel_config(compute_kernel_config), enable_act_double_buffer(enable_act_double_buffer), enable_split_reader(enable_split_reader), - enable_subblock_padding(enable_subblock_padding) {} + enable_subblock_padding(enable_subblock_padding), + use_non_tile_height(use_non_tile_height) {} void validate(const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; std::vector compute_output_shapes(const std::vector& input_tensors) const; @@ -158,7 +163,8 @@ Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optional compute_kernel_config = std::nullopt, bool enable_act_double_buffer = false, bool enable_split_reader = false, - bool enable_subblock_padding = false + bool enable_subblock_padding = false, + bool use_non_tile_height = false ); } // namespace conv2d diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp index a8a7b9e47141..055e57f69a29 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp @@ -80,7 +80,8 @@ std::tuple create_CBs_for_sharded_input_v2( bool with_bias, bool split_reader, bool fp32_dest_acc_en, - bool packer_l1_acc_en) { + bool packer_l1_acc_en, + bool use_non_tile_height) { tt::DataFormat interm0_df = packer_l1_acc_en ? (fp32_dest_acc_en ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b) : out_df; @@ -178,9 +179,12 @@ std::tuple create_CBs_for_sharded_input_v2( auto cb_reblock = tt_metal::CreateCircularBuffer(program, core, cb_reblock_config); log_debug(LogOp, "Reblock CB: {}, npages: {}, pagesize: {}", untilize_mode_reblock_cb, num_reblock_cb_tiles, out_tile_size); - CircularBufferConfig cb_output_config = - CircularBufferConfig(num_writer_output_tiles * out_tile_size, {{out0_cb, out_df}}) - .set_page_size(out0_cb, out_tile_size); + auto shard_shape = output.shard_spec().value().shape; + uint32_t aligned_output_stick_nbytes = use_non_tile_height ? shard_shape[1] * output.element_size() : out_tile_size; + uint32_t aligned_output_num_pages = use_non_tile_height ? shard_shape[0] : num_writer_output_tiles; + CircularBufferConfig cb_output_config = CircularBufferConfig(aligned_output_num_pages * aligned_output_stick_nbytes, {{out0_cb, out_df}}) + .set_page_size(out0_cb, aligned_output_stick_nbytes); + if (output.is_sharded()) { cb_output_config = cb_output_config.set_globally_allocated_address(*output.buffer()); } @@ -349,7 +353,8 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( DeviceComputeKernelConfig compute_kernel_config, bool enable_act_double_buffer, bool enable_split_reader, - bool enable_subblock_padding) { + bool enable_subblock_padding, + bool use_non_tile_height) { bool pass = true; tt_metal::Device* device = a.device(); TT_FATAL(a.get_layout() == Layout::ROW_MAJOR, "Conv activation should be in row major layout"); @@ -812,7 +817,13 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( log_debug(LogOp, "num_blocks_out_h_per_core: {}", num_blocks_out_h_per_core); TT_FATAL(act_matrix_height_ntiles % per_core_out_matrix_height_ntiles == 0, "Error"); - uint32_t total_active_num_cores_per_weight_slice = act_matrix_height_ntiles / per_core_out_matrix_height_ntiles; + uint32_t total_active_num_cores_per_weight_slice; + if(use_non_tile_height){ + uint32_t input_height_padded_per_core = shard_shape[0]; + total_active_num_cores_per_weight_slice = act_matrix_height / parallelization_config.per_core_out_matrix_height; + } else { + total_active_num_cores_per_weight_slice = act_matrix_height_ntiles / per_core_out_matrix_height_ntiles; + } TT_FATAL(total_active_num_cores_per_weight_slice <= total_num_cores_per_weight_slice, "Error"); uint32_t total_noop_cores = total_num_cores_per_weight_slice - total_active_num_cores_per_weight_slice; uint32_t total_active_num_cores = total_active_num_cores_per_weight_slice * num_weight_slices_width; @@ -947,12 +958,12 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( num_act_cb_tiles = num_act_cb_tiles * 2; // double buffered } } - - uint32_t out_block_h_ntiles_padded = num_blocks_act_h_per_core * act_block_h_ntiles; uint32_t writer_output_block_num_tiles = out_block_h_ntiles_padded * weight_block_w_ntiles; uint32_t output_block_num_tiles = enable_subblock_padding ? (act_block_h_ntiles_padded * weight_block_w_ntiles) : writer_output_block_num_tiles; + uint32_t aligned_output_num_pages = use_non_tile_height ? output.shard_spec().value().shape[0] : writer_output_block_num_tiles; + std::vector reader_rt_args; std::vector reader_compile_time_args; std::vector writer_rt_args; @@ -1030,7 +1041,8 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( has_bias, split_reader, fp32_dest_acc_en, - packer_l1_acc_en); + packer_l1_acc_en, + use_non_tile_height); } CBHandle cb_sharded_act = std::get<0>(input_output_cbs); CBHandle cb_output = std::get<1>(input_output_cbs); @@ -1241,6 +1253,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( out_dram_addr, weight_dram_addr, bias_dram_addr, + aligned_output_num_pages, }; if (split_reader) { std::vector split_reader_args = { @@ -1281,7 +1294,9 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( tilize_in0, untilize_out, - bias_ntiles_per_core}; + bias_ntiles_per_core, + aligned_output_num_pages, + use_non_tile_height}; auto writer_mcast_noc = NOC::NOC_0; auto reader_noc = writer_mcast_noc == NOC::NOC_0 ? NOC::NOC_1 : NOC::NOC_0; @@ -1663,7 +1678,8 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new( Tensor& output, bool enable_act_double_buffer, bool enable_split_reader, - bool enable_subblock_padding) { + bool enable_subblock_padding, + bool use_non_tile_height) { tt_metal::Program program = tt_metal::CreateProgram(); ttnn::operations::sliding_window::ParallelConfig parallel_config; @@ -1735,7 +1751,8 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new( compute_kernel_config.value(), enable_act_double_buffer, enable_split_reader, - enable_subblock_padding); + enable_subblock_padding, + use_non_tile_height); } } // namespace tt_metal diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp index 05bd35d88b89..4c629dd879f1 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp @@ -575,7 +575,8 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_width_sharded_v2_impl( if (false) { compute_defines["PACKER_L1_ACC"] = "1"; } - + uint32_t output_rows_h = output.shard_spec().value().shape[0]; + uint32_t use_non_tile_height = false; compute_kernel_args = { act_block_w_ntiles, //in0_block_w act_num_subblocks, //in0_num_sublocks @@ -602,6 +603,9 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_width_sharded_v2_impl( bias_ntiles_per_core, + output_rows_h, + use_non_tile_height, + total_num_cores, //in0_nblocks_w_tilize. Repeat tilize after all cores have done one round of MCAST. }; diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp index 714881eb564e..5f22f248601e 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp @@ -46,22 +46,24 @@ inline void tilize_in( tilize_uninit(in_cb_id); } // tilize_in() -template +template inline void reblock_and_untilize( uint32_t num_out_subblocks_in_col, uint32_t out_subblock_num_tiles, uint32_t out_subblock_h, + uint32_t output_rows_h, uint32_t interm_cb_id, uint32_t out_cb_id) { - + constexpr bool is_non_tile_height_= is_non_tile_height; + uint32_t TILE_SIZE = is_non_tile_height_ ? 32 : out_subblock_w; uint32_t num_tiles_in_row_of_subblocks = mulsi3(out_subblock_num_tiles, num_out_subblocks_in_col); cb_wait_front(interm_cb_id, num_tiles_in_row_of_subblocks); - uint32_t within_block_index = 0; for (uint32_t h = 0; h < out_subblock_h; h++) { uint32_t block_offset = 0; - - cb_reserve_back(out_cb_id, out_block_w); + uint32_t out_sub_block_rows_h = output_rows_h <= TILE_SIZE ? output_rows_h : TILE_SIZE; + uint32_t rows_to_copy = is_non_tile_height_ ? out_sub_block_rows_h : 16; + cb_reserve_back(out_cb_id, out_sub_block_rows_h); for (uint32_t n = 0; n < num_out_subblocks_in_col; n++) { tile_regs_acquire(); for (uint32_t w = 0; w < out_subblock_w; w++) { @@ -70,12 +72,12 @@ inline void reblock_and_untilize( } tile_regs_commit(); tile_regs_wait(); - pack_untilize_dst(out_cb_id, 1, n); + pack_untilize_dst(out_cb_id, 1, n, rows_to_copy); tile_regs_release(); block_offset += out_subblock_num_tiles; } - cb_push_back(out_cb_id, out_block_w); - + cb_push_back(out_cb_id, out_sub_block_rows_h); + output_rows_h -= out_sub_block_rows_h; within_block_index += out_subblock_w; } cb_pop_front(interm_cb_id, num_tiles_in_row_of_subblocks); @@ -101,10 +103,11 @@ void MAIN { constexpr uint32_t out_subblock_num_tiles = get_compile_time_arg_val(13); // out_subblock_h * out_subblock_w; constexpr bool tilize_in0 = get_compile_time_arg_val(14); constexpr bool untilize_out = get_compile_time_arg_val(15); - + uint32_t output_rows_h = get_compile_time_arg_val(17); + constexpr bool is_non_tile_height = get_compile_time_arg_val(18); #ifdef WIDTH_SHARDED - constexpr uint32_t in0_nblocks_w_tilize = get_compile_time_arg_val(17); + constexpr uint32_t in0_nblocks_w_tilize = get_compile_time_arg_val(19); #endif constexpr uint32_t out_block_num_tiles = in0_num_subblocks * in1_num_subblocks * out_subblock_num_tiles; @@ -407,13 +410,19 @@ void MAIN { #endif pack_untilize_dst_init_short(out_cb_id); copy_tile_to_dst_init_short(); + uint32_t curr_tile_output_rows_h = 0; + uint32_t TILE_SIZE = is_non_tile_height ? 32 : out_subblock_w; + TILE_SIZE = TILE_SIZE*out_subblock_h; for (uint32_t in0_subblock_i = 0; in0_subblock_i < in0_num_subblocks; ++in0_subblock_i) { - reblock_and_untilize ( - in1_num_subblocks, - out_subblock_num_tiles, - out_subblock_h, - matmul_partials_cb, - out_cb_id); + curr_tile_output_rows_h = output_rows_h < TILE_SIZE ? output_rows_h : TILE_SIZE; + reblock_and_untilize ( + in1_num_subblocks, + out_subblock_num_tiles, + out_subblock_h, + curr_tile_output_rows_h, + matmul_partials_cb, + out_cb_id); + output_rows_h -= curr_tile_output_rows_h; } pack_untilize_uninit(matmul_partials_cb); } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp index 0358ccbae593..030eb482cb0e 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp @@ -43,16 +43,18 @@ void kernel_main() { constexpr uint32_t out_width_num_tiles = get_compile_time_arg_val(28); constexpr uint32_t out_addr = get_compile_time_arg_val(29); + constexpr uint32_t output_rows_tiles = get_compile_time_arg_val(32); + // MCAST args - constexpr uint32_t act_block_h_datums = get_compile_time_arg_val(32); - constexpr uint32_t act_block_num_tiles = get_compile_time_arg_val(33); - constexpr uint32_t conv_act_size_c_bytes = get_compile_time_arg_val(34); - constexpr uint32_t coalesced_read_bytes = get_compile_time_arg_val(35); - constexpr uint32_t window_outer_offset = get_compile_time_arg_val(36); - constexpr uint32_t act_block_w_extra_align_bytes = get_compile_time_arg_val(37); - constexpr uint32_t act_block_h_datums_first_reader = get_compile_time_arg_val(38); - constexpr uint32_t act_block_h_datums_last_block = get_compile_time_arg_val(39); + constexpr uint32_t act_block_h_datums = get_compile_time_arg_val(33); + constexpr uint32_t act_block_num_tiles = get_compile_time_arg_val(34); + constexpr uint32_t conv_act_size_c_bytes = get_compile_time_arg_val(35); + constexpr uint32_t coalesced_read_bytes = get_compile_time_arg_val(36); + constexpr uint32_t window_outer_offset = get_compile_time_arg_val(37); + constexpr uint32_t act_block_w_extra_align_bytes = get_compile_time_arg_val(38); + constexpr uint32_t act_block_h_datums_first_reader = get_compile_time_arg_val(39); + constexpr uint32_t act_block_h_datums_last_block = get_compile_time_arg_val(40); constexpr uint32_t act_block_h_datums_read_last_block = act_block_h_datums_last_block > act_block_h_datums @@ -251,6 +253,6 @@ void kernel_main() { } // out_num_blocks_w #ifdef SHARDED_OUT - cb_wait_front(cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); + cb_wait_front(cb_id_out0, output_rows_tiles); #endif } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp index 6ecc4f7b71f9..7bcb357c1843 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp @@ -42,16 +42,17 @@ void kernel_main() { constexpr uint32_t out_width_num_tiles = get_compile_time_arg_val(28); constexpr uint32_t out_addr = get_compile_time_arg_val(29); + constexpr uint32_t output_rows_tiles = get_compile_time_arg_val(32); // MCAST args - constexpr uint32_t act_block_h_datums = get_compile_time_arg_val(32); - constexpr uint32_t act_block_num_tiles = get_compile_time_arg_val(33); - constexpr uint32_t conv_act_size_c_bytes = get_compile_time_arg_val(34); - constexpr uint32_t coalesced_read_bytes = get_compile_time_arg_val(35); - constexpr uint32_t window_outer_offset = get_compile_time_arg_val(36); - constexpr uint32_t act_block_w_extra_align_bytes = get_compile_time_arg_val(37); - constexpr uint32_t act_block_h_datums_first_reader = get_compile_time_arg_val(38); - constexpr uint32_t act_block_h_datums_last_block = get_compile_time_arg_val(39); + constexpr uint32_t act_block_h_datums = get_compile_time_arg_val(33); + constexpr uint32_t act_block_num_tiles = get_compile_time_arg_val(34); + constexpr uint32_t conv_act_size_c_bytes = get_compile_time_arg_val(35); + constexpr uint32_t coalesced_read_bytes = get_compile_time_arg_val(36); + constexpr uint32_t window_outer_offset = get_compile_time_arg_val(37); + constexpr uint32_t act_block_w_extra_align_bytes = get_compile_time_arg_val(38); + constexpr uint32_t act_block_h_datums_first_reader = get_compile_time_arg_val(39); + constexpr uint32_t act_block_h_datums_last_block = get_compile_time_arg_val(40); constexpr uint32_t act_block_h_datums_read_last_block = act_block_h_datums_last_block > act_block_h_datums @@ -369,6 +370,6 @@ void kernel_main() { } // out_num_blocks_w #ifdef SHARDED_OUT - cb_wait_front(cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); + cb_wait_front(cb_id_out0, output_rows_tiles); #endif } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp index 361870bb3a3d..fc0035fa2054 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp @@ -6,7 +6,6 @@ // #include "debug/dprint.h" - void kernel_main() { // This writer is for output tensor in tile format constexpr bool out_in_dram = get_compile_time_arg_val(0) == 1; @@ -43,6 +42,7 @@ void kernel_main() { constexpr uint32_t out_width_num_tiles = get_compile_time_arg_val(28); constexpr uint32_t out_addr = get_compile_time_arg_val(29); + constexpr uint32_t output_rows_tiles = get_compile_time_arg_val(32); constexpr uint32_t total_weight_num_tiles = weight_block_height_num_outer * num_blocks_weight_h * weight_block_num_tiles; @@ -212,6 +212,6 @@ void kernel_main() { } // out_num_blocks_w #ifdef SHARDED_OUT - cb_wait_front(cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); + cb_wait_front(cb_id_out0, output_rows_tiles); #endif } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp index b7cbe738a2c0..4bc752b50eae 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp @@ -43,7 +43,7 @@ void kernel_main() { constexpr uint32_t out_width_num_tiles = get_compile_time_arg_val(28); constexpr uint32_t out_addr = get_compile_time_arg_val(29); - + constexpr uint32_t output_rows_tiles = get_compile_time_arg_val(32); constexpr uint32_t total_weight_num_tiles = weight_block_height_num_outer * num_blocks_weight_h * weight_block_num_tiles; uint32_t i = 0; @@ -331,6 +331,6 @@ void kernel_main() { weight_start_tile_id += weight_next_block_stride_w; } // out_num_blocks_w #ifdef SHARDED_OUT - cb_wait_front(cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); + cb_wait_front(cb_id_out0, output_rows_tiles); #endif }