Skip to content

Commit

Permalink
#0: Optimize untilize_with_unpad for W 16
Browse files Browse the repository at this point in the history
In case tensor unpadded W=16 and H%32 ==0
just skip untialize in compute and copy face 0 and face 2
of tilized input tile and skip face 1 and face 2.
  • Loading branch information
Pavle Josipovic committed Sep 27, 2024
1 parent 17d2a39 commit 903bed2
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
@pytest.mark.models_device_performance_bare_metal
@pytest.mark.parametrize(
"batch, groups, expected_device_perf_fps",
((2, 1, 683.0),),
((2, 1, 755.0),),
)
def test_unet_perf_device(batch: int, groups: int, expected_device_perf_fps: float):
command = f"pytest models/experimental/functional_unet/tests/test_unet_model.py::test_unet_model[device_params0-{groups}-{batch}]"
Expand Down
36 changes: 35 additions & 1 deletion tests/ttnn/unit_tests/test_to_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

# SPDX-License-Identifier: Apache-2.0

from loguru import logger
import pytest

import torch

import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc
from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc_without_tensor_printout


@pytest.mark.parametrize("height", [32, 30])
Expand Down Expand Up @@ -91,3 +92,36 @@ def test_to_layout_wide_tensor(device, shape, on_device, from_layout, to_layout)

assert_with_pcc(torch_input_tensor, output_tensor)
assert torch.allclose(torch_input_tensor, output_tensor)


@pytest.mark.parametrize("in_dtype", [ttnn.bfloat8_b, ttnn.bfloat16, ttnn.float32])
@pytest.mark.parametrize("use_multicore", [False, True])
@pytest.mark.parametrize("use_pack_untilize", [False, True])
def test_untilize_with_unpadding_W_16(device, in_dtype, use_multicore, use_pack_untilize):
tile_height = 32
core_count = 56
tiles_per_core = 4
H = tile_height * core_count * tiles_per_core
W = 16

torch_input_shape = [1, 1, H, W]

torch_input = torch.randn(torch_input_shape, dtype=torch.bfloat16).bfloat16()

sharded_memory_config = ttnn.create_sharded_memory_config(
[tile_height * tiles_per_core, 2 * W],
core_grid=ttnn.CoreGrid(y=7, x=8),
strategy=ttnn.ShardStrategy.HEIGHT,
use_height_and_width_as_shard_shape=True,
)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=in_dtype, layout=ttnn.TILE_LAYOUT)
ttnn_input = ttnn.to_memory_config(ttnn_input, sharded_memory_config)

output_tt = ttnn.untilize_with_unpadding(
ttnn_input, [0, 0, H - 1, W - 1], use_multicore=use_multicore, use_pack_untilize=use_pack_untilize
)
output_torch = ttnn.to_torch(output_tt)

passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_input, output_torch)
logger.info(pcc_msg)
assert passing
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <stdint.h>
#include "dataflow_api.h"

