diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_ssm_prefix_scan.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_ssm_prefix_scan.py new file mode 100644 index 00000000000..97f2b1980ed --- /dev/null +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_ssm_prefix_scan.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import tt_lib as ttl +import pytest +from loguru import logger + +from models.utility_functions import tt2torch_tensor, comp_pcc, skip_for_grayskull + + +def sequential_prefix_scan(a, bx): + (_, _, L, EN) = bx.shape + hidden_states = torch.zeros((1, 1, L, EN), device=a.device) + for i in range(L): + hidden_states[:, :, i] = a[:, :, i] * hidden_states[:, :, i - 1] + bx[:, :, i] + return hidden_states + + +def run_ssm_prefix_scan(L: int, E: int, N: int, num_cores: int, dtype, device): + torch.manual_seed(0) + + a = torch.randn((1, 1, L, E * N)) + bx = torch.randn((1, 1, L, E * N)) + + expected = sequential_prefix_scan(a, bx) + + compute_grid_size = device.compute_with_storage_grid_size() + shard_grid = ttl.tensor.CoreRangeSet(ttl.tensor.num_cores_to_corerange_set(num_cores, compute_grid_size, True)) + shard_spec = ttl.tensor.ShardSpec( + shard_grid, + [L, E * N // num_cores], + ttl.tensor.ShardOrientation.ROW_MAJOR, + False, + ) + memory_config = ttl.tensor.MemoryConfig( + ttl.tensor.TensorMemoryLayout.WIDTH_SHARDED, ttl.tensor.BufferType.L1, shard_spec + ) + a = ttl.tensor.Tensor(a, dtype).to(ttl.tensor.Layout.TILE).to(device, memory_config) + bx = ttl.tensor.Tensor(bx, dtype).to(ttl.tensor.Layout.TILE).to(device, memory_config) + + actual = ttl.operations.primary.transformers.ssm_prefix_scan( + a, bx, output_mem_config=memory_config, output_dtype=dtype + ) + assert list(actual.get_legacy_shape()) == list(expected.shape) + assert actual.dtype == dtype + + actual = tt2torch_tensor(actual) + + passing_pcc, output_pcc = comp_pcc(actual, expected, 0.999) + logger.debug(f"Out passing={passing_pcc}") + logger.debug(f"Output pcc={output_pcc}") + + assert passing_pcc + + +@skip_for_grayskull("Grayskull not supported") +@pytest.mark.parametrize( + "dtype", + (ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B), +) +@pytest.mark.parametrize( + "L, E, N, num_cores", + ( + (32, 32, 32, 1), + (32, 64, 32, 1), + (32, 2560, 32, 32), + (32, 5120, 32, 40), + # (32, 5120, 32, 64) -> 8x8 grid not supported on CI + ), +) +def test_ssm_reduce(L: int, E: int, N: int, num_cores: int, dtype, device): + run_ssm_prefix_scan(L, E, N, num_cores, dtype, device) + + +@skip_for_grayskull("Grayskull not supported") +def test_ssm_prefix_scan_with_program_cache(device, use_program_cache): + L, E, N = 32, 64, 32 + num_cores = 1 + dtype = ttl.tensor.DataType.BFLOAT8_B + run_ssm_prefix_scan(L, E, N, num_cores, dtype, device) + + dummy_memory_config = ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1) + dummy_shape = [1, 1, 128, 128] + + for _ in range(2): + run_ssm_prefix_scan(L, E, N, num_cores, dtype, device) + py_dummy_tensor = torch.randn(dummy_shape) + tt_dummy_tensor = ( + ttl.tensor.Tensor(py_dummy_tensor, dtype).to(ttl.tensor.Layout.TILE).to(device, dummy_memory_config) + ) + + assert device.num_program_cache_entries() == 1 diff --git a/tt_eager/tt_dnn/op_library/CMakeLists.txt b/tt_eager/tt_dnn/op_library/CMakeLists.txt index d5f96ded232..60643ee3e6a 100644 --- a/tt_eager/tt_dnn/op_library/CMakeLists.txt +++ b/tt_eager/tt_dnn/op_library/CMakeLists.txt @@ -173,6 +173,7 @@ set(TT_DNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_group_attn_matmul/multi_core_group_attn_matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_ssm_eltwise_mul/multi_core_ssm_eltwise_mul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_ssm_1d_sum_reduce/multi_core_ssm_1d_sum_reduce.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_ssm_prefix_scan/multi_core_ssm_prefix_scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/run_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/split/split_tiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/split/split_last_dim_two_chunks_tiled.cpp diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/ssm_prefix_scan.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/ssm_prefix_scan.cpp new file mode 100644 index 00000000000..9734b41498a --- /dev/null +++ b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/ssm_prefix_scan.cpp @@ -0,0 +1,184 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "compute_kernel_api/eltwise_binary.h" +#include "compute_kernel_api/tile_move_copy.h" +#include "compute_kernel_api/tilize.h" +#include "compute_kernel_api/untilize.h" + +constexpr uint32_t NUM_TILES_IN_TILIZED_CHUNK = 32; + +constexpr uint32_t cb_a_in = get_compile_time_arg_val(0); +constexpr uint32_t cb_bx_in = get_compile_time_arg_val(1); + +constexpr uint32_t cb_a_tilize_in = get_compile_time_arg_val(2); +constexpr uint32_t cb_bx_tilize_in = get_compile_time_arg_val(3); + +constexpr uint32_t cb_h_prev = get_compile_time_arg_val(4); +constexpr uint32_t cb_ah = get_compile_time_arg_val(5); +constexpr uint32_t cb_h = get_compile_time_arg_val(6); + +constexpr uint32_t cb_tilize_out = get_compile_time_arg_val(7); +constexpr uint32_t cb_out = get_compile_time_arg_val(8); + +constexpr uint32_t cb_zeros = get_compile_time_arg_val(9); + +// This function relies on untilizing NUM_TILES_IN_TILIZED_CHUNK tiles so we pad up to that amount +FORCE_INLINE void pack_block_rows_into_tiles(uint32_t cb_in, uint32_t cb_out, uint32_t num_tiles) { + unpack_reconfig_data_format_srca(cb_in); + pack_reconfig_data_format(cb_out); + + untilize_init_short(cb_in); + + cb_wait_front(cb_in, num_tiles); + cb_reserve_back(cb_out, NUM_TILES_IN_TILIZED_CHUNK); + + untilize_block(cb_in, NUM_TILES_IN_TILIZED_CHUNK, cb_out); + + cb_push_back(cb_out, NUM_TILES_IN_TILIZED_CHUNK); + cb_pop_front(cb_in, num_tiles); + + untilize_uninit(cb_in); +} + +// This function relies on tilizing NUM_TILES_IN_TILIZED_CHUNK tiles so we pad up to that amount +FORCE_INLINE void pack_block_tiles_into_rows(uint32_t cb_in, uint32_t cb_out, uint32_t num_tiles) { + unpack_reconfig_data_format_srca(cb_in); + pack_reconfig_data_format(cb_out); + + tilize_init_short(cb_in, NUM_TILES_IN_TILIZED_CHUNK); + + cb_wait_front(cb_in, NUM_TILES_IN_TILIZED_CHUNK); + cb_reserve_back(cb_out, num_tiles); + + tilize_block(cb_in, NUM_TILES_IN_TILIZED_CHUNK, cb_out); + + cb_push_back(cb_out, num_tiles); + cb_pop_front(cb_in, NUM_TILES_IN_TILIZED_CHUNK); + + tilize_uninit(cb_in); +} + +FORCE_INLINE void mul(uint32_t cb_a, uint32_t cb_b, uint32_t cb_out) { + unpack_reconfig_data_format(cb_a, cb_b); + pack_reconfig_data_format(cb_out); + + mul_tiles_init(); + + cb_wait_front(cb_a, 1); + cb_wait_front(cb_b, 1); + cb_reserve_back(cb_out, 1); + + tile_regs_acquire(); + mul_tiles(cb_a, cb_b, 0, 0, 0); + tile_regs_commit(); + tile_regs_wait(); + pack_tile(0, cb_out); + tile_regs_release(); + + cb_push_back(cb_out, 1); + cb_pop_front(cb_a, 1); + cb_pop_front(cb_b, 1); +} + +FORCE_INLINE void sum(uint32_t cb_a, uint32_t cb_b, uint32_t cb_out) { + unpack_reconfig_data_format(cb_a, cb_b); + pack_reconfig_data_format(cb_out); + + add_tiles_init(); + + cb_wait_front(cb_a, 1); + cb_wait_front(cb_b, 1); + cb_reserve_back(cb_out, 1); + + tile_regs_acquire(); + add_tiles(cb_a, cb_b, 0, 0, 0); + tile_regs_commit(); + tile_regs_wait(); + pack_tile(0, cb_out); + tile_regs_release(); + + cb_push_back(cb_out, 1); + cb_pop_front(cb_a, 1); + cb_pop_front(cb_b, 1); +} + +FORCE_INLINE void copy(uint32_t cb_in, uint32_t cb_out) { + unpack_reconfig_data_format_srca(cb_in); + pack_reconfig_data_format(cb_out); + + copy_tile_to_dst_init_short(); + + cb_wait_front(cb_in, 1); + cb_reserve_back(cb_out, 1); + + tile_regs_acquire(); + copy_tile(cb_in, 0, 0); + tile_regs_commit(); + tile_regs_wait(); + pack_tile(0, cb_out); + tile_regs_release(); + + // Don't pop the copied tile - caller can do it + cb_push_back(cb_out, 1); +} + +FORCE_INLINE void setup_cb_zeros() { + cb_reserve_back(cb_zeros, 1); + cb_push_back(cb_zeros, 1); +} + +FORCE_INLINE void fill_tile_zeros(uint32_t cb_id) { copy(cb_zeros, cb_id); } + +FORCE_INLINE void compute_ht(uint32_t cb_a, uint32_t cb_bx, uint32_t cb_out, uint32_t num_tiles) { + for (uint32_t idx = 0; idx < num_tiles; idx++) { + mul(cb_a, cb_h_prev, cb_ah); + sum(cb_ah, cb_bx, cb_h); + copy(cb_h, cb_h_prev); + copy(cb_h, cb_out); // TODO: Get rid of this extraneous copy + cb_pop_front(cb_h, 1); + } + // Make sure to remove the last hidden state + cb_wait_front(cb_h_prev, 1); + cb_pop_front(cb_h_prev, 1); +} + +namespace NAMESPACE { +void MAIN { + const uint32_t total_tiles = get_arg_val(0); + const uint32_t total_tiles_per_row = get_arg_val(1); + const uint32_t total_tiles_per_col = get_arg_val(2); + + const uint32_t num_tilize_per_row = + (total_tiles_per_row + NUM_TILES_IN_TILIZED_CHUNK - 1) / NUM_TILES_IN_TILIZED_CHUNK; // ceil(x/y) + + untilize_init(cb_a_in); + binary_op_init_common(cb_a_in, cb_bx_in); + + setup_cb_zeros(); + + // For each row of tiles we want to tilize chunks of 32 tiles to pack the rows into tiles + for (uint32_t row_idx = 0; row_idx < total_tiles_per_col; row_idx++) { + for (uint32_t tilized_chunk_idx = 0; tilized_chunk_idx < num_tilize_per_row; tilized_chunk_idx++) { + fill_tile_zeros(cb_h_prev); + + // If we don't have a full chunk (NUM_TILES_IN_TILIZED_CHUNK tiles) we should figure out how many tiles we + // have left. This only runs 2-3 tiles per shard so no need to unroll. + const uint32_t remaining_tiles_in_chunk = + tilized_chunk_idx == num_tilize_per_row - 1 && total_tiles_per_row % NUM_TILES_IN_TILIZED_CHUNK != 0 + ? total_tiles_per_row % NUM_TILES_IN_TILIZED_CHUNK + : NUM_TILES_IN_TILIZED_CHUNK; + + pack_block_rows_into_tiles(cb_a_in, cb_a_tilize_in, remaining_tiles_in_chunk); + pack_block_rows_into_tiles(cb_bx_in, cb_bx_tilize_in, remaining_tiles_in_chunk); + + compute_ht(cb_a_tilize_in, cb_bx_tilize_in, cb_tilize_out, NUM_TILES_IN_TILIZED_CHUNK); + + pack_block_tiles_into_rows(cb_tilize_out, cb_out, remaining_tiles_in_chunk); + } + } +} +} // namespace NAMESPACE diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/reader_ssm_prefix_scan.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/reader_ssm_prefix_scan.cpp new file mode 100644 index 00000000000..f298185dca0 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/reader_ssm_prefix_scan.cpp @@ -0,0 +1,14 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" + +void kernel_main() { + uint32_t num_tiles_per_core = get_arg_val(0); + constexpr uint32_t cb_a_in = get_compile_time_arg_val(0); + constexpr uint32_t cb_bx_in = get_compile_time_arg_val(1); + + cb_push_back(cb_a_in, num_tiles_per_core); + cb_push_back(cb_bx_in, num_tiles_per_core); +} diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/writer_ssm_prefix_scan.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/writer_ssm_prefix_scan.cpp new file mode 100644 index 00000000000..764bb9e3b32 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/writer_ssm_prefix_scan.cpp @@ -0,0 +1,11 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" + +void kernel_main() { + uint32_t num_tiles_per_core = get_arg_val(0); + constexpr uint32_t cb_out = get_compile_time_arg_val(0); + cb_wait_front(cb_out, num_tiles_per_core); +} diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/multi_core_ssm_prefix_scan/multi_core_ssm_prefix_scan.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/multi_core_ssm_prefix_scan/multi_core_ssm_prefix_scan.cpp new file mode 100644 index 00000000000..35ec8648bb8 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/transformer_tms/multi_core_ssm_prefix_scan/multi_core_ssm_prefix_scan.cpp @@ -0,0 +1,190 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tensor/tensor.hpp" +#include "tt_dnn/op_library/operation.hpp" +#include "tt_metal/host_api.hpp" + +using namespace tt::constants; +using namespace tt; + +namespace tt { +namespace operations { +namespace primary { +namespace transformers { + +operation::ProgramWithCallbacks multi_core_ssm_prefix_scan( + const Tensor& a, + const Tensor& bx, + Tensor& output, + MathFidelity math_fidelity, + CoreCoord compute_with_storage_grid_size) { + tt_metal::Program program = tt_metal::CreateProgram(); + + auto* a_buffer = a.buffer(); + auto* bx_buffer = bx.buffer(); + auto* output_buffer = output.buffer(); + TT_ASSERT(output_buffer != nullptr, "Output buffer should be allocated on device"); + + const tt::DataFormat input_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + const uint32_t input_tile_size = tt_metal::detail::TileSize(input_format); + + const tt::DataFormat intermediary_format = tt::DataFormat::Float16_b; + const uint32_t intermediary_tile_size = tt_metal::detail::TileSize(intermediary_format); + + const auto all_cores = a.shard_spec()->grid; + const auto create_circular_buffer = [&program, &all_cores]( + uint32_t index, + uint32_t num_tiles, + uint32_t tile_size, + const tt::DataFormat& format, + Buffer* buffer = nullptr) -> tt_metal::CBHandle { + auto config = CircularBufferConfig(num_tiles * tile_size, {{index, format}}).set_page_size(index, tile_size); + if (buffer != nullptr) { + config = config.set_globally_allocated_address(*buffer); + } + return tt_metal::CreateCircularBuffer(program, all_cores, config); + }; + + const uint32_t sharded_sequence_length = a.shard_spec()->shape[0]; + const uint32_t sharded_hidden_state_length = a.shard_spec()->shape[1]; + + const uint32_t total_tiles_per_row = sharded_hidden_state_length / TILE_HEIGHT; + const uint32_t total_tiles_per_col = sharded_sequence_length / TILE_HEIGHT; + const uint32_t total_tiles = total_tiles_per_row * total_tiles_per_col; + + const uint32_t cb_a_in_id = tt::CB::c_in0; + const auto cb_a_in = create_circular_buffer(cb_a_in_id, total_tiles, input_tile_size, input_format, a_buffer); + + const uint32_t cb_bx_in_id = tt::CB::c_in1; + const auto cb_bx_in = create_circular_buffer(cb_bx_in_id, total_tiles, input_tile_size, input_format, bx_buffer); + + const uint32_t cb_out_id = tt::CB::c_out0; + const auto cb_out = create_circular_buffer(cb_out_id, total_tiles, input_tile_size, input_format, output_buffer); + + const uint32_t num_tiles_in_row_to_tile_cb = 32; // Tilizing 32 tiles will pack tensor rows into seperate tiles + const uint32_t cb_a_tilize_in_id = tt::CB::c_intermed0; + const auto cb_a_tilize_in = create_circular_buffer( + cb_a_tilize_in_id, num_tiles_in_row_to_tile_cb, intermediary_tile_size, intermediary_format); + + const uint32_t cb_bx_tilize_in_id = tt::CB::c_intermed1; + const auto cb_b_tilize_in = create_circular_buffer( + cb_bx_tilize_in_id, num_tiles_in_row_to_tile_cb, intermediary_tile_size, intermediary_format); + + const uint32_t cb_tilize_out_id = tt::CB::c_intermed2; + const auto cb_tilize_out = create_circular_buffer( + cb_tilize_out_id, num_tiles_in_row_to_tile_cb, intermediary_tile_size, intermediary_format); + + const uint32_t cb_h_prev_id = tt::CB::c_intermed3; + const auto cb_h_prev = create_circular_buffer(cb_h_prev_id, 2, intermediary_tile_size, intermediary_format); + + const uint32_t cb_ah_id = tt::CB::c_intermed4; + const auto cb_ah = create_circular_buffer(cb_ah_id, 2, intermediary_tile_size, intermediary_format); + + const uint32_t cb_h_id = tt::CB::c_intermed5; + const auto cb_h = create_circular_buffer(cb_h_id, 2, intermediary_tile_size, intermediary_format); + + const uint32_t cb_zeros_id = tt::CB::c_intermed6; + const auto cb_zeros = create_circular_buffer(cb_zeros_id, 1, intermediary_tile_size, intermediary_format); + + std::vector reader_compile_time_args = {cb_a_in_id, cb_bx_in_id}; + std::vector writer_compile_time_args = {cb_out_id}; + std::vector compute_compile_time_args = { + cb_a_in_id, + cb_bx_in_id, + cb_a_tilize_in_id, + cb_bx_tilize_in_id, + cb_h_prev_id, + cb_ah_id, + cb_h_id, + cb_tilize_out_id, + cb_out_id, + cb_zeros_id}; + + auto reader_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/reader_ssm_prefix_scan.cpp", + all_cores, + tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); + + auto writer_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/writer_ssm_prefix_scan.cpp", + all_cores, + tt_metal::WriterDataMovementConfig(writer_compile_time_args)); + + auto compute_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/ssm_prefix_scan.cpp", + all_cores, + tt_metal::ComputeConfig{ + .math_fidelity = math_fidelity, + .fp32_dest_acc_en = false, + .math_approx_mode = false, + .compile_args = compute_compile_time_args}); + + std::vector cores = + grid_to_cores(all_cores.num_cores(), compute_with_storage_grid_size.x, compute_with_storage_grid_size.y, true); + + auto set_runtime_args = [reader_kernel_id, + writer_kernel_id, + compute_kernel_id, + total_tiles, + total_tiles_per_col, + total_tiles_per_row, + all_cores, + cores, + cb_a_in, + cb_bx_in, + cb_out](Program& program, const Tensor& a, const Tensor& bx, const Tensor& output) { + tt_metal::Buffer* a_buffer = a.buffer(); + tt_metal::Buffer* bx_buffer = bx.buffer(); + tt_metal::Buffer* output_buffer = output.buffer(); + + UpdateDynamicCircularBufferAddress(program, cb_a_in, *a_buffer); + UpdateDynamicCircularBufferAddress(program, cb_bx_in, *bx_buffer); + UpdateDynamicCircularBufferAddress(program, cb_out, *output_buffer); + + std::vector> reader_runtime_args = {cores.size(), {0}}; // (num_tiles_per_core) + std::vector> writer_runtime_args = {cores.size(), {0}}; // (num_tiles_per_core) + std::vector> compute_runtime_args = { + cores.size(), {0, 0, 0}}; // (total_tiles, total_tiles_per_row, total_tiles_per_col) + + for (uint32_t i = 0, num_blocks_written = 0; i < cores.size(); i++) { + const CoreCoord& core = cores.at(i); + + reader_runtime_args[i][0] = total_tiles; + + writer_runtime_args[i][0] = total_tiles; + + compute_runtime_args[i][0] = total_tiles; + compute_runtime_args[i][1] = total_tiles_per_row; + compute_runtime_args[i][2] = total_tiles_per_col; + } + SetRuntimeArgs(program, reader_kernel_id, cores, reader_runtime_args); + SetRuntimeArgs(program, writer_kernel_id, cores, writer_runtime_args); + SetRuntimeArgs(program, compute_kernel_id, cores, compute_runtime_args); + }; + + set_runtime_args(program, a, bx, output); + + auto override_runtime_arguments_callback = [set_runtime_args]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>&, + const std::vector& output_tensors) { + auto& a = input_tensors.at(0); + auto& bx = input_tensors.at(1); + auto& out = output_tensors.at(0); + set_runtime_args(program, a, bx, out); + }; + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} + +} // namespace transformers +} // namespace primary +} // namespace operations +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp index f30ca5f9df1..ed8ff976089 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp @@ -668,6 +668,59 @@ tt::stl::reflection::Attributes SSM1DSumReduce::attributes() const { }; } +void SSMPrefixScan::validate(const std::vector& input_tensors) const { + TT_FATAL(input_tensors.size() == 2, "Expected 2 input tensors"); + + const auto& a = input_tensors.at(0); + const auto& bx = input_tensors.at(1); + + TT_FATAL(a.dtype() == bx.dtype(), "Expected input tensors to have the same data type"); + + TT_FATAL(a.layout() == Layout::TILE && bx.layout() == Layout::TILE, "Expected input tensors to be tile layout"); + + TT_FATAL(a.get_legacy_shape() == bx.get_legacy_shape(), "Expected input tensors to have the same shape"); + + const auto& shape = a.get_legacy_shape(); + TT_FATAL(shape.rank() == 4, "Expected input tensors to be rank 4"); + TT_FATAL(shape[0] == 1 && shape[1] == 1, "Dimension 0 and 1 should be size 1"); + TT_FATAL(shape[2] >= TILE_HEIGHT && shape[2] % TILE_HEIGHT == 0, "Sequence length should be a multiple of 32"); + + TT_FATAL(a.is_sharded() && bx.is_sharded(), "Expected input tensors to be sharded"); + TT_FATAL(a.shard_spec().has_value() && bx.shard_spec().has_value(), "Expected input tensors to be sharded"); + TT_FATAL( + a.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR, + "Expected A tensor to be row major orientation"); + TT_FATAL( + bx.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR, + "Expected Bx tensor to be row major orientation"); +} + +std::vector SSMPrefixScan::compute_output_shapes(const std::vector& input_tensors) const { + const auto& a = input_tensors.at(0); + return {a.get_legacy_shape()}; +} + +std::vector SSMPrefixScan::create_output_tensors(const std::vector& input_tensors) const { + return operation::generic_create_output_tensors( + *this, input_tensors, this->output_dtype, Layout::TILE, this->output_mem_config); +} + +operation::ProgramWithCallbacks SSMPrefixScan::create_program( + const std::vector& input_tensors, std::vector& output_tensors) const { + const auto& a = input_tensors.at(0); + const auto& bx = input_tensors.at(1); + auto& output = output_tensors.at(0); + auto device_compute_with_storage_grid_size = a.device()->compute_with_storage_grid_size(); + return multi_core_ssm_prefix_scan(a, bx, output, math_fidelity, device_compute_with_storage_grid_size); +} + +tt::stl::reflection::Attributes SSMPrefixScan::attributes() const { + return { + {"output_mem_config", this->output_mem_config}, + {"output_dtype", this->output_dtype}, + }; +} + } // namespace transformers } // namespace primary } // namespace operations diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.hpp b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.hpp index fafcbcb9b9b..e27af4e1e70 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.hpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.hpp @@ -23,11 +23,19 @@ namespace transformers { operation::ProgramWithCallbacks multi_core_split_query_key_value_and_split_heads(const Tensor &input_tensor, std::vector &output, CoreCoord compute_with_storage_grid_size); operation::ProgramWithCallbacks multi_core_split_query_key_value_and_split_heads_sharded(const Tensor &input_tensor, std::vector &output, CoreCoord compute_with_storage_grid_size); operation::ProgramWithCallbacks multi_core_concat_heads(const Tensor &input_tensor, Tensor &output_tensor, CoreCoord compute_with_storage_grid_size); + // TODO: Group attention matmul will support sharding, mcasting, and should be faster; we should make attn_matmul (ie. KV heads = 1) a special case of group_attn_matmul and run the same op operation::ProgramWithCallbacks multi_core_attn_matmul(const Tensor &input_tensor_a, const Tensor &input_tensor_b, Tensor &output_tensor, std::optional num_tokens, std::optional transpose_hw, CoreCoord compute_with_storage_grid_size, DeviceComputeKernelConfig compute_kernel_config); + operation::ProgramWithCallbacks multi_core_group_attn_matmul(const Tensor &input_tensor_a, const Tensor &input_tensor_b, Tensor &output_tensor, std::optional num_tokens, std::optional transpose_hw, const uint32_t out_subblock_w, CoreCoord compute_with_storage_grid_size, const bool row_major, DeviceComputeKernelConfig compute_kernel_config); operation::ProgramWithCallbacks multi_core_ssm_eltwise_mul(const Tensor &input_tensor_a, const Tensor &input_tensor_b, Tensor &output_tensor, const uint32_t hidden_size, MathFidelity math_fidelity, CoreCoord compute_with_storage_grid_size); operation::ProgramWithCallbacks multi_core_ssm_1d_sum_reduce(const Tensor &input_tensor_a, Tensor &output_tensor, MathFidelity math_fidelity, CoreCoord compute_with_storage_grid_size); +operation::ProgramWithCallbacks multi_core_ssm_prefix_scan( + const Tensor& a, + const Tensor& bx, + Tensor& output, + MathFidelity math_fidelity, + CoreCoord compute_with_storage_grid_size); struct SplitFusedQKVAndSplitHeads { CoreCoord compute_with_storage_grid_size; @@ -216,6 +224,41 @@ inline Tensor ssm_1d_sum_reduce(const Tensor &input_tensor_a, const MemoryConfig return output_tensors.at(0); } +struct SSMPrefixScan { + MemoryConfig output_mem_config; + DataType output_dtype; + MathFidelity math_fidelity; + + void validate(const std::vector& input_tensors) const; + std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector create_output_tensors(const std::vector& input_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector& input_tensors, std::vector& output_tensors) const; + tt::stl::reflection::Attributes attributes() const; +}; + +inline Tensor ssm_prefix_scan( + const Tensor& a, + const Tensor& bx, + const MemoryConfig& mem_config, + std::optional output_dtype = std::nullopt, + MathFidelity math_fidelity = MathFidelity::HiFi4) { + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({a, bx}))}; + operation::launch_op( + [mem_config, output_dtype, math_fidelity]( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector>& optional_output_tensors) mutable -> std::vector { + const auto& a = input_tensors.at(0); + const auto& bx = input_tensors.at(1); + return operation::run( + SSMPrefixScan{mem_config, output_dtype.value_or(a.get_dtype()), math_fidelity}, input_tensors); + }, + {a, bx}, + output_tensors); + return output_tensors.at(0); +} + } // namespace transformers } // namespace primary diff --git a/tt_eager/tt_lib/csrc/operations/primary/transformers/module.hpp b/tt_eager/tt_lib/csrc/operations/primary/transformers/module.hpp index 701f53253ca..d155b99419f 100644 --- a/tt_eager/tt_lib/csrc/operations/primary/transformers/module.hpp +++ b/tt_eager/tt_lib/csrc/operations/primary/transformers/module.hpp @@ -58,6 +58,16 @@ void py_module(py::module& m_transformers) { Performs a custom reduction along dim 3 which is used in the SSM block of the Mamba architecture. Performs the following PyTorch equivalent (where latent_size = 32): x = torch.sum(x.reshape(1, 1, shape[2], shape[3] // latent_size, latent_size), dim=-1).reshape(1, 1, shape[2], shape[3] // latent_size) )doc"); + m_transformers.def( + "ssm_prefix_scan", + &ssm_prefix_scan, + py::arg().noconvert(), + py::arg().noconvert(), + py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + py::arg("output_dtype").noconvert() = std::nullopt, + py::arg("math_fidelity").noconvert() = MathFidelity::HiFi4, + R"doc( + Performs a prefix scan to produce the SSM hidden states across an entire sequence. All input and output tensors are expected to be shape [1, 1, L, 2EN] where E = 2560 and N = 32. L can be any multiple of 32.)doc"); py::class_(m_transformers, "SoftmaxProgramConfig").def(py::init<>());