Skip to content

Commit

Permalink
#0: Update files
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Apr 11, 2024
1 parent edc2e84 commit 80e0ade
Show file tree
Hide file tree
Showing 14 changed files with 63 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@
"dim",
(3, 2, 1, 0, -1, -2, -3, -4),
)
@pytest.mark.parametrize("all_dimensions", (False, True))
@pytest.mark.parametrize("all_dimensions", [False, True])
@pytest.mark.parametrize(
"input_shapes",
[
[[1, 1, 32, 32]],
[[4, 3, 32, 32]],
# [[1, 1, 320, 320]], #Fails for all_dimensions = True
# [[1, 3, 320, 64]],
[[2, 2, 32, 32]],
# [[6, 4, 32, 32]], #Fails for all_dimensions = True ( expected result is inf but the result generated in nan )
# [[1, 1, 320, 320]], #Fails for all_dimensions = True ( expected result is inf but the result generated in nan )
# [[1, 3, 320, 64]], #Fails for all_dimensions = True ( expected result is inf but the result generated in nan )
],
)
@pytest.mark.parametrize(
Expand All @@ -50,8 +52,8 @@ def test_run_prod_op(
dst_mem_config,
device,
):
datagen_func = [ # "prod_cpu" not implemented for 'BFloat16'
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=1, high=1.5), torch.float32)
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=1, high=1.5), torch.bfloat16)
]
test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0]
test_args.update(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
def get_tensors(input_shape, output_shape, device):
torch.manual_seed(2023)
npu_dtype = ttl.tensor.DataType.BFLOAT16
cpu_dtype = torch.float32
cpu_dtype = torch.bfloat16
npu_layout = ttl.tensor.Layout.TILE

torch_input = torch.randint(1, 5, input_shape, dtype=cpu_dtype)
torch_output = torch.randint(1, 5, output_shape, dtype=cpu_dtype)
torch.set_printoptions(threshold=10000, precision=5, sci_mode=False)
tt_input = ttl.tensor.Tensor(torch_input, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)
tt_output = ttl.tensor.Tensor(torch_output, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)

Expand All @@ -40,9 +39,9 @@ def get_tensors(input_shape, output_shape, device):
([1, 1, 32, 32]),
([1, 4, 32, 32]),
([2, 2, 32, 32]),
([6, 4, 32, 32]),
# [[1, 1, 320, 320]], #Fails for all_dimensions = True
# [[1, 3, 320, 64]], #Fails for all_dimensions = True
# ([6, 4, 32, 32]), #Fails : expected result is inf but the result generated in nan
# ([1, 1, 320, 320]), #Fails : expected result is inf but the result generated in nan
# ([1, 3, 320, 64]), #Fails : expected result is inf but the result generated in nan
),
)
def test_prod(shapes, device):
Expand Down
11 changes: 8 additions & 3 deletions tests/tt_eager/python_api_testing/unit_testing/test_prod_nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
def get_tensors(input_shape, output_shape, device):
torch.manual_seed(2023)
npu_dtype = ttl.tensor.DataType.BFLOAT16
# prod_cpu" not implemented for 'BFloat16'
cpu_dtype = torch.float32
cpu_dtype = torch.bfloat16
npu_layout = ttl.tensor.Layout.TILE

torch_input = torch.randint(-100, 100, input_shape, dtype=cpu_dtype, requires_grad=True)
torch_input = torch.randint(-100, 100, input_shape, dtype=cpu_dtype)
torch_output = torch.randint(-100, 100, output_shape, dtype=cpu_dtype)