// Special case writer for unpad width 16 tensors
// Skip untilize and just copy f0 and f2 from input tiles to output tiles
void kernel_main() {
uint32_t num_unpadded_output_rows = get_arg_val<uint32_t>(0);
uint32_t num_padded_tiles_per_core = get_arg_val<uint32_t>(1);

constexpr uint32_t cb_id_untilize_out = get_compile_time_arg_val(0);
constexpr uint32_t cb_id_out = get_compile_time_arg_val(1);

constexpr uint32_t tile_size_in_bytes = get_tile_size(cb_id_out);
constexpr uint32_t quarter_tile_size_in_bytes = tile_size_in_bytes / 4;

const uint32_t batches_of_8 = num_padded_tiles_per_core / 8;
const uint32_t remaining_tiles = num_padded_tiles_per_core % 8;

cb_reserve_back(cb_id_out, num_unpadded_output_rows);
uint32_t l1_write_addr = get_write_ptr(cb_id_out);

for (uint32_t i = 0; i < batches_of_8; i++) {
cb_wait_front(cb_id_untilize_out, 8);
uint64_t noc_l1_read_addr = get_noc_addr(get_read_ptr(cb_id_untilize_out));

for (uint32_t j = 0; j < 8; j++) {
noc_async_read(noc_l1_read_addr, l1_write_addr, quarter_tile_size_in_bytes);
noc_l1_read_addr += 2 * quarter_tile_size_in_bytes;
l1_write_addr += quarter_tile_size_in_bytes;

noc_async_read(noc_l1_read_addr, l1_write_addr, quarter_tile_size_in_bytes);
noc_l1_read_addr += 2 * quarter_tile_size_in_bytes;
l1_write_addr += quarter_tile_size_in_bytes;
}

noc_async_read_barrier();
cb_pop_front(cb_id_untilize_out, 8);
}

for (uint32_t i = 0; i < remaining_tiles; i++) {
cb_wait_front(cb_id_untilize_out, 1);
uint64_t noc_l1_read_addr = get_noc_addr(get_read_ptr(cb_id_untilize_out));

noc_async_read(noc_l1_read_addr, l1_write_addr, quarter_tile_size_in_bytes);
noc_l1_read_addr += 2 * quarter_tile_size_in_bytes;
l1_write_addr += quarter_tile_size_in_bytes;

noc_async_read(noc_l1_read_addr, l1_write_addr, quarter_tile_size_in_bytes);
noc_l1_read_addr += 2 * quarter_tile_size_in_bytes;
l1_write_addr += quarter_tile_size_in_bytes;

noc_async_read_barrier();
cb_pop_front(cb_id_untilize_out, 1);
}

cb_push_back(cb_id_out, num_unpadded_output_rows);
}
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,10 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_sharded(

bool src_sharded = a.memory_config().is_sharded();
bool out_sharded = output.memory_config().is_sharded();

// Special handling for tensors of W=16 and H%32==0
// In this case skip untilizing on compute and in writer kernel just copy face0 and face2,
// and skip face1 and face3.
bool unpad_tensor_w_16 = output.get_legacy_shape()[-1] == 16 && output.get_legacy_shape()[-2] % TILE_HEIGHT == 0;
tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype());
uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format);
tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype());
Expand Down Expand Up @@ -439,7 +442,7 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_sharded(
input_cb_data_format,
src_sharded ? a.buffer() : nullptr);

uint32_t num_output_tiles = out_sharded ? ntiles_per_batch * 2 : ntiles_per_block * 2;
uint32_t num_output_tiles = out_sharded ? (unpad_tensor_w_16 ? 16 : ntiles_per_batch * 2) : ntiles_per_block * 2;
auto [output_cb_index, cb_output] =
create_cb(tt::CB::c_out0, program, all_cores, output_single_tile_size, num_output_tiles, output_cb_data_format);

Expand Down Expand Up @@ -475,8 +478,10 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_sharded(
vector<uint32_t> writer_ct_args = {(uint32_t)output_cb_index, (uint32_t)sharded_output_cb_index};
unary_writer_kernel_id = CreateKernel(
program,
"ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/kernels/dataflow/"
"writer_unary_unpad_batch_rows_sharded.cpp",
unpad_tensor_w_16 ? "ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/kernels/dataflow/"
"writer_unary_unpad_width_16_sharded.cpp"
: "ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/kernels/dataflow/"
"writer_unary_unpad_batch_rows_sharded.cpp",
all_cores,
WriterDataMovementConfig(writer_ct_args));
} else {
Expand All @@ -499,7 +504,11 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_sharded(

std::string compute_kernel(
"ttnn/cpp/ttnn/operations/data_movement/untilize/device/kernels/compute/pack_untilize.cpp");
if (ntiles_per_block > MAX_PACK_UNTILIZE_WIDTH || !use_pack_untilize) {
if (unpad_tensor_w_16) {
// Use copy compute kernel just potential data type conversion.
compute_kernel = "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/eltwise_copy.cpp";
compute_args[0] = (uint32_t)num_input_tiles; // per_core_tile_cnt
} else if (ntiles_per_block > MAX_PACK_UNTILIZE_WIDTH || !use_pack_untilize) {
log_debug(tt::LogOp, "Using slow untilize.");
compute_kernel = "ttnn/cpp/ttnn/operations/data_movement/untilize/device/kernels/compute/untilize.cpp";
} else {
Expand All @@ -520,13 +529,18 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_sharded(
std::vector<CoreCoord> cores;

if (out_sharded) {
vector<uint32_t> writer_rt_args = {
num_output_rows_unpadded,
ntiles_per_batch,
out_shard_spec.shape[0] / batch,
shard_spec.shape[1] * output.element_size(),
block_row_size,
batch};
vector<uint32_t> writer_rt_args;
if (unpad_tensor_w_16) {
writer_rt_args = {num_output_rows_unpadded, num_input_tiles};
} else {
writer_rt_args = {
num_output_rows_unpadded,
ntiles_per_batch,
out_shard_spec.shape[0] / batch,
shard_spec.shape[1] * output.element_size(),
block_row_size,
batch};
}
tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, all_cores, writer_rt_args);
} else {
uint32_t tile_start_id = 0;
Expand Down

0 comments on commit 903bed2

Please sign in to comment.