Skip to content

Commit

Permalink
#4686: add fp32 to untilize
Browse files Browse the repository at this point in the history
  • Loading branch information
yugaoTT committed Feb 21, 2024
1 parent ea2c0c9 commit 7e4f7a3
Show file tree
Hide file tree
Showing 11 changed files with 336 additions and 456 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
matmul_block
============

.. doxygenfunction:: mm_block_init(uint32_t in0_cb_id = 0, uint32_t in1_cb_id = 1, uint32_t out_cb_id = 16, uint32_t ct_dim = 1, uint32_t rt_dim = 1, uint32_t kt_dim = 1)
.. doxygenfunction:: mm_block_init_short(uint32_t in0_cb_id = 0, uint32_t in1_cb_id = 1, uint32_t transpose=0, uint32_t ct_dim = 1, uint32_t rt_dim = 1, uint32_t kt_dim = 1)
.. doxygenfunction:: mm_block_init_short_with_dt(uint32_t in0_cb_id = 0, uint32_t in1_cb_id = 1, uint32_t old_in1_cb_id=2, uint32_t ct_dim = 1, uint32_t rt_dim = 1, uint32_t kt_dim = 1)
.. doxygenfunction:: mm_block_init(uint32_t in0_cb_id = 0, uint32_t in1_cb_id = 1, uint32_t out_cb_id = 16, const uint32_t transpose=0, uint32_t ct_dim = 1, uint32_t rt_dim = 1, uint32_t kt_dim = 1)
.. doxygenfunction:: mm_block_init_short(uint32_t in0_cb_id = 0, uint32_t in1_cb_id = 1, const uint32_t transpose=0, uint32_t ct_dim = 1, uint32_t rt_dim = 1, uint32_t kt_dim = 1)
.. doxygenfunction:: mm_block_init_short_with_dt(uint32_t in0_cb_id = 0, uint32_t in1_cb_id = 1, uint32_t old_in1_cb_id=2, const uint32_t transpose=0, uint32_t ct_dim = 1, uint32_t rt_dim = 1, uint32_t kt_dim = 1)
.. doxygenfunction:: matmul_block(uint32_t in0_cb_id, uint32_t in1_cb_id, uint32_t in0_tile_index, uint32_t in1_tile_index, uint32_t idst, const uint32_t transpose, uint32_t ct_dim, uint32_t rt_dim, uint32_t kt_dim)
639 changes: 293 additions & 346 deletions tests/tt_eager/python_api_testing/unit_testing/test_attn_matmul.py

Large diffs are not rendered by default.