tt_input = ttl.tensor.Tensor(torch_input, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)
Expand All @@ -32,12 +31,18 @@ def get_tensors(input_shape, output_shape, device):
@pytest.mark.parametrize(
"input_shape",
(
([2, 3, TILE_HEIGHT * 6 - 1, TILE_WIDTH * 7 - 1]),
([9, 16, TILE_HEIGHT * 13 - 1, TILE_WIDTH * 19 - 1]),
([4, 3, TILE_HEIGHT * 3 - 1, TILE_WIDTH * 11 - 1]),
([1, 1, TILE_HEIGHT - 1, TILE_WIDTH - 1]),
([4, 4, TILE_HEIGHT * 9 - 1, TILE_WIDTH * 12 - 1]),
([4, 4, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 9 - 1]),
([8, 8, TILE_HEIGHT * 4 - 1, TILE_WIDTH * 4 - 1]),
),
ids=[
"2, 3, TILE_HEIGHT * 6 - 1, TILE_WIDTH * 7 - 1",
"9, 16, TILE_HEIGHT * 13 - 1, TILE_WIDTH * 19 - 1",
"4, 3, TILE_HEIGHT * 9 - 1, TILE_WIDTH * 11 - 1",
"1, 1, TILE_HEIGHT-1,TILE_WIDTH - 1",
"4, 4, TILE_HEIGHT * 9 - 1, TILE_WIDTH * 12 - 1",
"4, 4, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 9 - 1",
Expand Down
21 changes: 8 additions & 13 deletions tt_eager/tt_dnn/op_library/prod/kernels/compute/prod_all.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,6 @@
#include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h"
#include "compute_kernel_api/eltwise_unary/negative.h"

ALWI void ACQ() {
tile_regs_acquire();
tile_regs_wait();
}
ALWI void REL() {
tile_regs_commit();
tile_regs_release();
}

namespace NAMESPACE {
void MAIN {

Expand All @@ -35,25 +26,29 @@ void MAIN {
}
cb_reserve_back(tt::CB::c_out0, 1);
for(uint32_t tile_index = 0; tile_index < per_core_block_dim; ++tile_index) {
ACQ();
cb_wait_front(tt::CB::c_in0, 1);
if (once)
{
cb_reserve_back(tt::CB::c_intermed0, 1);
tile_regs_acquire();
copy_tile_to_dst_init_short();
copy_tile(tt::CB::c_in0, 0, 0); // copy from c_in[0] to DST[0]
tile_regs_commit();
tile_regs_wait();
if constexpr (num_tiles == 1)
pack_tile(0, tt::CB::c_out0);
else
{
pack_tile(0, tt::CB::c_intermed0);
cb_push_back(tt::CB::c_intermed0, 1);
}
tile_regs_release();
}else {
REL();
ACQ();
tile_regs_acquire();
mul_tiles_init();
mul_tiles(tt::CB::c_in0, tt::CB::c_intermed0, 0, 0, 0);
tile_regs_commit();
tile_regs_wait();
if (last_tile)
{
pack_tile(0, tt::CB::c_out0);
Expand All @@ -65,10 +60,10 @@ void MAIN {
pack_tile(0, tt::CB::c_intermed0);
cb_push_back(tt::CB::c_intermed0, 1);
}
tile_regs_release();
}
once = false;
cb_pop_front(tt::CB::c_in0, 1);
REL();
}
cb_push_back(tt::CB::c_out0, 1);
}
Expand Down
15 changes: 4 additions & 11 deletions tt_eager/tt_dnn/op_library/prod/kernels/compute/prod_nc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,6 @@
#include "compute_kernel_api/eltwise_binary.h"
#include "compute_kernel_api/tile_move_copy.h"

ALWI void ACQ() {
tile_regs_acquire();
tile_regs_wait();
}
ALWI void REL() {
tile_regs_commit();
tile_regs_release();
}

namespace NAMESPACE {
void MAIN {
const auto num_input_tiles = get_arg_val<uint32_t>(0);
Expand All @@ -39,14 +30,15 @@ void MAIN {
bool last_out = (j == num_input_tiles - 1);
uint32_t cb_add = (enable_reload) ? (cb_intermed0) : (cb_in1);

ACQ();
cb_wait_front(cb_in0, onetile);
if (enable_reload) {
cb_wait_front(cb_intermed0, onetile);
}

tile_regs_acquire();
mul_tiles_init();
mul_tiles(cb_in0, cb_add, first_tile, first_tile, dst0);
tile_regs_commit();

cb_pop_front(cb_in0, onetile);
if (enable_reload) {
Expand All @@ -55,9 +47,10 @@ void MAIN {

uint32_t cb_out = (last_out) ? (cb_out0) : (cb_intermed0);
cb_reserve_back(cb_out, onetile);
tile_regs_wait();
pack_tile(dst0, cb_out);
tile_regs_release();
cb_push_back(cb_out, onetile);
REL();
enable_reload = true;
}
}
Expand Down
38 changes: 18 additions & 20 deletions tt_eager/tt_dnn/op_library/prod/kernels/dataflow/reader_prod_nc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,6 @@
#include "dataflow_api.h"
#include "tt_eager/tt_dnn/kernels/dataflow/moreh_common.hpp"

inline uint32_t get_read_tile_id(uint32_t tile_id, uint32_t dim, uint32_t input_tile_offset, uint32_t HtWt) {
if(dim == 0){
return tile_id;
}else {
uint32_t a = 0 ;
while (tile_id >= HtWt){
tile_id-= HtWt;
a = a + 1;
}
uint32_t b = 0;
for (uint32_t i = 0; i < input_tile_offset; ++i) {
b = b + a;
}
return b + tile_id;
}
}

void kernel_main() {
const auto input_addr = get_arg_val<uint32_t>(0);
const auto num_input_tiles = get_arg_val<uint32_t>(1);
Expand All @@ -33,7 +16,7 @@ void kernel_main() {
const auto input_is_dram = get_compile_time_arg_val(0) == 1;
const auto HtWt = get_arg_val<uint32_t>(6);
const auto CHtWt = get_arg_val<uint32_t>(7);
const auto dim = get_arg_val<uint32_t>(8);
const auto dim = get_compile_time_arg_val(1);

constexpr uint32_t onetile = 1;
constexpr uint32_t cb_id_in0 = 0;
Expand All @@ -52,8 +35,14 @@ void kernel_main() {
const InterleavedAddrGenFast<input_is_dram> dram_input_addrg = {
.bank_base_address = input_addr, .page_size = input_tile_bytes, .data_format = input_data_format};

uint32_t read_tile_id_temp = (dim == 0 ) ? (start_id) : (start_id / HtWt * CHtWt) + (start_id % HtWt);
uint32_t start_tile_id = start_id / HtWt * CHtWt;
uint32_t end_tile_id = start_tile_id + HtWt - 1 ;
uint32_t read_tile_id = read_tile_id_temp;
for (uint32_t i = start_id; i < start_id + num_output_tiles; i++) {
auto read_tile_id = get_read_tile_id(i, dim, CHtWt, HtWt);
if constexpr (dim == 0){
read_tile_id = i;
}
for (uint32_t j = 0; j < num_input_tiles; ++j) {
cb_reserve_back(cb_id_in0, onetile);
l1_write_addr_in0 = get_write_ptr(cb_id_in0);
Expand All @@ -62,6 +51,15 @@ void kernel_main() {
cb_push_back(cb_id_in0, onetile);
read_tile_id += input_tile_offset;
}
if constexpr (dim != 0){
if(read_tile_id_temp == end_tile_id){
start_tile_id = start_tile_id + CHtWt;
read_tile_id_temp = start_tile_id;
end_tile_id = read_tile_id_temp + HtWt - 1;
}else{
read_tile_id_temp = read_tile_id_temp + 1;
}
read_tile_id = read_tile_id_temp;
}
}

}

This file was deleted.

7 changes: 4 additions & 3 deletions tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,15 @@ operation::ProgramWithCallbacks prod_nc_format(const Tensor &input, const Tensor

tt_metal::Buffer *input_buffer_type = input.buffer();
bool input_is_dram = input_buffer_type->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0;
std::vector<uint32_t> reader_compile_time_args = {(std::uint32_t) input_is_dram};
std::vector<uint32_t> reader_compile_time_args = {(std::uint32_t) input_is_dram, static_cast<uint32_t>(dim)};

tt_metal::Buffer *output_buffer_type = output.buffer();
constexpr uint32_t cb_id_out = 16;
bool output_is_dram = output_buffer_type->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0;
std::vector<uint32_t> writer_compile_time_args = {(std::uint32_t) output_is_dram};
std::vector<uint32_t> writer_compile_time_args = {(std::uint32_t) cb_id_out, (std::uint32_t) output_is_dram};

const auto reader_kernel_file = "tt_eager/tt_dnn/op_library/prod/kernels/dataflow/reader_prod_nc.cpp";
const auto writer_kernel_file = "tt_eager/tt_dnn/op_library/prod/kernels/dataflow/writer_prod_nc.cpp";
const auto writer_kernel_file = "tt_eager/tt_dnn/kernels/dataflow/writer_unary_interleaved_start_id.cpp";
const auto reader_kernel_id = CreateReadKernel(program, reader_kernel_file, all_cores, reader_compile_time_args);
const auto writer_kernel_id = CreateWriteKernel(program, writer_kernel_file, all_cores, writer_compile_time_args);

Expand Down
1 change: 1 addition & 0 deletions tt_eager/tt_dnn/op_library/prod/prod_op_all.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Tensor prod_all(const Tensor& input, const MemoryConfig& output_mem_config ) {
}
//else --> GS Arch
return tt::numpy::prod_result_computation_GS<bfloat16>(result, result.get_dtype(), result.get_layout(), result.device(), output_mem_config);
return operation::run(Prod_op{.output_mem_config = output_mem_config}, {input}).at(0);
}

}
Expand Down
2 changes: 0 additions & 2 deletions tt_eager/tt_dnn/op_library/prod/prod_op_all.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ namespace operations{

namespace primary{

using namespace constants;
using namespace tt_metal;
/*
* prod product
*/
Expand Down
4 changes: 4 additions & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1703,10 +1703,14 @@ namespace tt::tt_metal::detail{
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("all_dimensions") , py::arg("dim") , py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs backward operations for prod on ``input_a`` along ``all_dimensions`` or a particular ``dim``.
If ``all_dimensions`` is set to ``true``, irrespective of given dimension it will perform backward prod for all dimensions.
Input tensor must have BFLOAT16 data type.
Output tensors will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"input", "Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"all_dimensions", "Consider all dimension (ignores ``dim`` param)", "bool", "", "Yes"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ inline void llk_math_eltwise_unary_sfpu_tiled_prod_init() {
llk_math_eltwise_unary_sfpu_init<APPROXIMATE>();
}

template <bool APPROXIMATE, DstSync Dst = DstSync::SyncFull>
template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_tiled_prod(uint dst_index, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_unary_sfpu_0_param<APPROXIMATE, Dst>
llk_math_eltwise_unary_sfpu_0_param<APPROXIMATE>
(ckernel::sfpu::calculate_tiled_prod<APPROXIMATE>,
ckernel::sfpu::calculate_tiled_prod<APPROXIMATE>,
dst_index, vector_mode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ inline void llk_math_eltwise_unary_sfpu_tiled_prod_init() {
llk_math_eltwise_unary_sfpu_init<SfpuType::tiled_prod, APPROXIMATE>();
}

template <bool APPROXIMATE, DstSync Dst = DstSync::SyncFull>
template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_tiled_prod(uint dst_index, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_unary_sfpu_0_param<APPROXIMATE, Dst>
llk_math_eltwise_unary_sfpu_0_param<APPROXIMATE>
(ckernel::sfpu::calculate_tiled_prod<APPROXIMATE>,
ckernel::sfpu::calculate_tiled_prod<APPROXIMATE>,
dst_index, vector_mode);
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/include/compute_kernel_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ ALWI void lez_tile_init() {
* | idst | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be less than the size of the DST register buffer | True |
*/
ALWI void tiled_prod_tile(uint32_t idst) {
MATH(( llk_math_eltwise_unary_sfpu_tiled_prod<APPROX, SyncHalf>(idst) ));
MATH(( llk_math_eltwise_unary_sfpu_tiled_prod<APPROX>(idst) ));
}

/**
Expand Down

0 comments on commit 80e0ade

Please sign in to comment.