Skip to content

Commit

Permalink
#0: changed read and compute kernels for height reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
ipotkonjak-tt committed Nov 29, 2024
1 parent 53c32c0 commit 4bb43d2
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 18 deletions.
156 changes: 156 additions & 0 deletions tests/ttnn/unit_tests/operations/test_reduction_h_interleaved.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest

import torch
from functools import partial

import ttnn
from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import torch_random

from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt


@pytest.mark.parametrize(
"batch_size",
[
1,
],
)
@pytest.mark.parametrize("h", [2 * 32])
@pytest.mark.parametrize("w", [32, 48, 64, 80, 96, 112, 128])
@pytest.mark.parametrize("c", [9 * 64])
@pytest.mark.parametrize("n", [1])
@pytest.mark.parametrize("dim", [-2])
@pytest.mark.parametrize("input_dtype", [ttnn.bfloat16])
@pytest.mark.parametrize("input_memory_config", [ttnn.DRAM_MEMORY_CONFIG])
@pytest.mark.parametrize("output_memory_config", [ttnn.DRAM_MEMORY_CONFIG])
def test_3D_tensor(device, batch_size, h, w, c, n, dim, input_dtype, input_memory_config, output_memory_config):
torch.manual_seed(0)

torch_input_tensor = gen_func_with_cast_tt(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype
)((n, c, h, w))
golden_function = ttnn.get_golden_function(ttnn.sum)
torch_output_tensor = golden_function(torch_input_tensor, dim=dim, memory_config=output_memory_config)

input_tensor = ttnn.from_torch(
torch_input_tensor, dtype=input_dtype, layout=ttnn.TILE_LAYOUT, device=device, memory_config=input_memory_config
)

output_tensor = ttnn.sum(input_tensor, dim=dim, memory_config=output_memory_config)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)

if dim:
rank = 4
if isinstance(dim, tuple):
for d in dim:
if d < 0:
d += rank
else:
if dim < 0:
dim += rank
output_tensor = ttnn.to_torch(output_tensor).squeeze(dim=dim)
else:
output_tensor = ttnn.to_torch(output_tensor).squeeze()
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


@pytest.mark.parametrize(
"batch_size",
[
1,
],
)
@pytest.mark.parametrize("h", [2 * 32])
@pytest.mark.parametrize(
"w", [7 * 64 * 32, 7 * 64 * 48, 7 * 64 * 64, 7 * 64 * 80, 7 * 64 * 96, 7 * 64 * 112, 7 * 64 * 128]
)
@pytest.mark.parametrize("c", [1])
@pytest.mark.parametrize("n", [1])
@pytest.mark.parametrize("dim", [-2])
@pytest.mark.parametrize("input_dtype", [ttnn.bfloat16])
@pytest.mark.parametrize("input_memory_config", [ttnn.DRAM_MEMORY_CONFIG])
@pytest.mark.parametrize("output_memory_config", [ttnn.DRAM_MEMORY_CONFIG])
def test_2D_tensor_full_grid(
device, batch_size, h, w, c, n, dim, input_dtype, input_memory_config, output_memory_config
):
torch.manual_seed(0)

torch_input_tensor = gen_func_with_cast_tt(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype
)((n, c, h, w))
golden_function = ttnn.get_golden_function(ttnn.sum)
torch_output_tensor = golden_function(torch_input_tensor, dim=dim, memory_config=output_memory_config)

input_tensor = ttnn.from_torch(
torch_input_tensor, dtype=input_dtype, layout=ttnn.TILE_LAYOUT, device=device, memory_config=input_memory_config
)

output_tensor = ttnn.sum(input_tensor, dim=dim, memory_config=output_memory_config)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)

if dim:
rank = 4
if isinstance(dim, tuple):
for d in dim:
if d < 0:
d += rank
else:
if dim < 0:
dim += rank
output_tensor = ttnn.to_torch(output_tensor).squeeze(dim=dim)
else:
output_tensor = ttnn.to_torch(output_tensor).squeeze()
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


@pytest.mark.parametrize(
"batch_size",
[
1,
],
)
@pytest.mark.parametrize("h", [2 * 32])
@pytest.mark.parametrize("w", [32, 64, 96, 128])
@pytest.mark.parametrize("c", [1])
@pytest.mark.parametrize("n", [1])
@pytest.mark.parametrize("dim", [-2])
@pytest.mark.parametrize("input_dtype", [ttnn.bfloat8_b])
@pytest.mark.parametrize("input_memory_config", [ttnn.DRAM_MEMORY_CONFIG])
@pytest.mark.parametrize("output_memory_config", [ttnn.DRAM_MEMORY_CONFIG])
def test_2D_tensor(device, batch_size, h, w, c, n, dim, input_dtype, input_memory_config, output_memory_config):
torch.manual_seed(0)

torch_input_tensor = gen_func_with_cast_tt(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype
)((n, c, h, w))
golden_function = ttnn.get_golden_function(ttnn.sum)
torch_output_tensor = golden_function(torch_input_tensor, dim=dim, memory_config=output_memory_config)

