From 903bed2d96f61d9cbab42b327a679336d9bafc1f Mon Sep 17 00:00:00 2001 From: Pavle Josipovic Date: Fri, 20 Sep 2024 11:19:41 +0000 Subject: [PATCH] #0: Optimize untilize_with_unpad for W 16 In case tensor unpadded W=16 and H%32 ==0 just skip untialize in compute and copy face 0 and face 2 of tilized input tile and skip face 1 and face 2. --- .../functional_unet/tests/test_unet_perf.py | 2 +- tests/ttnn/unit_tests/test_to_layout.py | 36 ++++++++++- .../writer_unary_unpad_width_16_sharded.cpp | 61 +++++++++++++++++++ ...ntilize_with_unpadding_program_factory.cpp | 38 ++++++++---- 4 files changed, 123 insertions(+), 14 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/kernels/dataflow/writer_unary_unpad_width_16_sharded.cpp diff --git a/models/experimental/functional_unet/tests/test_unet_perf.py b/models/experimental/functional_unet/tests/test_unet_perf.py index e6b08e2219c1..7cf9554d5d1d 100644 --- a/models/experimental/functional_unet/tests/test_unet_perf.py +++ b/models/experimental/functional_unet/tests/test_unet_perf.py @@ -33,7 +33,7 @@ @pytest.mark.models_device_performance_bare_metal @pytest.mark.parametrize( "batch, groups, expected_device_perf_fps", - ((2, 1, 683.0),), + ((2, 1, 755.0),), ) def test_unet_perf_device(batch: int, groups: int, expected_device_perf_fps: float): command = f"pytest models/experimental/functional_unet/tests/test_unet_model.py::test_unet_model[device_params0-{groups}-{batch}]" diff --git a/tests/ttnn/unit_tests/test_to_layout.py b/tests/ttnn/unit_tests/test_to_layout.py index dca50ea0aed8..fafab9674a1a 100644 --- a/tests/ttnn/unit_tests/test_to_layout.py +++ b/tests/ttnn/unit_tests/test_to_layout.py @@ -2,13 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 +from loguru import logger import pytest import torch import ttnn -from tests.ttnn.utils_for_testing import assert_with_pcc +from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc_without_tensor_printout @pytest.mark.parametrize("height", [32, 30]) @@ -91,3 +92,36 @@ def test_to_layout_wide_tensor(device, shape, on_device, from_layout, to_layout) assert_with_pcc(torch_input_tensor, output_tensor) assert torch.allclose(torch_input_tensor, output_tensor) + + +@pytest.mark.parametrize("in_dtype", [ttnn.bfloat8_b, ttnn.bfloat16, ttnn.float32]) +@pytest.mark.parametrize("use_multicore", [False, True]) +@pytest.mark.parametrize("use_pack_untilize", [False, True]) +def test_untilize_with_unpadding_W_16(device, in_dtype, use_multicore, use_pack_untilize): + tile_height = 32 + core_count = 56 + tiles_per_core = 4 + H = tile_height * core_count * tiles_per_core + W = 16 + + torch_input_shape = [1, 1, H, W] + + torch_input = torch.randn(torch_input_shape, dtype=torch.bfloat16).bfloat16() + + sharded_memory_config = ttnn.create_sharded_memory_config( + [tile_height * tiles_per_core, 2 * W], + core_grid=ttnn.CoreGrid(y=7, x=8), + strategy=ttnn.ShardStrategy.HEIGHT, + use_height_and_width_as_shard_shape=True, + ) + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=in_dtype, layout=ttnn.TILE_LAYOUT) + ttnn_input = ttnn.to_memory_config(ttnn_input, sharded_memory_config) + + output_tt = ttnn.untilize_with_unpadding( + ttnn_input, [0, 0, H - 1, W - 1], use_multicore=use_multicore, use_pack_untilize=use_pack_untilize + ) + output_torch = ttnn.to_torch(output_tt) + + passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_input, output_torch) + logger.info(pcc_msg) + assert passing diff --git a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/kernels/dataflow/writer_unary_unpad_width_16_sharded.cpp b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/kernels/dataflow/writer_unary_unpad_width_16_sharded.cpp new file mode 100644 index 000000000000..34ea417a07c7 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/kernels/dataflow/writer_unary_unpad_width_16_sharded.cpp @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" + +// Special case writer for unpad width 16 tensors +// Skip untilize and just copy f0 and f2 from input tiles to output tiles +void kernel_main() { + uint32_t num_unpadded_output_rows = get_arg_val(0); + uint32_t num_padded_tiles_per_core = get_arg_val(1); + + constexpr uint32_t cb_id_untilize_out = get_compile_time_arg_val(0); + constexpr uint32_t cb_id_out = get_compile_time_arg_val(1); + + constexpr uint32_t tile_size_in_bytes = get_tile_size(cb_id_out); + constexpr uint32_t quarter_tile_size_in_bytes = tile_size_in_bytes / 4; + + const uint32_t batches_of_8 = num_padded_tiles_per_core / 8; + const uint32_t remaining_tiles = num_padded_tiles_per_core % 8; + + cb_reserve_back(cb_id_out, num_unpadded_output_rows); + uint32_t l1_write_addr = get_write_ptr(cb_id_out); + + for (uint32_t i = 0; i < batches_of_8; i++) { + cb_wait_front(cb_id_untilize_out, 8); + uint64_t noc_l1_read_addr = get_noc_addr(get_read_ptr(cb_id_untilize_out)); + + for (uint32_t j = 0; j < 8; j++) { + noc_async_read(noc_l1_read_addr, l1_write_addr, quarter_tile_size_in_bytes); + noc_l1_read_addr += 2 * quarter_tile_size_in_bytes; + l1_write_addr += quarter_tile_size_in_bytes; + + noc_async_read(noc_l1_read_addr, l1_write_addr, quarter_tile_size_in_bytes); + noc_l1_read_addr += 2 * quarter_tile_size_in_bytes; + l1_write_addr += quarter_tile_size_in_bytes; + } + + noc_async_read_barrier(); + cb_pop_front(cb_id_untilize_out, 8); + } + + for (uint32_t i = 0; i < remaining_tiles; i++) { + cb_wait_front(cb_id_untilize_out, 1); + uint64_t noc_l1_read_addr = get_noc_addr(get_read_ptr(cb_id_untilize_out)); + + noc_async_read(noc_l1_read_addr, l1_write_addr, quarter_tile_size_in_bytes); + noc_l1_read_addr += 2 * quarter_tile_size_in_bytes; + l1_write_addr += quarter_tile_size_in_bytes; + + noc_async_read(noc_l1_read_addr, l1_write_addr, quarter_tile_size_in_bytes); + noc_l1_read_addr += 2 * quarter_tile_size_in_bytes; + l1_write_addr += quarter_tile_size_in_bytes; + + noc_async_read_barrier(); + cb_pop_front(cb_id_untilize_out, 1); + } + + cb_push_back(cb_id_out, num_unpadded_output_rows); +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/untilize_with_unpadding_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/untilize_with_unpadding_program_factory.cpp index 1f4d9b92f137..18769c37971a 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/untilize_with_unpadding_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/untilize_with_unpadding_program_factory.cpp @@ -377,7 +377,10 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_sharded( bool src_sharded = a.memory_config().is_sharded(); bool out_sharded = output.memory_config().is_sharded(); - + // Special handling for tensors of W=16 and H%32==0 + // In this case skip untilizing on compute and in writer kernel just copy face0 and face2, + // and skip face1 and face3. + bool unpad_tensor_w_16 = output.get_legacy_shape()[-1] == 16 && output.get_legacy_shape()[-2] % TILE_HEIGHT == 0; tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format); tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); @@ -439,7 +442,7 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_sharded( input_cb_data_format, src_sharded ? a.buffer() : nullptr); - uint32_t num_output_tiles = out_sharded ? ntiles_per_batch * 2 : ntiles_per_block * 2; + uint32_t num_output_tiles = out_sharded ? (unpad_tensor_w_16 ? 16 : ntiles_per_batch * 2) : ntiles_per_block * 2; auto [output_cb_index, cb_output] = create_cb(tt::CB::c_out0, program, all_cores, output_single_tile_size, num_output_tiles, output_cb_data_format); @@ -475,8 +478,10 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_sharded( vector writer_ct_args = {(uint32_t)output_cb_index, (uint32_t)sharded_output_cb_index}; unary_writer_kernel_id = CreateKernel( program, - "ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/kernels/dataflow/" - "writer_unary_unpad_batch_rows_sharded.cpp", + unpad_tensor_w_16 ? "ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/kernels/dataflow/" + "writer_unary_unpad_width_16_sharded.cpp" + : "ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/kernels/dataflow/" + "writer_unary_unpad_batch_rows_sharded.cpp", all_cores, WriterDataMovementConfig(writer_ct_args)); } else { @@ -499,7 +504,11 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_sharded( std::string compute_kernel( "ttnn/cpp/ttnn/operations/data_movement/untilize/device/kernels/compute/pack_untilize.cpp"); - if (ntiles_per_block > MAX_PACK_UNTILIZE_WIDTH || !use_pack_untilize) { + if (unpad_tensor_w_16) { + // Use copy compute kernel just potential data type conversion. + compute_kernel = "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/eltwise_copy.cpp"; + compute_args[0] = (uint32_t)num_input_tiles; // per_core_tile_cnt + } else if (ntiles_per_block > MAX_PACK_UNTILIZE_WIDTH || !use_pack_untilize) { log_debug(tt::LogOp, "Using slow untilize."); compute_kernel = "ttnn/cpp/ttnn/operations/data_movement/untilize/device/kernels/compute/untilize.cpp"; } else { @@ -520,13 +529,18 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_sharded( std::vector cores; if (out_sharded) { - vector writer_rt_args = { - num_output_rows_unpadded, - ntiles_per_batch, - out_shard_spec.shape[0] / batch, - shard_spec.shape[1] * output.element_size(), - block_row_size, - batch}; + vector writer_rt_args; + if (unpad_tensor_w_16) { + writer_rt_args = {num_output_rows_unpadded, num_input_tiles}; + } else { + writer_rt_args = { + num_output_rows_unpadded, + ntiles_per_batch, + out_shard_spec.shape[0] / batch, + shard_spec.shape[1] * output.element_size(), + block_row_size, + batch}; + } tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, all_cores, writer_rt_args); } else { uint32_t tile_start_id = 0;