10 changes: 2 additions & 8 deletions tt_eager/tt_dnn/op_library/bmm/bmm_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,7 @@ inline Tensor matmul (const Tensor &input_tensor_a, const Tensor &input_tensor_b
.program_config=matmul_program_config,
.output_mem_config=mem_config,
.output_dtype=input_tensor_a.dtype(),
.math_fidelity=MathFidelity::HiFi4,
.fp32_dest_acc_en=false,
.math_approx_mode=false,
.packer_l1_acc=false
.compute_kernel_config=kernel_config_val
}, {input_tensor_a, input_tensor_b}, {std::nullopt}).at(0);
} else {
return operation::run_with_autoformat(Matmul{.bcast_batch=true, .output_mem_config=mem_config, .output_dtype=input_tensor_a.dtype(), .compute_kernel_config=kernel_config_val}, {input_tensor_a, input_tensor_b}, {std::nullopt}).at(0);
Expand All @@ -398,10 +395,7 @@ inline Tensor bmm (const Tensor &input_tensor_a, const Tensor &input_tensor_b
.program_config=matmul_program_config,
.output_mem_config=mem_config,
.output_dtype=input_tensor_a.dtype(),
.math_fidelity=MathFidelity::HiFi4,
.fp32_dest_acc_en=false,
.math_approx_mode=false,
.packer_l1_acc=false
.compute_kernel_config=kernel_config_val
}, {input_tensor_a, input_tensor_b}, {std::nullopt}).at(0);
} else {
return operation::run_with_autoformat(Matmul{.bcast_batch=false, .output_mem_config=mem_config, .output_dtype=input_tensor_a.dtype(), .compute_kernel_config=kernel_config_val}, {input_tensor_a, input_tensor_b}, {std::nullopt}).at(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
#include "compute_kernel_api/tilize.h"
#include "compute_kernel_api/pack_untilize.h"

#include "compute_kernel_api/bcast.h"
#include "debug/dprint.h"

using std::uint32_t;

Expand Down Expand Up @@ -58,66 +56,38 @@ void MAIN {
constexpr uint32_t in0_num_blocks_w = 1; // TODO: Generalize


mm_init(cb_in0, cb_in1, cb_intermed0, transpose_hw);
// mm_block_init(cb_in0, cb_in1, cb_intermed0, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w );

// UNPACK(( DPRINT << num_rows_in_one_tile << ENDL() ));
// UNPACK(( DPRINT << in1_num_blocks << ENDL() ));
// UNPACK(( DPRINT << in1_num_blocks << ENDL() ));
// UNPACK(( DPRINT << in0_block_num_tiles << ENDL() ));

// mm_init(cb_in0, cb_in1, cb_intermed0, transpose_hw);
mm_block_init(cb_in0, cb_in1, cb_intermed0, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w );

for (uint32_t b = 0; b < batch; b++) {

for (uint32_t m = 0; m < Mt; m++) { // TODO: Must be 1; generalize to support batch > 32 (ie. Mt > 1)
for (uint32_t in0_block = 0; in0_block < in0_num_blocks_w; in0_block++) { // TODO: Must be 1; generalize to support inner dim blocking
cb_wait_front(cb_in0, in0_block_num_tiles);

// UNPACK(( DPRINT << TSLICE(cb_in0, 0, SliceRange::h0_32_w31()) << ENDL() ));

for (uint32_t in1_block = 0; in1_block < in1_num_blocks; in1_block++) {
uint32_t in0_index_subblock_offset = 0;
for (uint32_t tile_row_id = 0; tile_row_id < num_rows_in_one_tile; tile_row_id++) {
cb_wait_front(cb_in1, in1_block_num_tiles);
cb_pop_front(cb_in1, num_kv_heads_skip);


// UNPACK(( DPRINT << TSLICE(cb_in1, 0, SliceRange::h0_32_w31()) << ENDL() ));

for (uint32_t in1_subblock = 0; in1_subblock < in1_num_subblocks; in1_subblock++) { // TODO: Must be 1; need to review inner dim blocking and the untilizing
uint32_t in1_index_subblock_offset = 0;

tile_regs_acquire();


// Compute output sub-block
// uint32_t dst_index = 0; // start at 0, each call to matmul_block internally increments dst_index
// uint32_t in0_index = in0_index_subblock_offset; // offset into in0 block
// uint32_t in1_index = in1_index_subblock_offset; // offset into in1 block
// // inner dim that we accumualte is the inner dim of in0/in1, which is in0_block_w
// for (uint32_t inner_dim_idx = 0; inner_dim_idx < in0_block_w; ++inner_dim_idx) {
// // matmul outer product of (out_subblock_h x out_subblock_w) tiles that fill dst
// // accumulation is done by iterating matmul_block across inner dim
// // in0_block_w is passed as innder dim (kt) to matmul_block, interally used to stride in0
// matmul_block(cb_in0, cb_in1, in0_index, in1_index, dst_index, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w);
// in0_index ++; // stride right by 1
// in1_index += in1_per_core_w; // to stride down by 1 need to stride by in_per_core_w (should be called in1_block_w)
// }

uint32_t dst_index = 0;
uint32_t in0_index_h_offset = 0;
for (uint32_t h = 0; h < out_subblock_h; h++) {
for (uint32_t w = 0; w < out_subblock_w; w++) {
uint32_t in1_index_inner_dim_offset = 0;
for (uint32_t inner_dim = 0; inner_dim < in0_block_w; inner_dim++) {
uint32_t in0_index = in0_index_subblock_offset + in0_index_h_offset + inner_dim;
uint32_t in1_index = in1_index_subblock_offset + in1_index_inner_dim_offset + w;
matmul_tiles(cb_in0, cb_in1, in0_index, in1_index, dst_index, transpose_hw);
in1_index_inner_dim_offset += in1_per_core_w;
}
dst_index++;
}
in0_index_h_offset += in0_block_w;
uint32_t dst_index = 0; // start at 0, each call to matmul_block internally increments dst_index
uint32_t in0_index = in0_index_subblock_offset; // offset into in0 block
uint32_t in1_index = in1_index_subblock_offset; // offset into in1 block
// inner dim that we accumualte is the inner dim of in0/in1, which is in0_block_w
for (uint32_t inner_dim_idx = 0; inner_dim_idx < in0_block_w; ++inner_dim_idx) {
// matmul outer product of (out_subblock_h x out_subblock_w) tiles that fill dst
// accumulation is done by iterating matmul_block across inner dim
// in0_block_w is passed as innder dim (kt) to matmul_block, interally used to stride in0
matmul_block(cb_in0, cb_in1, in0_index, in1_index, dst_index, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w);
in0_index ++; // stride right by 1
in1_index += in1_per_core_w; // to stride down by 1 need to stride by in_per_core_w (should be called in1_block_w)
}

tile_regs_commit();
Expand All @@ -128,34 +98,30 @@ void MAIN {

// TODO: Review inner dim blocking, untilizing, and in1_num_subblocks > 1 (with pack_untilize, can only untilize up to dst num tiles)
// This should normally be inside subblock loop and we pack out out_subblock_num_tiles
// pack_untilize_dst_init_short<intermediate_num_tiles>();
pack_untilize_dst_init_short<intermediate_num_tiles>();
cb_reserve_back(cb_intermed0, intermediate_num_tiles);
tile_regs_wait();
// pack_untilize_dst<intermediate_num_tiles>(cb_intermed0);
pack_tile(0, cb_intermed0);
pack_untilize_dst<intermediate_num_tiles>(cb_intermed0);
pack_untilize_uninit();

// pack_untilize_uninit();
tile_regs_release();
cb_push_back(cb_intermed0, intermediate_num_tiles);


cb_wait_front(cb_intermed0, intermediate_num_tiles);
UNPACK(( DPRINT << intermediate_num_tiles << ENDL() ));
UNPACK(( DPRINT << TSLICE(cb_intermed0, 0, SliceRange::h0_32_w31()) << ENDL() ));
} // 32 tiles loop

in0_index_subblock_offset += in0_subblock_num_tiles;
} // in1_num_blocks loop
} // in0_num_blocks_w

// cb_intermed1 comes from reader; untilized row-major tile
unpack_reconfig_data_format_srca(cb_in1, cb_intermed1);
pack_reconfig_data_format(cb_intermed0, out_cb_id);
cb_wait_front(cb_intermed1, out_num_tiles);

cb_reserve_back(out_cb_id, out_num_tiles);

// tilize CB::intermed1 and write to CB::c_out0
tilize_init_short_with_dt(cb_in1, cb_intermed1, out_num_tiles);
pack_reconfig_data_format(cb_intermed0, out_cb_id);
tilize_init_short(cb_intermed1, out_num_tiles);
tilize_block(cb_intermed1, out_num_tiles, out_cb_id);
cb_push_back(out_cb_id, out_num_tiles);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

#include "dataflow_api.h"

#include "debug/dprint.h"

void kernel_main() {
uint32_t i = 0;

Expand Down Expand Up @@ -33,10 +31,6 @@ void kernel_main() {
uint32_t bfloat16_Nt_bytes = get_arg_val<uint32_t>(i++);
uint32_t bfloat16_last_row_bytes_read = get_arg_val<uint32_t>(i++);

DPRINT << "bfloat16_row_bytes " <<bfloat16_row_bytes <<ENDL();
DPRINT << "bfloat16_Nt_bytes " <<bfloat16_Nt_bytes <<ENDL();


constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1;
constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1;
constexpr uint32_t cb_id_out = get_compile_time_arg_val(2);
Expand Down Expand Up @@ -110,8 +104,6 @@ void kernel_main() {

cb_push_back(cb_id_in0, in0_block_w);



cb_reserve_back(cb_id_intermed1, out_num_tiles);
uint32_t cb_intermed1_addr = get_write_ptr(cb_id_intermed1);
for (uint32_t in1_block = 0; in1_block < in1_num_blocks; in1_block++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ operation::ProgramWithCallbacks multi_core_attn_matmul(const Tensor &a, const Te

tt::DataFormat in0_data_format = tt_metal::datatype_to_dataformat_converter(a.dtype());
tt::DataFormat in1_data_format = tt_metal::datatype_to_dataformat_converter(b.dtype());
tt::DataFormat interm_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b;
tt::DataFormat interm_data_format = fp32_dest_acc_en and in0_data_format == tt::DataFormat::Float32 ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b;
tt::DataFormat output_data_format = tt_metal::datatype_to_dataformat_converter(output.dtype());
uint32_t in0_single_tile_size = tt_metal::detail::TileSize(in0_data_format);
uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format);
Expand All @@ -68,6 +68,10 @@ operation::ProgramWithCallbacks multi_core_attn_matmul(const Tensor &a, const Te
log_debug("math_approx_mode: {}", math_approx_mode);
log_debug("fp32_dest_acc_en: {}", fp32_dest_acc_en);
log_debug("packer_l1_acc: {}", packer_l1_acc);
log_debug("in0_data_format: {}", in0_data_format);
log_debug("in1_data_format: {}", in1_data_format);
log_debug("interm_data_format: {}", interm_data_format);
log_debug("output_data_format: {}", output_data_format);

tt_metal::Buffer *src0_buffer = a.buffer();
tt_metal::Buffer *src1_buffer = b.buffer();
Expand Down Expand Up @@ -149,7 +153,7 @@ operation::ProgramWithCallbacks multi_core_attn_matmul(const Tensor &a, const Te
(uint32_t) src0_is_dram,
(uint32_t) src1_is_dram,
(uint32_t) transpose_hw_bool,
(uint32_t) fp32_dest_acc_en
(uint32_t) (fp32_dest_acc_en and in0_data_format == tt::DataFormat::Float32)
};

bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ operation::ProgramWithCallbacks multi_core_group_attn_matmul(const Tensor &a, co
tt::DataFormat in0_data_format = tt_metal::datatype_to_dataformat_converter(a.dtype());
tt::DataFormat in1_data_format = tt_metal::datatype_to_dataformat_converter(b.dtype());
tt::DataFormat interm_data_format = fp32_dest_acc_en and in0_data_format == tt::DataFormat::Float32 ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b;
// interm_data_format=tt::DataFormat::Float16_b;
tt::DataFormat output_data_format = tt_metal::datatype_to_dataformat_converter(output.dtype());
uint32_t in0_single_tile_size = tt_metal::detail::TileSize(in0_data_format);
uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format);
Expand All @@ -65,8 +64,6 @@ operation::ProgramWithCallbacks multi_core_group_attn_matmul(const Tensor &a, co
TT_ASSERT(fp32_dest_acc_en == true, "when inputs/output are in fp32 format, fp32_dest_acc_en must be set");
}



tt_metal::Buffer *src0_buffer = a.buffer();
tt_metal::Buffer *src1_buffer = b.buffer();
tt_metal::Buffer *dst_buffer = output.buffer();
Expand Down Expand Up @@ -98,7 +95,6 @@ operation::ProgramWithCallbacks multi_core_group_attn_matmul(const Tensor &a, co
const uint32_t in1_per_core_w = in1_num_subblocks * out_block_w;
const uint32_t in1_block_w_tile_bytes = out_subblock_w * in1_single_tile_size;
uint32_t ONE_ROW_BFLOAT16_BYTES = fp32_dest_acc_en and in0_data_format == tt::DataFormat::Float32 ? 128 : 64;
// ONE_ROW_BFLOAT16_BYTES = 64;
const uint32_t bfloat16_row_bytes = ONE_ROW_BFLOAT16_BYTES * out_block_w; // TODO: Generalize

log_debug("in0_block_w: {}", in0_block_w);
Expand Down Expand Up @@ -143,17 +139,13 @@ operation::ProgramWithCallbacks multi_core_group_attn_matmul(const Tensor &a, co
tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(cb0_num_input_tiles * in0_single_tile_size, {{src0_cb_index, in0_data_format}})
.set_page_size(src0_cb_index, in0_single_tile_size).set_globally_allocated_address(*src0_buffer);
cb_src0 = tt_metal::CreateCircularBuffer(program, all_device_cores, src0_cb_config);

std::cout << cb0_num_input_tiles << std::endl;
} else {
uint32_t cb0_num_input_tiles = in0_block_w; // TODO: Generalize; double buffer and add blocking along inner dim if we have Mt > 1
tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(cb0_num_input_tiles * in0_single_tile_size, {{src0_cb_index, in0_data_format}})
.set_page_size(src0_cb_index, in0_single_tile_size);
cb_src0 = tt_metal::CreateCircularBuffer(program, all_device_cores, src0_cb_config);
}



// CB for interleaved/sharded KV heads for mcasting; mcasts to same CB
// Then, push all KV_HEADS to compute and compute chooses which head to use for matmul
uint32_t src1_cb_index = CB::c_in1;
Expand Down
19 changes: 3 additions & 16 deletions tt_eager/tt_lib/csrc/operations/primary/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,6 @@ void py_module(py::module& m_primary) {
const MemoryConfig& out_mem_config,
std::optional<DataType> output_dtype,
std::optional<DeviceComputeKernelConfig> compute_kernel_config
// const MathFidelity math_fidelity,
// const bool fp32_dest_acc_en,
// const bool math_approx_mode,
// const bool packer_l1_acc
) {
return matmul(
input_tensor_a, input_tensor_b, bias, program_config, out_mem_config, output_dtype, compute_kernel_config);
Expand Down Expand Up @@ -251,21 +247,15 @@ void py_module(py::module& m_primary) {
const MatmulMultiCoreReuseProgramConfig& program_config,
const MemoryConfig& out_mem_config,
std::optional<DataType> output_dtype,
const MathFidelity math_fidelity,
const bool fp32_dest_acc_en,
const bool math_approx_mode,
const bool packer_l1_acc) {
std::optional<DeviceComputeKernelConfig> compute_kernel_config) {
return matmul(
input_tensor_a,
input_tensor_b,
bias,
program_config,
out_mem_config,
output_dtype,
math_fidelity,
fp32_dest_acc_en,
math_approx_mode,
packer_l1_acc);
compute_kernel_config);
},
py::arg("input_tensor_a").noconvert(),
py::arg("input_tensor_b").noconvert(),
Expand All @@ -274,10 +264,7 @@ void py_module(py::module& m_primary) {
py::arg("program_config").noconvert() = MatmulDefaultProgramConfig(),
py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
py::arg("output_dtype").noconvert() = std::nullopt,
py::arg("math_fidelity").noconvert() = MathFidelity::LoFi,
py::arg("fp32_dest_acc_en").noconvert() = false,
py::arg("math_approx_mode").noconvert() = true,
py::arg("packer_l1_acc").noconvert() = false,
py::arg("compute_kernel_config").noconvert() = std::nullopt,
R"doc(
Perform a matrix multiplication ``input_tensor_a x input_tensor_b``.
Expand Down
2 changes: 0 additions & 2 deletions tt_metal/hw/inc/debug/dprint_tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ struct SliceRange {
static inline SliceRange hw0_32_4() { return SliceRange{ .h0 = 0, .h1 = 32, .hs = 4, .w0 = 0, .w1 = 32, .ws = 4 }; }
// [0, 0:32]
static inline SliceRange h0_w0_32() { return SliceRange{ .h0 = 0, .h1 = 1, .hs = 1, .w0 = 0, .w1 = 32, .ws = 1 }; }
static inline SliceRange h1_w0_32() { return SliceRange{ .h0 = 1, .h1 = 2, .hs = 1, .w0 = 0, .w1 = 32, .ws = 1 }; }
// [0:32, 0]
static inline SliceRange h0_32_w0() { return SliceRange{ .h0 = 0, .h1 = 32, .hs = 1, .w0 = 0, .w1 = 1, .ws = 1 }; }
static inline SliceRange h0_32_w31() { return SliceRange{ .h0 = 0, .h1 = 32, .hs = 1, .w0 = 31, .w1 = 32, .ws = 1 }; }
// [0:32:1, 1]
static inline SliceRange h0_32_w1() { return SliceRange{ .h0 = 0, .h1 = 32, .hs = 1, .w0 = 1, .w1 = 2, .ws = 1 }; }
// [0:4:1, 0:4:1]
Expand Down
Loading

0 comments on commit 7e4f7a3

Please sign in to comment.