Skip to content

Commit

Permalink
#9389: Add support for integer type in sum operation (#9548)
Browse files Browse the repository at this point in the history
Add int32 type support to moreh_sum op
  • Loading branch information
dongjin-na authored Jun 27, 2024
1 parent e93cfb5 commit b3b22b7
Show file tree
Hide file tree
Showing 26 changed files with 1,763 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from loguru import logger

import tt_lib as ttl
from models.utility_functions import comp_allclose_and_pcc
from models.utility_functions import (
comp_allclose_and_pcc,
skip_for_grayskull,
)
from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import (
get_compute_kernel_options,
compute_kernel_options,
Expand Down Expand Up @@ -46,11 +49,26 @@ def not_in_dims(index_value_pair):
return final_output_shape


def get_tensors(input_shape, dim, device, *, with_padding=True, use_randint=True, keep_batch_dim=False):
npu_dtype = ttl.tensor.DataType.BFLOAT16
cpu_dtype = torch.bfloat16
npu_layout = ttl.tensor.Layout.TILE
def is_npu_dtype_uint32(data_type):
return data_type == ttl.tensor.DataType.UINT32


def is_npu_dtype_float(data_type):
return data_type == ttl.tensor.DataType.FLOAT32 or data_type == ttl.tensor.DataType.BFLOAT16


def get_tensors(
input_shape,
dim,
device,
*,
with_padding=True,
use_randint=True,
keep_batch_dim=False,
npu_dtype=ttl.tensor.DataType.BFLOAT16,
cpu_dtype=torch.bfloat16,
):
npu_layout = ttl.tensor.Layout.TILE
output_shape = input_shape.copy()
if dim is None or dim == []:
dim = list(range(len(input_shape)))
Expand All @@ -69,8 +87,11 @@ def get_tensors(input_shape, dim, device, *, with_padding=True, use_randint=True
tt_output_shape = filter_indices_with_last_two(output_shape, dim)

if use_randint:
torch_input = torch.randint(-2, 3, input_shape, dtype=cpu_dtype, requires_grad=True)
torch_output = torch.randint(-2, 3, tt_output_shape, dtype=cpu_dtype)
int_min = 0 if is_npu_dtype_uint32(npu_dtype) else -2
int_max = 10 if is_npu_dtype_uint32(npu_dtype) else 3
requires_grad = True if is_npu_dtype_float(npu_dtype) else False
torch_input = torch.randint(int_min, int_max, input_shape, dtype=cpu_dtype, requires_grad=requires_grad)
torch_output = torch.randint(int_min, int_max, tt_output_shape, dtype=cpu_dtype)
else:
torch_input = torch.rand(input_shape, dtype=cpu_dtype, requires_grad=True)
torch_output = torch.rand(tt_output_shape, dtype=cpu_dtype)
Expand Down Expand Up @@ -442,7 +463,7 @@ def test_moreh_sum_backward_enable_cache(input_shape, dim, device, use_program_c
"input_shape",
([2, 3, 2, 4, TILE_HEIGHT * 6 - 1, TILE_WIDTH * 6 - 1],),
ids=[
"2, 3, 2, 4, TILE_HEIGHT * 4 - 1, TILE_WIDTH * 4 - 1",
"2, 3, 2, 4, TILE_HEIGHT * 6 - 1, TILE_WIDTH * 4 - 1",
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -492,3 +513,57 @@ def test_moreh_sum_backward_fp32_dest_acc(input_shape, dim, compute_kernel_optio
logger.debug(f"mean={torch.abs(torch_input.grad - tt_input_grad_cpu).mean()}")

assert passing


@skip_for_grayskull()
@pytest.mark.parametrize(
"input_shape",
[
[TILE_HEIGHT, TILE_WIDTH],
[3, 1, TILE_HEIGHT - 1, TILE_WIDTH - 1],
[2, 2, 3, TILE_HEIGHT * 8, TILE_WIDTH * 8],
[3, TILE_HEIGHT * 20 - 2, TILE_WIDTH * 20 - 2],
[10, 3, TILE_HEIGHT * 8 - 1, TILE_WIDTH * 8 - 1],
],
ids=[
"TILE_HEIGHT, TILE_WIDTH",
"3, 1, TILE_HEIGHT - 1, TILE_WIDTH - 1",
"2, 2, 3, TILE_HEIGHT * 8, TILE_WIDTH * 8",
"3, TILE_HEIGHT * 20 - 2, TILE_WIDTH * 20 - 2",
"10, 3, TILE_HEIGHT * 8 - 1, TILE_WIDTH * 8 - 1",
],
)
@pytest.mark.parametrize(
"dim",
[-1, -2, 0, 1],
ids=["dim-w", "dim-h", "dim-b0", "dim-b1"],
)
@pytest.mark.parametrize(
"data_type",
[ttl.tensor.DataType.INT32],
ids=["int32"],
)
def test_moreh_sum_integer(input_shape, dim, data_type, device):
if (dim == 0 or dim == 1) and (len(input_shape) - dim <= 2):
pytest.skip(f"skip sum for batch-dim with this config. {input_shape} and {dim}")

torch.manual_seed(3072)

compute_kernel_config = get_compute_kernel_options(True)
(tt_input, tt_output, tt_output_shape, _, torch_input) = get_tensors(
input_shape, dim, device, use_randint=True, keep_batch_dim=True, npu_dtype=data_type, cpu_dtype=torch.int64
)

normalized_dim = dim if dim >= 0 else len(input_shape) + dim

torch_output = torch.sum(torch_input, normalized_dim, True)
cpu_layout = ttl.tensor.Layout.ROW_MAJOR

tt_output = ttl.operations.primary.moreh_sum(
tt_input, dim=normalized_dim, keep_batch_dim=True, output=tt_output, compute_kernel_config=compute_kernel_config
)

tt_output_cpu = tt_output.cpu().to(cpu_layout).unpad_from_tile(tt_output_shape).to_torch()
logger.debug(f"{torch.equal(torch_output, tt_output_cpu)}")

assert torch.equal(torch_output, tt_output_cpu)
22 changes: 22 additions & 0 deletions tt_eager/tt_dnn/kernels/compute/moreh_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "compute_kernel_api.h"
#include "compute_kernel_api/bcast.h"
#include "compute_kernel_api/eltwise_binary.h"
#include "compute_kernel_api/eltwise_unary/eltwise_unary.h"
#include "compute_kernel_api/eltwise_unary/negative.h"
#include "compute_kernel_api/eltwise_unary/exp.h"
#include "compute_kernel_api/eltwise_unary/recip.h"
Expand Down Expand Up @@ -999,4 +1000,25 @@ ALWI void power_and_recip_tile_to_cb(
REL();
}

ALWI void copy_tile_to_dst(uint32_t icb, uint32_t itile = 0, uint32_t dst = 0, bool cb_wait_and_pop = true) {
constexpr uint32_t onetile = 1;
if (cb_wait_and_pop) {
cb_wait_front(icb, onetile);
}
unpack_reconfig_data_format_srca(icb);
copy_tile_to_dst_init_short(icb);
copy_tile(icb, itile, dst);
if (cb_wait_and_pop) {
cb_pop_front(icb, onetile);
}
}

ALWI void pack_tile_from_dst(uint32_t ocb, uint32_t dst = 0) {
constexpr uint32_t onetile = 1;
cb_reserve_back(ocb, onetile);
pack_reconfig_data_format(ocb);
pack_tile(dst, ocb);
cb_push_back(ocb, onetile);
}

} // namespace ckernel
135 changes: 135 additions & 0 deletions tt_eager/tt_dnn/kernels/dataflow/moreh_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,141 @@ FORCE_INLINE void generate_mask_w(uint32_t cb_mask, uint32_t mask_w) {
cb_push_back(cb_mask, 1);
}

// TODO: Template the generate_mask function to support different data types
FORCE_INLINE void generate_int_mask_h(uint32_t cb_mask, uint32_t mask_h) {
Scalar one;
Scalar zero;

one.u = 1;
zero.u = 0;

cb_reserve_back(cb_mask, 1);
auto ptr = reinterpret_cast<int32_t *>(get_write_ptr(cb_mask));

for (uint32_t w = 0; w < 16; w++) {
// sub tile 0
{
uint32_t mask_h_0 = mask_h;
if (mask_h_0 >= 16)
mask_h_0 = 16;
uint32_t h = 0;
for (; h < mask_h_0; h++) {
ptr[h * 16 + w] = one.u;
}
for (; h < 16; h++) {
ptr[h * 16 + w] = zero.u;
}
}

// sub tile 1
{
uint32_t mask_h_0 = mask_h;
if (mask_h_0 >= 16)
mask_h_0 = 16;
uint32_t h = 0;
for (; h < mask_h_0; h++) {
ptr[h * 16 + w + 256] = one.u;
}
for (; h < 16; h++) {
ptr[h * 16 + w + 256] = zero.u;
}
}

// sub tile 2
{
uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16;
uint32_t h = 0;
for (; h < mask_h_1; h++) {
ptr[h * 16 + w + 512] = one.u;
}
for (; h < 16; h++) {
ptr[h * 16 + w + 512] = zero.u;
}
}

// sub tile 3
{
uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16;
uint32_t h = 0;
for (; h < mask_h_1; h++) {
ptr[h * 16 + w + 768] = one.u;
}
for (; h < 16; h++) {
ptr[h * 16 + w + 768] = zero.u;
}
}
}

cb_push_back(cb_mask, 1);
}

FORCE_INLINE void generate_int_mask_w(uint32_t cb_mask, uint32_t mask_w) {
Scalar one;
Scalar zero;

one.u = 1;
zero.u = 0;

cb_reserve_back(cb_mask, 1);
auto ptr = reinterpret_cast<int32_t*>(get_write_ptr(cb_mask));

for (uint32_t h = 0; h < 16; h++) {
// sub tile 0
{
uint32_t mask_w_0 = mask_w;
if (mask_w_0 >= 16)
mask_w_0 = 16;
uint32_t w = 0;
for (; w < mask_w_0; w++) {
ptr[h * 16 + w] = one.u;
}
for (; w < 16; w++) {
ptr[h * 16 + w] = zero.u;
}
}

// sub tile 1
{
uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16;
uint32_t w = 0;
for (; w < mask_w_1; w++) {
ptr[h * 16 + w + 256] = one.u;
}
for (; w < 16; w++) {
ptr[h * 16 + w + 256] = zero.u;
}
}

// sub tile 2
{
uint32_t mask_w_0 = mask_w;
if (mask_w_0 >= 16)
mask_w_0 = 16;
uint32_t w = 0;
for (; w < mask_w_0; w++) {
ptr[h * 16 + w + 512] = one.u;
}
for (; w < 16; w++) {
ptr[h * 16 + w + 512] = zero.u;
}
}

// sub tile 3
{
uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16;
uint32_t w = 0;
for (; w < mask_w_1; w++) {
ptr[h * 16 + w + 768] = one.u;
}
for (; w < 16; w++) {
ptr[h * 16 + w + 768] = zero.u;
}
}
}

cb_push_back(cb_mask, 1);
}

FORCE_INLINE void generate_mask_h_w(
uint32_t cb_mask_h_w, uint32_t mask_h, uint32_t mask_w, uint32_t single_tile_size = 2048) {
Scalar one;
Expand Down
3 changes: 3 additions & 0 deletions tt_eager/tt_dnn/op_library/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,11 @@ set(TT_DNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_sum/moreh_sum_h_impl/moreh_sum_h_impl.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_sum/moreh_sum_h_impl/moreh_int_sum_h_impl.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_sum/moreh_sum_w_impl/moreh_sum_w_impl.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_sum/moreh_sum_w_impl/moreh_int_sum_w_impl.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_sum/moreh_sum_nc_impl/moreh_sum_nc_impl.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_sum/moreh_sum_nc_impl/moreh_int_sum_nc_impl.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_sum/moreh_sum_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_sum_backward/moreh_sum_backward_impl/moreh_sum_backward_impl.cpp
${CMAKE_CURRENT_SOURCE_DIR}/prod/prod_nc/prod_nc.cpp
Expand Down
Loading

0 comments on commit b3b22b7

Please sign in to comment.