input_tensor = ttnn.from_torch(
torch_input_tensor, dtype=input_dtype, layout=ttnn.TILE_LAYOUT, device=device, memory_config=input_memory_config
)

output_tensor = ttnn.sum(input_tensor, dim=dim, memory_config=output_memory_config)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)

if dim:
rank = 4
if isinstance(dim, tuple):
for d in dim:
if d < 0:
d += rank
else:
if dim < 0:
dim += rank
output_tensor = ttnn.to_torch(output_tensor).squeeze(dim=dim)
else:
output_tensor = ttnn.to_torch(output_tensor).squeeze()
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>

#include "compute_kernel_api/reduce.h"

namespace NAMESPACE {
void MAIN {

uint32_t Ht = get_compile_time_arg_val(0);
uint32_t Wt = get_compile_time_arg_val(1);
uint32_t NC = get_compile_time_arg_val(2);

reduce_init<true>(tt::CB::c_in0, tt::CB::c_in2);
cb_wait_front(tt::CB::c_in2, 1); // scaler tile from the reader

constexpr int onetile = 1;
for (uint32_t nc = 0; nc < NC; ++nc) {
uint32_t row_chunk = 8;
for(uint32_t wt = 0; wt < Wt; wt += row_chunk) {
uint32_t chunk_end = std::min(wt + row_chunk, Wt);
uint32_t tile_num = std::min(row_chunk, Wt - wt);
int reduce_dst_idx = 0;

//reduce a chunk of columns(max 8)
acquire_dst();
for(uint32_t ht = 0; ht < Ht; ++ht) {
reduce_dst_idx = 0;
for(uint32_t i = wt; i < chunk_end; ++i) {
cb_wait_front(tt::CB::c_in0, onetile);
reduce_tile(tt::CB::c_in0, tt::CB::c_in2, 0, 0, reduce_dst_idx);
cb_pop_front(tt::CB::c_in0, onetile);
++reduce_dst_idx;
}
}
for(uint32_t i = 0; i < tile_num; i++) {
cb_reserve_back(tt::CB::c_out0, onetile);
pack_tile(i, tt::CB::c_out0);
cb_push_back(tt::CB::c_out0, onetile);
}
release_dst();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,40 @@ void kernel_main() {

uint32_t w = curr_col_in_batch;

// this reader will read a NHW tensor in NWH order
for (uint32_t i = 0; i < num_cols; i++) {
uint32_t row_chunk = 8;
for (uint32_t i = 0; i < num_cols; i += row_chunk) {
uint32_t chunk_end = std::min(i + row_chunk, num_cols);
uint32_t curr_id = col_start_tile_id;
for (uint32_t j = 0; j < Ht; j++) {
cb_reserve_back(cb_id_in0, onetile);
uint32_t l1_write_addr = get_write_ptr(cb_id_in0);
noc_async_read_tile(curr_id, s, l1_write_addr);
noc_async_read_barrier();
cb_push_back(cb_id_in0, onetile);
curr_id += Wt; // stride in H
}
w++;
if (w == Wt) {
col_start_tile_id = curr_id - Wt + 1;
w = 0;
} else {
col_start_tile_id++;
uint32_t reset_curr_id = curr_id;
uint32_t reset_w = w;
uint32_t reset_col_start = col_start_tile_id;

// row wise read for a chunk of columns(max 8)
for (uint32_t j = 0; j < Ht; ++j) {
w = reset_w;
col_start_tile_id = reset_col_start;
for (uint32_t k = i; k < chunk_end; ++k) {


cb_reserve_back(cb_id_in0, onetile);
uint32_t l1_write_addr = get_write_ptr(cb_id_in0);
noc_async_read_tile(curr_id, s, l1_write_addr);
noc_async_read_barrier();
cb_push_back(cb_id_in0, onetile);

++w;

if (w == Wt) {
col_start_tile_id = curr_id + (Ht - j - 1) * Wt + 1;
curr_id = col_start_tile_id + j*Wt;
w = 0;
}
else {
++curr_id;
++col_start_tile_id;
}
}
curr_id = reset_curr_id + (j+1) * Wt; // stride in H
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,13 @@ operation::ProgramWithCallbacks reduce_multi_core_h(
1, // NC
};

std::string compute_kernel_path;
if(out_sharded) compute_kernel_path = "ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h.cpp";
else compute_kernel_path = "ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h_interleaved.cpp";

auto reduce_compute_kernel_group_1_id = tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h.cpp",
compute_kernel_path,
core_group_1,
tt_metal::ComputeConfig{
.math_fidelity = math_fidelity,
Expand All @@ -193,7 +197,7 @@ operation::ProgramWithCallbacks reduce_multi_core_h(

auto reduce_compute_kernel_group_2_id = tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h.cpp",
compute_kernel_path,
core_group_2,
tt_metal::ComputeConfig{
.math_fidelity = math_fidelity,
Expand Down

0 comments on commit 4bb43d2

Please sign in to comment.